In [1]:
import torch
from torch import nn
from transformers import T5Tokenizer, T5ForConditionalGeneration, AdamW
from peft import get_peft_model, LoraConfig, TaskType

from datasets import load_dataset
from torchmetrics import Accuracy
from transformers import T5Tokenizer, T5ForConditionalGeneration, Trainer, TrainingArguments
from peft import LoraConfig, get_peft_model
from torch.utils.data import Dataset, DataLoader

In [None]:
dataset = load_dataset('glue','rte')

train_dataset = dataset['train']
validation_dataset = dataset['validation']
test_dataset = dataset['test']

model_name = 'google/flan-t5-base'

tokenizer = T5Tokenizer.from_pretrained(model_name)

def preprocess_function(examples):
    inputs = [f"premise: {premise} hypothesis: {hypothesis}" for premise, hypothesis in zip(examples['sentence1'], examples['sentence2'])]
    model_inputs = tokenizer(inputs, max_length=512, truncation=True, padding='max_length')

    model_inputs["labels"] = examples["label"]    
    return model_inputs

train_dataset = train_dataset.map(preprocess_function, batched=True)
validation_dataset = validation_dataset.map(preprocess_function, batched=True)

train_dataset = train_dataset.remove_columns(['sentence1', 'sentence2', 'idx'])
validation_dataset = validation_dataset.remove_columns(['sentence1', 'sentence2', 'idx'])

train_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])
validation_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])

# train_dataset = train_dataset.select(range(1000))
train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)
validation_loader = DataLoader(validation_dataset, batch_size=32, shuffle=True)

In [None]:
# Load the tokenizer and model
model_name = "google/flan-t5-base"
tokenizer = T5Tokenizer.from_pretrained(model_name)
model = T5ForConditionalGeneration.from_pretrained(model_name)

lora_config = LoraConfig(
    task_type=TaskType.SEQ_2_SEQ_LM,
    r=8,
    lora_alpha=16, 
    lora_dropout=0.1,  
    target_modules=["q", "v"]  
)

model = get_peft_model(model, lora_config)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
classification_head = nn.Linear(model.config.d_model, 3).to(device)
optimizer = AdamW(list(model.parameters()) + list(classification_head.parameters()), lr=5e-5)
loss_fn = nn.CrossEntropyLoss()


epochs = 3 
model.train()
classification_head.train()

for epoch in range(epochs):
    total_loss = 0
    for batch in train_dataloader:
        optimizer.zero_grad()

        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)

        outputs = model.encoder(input_ids=input_ids, attention_mask=attention_mask)
        hidden_state = outputs.last_hidden_state[:, 0, :] 
        logits = classification_head(hidden_state)

        loss = loss_fn(logits, labels)
        total_loss += loss.item()

        loss.backward()
        optimizer.step()

    print(f"Epoch {epoch+1}/{epochs}, Loss: {total_loss/len(train_dataloader)}")


In [None]:
model.save_pretrained("./lora_trained/lora-t5-rte")  

# Load the base model again
# base_model = T5ForConditionalGeneration.from_pretrained(model_name).to(device)
# peft_model = PeftModel.from_pretrained(base_model, "lora_model").to(device)