In [None]:
%load_ext autoreload
%autoreload 2

import checklist
import pickle
from checklist.test_suite import TestSuite
import checklist
import numpy as np
import pandas as pd
from collections import OrderedDict
import pickle
import dill
from pathlib import Path
from tqdm.auto import tqdm

from suite_utils import get_test_info

In [None]:
CHECKS = True
TRAIN_CONFIG = "all"
METHOD = "dropout"
PLUS_IID=True
FROM_IID=True

In [None]:
def get_test_pass_rate(suite, print_results=True):
    results = OrderedDict()
    for k, v in suite.tests.items():
        try:
            score = 100 - v.get_stats()['fail_rate']
            if print_results:
                print(f"{k}; {score}")
            results[k] = score
        except KeyError:
            if print_results:
                print(f"{k}; None")
            results[k] = None
    return results

def fix_run_idxs(suite, ids):
    i=0
    for func, _ in suite.tests.items():
        suite.tests[func].result_indexes = ids[func]
        suite.tests[func].run_idxs = np.array(list(OrderedDict.fromkeys(ids[func])))
        suite.test_ranges[func] = (i, i+len(ids[func]))
        i+=len(ids[func])

In [None]:
def get_out_results(suite, data_path, task, split, method="default", plus_iid=False, from_iid=False, print_results=True):
    if plus_iid:
        outPath = data_path/"-".join([f"{split}+iid",method])
        if from_iid:
            outPath = data_path/"-".join([f"iid-{split}+iid",method])
    elif task=="squad" or method != "default":
        outPath = data_path/"-".join([split,method])
    else:
        outPath = data_path/split
    results = {}
    for result in outPath.rglob("*"):
        pred_path = result
        if task=="sa":
            suite.run_from_file(pred_path, overwrite=True)
        elif task == "qqp":
            suite.run_from_file(pred_path, overwrite=True, file_format='binary_conf')
        elif task == "squad":
            suite.run_from_file(pred_path, overwrite=True, file_format='pred_only')
        if split == "funcOut":
            test_name = result.name[:-3].replace("|", "/")
            try:
                results[test_name] = 100 - suite.tests[test_name].get_stats()["fail_rate"]
            except KeyError:
                results[test_name] = "None"
        elif split == "classOut":
            class_name = result.name[:-3]
            for test_name, test in suite.tests.items():
                if test.capability == class_name:
                    try:
                        results[test_name] = 100 - suite.tests[test_name].get_stats()["fail_rate"]
                    except KeyError:
                        results[test_name] = "None"
        elif split == "aspectOut":
            aspect_name = result.name[:-3]
            for test_name, test in suite.tests.items():
                if test.form_test_info()["type"] == aspect_name:
#                     print(test.summary())
                    try:
                        results[test_name] = 100 - suite.tests[test_name].get_stats()["fail_rate"]
                    except KeyError:
                        results[test_name] = "None"
    if print_results:
        for test in suite.tests:
            print(test, ";", results[test])
    return results

In [None]:
def get_seed_results(suite, data_path, task, print_results=True):
    outPath = data_path/"seeds"
    results = {}
    for result in outPath.rglob("*"):
        pred_path = result
        if task=="sa":
            suite.run_from_file(pred_path, overwrite=True)
        elif task == "qqp":
            suite.run_from_file(pred_path, overwrite=True, file_format='binary_conf')
        elif task == "squad":
            suite.run_from_file(pred_path, overwrite=True, file_format='pred_only')
        for k, v in suite.tests.items():
            try:
                results.setdefault(k, []).append(100 - v.get_stats()['fail_rate'])
            except KeyError:
                results.setdefault(k, []).append(None)
    if print_results:
        for k, v in results.items():
            print(f"{k}; {';'.join([str(pass_rate) for pass_rate in v])}")
    return results

In [None]:
def get_random_results(data_path, task):
    results = {}
    for result in data_path.rglob("*random*"):
        pred_path = result
        if task=="sa":
            suite.run_from_file(pred_path, overwrite=True)
        elif task == "qqp":
            suite.run_from_file(pred_path, overwrite=True, file_format='binary_conf')
        elif task == "squad":
            suite.run_from_file(pred_path, overwrite=True, file_format='pred_only')
        for k, v in suite.tests.items():
            try:
                results.setdefault(k, []).append(100 - v.get_stats()['fail_rate'])
            except KeyError:
                print(f"{k}; None")
    for k, v in results.items():
        print(f"{k}; {';'.join([str(pass_rate) for pass_rate in v])}")

In [None]:
def get_funcs_dic(suite):
    return {n: (t.capability, t.form_test_info()["type"]) for n, t in suite.tests.items()}

In [None]:
try:
    suite_path = './data/release_data/sentiment/sentiment_suite.pkl'
    suite = TestSuite.from_file(suite_path)
except FileNotFoundError:
    print("Downloading CheckList release files...")
    ! wget -P "./data" https://github.com/marcotcr/checklist/raw/master/release_data.tar.gz
    ! tar xvzf data/release_data.tar.gz -C ./data
    suite_path = './data/release_data/sentiment/sentiment_suite.pkl'
    suite = TestSuite.from_file(suite_path)

## Sentiment

In [None]:
pred_path = "./data/sa/predictions/SST-2/BERT-SST2"

In [None]:
suite.run_from_file(pred_path, overwrite=True)

In [None]:
suite.summary()

In [None]:
suite.visual_summary_table()

In [None]:
get_test_info(suite)

In [None]:
sa_dic = get_funcs_dic(suite)

In [None]:
pickle.dump(sa_dic, open("./data/sa/func_dic.pkl", "wb"))

In [None]:
get_test_pass_rate(suite)

In [None]:
pred_path = "./data/release_data/sentiment/predictions/bert"

In [None]:
suite.run_from_file(pred_path, overwrite=True)

In [None]:
suite.visual_summary_table()

In [None]:
get_test_pass_rate(suite)

### Our test set

In [None]:
with open("./data/sa/predictions/checklist/test_ids.pkl", "rb") as file:
    test_ids = pickle.load(file)

In [None]:
fix_run_idxs(suite, test_ids)

In [None]:
if CHECKS:
    if PLUS_IID:
        pred_path = f'./data/sa/predictions/checklist/BERT-SST2-{TRAIN_CONFIG}+iid-{METHOD}'
        if FROM_IID:
            pred_path = f'./data/sa/predictions/checklist/BERT-SST2-iid-{TRAIN_CONFIG}+iid-{METHOD}'
    elif METHOD == "default":
        pred_path = f'./data/sa/predictions/checklist/BERT-SST2-{TRAIN_CONFIG}'
    else:
        pred_path = f'./data/sa/predictions/checklist/BERT-SST2-{TRAIN_CONFIG}-{METHOD}'
else:
    pred_path = './data/sa/predictions/checklist/BERT-SST2'

In [None]:
suite.run_from_file('./data/sa/predictions/checklist/BERT-SST2', overwrite=True)

In [None]:
suite.run_from_file(pred_path, overwrite=True)

In [None]:
suite.visual_summary_table()

In [None]:
get_test_pass_rate(suite)

### held-out results

In [None]:
data_path = Path("./data/sa/predictions/checklist/")

In [None]:
get_out_results(suite, data_path, "sa", "funcOut", METHOD, PLUS_IID, FROM_IID)

In [None]:
get_out_results(suite, data_path, "sa", "classOut", METHOD, PLUS_IID, FROM_IID)

In [None]:
get_out_results(suite, data_path, "sa", "aspectOut", METHOD, PLUS_IID, FROM_IID)

## QQP

In [None]:
suite_path = './data/release_data/qqp/qqp_suite.pkl'
suite = TestSuite.from_file(suite_path)

In [None]:
with open("/home/peluz/Projects/univie/specification-learning-selection/specification-learning-selection-code/data/suites/sa/suite.pkl", "wb") as file:
    dill.dump(suite, file, protocol=4)

In [None]:
pred_path = './data/qqp/predictions/qqp/BERT-qqp'
suite.run_from_file(pred_path, overwrite=True, file_format='binary_conf')

In [None]:
suite.summary()

In [None]:
suite.visual_summary_table()

In [None]:
get_test_info(suite)

In [None]:
suite.tests["Order does matter for asymmetric relations"].labels = 0

In [None]:
suite.save(suite_path)

In [None]:
qqp_dic = get_funcs_dic(suite)

In [None]:
pickle.dump(qqp_dic, open("./data/qqp/func_dic.pkl", "wb"))

In [None]:
suite.run_from_file(pred_path, overwrite=True, file_format='binary_conf')

In [None]:
get_test_pass_rate(suite)

In [None]:
pred_path = './data/release_data/qqp/predictions/bert'
suite.run_from_file(pred_path, overwrite=True, file_format='binary_conf')

In [None]:
get_test_pass_rate(suite)

### Our test set

In [None]:
with open("./data/qqp/predictions/checklist/test_ids.pkl", "rb") as file:
    test_ids = pickle.load(file)

In [None]:
if CHECKS:
    if PLUS_IID:
        pred_path = f'./data/qqp/predictions/checklist/BERT-qqp-{TRAIN_CONFIG}+iid-{METHOD}'
        if FROM_IID:
            pred_path = f'./data/qqp/predictions/checklist/BERT-qqp-iid-{TRAIN_CONFIG}+iid-{METHOD}'
    elif METHOD == "default":
        pred_path = f'./data/qqp/predictions/checklist/BERT-qqp-{TRAIN_CONFIG}'
    else:
        pred_path = f'./data/qqp/predictions/checklist/BERT-qqp-{TRAIN_CONFIG}-{METHOD}'
else:
    pred_path = './data/qqp/predictions/checklist/BERT-qqp'

In [None]:
fix_run_idxs(suite, test_ids)

In [None]:
suite.run_from_file(pred_path, overwrite=True, file_format='binary_conf')

In [None]:
suite.visual_summary_table()

In [None]:
get_test_pass_rate(suite)

### held-out results

In [None]:
data_path = Path("./data/qqp/predictions/checklist/")

In [None]:
get_out_results(suite, data_path, "qqp", "funcOut", METHOD, plus_iid=PLUS_IID, from_iid=FROM_IID)

In [None]:
get_out_results(suite, data_path, "qqp", "classOut", METHOD, plus_iid=PLUS_IID, from_iid=FROM_IID)

In [None]:
get_out_results(suite, data_path, "qqp", "aspectOut", METHOD, plus_iid=PLUS_IID, from_iid=FROM_IID)

## SQuAD

In [None]:
suite_path = 'data/release_data/squad/squad_suite.pkl'
suite = TestSuite.from_file(suite_path)

In [None]:
pred_path = 'data/squad/predictions/squad/BERT-squad'
suite.run_from_file(pred_path, overwrite=True, file_format='pred_only')

In [None]:
suite.summary()

In [None]:
suite.visual_summary_table()

In [None]:
squad_dic = get_funcs_dic(suite)

In [None]:
pickle.dump(squad_dic, open("./data/squad/func_dic.pkl", "wb"))

In [None]:
get_test_info(suite)

In [None]:
get_test_pass_rate(suite)

In [None]:
pred_path = 'data/release_data/squad/predictions/bert'
suite.run_from_file(pred_path, overwrite=True, file_format='pred_only')
get_test_pass_rate(suite)

### Our test set

In [None]:
if CHECKS:
    if PLUS_IID:
        pred_path = f'./data/squad/predictions/checklist/BERT-squad-{TRAIN_CONFIG}+iid-{METHOD}'
        if FROM_IID:
            pred_path = f'./data/squad/predictions/checklist/BERT-squad-iid-{TRAIN_CONFIG}+iid-{METHOD}'
    elif METHOD == "default":
        pred_path = f'./data/squad/predictions/checklist/BERT-squad-{TRAIN_CONFIG}'
    else:
        pred_path = f'./data/squad/predictions/checklist/BERT-squad-{TRAIN_CONFIG}-{METHOD}'
else:
    pred_path = './data/squad/predictions/checklist/BERT-squad'

In [None]:
import pickle

with open("./data/squad/predictions/checklist/test_ids.pkl", "rb") as file:
    test_ids = pickle.load(file)

In [None]:
fix_run_idxs(suite, test_ids)

In [None]:
pred_path = "./data/squad/predictions/checklist/iid-funcOut+iid-fish/Basic coref, his | herOut"

In [None]:
suite.run_from_file(pred_path, overwrite=True, file_format='pred_only')

In [None]:
suite.visual_summary_table()

In [None]:
get_test_pass_rate(suite)

### held-out results

In [None]:
data_path = Path("./data/squad/predictions/checklist/")

In [None]:
get_out_results(suite, data_path, "squad", "funcOut", METHOD, plus_iid=PLUS_IID, from_iid=FROM_IID)

In [None]:
get_out_results(suite, data_path, "squad", "classOut", METHOD, plus_iid=PLUS_IID, from_iid=FROM_IID)

In [None]:
get_out_results(suite, data_path, "squad", "aspectOut", METHOD, plus_iid=PLUS_IID, from_iid=FROM_IID)

## Generate csvs with fine-grained results

In [None]:
def get_results(task):
    if task == "sa":
        suite_path = './data/release_data/sentiment/sentiment_suite.pkl'
        pred_path = './data/sa/predictions/checklist/BERT-SST2'
        task_data = "SST2"
    elif task == "qqp":
        suite_path = './data/release_data/qqp/qqp_suite.pkl'
        pred_path = './data/qqp/predictions/checklist/BERT-qqp'
        task_data = task
    else:
        suite_path = './data/release_data/squad/squad_suite.pkl'
        pred_path = './data/squad/predictions/checklist/BERT-squad'
        task_data = task
    suite = TestSuite.from_file(suite_path)
    
    results = pd.DataFrame(columns=["method", "config", "score", *suite.tests.keys()])
    
    with open(f"./data/{task}/predictions/checklist/test_ids.pkl", "rb") as file:
        test_ids = pickle.load(file)
    fix_run_idxs(suite, test_ids)
    if task=="sa":
        suite.run_from_file(pred_path, overwrite=True)
    elif task == "qqp":
        suite.run_from_file(pred_path, overwrite=True, file_format='binary_conf')
    elif task == "squad":
        suite.run_from_file(pred_path, overwrite=True, file_format='pred_only')
    scores = get_test_pass_rate(suite, print_results=False)
    method = "default"
    config = "IID"
    score = "standard"
    scores["method"] = method
    scores["config"] = config
    scores["score"] = score
    results = results.append(scores, ignore_index=True)
    
    for config in ["IID-T", "IID+T", "IID-(IID+T)"]:
        print(f"Checking config {config}")
        if config == "IID-T":
            plus_iid = from_iid = False
        elif config == "IID+T":
            plus_iid = True
            from_iid = False
        else:
            plus_iid = from_iid = True
        for method in ["default", "l2", "dropout", "freeze", "lp-ft", "irm", "dro", "fish"]:
            print(f"method {method}")
            for score in ["seen", "funcOut", "classOut", "aspectOut"]:
                print(f"score {score}")
                if score == "seen":
                    if plus_iid:
                        pred_path = f'./data/{task}/predictions/checklist/BERT-{task_data}-all+iid-{method}'
                        if from_iid:
                            pred_path = f'./data/{task}/predictions/checklist/BERT-{task_data}-iid-all+iid-{method}'
                    elif method == "default":
                        pred_path = f'./data/{task}/predictions/checklist/BERT-{task_data}-all'
                    else:
                        pred_path = f'./data/{task}/predictions/checklist/BERT-{task_data}-all-{method}'
                    try:
                        if task=="sa":
                            suite.run_from_file(pred_path, overwrite=True)
                        elif task == "qqp":
                            suite.run_from_file(pred_path, overwrite=True, file_format='binary_conf')
                        elif task == "squad":
                            suite.run_from_file(pred_path, overwrite=True, file_format='pred_only')
                    except:
                        break
                    scores = get_test_pass_rate(suite, print_results=False)
                else:
                    scores = get_out_results(suite, Path(f"./data/{task}/predictions/checklist/"), task, score, method, plus_iid=plus_iid, from_iid=from_iid, print_results=False)
                scores["method"] = method
                scores["config"] = config
                scores["score"] = score
                results = results.append(scores, ignore_index=True)
                if task == "sa":
                    results = results.dropna(axis=1)
    return results

### SA

In [None]:
sa_results = get_results("sa")

In [None]:
sa_results

In [None]:
sa_results.mean(axis=1)

In [None]:
sa_results.to_csv('data/sa/results_fine.csv', index=False)

### QQP

In [None]:
qqp_results = get_results("qqp")

In [None]:
qqp_results

In [None]:
qqp_results.mean(axis=1)

In [None]:
qqp_results.to_csv('data/qqp/results_fine.csv', index=False)

### Squad

In [None]:
squad_results = get_results("squad")

In [None]:
squad_results

In [None]:
squad_results.mean(axis=1)

In [None]:
squad_results.to_csv('data/squad/results_fine.csv', index=False)

## Generate suite hits

In [None]:
def get_test_hits(suite):
    results = OrderedDict()
    for k, v in suite.tests.items():
        if k in ['"used to" should reduce', "reducers"]:
            continue
        filtered = v.filtered_idxs()
        fails = v.fail_idxs()
        n_tests = v.get_stats().testcases_run
        hits = np.ones(n_tests)
        hits[fails] = 0
        hits[filtered] = np.nan
        results[k] = hits
    return results

In [None]:
def get_out_hits(suite, data_path, task, split, method="default", plus_iid=False, from_iid=False):
    if plus_iid:
        outPath = data_path/"-".join([f"{split}+iid",method])
        if from_iid:
            outPath = data_path/"-".join([f"iid-{split}+iid",method])
    elif task=="squad" or method != "default":
        outPath = data_path/"-".join([split,method])
    else:
        outPath = data_path/split
    results = {}
    for result in outPath.rglob("*"):
        pred_path = result
        if task=="sa":
            suite.run_from_file(pred_path, overwrite=True)
        elif task == "qqp":
            suite.run_from_file(pred_path, overwrite=True, file_format='binary_conf')
        elif task == "squad":
            suite.run_from_file(pred_path, overwrite=True, file_format='pred_only')
        funcs = []
        if split == "funcOut":
            funcs.append(result.name[:-3].replace("|", "/"))
        elif split == "classOut":
            class_name = result.name[:-3]
            for test_name, test in suite.tests.items():
                if test.capability == class_name:
                    funcs.append(test_name)
        elif split == "aspectOut":
            aspect_name = result.name[:-3]
            for test_name, test in suite.tests.items():
                if test.form_test_info()["type"] == aspect_name:
                    funcs.append(test_name)
        for func in funcs:
            if func in ['"used to" should reduce', "reducers"]:
                continue
            v = suite.tests[func]
            filtered = v.filtered_idxs()
            fails = v.fail_idxs()
            n_tests = v.get_stats().testcases_run
            hits = np.ones(n_tests)
            hits[fails] = 0
            hits[filtered] = np.nan
            results[func] = hits
    return results

In [None]:
def get_all_hits(task):
    if task == "sa":
        suite_path = './data/release_data/sentiment/sentiment_suite.pkl'
        pred_path = './data/sa/predictions/checklist/BERT-SST2'
        task_data = "SST2"
    elif task == "qqp":
        suite_path = './data/release_data/qqp/qqp_suite.pkl'
        pred_path = './data/qqp/predictions/checklist/BERT-qqp'
        task_data = task
    else:
        suite_path = './data/release_data/squad/squad_suite.pkl'
        pred_path = './data/squad/predictions/checklist/BERT-squad'
        task_data = task
    suite = TestSuite.from_file(suite_path)
    
    results = {}
    
    with open(f"./data/{task}/predictions/checklist/test_ids.pkl", "rb") as file:
        test_ids = pickle.load(file)
    fix_run_idxs(suite, test_ids)
    if task=="sa":
        suite.run_from_file(pred_path, overwrite=True)
    elif task == "qqp":
        suite.run_from_file(pred_path, overwrite=True, file_format='binary_conf')
    elif task == "squad":
        suite.run_from_file(pred_path, overwrite=True, file_format='pred_only')
    hits = get_test_hits(suite)

    results["baseline"] = hits
    
    for config in ["iid-t", "iid+t", "iid-iid+t"]:
        print(f"Checking config {config}")
        if config == "iid-t":
            plus_iid = from_iid = False
        elif config == "iid+t":
            plus_iid = True
            from_iid = False
        else:
            plus_iid = from_iid = True
        for method in ["default", "l2", "dropout", "freeze", "lp-ft", "irm", "dro", "fish"]:
            print(f"method {method}")
            for score in ["seen", "funcOut", "classOut", "aspectOut"]:
                print(f"score {score}")
                if score == "seen":
                    if plus_iid:
                        pred_path = f'./data/{task}/predictions/checklist/BERT-{task_data}-all+iid-{method}'
                        if from_iid:
                            pred_path = f'./data/{task}/predictions/checklist/BERT-{task_data}-iid-all+iid-{method}'
                    elif method == "default":
                        pred_path = f'./data/{task}/predictions/checklist/BERT-{task_data}-all'
                    else:
                        pred_path = f'./data/{task}/predictions/checklist/BERT-{task_data}-all-{method}'
                    try:
                        if task=="sa":
                            suite.run_from_file(pred_path, overwrite=True)
                        elif task == "qqp":
                            suite.run_from_file(pred_path, overwrite=True, file_format='binary_conf')
                        elif task == "squad":
                            suite.run_from_file(pred_path, overwrite=True, file_format='pred_only')
                    except:
                        break
                    scores = get_test_hits(suite)
                else:
                    scores = get_out_hits(suite, Path(f"./data/{task}/predictions/checklist/"), task, score, method, plus_iid=plus_iid, from_iid=from_iid)
                results.setdefault((config, method), {})[score] = scores
    return results

In [None]:
sa_hits = get_all_hits("sa")

In [None]:
qqp_hits = get_all_hits("qqp")

In [None]:
squad_hits = get_all_hits("squad")

In [None]:
all_hits = {
    "sa": sa_hits,
    "qqp": qqp_hits,
    "squad": squad_hits
}

In [None]:
with open("./data/suite_hits.pkl", "wb") as file:
    pickle.dump(all_hits, file, protocol=4)

In [None]:
np.mean([np.nanmean(v) for v in all_hits["sa"][("iid-t", "l2")]["classOut"].values()])

In [None]:
for k, v in all_hits["sa"][("iid-t", "l2")]["classOut"].items():
    print(k)
    print(v)
    print(all_hits["sa"][("iid-t", "l2")]["funcOut"][k])
    print()

In [None]:
all_hits["sa"][("iid-t", "l2")]["funcOut"]

In [None]:
np.sum([sum(~np.isnan(v)) for k, v in all_hits["sa"][("iid-t", "l2")]["classOut"].items()])