# Load Library

In [1]:
import os
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
os.environ["TOKENIZERS_PARALLELISM"] = "false"


In [2]:
!pip install huggingface_hub
!pip install transformers datasets torch
!pip install --upgrade torch
!pip install --upgrade pip
!pip install --disable-pip-version-check \
    torch \
    torchdata \
    transformers[torch] \
    evaluate \
    rouge_score \
    loralib \
    datasets \

!pip install 'accelerate>=0.26.0' --quiet

[0m

In [3]:
import torch
torch.cuda.is_available()

True

In [4]:
nvidiagpu = !nvidia-smi
nvidiagpu

['Tue Dec 24 10:19:08 2024       ',
 '+---------------------------------------------------------------------------------------+',
 '| NVIDIA-SMI 535.154.05             Driver Version: 535.154.05   CUDA Version: 12.2     |',
 '|-----------------------------------------+----------------------+----------------------+',
 '| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |',
 '| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |',
 '|                                         |                      |               MIG M. |',
 '|   0  NVIDIA GeForce RTX 3090        Off | 00000000:0B:00.0 Off |                  N/A |',
 '|  0%   34C    P0             108W / 420W |      3MiB / 24576MiB |      4%      Default |',
 '|                                         |                      |                  N/A |',
 '+-----------------------------------------+----------------------+----------------------+',
 '                      

# Load Data & EDA

In [5]:
from datasets import load_dataset
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, GenerationConfig, TrainingArguments, Trainer
import torch
import pandas as pd
import numpy as np

In [6]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# torch.autograd.set_detect_anomaly(True)
f"Using device: {device}"

'Using device: cuda'

In [7]:
from datasets import load_dataset

dataset = load_dataset("super_glue", "rte", trust_remote_code=True)
dataset


Using the latest cached version of the module from /root/.cache/huggingface/modules/datasets_modules/datasets/super_glue/bb9675f958ebfee0d5d6dc5476fafe38c79123727a7258d515c450873dbdbbed (last modified on Tue Dec 24 09:36:11 2024) since it couldn't be found locally at super_glue, or remotely on the Hugging Face Hub.


DatasetDict({
    train: Dataset({
        features: ['premise', 'hypothesis', 'idx', 'label'],
        num_rows: 2490
    })
    validation: Dataset({
        features: ['premise', 'hypothesis', 'idx', 'label'],
        num_rows: 277
    })
    test: Dataset({
        features: ['premise', 'hypothesis', 'idx', 'label'],
        num_rows: 3000
    })
})

In [8]:
import hashlib
import datetime
import random

LABEL_MAP = {0: "Entailment", 1: "Neutral", 2: "Contradiction"}
TEMPLATE_VARIANTS = [
    "Given the premise and hypothesis below, identify whether the hypothesis logically follows from the premise.",
    "Determine the logical relationship between the following premise and hypothesis.",
    "Does the hypothesis follow, contradict, or remain neutral to the premise provided below?",
    "Classify the relationship between the provided premise and hypothesis as entailment, contradiction, or neutral.",
    "Based on the premise, decide if the hypothesis is entailed, neutral, or contradicting it.",
    "Analyze the premise and hypothesis to classify their logical connection.",
    "Evaluate whether the hypothesis is supported, unrelated, or contradicted by the premise.",
    "Read the premise and hypothesis carefully and classify their relationship."
]

def generate_unique_id(premise, hypothesis):
    return hashlib.md5(f"{premise}{hypothesis}".encode()).hexdigest()

def generate_metadata(sample, unique_id):
    return {
        "idx": sample.get("idx", None),
        "source": "SuperGLUE RTE",
        "timestamp": datetime.datetime.now().isoformat(),
        "unique_id": unique_id,
        "lengths": {
            "premise": len(sample["premise"].split()),
            "hypothesis": len(sample["hypothesis"].split())
        },
    }

def process_superglue_rte(sample):
    label = LABEL_MAP.get(sample["label"], str(sample["label"]).capitalize())
    # Replace '-1' with 'Neutral' in the dataset
    if label == "-1":
        label = "Neutral"
        
    unique_id = generate_unique_id(sample["premise"], sample["hypothesis"])
    metadata = generate_metadata(sample, unique_id)
    instruction = random.choice(TEMPLATE_VARIANTS)
    return {
        "instruction": instruction,
        "input": {
            "premise": sample["premise"],
            "hypothesis": sample["hypothesis"]
        },
        "output": label,
        "metadata": metadata
    }

def process_superglue_dataset(task_name, dataset):
    if task_name != "rte":
        raise ValueError(f"Task '{task_name}' is not supported.")
    return [process_superglue_rte(sample) for sample in dataset]

trainData = process_superglue_dataset('rte', dataset['train'])
testData = process_superglue_dataset('rte', dataset['test'])
valData = process_superglue_dataset('rte', dataset['validation'])


In [9]:
trainData[0]


{'instruction': 'Does the hypothesis follow, contradict, or remain neutral to the premise provided below?',
 'input': {'premise': 'No Weapons of Mass Destruction Found in Iraq Yet.',
  'hypothesis': 'Weapons of Mass Destruction Found in Iraq.'},
 'output': 'Neutral',
 'metadata': {'idx': 0,
  'source': 'SuperGLUE RTE',
  'timestamp': '2024-12-24T10:19:22.897152',
  'unique_id': 'd66c49c494a8aa9999ea35c06205542c',
  'lengths': {'premise': 9, 'hypothesis': 7}}}

# Data preprocessing

# Tokenize

In [10]:
from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-small")

In [11]:
from transformers import AutoModelForSeq2SeqLM
import torch
import torch.nn as nn

class LVModel(nn.Module):
    def __init__(self, base_model):
        super(LVModel, self).__init__()
        self.base_model = base_model
        self.dropout = nn.Dropout(p=0.3)
        self.layer_norm = nn.LayerNorm(self.base_model.config.d_model)
        self.classifier = nn.Linear(self.base_model.config.d_model, 3)
        
        # Ensure weight sharing is maintained
        self.base_model.shared = self.base_model.encoder.embed_tokens
        self.base_model.decoder.embed_tokens = self.base_model.encoder.embed_tokens

    def forward(self, input_ids, attention_mask=None, decoder_input_ids=None, labels=None):
        outputs = self.base_model(
            input_ids=input_ids, 
            attention_mask=attention_mask, 
            decoder_input_ids=decoder_input_ids, 
            labels=labels
        )
        logits = self.dropout(outputs.logits)
        
        if labels is not None:
            loss_fn = nn.CrossEntropyLoss()
            loss = loss_fn(logits.view(-1, logits.size(-1)), labels.view(-1))
            return {"loss": loss, "logits": logits}
        else:
            return {"logits": logits}

    def save_pretrained(self, path):
        # Save the model configuration
        self.base_model.config.save_pretrained(path)
        
        # Save the model weights
        state_dict = self.state_dict()
        
        # Remove duplicate weights
        if 'base_model.decoder.embed_tokens.weight' in state_dict:
            del state_dict['base_model.decoder.embed_tokens.weight']
        if 'base_model.shared.weight' in state_dict:
            del state_dict['base_model.shared.weight']
            
        torch.save(state_dict, f"{path}/FlanT5.bin")

# Load base model
base_model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-small")
model = LVModel(base_model=base_model)
model.to(device)


LVModel(
  (base_model): T5ForConditionalGeneration(
    (shared): Embedding(32128, 512)
    (encoder): T5Stack(
      (embed_tokens): Embedding(32128, 512)
      (block): ModuleList(
        (0): T5Block(
          (layer): ModuleList(
            (0): T5LayerSelfAttention(
              (SelfAttention): T5Attention(
                (q): Linear(in_features=512, out_features=384, bias=False)
                (k): Linear(in_features=512, out_features=384, bias=False)
                (v): Linear(in_features=512, out_features=384, bias=False)
                (o): Linear(in_features=384, out_features=512, bias=False)
                (relative_attention_bias): Embedding(32, 6)
              )
              (layer_norm): T5LayerNorm()
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (1): T5LayerFF(
              (DenseReluDense): T5DenseGatedActDense(
                (wi_0): Linear(in_features=512, out_features=1024, bias=False)
                (wi_1): Linear(i

In [12]:
def print_number_of_trainable_model_parameters(model):
    trainable_model_params = 0
    all_model_params = 0
    for _, param in model.named_parameters():
        all_model_params += param.numel()
        if param.requires_grad:
            trainable_model_params += param.numel()

    return \
        f"Trainable model parameters: {trainable_model_params}\n" +\
        f"All model parameters: {all_model_params}\n" +\
        f"Percentage of trainable model parameters: {100 * trainable_model_params / all_model_params:.2f}%"

print(print_number_of_trainable_model_parameters(model))

Trainable model parameters: 76963715
All model parameters: 76963715
Percentage of trainable model parameters: 100.00%


In [13]:
from transformers import T5Tokenizer
import torch

# Hàm chuẩn bị dữ liệu
def prepare_data_for_training(data, tokenizer, max_length=128):
    """
    Hàm chuẩn bị dữ liệu đầu vào cho mô hình học sâu.
    
    Args:
    - data (list): Danh sách các ví dụ dữ liệu chứa tiền đề, giả thuyết và nhãn.
    - tokenizer (T5Tokenizer): Tokenizer dùng để token hóa văn bản.
    - max_length (int): Độ dài tối đa của chuỗi token.

    Returns:
    - inputs_tensor (torch.Tensor): Tensor đầu vào (input_ids) cho mô hình.
    - labels_tensor (torch.Tensor): Tensor nhãn cho mô hình.
    """
    inputs = []
    labels = []

    # Mã hóa dữ liệu
    for example in data:
        premise = example['input']['premise']
        hypothesis = example['input']['hypothesis']
        
        # Hợp nhất tiền đề và giả thuyết với dấu phân cách "<sep>"
        input_text = f"{premise} <sep> {hypothesis}"
        
        # Token hóa và chuẩn hóa độ dài chuỗi
        encoding = tokenizer(input_text, truncation=True, padding='max_length', max_length=max_length, return_tensors="pt")
        
        # Thêm input_ids vào danh sách inputs
        inputs.append(encoding['input_ids'])
        
        # Mã hóa nhãn thành số
        label = example['output']
        label_map = {'entailment': 0, 'contradiction': 1, 'neutral': 2}
        labels.append(label_map.get(label, -1))  # Nếu không tìm thấy nhãn, gán -1

    # Chuyển inputs và labels thành tensor
    inputs_tensor = torch.cat(inputs, dim=0)
    labels_tensor = torch.tensor(labels)

    return inputs_tensor, labels_tensor
def prepare_data_for_training(data, tokenizer, max_length=128):
    inputs = tokenizer([item['input']['premise'] + " " + item['input']['hypothesis'] for item in data], 
                       padding=True, truncation=True, max_length=max_length, return_tensors="pt")
    labels = tokenizer([item['output'] for item in data], 
                       padding=True, truncation=True, max_length=max_length, return_tensors="pt")
    return inputs['input_ids'], labels['input_ids']
# Gọi hàm chuẩn bị dữ liệu
trainDataset = prepare_data_for_training(trainData, tokenizer)
testDataset = prepare_data_for_training(testData, tokenizer)
valDataset =  prepare_data_for_training(valData, tokenizer)


In [14]:
trainDataset[0]

tensor([[  465, 30785,     7,  ...,     0,     0,     0],
        [   71,   286,    13,  ...,     0,     0,     0],
        [ 1347,  6873,    77,  ...,     0,     0,     0],
        ...,
        [15971,    31,     7,  ...,     0,     0,     0],
        [12805, 28666,  2501,  ...,  3677, 11095,     1],
        [ 9299,    19,  9909,  ...,     0,     0,     0]])

In [15]:
def tokenize_dataset(data, tokenizer, max_length=128):
    """
    Tokenizes the dataset for training.
    
    Args:
        data (list): List of examples containing premise and hypothesis
        tokenizer: The tokenizer to use
        max_length (int): Maximum sequence length
    """
    tokenized_data = []
    
    for example in data:
        premise = example['input']['premise']
        hypothesis = example['input']['hypothesis']
        label = example['output'].lower()
        
        # Combine premise and hypothesis
        input_text = f"{premise} </s> {hypothesis}"
        
        # Tokenize input
        inputs = tokenizer(
            input_text,
            max_length=max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        
        # Tokenize output/label
        labels = tokenizer(
            label,
            max_length=max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        
        tokenized_data.append({
            'input_ids': inputs['input_ids'].squeeze(),
            'attention_mask': inputs['attention_mask'].squeeze(),
            'labels': labels['input_ids'].squeeze()
        })
    
    return tokenized_data

In [16]:
tokenized_train = tokenize_dataset(trainData, tokenizer)
tokenized_test = tokenize_dataset(testData, tokenizer)
tokenized_val = tokenize_dataset(valData, tokenizer)

In [17]:

from torch.utils.data import DataLoader, Dataset
from transformers import T5ForConditionalGeneration, T5Tokenizer
from torch.optim import AdamW
from torch.nn import CrossEntropyLoss
from tqdm import tqdm

class CustomDataset(Dataset):
    def __init__(self, tokenized_data):
        self.data = tokenized_data  # Now accepting the list directly

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        return {
            'input_ids': item['input_ids'],
            'attention_mask': item['attention_mask'],
            'labels': item['labels']
        }
BATCH_SIZE = 16
LEARNING_RATE = 1e-4
EPOCHS = 100

# Create data loaders with the fixed dataset class
train_dataset = CustomDataset(tokenized_train)
val_dataset = CustomDataset(tokenized_val)
test_dataset = CustomDataset(tokenized_test)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE)

# Optimizer and Loss Function
optimizer = AdamW(model.parameters(), lr=LEARNING_RATE)
loss_fn = CrossEntropyLoss(ignore_index=-100)
def compute_accuracy(logits, labels):
    predictions = torch.argmax(logits, dim=-1)
    valid_mask = (labels != -100)  # Ignore padding tokens
    accuracy = (predictions[valid_mask] == labels[valid_mask]).float().mean()
    return accuracy.item()
# First, add this class before your training loop
class EarlyStopping:
    def __init__(self, patience=5, min_delta=0, mode='min'):
        """
        patience (int): How many epochs to wait before stopping when loss is not improving
        min_delta (float): Minimum change in monitored value to qualify as an improvement
        mode (str): 'min' for loss, 'max' for accuracy
        """
        self.patience = patience
        self.min_delta = min_delta
        self.mode = mode
        self.counter = 0
        self.best_loss = None
        self.early_stop = False
        self.min_validation_loss = float('inf')

    def __call__(self, current_value):
        if self.best_loss is None:
            self.best_loss = current_value
            return False

        if self.mode == 'min':
            if current_value <= self.best_loss - self.min_delta:
                self.best_loss = current_value
                self.counter = 0
            else:
                self.counter += 1
        else:  # mode == 'max'
            if current_value >= self.best_loss + self.min_delta:
                self.best_loss = current_value
                self.counter = 0
            else:
                self.counter += 1

        if self.counter >= self.patience:
            self.early_stop = True
            return True
        return False


# Training Loop
def train_one_epoch(model, data_loader, optimizer, device):
    model.train()
    total_loss = 0
    total_accuracy = 0
    num_batches = len(data_loader)
    
    for batch in tqdm(data_loader, desc="Training"):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)

        optimizer.zero_grad()
        outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs["loss"]
        logits = outputs["logits"]
        
        # Compute accuracy
        accuracy = compute_accuracy(logits, labels)
        
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        total_accuracy += accuracy
        
    return total_loss / num_batches, total_accuracy / num_batches

# Validation Loop
def evaluate(model, data_loader, device):
    model.eval()
    total_loss = 0
    total_accuracy = 0
    num_batches = len(data_loader)
    
    with torch.no_grad():
        for batch in tqdm(data_loader, desc="Evaluating"):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
            loss = outputs["loss"]
            logits = outputs["logits"]
            
            # Compute accuracy
            accuracy = compute_accuracy(logits, labels)
            
            total_loss += loss.item()
            total_accuracy += accuracy
            
    return total_loss / num_batches, total_accuracy / num_batches

# Training and Validation Process
train_losses = []
train_accuracies = []
val_losses = []
val_accuracies = []

early_stopping = EarlyStopping(patience=5, min_delta=1e-4)
best_model_state = None
best_val_loss = float('inf')

for epoch in range(EPOCHS):
    print(f"\nEpoch {epoch + 1}/{EPOCHS}")
    
    # Training
    train_loss, train_accuracy = train_one_epoch(model, train_loader, optimizer, device)
    train_losses.append(train_loss)
    train_accuracies.append(train_accuracy)
    
    # Validation
    val_loss, val_accuracy = evaluate(model, val_loader, device)
    val_losses.append(val_loss)
    val_accuracies.append(val_accuracy)
    
    print(f"Training Loss: {train_loss:.4f} | Training Accuracy: {train_accuracy:.4f}")
    print(f"Validation Loss: {val_loss:.4f} | Validation Accuracy: {val_accuracy:.4f}")

    # Save best model
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        best_model_state = model.state_dict().copy()
    
    # Early stopping check
    if early_stopping(val_loss):
        print(f"Early stopping triggered after epoch {epoch + 1}")
        break

# Load best model after training
if best_model_state is not None:
    model.load_state_dict(best_model_state)
    print(f"Loaded best model with validation loss: {best_val_loss:.4f}")




Epoch 1/100


Training:   0%|                                                                                 | 0/156 [00:00<?, ?it/s]Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.
Training: 100%|███████████████████████████████████████████████████████████████████████| 156/156 [00:19<00:00,  8.17it/s]
Evaluating: 100%|███████████████████████████████████████████████████████████████████████| 18/18 [00:00<00:00, 22.57it/s]


Training Loss: 12.8236 | Training Accuracy: 0.1767
Validation Loss: 4.1894 | Validation Accuracy: 0.9800

Epoch 2/100


Training: 100%|███████████████████████████████████████████████████████████████████████| 156/156 [00:18<00:00,  8.50it/s]
Evaluating: 100%|███████████████████████████████████████████████████████████████████████| 18/18 [00:00<00:00, 27.77it/s]


Training Loss: 6.1399 | Training Accuracy: 0.5991
Validation Loss: 0.7472 | Validation Accuracy: 0.9924

Epoch 3/100


Training: 100%|███████████████████████████████████████████████████████████████████████| 156/156 [00:17<00:00,  8.94it/s]
Evaluating: 100%|███████████████████████████████████████████████████████████████████████| 18/18 [00:00<00:00, 27.67it/s]


Training Loss: 4.2581 | Training Accuracy: 0.6628
Validation Loss: 0.1230 | Validation Accuracy: 0.9958

Epoch 4/100


Training: 100%|███████████████████████████████████████████████████████████████████████| 156/156 [00:17<00:00,  8.98it/s]
Evaluating: 100%|███████████████████████████████████████████████████████████████████████| 18/18 [00:00<00:00, 27.59it/s]


Training Loss: 3.5644 | Training Accuracy: 0.6880
Validation Loss: 0.0318 | Validation Accuracy: 0.9958

Epoch 5/100


Training: 100%|███████████████████████████████████████████████████████████████████████| 156/156 [00:17<00:00,  8.94it/s]
Evaluating: 100%|███████████████████████████████████████████████████████████████████████| 18/18 [00:00<00:00, 27.28it/s]


Training Loss: 3.2106 | Training Accuracy: 0.6958
Validation Loss: 0.0138 | Validation Accuracy: 0.9958

Epoch 6/100


Training: 100%|███████████████████████████████████████████████████████████████████████| 156/156 [00:17<00:00,  8.91it/s]
Evaluating: 100%|███████████████████████████████████████████████████████████████████████| 18/18 [00:00<00:00, 22.57it/s]


Training Loss: 3.0252 | Training Accuracy: 0.6978
Validation Loss: 0.0094 | Validation Accuracy: 0.9961

Epoch 7/100


Training: 100%|███████████████████████████████████████████████████████████████████████| 156/156 [00:17<00:00,  8.89it/s]
Evaluating: 100%|███████████████████████████████████████████████████████████████████████| 18/18 [00:00<00:00, 27.49it/s]


Training Loss: 2.9303 | Training Accuracy: 0.6995
Validation Loss: 0.0074 | Validation Accuracy: 0.9967

Epoch 8/100


Training: 100%|███████████████████████████████████████████████████████████████████████| 156/156 [00:17<00:00,  8.69it/s]
Evaluating: 100%|███████████████████████████████████████████████████████████████████████| 18/18 [00:00<00:00, 27.43it/s]


Training Loss: 2.8699 | Training Accuracy: 0.7030
Validation Loss: 0.0071 | Validation Accuracy: 0.9959

Epoch 9/100


Training: 100%|███████████████████████████████████████████████████████████████████████| 156/156 [00:17<00:00,  8.95it/s]
Evaluating: 100%|███████████████████████████████████████████████████████████████████████| 18/18 [00:00<00:00, 27.59it/s]


Training Loss: 2.8335 | Training Accuracy: 0.7078
Validation Loss: 0.0070 | Validation Accuracy: 0.9958

Epoch 10/100


Training: 100%|███████████████████████████████████████████████████████████████████████| 156/156 [00:17<00:00,  8.97it/s]
Evaluating: 100%|███████████████████████████████████████████████████████████████████████| 18/18 [00:00<00:00, 27.97it/s]


Training Loss: 2.8222 | Training Accuracy: 0.7134
Validation Loss: 0.0066 | Validation Accuracy: 0.9962

Epoch 11/100


Training: 100%|███████████████████████████████████████████████████████████████████████| 156/156 [00:17<00:00,  8.85it/s]
Evaluating: 100%|███████████████████████████████████████████████████████████████████████| 18/18 [00:00<00:00, 27.52it/s]


Training Loss: 2.8036 | Training Accuracy: 0.7212
Validation Loss: 0.0065 | Validation Accuracy: 0.9962

Epoch 12/100


Training: 100%|███████████████████████████████████████████████████████████████████████| 156/156 [00:17<00:00,  8.95it/s]
Evaluating: 100%|███████████████████████████████████████████████████████████████████████| 18/18 [00:00<00:00, 28.05it/s]


Training Loss: 2.7798 | Training Accuracy: 0.7308
Validation Loss: 0.0065 | Validation Accuracy: 0.9967

Epoch 13/100


Training: 100%|███████████████████████████████████████████████████████████████████████| 156/156 [00:17<00:00,  8.85it/s]
Evaluating: 100%|███████████████████████████████████████████████████████████████████████| 18/18 [00:00<00:00, 27.14it/s]


Training Loss: 2.7855 | Training Accuracy: 0.7408
Validation Loss: 0.0064 | Validation Accuracy: 0.9962

Epoch 14/100


Training: 100%|███████████████████████████████████████████████████████████████████████| 156/156 [00:17<00:00,  8.87it/s]
Evaluating: 100%|███████████████████████████████████████████████████████████████████████| 18/18 [00:00<00:00, 27.90it/s]


Training Loss: 2.7805 | Training Accuracy: 0.7481
Validation Loss: 0.0064 | Validation Accuracy: 0.9958

Epoch 15/100


Training: 100%|███████████████████████████████████████████████████████████████████████| 156/156 [00:17<00:00,  9.01it/s]
Evaluating: 100%|███████████████████████████████████████████████████████████████████████| 18/18 [00:00<00:00, 28.20it/s]


Training Loss: 2.7882 | Training Accuracy: 0.7562
Validation Loss: 0.0062 | Validation Accuracy: 0.9962

Epoch 16/100


Training: 100%|███████████████████████████████████████████████████████████████████████| 156/156 [00:17<00:00,  8.96it/s]
Evaluating: 100%|███████████████████████████████████████████████████████████████████████| 18/18 [00:00<00:00, 27.65it/s]


Training Loss: 2.7703 | Training Accuracy: 0.7683
Validation Loss: 0.0062 | Validation Accuracy: 0.9963

Epoch 17/100


Training: 100%|███████████████████████████████████████████████████████████████████████| 156/156 [00:17<00:00,  8.97it/s]
Evaluating: 100%|███████████████████████████████████████████████████████████████████████| 18/18 [00:00<00:00, 27.99it/s]


Training Loss: 2.7623 | Training Accuracy: 0.7773
Validation Loss: 0.0062 | Validation Accuracy: 0.9967

Epoch 18/100


Training: 100%|███████████████████████████████████████████████████████████████████████| 156/156 [00:17<00:00,  8.74it/s]
Evaluating: 100%|███████████████████████████████████████████████████████████████████████| 18/18 [00:00<00:00, 27.60it/s]


Training Loss: 2.7718 | Training Accuracy: 0.7867
Validation Loss: 0.0061 | Validation Accuracy: 0.9961

Epoch 19/100


Training: 100%|███████████████████████████████████████████████████████████████████████| 156/156 [00:17<00:00,  8.80it/s]
Evaluating: 100%|███████████████████████████████████████████████████████████████████████| 18/18 [00:00<00:00, 27.96it/s]


Training Loss: 2.7537 | Training Accuracy: 0.7941
Validation Loss: 0.0061 | Validation Accuracy: 0.9964

Epoch 20/100


Training: 100%|███████████████████████████████████████████████████████████████████████| 156/156 [00:18<00:00,  8.62it/s]
Evaluating: 100%|███████████████████████████████████████████████████████████████████████| 18/18 [00:00<00:00, 22.48it/s]


Training Loss: 2.7728 | Training Accuracy: 0.8042
Validation Loss: 0.0060 | Validation Accuracy: 0.9963

Epoch 21/100


Training: 100%|███████████████████████████████████████████████████████████████████████| 156/156 [00:17<00:00,  8.75it/s]
Evaluating: 100%|███████████████████████████████████████████████████████████████████████| 18/18 [00:00<00:00, 27.59it/s]


Training Loss: 2.7596 | Training Accuracy: 0.8107
Validation Loss: 0.0063 | Validation Accuracy: 0.9958

Epoch 22/100


Training: 100%|███████████████████████████████████████████████████████████████████████| 156/156 [00:18<00:00,  8.58it/s]
Evaluating: 100%|███████████████████████████████████████████████████████████████████████| 18/18 [00:00<00:00, 25.69it/s]


Training Loss: 2.7631 | Training Accuracy: 0.8208
Validation Loss: 0.0061 | Validation Accuracy: 0.9967

Epoch 23/100


Training: 100%|███████████████████████████████████████████████████████████████████████| 156/156 [00:17<00:00,  8.87it/s]
Evaluating: 100%|███████████████████████████████████████████████████████████████████████| 18/18 [00:00<00:00, 27.52it/s]


Training Loss: 2.7694 | Training Accuracy: 0.8244
Validation Loss: 0.0061 | Validation Accuracy: 0.9967

Epoch 24/100


Training: 100%|███████████████████████████████████████████████████████████████████████| 156/156 [00:18<00:00,  8.47it/s]
Evaluating: 100%|███████████████████████████████████████████████████████████████████████| 18/18 [00:00<00:00, 22.44it/s]


Training Loss: 2.7626 | Training Accuracy: 0.8343
Validation Loss: 0.0062 | Validation Accuracy: 0.9962

Epoch 25/100


Training: 100%|███████████████████████████████████████████████████████████████████████| 156/156 [00:17<00:00,  8.91it/s]
Evaluating: 100%|███████████████████████████████████████████████████████████████████████| 18/18 [00:00<00:00, 27.70it/s]

Training Loss: 2.7699 | Training Accuracy: 0.8419
Validation Loss: 0.0062 | Validation Accuracy: 0.9959
Early stopping triggered after epoch 25
Loaded best model with validation loss: 0.0060





In [19]:
# Visualize metrics
import plotly.graph_objects as go
from plotly.subplots import make_subplots

# Create subplots
fig = make_subplots(rows=2, cols=1, 
                    subplot_titles=('Loss', 'Accuracy'),
                    vertical_spacing=0.15)

# Add loss traces
fig.add_trace(
    go.Scatter(x=list(range(EPOCHS)), y=train_losses, name="Training Loss", 
               mode='lines+markers', line=dict(color='blue')),
    row=1, col=1
)
fig.add_trace(
    go.Scatter(x=list(range(EPOCHS)), y=val_losses, name="Validation Loss", 
               mode='lines+markers', line=dict(color='red')),
    row=1, col=1
)

# Add accuracy traces
fig.add_trace(
    go.Scatter(x=list(range(EPOCHS)), y=train_accuracies, name="Training Accuracy", 
               mode='lines+markers', line=dict(color='green')),
    row=2, col=1
)
fig.add_trace(
    go.Scatter(x=list(range(EPOCHS)), y=val_accuracies, name="Validation Accuracy", 
               mode='lines+markers', line=dict(color='orange')),
    row=2, col=1
)

# Update layout
fig.update_layout(
    height=800,
    title_text="Training Metrics",
    showlegend=True,
    template='plotly_dark'
)

fig.show()


In [20]:
model.save_pretrained("t5_lv_model")


In [24]:
# Load base model first
base_model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-small")
model = LVModel(base_model=base_model)

# Load the saved state dictionary
state_dict = torch.load("./t5_lv_model/FlanT5.bin")

# Get the encoder embedding weights
encoder_embed_weight = state_dict['base_model.encoder.embed_tokens.weight']

# Add the shared embeddings back to the state dict
state_dict['base_model.shared.weight'] = encoder_embed_weight
state_dict['base_model.decoder.embed_tokens.weight'] = encoder_embed_weight

# Load the modified state dict
model.load_state_dict(state_dict)
model.to(device)

# Test model
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE)
test_loss, test_accuracy = evaluate(model, test_loader, device)
print(f"Test Loss: {test_loss:.4f} | Test Accuracy: {test_accuracy:.4f}")


You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.

Evaluating: 100%|█████████████████████████████████████████████████████████████████████| 188/188 [00:06<00:00, 27.79it/s]

Test Loss: 0.0050 | Test Accuracy: 0.9993



