In [1]:
%matplotlib inline

In [2]:
import numpy as np
import pathlib
import matplotlib.pyplot as plt
import time
import json
import itertools
from statsmodels.stats.contingency_tables import cochrans_q
import krippendorff

In [3]:
def load_head_data(experiments_path):
    head_data = {}
    for task_dir in experiments_path.iterdir():
        head_data[task_dir.stem] = {}
        for seed_dir in task_dir.iterdir():
            head_mask = np.load(seed_dir / "head_mask.npy")
            head_data[task_dir.stem][seed_dir.stem] = {
                "head_mask": head_mask,
            }
    return head_data
def load_mlp_data(experiments_path):
    mlp_data = {}
    for task_dir in experiments_path.iterdir():
        mlp_data[task_dir.stem] = {}
        for seed_dir in task_dir.iterdir():
            mlp_mask = np.load(seed_dir / "mlp_mask.npy")
            mlp_importance = np.load(seed_dir / "mlp_importance.npy")
            mlp_data[task_dir.stem][seed_dir.stem] = {
                "mlp_mask": mlp_mask,
                "mlp_importance": mlp_importance
            }
    return mlp_data

In [4]:
experiments_path = pathlib.Path("../masks/heads_mlps")
heads = load_head_data(experiments_path)

experiments_path = pathlib.Path("../masks/heads_mlps_hans")
hans_heads = load_head_data(experiments_path)

for k, v in hans_heads.items():
    heads[k] = v

experiments_path = pathlib.Path("../masks/heads_mlps")
mlps = load_mlp_data(experiments_path)




experiments_path = pathlib.Path("../masks/heads_mlps_hans")
hans_mlps = load_mlp_data(experiments_path)

for k, v in hans_mlps.items():
    mlps[k] = v


In [8]:

def cochrans_q_masks(masks):
    inp = np.array(masks).transpose()
    return cochrans_q(inp)


def krippendorff_alpha_tasks_separate(data, mask="head_mask"):
    for task in sorted(data.keys()):
        krippendorff_alpha_tasks(data, [task], mask)     
        
def krippendorff_alpha_tasks(data, tasks, mask="head_mask"):
    masks = []
    seeds = sorted(data[tasks[0]].keys())
    for task in tasks:
        for seed in seeds:
            masks.append(data[task][seed][mask].reshape(-1))
    alpha = krippendorff.alpha(masks)
    print(','.join(tasks))
    print("---------")
    print(f'alpha: {alpha}')
    
def print_p_value_tasks_separate(data, mask="head_mask"):
    for task in sorted(data.keys()):
        print_p_value_tasks(data, [task], mask)        
        
def print_p_value_tasks(data, tasks, mask="head_mask"):
    masks = []
    seeds = sorted(data[tasks[0]].keys())
    for task in tasks:
        for seed in seeds:
            masks.append(data[task][seed][mask].reshape(-1))
    test_result = cochrans_q_masks(masks)
    print(','.join(tasks))
    print("---------------------------------------------------------------------")
    print(f'p-value: {test_result.pvalue}')
    print(f"{'Null hypothesis (all seeds are similar) is rejected.' if test_result.pvalue < 0.05 else 'Null hypothesis (all seeds are similar) is not rejected.'}")
    
    if len(tasks) == 1:
        masks_combos = list(itertools.combinations(range(len(masks)), 2))
    else:
        masks_combos = []
        for i in range(len(seeds)):
            for j in range(len(seeds)):
                if i < j:
                    mask_1_idx = i
                    mask_2_idx = len(seeds)  +  j
                    masks_combos.append((mask_1_idx, mask_2_idx))
    similar_masks_combos = []
    for mask_1, mask_2 in masks_combos:
        r = cochrans_q_masks([masks[mask_1], masks[mask_2]])
        if r.pvalue >= 0.05:
            task1_name, seed1_name = tasks[mask_1 // len(seeds)], seeds[mask_1 % len(seeds)]
            task2_name, seed2_name = tasks[mask_2 // len(seeds)], seeds[mask_2 % len(seeds)]
            similar_masks_combos.append((f"{task1_name}-{seed1_name}", f"{task2_name}-{seed2_name}"))
    
    print(f"Total mask pairs where Null hypothesis is not rejected - {len(similar_masks_combos)}")
    print(f"Total mask pairs - {len(masks_combos)}")
    print(f"Percentage - {len(similar_masks_combos)/ len(masks_combos)}")
    print("\nSimilar Mask Pairs:\n")
    print("\t".join([",".join(p) for p in similar_masks_combos]))
    print("\n\n")

# Seeds in a Task

## Heads

In [None]:
print_p_value_tasks_separate(heads)

In [10]:
krippendorff_alpha_tasks_separate(heads)

CoLA
---------
alpha: 0.22920627475942013
HANS
---------
alpha: 0.1511915220733583
HANS_MNLI
---------
alpha: 0.28187521458222786
MNLI
---------
alpha: 0.25860365198711066
MNLI_TWO
---------
alpha: 0.2586712903250721
MNLI_TWO_HALF
---------
alpha: 0.27296884185773074
MRPC
---------
alpha: 0.21637124992829693
QNLI
---------
alpha: 0.30249353262431733
QQP
---------
alpha: 0.1923483546980309
RTE
---------
alpha: 0.1810617760617761
SST-2
---------
alpha: 0.32545842217484
STS-B
---------
alpha: 0.2226213727678572
WNLI
---------
alpha: -0.24826388888888884


## MLPs

In [None]:
print_p_value_tasks_separate(mlps, mask="mlp_mask")

In [13]:
krippendorff_alpha_tasks_separate(mlps, mask="mlp_mask")

CoLA
---------
alpha: 0.4537037037037036
HANS
---------
alpha: 0.3413631022326674
HANS_MNLI
---------
alpha: 0.31835686777920413
MNLI
---------
alpha: 0.18055555555555558
MNLI_TWO
---------
alpha: 0.2804878048780487
MNLI_TWO_HALF
---------
alpha: 0.4319640564826701
MRPC
---------
alpha: 0.03804347826086951
QNLI
---------
alpha: 0.16193181818181823
QQP
---------
alpha: 0.03146374829001364
RTE
---------
alpha: -0.0993788819875776
SST-2
---------
alpha: 0.06349206349206349
STS-B
---------
alpha: 0.09870740305522918
WNLI
---------
alpha: nan


  return 1 - np.sum(o * d) / np.sum(e * d)


## Pairwise Task to task comparison

# Heads

In [None]:
tasks = sorted(heads.keys())
for t1, t2 in itertools.combinations(tasks, 2):
    print_p_value_tasks(heads, [t1, t2])

In [14]:
tasks = sorted(heads.keys())
for t1, t2 in itertools.combinations(tasks, 2):
    krippendorff_alpha_tasks(heads, [t1, t2])

CoLA,HANS
---------
alpha: 0.12605966529226342
CoLA,HANS_MNLI
---------
alpha: 0.15390838813984287
CoLA,MNLI
---------
alpha: 0.1472484134521177
CoLA,MNLI_TWO
---------
alpha: 0.14212075685326764
CoLA,MNLI_TWO_HALF
---------
alpha: 0.14118472263574544
CoLA,MRPC
---------
alpha: 0.13809026972718175
CoLA,QNLI
---------
alpha: 0.17694109929350343
CoLA,QQP
---------
alpha: 0.1261541914627099
CoLA,RTE
---------
alpha: 0.135281470108642
CoLA,SST-2
---------
alpha: 0.19279541253541732
CoLA,STS-B
---------
alpha: 0.14422753716871406
CoLA,WNLI
---------
alpha: -0.009957825396038844
HANS,HANS_MNLI
---------
alpha: 0.17639652014652096
HANS,MNLI
---------
alpha: 0.15112170220719234
HANS,MNLI_TWO
---------
alpha: 0.1591999239058064
HANS,MNLI_TWO_HALF
---------
alpha: 0.16878590515113767
HANS,MRPC
---------
alpha: 0.14627980692708387
HANS,QNLI
---------
alpha: 0.1579596739596747
HANS,QQP
---------
alpha: 0.1233765941364734
HANS,RTE
---------
alpha: 0.14104061594315065
HANS,SST-2
---------
alpha: 0.1

## MLPs

In [None]:
for t1, t2 in itertools.combinations(tasks, 2):
    print_p_value_tasks(mlps, [t1, t2], mask="mlp_mask")

In [17]:
tasks = sorted(mlps.keys())
for t1, t2 in itertools.combinations(tasks, 2):
    krippendorff_alpha_tasks(mlps, [t1, t2], mask="mlp_mask")

CoLA,HANS
---------
alpha: 0.22340111995284406
CoLA,HANS_MNLI
---------
alpha: 0.21501754385964922
CoLA,MNLI
---------
alpha: 0.2309992283950617
CoLA,MNLI_TWO
---------
alpha: 0.23728654970760243
CoLA,MNLI_TWO_HALF
---------
alpha: 0.17047953216374256
CoLA,MRPC
---------
alpha: 0.16700000000000004
CoLA,QNLI
---------
alpha: 0.11688311688311681
CoLA,QQP
---------
alpha: 0.1050849145873799
CoLA,RTE
---------
alpha: 0.07444444444444442
CoLA,SST-2
---------
alpha: 0.046904730297800845
CoLA,STS-B
---------
alpha: 0.1933392278219862
CoLA,WNLI
---------
alpha: -0.046202768424990825
HANS,HANS_MNLI
---------
alpha: 0.29772079772079785
HANS,MNLI
---------
alpha: 0.13049295049912513
HANS,MNLI_TWO
---------
alpha: 0.2573599240265908
HANS,MNLI_TWO_HALF
---------
alpha: 0.26543209876543206
HANS,MRPC
---------
alpha: 0.12598140308983674
HANS,QNLI
---------
alpha: 0.23404030811438215
HANS,QQP
---------
alpha: 0.19013888888888886
HANS,RTE
---------
alpha: 0.07431527913455627
HANS,SST-2
---------
alpha:

In [37]:

sets = []
for seed in heads['MNLI']:
    a = heads['MNLI'][seed]['head_mask'].reshape(-1)
    sets.append({x[0] for x in np.argwhere(a == 1)})

In [38]:
union =  set.union(*sets)
intersection = set.intersection(*sets)

In [39]:
len(intersection) / len(union)

0.291970802919708