In [None]:
# Using distilbert to finetune on squad, doing different versions of r and a full finetune
# Load model directly
from transformers import AutoTokenizer, DistilBertForQuestionAnswering

tokenizer = AutoTokenizer.from_pretrained("distilbert/distilbert-base-uncased")
model = DistilBertForQuestionAnswering.from_pretrained("distilbert/distilbert-base-uncased")

In [None]:
import torch
import torch.nn as nn

class LoRALayer(nn.Module):
    def __init__(self, weight, bias, r, alpha):
        super(LoRALayer, self).__init__()
        self.weight = weight
        self.weight.requires_grad = False
        self.bias = bias
        self.r = r
        self.alpha = alpha
        out_features = self.weight.shape[0]
        in_features = self.weight.shape[1]
        self.A = nn.Parameter(self.weight.new_zeros(self.r, in_features))
        self.B = nn.Parameter(self.weight.new_zeros(out_features, r))
    
    def forward(self, x):
        result = x @ self.weight.T
        result = torch.add(result, self.bias)
        result = torch.add(result, x @ (self.A.T @ self.B.T))
        return result

In [None]:
#replace all the attention layers in model with LoRA layers
r = 4
alpha = 0
for name, module in model.named_modules():
    if isinstance(module, nn.Linear) and "_lin" in str(name):
        lora_layer = LoRALayer(module.weight, module.bias, r, alpha)
        # Replace the module directly in the parent's _modules dictionary
        parent_name, child_name = name.rsplit('.', 1)
        parent_module = dict(model.named_modules())[parent_name]
        parent_module._modules[child_name] = lora_layer

In [None]:
for _, param in model.named_parameters():
    param.requires_grad = False

for name, module in model.named_modules():
    if isinstance(module, LoRALayer):
        for param in module.parameters():
            param.requires_grad = True

for name, param in model.named_parameters():
    if "_lin" in name: assert param.requires_grad == True
    else: assert param.requires_grad == False

In [None]:
from datasets import load_dataset

dataset = load_dataset("rajpurkar/squad")

# Function to tokenize the data suitable for question answering
def prepare_train_features(examples):
    tokenized_examples = tokenizer(
        examples["question"],
        examples["context"],
        truncation="only_second",
        max_length=512,
        stride=128,
        return_overflowing_tokens=False,
        return_offsets_mapping=True,
        padding="max_length"
    )
    
    # We need to find where the answers are in the tokenized context
    start_positions = []
    end_positions = []
    
    for i, offsets in enumerate(tokenized_examples["offset_mapping"]):
        # We assume that each question has exactly one answer
        start_char = examples["answers"][i]["answer_start"][0]
        end_char = start_char + len(examples["answers"][i]["text"][0]) - 1
        
        # Convert character start and end positions to token start and end positions
        sequence_ids = tokenized_examples.sequence_ids(i)
        
        # Find start and end token index for the answers
        start_index = next(
            (idx for idx, (offset, seq_id) in enumerate(zip(offsets, sequence_ids)) if seq_id == 1 and offset[0] <= start_char < offset[1]),
            None
        )
        end_index = next(
            (idx for idx, (offset, seq_id) in enumerate(zip(offsets, sequence_ids)) if seq_id == 1 and offset[0] < end_char <= offset[1]),
            None
        )
        
        start_positions.append(start_index)
        end_positions.append(end_index)
    
    # Update tokenized examples with the start and end positions
    tokenized_examples["start_positions"] = start_positions
    tokenized_examples["end_positions"] = end_positions
    
    return tokenized_examples

# Apply the function to the train dataset
train_dataset = dataset['train'].map(prepare_train_features, batched=True)
train_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'start_positions', 'end_positions'])
val_dataset = dataset['validation'].map(prepare_train_features, batched=True)
val_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'start_positions', 'end_positions'])

In [None]:
example = train_dataset[0]
decoded_text = tokenizer.decode(example['input_ids'])
print(decoded_text)
answer_tokens = example['input_ids'][example['start_positions']:example['end_positions']+1]
decoded_answer = tokenizer.decode(answer_tokens)
print(decoded_answer)

In [None]:
from torch.utils.data import DataLoader
from torch.optim import AdamW
from tqdm.auto import tqdm

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
optimizer = AdamW(model.parameters(), lr=5e-5)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

def train(epochs, model, optimizer, dataloader):
    model.train()
    for epoch in range(epochs):
        total_loss = 0
        num_batches = 0
        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}", leave=False)
        for batch in dataloader:
            batch = {k: v.to(device) for k, v in batch.items()}
            outputs = model(**batch)
            loss = outputs.loss
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            num_batches += 1
            progress_bar.set_postfix({'Training Loss': loss.item()})
        avg_loss = total_loss / num_batches
        print(f"Average Training Loss for Epoch {epoch+1}: {avg_loss}")
def eval(model, dataloader):
    model.eval()
    total_eval_loss = 0
    num_eval_batches = 0
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Evaluating", leave=False):
            batch = {k: v.to(device) for k, v in batch.items()}
            outputs = model(**batch)
            loss = outputs.loss
            total_eval_loss += loss.item()
            num_eval_batches += 1

    avg_eval_loss = total_eval_loss / num_eval_batches
    print(f"Average Validation Loss: {avg_eval_loss}")

train(5, model, optimizer, train_loader)
eval(model, val_loader)