In [1]:
import re
import warnings

import datasets
import pandas as pd
import numpy as np
import sympy

import gadgets

In [2]:
math_qa = datasets.load_dataset("math_qa")
math_qa

Found cached dataset math_qa (/var/tmp/xkadlci2/.cache/huggingface/datasets/math_qa/default/0.1.0/67fc1cc5d22b185002c6fd16e19e4d5215eae01fb04d656bed83204ba6ee55ff)


  0%|          | 0/3 [00:00<?, ?it/s]

DatasetDict({
    train: Dataset({
        features: ['Problem', 'Rationale', 'options', 'correct', 'annotated_formula', 'linear_formula', 'category'],
        num_rows: 29837
    })
    test: Dataset({
        features: ['Problem', 'Rationale', 'options', 'correct', 'annotated_formula', 'linear_formula', 'category'],
        num_rows: 2985
    })
    validation: Dataset({
        features: ['Problem', 'Rationale', 'options', 'correct', 'annotated_formula', 'linear_formula', 'category'],
        num_rows: 4475
    })
})

In [3]:
df = {
    "train": math_qa["train"].to_pandas().reset_index(),
    "validation": math_qa["validation"].to_pandas().reset_index(),
    "test": math_qa["test"].to_pandas().reset_index()
}
for name in df.keys():
    df[name]["split_name"] = name

df_all = pd.concat(df.values()).reset_index(drop=True)

In [4]:
df_all_orig = pd.concat(df.values()).reset_index(drop=True)

In [5]:
replacement_map = {
    "const_0_2778": "divide(const_10, const_36)", # convert km/h m/s
    "const_0_25": "divide(const_1, const_4)",
    "const_0_33": "divide(const_1, const_3)",
    "const_0_6": "divide(const_6, const_10)",
    "const_0_3937": "divide(const_100, const_254)", # convert cm to inch
}

def replace_constants(formula: str, replacement_map: dict[str, str]) -> str:
    for const, replacement in replacement_map.items():
        formula = formula.replace(const, replacement)

    if "const_0_" in formula:
        warnings.warn(f"formula still contains decimal constants: '{formula}'")
        formula = formula.replace("const_0_", "const_0.")
    return formula

df_all["annotated_formula"] = df_all["annotated_formula"].apply(lambda x: replace_constants(x, replacement_map))
df_all["linear_formula"] = df_all["linear_formula"].apply(lambda x: replace_constants(x, replacement_map))

assert df_all["annotated_formula"].str.contains("const_0_").sum() == 0
assert df_all["linear_formula"].str.contains("const_0_").sum() == 0

# replace thousands separator from comma to underscore
# avoiding conflict with argument separator
df_all["annotated_formula"] = df_all["annotated_formula"].str.replace(r"\d,\d", lambda m: m.group().replace(",", "_"), regex=True)
df_all["linear_formula"] = df_all["linear_formula"].str.replace(r"\d,\d", lambda m: m.group().replace(",", "_"), regex=True)

if isinstance(df_all["linear_formula"].iloc[0], str):
    df_all["linear_formula"] = (
        df_all["linear_formula"]
        .str.strip()
        .str.strip("|")
        .str.split("|")
        .apply(lambda steps: [step.strip() for step in steps])
    )

In [6]:
df_all

Unnamed: 0,index,Problem,Rationale,options,correct,annotated_formula,linear_formula,category,split_name
0,0,the banker ' s gain of a certain sum due 3 yea...,"""explanation : t = 3 years r = 10 % td = ( bg ...","a ) rs . 400 , b ) rs . 300 , c ) rs . 500 , d...",a,"divide(multiply(const_100, divide(multiply(36,...","[multiply(n2,const_100), multiply(n0,n1), divi...",gain,train
1,1,average age of students of an adult school is ...,"""explanation : let the original no . of studen...","a ) 1200 , b ) 120 , c ) 360 , d ) 240 , e ) n...",d,"multiply(divide(subtract(multiply(add(32, 4), ...","[add(n2,n3), multiply(n1,n2), multiply(n1,#0),...",general,train
2,2,sophia finished 2 / 3 of a book . she calculat...,let xx be the total number of pages in the boo...,"a ) 229 , b ) 270 , c ) 877 , d ) 266 , e ) 281",b,"divide(90, subtract(const_1, divide(2, 3)))","[divide(n0,n1), subtract(const_1,#0), divide(n...",general,train
3,3,120 is what percent of 50 ?,"""50 * x = 120 - - > x = 2.4 - - > 2.4 expresse...","a ) 5 % , b ) 240 % , c ) 50 % , d ) 2 % , e )...",b,"multiply(divide(120, 50), const_100)","[divide(n0,n1), multiply(#0,const_100)]",gain,train
4,4,there are 10 girls and 20 boys in a classroom ...,"if girls is 10 and boys is 20 , then 10 / 20 ....","a ) 1 / 2 , b ) 1 / 3 , c ) 1 / 5 , d ) 10 / 3...",a,"divide(10, 20)","[divide(n0,n1)]",other,train
...,...,...,...,...,...,...,...,...,...
37292,2980,find the area of a parallelogram with base 20 ...,"""area of a parallelogram = base * height = 20 ...","a ) 100 cm 2 , b ) 250 cm 2 , c ) 800 cm 2 , d...",c,"multiply(20, 40)","[multiply(n0,n1)]",geometry,test
37293,2981,"in a garden , there are 10 rows and 14 columns...","""explanation : each row contains 14 plants . t...","a ) 20 m , b ) 22 m , c ) 24 m , d ) 26 m , e ...",e,"add(add(multiply(subtract(14, const_1), 2), di...","[divide(n0,n2), subtract(n1,const_1), multiply...",physics,test
37294,2982,a can do a piece of work in 6 hours ; b and c ...,"""a ' s 1 hour work = 1 / 6 ; ( b + c ) ' s 1 h...","a ) 8 hours , b ) 6 hours , c ) 14 hours , d )...",d,"divide(const_1, subtract(divide(const_1, 4), s...","[divide(const_1,n1), divide(const_1,n2), divid...",physics,test
37295,2983,a train 250 m long running at 72 kmph crosses ...,"""d = 72 * 5 / 18 = 30 = 600 â € “ 250 = 350 m ...","a ) 350 m , b ) 200 m , c ) 250 m , d ) 270 m ...",a,"subtract(multiply(30, multiply(72, divide(cons...","[multiply(n1,divide(const_10, const_36)), mult...",physics,test


In [7]:
print(len(df_all[df_all["options"].str.contains("b ) b )", regex=False)]))

# for some reason, some examples have options choices repeated
df_all["options"] = (
    df_all["options"]
    .str.replace("a ) a )", "a )", regex=False)
    .str.replace("b ) b )", "b )", regex=False)
    .str.replace("c ) c )", "c )", regex=False)
    .str.replace("d ) d )", "d )", regex=False)
    .str.replace("e ) e )", "e )", regex=False)
)

print(len(df_all[df_all["options"].str.contains("b ) b )", regex=False)]))

571
0


In [8]:
import ast

options_regex = re.compile(r"a \) (?P<a>.*), b \) (?P<b>.*), c \) (?P<c>.*), d \) (?P<d>.*) , e \) (?P<e>.*).*")
number_re = re.compile(r"[-+]?\d+[,.]?\d*(\s?[\/:]\s?\d+\.?\d*)*")

def parse_options(options: str) -> dict[str, str]:
    options = options.strip()
    if options.startswith("["):
        options = " , ".join(ast.literal_eval(options))
    match = options_regex.match(options)
    if match is None:
        raise ValueError(f"Could not parse options {options}")
    return match.groupdict()

def extract_number_from_option(string: str) -> str:
    string = string.strip()
    if string == "":
        return None
    
    string = (string
        .replace(",", "")
        .replace("+ ", "+")
        .replace("−", "-")
        .replace("- ", "-")
    )
    match = number_re.search(string)
    if not match:
        return None
    
    string = match.group()
    string = "".join(string.split())
    string = string.replace(":", "/")
    return string


In [9]:
calc = gadgets.gadget.Calculator()

In [10]:
def parse_float(string: str) -> float:
    if string is None:
        return None
    
    try:
        return float(string)
    except Exception:
        pass
    
    try:
        return calc._float_eval(string)
    except Exception:
        pass

    warnings.warn(f"Could not parse '{string}' as float")
    return float('nan')

In [11]:
if not isinstance(df_all["options"].iloc[0], dict):
    df_all["options"] = df_all["options"].apply(parse_options)

if "correct_float" not in df_all.columns:
    df_all["options_num"] = df_all["options"].apply(lambda x: {k: extract_number_from_option(v) for k, v in x.items()})
    df_all["options_float"] = df_all["options_num"].apply(lambda x: {k: parse_float(v) for k, v in x.items()})
    df_all["correct_float"] = df_all.apply(lambda row: row["options_float"][row["correct"]], axis=1)



In [12]:
import lark

# grammar for parsing nested function calls like fn1(fn2(arg1, arg2), arg3) into a tree
grammar = """
    ?call : CNAME "(" argument ("," argument)* ")"
    ?!argument : call | value
    !value : ( LETTER | DIGIT | "." | "_" | )* (DIGIT | LETTER)

    %import common.WS
    %import common.CNAME
    %import common.DIGIT
    %import common.LETTER

    %ignore WS
"""

parser = lark.Lark(grammar, start="call")

In [13]:
idx = 210
tree = parser.parse(df_all["annotated_formula"].iloc[idx])
print(df_all["annotated_formula"].iloc[idx])
#print(tree.pretty(indent_str="  "))

subtract(52000, multiply(const_60, const_100))


In [14]:
from typing import Any, Type
import abc


class Operation(abc.ABC):
    min_args: int = 1
    max_args: int | None = None
    exact_num_args: int | None = None
    _str: str | None = None

    @classmethod
    def template(cls, args: list[Any]) -> str:
        cls.validate(args)
        return cls._template(args)

    @classmethod
    def _template(cls, args: list[Any]) -> str:
        if cls._str is None:
            raise NotImplementedError
        return cls._str.format(*args)
    
    @classmethod
    def validate(cls, args: list[Any]) -> None:
        if cls.exact_num_args is not None and len(args) != cls.exact_num_args:
            raise ValueError(f"Expected {cls.exact_num_args} arguments for {cls.__name__}, got {len(args)}")
        if cls.min_args is not None and len(args) < cls.min_args:
            raise ValueError(f"Expected at least {cls.min_args} arguments for {cls.__name__}, got {len(args)}")
        if cls.max_args is not None and len(args) > cls.max_args:
            raise ValueError(f"Expected at most {cls.max_args} arguments for {cls.__name__}, got {len(args)}")
        

class InfixOperation(Operation):
    operator: str | None = None
    min_args: int = 2

    @classmethod
    def _template(cls, args: list[Any]) -> str:
        args = list(map(str, args))
        assert cls.operator is not None
        return (" " + cls.operator + " ").join(args)

class PrefixOperation(Operation):
    fn_name: str | None = None

    @classmethod
    def _template(cls, args: list[Any]) -> str:        
        args = list(map(str, args))
        assert cls.fn_name is not None
        return f"{cls.fn_name}({', '.join(args)})"

class Add(InfixOperation):
    operator = "+"

class Subtract(InfixOperation):
    operator = "-"
    
class Multiply(InfixOperation):
    operator = "*"
    
class Divide(InfixOperation):
    operator = "/"

    @classmethod
    def validate(cls, args: list[Any]) -> None:
        if len(args) > 2:
            warnings.warn(f"More than 2 arguments for {cls.__name__}, will be templated as a/b/c/d...")
        return super().validate(args)

class Power(InfixOperation):
    operator = "**"

    @classmethod
    def validate(cls, args: list[Any]) -> str:
        if len(args) > 2:
            warnings.warn(f"More than 2 arguments for {cls.__name__}, will be templated as a**b**c**d...")
        return super().validate(args)

class Modulo(InfixOperation):
    exact_num_args = 2
    operator = "%"

class Sqrt(Operation):
    exact_num_args = 1
    _str = "{} ** (1/2)"

class Inverse(Operation):
    exact_num_args = 1
    _str = "1 / {}"

class Negate(Operation):
    exact_num_args = 1
    _str = "-{}"
    
class Square(Operation):
    exact_num_args = 1
    _str = "{} ** 2"

class SurfaceCube(Operation):
    exact_num_args = 1
    _str = "6 * ({0} ** 2)"
    
class VolumeCube(Operation):
    exact_num_args = 1
    _str = "{} ** 3"
    
class CircleArea(Operation):
    exact_num_args = 1
    _str = "pi * ({} ** 2)"

class CircleCircumface(Operation):
    exact_num_args = 1
    _str = "2 * pi * {}"
    
class SquarePerimeter(Operation):
    exact_num_args = 1
    _str = "4 * {}"

class RectanglePerimeter(Operation):
    exact_num_args = 2
    _str = "2 * ({} + {})"

class Speed(Operation):
    exact_num_args = 2
    _str = "{} / {}"
    
class NegateProbability(Operation):
    exact_num_args = 1
    _str = "1 - {}"

class TrianglePerimeter(Operation):
    exact_num_args = 3
    _str = "{} + {} + {}"

class CubeEdgeByVolume(Operation):
    exact_num_args = 1
    _str = "{} ** (1/3)"

class SquareEdgeByPerimeter(Operation):
    exact_num_args = 1
    _str = "{} / 4"
    
class SquareEdgeByArea(Operation):
    exact_num_args = 1
    _str = "{} ** (1/2)"

class VolumeRectangularPrism(Operation):
    exact_num_args = 3
    _str = "{} * {} * {}"

class SurfaceRectangularPrism(Operation):
    exact_num_args = 3
    _str = "2 * ({0} * {1} + {1} * {2} + {0} * {2})"

class TriangleAreaFromThreeSides(Operation):
    exact_num_args = 3

    @classmethod
    def _template(cls, args: list[Any]) -> str:
        # Heron's formula
        a, b, c = args
        s = f"({a} + {b} + {c}) / 2"
        return f"(({s}) * ({s} - {a}) * ({s} - {b}) * ({s} - {c})) ** (1/2)"

class VolumeCylinder(Operation):
    exact_num_args = 2
    _str = "pi * ({0} ** 2) * {1}"

class TriangleAreaFromBaseAndHeight(Operation):
    exact_num_args = 2
    _str = "({0} * {1}) / 2"

class RhombusAreaByDiagonals(Operation):
    exact_num_args = 2
    _str = "({0} * {1}) / 2"

class QuadrilateralArea(Operation):
    exact_num_args = 3
    _str = "({0} + {1}) * {2} / 2"

class SurfaceCylinder(Operation):
    exact_num_args = 2
    _str = "2 * pi * {0} * ({0} + {1})"

class VolumeCone(Operation):
    exact_num_args = 2
    _str = "pi * ({0} ** 2) * {1} / 3"

class VolumeSphere(Operation):
    exact_num_args = 1
    _str = "4/3 * pi * ({} ** 3)"

class SurfaceSphere(Operation):
    exact_num_args = 1
    _str = "4 * pi * ({} ** 2)"

class Max(PrefixOperation):
    fn_name = "max"
    
class Min(PrefixOperation):
    fn_name = "min"

class Log(PrefixOperation):
    fn_name = "log"
    min_args = 1
    max_args = 2

class Factorial(PrefixOperation):
    exact_num_args = 1
    fn_name = "factorial"

class Floor(PrefixOperation):
    exact_num_args = 1
    fn_name = "floor"

class Ceiling(PrefixOperation):
    exact_num_args = 1
    fn_name = "ceiling"
    
class Binomial(PrefixOperation):
    exact_num_args = 2
    fn_name = "binomial"

class Permutation(Operation):
    exact_num_args = 2
    _str = "factorial({0}) / factorial({0} - {1})"

class GCD(PrefixOperation):
    min_args = 2
    fn_name = "gcd"

class LCM(PrefixOperation):
    min_args = 2
    fn_name = "lcm"

class Diagonal(Operation):
    exact_num_args = 2
    _str = "({0} ** 2 + {1} ** 2) ** (1/2)"


OPERATIONS: dict[str, Type[Operation]] = {
    "add": Add,
    "subtract": Subtract,
    "multiply": Multiply,
    "divide": Divide,
    "power": Power,
    "sqrt": Sqrt,
    "reminder": Modulo, # someone made a typo when generating the dataset
    "modulo": Modulo,
    "inverse": Inverse,
    "max": Max,
    "min": Min,
    "factorial": Factorial,
    "negate": Negate,
    "floor": Floor,
    "ceiling": Ceiling,
    "choose": Binomial,
    "binomial": Binomial,
    "permutation": Permutation,
    "rectangle_area": Multiply,
    "square_area": Square,
    "surface_cube": SurfaceCube,
    "volume_cube": VolumeCube,
    "gcd": GCD,
    "lcm": LCM,
    "log": Log,
    "circle_area": CircleArea,
    "circumface": CircleCircumface,
    "square_perimeter": SquarePerimeter,
    "rectangle_perimeter": RectanglePerimeter,
    "speed": Speed,
    "negate_prob": NegateProbability,
    "triangle_perimeter": TrianglePerimeter,
    "cube_edge_by_volume": CubeEdgeByVolume,
    "square_edge_by_perimeter": SquareEdgeByPerimeter,
    "square_edge_by_area": SquareEdgeByArea,
    "volume_rectangular_prism": VolumeRectangularPrism,
    "surface_rectangular_prism": SurfaceRectangularPrism,
    "triangle_area_three_edges": TriangleAreaFromThreeSides,
    "volume_cylinder": VolumeCylinder,
    "triangle_area": TriangleAreaFromBaseAndHeight,
    "rhombus_area": RhombusAreaByDiagonals,
    "quadrilateral_area": QuadrilateralArea,
    "surface_cylinder": SurfaceCylinder,
    "volume_cone": VolumeCone,
    "volume_sphere": VolumeSphere,
    "surface_sphere": SurfaceSphere,
    "diagonal": Diagonal,
}


def to_calc_input(fn_name: str, args: list[Any]) -> str:
    args = [str(arg) for arg in args]

    if fn_name not in OPERATIONS:
        raise ValueError(f"Unknown function {fn_name}")
    
    operation = OPERATIONS[fn_name]
    return operation.template(args)


class TreeEvaluator():
    def __init__(self) -> None:
        self.cache = {}

    def _eval_tree(self, tree: lark.Tree | lark.Token) -> tuple[str, str, sympy.Expr]:
        if isinstance(tree, lark.Token):
            print(f"Token: {tree}")
        if tree.data == "value":
            children: list[lark.Token] = tree.children
            value_literal = "".join(token.value for token in children)
            # if the value is a constant known by general knowledge and not in the problem assignment,
            # it is prefixed with "const_" in the formula
            value_literal = value_literal.removeprefix("const_")
            # remove leading zeros
            value_literal = value_literal.lstrip("0")
            if value_literal.startswith(".") or value_literal == "":
                value_literal = "0" + value_literal
            return None, calc(value_literal), calc.evaluate(value_literal)

        if tree.data != "call":
            raise ValueError(f"Unexpected tree root type {tree.data}")
        
        fn_call_node: lark.Token
        args_nodes: list[lark.Tree]

        fn_call_node, *args_nodes = tree.children
        fn_name = fn_call_node.value
        args = []
        for arg_node in args_nodes:
            _, value_str, value_expr = self.eval_tree(arg_node)
            value_str = value_str.split(" = ")[0]
            operation = OPERATIONS.get(fn_name, None)
            if isinstance(value_expr, sympy.core.function.Application):
                # argument is a function call, no need to wrap in parentheses
                args.append(value_str)
            elif operation is not None and issubclass(operation, PrefixOperation):
                # current operation is a function call, no need to wrap in parentheses
                args.append(value_str)
            elif isinstance(value_expr, (sympy.Float, sympy.Integer, sympy.NumberSymbol)):
                if value_expr < 0:
                    # negative numbers need to be wrapped in parentheses
                    # otherwise some precedence rules could interfere
                    args.append("(" + value_str + ")")
                else:
                    args.append(value_str)
            elif isinstance(value_expr, sympy.Rational):
                # fractions need to be wrapped in parentheses
                # otherwise some precedence rules could interfere
                args.append("(" + value_str + ")")
            elif isinstance(value_expr, (sympy.Mul, sympy.Pow, sympy.Add)):
                # expressions need to be wrapped in parentheses
                # otherwise some precedence rules could interfere
                args.append("(" + value_str + ")")
            else:
                warnings.warn(f"weird value type {type(value_expr)} for {value_expr} (string: '{value_str}') in {fn_name}(...)")
                # not sure if this is needed, but better safe than sorry lol
                args.append("(" + value_str + ")")
        
        inputs = to_calc_input(fn_name, args)
        result_full = calc(inputs)
        return inputs, result_full, calc.evaluate(inputs)


    def eval_tree(self, tree: lark.Tree | lark.Token) -> tuple[str, str]:
        # cache results to avoid recalculating the same subexpressions
        # it could be done in a more elegant way when traversing the tree,
        # but this is the easiest way to do it and still is around 10x
        # faster on math_qa dataset than the most naive implementation
        if tree not in self.cache:
            self.cache[tree] = self._eval_tree(tree)
        return self.cache[tree]
        

def dfs(tree):
    if isinstance(tree, lark.tree.Tree):
        for child in tree.children:
            yield from dfs(child)
        yield tree

import timeout_decorator

@timeout_decorator.timeout(10, use_signals=False)
def formula_to_steps(formula: str | lark.Tree, drop_repeated: bool) -> tuple[list[gadgets.datatypes.Interaction], sympy.Expr, str]:
    """
    Convert a nested formula into a list of steps
    """
    
    if isinstance(formula, lark.Tree):
        tree = formula
    else:
        tree = parser.parse(formula)

    steps = []
    eval_tree = TreeEvaluator().eval_tree
    final_result_expr = None

    for subtree in dfs(tree):
        inputs_str, result_str, result_expr = eval_tree(subtree)
        if inputs_str is None:
            continue
        
        interaction = gadgets.datatypes.Interaction(
            gadget_id="calculator",
            inputs=inputs_str,
            outputs=result_str,
        )

        final_result_expr = result_expr

        if drop_repeated:
            if interaction not in steps:
                steps.append(interaction)
        else:
            steps.append(interaction)

    result_str = calc.format_sympy_number(final_result_expr)
    markup = gadgets.markup.to_model_markup(chain=steps, result=result_str)
    return steps, final_result_expr, str(markup)


In [15]:
expression = "negate(divide(add(2, 07_000_400, 4), subtract(0.6, divide(4000, 3))))"
formula_to_steps(expression, drop_repeated=True)

([Interaction(gadget_id='calculator', inputs='2 + 7_000_400 + 4', outputs='7_000_406'),
  Interaction(gadget_id='calculator', inputs='4_000 / 3', outputs='4_000/3 = around 1_333.333333'),
  Interaction(gadget_id='calculator', inputs='0.6 - (4_000/3)', outputs='-1_332.733333'),
  Interaction(gadget_id='calculator', inputs='7_000_406 / (-1_332.733333)', outputs='-5_252.668202'),
  Interaction(gadget_id='calculator', inputs='-(-5_252.668202)', outputs='5_252.668202')],
 5252.66820200000,
 '<gadget id="calculator">2 + 7_000_400 + 4</gadget>\n<output>7_000_406</output>\n<gadget id="calculator">4_000 / 3</gadget>\n<output>4_000/3 = around 1_333.333333</output>\n<gadget id="calculator">0.6 - (4_000/3)</gadget>\n<output>-1_332.733333</output>\n<gadget id="calculator">7_000_406 / (-1_332.733333)</gadget>\n<output>-5_252.668202</output>\n<gadget id="calculator">-(-5_252.668202)</gadget>\n<output>5_252.668202</output>\n<result>5_252.668202</result>')

In [16]:
idx = 1200
tree = parser.parse(df_all["annotated_formula"].iloc[idx])
print(df_all["annotated_formula"].iloc[idx])
formula_to_steps(tree, drop_repeated=True)

subtract(180, multiply(divide(3, const_60), const_1000))


([Interaction(gadget_id='calculator', inputs='3 / 60', outputs='1/20 = around 0.05'),
  Interaction(gadget_id='calculator', inputs='(1/20) * 1_000', outputs='50'),
  Interaction(gadget_id='calculator', inputs='180 - 50', outputs='130')],
 130,
 '<gadget id="calculator">3 / 60</gadget>\n<output>1/20 = around 0.05</output>\n<gadget id="calculator">(1/20) * 1_000</gadget>\n<output>50</output>\n<gadget id="calculator">180 - 50</gadget>\n<output>130</output>\n<result>130</result>')

In [17]:
formula_to_steps("divide(2_0000, 122)", drop_repeated=True)

([Interaction(gadget_id='calculator', inputs='20_000 / 122', outputs='10_000/61 = around 163.934426')],
 10000/61,
 '<gadget id="calculator">20_000 / 122</gadget>\n<output>10_000/61 = around 163.934426</output>\n<result>10_000/61 = around 163.934426</result>')

In [18]:
try:
    formula_to_steps("power(44444445, negate(88888885))", drop_repeated=True)
except timeout_decorator.TimeoutError as e:
    print("Timeout!")

Timeout!


In [19]:
df_all["annotated_formula"][df_all["annotated_formula"].str.contains("power(444", regex=False)]

29733    power(44444445, negate(88888885))
37285    power(44444445, negate(88888885))
Name: annotated_formula, dtype: object

In [20]:
from tqdm.auto import tqdm
tqdm.pandas()

def try_stepify(formula: str) -> tuple[list[gadgets.datatypes.Interaction], sympy.Expr, str] | tuple[None, None, None]:
    try:
        return formula_to_steps(formula, drop_repeated=True)
    except timeout_decorator.TimeoutError as e:
        print("Timeout!")
    except Exception as e:
        print(f"Failed to stepify {formula}: {e}")
    return None, None, None


In [21]:
try_stepify("power(44444445, negate(88888885))")

Timeout!


(None, None, None)

In [22]:
from tqdm.auto import tqdm
from concurrent.futures import ProcessPoolExecutor as Pool


with Pool() as pool:
    df_all["steps"], df_all["calculated_result_sympy"], df_all["chain"] = zip(*tqdm(pool.map(try_stepify, df_all["annotated_formula"]), total=len(df_all)))

  0%|          | 0/37297 [00:00<?, ?it/s]

Failed to stepify original_price_before_loss(60, 300): Unknown function original_price_before_loss
Failed to stepify p_after_gain(10, 10): Unknown function p_after_gain




Failed to stepify multiply(divide(20, subtract(subtract(const_100, 80), 20)), const_100): invalid syntax (<string>, line 1)
Failed to stepify p_after_gain(219, 2017): Unknown function p_after_gain
Failed to stepify sqrt(add(power(sqrt(subtract(289, multiply(const_2, 220))), const_2), multiply(const_4, 220))): invalid syntax (<string>, line 1)
Failed to stepify rhombus_perimeter(sqrt(add(power(divide(24, const_2), const_2), power(divide(10, const_2), const_2)))): Unknown function rhombus_perimeter
Failed to stepify original_price_before_loss(40, 420): Unknown function original_price_before_loss
Failed to stepify divide(0.1, power(0.01, 5)): Cannot convert complex to float
Failed to stepify sqrt(add(power(sqrt(subtract(289, multiply(const_2, 276))), const_2), multiply(const_4, 276))): invalid syntax (<string>, line 1)
Failed to stepify divide(59.32, subtract(2, floor(2))): Cannot convert complex to float
Failed to stepify divide(original_price_before_loss(20, 60), divide(original_price_b



Failed to stepify original_price_before_gain(20, 270): Unknown function original_price_before_gain
Failed to stepify multiply(divide(15.06, 0.0000001), const_100): invalid syntax (<string>, line 1)




Failed to stepify multiply(subtract(divide(multiply(const_2, 40), subtract(80, multiply(const_2, 40))), divide(40, subtract(80, 40))), const_60): invalid syntax (<string>, line 1)
Failed to stepify sqrt(add(power(sqrt(subtract(5, multiply(const_2, 266))), const_2), multiply(const_4, 266))): invalid syntax (<string>, line 1)
Failed to stepify log(divide(log(subtract(1, multiply(add(const_4, const_1), const_1000))), log(add(const_4, const_1)))): invalid syntax (<string>, line 1)




Failed to stepify p_after_gain(11, 11): Unknown function p_after_gainFailed to stepify subtract(negate(9), multiply(subtract(1, 3), divide(subtract(1, 3), subtract(1, 1)))): invalid syntax (<string>, line 1)

Failed to stepify divide(0.1, power(0.0001, 3)): Cannot convert complex to float
Failed to stepify stream_speed(13, 9): Unknown function stream_speed
Failed to stepify divide(subtract(2, sqrt(subtract(power(2, 3), multiply(5, 5)))), 3): invalid syntax (<string>, line 1)
Failed to stepify sqrt(subtract(power(divide(24, 4), const_2), power(12, const_2))): Cannot convert complex to float
Failed to stepify original_price_before_gain(20, 250): Unknown function original_price_before_gain




Failed to stepify multiply(divide(15.06, 0.00000001), const_100): invalid syntax (<string>, line 1)
Failed to stepify stream_speed(speed(25, 20), divide(25, multiply(const_1, const_60))): Unknown function stream_speed
Failed to stepify divide(add(5, const_2), 00): Cannot convert complex to float
Failed to stepify divide(original_price_before_loss(20, 90), divide(original_price_before_gain(20, 40), 20)): Unknown function original_price_before_loss
Failed to stepify divide(subtract(2, sqrt(subtract(power(2, 3), multiply(5, 5)))), 3): invalid syntax (<string>, line 1)




Failed to stepify tangent(0.232323): Unknown function tangent
Failed to stepify divide(const_4, add(0_0, const_10)): invalid syntax (<string>, line 1)
Failed to stepify p_after_gain(13, 13): Unknown function p_after_gain
Failed to stepify divide(triangle_area_three_edges(8, 15, 27), divide(triangle_perimeter(8, 15, 27), const_2)): invalid syntax (<string>, line 1)Failed to stepify divide(59.32, subtract(2, floor(2))): Cannot convert complex to float

Failed to stepify log(subtract(divide(10, 2), 10)): Cannot convert complex to float
Failed to stepify divide(10, subtract(const_1, divide(add(10, 10), 20))): Cannot convert complex to float
Failed to stepify sqrt(subtract(power(divide(40, 4), const_2), power(12, const_2))): Cannot convert complex to float
Failed to stepify stream_speed(multiply(3, add(3, 1)), 1): Unknown function stream_speed
Failed to stepify stream_speed(21, 11): Unknown function stream_speed
Failed to stepify rhombus_perimeter(const_4): Unknown function rhombus_perimete



Failed to stepify rectangle_perimeter(18, const_deg_to_rad): invalid syntax (<string>, line 1)Failed to stepify sqrt(add(power(sqrt(subtract(13, multiply(const_2, 2028))), const_2), multiply(const_4, 2028))): invalid syntax (<string>, line 1)

Failed to stepify divide(const_4, add(0_0, const_10)): invalid syntax (<string>, line 1)
Failed to stepify multiply(divide(add(200, 150), speed_in_still_water(divide(const_10, const_36), subtract(45, 40))), const_2): Unknown function speed_in_still_water
Failed to stepify stream_speed(15, 11): Unknown function stream_speed
Failed to stepify divide(original_price_before_loss(20, 70), divide(original_price_before_gain(20, 60), 20)): Unknown function original_price_before_loss
Failed to stepify original_price_before_gain(20, 200): Unknown function original_price_before_gain
Failed to stepify multiply(add(const_1, divide(40, const_100)), original_price_before_gain(40, 2500)): Unknown function original_price_before_gain




Failed to stepify divide(const_4, add(0_0, const_10)): invalid syntax (<string>, line 1)
Failed to stepify multiply(cosine(divide(multiply(60, divide(add(19, const_3), add(const_4, const_3))), multiply(const_60, const_3))), 19): Unknown function cosine
Failed to stepify subtract(p_after_gain(divide(25, 4), 5000), p_after_gain(4, 5000)): Unknown function p_after_gain
Failed to stepify divide(0.1, power(0.01, 4)): Cannot convert complex to float




Failed to stepify divide(add(divide(subtract(360, multiply(divide(8, const_100), 1_000)), subtract(divide(8, const_100), divide(8, const_100))), divide(subtract(360, multiply(divide(8, const_100), 1_000)), subtract(divide(8, const_100), divide(8, const_100)))), 1_000): invalid syntax (<string>, line 1)Failed to stepify add(2001, divide(add(divide(90, const_100), subtract(4.50, 4.20)), subtract(divide(30, const_100), subtract(4.50, 4.20)))): invalid syntax (<string>, line 1)

Failed to stepify multiply(8796, power(223, 8796)): invalid syntax (<string>, line 1)
Failed to stepify original_price_before_gain(20, 300): Unknown function original_price_before_gain
Failed to stepify original_price_before_gain(20, 290): Unknown function original_price_before_gain
Failed to stepify original_price_before_loss(35, 6500): Unknown function original_price_before_loss
Failed to stepify divide(subtract(divide(multiply(divide(8, const_100), 4), const_3), divide(4, const_100)), subtract(divide(8, const_10



Failed to stepify divide(const_4, add(0_0, const_10)): invalid syntax (<string>, line 1)
Failed to stepify original_price_before_loss(40, 420): Unknown function original_price_before_loss
Failed to stepify sqrt(add(power(sqrt(subtract(13, multiply(const_2, 2028))), const_2), multiply(const_4, 2028))): invalid syntax (<string>, line 1)
Failed to stepify rhombus_perimeter(sqrt(add(power(divide(72, const_2), const_2), power(divide(30, const_2), const_2)))): Unknown function rhombus_perimeter
Failed to stepify p_after_gain(14, 14): Unknown function p_after_gain
Failed to stepify multiply(17200, power(223, 17200)): invalid syntax (<string>, line 1)
Failed to stepify original_price_before_loss(50, 5000): Unknown function original_price_before_loss
Failed to stepify stream_speed(19, 9): Unknown function stream_speed
Failed to stepify divide(original_price_before_loss(20, 40), divide(original_price_before_gain(20, 60), 20)): Unknown function original_price_before_loss
Failed to stepify sqrt(ad



Failed to stepify subtract(negate(20), multiply(subtract(6, 12), divide(subtract(6, 12), subtract(0_2, 6)))): invalid syntax (<string>, line 1)
Failed to stepify sqrt(add(power(sqrt(subtract(28, multiply(const_2, 192))), const_2), multiply(const_4, 192))): invalid syntax (<string>, line 1)
Timeout!
Failed to stepify sqrt(add(power(sqrt(subtract(289, multiply(const_2, 468))), const_2), multiply(const_4, 468))): invalid syntax (<string>, line 1)
Failed to stepify stream_speed(11, 5): Unknown function stream_speed
Failed to stepify divide(59.32, subtract(2, floor(2))): Cannot convert complex to float




Failed to stepify multiply(divide(10, subtract(subtract(const_100, 90), 10)), const_100): invalid syntax (<string>, line 1)
Failed to stepify multiply(divide(add(divide(subtract(subtract(const_1000, 5), add(add(multiply(multiply(5, 5), const_10), multiply(5, 5)), 5)), 1_2), const_1), 0), add(subtract(const_1000, 5), add(add(multiply(multiply(5, 5), const_10), multiply(5, 5)), 5))): invalid syntax (<string>, line 1)
Failed to stepify sqrt(subtract(power(divide(32, 4), const_2), power(12, const_2))): Cannot convert complex to float
Timeout!


In [23]:
f"{df_all['steps'].isna().mean():%}"

'0.335148%'

In [24]:
def try_float(x):
    try:
        return float(x.evalf())
    except:
        return float("nan")

In [25]:
rtol = 0.05

matches_any_option = pd.Series([
    np.isclose(options_float, calculated_res, equal_nan=False, rtol=rtol).any()
    for calculated_res, options_float
    in zip(
        df_all["calculated_result_sympy"].apply(try_float),
        df_all["options_float"].apply(lambda d: np.array(list(d.values()), dtype=float))
    )
]).sum()

df_all["matches_selected_option"] = np.isclose(df_all["calculated_result_sympy"].apply(try_float), df_all["correct_float"], rtol=rtol)

print("matches any choice", matches_any_option, "losst:", 1 - (matches_any_option / len(df_all)))
print("matches selected option", df_all["matches_selected_option"].sum(), "lost:", 1 - (df_all["matches_selected_option"].sum() / len(df_all)))

matches any choice 27888 losst: 0.25227230072123763
matches selected option 25999 lost: 0.30291980588251066


In [26]:
import textwrap

for row in df_all[~np.isclose(df_all["calculated_result_sympy"].apply(try_float), df_all["correct_float"], atol=0.1)].sample(20, random_state=0).itertuples():
    print("INDEX:", row.index, "SPLIT:", row.split_name)
    print("PROMPT: ")
    print("    " + "\n    ".join(textwrap.wrap(row.Problem)))
    print("FORMULA:")
    print("    " + row.annotated_formula)
    print("LINEAR:")
    for step in row.linear_formula:
        print("   ", step)
    print("STEPS:")
    for step in row.steps:
        print("   ", step)
    print("RESULT: ", row.calculated_result_sympy)
    print("CORRECT:", row.correct_float)
    print("MATCHES OTHER:", np.isclose(np.array(list(row.options_float.values()), dtype=float).any(), try_float(row.calculated_result_sympy)))
    print("OPTIONS:", row.options)
    print("OPTIONS FLOAT:", row.options_float)
    print()

INDEX: 27662 SPLIT: train
PROMPT: 
    10 + 45
FORMULA:
    multiply(divide(10, 45), const_100)
LINEAR:
    divide(n0,n1)
    multiply(#0,const_100)
STEPS:
    gadget_id='calculator' inputs='10 / 45' outputs='2/9 = around 0.222222'
    gadget_id='calculator' inputs='(2/9) * 100' outputs='200/9 = around 22.222222'
RESULT:  200/9
CORRECT: 55.0
MATCHES OTHER: False
OPTIONS: {'a': '8 ', 'b': '55 ', 'c': '87 ', 'd': '90', 'e': '2'}
OPTIONS FLOAT: {'a': 8.0, 'b': 55.0, 'c': 87.0, 'd': 90.0, 'e': 2.0}

INDEX: 492 SPLIT: train
PROMPT: 
    a train passes a station platform in 42 sec and a man standing on the
    platform in 12 sec . if the speed of the train is 54 km / hr . what is
    the length of the platform ?
FORMULA:
    multiply(12, multiply(54, divide(const_10, const_36)))
LINEAR:
    multiply(n2,divide(const_10, const_36))
    multiply(n1,#0)
STEPS:
    gadget_id='calculator' inputs='10 / 36' outputs='5/18 = around 0.277778'
    gadget_id='calculator' inputs='54 * (5/18)' outputs='15'

In [27]:
(df_all["linear_formula"]
    .apply(lambda steps: [step.split("(")[0] for step in steps])
    .apply(lambda fns: [fn for fn in fns if fn not in OPERATIONS])
    .apply(sorted)
    .apply(tuple)
    .explode()
).value_counts()

original_price_before_loss    19
original_price_before_gain    18
p_after_gain                  14
stream_speed                  10
rhombus_perimeter              3
tangent                        1
sine                           1
speed_in_still_water           1
cosine                         1
Name: linear_formula, dtype: int64

In [28]:
df_all["annotated_formula_orig"] = df_all_orig["annotated_formula"]
df_all["linear_formula_orig"] = df_all_orig["linear_formula"]

In [29]:
df_all["correct_num_str"] = df_all.apply(lambda row: row["options"][row["correct"]], axis=1)
df_all["calulated_result_float"] = df_all["calculated_result_sympy"].apply(try_float)
#removed[["correct_num_str", "correct_float", "calulated_result_float", "calculated_result_sympy"]].sample(30, random_state=0)

In [30]:
df_export = df_all[df_all["matches_selected_option"]]
df_export.shape

(25999, 20)

In [31]:
df_export

Unnamed: 0,index,Problem,Rationale,options,correct,annotated_formula,linear_formula,category,split_name,options_num,options_float,correct_float,steps,calculated_result_sympy,chain,matches_selected_option,annotated_formula_orig,linear_formula_orig,correct_num_str,calulated_result_float
0,0,the banker ' s gain of a certain sum due 3 yea...,"""explanation : t = 3 years r = 10 % td = ( bg ...","{'a': 'rs . 400 ', 'b': 'rs . 300 ', 'c': 'rs ...",a,"divide(multiply(const_100, divide(multiply(36,...","[multiply(n2,const_100), multiply(n0,n1), divi...",gain,train,"{'a': '400', 'b': '300', 'c': '500', 'd': '350...","{'a': 400.0, 'b': 300.0, 'c': 500.0, 'd': 350....",400.0,[gadget_id='calculator' inputs='36 * 100' outp...,400,"<gadget id=""calculator"">36 * 100</gadget>\n<ou...",True,"divide(multiply(const_100, divide(multiply(36,...","multiply(n2,const_100)|multiply(n0,n1)|divide(...",rs . 400,400.0
2,2,sophia finished 2 / 3 of a book . she calculat...,let xx be the total number of pages in the boo...,"{'a': '229 ', 'b': '270 ', 'c': '877 ', 'd': '...",b,"divide(90, subtract(const_1, divide(2, 3)))","[divide(n0,n1), subtract(const_1,#0), divide(n...",general,train,"{'a': '229', 'b': '270', 'c': '877', 'd': '266...","{'a': 229.0, 'b': 270.0, 'c': 877.0, 'd': 266....",270.0,[gadget_id='calculator' inputs='2 / 3' outputs...,270,"<gadget id=""calculator"">2 / 3</gadget>\n<outpu...",True,"divide(90, subtract(const_1, divide(2, 3)))","divide(n0,n1)|subtract(const_1,#0)|divide(n2,#1)",270,270.0
3,3,120 is what percent of 50 ?,"""50 * x = 120 - - > x = 2.4 - - > 2.4 expresse...","{'a': '5 % ', 'b': '240 % ', 'c': '50 % ', 'd'...",b,"multiply(divide(120, 50), const_100)","[divide(n0,n1), multiply(#0,const_100)]",gain,train,"{'a': '5', 'b': '240', 'c': '50', 'd': '2', 'e...","{'a': 5.0, 'b': 240.0, 'c': 50.0, 'd': 2.0, 'e...",240.0,[gadget_id='calculator' inputs='120 / 50' outp...,240,"<gadget id=""calculator"">120 / 50</gadget>\n<ou...",True,"multiply(divide(120, 50), const_100)","divide(n0,n1)|multiply(#0,const_100)|",240 %,240.0
4,4,there are 10 girls and 20 boys in a classroom ...,"if girls is 10 and boys is 20 , then 10 / 20 ....","{'a': '1 / 2 ', 'b': '1 / 3 ', 'c': '1 / 5 ', ...",a,"divide(10, 20)","[divide(n0,n1)]",other,train,"{'a': '1/2', 'b': '1/3', 'c': '1/5', 'd': '10/...","{'a': 0.5, 'b': 0.3333333333333333, 'c': 0.2, ...",0.5,[gadget_id='calculator' inputs='10 / 20' outpu...,1/2,"<gadget id=""calculator"">10 / 20</gadget>\n<out...",True,"divide(10, 20)","divide(n0,n1)",1 / 2,0.5
5,5,an empty fuel tank with a capacity of 218 gall...,"""say there are a gallons of fuel a in the tank...","{'a': '122 ', 'b': '150 ', 'c': '100 ', 'd': '...",a,"divide(subtract(multiply(218, divide(16, const...","[divide(n2,const_100), divide(n1,const_100), m...",gain,train,"{'a': '122', 'b': '150', 'c': '100', 'd': '80'...","{'a': 122.0, 'b': 150.0, 'c': 100.0, 'd': 80.0...",122.0,[gadget_id='calculator' inputs='16 / 100' outp...,122,"<gadget id=""calculator"">16 / 100</gadget>\n<ou...",True,"divide(subtract(multiply(218, divide(16, const...","divide(n2,const_100)|divide(n1,const_100)|mult...",122,122.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
37291,2979,there are 690 male and female participants in ...,female = x male = 690 - x x / 2 + 690 - x / 4 ...,"{'a': '75 ', 'b': '100 ', 'c': '115 ', 'd': '1...",c,"divide(subtract(multiply(divide(690, const_3),...","[divide(n0,const_3), multiply(#0,const_4), sub...",general,test,"{'a': '75', 'b': '100', 'c': '115', 'd': '175'...","{'a': 75.0, 'b': 100.0, 'c': 115.0, 'd': 175.0...",115.0,[gadget_id='calculator' inputs='690 / 3' outpu...,115,"<gadget id=""calculator"">690 / 3</gadget>\n<out...",True,"divide(subtract(multiply(divide(690, const_3),...","divide(n0,const_3)|multiply(#0,const_4)|subtra...",115,115.0
37292,2980,find the area of a parallelogram with base 20 ...,"""area of a parallelogram = base * height = 20 ...","{'a': '100 cm 2 ', 'b': '250 cm 2 ', 'c': '800...",c,"multiply(20, 40)","[multiply(n0,n1)]",geometry,test,"{'a': '100', 'b': '250', 'c': '800', 'd': '296...","{'a': 100.0, 'b': 250.0, 'c': 800.0, 'd': 296....",800.0,[gadget_id='calculator' inputs='20 * 40' outpu...,800,"<gadget id=""calculator"">20 * 40</gadget>\n<out...",True,"multiply(20, 40)","multiply(n0,n1)|",800 cm 2,800.0
37294,2982,a can do a piece of work in 6 hours ; b and c ...,"""a ' s 1 hour work = 1 / 6 ; ( b + c ) ' s 1 h...","{'a': '8 hours ', 'b': '6 hours ', 'c': '14 ho...",d,"divide(const_1, subtract(divide(const_1, 4), s...","[divide(const_1,n1), divide(const_1,n2), divid...",physics,test,"{'a': '8', 'b': '6', 'c': '14', 'd': '12', 'e'...","{'a': 8.0, 'b': 6.0, 'c': 14.0, 'd': 12.0, 'e'...",12.0,[gadget_id='calculator' inputs='1 / 4' outputs...,12,"<gadget id=""calculator"">1 / 4</gadget>\n<outpu...",True,"divide(const_1, subtract(divide(const_1, 4), s...","divide(const_1,n1)|divide(const_1,n2)|divide(c...",12 hours,12.0
37295,2983,a train 250 m long running at 72 kmph crosses ...,"""d = 72 * 5 / 18 = 30 = 600 â € “ 250 = 350 m ...","{'a': '350 m ', 'b': '200 m ', 'c': '250 m ', ...",a,"subtract(multiply(30, multiply(72, divide(cons...","[multiply(n1,divide(const_10, const_36)), mult...",physics,test,"{'a': '350', 'b': '200', 'c': '250', 'd': '270...","{'a': 350.0, 'b': 200.0, 'c': 250.0, 'd': 270....",350.0,[gadget_id='calculator' inputs='10 / 36' outpu...,350,"<gadget id=""calculator"">10 / 36</gadget>\n<out...",True,"subtract(multiply(30, multiply(72, const_0_277...","multiply(n1,const_0_2778)|multiply(n2,#0)|subt...",350 m,350.0


In [52]:
def export_df(part):
    return pd.DataFrame({
        "problem": part["Problem"],
        "rationale": part["Rationale"],
        "correct": part["correct"],
        "options": part["options"],
        "options_float": part["options_float"],
        "annotated_formula": part["annotated_formula_orig"],
        "linear_formula": part["linear_formula_orig"],
        "chain": part["chain"],
        "index": part["index"],
    })

In [54]:
import os
os.makedirs("../data/math_qa/processed", exist_ok=True)

for name in df_all["split_name"].unique():
    split = export_df(df_export[df_export["split_name"] == name])
    split.to_json(f"../data/math_qa/processed/{name}.jsonl", orient="records", lines=True, force_ascii=False)

In [45]:
df_export["split_name"].value_counts()

train         20868
validation     3102
test           2029
Name: split_name, dtype: int64

In [56]:
ds = datasets.load_dataset("MU-NLPC/Calc-math_qa")

Using custom data configuration MU-NLPC--Calc-math_qa-c908af51527daa56
Found cached dataset json (/var/tmp/xkadlci2/.cache/huggingface/datasets/MU-NLPC___json/MU-NLPC--Calc-math_qa-c908af51527daa56/0.0.0/0f7e3662623656454fcd2b650f34e886a7db4b9104504885bd462096cc7a9f51)


  0%|          | 0/3 [00:00<?, ?it/s]

In [58]:
df_export.iloc[0]

index                                                                      0
Problem                    the banker ' s gain of a certain sum due 3 yea...
Rationale                  "explanation : t = 3 years r = 10 % td = ( bg ...
options                    {'a': 'rs . 400 ', 'b': 'rs . 300 ', 'c': 'rs ...
correct                                                                    a
annotated_formula          divide(multiply(const_100, divide(multiply(36,...
linear_formula             [multiply(n2,const_100), multiply(n0,n1), divi...
category                                                                gain
split_name                                                             train
options_num                {'a': '400', 'b': '300', 'c': '500', 'd': '350...
options_float              {'a': 400.0, 'b': 300.0, 'c': 500.0, 'd': 350....
correct_float                                                          400.0
steps                      [gadget_id='calculator' inputs='36 * 100' outp...