In [2]:
import sys
sys.path.insert(0, '../src')

In [3]:
from classify_attention_patterns import load_model
from argparse import Namespace
from run_glue import load_and_cache_examples, set_seed
from model_bert import BertForSequenceClassification, BertForMaskedLM
from config_bert import BertConfig
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
from transformers import BertTokenizer
import numpy as np
import torch
import random
from collections import Counter
import pathlib

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

loading region bounding boxes for computing carbon emissions region, this may take a moment...
 454/454... rate=522.13 Hz, eta=0:00:00, total=0:00:00, wall=11:04 CETT
Done!


In [4]:
import numpy as np
import pathlib
import matplotlib.pyplot as plt
import time
import torch
import pathlib
import numpy as np
import pathlib
import json
import numpy as np
import time
import matplotlib.pyplot as plt

def plot_matrix(mean_matrix, x_labels, y_labels, save_path, x_title='', y_title='', colour_map=plt.cm.OrRd):
    fig, ax = plt.subplots(figsize=(9,8))
    ax.set_xlabel(x_title, labelpad=20)
    ax.set_ylabel(y_title)
    ax.xaxis.set_ticks_position('top')
    ax.xaxis.set_label_position('top')
    ax.set_xticks(range(len(x_labels)))
    ax.set_yticks(range(len(y_labels)))
    ax.set_xticklabels(x_labels)
    ax.set_yticklabels(y_labels)
    ax.imshow(mean_matrix, cmap=colour_map)
    plt.tight_layout()
    fig.savefig(save_path, bbox_inches='tight')
    plt.close(fig)


In [5]:
set_seed(1337)

save_dir = pathlib.Path("./heatmap_attention_mask/")

mask_id = 0
for task in ["CoLA", "SST-2", "MRPC", "STS-B", "QQP", "MNLI", "QNLI", "RTE"]:
    for seed in ["seed_1337", "seed_42", "seed_86", "seed_71", "seed_166"]:
        #Load Model
        model_path = f"../models/finetuned/{task}/{seed}/"
        tokenizer = BertTokenizer.from_pretrained(model_path)
        config = BertConfig.from_pretrained(model_path)
        config.output_attentions = True
        transformer_model = BertForSequenceClassification.from_pretrained('bert-base-uncased',  config=config)
        transformer_model = transformer_model.eval()
        transformer_model.cuda()
        args = Namespace(data_dir=f"../data/glue/{task}/", local_rank=-1, 
                         model_name_or_path=model_path, 
                         overwrite_cache=False, model_type="bert", max_seq_length=128)
        eval_dataset = load_and_cache_examples(args, task.lower(), tokenizer, evaluate=True)
        eval_sampler = RandomSampler(eval_dataset)
        eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=1)

        k = 0
        print(task, seed)
        for batch in eval_dataloader:
            batch = tuple(t.to("cuda:0") for t in batch)
            input_data =  {"input_ids": batch[0], "attention_mask": batch[1], "labels": batch[3]}
            n_tokens = batch[1].sum()
            if n_tokens < 16:
                continue
            subwords = tokenizer.convert_ids_to_tokens([int(t.item()) for t in batch[0][0]][:n_tokens])
            with torch.no_grad():
                _, _, _, allalpha_f, _ = transformer_model(**input_data)
                head_attentions = None
                
                layer = random.choice(range(len(allalpha_f)))
                head = random.choice(range(allalpha_f[layer].size(1)))
                
                weight_normed_atttention = allalpha_f[layer][0,head].norm(dim=-1)
                print(layer, head, weight_normed_atttention.shape)
                attention = weight_normed_atttention[:n_tokens, :n_tokens].cpu().numpy()
                plot_matrix(attention, subwords, subwords, f"./heatmap_attention_mask/images/{task}_{seed}_{mask_id}.pdf")
                np.save(f"./heatmap_attention_mask/data/{task}_{seed}_{mask_id}.npy", attention)
                mask_id += 1
                del allalpha_f
            k += 1
            if k == 15:
                break

CoLA seed_1337
9 8 torch.Size([128, 128])
11 5 torch.Size([128, 128])
9 9 torch.Size([128, 128])
11 2 torch.Size([128, 128])
5 6 torch.Size([128, 128])
10 5 torch.Size([128, 128])
4 6 torch.Size([128, 128])
11 3 torch.Size([128, 128])
10 5 torch.Size([128, 128])
1 6 torch.Size([128, 128])
10 6 torch.Size([128, 128])
1 8 torch.Size([128, 128])
11 10 torch.Size([128, 128])
5 6 torch.Size([128, 128])
6 11 torch.Size([128, 128])
CoLA seed_42
4 10 torch.Size([128, 128])
0 5 torch.Size([128, 128])
9 1 torch.Size([128, 128])
7 3 torch.Size([128, 128])
2 5 torch.Size([128, 128])
10 11 torch.Size([128, 128])
9 7 torch.Size([128, 128])
9 4 torch.Size([128, 128])
3 8 torch.Size([128, 128])
1 10 torch.Size([128, 128])
8 9 torch.Size([128, 128])
9 0 torch.Size([128, 128])
1 6 torch.Size([128, 128])
10 9 torch.Size([128, 128])
0 10 torch.Size([128, 128])
CoLA seed_86
0 10 torch.Size([128, 128])
3 3 torch.Size([128, 128])
1 4 torch.Size([128, 128])
3 7 torch.Size([128, 128])
11 9 torch.Size([128, 128

In [6]:
!mkdir heatmap_attention_mask/images heatmap_attention_mask/data

In [7]:
!mv heatmap_attention_mask/*.npy heatmap_attention_mask/data

In [8]:
!mv heatmap_attention_mask/*.pdf heatmap_attention_mask/images

In [9]:
!touch heatmap_attention_mask/images/labels.txt