In [13]:
import json
import os
import numpy as np 
from collections import defaultdict

In [16]:
path = "../data/cladder-v1-q-balanced.json"
meta_path = "../data/cladder-v1-meta-models.json"
with open(path, "r") as f:
    data = json.load(f)

with open(meta_path, "r") as f:
    meta_data = json.load(f)


In [17]:
for k,v in meta_data[0].items():
    print (f"{k}: {v}")

print ('--'*80)
for k,v in data[100].items():
    print (f"{k}: {v}")

model_id: 0
story_id: alarm
graph_id: mediation
spec_id: 0
spec: {'X': [0.6, 0.8], 'V2': {'X': -0.2}, 'Y': {'X': 0.3, 'V2': 0.4}}
seed: 101
builder: difficulty
difficulty: easy
equation_type: bernoulli
background: Imagine a self-contained, hypothetical world with only the following conditions, and without any unmentioned factors or causal relationships: Husband has a direct effect on wife and alarm clock. Wife has a direct effect on alarm clock.
variable_mapping: {'Xname': 'husband', 'X1': 'alarm set by husband', 'X0': 'alarm not set by husband', 'V2name': 'wife', 'V21': 'alarm set by wife', 'V20': 'alarm not set by wife', 'Yname': 'alarm clock', 'Y1': 'ringing alarm', 'Y0': 'silent alarm'}
structure: X->V2,X->Y,V2->Y
params: {'p(X)': 0.7887065011221108, 'p(V2 | X)': [0.7416866188819118, 0.2351932407152126], 'p(Y | X, V2)': [[0.08430222457648492, 0.5394610521458689], [0.4061509701126924, 0.8620283206949243]]}
groundtruth: {'ATE(Y | X)': 0.09148280511411633, 'ETT(Y | X)': 0.091482805114

data sample should have question|steps|answer field

In [46]:
rung_ds = defaultdict(list) # split into different rungs

background_info = {}
for d in meta_data:
    background_info[d['model_id']] = d['background']

# make sure the last 3 steps are different, ie the variable should be different.
for d in data:
    if d['reasoning'] is None:
        continue
    assert d['question'] is not None and d['given_info'] is not None and d['answer'] in ['yes','no']
    rung = d['meta']['rung']
    assert rung in [1,2,3]
    background = background_info[d['meta']['model_id']]
    steps = [d['reasoning'][f'step{i}'] for i in range(6)]
    steps[-1] = steps[-1] + ', ' + d['reasoning']['end']
    rung_ds[rung].append({'question': f"{background}\n{d['given_info']}\n{d['question']}",
        'steps':steps,
        'answer':d['answer'],
        'rung':rung,
        })

    

Train on rung 1/2 and test on 3 (3 is the hardest)

In [47]:
# split into training,test 

valid_size = 100
train_ds = []
test_ds = []
valid_ds = []

np.random.seed(42)

train_ds = sum([v for k,v in rung_ds.items() if k <3],[])
np.random.shuffle(train_ds)

np.random.shuffle(rung_ds[3])
valid_ds = rung_ds[3][:valid_size]
test_ds = rung_ds[3][valid_size:]

print (f"train size: {len(train_ds)}")
print (f"valid size: {len(valid_ds)}")
print (f"test size: {len(test_ds)}")

train size: 4740
valid size: 100
test size: 3692


In [54]:
train_ds_path = '../data/cladder_train_gen.json'
valid_ds_path = '../data/cladder_valid_gen.json'
test_ds_path = '../data/cladder_test_gen.json'

with open(train_ds_path, "w") as f:
    json.dump(train_ds, f)
with open(valid_ds_path, "w") as f:
    json.dump(valid_ds, f)
with open(test_ds_path, "w") as f:
    json.dump(test_ds, f)

# Look at GSM8K, take fixed step and see if helps perf

In [8]:
gsm_train_path = '../data/gsm_train.json'
gsm_valid_path = '../data/gsm_valid.json'
gsm_test_path = '../data/gsm_test.json'

with open(gsm_train_path, "r") as f:
    gsm_train = json.load(f)
with open(gsm_valid_path, "r") as f:
    gsm_valid = json.load(f)
with open(gsm_test_path, "r") as f:
    gsm_test = json.load(f)

In [None]:

train_steps = defaultdict(int)
val_steps = defaultdict(int)
test_steps = defaultdict(int)

for d in gsm_train:
    train_steps[len(d['steps'])] += 1
for d in gsm_valid:
    val_steps[len(d['steps'])] += 1
for d in gsm_test:
    test_steps[len(d['steps'])] += 1
print(sorted(train_steps.items(), key=lambda x: x[1], reverse=True))
print(sorted(val_steps.items(), key=lambda x: x[1], reverse=True))
print(sorted(test_steps.items(), key=lambda x: x[1], reverse=True))

[(2, 143578), (3, 104249), (1, 62908), (4, 48198), (5, 17906), (6, 5666), (7, 2359), (8, 577), (9, 126), (10, 43), (11, 8), (12, 1), (13, 1)]
[(2, 155), (3, 140), (4, 91), (5, 53), (1, 32), (6, 18), (7, 8), (8, 3)]
[(3, 364), (2, 357), (4, 290), (5, 138), (1, 83), (6, 57), (7, 21), (8, 9)]


In [12]:
# fix steps to 3
steps = 3
new_train_ds = []
new_val_ds = []
new_test_ds = []
for d in gsm_train:
    if len(d['steps']) == steps:
        new_train_ds.append(d)
for d in gsm_valid:
    if len(d['steps']) == steps:
        new_val_ds.append(d)
for d in gsm_test:
    if len(d['steps']) == steps:
        new_test_ds.append(d)


new_train_ds_path = '../data/gsm_train_3.json'
new_val_ds_path = '../data/gsm_valid_3.json'
new_test_ds_path = '../data/gsm_test_3.json'
with open(new_train_ds_path, "w") as f:
    json.dump(new_train_ds, f)
with open(new_val_ds_path, "w") as f:
    json.dump(new_val_ds, f)
with open(new_test_ds_path, "w") as f:
    json.dump(new_test_ds, f)

