**Note:** We used the code of Value Zeroing method from https://github.com/hmohebbi/ValueZeroing.

# The Device

In [1]:
import torch

# GPU
if torch.cuda.is_available():
    device = torch.device(f"cuda:{0}")
    print('We will use the GPU:', torch.cuda.get_device_name(0))
else:
    device = torch.device("cpu")
    print('No GPU available, using the CPU instead.')

We will use the GPU: Quadro RTX 8000


# utils

In [7]:
import numpy as np

NUM_LABELS = {
    "ana": 2,
    "dna": 2,
    "dnaa": 2,
    "rpsv": 2,
    "darn": 2,
    "NA": 2,
}

blimp_to_label = {
    'singular': 0,
    'plural': 1,
}

MODEL_PATH = {
    'bert': 'bert-base-uncased',
    'roberta': 'roberta-base',
    'electra': 'google/electra-base-generator',
    'deberta': 'microsoft/deberta-v3-base'
}

BLIMP_TASKS = [
    "ana",
    'dna',
    "dnaa",
    "rpsv",
    "darn",
    "NA",
]

def blimp_to_features(data, tokenizer, max_length, input_masking, mlm):
    all_features = []
    for example in data:
        text = example['sentence_good']
        tokens = []
        cue_indices = []
        # token to id
        for w_ind, word in enumerate(text):
            ids = tokenizer.encode(word, add_special_tokens=False)
            if w_ind in example['cue_indices']:
                cue_indices.append(len(tokens))
            if w_ind == example['target_index']:
                target_index = len(tokens)
            tokens.extend(ids)
        
        tokens = [tokenizer.cls_token_id] + tokens + [tokenizer.sep_token_id]
        cue_indices = [x+1 for x in cue_indices] # 'cause of adding cls
        target_index += 1 # 'cause of adding cls
        if input_masking:
            tokens[target_index] = tokenizer.mask_token_id

        # padding
        length = len(tokens)
        inputs = {}
        inputs['input_ids'] = tokens if max_length is None else tokens + [tokenizer.pad_token_id]*(max_length - length)
        inputs['attention_mask'] = [1]*length if max_length is None else [1]*length + [0]*(max_length - length)
        inputs['token_type_ids'] = [0]*length if max_length is None else [0]*max_length
        inputs['target_index'] = target_index
        inputs['labels'] = tokenizer.convert_tokens_to_ids(example['good_word']) if mlm else blimp_to_label[example['labels']]
        inputs['good_token_id'] = tokenizer.convert_tokens_to_ids(example['good_word'])
        inputs['bad_token_id'] = tokenizer.convert_tokens_to_ids(example['bad_word'])

        # As a 2d tensor, we need all rows to have the same length. So, we add -1 to the end of each list.
        inputs['cue_indices'] = cue_indices + (10 - len(cue_indices)) * [-1]

        all_features.append(inputs)
    return all_features[0] if len(all_features) == 1 else all_features

PREPROCESS_FUNC = {
    'ana': blimp_to_features,
    'dna': blimp_to_features,
    'dnaa': blimp_to_features,
    'rpsv': blimp_to_features,
    'darn': blimp_to_features,
    'NA': blimp_to_features,
}

In [11]:
SELECTED_GPU = 0
MODEL_NAME = 'roberta'
FIXED = False
TASK = "NA"
MAX_LENGTH = 32
NUM_TRAIN_EPOCHS = 1
PER_DEVICE_BATCH_SIZE = 64

INPUT_MASKING = True
MLM = True
LEARNING_RATE = 3e-5
LR_SCHEDULER_TYPE = "linear" 
WARMUP_RATIO = 0.1
SEED = 42

import numpy as np
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

import torch
from torch.utils.data import DataLoader
from torch.nn import CrossEntropyLoss

from datasets import (
    load_dataset,
    load_from_disk,
    load_metric,
)

from transformers import BertForMaskedLM, RobertaForMaskedLM
# from customized_modeling_bert import BertForMaskedLM
# from modeling.customized_modeling_roberta import RobertaForMaskedLM
# from modeling.customized_modeling_electra import ElectraForMaskedLM
from transformers import (
    AutoConfig,
    AutoTokenizer,
    AdamW,
    get_scheduler,
    default_data_collator,
    set_seed,
)
set_seed(SEED)

# Load Dataset
if TASK in BLIMP_TASKS:
    data_path = f"./BLIMP Dataset/{MODEL_NAME}/"
    data = load_from_disk(data_path)
    train_data = data['train']
    eval_data = data['test']
else:
    print("Not implemented yet!")
    exit()
train_data = train_data.shuffle(SEED)
num_labels = NUM_LABELS[TASK]

# Download Tokenizer & Model
config = AutoConfig.from_pretrained(MODEL_PATH[MODEL_NAME], num_labels=num_labels)
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH[MODEL_NAME])  

if MODEL_NAME == "bert":
    model = BertForMaskedLM.from_pretrained(MODEL_PATH[MODEL_NAME], config=config)
elif MODEL_NAME == "roberta":
    model = RobertaForMaskedLM.from_pretrained(MODEL_PATH[MODEL_NAME], config=config)
# elif MODEL_NAME == "electra":
#     model = ElectraForMaskedLM.from_pretrained(MODEL_PATH[MODEL_NAME], config=config)
else:
    print("model doesn't exist")

model.to(device)

# Preprocessing
train_dataset = PREPROCESS_FUNC[TASK](train_data, tokenizer, MAX_LENGTH, input_masking=INPUT_MASKING, mlm=MLM)
eval_dataset = PREPROCESS_FUNC[TASK](eval_data, tokenizer, MAX_LENGTH, input_masking=INPUT_MASKING, mlm=MLM)

train_dataloader = DataLoader(train_dataset, shuffle=True, collate_fn= default_data_collator, batch_size=PER_DEVICE_BATCH_SIZE)
eval_dataloader = DataLoader(eval_dataset, collate_fn= default_data_collator, batch_size=PER_DEVICE_BATCH_SIZE)

num_update_steps_per_epoch = len(train_dataloader)
max_train_steps = NUM_TRAIN_EPOCHS * num_update_steps_per_epoch 

# Optimizer
optimizer = AdamW(model.parameters(), lr=LEARNING_RATE)
lr_scheduler = get_scheduler(
        name=LR_SCHEDULER_TYPE,
        optimizer=optimizer,
        num_warmup_steps=WARMUP_RATIO * max_train_steps,
        num_training_steps=max_train_steps,
    )

# metric & Loss
metric = load_metric("accuracy")
loss_fct = CrossEntropyLoss()

tag = "forseqclassification_"
tag += "pretrained" if FIXED else "finetuned" 
if MLM:
    tag += "_MLM"

Loading cached shuffled indices for dataset at BLIMP Dataset/roberta/train/cache-3c138be14de5eff0.arrow


# Training

In [12]:
# Train
progress_bar = tqdm(range(max_train_steps))
completed_steps = 0
for epoch in range(NUM_TRAIN_EPOCHS):
#     Train
    model.train()
    for batch in train_dataloader:
        good_token_id = batch.pop('good_token_id').to(device)
        bad_token_id = batch.pop('bad_token_id').to(device)
        cue_indices = batch.pop('cue_indices').to(device)
        target_index = batch.pop('target_index').to(device) # size: [64]
        batch.pop('labels').to(device)
        batch = {k: v.to(device) for k, v in batch.items()} 
    
        outputs = model(**batch) # outputs.logit.size(): [64, 32, 30522]
        # To output the logits correspondig to target index, value zeroing's code changed the BertForMaskedLM in which
        # the model outputs the logit directly. Now we use the original BertForMaskedLM and from it's output find the logits
        # corresponding to target index as shown down below.
        logits = outputs.logits[torch.arange(outputs.logits.size(0)), target_index] # size: [64, 30522]
        good_logits = logits[torch.arange(logits.size(0)), good_token_id] # size: [64]
        bad_logits = logits[torch.arange(logits.size(0)), bad_token_id] # size: [64]
        
        logits_of_interest = torch.stack([good_logits, bad_logits], dim=1) # size: [64, 2]
        labels = torch.zeros(logits_of_interest.shape[0], dtype=torch.int64, device=device) # size: [64]        
        loss = loss_fct(logits_of_interest, labels)
        
        loss.backward()
        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()
        progress_bar.update(1)
        completed_steps += 1

    model.eval()
    for batch in eval_dataloader:
        if MLM:
            good_token_id = batch.pop('good_token_id').to(device)
            bad_token_id = batch.pop('bad_token_id').to(device)
        target_index = batch.pop('target_index').to(device) # size: [64]
        batch.pop('labels').to(device)
        batch.pop('cue_indices').to(device)
        batch = {k: v.to(device) for k, v in batch.items()}
        with torch.no_grad():
            outputs = model(**batch)
        logits = outputs.logits[torch.arange(outputs.logits.size(0)), target_index] # size: [64, 30522]

        if MLM:
            good_logits = logits[torch.arange(logits.size(0)), good_token_id]
            bad_logits = logits[torch.arange(logits.size(0)), bad_token_id]
            logits_of_interest = torch.stack([good_logits, bad_logits], dim=1)
            labels = torch.zeros(logits_of_interest.shape[0], dtype=torch.int64, device=device)
            predictions = torch.argmax(logits_of_interest, dim=-1)
            metric.add_batch(predictions=predictions, references=labels)
        else:
            predictions = torch.argmax(logits, dim=-1)
            metric.add_batch(predictions=predictions, references=batch['labels'])

    eval_metric = metric.compute()
    print(f"epoch {epoch}: {eval_metric}")


  0%|          | 0/35 [00:00<?, ?it/s]

epoch 0: {'accuracy': 0.9904891304347826}


In [66]:
# torch.save(model.state_dict(), f'{MODEL_NAME}_full_{tag}_epoch{NUM_TRAIN_EPOCHS}.pt')

# Evaluation

In [158]:
model.load_state_dict(torch.load(f'{MODEL_NAME}_full_{tag}_epoch{NUM_TRAIN_EPOCHS}.pt'))

<All keys matched successfully>

In [5]:
model.eval()
for batch in eval_dataloader:
    if MLM:
        good_token_id = batch.pop('good_token_id').to(device)
        bad_token_id = batch.pop('bad_token_id').to(device)
    target_index = batch.pop('target_index').to(device) # size: [64]
    batch.pop('labels').to(device)
    batch.pop('cue_indices').to(device)
    batch = {k: v.to(device) for k, v in batch.items()}
    with torch.no_grad():
        outputs = model(**batch)
    logits = outputs.logits[torch.arange(outputs.logits.size(0)), target_index] # size: [64, 30522]

    if MLM:
        good_logits = logits[torch.arange(logits.size(0)), good_token_id]
        bad_logits = logits[torch.arange(logits.size(0)), bad_token_id]
        logits_of_interest = torch.stack([good_logits, bad_logits], dim=1)
        labels = torch.zeros(logits_of_interest.shape[0], dtype=torch.int64, device=device)
        predictions = torch.argmax(logits_of_interest, dim=-1)
        metric.add_batch(predictions=predictions, references=labels)
    else:
        predictions = torch.argmax(logits, dim=-1)
        metric.add_batch(predictions=predictions, references=batch['labels'])

eval_metric = metric.compute()
print(f"epoch: {eval_metric}")  


epoch: {'accuracy': 0.9904891304347826}
