In [None]:
from transformers import T5Tokenizer, T5ForConditionalGeneration, T5Config
from torch.optim import AdamW
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
from transformers import get_linear_schedule_with_warmup
import torch
from tqdm import tqdm
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from torch.optim.lr_scheduler import ReduceLROnPlateau, StepLR
from torch.nn.utils import clip_grad_norm_
import os
import json
import numpy as np 
from collections import defaultdict

In [None]:
DATA_DIR = "/kaggle/input/experiment2"
OUTPUT_DIR = "/kaggle/working/"

In [None]:
class SCANDataset(Dataset):
    def __init__(self, data, tokenizer):
        self.tokenizer = tokenizer
        self.data = data

    def __len__(self):
        return len(self.data)
        
    def __getitem__(self, idx):
        input_text, target_text = self.data[idx]
    
        # Tokenize input and target text with attention masks
        input_encoding = self.tokenizer(input_text, return_tensors="pt", padding=False)
        target_encoding = self.tokenizer(target_text, return_tensors="pt", padding=False)
    
        input_ids = input_encoding.input_ids.squeeze()
        input_attention_mask = input_encoding.attention_mask.squeeze()
    
        target_ids = target_encoding.input_ids.squeeze()
        target_attention_mask = target_encoding.attention_mask.squeeze()
    
        return {
            'input_ids': input_ids,
            'input_attention_mask': input_attention_mask,
            'target_ids': target_ids,
            'target_attention_mask': target_attention_mask,
        }
    
        
def parse_dataset(file_path):
        """Parse the dataset from a text file."""
        inputs, outputs = [], []
        with open(file_path, 'r') as file:
            for line in file:
                # Assuming each line has 'IN:' and 'OUT:'
                try:
                    in_part, out_part = line.split('OUT:')
                    in_part = in_part.replace('IN:', '').strip()
                    out_part = out_part.strip()
                    inputs.append(in_part)
                    outputs.append(out_part)
                except ValueError:
                    continue  # Skip lines that don't match the expected format

        return list(zip(inputs, outputs))  # Returning list of tuples


def collate_fn(batch):
    '''Processes a batch of data by padding input and target sequences along with their attention masks and returns them in a dictionary.'''
    # Extract input and target sequences along with their attention masks
    input_seqs = [torch.tensor(item['input_ids']) for item in batch]
    input_masks = [torch.tensor(item['input_attention_mask']) for item in batch]
    target_seqs = [torch.tensor(item['target_ids']) for item in batch]
    target_masks = [torch.tensor(item['target_attention_mask']) for item in batch]

    # Pad sequences dynamically
    input_padded = pad_sequence(input_seqs, batch_first=True, padding_value=0)
    input_masks_padded = pad_sequence(input_masks, batch_first=True, padding_value=0)
    target_padded = pad_sequence(target_seqs, batch_first=True, padding_value=0)
    target_masks_padded = pad_sequence(target_masks, batch_first=True, padding_value=0)

    return {
        'input_ids': input_padded,
        'input_attention_mask': input_masks_padded,
        'labels': target_padded,
        'label_attention_mask': target_masks_padded,
    }

def split_dataset(data, train_ratio=0.8):
    """Splits a list of data into training and validation subsets."""
    train_size = int(train_ratio * len(data))
    train_data = data[:train_size]
    val_data = data[train_size:]
    return train_data, val_data


def create_dataloader(parsed_data, tokenizer, batch_size, shuffle):
    '''Creates a DataLoader for batching and shuffling data from the parsed dataset using a tokenizer'''
    dataset = SCANDataset(parsed_data, tokenizer)
    return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, collate_fn=collate_fn)

def plot_metrics(train_losses, val_losses, train_accuracies, val_accuracies, save_path='metrics.png'):
    """Plot training and validation loss and accuracy."""
    epochs = range(1, len(train_losses) + 1)
    plt.figure(figsize=(10, 5))

    plt.subplot(1, 2, 1)
    plt.plot(epochs, train_losses, label="Training Loss", color='blue')
    plt.plot(epochs, val_losses, label="Validation Loss", color='orange')
    plt.xlabel("Epochs")
    plt.ylabel("Loss")
    plt.title("Training and Validation Loss")
    plt.legend()


    plt.subplot(1, 2, 2)
    plt.plot(epochs, train_accuracies, label="Training Accuracy", color='blue')
    plt.plot(epochs, val_accuracies, label="Validation Accuracy", color='orange')
    plt.xlabel("Epochs")
    plt.ylabel("Accuracy")
    plt.title("Training and Validation Accuracy")
    plt.legend()

    plt.tight_layout()
    plt.savefig(save_path)
    plt.show()

In [None]:
def create_dictionaries_and_histograms(data):
    """
    Creates dictionaries mapping the length of input commands and output action sequences (in words)
    to the corresponding command-action pairs. The function also generates histograms to visualize 
    the distribution of command lengths and action sequence lengths.

    Args:
        data (list of tuples): A list where each element is a tuple containing a command (input)
                               and an action sequence (output). Both command and action are strings.

    Returns:
        tuple: Two dictionaries:
            - command_length_dict (dict): A dictionary where the keys are the lengths of commands 
              (in words), and the values are lists of tuples (command, action) corresponding to 
              that length.
            - action_sequence_length_dict (dict): A dictionary where the keys are the lengths of 
              action sequences (in words), and the values are lists of tuples (command, action) 
              corresponding to that length.

    Plots:
        - A histogram of the distribution of command lengths (in words).
        - A histogram of the distribution of action sequence lengths (in words).
    """
    
    command_length_dict = {}
    action_sequence_length_dict = {}

    for command, action in data:
        command_length = len(command.split())
        action_length = len(action.split())

        # Add to command length dictionary
        if command_length not in command_length_dict:
            command_length_dict[command_length] = []
        command_length_dict[command_length].append((command, action))

        # Add to action sequence length dictionary
        if action_length not in action_sequence_length_dict:
            action_sequence_length_dict[action_length] = []
        action_sequence_length_dict[action_length].append((command, action))

    # Plot histograms
    plt.figure(figsize=(12, 6))

    # Histogram for command lengths
    plt.subplot(1, 2, 1)
    plt.bar(command_length_dict.keys(), [len(v) for v in command_length_dict.values()], color='blue', alpha=0.7)
    plt.xlabel("Command Length (words)")
    plt.ylabel("Number of Lines")
    plt.title("Histogram of Command Lengths")

    # Histogram for action sequence lengths
    plt.subplot(1, 2, 2)
    plt.bar(action_sequence_length_dict.keys(), [len(v) for v in action_sequence_length_dict.values()], color='green', alpha=0.7)
    plt.xlabel("Action Sequence Length (words)")
    plt.ylabel("Number of Lines")
    plt.title("Histogram of Action Sequence Lengths")

    plt.tight_layout()
    plt.show()

    return command_length_dict, action_sequence_length_dict

file_path = os.path.join(DATA_DIR, "tasks_test_length.txt")
parsed_data = parse_dataset(file_path)
command_dict, action_dict = create_dictionaries_and_histograms(parsed_data)

In [None]:
def train_and_validate(
        model, train_loader, val_loader, device,
        num_epochs=15, lr= 7e-4, grad_clip=1.0, tgt_pad_idx=0, patience=3, lr_factor=0.5, weight_decay=0.05
):
    """Train and validate the T5 model with early stopping, L2 regularization, and dynamic learning rate."""
    model.to(device)

    train_losses, val_losses = [], []
    train_accuracies, val_accuracies = [], []

    best_val_loss = float('inf')  # Initialize the best validation loss
    best_val_acc = 0.0  # Track the best validation accuracy
    best_model_state = None  # Variable to store the best model weights
    patience_counter = 0  # Counter to track patience for early stopping

    optimizer = AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
    scheduler = StepLR(optimizer, step_size=5, gamma=0.9)

    for epoch in range(1, num_epochs + 1):
        print(f"Epoch [{epoch}/{num_epochs}]")

        # Training phase
        model.train()
        train_loss, train_correct, train_total = 0, 0, 0
        with tqdm(total=len(train_loader), desc="Training") as pbar:
            for batch in train_loader:
                input_ids = batch["input_ids"].to(device)
                labels = batch["labels"].to(device)
                input_attention_mask = batch["input_attention_mask"].to(device)

                optimizer.zero_grad()
                
                # Forward pass
                outputs = model(input_ids=input_ids, attention_mask=input_attention_mask, labels=labels)
                loss = outputs.loss
                logits = outputs.logits

                # Backward pass and optimizer step
                loss.backward()
                clip_grad_norm_(model.parameters(), grad_clip)
                optimizer.step()

                train_loss += loss.item()
                
                # Calculate token-level accuracy
                predictions = torch.argmax(logits, dim=-1).to(device)
                target_tokens = labels.to(device)

                non_pad_mask = target_tokens != tokenizer.pad_token_id
                correct_tokens = (predictions == target_tokens) & non_pad_mask
                train_correct += correct_tokens.sum().item()
                train_total += non_pad_mask.sum().item()

                pbar.set_postfix(loss=loss.item(), accuracy=train_correct / train_total)
                pbar.update()

        train_losses.append(train_loss / len(train_loader))
        train_accuracies.append(train_correct / train_total)
        print(f"Training Loss: {train_losses[-1]:.4f}, Accuracy: {train_accuracies[-1]:.4f}")

        # Validation phase
        model.eval()
        val_loss, val_correct, val_total = 0, 0, 0
        with torch.no_grad():
            with tqdm(total=len(val_loader), desc="Validation") as pbar:
                for batch in val_loader:
                    input_ids = batch["input_ids"].to(device)
                    labels = batch["labels"].to(device)

                    # Forward pass
                    outputs = model(input_ids=input_ids, labels=labels)
                    logits = outputs.logits

                    # Calculate validation loss
                    loss = outputs.loss
                    val_loss += loss.item()

                    # Calculate token-level accuracy for validation
                    predictions = torch.argmax(logits, dim=-1).to(device)
                    target_tokens = labels.to(device)

                    non_pad_mask = target_tokens != tokenizer.pad_token_id
                    correct_tokens = (predictions == target_tokens) & non_pad_mask
                    val_correct += correct_tokens.sum().item()
                    val_total += non_pad_mask.sum().item()

                    pbar.set_postfix(loss=loss.item(), accuracy=val_correct / val_total)
                    pbar.update()

        val_losses.append(val_loss / len(val_loader))
        val_accuracies.append(val_correct / val_total)
        print(f"Validation Loss: {val_losses[-1]:.4f}, Accuracy: {val_accuracies[-1]:.4f}")

        # Step the scheduler
        scheduler.step(val_losses[-1])  # Pass the current validation loss to scheduler

        # Save the model if validation accuracy is the best so far
        if val_accuracies[-1] > best_val_acc:
            best_val_acc = val_accuracies[-1]
            best_model_state = model.state_dict()  # Save the model state

        # Early stopping check
        if val_losses[-1] < best_val_loss:
            best_val_loss = val_losses[-1]
            patience_counter = 0  # Reset patience counter if validation loss improved
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print(f"Early stopping at epoch {epoch}. No improvement in validation loss for {patience} epochs.")
                break  # Stop training if patience is exceeded
                
        

    plot_metrics(train_losses, val_losses, train_accuracies, val_accuracies)
    return best_val_acc

In [None]:
#torch.set_printoptions(threshold=torch.inf)
def evaluate_model(model, data_loader, device, tgt_pad_idx, eos_idx, oracle=False):
    model.eval()
    test_correct, test_total, sequence_correct, total_sequences = 0, 0, 0, 0

    with torch.no_grad():
        for batch in data_loader:
            input_ids = batch["input_ids"].to(device)  # Shape: (batch_size, seq_len)
            labels = batch["labels"].to(device)  # Shape: (batch_size, seq_len)
            input_attention_mask = batch["input_attention_mask"].to(device)  

            batch_size, max_len = labels.size()
            if not oracle:
                # Perform a single query
                outputs = model(input_ids=input_ids, labels=labels)
                logits = outputs.logits
                predictions = torch.argmax(logits, dim=-1).to(device)  # Shape: (batch_size, seq_len)
                #print(predictions)
                for i in range(batch_size):
                    # Find the first occurrence of eos_idx (1) for each sequence
                    eos_position = (predictions[i] == eos_idx).nonzero(as_tuple=True)
                    
                    if len(eos_position[0]) > 0:
                        first_eos_index = eos_position[0].min().item()  # Get the index of the first EOS token
                        predictions[i, first_eos_index + 1:] = tgt_pad_idx
                        
                #print(predictions)
            else:
                # perform step-by-step greedy decoding
                predictions = torch.full((batch_size, max_len), tgt_pad_idx, dtype=torch.long, device=device)
                predictions[:, 0] = labels[:, 0]  # (batch_size, seq_length) Initialize with the first token
                
                for step in range(1, max_len + 1):
                    current_sequence = predictions[:, :step]  # Shape: (batch_size, step)
                    output = model(input_ids=input_ids, attention_mask=input_attention_mask, labels=current_sequence.contiguous())
                    logits = output.logits  # Shape: (batch_size, seq_len, vocab_size)
                    logits_last_token = logits[:, -1, :]  # Shape: (batch_size, vocab_size)
                    top2_tokens = logits_last_token.topk(2).indices  # Shape: (batch_size, 2)
                    next_token = top2_tokens[:, 0] # shape: batch_size, 1
                    
                    if (next_token == eos_idx).any() and step != max_len - 1:
                        next_token[next_token == eos_idx] = top2_tokens[next_token == eos_idx, 1]

                    predictions[:, step - 1] = next_token 
                
                predictions[:, max_len - 1] = eos_idx  # Ensure EOS token is added at the end
            # Token-level accuracy
            
            non_pad_mask = (labels != tgt_pad_idx)  # Ignore padding
            correct_tokens = (predictions == labels) & non_pad_mask
            test_correct += correct_tokens.sum().item()
            test_total += non_pad_mask.sum().item()

            # Sequence-level accuracy
            for pred_seq, target_seq in zip(predictions, labels):
                pred_seq = pred_seq[: (pred_seq == eos_idx).nonzero(as_tuple=True)[0][0]] if eos_idx in pred_seq else pred_seq
                target_seq = target_seq[: (target_seq == eos_idx).nonzero(as_tuple=True)[0][0]] if eos_idx in target_seq else target_seq

                pred_seq = pred_seq[pred_seq != tgt_pad_idx]  # Remove padding
                target_seq = target_seq[target_seq != tgt_pad_idx]  # Remove padding

                if torch.equal(pred_seq, target_seq):
                    sequence_correct += 1

            total_sequences += batch_size

    # Calculate final token and sequence accuracies
    token_accuracy = test_correct / test_total if test_total > 0 else 0
    sequence_accuracy = sequence_correct / total_sequences if total_sequences > 0 else 0

    return token_accuracy, sequence_accuracy



In [None]:
dropout_rate = 0.2 
config = T5Config.from_pretrained('t5-small', 
                                  attention_dropout=dropout_rate,
                                  activation_dropout=dropout_rate,
                                  dropout_rate=dropout_rate)


tokenizer = T5Tokenizer.from_pretrained('t5-small')

# Custom tokens and device setup
#custom_tokens = ['I_WALK', 'I_TURN_LEFT', 'I_RUN', 'I_LOOK', 'I_JUMP', 'I_TURN_RIGHT']
#tokenizer.add_tokens(custom_tokens)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

tgt_pad_idx = tokenizer.pad_token_id
eos_idx = tokenizer.eos_token_id

train_file_path = os.path.join(DATA_DIR, f"tasks_train_length.txt")
test_file_path = os.path.join(DATA_DIR, f"tasks_test_length.txt")

# Parse the datasets
train_data = parse_dataset(train_file_path)
test_data = parse_dataset(test_file_path)

train_set, val_set = split_dataset(train_data)

# Create dataloaders for training, validation, and test sets
train_loader = create_dataloader(train_set, batch_size=16, shuffle=True, tokenizer=tokenizer)
val_loader = create_dataloader(val_set, batch_size=16, shuffle=True, tokenizer=tokenizer)
test_loader = create_dataloader(test_data, batch_size=16, shuffle=False, tokenizer=tokenizer)

# create dataloaders from command dict and action dict
file_path = os.path.join(DATA_DIR, "tasks_test_length.txt")
parsed_data = parse_dataset(file_path)
command_dict, action_dict = create_dictionaries_and_histograms(parsed_data)

# Sort the dictionaries by key (i.e., command and action lengths)
sorted_command_dict = dict(sorted(command_dict.items()))  # Sorting by key
sorted_action_dict = dict(sorted(action_dict.items()))    # Sorting by key

command_lengths = sorted_command_dict.keys()
action_lengths = sorted_action_dict.keys()

# Create data loaders for each sorted command and action data
command_data_loaders = [create_dataloader(v,tokenizer,16, shuffle=False) for k, v in sorted_command_dict.items()]
action_data_loaders = [create_dataloader(v, tokenizer,16, shuffle=False) for k, v in sorted_action_dict.items()]

# Lists to store the accuracies
token_accuracies_command_seq = []
sequence_accuracies_command_seq = []
token_accuracies_oracle_command_seq = []
sequence_accuracies_oracle_command_seq = []

token_accuracies_action_seq = []
sequence_accuracies_action_seq = []
token_accuracies_oracle_action_seq = []
sequence_accuracies_oracle_action_seq = []

    
model = T5ForConditionalGeneration.from_pretrained('t5-small', config=config).to(device)
#model.resize_token_embeddings(len(tokenizer))
train_and_validate(model, train_loader, val_loader, device, num_epochs=10, lr=3e-5)  # Train model

for command_length, command_loader in zip(command_lengths,command_data_loaders):
    print(f"Evaluating command sequences of length {command_length}")
    # Evaluate without oracle
    token_accuracy, sequence_accuracy = evaluate_model(model, command_loader, device, tgt_pad_idx, eos_idx, oracle=False)
    token_accuracies_command_seq.append(token_accuracy)
    sequence_accuracies_command_seq.append(sequence_accuracy)
    print(f"Command Length {command_length} - No oracle: Token Accuracy: {token_accuracy}, Sequence Accuracy: {sequence_accuracy}")
        
    # Evaluate with oracle
    token_accuracy, sequence_accuracy = evaluate_model(model, command_loader, device, tgt_pad_idx, eos_idx, oracle=True)
    token_accuracies_oracle_command_seq.append(token_accuracy)
    sequence_accuracies_oracle_command_seq.append(sequence_accuracy)
    print(f"Command Length {command_length} - With oracle: Token Accuracy: {token_accuracy}, Sequence Accuracy: {sequence_accuracy}")

# Evaluate action sequences
for action_length, action_loader in zip(action_lengths,action_data_loaders):
    print(f"Evaluating action sequences of length {action_length}")
        
    # Evaluate without oracle
    token_accuracy, sequence_accuracy = evaluate_model(model, action_loader, device, tgt_pad_idx, eos_idx, oracle=False)
    token_accuracies_action_seq.append(token_accuracy)
    sequence_accuracies_action_seq.append(sequence_accuracy)
    print(f"Action Length {action_length} - No oracle: Token Accuracy: {token_accuracy}, Sequence Accuracy: {sequence_accuracy}")
        
    # Evaluate with oracle
    token_accuracy, sequence_accuracy = evaluate_model(model, action_loader, device, tgt_pad_idx, eos_idx, oracle=True)
    token_accuracies_oracle_action_seq.append(token_accuracy)
    sequence_accuracies_oracle_action_seq.append(sequence_accuracy)
    print(f"Action Length {action_length} - With oracle: Token Accuracy: {token_accuracy}, Sequence Accuracy: {sequence_accuracy}")


# evaluate on whole test set

token_accuracy, sequence_accuracy = evaluate_model(model, test_loader, device, tgt_pad_idx, eos_idx, oracle=False)
token_accuracy_oracle, sequence_accuracy_oracle = evaluate_model(model, test_loader, device, tgt_pad_idx, eos_idx, oracle=True)
print("Without oracle :", token_accuracy, sequence_accuracy)
print("With oracle:", token_accuracy_oracle, sequence_accuracy_oracle)

In [None]:
def plot_accuracies():
    # Create the figure for token-level results without oracle
    fig, axes = plt.subplots(1, 2, figsize=(16, 6), sharey=False)
    
    # Left subplot: Token-level accuracy (Action sequences) without oracle
    axes[0].bar([f"{k}" for k in sorted_action_dict.keys()], token_accuracies_action_seq, color='#D6A9D0')
    axes[0].set_xlabel('Ground-Truth Action Sequence Length (in words)')
    axes[0].set_ylabel('Accuracy on New Commands (%)')
    axes[0].set_title('Token-Level Accuracy by Action Sequence Length')
    axes[0].grid(True)

    # Right subplot: Token-level accuracy (Command sequences) without oracle
    axes[1].bar([f"{k}" for k in sorted_command_dict.keys()], token_accuracies_command_seq, color='#D6A9D0')
    axes[1].set_xlabel('Command Length (in words)')
    axes[1].set_ylabel('Accuracy on New Commands (%)')
    axes[1].set_title('Token-Level Accuracy by Command Length')
    axes[1].grid(True)
    
    # Adjust layout and save the figure
    plt.tight_layout()
    plt.savefig(os.path.join(OUTPUT_DIR, "token_level_results_without_oracle.png"))
    plt.show()
    
    # Create the figure for token-level results with oracle
    fig, axes = plt.subplots(1, 2, figsize=(16, 6), sharey=False)

    # Left subplot: Token-level accuracy (Action sequences) with Oracle
    axes[0].bar([f"{k}" for k in sorted_action_dict.keys()], token_accuracies_oracle_action_seq, color='#D6A9D0')
    axes[0].set_xlabel('Ground-Truth Action Sequence Length (in words)')
    axes[0].set_ylabel('Accuracy on New Commands (%)')
    axes[0].set_title('Token-Level Accuracy by Action Sequence Length')
    axes[0].grid(True)

    # Right subplot: Token-level accuracy (Command sequences) with Oracle
    axes[1].bar([f"{k}" for k in sorted_command_dict.keys()], token_accuracies_oracle_command_seq, color='#D6A9D0')
    axes[1].set_xlabel('Command Length (in words)')
    axes[1].set_ylabel('Accuracy on New Commands (%)')
    axes[1].set_title('Token-Level Accuracy by Command Length')
    axes[1].grid(True)
    
    # Adjust layout and save the figure
    plt.tight_layout()
    plt.savefig(os.path.join(OUTPUT_DIR, "token_level_results_with_oracle.png"))
    plt.show()

    # Create the figure for sequence-level results with oracle
    fig, axes = plt.subplots(1, 2, figsize=(16, 6), sharey=False)

    # Left subplot: Sequence-level accuracy (Action sequences) with Oracle
    axes[0].bar([f"{k}" for k in sorted_action_dict.keys()], sequence_accuracies_oracle_action_seq, color='#D6A9D0')
    axes[0].set_xlabel('Ground-Truth Action Sequence Length (in words)')
    axes[0].set_ylabel('Accuracy on New Commands (%)')
    axes[0].set_title('Sequence-Level Accuracy by Action Sequence Length')
    axes[0].grid(True)

    # Right subplot: Sequence-level accuracy (Command sequences) with Oracle
    axes[1].bar([f"{k}" for k in sorted_command_dict.keys()], sequence_accuracies_oracle_command_seq, color='#D6A9D0')
    axes[1].set_xlabel('Command Length (in words)')
    axes[1].set_ylabel('Accuracy on New Commands (%)')
    axes[1].set_title('Sequence-Level Accuracy by Command Length')
    axes[1].grid(True)

    # Adjust layout and save the figure
    plt.tight_layout()
    plt.savefig(os.path.join(OUTPUT_DIR, "sequence_level_results_with_oracle.png"))
    plt.show()

# Call the function to generate and plot the results
plot_accuracies()