In [1]:
import sys
sys.path.insert(0, '../src')

In [2]:
from classify_attention_patterns import load_model
from argparse import Namespace
from run_glue import load_and_cache_examples, set_seed
from model_bert import BertForSequenceClassification, BertForMaskedLM
from config_bert import BertConfig
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
from transformers import BertTokenizer
import numpy as np
import torch
import random
from collections import Counter

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

loading region bounding boxes for computing carbon emissions region, this may take a moment...
 454/454... rate=467.44 Hz, eta=0:00:00, total=0:00:00, wall=15:37 CETT
Done!


In [3]:
head_classifier_model, label2id, min_max_size = load_model("../models/head_classifier/classify_attention_patters.tar")
head_classifier_model = head_classifier_model.eval().cuda()
id2label = {idx:label for label, idx in label2id.items()}

# Fine-tuned model - Super Survivors

In [4]:
set_seed(1337)
for task in ["CoLA", "SST-2", "MRPC", "STS-B", "QQP", "MNLI", "QNLI", "RTE"]:
    attention_counter = Counter()
    normed_attention_counter = Counter()
    for seed in ["seed_1337", "seed_42", "seed_86", "seed_71", "seed_166"]:
        #Load Model
        model_path = f"../models/finetuned/{task}/{seed}/"
        tokenizer = BertTokenizer.from_pretrained(model_path)
        config = BertConfig.from_pretrained(model_path)
        config.output_attentions = True
        transformer_model = BertForSequenceClassification.from_pretrained(model_path,  config=config)
        # Prune
        mask_path = f"../masks/heads_mlps_super/{task}/{seed}/"
        head_mask = np.load(f"{mask_path}/head_mask.npy")
        mlp_mask = np.load(f"{mask_path}/mlp_mask.npy")
        head_mask = torch.from_numpy(head_mask)
        heads_to_prune = {} 
        for layer in range(len(head_mask)):
            heads_to_mask = [h[0] for h in (1 - head_mask[layer].long()).nonzero().tolist()]
            heads_to_prune[layer] = heads_to_mask
        mlps_to_prune = [h[0] for h in (1 - torch.from_numpy(mlp_mask).long()).nonzero().tolist()]

        transformer_model.prune_heads(heads_to_prune)
        transformer_model.prune_mlps(mlps_to_prune)
        transformer_model = transformer_model.eval()
        transformer_model.cuda()
        args = Namespace(data_dir=f"../data/glue/{task}/", local_rank=-1, 
                         model_name_or_path=model_path, 
                         overwrite_cache=False, model_type="bert", max_seq_length=128)
        eval_dataset = load_and_cache_examples(args, task.lower(), tokenizer, evaluate=True)
        eval_sampler = RandomSampler(eval_dataset)
        eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=1)
        input_data = None
        attention_head_types = [[] for _ in range(12)]
        normed_attention_head_types = [[] for _ in range(12)]

        k = 0
        for batch in eval_dataloader:
            batch = tuple(t.to("cuda:0") for t in batch)
            n_tokens = batch[1].sum()
            if n_tokens < min_max_size[0]:
                continue
            input_data =  {"input_ids": batch[0], "attention_mask": batch[1], "labels": batch[3]}
            with torch.no_grad():
                _, _, attentions, allalpha_f, attention_mask = transformer_model(**input_data)
                for layer in range(len(attentions)):
                    if attentions[layer] is None:
                        continue

                    num_classifier_tokens = max(n_tokens, min_max_size[0]+20)
                    head_attentions = attentions[layer].transpose(0, 1)

                    head_attentions = head_attentions[:,:,:num_classifier_tokens,:num_classifier_tokens]
                    logits = head_classifier_model(head_attentions)
                    label_ids = torch.argmax(logits, dim=-1)
                    labels = [id2label[int(label_id.item())] for label_id in label_ids]
                    if len(attention_head_types[layer]) == 0:
                        for i in range(len(labels)):
                            attention_head_types[layer].append(Counter())
                    for i, label in enumerate(labels):
                        attention_head_types[layer][i][label] += 1

                    normed_alpha_f = allalpha_f[layer].norm(dim=-1)
                    weight_normed_attention = (normed_alpha_f+attention_mask).softmax(dim=-1)
                    head_attentions = weight_normed_attention.transpose(0, 1)

                    head_attentions = head_attentions[:,:,:num_classifier_tokens,:num_classifier_tokens]
                    logits = head_classifier_model(head_attentions)
                    label_ids = torch.argmax(logits, dim=-1)
                    labels = [id2label[int(label_id.item())] for label_id in label_ids]
                    if len(normed_attention_head_types[layer]) == 0:
                        for i in range(len(labels)):
                            normed_attention_head_types[layer].append(Counter())
                    for i, label in enumerate(labels):
                        normed_attention_head_types[layer][i][label] += 1
                del attentions, allalpha_f, attention_mask
            k += 1
            if k == 100:
                break
        for layer in attention_head_types:
            for head_type_ctr in layer:
                attention_counter += head_type_ctr
        total_counter = Counter()
        for layer in normed_attention_head_types:
            for head_type_ctr in layer:
                normed_attention_counter += head_type_ctr
    print(task)
    print("attention:", {k:v/sum(attention_counter.values()) for k,v in attention_counter.most_common()})
    print("weight normed:", {k:v/sum(normed_attention_counter.values()) for k,v in normed_attention_counter.most_common()})

CoLA
attention: {'other': 0.4611764705882353, 'mix': 0.24670588235294116, 'vertical': 0.10911764705882353, 'diagonal': 0.09170588235294118, 'block': 0.09129411764705882}
weight normed: {'block': 0.789235294117647, 'diagonal': 0.11829411764705883, 'other': 0.07758823529411765, 'mix': 0.01488235294117647}
SST-2
attention: {'other': 0.5504210526315789, 'block': 0.2391578947368421, 'diagonal': 0.11105263157894738, 'mix': 0.08294736842105263, 'vertical': 0.016421052631578947}
weight normed: {'block': 0.798, 'other': 0.152, 'diagonal': 0.040736842105263155, 'mix': 0.009263157894736843}
MRPC
attention: {'vertical': 0.39321428571428574, 'other': 0.24957142857142858, 'block': 0.21857142857142858, 'mix': 0.06992857142857142, 'diagonal': 0.06871428571428571}
weight normed: {'block': 0.9432142857142857, 'diagonal': 0.03578571428571429, 'other': 0.020785714285714286, 'mix': 0.00021428571428571427}
STS-B
attention: {'other': 0.6716666666666666, 'block': 0.24308333333333335, 'diagonal': 0.0325, 'mix'

# Finetuned Model - All

In [9]:
set_seed(1337)
for task in ["CoLA", "SST-2", "MRPC", "STS-B", "QQP", "MNLI", "QNLI", "RTE"]:
    attention_counter = Counter()
    normed_attention_counter = Counter()
    for seed in ["seed_1337", "seed_42", "seed_86", "seed_71", "seed_166"]:
        #Load Model
        model_path = f"../models/finetuned/{task}/{seed}/"
        tokenizer = BertTokenizer.from_pretrained(model_path)
        config = BertConfig.from_pretrained(model_path)
        config.output_attentions = True
        transformer_model = BertForSequenceClassification.from_pretrained(model_path,  config=config)
        transformer_model = transformer_model.eval()
        transformer_model.cuda()
        args = Namespace(data_dir=f"../data/glue/{task}/", local_rank=-1, 
                         model_name_or_path=model_path, 
                         overwrite_cache=False, model_type="bert", max_seq_length=128)
        eval_dataset = load_and_cache_examples(args, task.lower(), tokenizer, evaluate=True)
        eval_sampler = RandomSampler(eval_dataset)
        eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=1)
        input_data = None
        attention_head_types = [[] for _ in range(12)]
        normed_attention_head_types = [[] for _ in range(12)]

        k = 0
        for batch in eval_dataloader:
            batch = tuple(t.to("cuda:0") for t in batch)
            n_tokens = batch[1].sum()
            if n_tokens < min_max_size[0]:
                continue
            input_data =  {"input_ids": batch[0], "attention_mask": batch[1], "labels": batch[3]}
            with torch.no_grad():
                _, _, attentions, allalpha_f, attention_mask = transformer_model(**input_data)
                for layer in range(len(attentions)):
                    if attentions[layer] is None:
                        continue

                    num_classifier_tokens = max(n_tokens, min_max_size[0]+20)
                    head_attentions = attentions[layer].transpose(0, 1)

                    head_attentions = head_attentions[:,:,:num_classifier_tokens,:num_classifier_tokens]
                    logits = head_classifier_model(head_attentions)
                    label_ids = torch.argmax(logits, dim=-1)
                    labels = [id2label[int(label_id.item())] for label_id in label_ids]
                    if len(attention_head_types[layer]) == 0:
                        for i in range(len(labels)):
                            attention_head_types[layer].append(Counter())
                    for i, label in enumerate(labels):
                        attention_head_types[layer][i][label] += 1

                    normed_alpha_f = allalpha_f[layer].norm(dim=-1)
                    weight_normed_attention = (normed_alpha_f+attention_mask).softmax(dim=-1)
                    head_attentions = weight_normed_attention.transpose(0, 1)

                    head_attentions = head_attentions[:,:,:num_classifier_tokens,:num_classifier_tokens]
                    logits = head_classifier_model(head_attentions)
                    label_ids = torch.argmax(logits, dim=-1)
                    labels = [id2label[int(label_id.item())] for label_id in label_ids]
                    if len(normed_attention_head_types[layer]) == 0:
                        for i in range(len(labels)):
                            normed_attention_head_types[layer].append(Counter())
                    for i, label in enumerate(labels):
                        normed_attention_head_types[layer][i][label] += 1
                del attentions, allalpha_f, attention_mask
            k += 1
            if k == 100:
                break
        for layer in attention_head_types:
            for head_type_ctr in layer:
                attention_counter += head_type_ctr
        total_counter = Counter()
        for layer in normed_attention_head_types:
            for head_type_ctr in layer:
                normed_attention_counter += head_type_ctr
    print(task)
    print("attention:", {k:v/sum(attention_counter.values()) for k,v in attention_counter.most_common()})
    print("weight normed:", {k:v/sum(normed_attention_counter.values()) for k,v in normed_attention_counter.most_common()})

CoLA
attention: {'other': 0.4894027777777778, 'vertical': 0.229875, 'mix': 0.16377777777777777, 'block': 0.065125, 'diagonal': 0.051819444444444446}
weight normed: {'block': 0.8005416666666667, 'other': 0.1351527777777778, 'diagonal': 0.04770833333333333, 'mix': 0.01622222222222222, 'vertical': 0.000375}
SST-2
attention: {'other': 0.3145833333333333, 'mix': 0.30145833333333333, 'vertical': 0.21038888888888888, 'diagonal': 0.08898611111111111, 'block': 0.08458333333333333}
weight normed: {'block': 0.6183611111111111, 'other': 0.2445277777777778, 'diagonal': 0.077875, 'mix': 0.058375, 'vertical': 0.0008611111111111111}
MRPC
attention: {'other': 0.29868055555555556, 'mix': 0.2895277777777778, 'vertical': 0.233375, 'diagonal': 0.11504166666666667, 'block': 0.063375}
weight normed: {'block': 0.7511666666666666, 'other': 0.14125, 'diagonal': 0.07773611111111112, 'mix': 0.028819444444444446, 'vertical': 0.0010277777777777778}
STS-B
attention: {'other': 0.5713055555555555, 'mix': 0.16594444444

# Pre-trained Super Survivors

In [10]:
set_seed(1337)
for task in ["CoLA", "SST-2", "MRPC", "STS-B", "QQP", "MNLI", "QNLI", "RTE"]:
    attention_counter = Counter()
    normed_attention_counter = Counter()
    for seed in ["seed_1337", "seed_42", "seed_86", "seed_71", "seed_166"]:
        #Load Model
        model_path = f"../models/finetuned/{task}/{seed}/"
        tokenizer = BertTokenizer.from_pretrained(model_path)
        config = BertConfig.from_pretrained(model_path)
        config.output_attentions = True
        transformer_model = BertForSequenceClassification.from_pretrained('bert-base-uncased',  config=config)
        # Prune
        mask_path = f"../masks/heads_mlps_super/{task}/{seed}/"
        head_mask = np.load(f"{mask_path}/head_mask.npy")
        mlp_mask = np.load(f"{mask_path}/mlp_mask.npy")
        head_mask = torch.from_numpy(head_mask)
        heads_to_prune = {} 
        for layer in range(len(head_mask)):
            heads_to_mask = [h[0] for h in (1 - head_mask[layer].long()).nonzero().tolist()]
            heads_to_prune[layer] = heads_to_mask
        mlps_to_prune = [h[0] for h in (1 - torch.from_numpy(mlp_mask).long()).nonzero().tolist()]

        transformer_model.prune_heads(heads_to_prune)
        transformer_model.prune_mlps(mlps_to_prune)
        transformer_model = transformer_model.eval()
        transformer_model.cuda()
        args = Namespace(data_dir=f"../data/glue/{task}/", local_rank=-1, 
                         model_name_or_path=model_path, 
                         overwrite_cache=False, model_type="bert", max_seq_length=128)
        eval_dataset = load_and_cache_examples(args, task.lower(), tokenizer, evaluate=True)
        eval_sampler = RandomSampler(eval_dataset)
        eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=1)
        input_data = None
        attention_head_types = [[] for _ in range(12)]
        normed_attention_head_types = [[] for _ in range(12)]

        k = 0
        for batch in eval_dataloader:
            batch = tuple(t.to("cuda:0") for t in batch)
            n_tokens = batch[1].sum()
            if n_tokens < min_max_size[0]:
                continue
            input_data =  {"input_ids": batch[0], "attention_mask": batch[1], "labels": batch[3]}
            with torch.no_grad():
                _, _, attentions, allalpha_f, attention_mask = transformer_model(**input_data)
                for layer in range(len(attentions)):
                    if attentions[layer] is None:
                        continue

                    num_classifier_tokens = max(n_tokens, min_max_size[0]+20)
                    head_attentions = attentions[layer].transpose(0, 1)

                    head_attentions = head_attentions[:,:,:num_classifier_tokens,:num_classifier_tokens]
                    logits = head_classifier_model(head_attentions)
                    label_ids = torch.argmax(logits, dim=-1)
                    labels = [id2label[int(label_id.item())] for label_id in label_ids]
                    if len(attention_head_types[layer]) == 0:
                        for i in range(len(labels)):
                            attention_head_types[layer].append(Counter())
                    for i, label in enumerate(labels):
                        attention_head_types[layer][i][label] += 1

                    normed_alpha_f = allalpha_f[layer].norm(dim=-1)
                    weight_normed_attention = (normed_alpha_f+attention_mask).softmax(dim=-1)
                    head_attentions = weight_normed_attention.transpose(0, 1)

                    head_attentions = head_attentions[:,:,:num_classifier_tokens,:num_classifier_tokens]
                    logits = head_classifier_model(head_attentions)
                    label_ids = torch.argmax(logits, dim=-1)
                    labels = [id2label[int(label_id.item())] for label_id in label_ids]
                    if len(normed_attention_head_types[layer]) == 0:
                        for i in range(len(labels)):
                            normed_attention_head_types[layer].append(Counter())
                    for i, label in enumerate(labels):
                        normed_attention_head_types[layer][i][label] += 1
                del attentions, allalpha_f, attention_mask
            k += 1
            if k == 100:
                break
        for layer in attention_head_types:
            for head_type_ctr in layer:
                attention_counter += head_type_ctr
        total_counter = Counter()
        for layer in normed_attention_head_types:
            for head_type_ctr in layer:
                normed_attention_counter += head_type_ctr
    print(task)
    print("attention:", {k:v/sum(attention_counter.values()) for k,v in attention_counter.most_common()})
    print("weight normed:", {k:v/sum(normed_attention_counter.values()) for k,v in normed_attention_counter.most_common()})

CoLA
attention: {'other': 0.4672352941176471, 'mix': 0.26794117647058824, 'vertical': 0.10170588235294117, 'diagonal': 0.0838235294117647, 'block': 0.07929411764705882}
weight normed: {'block': 0.7966470588235294, 'diagonal': 0.10970588235294118, 'other': 0.08041176470588235, 'mix': 0.013235294117647059}
SST-2
attention: {'other': 0.5573684210526316, 'block': 0.23326315789473684, 'diagonal': 0.11168421052631579, 'mix': 0.07810526315789473, 'vertical': 0.019578947368421053}
weight normed: {'block': 0.8006315789473685, 'other': 0.1491578947368421, 'diagonal': 0.04347368421052632, 'mix': 0.006736842105263158}
MRPC
attention: {'vertical': 0.40035714285714286, 'other': 0.2547857142857143, 'block': 0.21471428571428572, 'diagonal': 0.06785714285714285, 'mix': 0.062285714285714285}
weight normed: {'block': 0.9435714285714286, 'diagonal': 0.03578571428571429, 'other': 0.020428571428571428, 'mix': 0.00021428571428571427}
STS-B
attention: {'other': 0.6811666666666667, 'block': 0.22441666666666665

# Pre-trained All

In [11]:
set_seed(1337)
for task in ["CoLA", "SST-2", "MRPC", "STS-B", "QQP", "MNLI", "QNLI", "RTE"]:
    attention_counter = Counter()
    normed_attention_counter = Counter()
    for seed in ["seed_1337", "seed_42", "seed_86", "seed_71", "seed_166"]:
        #Load Model
        model_path = f"../models/finetuned/{task}/{seed}/"
        tokenizer = BertTokenizer.from_pretrained(model_path)
        config = BertConfig.from_pretrained(model_path)
        config.output_attentions = True
        transformer_model = BertForSequenceClassification.from_pretrained('bert-base-uncased',  config=config)
        transformer_model = transformer_model.eval()
        transformer_model.cuda()
        args = Namespace(data_dir=f"../data/glue/{task}/", local_rank=-1, 
                         model_name_or_path=model_path, 
                         overwrite_cache=False, model_type="bert", max_seq_length=128)
        eval_dataset = load_and_cache_examples(args, task.lower(), tokenizer, evaluate=True)
        eval_sampler = RandomSampler(eval_dataset)
        eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=1)
        input_data = None
        attention_head_types = [[] for _ in range(12)]
        normed_attention_head_types = [[] for _ in range(12)]

        k = 0
        for batch in eval_dataloader:
            batch = tuple(t.to("cuda:0") for t in batch)
            n_tokens = batch[1].sum()
            if n_tokens < min_max_size[0]:
                continue
            input_data =  {"input_ids": batch[0], "attention_mask": batch[1], "labels": batch[3]}
            with torch.no_grad():
                _, _, attentions, allalpha_f, attention_mask = transformer_model(**input_data)
                for layer in range(len(attentions)):
                    if attentions[layer] is None:
                        continue

                    num_classifier_tokens = max(n_tokens, min_max_size[0]+20)
                    head_attentions = attentions[layer].transpose(0, 1)

                    head_attentions = head_attentions[:,:,:num_classifier_tokens,:num_classifier_tokens]
                    logits = head_classifier_model(head_attentions)
                    label_ids = torch.argmax(logits, dim=-1)
                    labels = [id2label[int(label_id.item())] for label_id in label_ids]
                    if len(attention_head_types[layer]) == 0:
                        for i in range(len(labels)):
                            attention_head_types[layer].append(Counter())
                    for i, label in enumerate(labels):
                        attention_head_types[layer][i][label] += 1

                    normed_alpha_f = allalpha_f[layer].norm(dim=-1)
                    weight_normed_attention = (normed_alpha_f+attention_mask).softmax(dim=-1)
                    head_attentions = weight_normed_attention.transpose(0, 1)

                    head_attentions = head_attentions[:,:,:num_classifier_tokens,:num_classifier_tokens]
                    logits = head_classifier_model(head_attentions)
                    label_ids = torch.argmax(logits, dim=-1)
                    labels = [id2label[int(label_id.item())] for label_id in label_ids]
                    if len(normed_attention_head_types[layer]) == 0:
                        for i in range(len(labels)):
                            normed_attention_head_types[layer].append(Counter())
                    for i, label in enumerate(labels):
                        normed_attention_head_types[layer][i][label] += 1
                del attentions, allalpha_f, attention_mask
            k += 1
            if k == 100:
                break
        for layer in attention_head_types:
            for head_type_ctr in layer:
                attention_counter += head_type_ctr
        total_counter = Counter()
        for layer in normed_attention_head_types:
            for head_type_ctr in layer:
                normed_attention_counter += head_type_ctr
    print(task)
    print("attention:", {k:v/sum(attention_counter.values()) for k,v in attention_counter.most_common()})
    print("weight normed:", {k:v/sum(normed_attention_counter.values()) for k,v in normed_attention_counter.most_common()})

CoLA
attention: {'other': 0.48944444444444446, 'vertical': 0.26025, 'mix': 0.14822222222222223, 'block': 0.05431944444444444, 'diagonal': 0.04776388888888889}
weight normed: {'block': 0.8306111111111111, 'other': 0.11386111111111111, 'diagonal': 0.043166666666666666, 'mix': 0.012319444444444444, 'vertical': 4.1666666666666665e-05}
SST-2
attention: {'other': 0.3076527777777778, 'mix': 0.27318055555555554, 'vertical': 0.25622222222222224, 'diagonal': 0.08556944444444445, 'block': 0.077375}
weight normed: {'block': 0.6592638888888889, 'other': 0.2195138888888889, 'diagonal': 0.07645833333333334, 'mix': 0.04452777777777778, 'vertical': 0.00023611111111111112}
MRPC
attention: {'other': 0.29683333333333334, 'vertical': 0.2928611111111111, 'mix': 0.2599722222222222, 'diagonal': 0.09270833333333334, 'block': 0.057625}
weight normed: {'block': 0.7734861111111111, 'other': 0.13669444444444445, 'diagonal': 0.06644444444444444, 'mix': 0.022319444444444444, 'vertical': 0.0010555555555555555}
STS-B
