### Benchmark Creation
- logic operators: ==, !=, or, and
- arity of 3: a, b, c unique boolean values

In [None]:
from collections import OrderedDict
import random
import pickle
import pandas as pd
import re
from tqdm import tqdm
import copy
import numpy as np
import torch

class CFG(OrderedDict):
    def __init__(self, *args):
        super().__init__(map(lambda s: s.replace(' ', '').split('->'), args))
        
    def __repr__(self):
        return '\n'.join('{} -> {}'.format(k, v) for k, v in self.items())

    def getProductions(self, symbol):
        return self[symbol].split('|')

# Depth-first walk through tree, selecting random productions
def generateSentence(cfg, start='S'):
    string = []
    def dfs(root):
        local_str = ''
        prod = random.choice(cfg.getProductions(root))
        for char in prod:
            if char in cfg:
                result = dfs(char)
                if result:
                    string.append(result)
            else:
                local_str += char
        return local_str

    dfs(start)
    return ' '.join(string)

# Example CFG found online
L = [
    'S -> CLAUSES',
    'CLAUSES -> CLAUSE CONJ CLAUSE',
    'CLAUSE -> LPR VAR EQ VAR RPR',
    'CONJ -> "and" | "or"',
    'EQ -> "==" | "!="',
    'VAR -> "a" | "b" | "c"',
    'LPR -> "("',
    'RPR -> ")"',
]

# Replacing variable names for simpler parsing
table = OrderedDict([
    ('CLAUSES', 'A'),
    ('CLAUSE',  'B'),
    ('CONJ',    'C'),
    ('EQ',      'D'),
    ('VAR',     'E'),
    ('LPR',     'F'),
    ('RPR',     'G')
])

conj_re = re.compile(r"""
    ^
    \s*
    \(
    \s*(\w+?)\s*(?:==|!=)\s*(\w+?)\s*
    \)
    \s*$""", re.VERBOSE)

In [None]:
SKIP = True
if not SKIP:
    for i in range(len(L)):
        L[i] = L[i].replace('\"', '')
        for key in table:
            L[i] = L[i].replace(key, table[key])

    cfg = CFG(*L)
    primitive_clauses = set([])
    equivalent_clauses = set([])
    for _ in range(10000): # 10000 will get you all the programs.
        clause = generateSentence(cfg)
        if "a == a" in clause or \
            "a != a" in clause or \
            "b == b" in clause or \
            "b != b" in clause or \
            "c == c" in clause or \
            "c != c" in clause:
            continue

        if "and" in clause and \
            clause.split(" and ")[0] == clause.split(" and ")[1]:
            continue
        if "or" in clause and \
            clause.split(" or ")[0] == clause.split(" or ")[1]:
            continue

        if "and" in clause:
            variables = set([])
            for c in clause.split(" and "):
                if "==" in c:
                    for t in c.strip("() ").split(" == "):
                        variables.add(t)
                elif "!=" in c:
                    for t in c.strip("() ").split(" != "):
                        variables.add(t)
            if len(variables) < 3:
                continue

            left = sorted(clause.split(" and ")[0])
            right = sorted(clause.split(" and ")[1])

            if ("and", tuple(left), tuple(right)) in equivalent_clauses or \
                ("and", tuple(right), tuple(left)) in equivalent_clauses:
                continue
            else:
                equivalent_clauses.add(("and", tuple(left), tuple(right)))
                equivalent_clauses.add(("and", tuple(right), tuple(left)))

        if "or" in clause:
            variables = set([])
            for c in clause.split(" or "):
                if "==" in c:
                    for t in c.strip("() ").split(" == "):
                        variables.add(t)
                elif "!=" in c:
                    for t in c.strip("() ").split(" != "):
                        variables.add(t)
            if len(variables) < 3:
                continue

            if ("or", tuple(left), tuple(right)) in equivalent_clauses or \
                ("or", tuple(right), tuple(left)) in equivalent_clauses:
                continue
            else:
                equivalent_clauses.add(("or", tuple(left), tuple(right)))
                equivalent_clauses.add(("or", tuple(right), tuple(left)))

        primitive_clauses.add(clause)


    primitive_clauses = list(primitive_clauses)
    random.shuffle(primitive_clauses)
    training_clauses = primitive_clauses[:20]
    eval_clauses = primitive_clauses[20:]
    pickle.dump(primitive_clauses, open("./cfg_all.pkl", 'wb'))
    pickle.dump(training_clauses, open("./cfg_train.pkl", 'wb'))
    pickle.dump(eval_clauses, open("./cfg_test.pkl", 'wb'))

In [None]:
def parse(clauses):
    conjs = re.split(r"\s*(?:and|or)\s*", clauses)
    data = []
    for conj in conjs:
        if conj_re.search(conj):
            LVAR, RVAR = conj_re.search(conj).groups()
            EQ = "==" if "==" in conj else "!="
            d = {
                "L" : LVAR,
                "R" : RVAR,
                "EQ" : EQ
            }
            data += [d]
    return data

def sample_demonstration_for_clauses(
    clauses,
    vocab,
    final_value=None
):
    if final_value == None:
        final_value = random.choice([True, False])

    if "and" in clauses:
        data = parse(clauses)
        if final_value == True:
            data[0]["VAL"] = True
            data[1]["VAL"] = data[0]["VAL"]
        else:
            data[0]["VAL"] = True if random.random() >= 0.5 else False
            data[1]["VAL"] = not data[0]["VAL"]
    elif "or" in clauses:
        data = parse(clauses)
        if final_value:
            data[0]["VAL"] = True if random.random() >= 0.5 else False
            data[1]["VAL"] = random.choice([True, False]) if data[0]["VAL"] else True
        else:
            data[0]["VAL"] = False
            data[1]["VAL"] = data[0]["VAL"]
    else:
        data = parse(clauses)
        data[0]["VAL"] = final_value
    
    value_assignment = {}
    for d in data:
        if (d["EQ"] == "==" and d["VAL"] == True) or \
                (d["EQ"] == "!=" and d["VAL"] == False):
            if d['L'] in value_assignment:
                value_assignment[d['R']] = value_assignment[d['L']]
            elif d['R'] in value_assignment:
                value_assignment[d['L']] = value_assignment[d['R']]
            else:
                value_assignment[d['L']] = value_assignment[d['L']] if d['L'] in value_assignment else random.choice(list(vocab))
                value_assignment[d['R']] = value_assignment[d['L']]
                vocab -= {value_assignment[d['L']]}
        elif (d["EQ"] == "==" and d["VAL"] == False) or \
                (d["EQ"] == "!=" and d["VAL"] == True):
            if d['L'] in value_assignment:
                value_assignment[d['R']] = random.choice(list(vocab))
                assert value_assignment[d['R']] != value_assignment[d['L']]
                vocab -= {value_assignment[d['R']]}
            elif d['R'] in value_assignment:
                value_assignment[d['L']] = random.choice(list(vocab))
                assert value_assignment[d['L']] != value_assignment[d['R']]
                vocab -= {value_assignment[d['L']]}
            else:
                value_assignment[d['L']] = random.choice(list(vocab))
                vocab -= {value_assignment[d['L']]}
                value_assignment[d['R']] = random.choice(list(vocab))
                vocab -= {value_assignment[d['R']]}
    
    for d in data:
        if (d["EQ"] == "==" and d["VAL"] == True) or \
                (d["EQ"] == "!=" and d["VAL"] == False):
            assert value_assignment[d['L']] == value_assignment[d['R']]
        elif (d["EQ"] == "==" and d["VAL"] == False) or \
                (d["EQ"] == "!=" and d["VAL"] == True):
            assert value_assignment[d['L']] != value_assignment[d['R']]
    
    if "and" in clauses:
        assert final_value == (data[0]["VAL"] and data[1]["VAL"])
    elif "or" in clauses:
        assert final_value == (data[0]["VAL"] or data[1]["VAL"])
        
    return value_assignment
    # we need to assert check

def sample_demonstrations_for_clauses(
    clauses,
    vocab,
    final_values,
):
    demos = []
    for i in range(len(final_values)):
        demo = sample_demonstration_for_clauses(
            clauses, vocab, final_values[i]
        )
        demo['clause'] = clauses
        demo['output'] = final_values[i]
        demos += [demo]
    return demos

def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

In [None]:
seed = 42
set_seed(seed)

n_training_examples = 20000
n_eval_examples = 1000
n_test_examples = 1000
n_training_program = 15
n_fewshot = 6

#################
#
# DO NOT CHANGE
#
#################
FALSE_TOKEN_ID = 0
TRUE_TOKEN_ID = 1
INPUT_PREFIX_TOKEN_ID = 2
OUTPUT_PREFIX_TOKEN_ID = 3
SEPARATOR_TOKEN_ID = 4
PADDING_TOKEN_ID = 5
BOS_TOKEN_ID = 6
EOS_TOKEN_ID = 7

vocab = set([i for i in range(10, 50257)]) # reserve the first 10 for special tokens.
n_examples = n_fewshot + 1
training_clauses = pickle.load(open("./cfg_train.pkl", 'rb'))
eval_clauses = pickle.load(open("./cfg_test.pkl", 'rb'))
if n_training_program is not None:
    training_clauses = random.sample(training_clauses, k=n_training_program)

all_train_input_ids = []
all_train_output_ids = []
all_train_clauses = []
for i in tqdm(range(n_training_examples)):
    clauses = random.choice(training_clauses)
    demostrations = sample_demonstrations_for_clauses(
        clauses,
        copy.deepcopy(vocab),
        final_values=[random.choice([True, False]) for i in range(n_examples)]
    )
    
    # listify
    input_ids = [BOS_TOKEN_ID]
    output_ids = [BOS_TOKEN_ID]
    for d in demostrations:
        output = FALSE_TOKEN_ID if d['output'] == False else TRUE_TOKEN_ID
        input_ids += [INPUT_PREFIX_TOKEN_ID, d['a'], d['b'], d['c'], OUTPUT_PREFIX_TOKEN_ID, output, SEPARATOR_TOKEN_ID]
        output_ids += [-100, -100, -100, -100, -100, output, -100]
        assert len(input_ids) == len(output_ids)
    input_ids += [EOS_TOKEN_ID]
    output_ids += [EOS_TOKEN_ID]
    all_train_input_ids += [input_ids]
    all_train_output_ids += [output_ids]
    all_train_clauses += [clauses]
    
all_eval_input_ids = []
all_eval_output_ids = []
all_eval_clauses = []
for i in tqdm(range(n_eval_examples)):
    clauses = random.choice(training_clauses)
    demostrations = sample_demonstrations_for_clauses(
        clauses,
        copy.deepcopy(vocab),
        final_values=[random.choice([True, False]) for i in range(n_examples)]
    )
    
    # listify
    input_ids = [BOS_TOKEN_ID]
    output_ids = [BOS_TOKEN_ID]
    for d in demostrations:
        output = FALSE_TOKEN_ID if d['output'] == False else TRUE_TOKEN_ID
        input_ids += [INPUT_PREFIX_TOKEN_ID, d['a'], d['b'], d['c'], OUTPUT_PREFIX_TOKEN_ID, output, SEPARATOR_TOKEN_ID]
        output_ids += [-100, -100, -100, -100, -100, output, -100]
        assert len(input_ids) == len(output_ids)
    input_ids += [EOS_TOKEN_ID]
    output_ids += [EOS_TOKEN_ID]
    all_eval_input_ids += [input_ids]
    all_eval_output_ids += [output_ids]
    all_eval_clauses += [clauses]
    
train_data = {
    "input_ids" : all_train_input_ids,
    "output_ids" : all_train_output_ids,
    "clauses" : all_train_clauses,
}
dev_data = {
    "input_ids" : all_eval_input_ids,
    "output_ids" : all_eval_output_ids,
    "clauses" : all_eval_clauses,
}
pickle.dump(train_data, open(f"./train_data.n_rule.{n_training_program}.n_shot.{n_fewshot}.pkl", 'wb'))
pickle.dump(dev_data, open(f"./dev_data.n_rule.{n_training_program}.n_shot.{n_fewshot}.pkl", 'wb'))

all_test_input_ids = []
all_test_output_ids = []
all_test_clauses = []
for i in tqdm(range(n_test_examples)):
    clauses = random.choice(eval_clauses)
    demostrations = sample_demonstrations_for_clauses(
        clauses,
        copy.deepcopy(vocab),
        final_values=[random.choice([True, False]) for i in range(n_examples)]
    )
    
    # listify
    input_ids = [BOS_TOKEN_ID]
    output_ids = [BOS_TOKEN_ID]
    for d in demostrations:
        output = FALSE_TOKEN_ID if d['output'] == False else TRUE_TOKEN_ID
        input_ids += [INPUT_PREFIX_TOKEN_ID, d['a'], d['b'], d['c'], OUTPUT_PREFIX_TOKEN_ID, output, SEPARATOR_TOKEN_ID]
        output_ids += [-100, -100, -100, -100, -100, output, -100]
        assert len(input_ids) == len(output_ids)
    input_ids += [EOS_TOKEN_ID]
    output_ids += [EOS_TOKEN_ID]
    all_test_input_ids += [input_ids]
    all_test_output_ids += [output_ids]
    all_test_clauses += [clauses]
    
test_data = {
    "input_ids" : all_test_input_ids,
    "output_ids" : all_test_output_ids,
    "clauses" : all_test_clauses,
}
pickle.dump(test_data, open(f"./test_data.n_rule.{n_training_program}.n_shot.{n_fewshot}.pkl", 'wb'))