In [6]:
# Using distilbert to finetune on squad, doing different versions of r and a full finetune
# Load model directly
from transformers import AutoTokenizer, DistilBertForQuestionAnswering

tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
model = DistilBertForQuestionAnswering.from_pretrained("distilbert-base-uncased")

Some weights of DistilBertForQuestionAnswering were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['qa_outputs.bias', 'qa_outputs.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [7]:
model.modules

<bound method Module.modules of DistilBertForQuestionAnswering(
  (distilbert): DistilBertModel(
    (embeddings): Embeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (transformer): Transformer(
      (layer): ModuleList(
        (0-5): 6 x TransformerBlock(
          (attention): MultiHeadSelfAttention(
            (dropout): Dropout(p=0.1, inplace=False)
            (q_lin): Linear(in_features=768, out_features=768, bias=True)
            (k_lin): Linear(in_features=768, out_features=768, bias=True)
            (v_lin): Linear(in_features=768, out_features=768, bias=True)
            (out_lin): Linear(in_features=768, out_features=768, bias=True)
          )
          (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (ffn): FFN(
            (dropout): Drop

In [8]:
import torch
import torch.nn as nn

class LoRALayer(nn.Module):
    def __init__(self, weight, bias, r, alpha):
        super(LoRALayer, self).__init__()
        self.weight = weight
        self.weight.requires_grad = False
        self.bias = bias
        self.r = r
        self.alpha = alpha
        out_features = self.weight.shape[0]
        in_features = self.weight.shape[1]
        self.A = nn.Parameter(self.weight.new_zeros(self.r, in_features))
        self.B = nn.Parameter(self.weight.new_zeros(out_features, r))
    
    def forward(self, x):
        result = x @ self.weight.T
        result = torch.add(result, self.bias)
        result = torch.add(result, x @ (self.A.T @ self.B.T))
        return result

In [9]:
#replace all the attention layers in model with LoRA layers
r = 4
alpha = 0
for name, module in model.named_modules():
    if isinstance(module, nn.Linear) and "_lin" in str(name):
        lora_layer = LoRALayer(module.weight, module.bias, r, alpha)
        # Replace the module directly in the parent's _modules dictionary
        parent_name, child_name = name.rsplit('.', 1)
        parent_module = dict(model.named_modules())[parent_name]
        parent_module._modules[child_name] = lora_layer

In [10]:
for _, param in model.named_parameters():
    param.requires_grad = False

for name, module in model.named_modules():
    if isinstance(module, LoRALayer):
        for param in module.parameters():
            param.requires_grad = True
    if name == "qa_outputs":
        for param in module.parameters():
            param.requires_grad = True

for name, param in model.named_parameters():
    if "_lin" in name or "qa_outputs" in name: assert param.requires_grad == True
    else: assert param.requires_grad == False

In [11]:
from datasets import load_dataset

dataset = load_dataset("rajpurkar/squad")

# Function to tokenize the data suitable for question answering
def prepare_train_features(examples):
    tokenized_examples = tokenizer(
        examples["question"],
        examples["context"],
        truncation="only_second",
        max_length=512,
        stride=128,
        return_overflowing_tokens=False,
        return_offsets_mapping=True,
        padding="max_length"
    )
 
    # We need to find where the answers are in the tokenized context
    start_positions = []
    end_positions = []
 
    for i, offsets in enumerate(tokenized_examples["offset_mapping"]):
        # We assume that each question has exactly one answer
        start_char = examples["answers"][i]["answer_start"][0]
        end_char = start_char + len(examples["answers"][i]["text"][0]) - 1
 
        # Convert character start and end positions to token start and end positions
        sequence_ids = tokenized_examples.sequence_ids(i)
 
        # Find start and end token index for the answers
        start_index = next(
            (idx for idx, (offset, seq_id) in enumerate(zip(offsets, sequence_ids)) if seq_id == 1 and offset[0] <= start_char < offset[1]),
            -1
        )
        end_index = next(
            (idx for idx, (offset, seq_id) in enumerate(zip(offsets, sequence_ids)) if seq_id == 1 and offset[0] < end_char <= offset[1]),
            -1
        )

        start_positions.append(start_index)
        end_positions.append(end_index)
 
    # Update tokenized examples with the start and end positions
    tokenized_examples["start_positions"] = start_positions
    tokenized_examples["end_positions"] = end_positions
 
    return tokenized_examples

# Apply the function to the train dataset
train_dataset = dataset['train'].map(prepare_train_features, batched=True)
train_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'start_positions', 'end_positions'])
val_dataset = dataset['validation'].map(prepare_train_features, batched=True)
val_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'start_positions', 'end_positions'])

Map: 100%|███████████████████████| 87599/87599 [00:17<00:00, 4939.91 examples/s]


In [12]:
def find_none_positions(dataset, columns=['start_positions', 'end_positions']):
    none_positions = {}
    for column in columns:
        none_positions[column] = [i for i, val in enumerate(dataset[column]) if val is None]
    return none_positions

# Check the train dataset
none_positions_train = find_none_positions(train_dataset)
print("Train dataset None positions:")
print(none_positions_train)

# Check the validation dataset
none_positions_val = find_none_positions(val_dataset)
print("Validation dataset None positions:")
print(none_positions_val)

Train dataset None positions:
{'start_positions': [], 'end_positions': []}
Validation dataset None positions:
{'start_positions': [], 'end_positions': []}


In [13]:
for val in train_dataset['start_positions']:
    print(val)

tensor(130)
tensor(52)
tensor(81)
tensor(95)
tensor(33)
tensor(63)
tensor(98)
tensor(123)
tensor(39)
tensor(182)
tensor(36)
tensor(44)
tensor(59)
tensor(84)
tensor(149)
tensor(107)
tensor(22)
tensor(43)
tensor(74)
tensor(47)
tensor(111)
tensor(28)
tensor(47)
tensor(135)
tensor(90)
tensor(133)
tensor(230)
tensor(38)
tensor(167)
tensor(15)
tensor(118)
tensor(78)
tensor(90)
tensor(84)
tensor(39)
tensor(42)
tensor(57)
tensor(86)
tensor(100)
tensor(33)
tensor(46)
tensor(67)
tensor(119)
tensor(147)
tensor(31)
tensor(95)
tensor(136)
tensor(53)
tensor(188)
tensor(20)
tensor(80)
tensor(100)
tensor(133)
tensor(33)
tensor(20)
tensor(45)
tensor(43)
tensor(65)
tensor(30)
tensor(22)
tensor(201)
tensor(216)
tensor(232)
tensor(37)
tensor(12)
tensor(25)
tensor(59)
tensor(18)
tensor(28)
tensor(37)
tensor(93)
tensor(149)
tensor(168)
tensor(29)
tensor(71)
tensor(310)
tensor(339)
tensor(23)
tensor(16)
tensor(45)
tensor(84)
tensor(24)
tensor(88)
tensor(24)
tensor(58)
tensor(86)
tensor(267)
tensor(71)
tensor

In [14]:
example = train_dataset[0]
decoded_text = tokenizer.decode(example['input_ids'])
print(decoded_text)
answer_tokens = example['input_ids'][example['start_positions']:example['end_positions']+1]
decoded_answer = tokenizer.decode(answer_tokens)
print(decoded_answer)

[CLS] to whom did the virgin mary allegedly appear in 1858 in lourdes france? [SEP] architecturally, the school has a catholic character. atop the main building's gold dome is a golden statue of the virgin mary. immediately in front of the main building and facing it, is a copper statue of christ with arms upraised with the legend " venite ad me omnes ". next to the main building is the basilica of the sacred heart. immediately behind the basilica is the grotto, a marian place of prayer and reflection. it is a replica of the grotto at lourdes, france where the virgin mary reputedly appeared to saint bernadette soubirous in 1858. at the end of the main drive ( and in a direct line that connects through 3 statues and the gold dome ), is a simple, modern stone statue of mary. [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] 

In [15]:
from torch.utils.data import DataLoader
from torch.optim import AdamW
from tqdm.auto import tqdm

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False)
optimizer = AdamW(model.parameters(), lr=5e-5)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

def train(epochs, model, optimizer, dataloader):
    model.train()
    for epoch in range(epochs):
        total_loss = 0
        num_batches = 0
        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}", leave=False)
        for batch in dataloader:
            batch = {k: v.to(device) for k, v in batch.items()}
            outputs = model(**batch)
            loss = outputs.loss
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            num_batches += 1
            progress_bar.set_postfix({'Training Loss': loss.item()})
        avg_loss = total_loss / num_batches
        print(f"Average Training Loss for Epoch {epoch+1}: {avg_loss}")
def eval(model, dataloader):
    model.eval()
    total_eval_loss = 0
    num_eval_batches = 0
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Evaluating", leave=False):
            batch = {k: v.to(device) for k, v in batch.items()}
            outputs = model(**batch)
            loss = outputs.loss
            total_eval_loss += loss.item()
            num_eval_batches += 1

    avg_eval_loss = total_eval_loss / num_eval_batches
    print(f"Average Validation Loss: {avg_eval_loss}")

train(5, model, optimizer, train_loader)
eval(model, val_loader)

Epoch 1:   0%|                    | 0/5475 [13:40<?, ?it/s, Training Loss=0.804]

KeyboardInterrupt: 

In [None]:
import os
os._exit(00)

: 

: 

: 