In [None]:
from transformers import T5Tokenizer, T5ForConditionalGeneration, T5Config
from torch.optim import AdamW
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
from transformers import get_linear_schedule_with_warmup
import torch
from tqdm import tqdm
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from torch.optim.lr_scheduler import ReduceLROnPlateau, StepLR
from torch.nn.utils import clip_grad_norm_
import os
import numpy as np
import json

In [None]:
DATA_DIR = "/kaggle/input/experiment1"
OUTPUT_DIR = "/kaggle/working/"

In [None]:
class SCANDataset(Dataset):
    def __init__(self, data, tokenizer):
        self.tokenizer = tokenizer
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        input_text, target_text = self.data[idx]
        
        # Tokenize the input and target text without truncating or using fixed max_length
        input_ids = self.tokenizer(input_text, return_tensors="pt", padding=False).input_ids.squeeze()
        target_ids = self.tokenizer(target_text, return_tensors="pt", padding=False).input_ids.squeeze()
        
        return {
            'input_ids': input_ids,
            'target_ids': target_ids,
        }


def parse_dataset(file_path):
        """Parse the dataset from a text file."""
        inputs, outputs = [], []
        with open(file_path, 'r') as file:
            for line in file:
                # Assuming each line has 'IN:' and 'OUT:'
                try:
                    in_part, out_part = line.split('OUT:')
                    in_part = in_part.replace('IN:', '').strip()
                    out_part = out_part.strip()
                    inputs.append(in_part)
                    outputs.append(out_part)
                except ValueError:
                    continue  # Skip lines that don't match the expected format

        return list(zip(inputs, outputs))  # Returning list of tuples
    

# Collate function for padding sequences
def collate_fn(batch):
    # Extract input and target sequences from the batch
    input_seqs = [torch.tensor(item['input_ids']) for item in batch]
    target_seqs = [torch.tensor(item['target_ids']) for item in batch]
    
    # Find the maximum length in the batch (for both inputs and targets)
    max_input_len = max([seq.size(0) for seq in input_seqs])
    max_target_len = max([seq.size(0) for seq in target_seqs])
    
    # Pad sequences dynamically based on the longest sequence in the batch
    input_padded = pad_sequence(input_seqs, batch_first=True, padding_value=0)
    target_padded = pad_sequence(target_seqs, batch_first=True, padding_value=0)
    
    # If input and target sequences have different max lengths, pad them individually
    input_padded = torch.cat([input_padded, torch.zeros((input_padded.size(0), max_input_len - input_padded.size(1)), dtype=torch.long)], dim=1)
    target_padded = torch.cat([target_padded, torch.zeros((target_padded.size(0), max_target_len - target_padded.size(1)), dtype=torch.long)], dim=1)
    
    return {
        'input_ids': input_padded,
        'labels': target_padded,
    }

# Function to split dataset into training and validation
def split_dataset(data, train_ratio=0.8):
    """Split a list of data into training and validation subsets."""
    train_size = int(train_ratio * len(data))
    train_data = data[:train_size]
    val_data = data[train_size:]
    return train_data, val_data


def create_dataloader(parsed_data, tokenizer, batch_size, shuffle):
    dataset = SCANDataset(parsed_data, tokenizer)
    return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, collate_fn=collate_fn)

def plot_metrics(train_losses, val_losses, train_accuracies, val_accuracies, save_path='metrics.png'):
    """Plot training and validation loss and accuracy."""
    epochs = range(1, len(train_losses) + 1)
    plt.figure(figsize=(10, 5))

    plt.subplot(1, 2, 1)
    plt.plot(epochs, train_losses, label="Training Loss", color='blue')
    plt.plot(epochs, val_losses, label="Validation Loss", color='orange')
    plt.xlabel("Epochs")
    plt.ylabel("Loss")
    plt.title("Training and Validation Loss")
    plt.legend()


    plt.subplot(1, 2, 2)
    plt.plot(epochs, train_accuracies, label="Training Accuracy", color='blue')
    plt.plot(epochs, val_accuracies, label="Validation Accuracy", color='orange')
    plt.xlabel("Epochs")
    plt.ylabel("Accuracy")
    plt.title("Training and Validation Accuracy")
    plt.legend()

    plt.tight_layout()
    plt.savefig(save_path)
    plt.show()


### Learning Rate Finder

In [None]:
class LRFinder:
    def __init__(self, model, optimizer, device):
        self.model = model
        self.optimizer = optimizer
        self.device = device
        self.history = {"lr": [], "loss": []}

    def range_test(self, train_loader, start_lr=1e-6, end_lr=1e-2, num_iters=100, clip_grad_norm=None):
        lr_multiplier = (end_lr / start_lr) ** (1 / num_iters)
        lr = start_lr
        
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = lr

        self.model.train()
    
        for batch_idx, batch in enumerate(train_loader):
            if batch_idx >= num_iters:
                break
            input_ids = batch["input_ids"].to(self.device)
            labels = batch["labels"].to(self.device)
    
            self.optimizer.zero_grad()
            outputs = self.model(input_ids=input_ids, labels=labels)
            loss = outputs.loss
            loss.backward()
    
            if clip_grad_norm:
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), clip_grad_norm)
            
            self.optimizer.step()
    
            # Record and update learning rate
            self.history["lr"].append(lr)
            self.history["loss"].append(loss.item())
            lr *= lr_multiplier
            for param_group in self.optimizer.param_groups:
                param_group['lr'] = lr
    
            print(f"Batch [{batch_idx+1}/{num_iters}] - Loss: {loss.item()} - LR: {lr}")
    
            # Early stopping for diverging loss
            if loss.item() > 100 or torch.isnan(loss):
                print("Loss diverged, stopping early.")
                break
    
        return self.history

    def plot_lr_find(self):
        # Plot the learning rate vs loss graph
        lrs = np.array(self.history['lr'])
        losses = np.array(self.history['loss'])
        plt.figure(figsize=(8, 6))
        plt.plot(lrs, losses)
        plt.xscale('log')
        plt.xlabel("Learning Rate (log scale)")
        plt.ylabel("Loss")
        plt.title("LR Finder")
        plt.show()

device = device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dropout_rate = 0.05 
config = T5Config.from_pretrained('t5-small', 
                                  attention_dropout=dropout_rate,
                                  activation_dropout=dropout_rate,
                                  dropout_rate=dropout_rate)
model = T5ForConditionalGeneration.from_pretrained('t5-small', config=config).to(device)
optimizer = AdamW(model.parameters(), lr=1e-6, weight_decay=0.05)
tokenizer = T5Tokenizer.from_pretrained('t5-small')

train_file_path = os.path.join(DATA_DIR, f"tasks_train_simple_p64.txt")
train_data = parse_dataset(train_file_path)
train_set, val_set = split_dataset(train_data)
train_loader = create_dataloader(train_set, batch_size=32, shuffle=True, tokenizer=tokenizer)

# Step 2: Initialize the LR Finder
lr_finder = LRFinder(model, optimizer,  device=device)

# Step 3: Run the range test
history = lr_finder.range_test(train_loader, start_lr=1e-6, end_lr=1e-2, num_iters=100)

# Step 4: Plot the results and choose the best learning rate
lr_finder.plot_lr_find()

#### Training and Evaluation Function

In [None]:
def train_and_validate(
        model, train_loader, val_loader, device,
        num_epochs=10, lr=5e-4, grad_clip=1.0, tgt_pad_idx=0, patience=3, lr_factor=0.5, weight_decay=0.05
):
    """Train and validate the T5 model with early stopping, L2 regularization, and dynamic learning rate."""
    model.to(device)

    train_losses, val_losses = [], []
    train_accuracies, val_accuracies = [], []

    best_val_loss = float('inf')  # Initialize the best validation loss
    best_val_acc = 0.0  # Track the best validation accuracy
    best_model_state = None  # Variable to store the best model weights
    patience_counter = 0  # Counter to track patience for early stopping

    optimizer = AdamW(model.parameters(), lr, weight_decay=weight_decay)
    
    scheduler = StepLR(optimizer, step_size=5, gamma=0.9)

    for epoch in range(1, num_epochs + 1):
        print(f"Epoch [{epoch}/{num_epochs}]")

        # Training phase
        model.train()
        train_loss, train_correct, train_total = 0, 0, 0
        with tqdm(total=len(train_loader), desc="Training") as pbar:
            for batch in train_loader:
                input_ids = batch["input_ids"].to(device)
                labels = batch["labels"].to(device)

                optimizer.zero_grad()
                
                # Forward pass
                outputs = model(input_ids=input_ids, labels=labels)
                loss = outputs.loss
                logits = outputs.logits

                # Backward pass and optimizer step
                loss.backward()
                clip_grad_norm_(model.parameters(), grad_clip)
                optimizer.step()

                train_loss += loss.item()
                
                # Calculate token-level accuracy
                predictions = torch.argmax(logits, dim=-1).to(device)
                target_tokens = labels.to(device)

                non_pad_mask = target_tokens != tokenizer.pad_token_id
                correct_tokens = (predictions == target_tokens) & non_pad_mask
                train_correct += correct_tokens.sum().item()
                train_total += non_pad_mask.sum().item()

                pbar.set_postfix(loss=loss.item(), accuracy=train_correct / train_total)
                pbar.update()

        train_losses.append(train_loss / len(train_loader))
        train_accuracies.append(train_correct / train_total)
        print(f"Training Loss: {train_losses[-1]:.4f}, Accuracy: {train_accuracies[-1]:.4f}")

        # Validation phase
        model.eval()
        val_loss, val_correct, val_total = 0, 0, 0
        with torch.no_grad():
            with tqdm(total=len(val_loader), desc="Validation") as pbar:
                for batch in val_loader:
                    input_ids = batch["input_ids"].to(device)
                    labels = batch["labels"].to(device)

                    # Forward pass
                    outputs = model(input_ids=input_ids, labels=labels)
                    logits = outputs.logits

                    # Calculate validation loss
                    loss = outputs.loss
                    val_loss += loss.item()

                    # Calculate token-level accuracy for validation
                    predictions = torch.argmax(logits, dim=-1).to(device)
                    target_tokens = labels.to(device)

                    non_pad_mask = target_tokens != tokenizer.pad_token_id
                    correct_tokens = (predictions == target_tokens) & non_pad_mask
                    val_correct += correct_tokens.sum().item()
                    val_total += non_pad_mask.sum().item()

                    pbar.set_postfix(loss=loss.item(), accuracy=val_correct / val_total)
                    pbar.update()

        val_losses.append(val_loss / len(val_loader))
        val_accuracies.append(val_correct / val_total)
        print(f"Validation Loss: {val_losses[-1]:.4f}, Accuracy: {val_accuracies[-1]:.4f}")

        # Step the scheduler
        scheduler.step(val_losses[-1])  # Pass the current validation loss to scheduler

        # Save the model if validation accuracy is the best so far
        if val_accuracies[-1] > best_val_acc:
            best_val_acc = val_accuracies[-1]
            best_model_state = model.state_dict()  # Save the model state

        # Early stopping check
        if val_losses[-1] < best_val_loss:
            best_val_loss = val_losses[-1]
            patience_counter = 0  # Reset patience counter if validation loss improved
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print(f"Early stopping at epoch {epoch}. No improvement in validation loss for {patience} epochs.")
                break  # Stop training if patience is exceeded
                
        

    plot_metrics(train_losses, val_losses, train_accuracies, val_accuracies)
    return best_val_acc

In [None]:
def evaluate_model(model, data_loader, device, tgt_pad_idx, eos_idx, oracle=False): # Note: use this only for oracle=True
    model.eval()
    test_correct, test_total, sequence_correct, total_sequences = 0, 0, 0, 0

    with torch.no_grad():
        # Wrap data_loader with tqdm to display a progress bar
        for batch in tqdm(data_loader, desc="Evaluating", unit="batch"):
            input_ids = batch["input_ids"].to(device)
            labels = batch["labels"].to(device)  # Assuming labels are the target sequence.
            batch_size = labels.size()[0]

            # Perform a single query if oracle=False
            if not oracle:
                outputs = model(input_ids=input_ids, labels=labels)
                logits = outputs.logits
                predictions = torch.argmax(logits, dim=-1).to(device)  # shape= (batch_size, seq_len)
            else:
                # If oracle=True, perform step-by-step decoding
                batch_size, max_len = labels.size()
                predictions = torch.full((batch_size, max_len), tgt_pad_idx, dtype=torch.long, device=device)

                # Set first token as the label's first token (start token), so we give the model the first token of the target sequence as input. Maybe its better to give first prediction...
                predictions[:, 0] = labels[:, 0] 

                for step in range(1, max_len):
                    # Query the model for the next token logits based on the current predicted sequence
                    outputs = model(input_ids=input_ids, decoder_input_ids=predictions[:, :step])
                    logits = outputs.logits  # Shape: (batch_size, step, vocab_size)

                    # Get the logits for the current step (last token's logits)
                    step_logits = logits[:, -1, :]  # Shape: (batch_size, vocab_size)
                    top2_tokens = step_logits.topk(2, dim=-1).indices  # Get top 2 token IDs

                    # Default to the highest probability token
                    next_tokens = top2_tokens[:, 0]  

                    premature_eos_mask = (next_tokens == eos_idx) & (step < (labels != tgt_pad_idx).sum(dim=1) - 1)
                    next_tokens[premature_eos_mask] = top2_tokens[premature_eos_mask, 1]  # Use second-most probable token

                    # Update predictions with the next token
                    predictions[:, step] = next_tokens

                    # Break early if all sequences terminate
                    if torch.all(predictions[:, step] == eos_idx):
                        break

            # Token-level accuracy
            target_tokens = labels.to(device)
            non_pad_mask = (target_tokens != tgt_pad_idx)  # Ignore padding
            non_eos_mask = (target_tokens != eos_idx)  # Ignore EOS
            valid_mask = non_pad_mask & non_eos_mask

            correct_tokens = (predictions == labels) & valid_mask
            test_correct += correct_tokens.sum().item()
            test_total += valid_mask.sum().item()

            # Sequence-level accuracy
            for pred, target in zip(predictions, labels):
                pred_seq = pred[:(pred == eos_idx).nonzero(as_tuple=True)[0][0]] if eos_idx in pred else pred
                target_seq = target[:(target == eos_idx).nonzero(as_tuple=True)[0][0]] if eos_idx in target else target

                pred_seq = pred_seq[pred_seq != tgt_pad_idx]  # Remove padding
                target_seq = target_seq[target_seq != tgt_pad_idx]  # Remove padding

                if torch.equal(pred_seq, target_seq):
                    sequence_correct += 1

            total_sequences += batch_size 

    # Calculate final token and sequence accuracies
    token_accuracy = test_correct / test_total if test_total > 0 else 0
    sequence_accuracy = sequence_correct / total_sequences if total_sequences > 0 else 0

    return token_accuracy, sequence_accuracy


In [None]:
train_fractions = [1, 2, 4, 8, 16, 32, 64]
random_seeds = [42, 123, 456]

results = {}

token_accuracies = []  # Test token accuracies for plotting (averaged over 3 random seeds)
sequence_accuracies = []  # Test sequence accuracies for plotting (averaged over 3 random seeds)

dropout_rate = 0.05 
config = T5Config.from_pretrained('t5-small', 
                                  attention_dropout=dropout_rate,
                                  activation_dropout=dropout_rate,
                                  dropout_rate=dropout_rate)

tokenizer = T5Tokenizer.from_pretrained('t5-small')

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

tgt_pad_idx = tokenizer.pad_token_id
eos_idx = tokenizer.eos_token_id

for train_frac in train_fractions:
    # Load datasets
    train_file_path = os.path.join(DATA_DIR, f"tasks_train_simple_p{train_frac}.txt")
    test_file_path = os.path.join(DATA_DIR, f"tasks_test_simple_p{train_frac}.txt")

    # Parse the datasets
    train_data = parse_dataset(train_file_path)
    test_data = parse_dataset(test_file_path)

    train_set, val_set = split_dataset(train_data)

    # Create dataloaders for training, validation, and test sets
    train_loader = create_dataloader(train_set, batch_size=32, shuffle=True, tokenizer=tokenizer)
    val_loader = create_dataloader(val_set, batch_size=32, shuffle=True, tokenizer=tokenizer)
    test_loader = create_dataloader(test_data, batch_size=32, shuffle=False, tokenizer=tokenizer)

    token_accuracies_for_fractions = []  # Store token accuracies for each seed
    sequence_accuracies_for_fractions = []  # Store sequence accuracies for each seed
    
    for seed in random_seeds:
        torch.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        
        model = T5ForConditionalGeneration.from_pretrained('t5-small', config=config).to(device)

        train_and_validate(model, train_loader, val_loader, device, num_epochs=10)  # Train model
        
        token_accuracy, sequence_accuracy = evaluate_model(model, test_loader, device, tgt_pad_idx, eos_idx)  # Evaluate model
        token_accuracies_for_fractions.append(token_accuracy)
        sequence_accuracies_for_fractions.append(sequence_accuracy)

        print(f"Training with {train_frac}% of data on random seed {seed} - Token Accuracy: {token_accuracy:.4f}, Sequence Accuracy: {sequence_accuracy:.4f}")

        # Delete model
        model.to('cpu')
        del model
        torch.cuda.empty_cache()
        
    # Average accuracies over all seeds for the current train_frac
    average_token_accuracy = sum(token_accuracies_for_fractions) / len(token_accuracies_for_fractions)
    average_sequence_accuracy = sum(sequence_accuracies_for_fractions) / len(sequence_accuracies_for_fractions)
    
    token_accuracies.append(average_token_accuracy)
    sequence_accuracies.append(average_sequence_accuracy)

    results[train_frac] = {
        'average_token_accuracy': average_token_accuracy,
        'average_sequence_accuracy': average_sequence_accuracy
    }


print("Final Results:")
for i, train_frac in enumerate(train_fractions):
    print(f"Training Fraction: {train_frac}% - Token Accuracy: {token_accuracies[i]:.4f}, Sequence Accuracy: {sequence_accuracies[i]:.4f}")

output_file = os.path.join(OUTPUT_DIR, 'training_results.json')
with open(output_file, 'w') as f:
    json.dump(results, f, indent=4)

print(f"Results saved to {output_file}")

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(16, 6), sharey=False)

# Left subplot: Token-level accuracy
axes[0].bar([f"{f}%" for f in train_fractions], token_accuracies, color='#D6A9D0')
axes[0].set_xlabel('Commands used (%)')
axes[0].set_ylabel('Accuracy')
axes[0].set_title('Token Level Accuracy')
axes[0].grid(True)

# Right subplot: Sequence-level accuracy
axes[1].bar([f"{f}%" for f in train_fractions], sequence_accuracies, color='#D6A9D0')
axes[1].set_xlabel('Commands used (%)')
axes[1].set_ylabel('Accuracy')
axes[1].set_title('Sequence Level Accuracy')
axes[1].grid(True)

plt.tight_layout()
plt.savefig(os.path.join(OUTPUT_DIR, "accuracies_exp1.png"))

plt.show()