In [2]:
import pprint
import string
import itertools

import datasets
import pandas as pd
import numpy as np
import torch
import joblib
import scipy.sparse
import sklearn.feature_extraction


ds_names = [
    "MU-NLPC/Calc-gsm8k",
    "MU-NLPC/Calc-aqua_rat",
    "MU-NLPC/Calc-math_qa",
    "MU-NLPC/Calc-ape210k",
    "MU-NLPC/Calc-mawps",
    "MU-NLPC/Calc-svamp",
    "MU-NLPC/Calc-asdiv_a",
]

In [3]:
keep_symbols = set(string.ascii_lowercase.lower() + " ")
dss = {}
split_names = set()

for full_name in ds_names:
    ds = datasets.load_dataset(full_name)
    ds_name = full_name.split("/")[-1].lower()
    for split_name, split in ds.items():
        split_names.add(split_name)
        key = ds_name, split_name
        dss[key] = split.to_pandas()[["question", "chain", "result"]]
        dss[key]["question_simplified"] = (
            dss[key]["question"]
            .str.encode("ascii", errors="ignore")
            .str.decode("ascii")
            .str.lower()
            .str.split()
            .str.join(" ")
            .apply(lambda text: "".join([c for c in text if c in keep_symbols]))
            .str.split()
            .str.join(" ")
        )

In [4]:
bow_ngrams_vectorizer = sklearn.feature_extraction.text.CountVectorizer(binary=True, dtype=np.int32, ngram_range=(1, 2))

bow_ngrams_vectorizer.fit(
    itertools.chain.from_iterable(ds["question_simplified"] for ds in dss.values())
)

bows = {}

for key, ds in dss.items():
    bows[key] = bow_ngrams_vectorizer.transform(ds["question_simplified"])

In [5]:
def pairwise_jaccard_sim(bows_1: scipy.sparse.csr_matrix, bows_2: scipy.sparse.csr_matrix) -> np.ndarray:
    """
    Computes the Jaccard distance between each row of X matrix and each row of Y matrix.
    """
    sizes_of_1 = bows_1.getnnz(axis=1).astype(np.float32)
    sizes_of_2 = bows_2.getnnz(axis=1).astype(np.float32)
    intersect = (bows_1 @ bows_2.T).toarray().astype(np.float32)
    union = sizes_of_1.reshape(-1, 1) + sizes_of_2.reshape(1, -1) - intersect
    with np.errstate(divide='ignore', invalid='ignore'):
        result = intersect / union
        np.nan_to_num(result, nan=0, posinf=0, neginf=0, copy=False)
    return result


def get_highest_k_matches(scores: torch.Tensor, k: int):
    top_in_rows = torch.topk(k=k, dim=1, sorted=False, largest=True, input=scores)
    top_in_cols = torch.topk(k=k, dim=1, sorted=False, largest=True, input=scores.T)
    return top_in_rows, top_in_cols


def check_leak(bows_1, bows_2, top_k=10):
    scores = pairwise_jaccard_sim(bows_1, bows_2)
    return get_highest_k_matches(torch.tensor(scores), k=top_k)

In [6]:
check_leaks = []
for ds_name_1, ds_split_name_1 in dss.keys():
    for ds_name_2, ds_split_name_2 in dss.keys():
        if ds_split_name_1 == "train" and ds_split_name_2 != "train":
            check_leaks.append(((ds_name_1, ds_split_name_1), (ds_name_2, ds_split_name_2)))


pprint.pprint(check_leaks)
print(len(check_leaks))

[(('calc-gsm8k', 'train'), ('calc-gsm8k', 'test')),
 (('calc-gsm8k', 'train'), ('calc-aqua_rat', 'test')),
 (('calc-gsm8k', 'train'), ('calc-aqua_rat', 'validation')),
 (('calc-gsm8k', 'train'), ('calc-math_qa', 'test')),
 (('calc-gsm8k', 'train'), ('calc-math_qa', 'validation')),
 (('calc-gsm8k', 'train'), ('calc-ape210k', 'test')),
 (('calc-gsm8k', 'train'), ('calc-ape210k', 'validation')),
 (('calc-gsm8k', 'train'), ('calc-mawps', 'validation')),
 (('calc-gsm8k', 'train'), ('calc-mawps', 'test')),
 (('calc-gsm8k', 'train'), ('calc-svamp', 'test')),
 (('calc-gsm8k', 'train'), ('calc-asdiv_a', 'test')),
 (('calc-aqua_rat', 'train'), ('calc-gsm8k', 'test')),
 (('calc-aqua_rat', 'train'), ('calc-aqua_rat', 'test')),
 (('calc-aqua_rat', 'train'), ('calc-aqua_rat', 'validation')),
 (('calc-aqua_rat', 'train'), ('calc-math_qa', 'test')),
 (('calc-aqua_rat', 'train'), ('calc-math_qa', 'validation')),
 (('calc-aqua_rat', 'train'), ('calc-ape210k', 'test')),
 (('calc-aqua_rat', 'train'), ('ca

In [7]:
candidates = {}

with joblib.Parallel(n_jobs=-1) as parallel:
    jobs = (joblib.delayed(check_leak)(bows[ds_1], bows[ds_2]) for ds_1, ds_2 in check_leaks)
    results = parallel(jobs)
    for (ds_train, ds_eval), leak_candidates in zip(check_leaks, results):
        candidates[ds_train, ds_eval] = leak_candidates


In [17]:
threshold = 0.5
print_examples = False

for (ds_train, ds_eval), (train_sim, eval_sim) in candidates.items():
    if ds_eval[1] == "validation":
        continue
    is_mostly_formula_problem = (dss[ds_eval]["question_simplified"].apply(len) / dss[ds_eval]["question"].apply(len)) < 0.5
    # example of mostly_formula_problem is: Solve 2x + 3x^2 + 8/5 = 1295
    # on those examples, we don't want to check for similarity on words
    sus_mask = (eval_sim.values > threshold) # has shape (len_eval, top_k)
    sus_mask[is_mostly_formula_problem] = False
    suspicious_frac = sus_mask.any(dim=1).float().mean().item()
    if suspicious_frac > 0.05:
        print(f"{suspicious_frac:.2%} of {'/'.join(ds_eval):<30} examples appear similar to some examples in {'/'.join(ds_train)}")
        sus_mask_in_train = (train_sim.values > threshold).any(dim=1).float().mean().item()
        print(f"-> {sus_mask_in_train:.2%} of {'/'.join(ds_train):<27} examples would have to be dropped")
        print()
        if not print_examples:
            continue
        all_sus_eval_idxs, train_nth_similar = sus_mask.nonzero(as_tuple=True)
        sample = torch.randint(0, len(all_sus_eval_idxs), (10,))
        sampled_sus_eval_idxs = all_sus_eval_idxs[sample]
        sampled_train_nth_similar = train_nth_similar[sample]
        sampled_eval_questions = dss[ds_eval]["question"].iloc[sampled_sus_eval_idxs]
        sampled_train_questions = dss[ds_train]["question"].iloc[eval_sim.indices[sampled_sus_eval_idxs, sampled_train_nth_similar]]
        sampled_similarities = eval_sim.values[sampled_sus_eval_idxs, sampled_train_nth_similar]
        for eval_question, train_question, similarity in zip(sampled_eval_questions, sampled_train_questions, sampled_similarities):
            print("  eval: ", eval_question)
            print("  train:", train_question)
            print(f"  {similarity=:.2f}")
            print()

        print()
        print("-" * 100)


30.71% of calc-aqua_rat/test             examples appear similar to some examples in calc-aqua_rat/train
-> 1.36% of calc-aqua_rat/train         examples would have to be dropped

97.49% of calc-math_qa/test              examples appear similar to some examples in calc-aqua_rat/train
-> 24.56% of calc-aqua_rat/train         examples would have to be dropped

7.87% of calc-aqua_rat/test             examples appear similar to some examples in calc-math_qa/train
-> 0.99% of calc-math_qa/train          examples would have to be dropped

85.17% of calc-math_qa/test              examples appear similar to some examples in calc-math_qa/train
-> 42.85% of calc-math_qa/train          examples would have to be dropped

59.77% of calc-ape210k/test              examples appear similar to some examples in calc-ape210k/train
-> 20.94% of calc-ape210k/train          examples would have to be dropped

77.50% of calc-mawps/test                examples appear similar to some examples in calc-mawps/train

In [9]:
# Data leaks:
# aqua_rat train -> math_qa test + validation # math_qa is basically whole a subset of train aqua_rat
# math_qa train -> math_qa test + validation
# ape210k train -> ape210k test + validation
# mawps train -> mawps test + validation

# Fair evaluation for models trained on aquarat+ape210k+gsm8k+mathqa:
# - don't eval on mathqa at all -> remove completely from latex table
# - evaluation on gsm8k is ok
# - need to evaluate on svamp and mawps -> we don't need to filter anything
# - drop a lot of ape210k eval samples
# - drop some aqua_rat eval samples

In [10]:
dss["calc-ape210k", "test"]

Unnamed: 0,question,chain,result,question_simplified
0,Wang Yan's family bought a washing machine an...,"\n<gadget id=""calculator"">3 / 5</gadget>\n<out...",3_750,wang yans family bought a washing machine and ...
1,There are 5 baskets of apples with the same w...,"\n<gadget id=""calculator"">5 - 3</gadget>\n<out...",25,there are baskets of apples with the same weig...
2,"Aunt Wang types 60 words per minute, how many...","\n<gadget id=""calculator"">60 * 15</gadget>\n<o...",900,aunt wang types words per minute how many word...
3,"The number A is 42, the number B is (3/7) of ...","\n<gadget id=""calculator"">3 / 7</gadget>\n<out...",18,the number a is the number b is of the number ...
4,Uncle Zhang surrounded a semi-circular duck h...,"\n<gadget id=""calculator"">18.84 / 3.14</gadget...",56.52,uncle zhang surrounded a semicircular duck hou...
...,...,...,...,...
4862,The sum of (4/5) of a number and 25% of 200 i...,"\n<gadget id=""calculator"">25 / 100</gadget>\n<...",120,the sum of of a number and of is find this number
4863,"Each bookshelf has 3 layers, and each layer h...","\n<gadget id=""calculator"">50 * 3 * 4</gadget>\...",600,each bookshelf has layers and each layer holds...
4864,There are 6 children participating in a party...,"\n<gadget id=""calculator"">6 - 1</gadget>\n<out...",15,there are children participating in a party an...
4865,The ratio of the number of people who subscrib...,"\n<gadget id=""calculator"">20 / 4</gadget>\n<ou...",35,the ratio of the number of people who subscrib...


In [11]:
(candidates[("calc-ape210k", "train"), ("calc-ape210k", "test")][1].values < 0.50).all(dim=1).float().sum()

tensor(1785.)

In [12]:
ape_ok = (candidates[("calc-ape210k", "train"), ("calc-ape210k", "test")][1].values < 0.50).all(dim=1)

In [13]:
import pathlib

for file in pathlib.Path("../predictions/full_ape210k/").glob("*"):
    with open(file) as f:
        lines = f.readlines()
        assert len(lines) == len(ape_ok)

    with open(file.parent.parent / file.name, "w") as f:
        for line, ok in zip(lines, ape_ok):
            if ok:
                f.write(line)

In [14]:
datasets.load_dataset("mu-nlpc/calc-ape210k", split="test").to_pandas()[ape_ok.numpy()]

Unnamed: 0,id,question,question_chinese,chain,result,result_float,equation
0,971711,Wang Yan's family bought a washing machine an...,王艳家买了一台洗衣机和一台电冰箱，一共花了6000元，电冰箱的价钱是洗衣机的(3/5)，求洗...,"\n<gadget id=""calculator"">3 / 5</gadget>\n<out...",3_750,3750.00,x=6000/(1+(3/5))
6,636507,"There are 5 plum trees, each of which produce...",有李树5棵，每棵产李子60.8千克，桃树8棵，每棵产桃子47.5千克，收获哪种水果比较重？比...,"\n<gadget id=""calculator"">47.5 * 8</gadget>\n<...",76,76.00,x=(47.5*8)-(60.8*5)
7,298954,If the number of male students in a class is 6...,某班男生人数是全班人数的60%，那么这个班男生人数是女生人数的多少,"\n<gadget id=""calculator"">60 / 100</gadget>\n<...",3/2,1.50,x=60%/(1-60%)
8,241525,The distance between A and B is 120 kilometer...,甲、乙两地相距120千米，客车和货车同时从甲地出发驶向乙地，客车到达乙地后立即沿原路返回，在...,"\n<gadget id=""calculator"">120 / 3</gadget>\n<o...",80,80.00,x=(120+(120/3))/2
13,899977,"The length and width of a cuboid are 4 meters,...",一个长方体的长和宽都是4米，高是5米，如果底面积扩大5倍，要使体积不变，高应该是多少厘米．,"\n<gadget id=""calculator"">5 / 5</gadget>\n<out...",100,100.00,(5/5)*100
...,...,...,...,...,...,...,...
4853,432324,Put a row of flowers (both ends) on the path ...,在公园小路上放一排花(两端都放)．每两盘花之间的距离是4米，需要25盘花却少了8盘花．要把现...,"\n<gadget id=""calculator"">25 - 1</gadget>\n<ou...",6,6.00,x=4*(25-1)/(25-8-1)
4856,198804,"When a road repair team builds a road, the ra...",修路队修一条路，已经修的与总长的比是1：3，再修150米，则正好修完全长的50%．这条路全长...,"\n<gadget id=""calculator"">50 / 100</gadget>\n<...",900,900.00,x=150/(50%-(1/3))
4857,1018629,The price of a product first increases by 10%...,一件商品先涨价10%，再降价10%，这件商品的价格和原价比较，是涨了还是降了？变化幅度是多少？,"\n<gadget id=""calculator"">10 / 100</gadget>\n<...",1/100,0.01,x=1-((1+10%)*(1-10%))
4858,370131,"When Xiao Ming calculated △*1.6+0.8, he accide...",小明在计算△*1.6+0.8时，不小心算成了△*1.6-0.8这样两个得数之间相差多少？,"\n<gadget id=""calculator"">0.8 + 0.8</gadget>\n...",1.6,1.60,x=0.8+0.8


In [15]:
aqua_ok = (candidates[("calc-aqua_rat", "train"), ("calc-aqua_rat", "test")][1].values < 0.5).all(dim=1)

In [16]:
import pathlib

for file in pathlib.Path("../predictions/full_aqua_rat/").glob("*"):
    with open(file) as f:
        lines = f.readlines()
        assert len(lines) == len(aqua_ok)

    with open(file.parent.parent / file.name, "w") as f:
        for line, ok in zip(lines, aqua_ok):
            if ok:
                f.write(line)