**Note:** The implementation and preprocessing codes were sourced from https://github.com/hmohebbi/ValueZeroing.

# The Device

In [8]:
import torch

# GPU
if torch.cuda.is_available():
    device = torch.device(f"cuda:{0}")
    print('We will use the GPU:', torch.cuda.get_device_name(0))
else:
    device = torch.device("cpu")
    print('No GPU available, using the CPU instead.')

We will use the GPU: Quadro RTX 8000


# utils

In [9]:
import numpy as np

NUM_LABELS = {
    "ana": 2,
    "dna": 2,
    "dnaa": 2,
    "rpsv": 2,
    "darn": 2,
    "NA": 2,
}

blimp_to_label = {
    'singular': 0,
    'plural': 1,
}

MODEL_PATH = {
    'bert': 'bert-base-uncased',
    'roberta': 'roberta-base',
    'electra': 'google/electra-base-generator',
    'deberta': 'microsoft/deberta-v3-base'
}

BLIMP_TASKS = [
    "ana",
    'dna',
    "dnaa",
    "rpsv",
    "darn",
    "NA",
]

def blimp_to_features(data, tokenizer, max_length, input_masking, mlm):
    all_features = []
    for example in data:
        text = example['sentence_good']
        tokens = []
        cue_indices = []
        # token to id
        for w_ind, word in enumerate(text):
            ids = tokenizer.encode(word, add_special_tokens=False)
            if w_ind in example['cue_indices']:
                cue_indices.append(len(tokens))
            if w_ind == example['target_index']:
                target_index = len(tokens)
            tokens.extend(ids)
        
        tokens = [tokenizer.cls_token_id] + tokens + [tokenizer.sep_token_id]
        cue_indices = [x+1 for x in cue_indices] # 'cause of adding cls
        target_index += 1 # 'cause of adding cls
        if input_masking:
            tokens[target_index] = tokenizer.mask_token_id

        # padding
        length = len(tokens)
        inputs = {}
        inputs['input_ids'] = tokens if max_length is None else tokens + [tokenizer.pad_token_id]*(max_length - length)
        inputs['attention_mask'] = [1]*length if max_length is None else [1]*length + [0]*(max_length - length)
        inputs['token_type_ids'] = [0]*length if max_length is None else [0]*max_length
        inputs['target_index'] = target_index
        inputs['labels'] = tokenizer.convert_tokens_to_ids(example['good_word']) if mlm else blimp_to_label[example['labels']]
        inputs['good_token_id'] = tokenizer.convert_tokens_to_ids(example['good_word'])
        inputs['bad_token_id'] = tokenizer.convert_tokens_to_ids(example['bad_word'])

        # As a 2d tensor, we need all rows to have the same length. So, we add -1 to the end of each list.
        inputs['cue_indices'] = cue_indices + (10 - len(cue_indices)) * [-1]

        all_features.append(inputs)
    return all_features[0] if len(all_features) == 1 else all_features

PREPROCESS_FUNC = {
    'ana': blimp_to_features,
    'dna': blimp_to_features,
    'dnaa': blimp_to_features,
    'rpsv': blimp_to_features,
    'darn': blimp_to_features,
    'NA': blimp_to_features,
}

In [10]:
SELECTED_GPU = 0
MODEL_NAME = 'roberta'
FIXED = False
CHECKPOINT = "full"
METRIC = 'cosine' 
TASK = "NA"
SPLIT = "test"
MAX_LENGTH = 32
NUM_TRAIN_EPOCHS = 1
PER_DEVICE_BATCH_SIZE = 1
SAVE_SCORES = False

INPUT_MASKING = True
MLM = True
LEARNING_RATE = 3e-5
LR_SCHEDULER_TYPE = "linear" 
WARMUP_RATIO = 0.1
SEED = 42

import numpy as np
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
import pickle

import torch
from torch.utils.data import DataLoader
from torch.nn import CrossEntropyLoss
from sklearn.metrics.pairwise import cosine_distances

from datasets import (
    load_dataset,
    load_from_disk,
    load_metric,
)

from modeling.customized_modeling_bert import BertForMaskedLM
from modeling.customized_modeling_roberta import RobertaForMaskedLM
# from modeling.customized_modeling_electra import ElectraForMaskedLM
from transformers import (
    AutoConfig,
    AutoTokenizer,
    AdamW,
    get_scheduler,
    default_data_collator,
    set_seed,
)
set_seed(SEED)

# Load Dataset
if TASK in BLIMP_TASKS:
    data_path = f"./BLIMP Dataset/{MODEL_NAME}/"
    eval_data = load_from_disk(data_path)[SPLIT]
else:
    print("Not implemented yet!")
    exit()

num_labels = NUM_LABELS[TASK]

# Load Tokenizer & Model
config = AutoConfig.from_pretrained(MODEL_PATH[MODEL_NAME], num_labels=num_labels)
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH[MODEL_NAME])  

if MODEL_NAME == "bert":
    model = BertForMaskedLM.from_pretrained(MODEL_PATH[MODEL_NAME], config=config)
elif MODEL_NAME == "roberta":
    model = RobertaForMaskedLM.from_pretrained(MODEL_PATH[MODEL_NAME], config=config)
# elif MODEL_NAME == "electra":
#     model = ElectraForMaskedLM.from_pretrained(MODEL_PATH[MODEL_NAME], config=config)
else:
    print("model doesn't exist")

model.to(device)

# Preprocessing
eval_dataset = PREPROCESS_FUNC[TASK](eval_data, tokenizer, max_length=None, input_masking=INPUT_MASKING, mlm=MLM)
eval_dataloader = DataLoader(eval_dataset, collate_fn=default_data_collator, batch_size=PER_DEVICE_BATCH_SIZE)
num_examples = len(eval_dataset)

# metric & Loss
metric = load_metric("accuracy")
loss_fct = CrossEntropyLoss()

tag = "forseqclassification_"
tag += "pretrained" if FIXED else "finetuned" 
if MLM:
    tag += "_MLM"

masking_tag = "masked" if INPUT_MASKING else "full"


In [11]:
import warnings
warnings.filterwarnings("ignore", category=FutureWarning)

DISTANCE_FUNC = {'cosine': cosine_distances}

In [158]:
model.load_state_dict(torch.load(f'{MODEL_NAME}_full_{tag}_epoch{NUM_TRAIN_EPOCHS}.pt'))

<All keys matched successfully>

# Value Zeroing Importance Scores

In [None]:
# rollout
def compute_joint_attention(att_mat, res=True):
    if res:
        residual_att = np.eye(att_mat.shape[1])[None,...]
        att_mat = att_mat + residual_att
        att_mat = att_mat / att_mat.sum(axis=-1)[...,None]
    
    joint_attentions = np.zeros(att_mat.shape)
    layers = joint_attentions.shape[0]
    joint_attentions[0] = att_mat[0]
    for i in np.arange(1,layers):
        joint_attentions[i] = att_mat[i].dot(joint_attentions[i-1])
        
    return joint_attentions

Computing Value Zeroing scores.

In [None]:
all_valuezeroing_scores = [] # (#layers, #seq_length, #seq_length)
all_rollout_valuezeroing_scores = [] # (#layers, #seq_length, #seq_length)

progress_bar = tqdm(range(num_examples))
for step, inputs in enumerate(eval_dataloader):
    
    inputs = {k: v.to(device) for k, v in inputs.items()}

    with torch.no_grad():
        outputs = model(inputs['input_ids'],
                        attention_mask=inputs['attention_mask'],
                        # token_type_ids=inputs['token_type_ids'], 
                        output_hidden_states=True, output_attentions=False)

    org_hidden_states = torch.stack(outputs['hidden_states']).squeeze(1)
    input_shape = inputs['input_ids'].size() 
    batch_size, seq_length = input_shape

    ## layerwise zeroing value
    score_matrix = np.zeros((config.num_hidden_layers, seq_length, seq_length))
    for l, layer_module in enumerate(model.roberta.encoder.layer): # change based on the model
        for t in range(seq_length):
            
            if MODEL_NAME == 'bert':
                extended_blanking_attention_mask: torch.Tensor = model.bert.get_extended_attention_mask(inputs['attention_mask'], input_shape).to(device)            
            elif MODEL_NAME == 'roberta':
                extended_blanking_attention_mask: torch.Tensor = model.roberta.get_extended_attention_mask(inputs['attention_mask'], input_shape).to(device)            
                
            with torch.no_grad():
                layer_outputs = layer_module(org_hidden_states[l].unsqueeze(0), # previous layer's original output 
                                            attention_mask=extended_blanking_attention_mask,
                                            output_attentions=False,
                                            zero_value_index=t,
                                            )
                
            hidden_states = layer_outputs[0].squeeze().detach().cpu().numpy()
            
            # compute similarity between original and new outputs
            # cosine
            x = hidden_states
            y = org_hidden_states[l+1].detach().cpu().numpy()

            distances = DISTANCE_FUNC[METRIC](x, y).diagonal()
            score_matrix[l, :, t] = distances

    score_matrix = score_matrix / np.sum(score_matrix, axis=-1, keepdims=True)
    
    all_valuezeroing_scores.append(score_matrix)
    all_rollout_valuezeroing_scores.append(compute_joint_attention(score_matrix, res=False))
    
    progress_bar.update(1)

# Alignment Metrics

Here, we compute Dot Product and Average Precision. At first, let's define a method to compute Average Precision.

In [None]:
# Cleaning the tensors from -1 padding
def CI_cleaner(CI):
    first_pad_index = torch.where(CI == -1)[0][0].item() # We have used -1 as paddings of CIs
    return CI[:first_pad_index]

# Calculating Precision
def precision(TP, FP):
    return TP / (TP + FP)

# Calculating Recall
def recall(TP, FN):
    return TP / (TP + FN)

# Calculating Average Precision
def avg_precision(topk, CI):
    R_base = 0 # The starting recall before the first round
    AP, TP, FP, FN = 0, 0, 0, len(CI)
    previous_recall = R_base
    for i in range(len(topk)):
        if topk[i] in CI:
            TP += 1
            FN -= 1
        else:
            FP += 1

        AP += (recall(TP, FN) -  previous_recall) * precision(TP, FP)
        previous_recall = recall(TP, FN)

    return AP

topk = torch.tensor([1, 0, 3, 4, 2])
CI = torch.tensor([3, -1, -1, -1])
# CI = CI_cleaner(CI)
# print(avg_precision(topk, CI))

In [126]:
diagram_layers = range(12)
APs_vz = dict()

for layer in diagram_layers:
    APs_vz[f'layer{layer}'] = list()

sum_vz_scores = 0

model.eval()
for i, batch_sample in enumerate(tqdm(eval_dataloader)):
    
    CI = CI_cleaner(batch_sample['cue_indices'][0]) # [0]: because we only have one sample in each batch
    
    ### Average Precision
    batch_lengths = batch_sample['attention_mask'].sum(axis=-1)
    mask_index = batch_sample['target_index'] # mask_index = target_index
    # The contribution of each token in the sequence in building the rep. of target token for different layers
    vz_importance = all_valuezeroing_scores[i][:, batch_sample['target_index']] # shape: [12, seq_len]
    # Convert to torch tensor form numpy ndarray
    vz_importance = torch.from_numpy(vz_importance)
    # batch_lengths[0]: because we only have one sample in each batch
    vz_importance_topk = torch.topk(vz_importance, k=batch_lengths[0], largest=True, dim=1).indices
    
    ### excluding mask_index
    mask_index_tensor = torch.full_like(vz_importance_topk, mask_index.item())
    # Create a mask that is True for elements not equal to mask_index
    mask = vz_importance_topk != mask_index_tensor
    # Apply the mask to exclude mask_index
    vz_importance_topk_filtered = vz_importance_topk[mask].view(vz_importance_topk.size(0), -1)

    for layer, layer_importance in enumerate(vz_importance_topk_filtered):
        APs_vz[f'layer{layer}'].append(avg_precision(layer_importance, CI))
        
    ### Dot Product
    # Remove mask_index and then normalize the scores.
    vz_scores = all_valuezeroing_scores[i][:, batch_sample['target_index']]
    vz_scores = torch.from_numpy(vz_scores)
    vz_scores = torch.concat((vz_scores[:, :mask_index.item()], vz_scores[:, mask_index.item() + 1:]), dim=1)
    vz_scores = vz_scores / vz_scores.sum(dim=-1, keepdim=True)
    
    if CI[-1].item() > mask_index.item():
        CI_scores = vz_scores[:, CI - 1] # Because of the removed mask_index
    else:
        CI_scores = vz_scores[:, CI]
    
    # In case there are more than one cue indices (i.e. evidence)
    if CI.shape[0] > 1:
        CI_scores = CI_scores.sum(axis=1, keepdim=True)
        
    sum_vz_scores += CI_scores

print("### Dot Product ###")
print(sum_vz_scores / len(eval_dataloader))

print("### Average Precision ###")
temp_list = list()
for layer in diagram_layers:
    temp_list.append(sum(APs_vz[f'layer{layer}']) / len(APs_vz[f'layer{layer}']))

print(temp_list)


  0%|          | 0/2208 [00:00<?, ?it/s]

### Dot Product ###
tensor([[0.1210],
        [0.1553],
        [0.1273],
        [0.0903],
        [0.1152],
        [0.1271],
        [0.1091],
        [0.1366],
        [0.1186],
        [0.1195],
        [0.1318],
        [0.1573]], dtype=torch.float64)
### Average Precision ###
[0.33297120160731064, 0.35506991249036896, 0.3675264035153822, 0.2629552616098428, 0.31504201560673706, 0.35902219496619014, 0.3075139568833522, 0.3653852097191506, 0.3403612887286204, 0.307320248072889, 0.33685834170319817, 0.3594334375077499]


In [127]:
torch.tensor([[0.1210],
        [0.1553],
        [0.1273],
        [0.0903],
        [0.1152],
        [0.1271],
        [0.1091],
        [0.1366],
        [0.1186],
        [0.1195],
        [0.1318],
        [0.1573]]).squeeze()

tensor([0.1210, 0.1553, 0.1273, 0.0903, 0.1152, 0.1271, 0.1091, 0.1366, 0.1186,
        0.1195, 0.1318, 0.1573])