In [1]:
import torch
from torch import nn
from transformers import T5Tokenizer, T5ForConditionalGeneration, AdamW
from peft import get_peft_model, LoraConfig, TaskType

from datasets import load_dataset
from torchmetrics import Accuracy
from transformers import T5Tokenizer, T5ForConditionalGeneration, Trainer, TrainingArguments
from peft import LoraConfig, get_peft_model
from torch.utils.data import Dataset, DataLoader

In [3]:
dataset = load_dataset('glue', 'mnli')

train_dataset = dataset['train']
validation_dataset = dataset['validation_matched']
test_dataset = dataset['test_matched']

model_name = 'google/flan-t5-small'

tokenizer = T5Tokenizer.from_pretrained(model_name)

def preprocess_function(examples):
    inputs = ["premise: " + premise + " hypothesis: " + hypothesis for premise, hypothesis in zip(examples['premise'], examples['hypothesis'])]
    model_inputs = tokenizer(inputs, max_length=512, truncation=True, padding='max_length')

    model_inputs["labels"] = examples["label"]    
    return model_inputs

train_dataset = train_dataset.map(preprocess_function)
validation_dataset = validation_dataset.map(preprocess_function)

train_dataset = train_dataset.remove_columns(['premise', 'hypothesis', 'idx'])
validation_dataset = validation_dataset.remove_columns(['premise', 'hypothesis', 'idx'])

train_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])
validation_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])

train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)
validation_loader = DataLoader(validation_dataset, batch_size=32, shuffle=True)

Map:   0%|          | 0/392702 [00:00<?, ? examples/s]

In [3]:
# Load the tokenizer and model
model_name = "google/flan-t5-base"
tokenizer = T5Tokenizer.from_pretrained(model_name)
model = T5ForConditionalGeneration.from_pretrained(model_name)

lora_config = LoraConfig(
    task_type=TaskType.SEQ_2_SEQ_LM,
    r=8,
    lora_alpha=16, 
    lora_dropout=0.1,  
    target_modules=["q", "v"]  
)

model = get_peft_model(model, lora_config)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
classification_head = nn.Linear(model.config.d_model, 3).to(device)
optimizer = AdamW(list(model.parameters()) + list(classification_head.parameters()), lr=5e-5)
loss_fn = nn.CrossEntropyLoss()


epochs = 3 
model.train()
classification_head.train()

for epoch in range(epochs):
    total_loss = 0
    for batch in validation_loader:
        optimizer.zero_grad()

        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)

        outputs = model.encoder(input_ids=input_ids, attention_mask=attention_mask)
        hidden_state = outputs.last_hidden_state[:, 0, :] 
        logits = classification_head(hidden_state)

        loss = loss_fn(logits, labels)
        total_loss += loss.item()

        loss.backward()
        optimizer.step()

    print(f"Epoch {epoch+1}/{epochs}, Loss: {total_loss/len(validation_loader)}")




OutOfMemoryError: CUDA out of memory. Tried to allocate 384.00 MiB. GPU 0 has a total capacity of 12.00 GiB of which 0 bytes is free. Of the allocated memory 24.54 GiB is allocated by PyTorch, and 241.37 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [4]:
model.eval()
classification_head.eval()

accuracy_metric = Accuracy(num_classes=3, task="multiclass").to(device)

accuracy_metric.reset() 

with torch.no_grad():
    for batch in validation_loader:
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)

        outputs = model.encoder(input_ids=input_ids, attention_mask=attention_mask)
        hidden_state = outputs.last_hidden_state[:, 0, :]
        logits = classification_head(hidden_state)

        predictions = torch.argmax(logits, dim=-1)

        # Update accuracy metric
        accuracy_metric.update(predictions, labels)

# Compute and print accuracy
accuracy = accuracy_metric.compute()
print(f"Accuracy: {accuracy.item() * 100:.2f}%")

Accuracy: 29.69%


In [5]:
model.save_pretrained("./lora_trained/lora-t5-mnli")  # This saves the LoRA matrices to the "lora_model" directory

# Later, to load the saved LoRA model, you can do this:

# Load the base model again
# base_model = T5ForConditionalGeneration.from_pretrained(model_name).to(device)

# # Load the saved LoRA model and combine it with the base model
# peft_model = PeftModel.from_pretrained(base_model, "lora_model").to(device)