In [6]:
import sys
sys.path.insert(0,'../src/')

In [10]:
from model_bert import BertForSequenceClassification
import numpy as np
import pathlib
import matplotlib.pyplot as plt
import time

import pathlib
import json
import numpy as np
import time
import matplotlib.pyplot as plt


In [11]:
def load_head_data(experiments_path):
    head_data = {}
    for task_dir in experiments_path.iterdir():
        head_data[task_dir.stem] = {}
        for seed_dir in task_dir.iterdir():
            head_mask = np.load(seed_dir / "head_mask.npy")
            head_data[task_dir.stem][seed_dir.stem] = {
                "head_mask": head_mask,
            }
    return head_data
def load_mlp_data(experiments_path):
    mlp_data = {}
    for task_dir in experiments_path.iterdir():
        mlp_data[task_dir.stem] = {}
        for seed_dir in task_dir.iterdir():
            mlp_mask = np.load(seed_dir / "mlp_mask.npy")
            mlp_importance = np.load(seed_dir / "mlp_importance.npy")
            mlp_data[task_dir.stem][seed_dir.stem] = {
                "mlp_mask": mlp_mask,
                "mlp_importance": mlp_importance
            }
    return mlp_data

In [13]:
experiments_path = pathlib.Path("../masks/heads_mlps")
heads = load_head_data(experiments_path)
experiments_path = pathlib.Path("../masks/heads_mlps")
mlps = load_mlp_data(experiments_path)

In [53]:

def distib_weights(weights):
    tensors = []
    for weight_obj in weights:
        for k, weight in weight_obj.items():
            tensors.append(weight)
    if tensors:
        concat = torch.cat(tensors, dim=0)
    else:
        concat = torch.zeros((1),)
    return concat.abs().mean().cpu().item(), concat.abs().std().cpu().item()



from collections import Counter

head_counter = Counter()
mlp_counter = Counter()
head_mlp_counter = Counter()

for task in heads:
    for seed in heads[task]:
        head_mask = heads[task][seed]["head_mask"]
        mlp_mask = mlps[task][seed]["mlp_mask"]
        model_dir = pathlib.Path("../models/finetuned/") / task / seed
        model = BertForSequenceClassification.from_pretrained(model_dir)

        good_head_weights = []
        bad_head_weights = []
        for layer_id, layer in enumerate(head_mask):
            attention_component = model.bert.encoder.layer[layer_id].attention
            for head_id, is_good in enumerate(layer):
                start_idx = head_id * 64
                end_idx = (head_id + 1) * 64
                weights = {}
                attention_heads = attention_component.self
                weights["query"] = attention_heads.query.weight.data[start_idx:end_idx].reshape(-1)
                weights["query_bias"] = attention_heads.query.bias.data[start_idx:end_idx].reshape(-1)
                weights["key"] = attention_heads.key.weight.data[start_idx:end_idx].reshape(-1)
                weights["key_bias"] = attention_heads.key.bias.data[start_idx:end_idx].reshape(-1)
                weights["value"] = attention_heads.value.weight.data[start_idx:end_idx].reshape(-1)
                weights["value_bias"] = attention_heads.value.bias.data[start_idx:end_idx].reshape(-1)
                # Important to index second dimension, as it corresponds to individual head projecting back up to d from d_h
                weights["output"] = attention_component.output.dense.weight.data[:,start_idx:end_idx].reshape(-1)
                # We never prune output bias
                if is_good == 1:
                    good_head_weights.append(weights)
                else:
                    bad_head_weights.append(weights)

        good_mlp_weights = []
        bad_mlp_weights = []

        for layer_id, is_good in enumerate(mlp_mask):
            weights = {}
            layer = model.bert.encoder.layer[layer_id]
            weights["intermediate"] = layer.intermediate.dense.weight.reshape(-1)
            weights["intermediate_bias"]  = layer.intermediate.dense.bias.reshape(-1)
            weights["output"] = layer.output.dense.weight.reshape(-1)
            weights["output_bias"] = layer.output.dense.bias.reshape(-1)
            if is_good == 1:
                good_mlp_weights.append(weights)
            else:
                bad_mlp_weights.append(weights)
        
        print(f"Task - {task} Seed - {seed}")
        print("----------------------------------------------------")
        good_mean, good_std = distib_weights(good_head_weights)
        print(f"Good Heads - Mean: {good_mean}, Std: {good_std}")

        bad_mean, bad_std = distib_weights(bad_head_weights)
        print(f"Bad Heads - Mean: {bad_mean}, Std: {bad_std}")
        print(f"Good > Bad ? {good_mean > bad_mean}")
        if good_mean > bad_mean:
            head_counter[task] += 1
        
        
        
        good_mean, good_std = distib_weights(good_mlp_weights)
        print(f"Good MLPs - Mean: {good_mean}, Std: {good_std}")
        bad_mean, bad_std = distib_weights(bad_mlp_weights)
        print(f"Bad MLPs - Mean: {bad_mean}, Std: {bad_std}")
        print(f"Good > Bad ? {good_mean > bad_mean}")
        if good_mean > bad_mean:
            mlp_counter[task] += 1
        
        
        
        good_mean, good_std = distib_weights(good_head_weights + good_mlp_weights)
        print(f"Good Head+MLP - Mean: {good_mean}, Std: {good_std}")
        bad_mean, bad_std = distib_weights(bad_head_weights + bad_mlp_weights)
        print(f"Bad Head+MLPs - Mean: {bad_mean}, Std: {bad_std}")
        print(f"Good > Bad ? {good_mean > bad_mean}")
        if good_mean > bad_mean:
            head_mlp_counter[task] += 1
        
        
        
        
        print("------------------------------------------------------------")
        

        

Task - WNLI Seed - seed_71
----------------------------------------------------
Good Heads - Mean: 0.0, Std: nan
Bad Heads - Mean: 0.03006366267800331, Std: 0.02456958033144474
Good > Bad ? False
Good MLPs - Mean: 0.0, Std: nan
Bad MLPs - Mean: 0.030465181916952133, Std: 0.024271881207823753
Good > Bad ? False
Good Head+MLP - Mean: 0.0, Std: nan
Bad Head+MLPs - Mean: 0.03033130243420601, Std: 0.024372264742851257
Good > Bad ? False
------------------------------------------------------------
Task - WNLI Seed - seed_1337
----------------------------------------------------
Good Heads - Mean: 0.0, Std: nan
Bad Heads - Mean: 0.030063651502132416, Std: 0.024569561704993248
Good > Bad ? False
Good MLPs - Mean: 0.0, Std: nan
Bad MLPs - Mean: 0.03046516887843609, Std: 0.024271847680211067
Good > Bad ? False
Good Head+MLP - Mean: 0.0, Std: nan
Bad Head+MLPs - Mean: 0.03033129870891571, Std: 0.02437223494052887
Good > Bad ? False
------------------------------------------------------------
Task

In [54]:
head_counter

Counter({'WNLI': 1, 'MNLI': 1, 'RTE': 1, 'SST-2': 3, 'STS-B': 1})

In [55]:
mlp_counter

Counter({'MRPC': 4,
         'MNLI': 1,
         'QQP': 4,
         'RTE': 2,
         'SST-2': 4,
         'STS-B': 4,
         'CoLA': 5,
         'QNLI': 4})

In [56]:
head_mlp_counter

Counter({'MRPC': 3,
         'QQP': 4,
         'RTE': 1,
         'SST-2': 4,
         'STS-B': 3,
         'CoLA': 2,
         'QNLI': 2})