In [1]:
import warnings
from typing import Iterable

import datasets
import pandas as pd
import numpy as np
import sympy
import lark
import pandarallel
from tqdm.auto import tqdm

import gadgets

tqdm.pandas()
pandarallel.pandarallel.initialize(progress_bar=True)

INFO: Pandarallel will run on 40 workers.
INFO: Pandarallel will use Memory file system to transfer data between the main process and workers.


In [2]:
df = {
    "train": pd.read_json("../data/ape210k/my_train.ape.jsonl", lines=True),
    "valid": pd.read_json("../data/ape210k/my_valid.ape.jsonl", lines=True),
    "test": pd.read_json("../data/ape210k/my_tests.ape.jsonl", lines=True),
}

In [3]:
grammar = """
?start: expr

?expr: neg

?atom: num
    | implicit_mul
    | "(" expr ")"

implicit_mul: num ( "(" expr ")" )+
            | "(" expr ")" ( "(" expr ")" )+

?neg: add
    | "-" neg -> neg
    | "-" add -> neg
?add: sub
    | sub ("+" sub)+ -> add
?sub: mul
    | mul ("-" mul)+ -> sub
?mul: div
    | div ("*" div)+ -> mul
?div: pow
    | pow ("/" pow)+ -> div
?pow: perc
    | perc ("**" perc)+ -> pow
?perc: atom "%" -> perc
     | atom
?num: SIGNED_NUMBER

%import common.SIGNED_NUMBER
%import common.WS
%ignore WS
"""



In [30]:
class TreeEvaluator:

    def __init__(self, calc: gadgets.gadget.Calculator, parser: lark.Lark) -> None:
        self.cache = {}
        self.calc = calc
        self.parser = parser

    def eval_tree(self, tree: lark.Tree | lark.Token) -> tuple[str, sympy.Expr]:
        if tree not in self.cache:
            self.cache[tree] = self._eval_tree(tree)
        return self.cache[tree]

    def _eval_tree(self, tree: lark.Tree | lark.Token) -> tuple[str, sympy.Expr]:
        if isinstance(tree, lark.Token):
            if tree.type in ("SIGNED_NUMBER", "NUMBER"):
                return None, self.calc.evaluate(tree.value)
            else:
                raise ValueError(f"unknown token {tree}")
        
        assert isinstance(tree.data, str)
        operation = tree.data
        args_nodes = tree.children
        args = [self._format_arg(self.eval_tree(arg_node)[1]) for arg_node in args_nodes]
        inputs = self._format_op(operation, args)
        return inputs, self.calc.evaluate(inputs)
    
    def _format_op(self, op: str, args: list[str]) -> str:
        if op == "neg":
            assert len(args) == 1
            return "-" + args[0]
        if op == "add" or op == "implicit_add":
            return " + ".join(args)  
        if op == "sub":
            return " - ".join(args)
        if op == "mul" or op == "implicit_mul":
            return " * ".join(args)
        if op == "div":
            return " / ".join(args)
        if op == "pow":
            return " ** ".join(args)
        if op == "perc":
            assert len(args) == 1
            return f"{args[0]} / 100"
        raise ValueError(f"unknown operation {op}")

    def _format_arg(self, value_expr: sympy.Number) -> str:
        value_str = self.calc.format_sympy_number(value_expr, add_approx=False)
        if isinstance(value_expr, sympy.core.function.Application):
            return value_str
        if isinstance(value_expr, (sympy.Float, sympy.Integer, sympy.NumberSymbol)):
            if value_expr < 0:
                return "(" + value_str + ")"
            return value_str
        if isinstance(value_expr, sympy.Rational):
            return "(" + value_str + ")"
        if isinstance(value_expr, (sympy.Mul, sympy.Pow, sympy.Add)):
            return "(" + value_str + ")"
        warnings.warn(f"weird value type {type(value_expr)} for {value_expr} (string: '{value_str}')")
        return "(" + value_str + ")"

    def dfs(self, tree: lark.Tree | lark.Token) -> Iterable[lark.Tree | lark.Token]:
        if isinstance(tree, lark.Tree):
            for child in tree.children:
                yield from self.dfs(child)
        yield tree

    def expr_to_steps(self, expr: str, drop_repeated: bool = True) -> tuple[list[gadgets.datatypes.Interaction], sympy.Expr]:
        tree = self.parser.parse(expr)
        steps = []
        for subtree in self.dfs(tree):
            step = self.eval_tree(subtree)
            if step is None:
                continue
            if step[0] is None:
                continue
            inputs, output_expr = step
            interaction = gadgets.datatypes.Interaction(
                gadget_id="calculator",
                inputs=inputs,
                outputs=self.calc.format_sympy_number(output_expr),
            )
            if drop_repeated and interaction in steps:
                continue
            steps.append(interaction)
            
        _, result = self.eval_tree(tree)
        return steps, result
    

In [31]:
calc = gadgets.gadget.Calculator()
parser = lark.Lark(grammar)

TreeEvaluator(calc, parser).expr_to_steps("1-(-1/-5)-(1/4000)%")

([Interaction(gadget_id='calculator', inputs='(-1) / (-5)', outputs='1/5 = around 0.2'),
  Interaction(gadget_id='calculator', inputs='1 / 4_000', outputs='1/4_000 = around 0.00025'),
  Interaction(gadget_id='calculator', inputs='(1/4_000) / 100', outputs='1/400_000 = around 0.000002'),
  Interaction(gadget_id='calculator', inputs='1 - (1/5) - (1/400_000)', outputs='319_999/400_000 = around 0.799998')],
 319999/400000)

In [32]:
df["valid"].columns

Index(['id', 'question_chinese', 'question_english_mt', 'equation', 'result',
       'chain_list', 'my_result', 'chain_markup', 'result_eval', 'result_eq'],
      dtype='object')

In [33]:
def try_stepify(expr: str) -> tuple[list[gadgets.datatypes.Interaction], sympy.Expr] | tuple[None, None]:
    evaluator = TreeEvaluator(calc, parser)
    expr = expr.removeprefix("x=").replace(":", "/").lstrip("+")
    try:
        chain, result = evaluator.expr_to_steps(expr)
        result_str = calc.format_sympy_number(result)
        return chain, result, str(gadgets.markup.to_model_markup(chain=chain, result=result_str))
    except Exception:
        print(f"failed to parse {expr}")
        return None, None, None

def try_result_eval(expr: str) -> sympy.Expr:
    evaluator = TreeEvaluator(calc, parser)
    expr = expr.replace(":", "/").lstrip("+")
    try:
        inputs, output = evaluator.expr_to_steps(expr)
        return output
    except Exception:
        print(f"failed to parse {expr}")
        return None


In [34]:
for split in df.keys():
    df[split]["chain_list"], df[split]["my_result"], df[split]["chain_markup"] = zip(*df[split]["equation"].parallel_apply(try_stepify))

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=5013), Label(value='0 / 5013'))), …

failed to parse 1-(1/+(2048/1))
failed to parse (-(9/-((9/1)-(10/1)))+(10/-((9/1)-(10/1))))-((-(9/-((9/1)-(10/1)))+(10/-((9/1)-(10/1))))/2)+((-(9/-((9/1)-(10/1)))+(10/-((9/1)-(10/1))))/10)
failed to parse 1*(2004/(+(3-2)+1))+1
failed to parse 6.3+6.3*(3/+(2/1))
failed to parse 10.75/(1+(+(3/1)/4)+(1/2))
failed to parse 1003+((+((2006**2+1)/(2006**2-1))+(2/(2005*2007)))-((+((2006**2+1)/(2006**2-1))+(2/(2005*2007)))/2007))
failed to parse (((+(2005/2005)+2005)*2005/2)/2005)
failed to parse (1+(10/1))*+(10/1)/2
failed to parse (((+(1990/1990)+1990)*1990/2)/1990)
failed to parse 1+(1/2)-(1/+(128/1))
failed to parse 102-101+100/+(4+3-2-1)*+(4+3-2-1)
failed to parse 1-(1/+(256/1))
failed to parse (16/(4-(16/10)*2))*((16/10)++(4-(16/10)*2))
failed to parse 2*+(999-1)*((+(999-1)+999-(999-1))/2)
failed to parse (1/2)-(1/+(512/1))
failed to parse (1+(1991/1))*+(1991/1)/2
failed to parse (39*(22*+(+((+(+((+(+((+(+((3/1)/1)/1)/1)/1)/1)/1)/1)/1)/1)/1)/1)-(22+25+34+39)/2)/2-(22+25+34+39)/2)/2
failed

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=125), Label(value='0 / 125'))), HB…

failed to parse 1.8/(1.2/(60*100000))/100000
failed to parse (1007-1)*+(1007-1)


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=125), Label(value='0 / 125'))), HB…

failed to parse (+(100/100)+(5050/100))*100/2
failed to parse 1-(1/+(512/1))
failed to parse (2/+(3/1))
failed to parse (2950+25-+(1998-1997))*25-(2949+25-+(1998-1997))*25


In [35]:
for split in df.keys():
    df[split]["result_eval"] = df[split]["result"].parallel_apply(try_result_eval)

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=5013), Label(value='0 / 5013'))), …

failed to parse 25%%
failed to parse 100%%


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=125), Label(value='0 / 125'))), HB…

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=125), Label(value='0 / 125'))), HB…

In [36]:
import math
def eq(x, y, tol=1e-5):
    if x is None or y is None:
        return False
    return math.isclose(x.evalf(), y.evalf(), abs_tol=tol)

for name, split in reversed(df.items()):
    df[name]["result_eq"] = [eq(x, y) for x, y in zip(split["my_result"], split["result_eval"])]
    print(name, df[name]["result_eq"].mean())

test 0.9746
valid 0.9742
train 0.9746269103387734


In [55]:
df["train"].columns

Index(['id', 'question_chinese', 'question_english_mt', 'equation', 'result',
       'chain_list', 'my_result', 'chain_markup', 'result_eval', 'result_eq'],
      dtype='object')

In [94]:
def export_split(split: pd.DataFrame):
    orig_size = len(split)
    # drop unparseable
    split = split.dropna()
    # drop with inconsistent results
    split = split[split["result_eq"]]
    # drop with ambiguous result caused by implicit multiplication vs compound fraction
    split = split[~split["result"].str.contains("\d\(", regex=True)]
    print(len(split) / orig_size)
    return pd.DataFrame({
        "id": split["id"],
        "question_chinese": split["question_chinese"],
        "question_english_mt": split["question_english_mt"],
        "equation": split["equation"],
        "result_orig_format": split["result"],
        "result_new_format": split["my_result"].apply(lambda x: calc.format_sympy_number(x, add_approx=False)),
        "result_float": split["my_result"].apply(lambda x: float(x.evalf())),
        "chain": split["chain_markup"],
    })

In [95]:
for split in df.keys():
    export_split(df[split]).to_json(f"../data/ape210k/processed/{split}.jsonl", lines=True, orient="records", force_ascii=False)

0.9735196121463628
0.9734
0.9734


In [97]:
ds = datasets.load_dataset(
    "json",
    data_files={
        "train": "../data/ape210k/processed/train.jsonl",
        "validation": "../data/ape210k/processed/valid.jsonl",
        "test": "../data/ape210k/processed/test.jsonl",
    }
)

Using custom data configuration default-4fbe7afbe04d68b2


Downloading and preparing dataset json/default to /var/tmp/xkadlci2/.cache/huggingface/datasets/json/default-4fbe7afbe04d68b2/0.0.0/0f7e3662623656454fcd2b650f34e886a7db4b9104504885bd462096cc7a9f51...


Downloading data files:   0%|          | 0/3 [00:00<?, ?it/s]

Extracting data files:   0%|          | 0/3 [00:00<?, ?it/s]

Generating train split: 0 examples [00:00, ? examples/s]

Generating validation split: 0 examples [00:00, ? examples/s]

Generating test split: 0 examples [00:00, ? examples/s]

Dataset json downloaded and prepared to /var/tmp/xkadlci2/.cache/huggingface/datasets/json/default-4fbe7afbe04d68b2/0.0.0/0f7e3662623656454fcd2b650f34e886a7db4b9104504885bd462096cc7a9f51. Subsequent calls will reuse this data.


  0%|          | 0/3 [00:00<?, ?it/s]

In [119]:
ds = datasets.load_dataset("MU-NLPC/Calc-ape210k")

Downloading readme:   0%|          | 0.00/21.0 [00:00<?, ?B/s]

Using custom data configuration MU-NLPC--Calc-ape210k-9d633ba24ff13754


Downloading and preparing dataset json/MU-NLPC--Calc-ape210k to /var/tmp/xkadlci2/.cache/huggingface/datasets/MU-NLPC___json/MU-NLPC--Calc-ape210k-9d633ba24ff13754/0.0.0/0f7e3662623656454fcd2b650f34e886a7db4b9104504885bd462096cc7a9f51...


Downloading data files:   0%|          | 0/3 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/137M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/3.41M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/3.42M [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/3 [00:00<?, ?it/s]

Generating train split: 0 examples [00:00, ? examples/s]

Generating test split: 0 examples [00:00, ? examples/s]

Generating validation split: 0 examples [00:00, ? examples/s]

Dataset json downloaded and prepared to /var/tmp/xkadlci2/.cache/huggingface/datasets/MU-NLPC___json/MU-NLPC--Calc-ape210k-9d633ba24ff13754/0.0.0/0f7e3662623656454fcd2b650f34e886a7db4b9104504885bd462096cc7a9f51. Subsequent calls will reuse this data.


  0%|          | 0/3 [00:00<?, ?it/s]

In [120]:
ds

DatasetDict({
    train: Dataset({
        features: ['id', 'question_chinese', 'question_english_mt', 'equation', 'result_orig_format', 'result_new_format', 'result_float', 'chain'],
        num_rows: 195179
    })
    test: Dataset({
        features: ['id', 'question_chinese', 'question_english_mt', 'equation', 'result_orig_format', 'result_new_format', 'result_float', 'chain'],
        num_rows: 4867
    })
    validation: Dataset({
        features: ['id', 'question_chinese', 'question_english_mt', 'equation', 'result_orig_format', 'result_new_format', 'result_float', 'chain'],
        num_rows: 4867
    })
})