# Train Router_NoSynth on NoSynthQA

In [1]:
import random
import sys
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import ast
import torch
import os
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, TensorDataset
from torch.nn import Linear
from torch.optim import Adam
from collections import deque

from router import Router


In [2]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")

In [3]:
# Create training data (x=question, y=labels)
train_df = pd.read_csv("train_metadata.csv", usecols=['source', 'embedding', 'label'])
train_df.shape

(92960, 3)

In [4]:
# Remove ComSciQA
train_df = train_df[train_df['source'] != 'ComSciQA']
train_df.shape

(59420, 3)

In [5]:
# Convert columns to lists
train_df['embedding'] = train_df['embedding'].apply(lambda x: ast.literal_eval(x))
# train_df['label'] = train_df['label'].apply(lambda x: ast.literal_eval(x))

In [6]:
# Change [0,0,0] to [0,0,0,1] to simulate choosing a fallback LLM
def expand_labels(label):
    if label == '[0, 0, 0]':
        return [0,0,0,1]
    else:
        return ast.literal_eval(label) + [0]
    
train_df['label'] = train_df['label'].apply(expand_labels)

In [7]:
embeddings = np.array(train_df['embedding'].tolist())
labels = np.array(train_df['label'].tolist())

np.random.seed(42)
indices = np.arange(embeddings.shape[0])
shuffled_indices = np.random.permutation(indices)
embeddings = embeddings[shuffled_indices]
labels = labels[shuffled_indices]

In [8]:
# Double checking we don't have any cases where no model is chosen
contains_zero = np.any(np.all(labels == [0,0,0,0], axis=1))
contains_zero

np.False_

In [9]:
train_embeddings, temp_embeddings, train_labels, temp_labels = train_test_split(embeddings, labels, test_size=0.2, random_state=42) 
val_embeddings, test_embeddings, val_labels, test_labels = train_test_split(temp_embeddings, temp_labels, test_size=0.5, random_state=42)

train_embeddings = torch.tensor(train_embeddings, dtype=torch.float32)
val_embeddings = torch.tensor(val_embeddings, dtype=torch.float32)
test_embeddings = torch.tensor(test_embeddings, dtype=torch.float32)
train_labels = torch.tensor(train_labels, dtype=torch.float32)
val_labels = torch.tensor(val_labels, dtype=torch.float32)
test_labels = torch.tensor(test_labels, dtype=torch.float32)

train_data = TensorDataset(train_embeddings, train_labels)
val_data = TensorDataset(val_embeddings, val_labels)
test_data = TensorDataset(test_embeddings, test_labels)

In [10]:
print('Train Shape:', train_embeddings.shape)
print('Test Shape:', test_embeddings.shape)
print('Validation Shape:', val_embeddings.shape)

Train Shape: torch.Size([47536, 1024])
Test Shape: torch.Size([5942, 1024])
Validation Shape: torch.Size([5942, 1024])


## Training

In [11]:
# Load trained router
def load_router(device, model_path="router.pth"):
    router = Router(input_dim=1024, output_dim=4).to(device)
    router.load_state_dict(torch.load(model_path, map_location=device))
    router.train()
    return router

In [None]:
def train_model(model, train_data, val_data, device, patience=5, lr=0.001, epochs=10, batch_size=32, ckpt_dir='checkpoints/router'):
    # Set up loss function for imbalanced labels
    num_positives = train_data[:][1].sum(dim=0)
    num_negatives = len(train_data[:][1]) - num_positives
    pos_weight = num_negatives / num_positives
    criterion = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight.to(device))
    # criterion = torch.nn.BCELoss()
    optimizer = Adam(model.parameters(), lr=lr)

    train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_data, batch_size=batch_size, shuffle=False)

    best_val_acc = 0
    all_train_loss = []
    all_train_acc = []
    all_val_loss = []
    all_val_acc = []
    val_loss_history = deque(maxlen=patience)

    for epoch in range(epochs):
        # Training
        model.train()
        train_loss = 0
        train_correct = 0
        train_total = 0

        for X_batch, y_batch in train_loader:
            X_batch, y_batch = X_batch.to(device), y_batch.to(device)

            optimizer.zero_grad()
            outputs = model(X_batch.float())

            loss = criterion(outputs, y_batch)
            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            predicted = (outputs > 0.5).float() # Threshold at 0.5 for multi-label classification
            # train_correct += (predicted == y_batch).all(dim=1).sum().item()
            matches = (predicted * y_batch).sum(dim=1) > 0
            train_correct += matches.sum().item()
            train_total += y_batch.size(0)
        
        train_acc = train_correct / train_total
        train_loss = train_loss / len(train_loader)
        all_train_acc.append(train_acc)
        all_train_loss.append(train_loss)

        # Validation
        model.eval()
        val_loss = 0
        val_correct = 0
        val_total = 0
        with torch.no_grad():
            for X_batch, y_batch in val_loader:
                X_batch, y_batch = X_batch.to(device), y_batch.to(device)
                outputs = model(X_batch.float())
                val_loss += criterion(outputs, y_batch).item()
                predicted = (outputs > 0.5).float() # Threshold at 0.5 for multi-label classification
                # val_correct += (predicted == y_batch).all(dim=1).sum().item()
                matches = (predicted * y_batch).sum(dim=1) > 0
                val_correct += matches.sum().item()
                val_total += y_batch.size(0)

        
        val_acc = val_correct / val_total
        val_loss = val_loss / len(val_loader)
        all_val_acc.append(val_acc)
        all_val_loss.append(val_loss)
        if epoch % 5 == 0:
            print(f"Epoch [{epoch}/{epochs}] - Train Loss: {train_loss:.5f}, Val Loss: {val_loss:.5f}, Train Accuracy: {train_acc:.5f}, Val Accuracy: {val_acc:.5f}")

        # if epoch == 0:
        #     val_loss_history.append(val_loss)
        # # Reset patience if current validation loss is less than the previous validation loss
        # else:
        #     if val_loss < val_loss_history[-1]:
        #         val_loss_history = deque(maxlen=patience)
        #         val_loss_history.append(val_loss)
        #     else:
        #         val_loss_history.append(val_loss)

        val_loss_history.append(val_loss)

        # Early stopping if validation loss doesn't decrease
        if len(val_loss_history) == patience and val_loss >= min(val_loss_history):
            print("Early stopping triggered")       
            print(f'Highest Validation Accuracy: {best_val_acc}, Lowest Validation Loss: {min(all_val_loss)}')
            torch.save(model.state_dict(), os.path.join(ckpt_dir, f'epoch={epoch}_loss={loss:.3f}_tacc={train_acc:.3f}_vacc={val_acc:.3f}.pth'))
            return all_train_loss, all_train_acc, all_val_loss, all_val_acc
        
        # Save model if validation accuracy improves
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save(model.state_dict(), os.path.join(ckpt_dir, f'epoch={epoch}_loss={loss:.3f}_tacc={train_acc:.3f}_vacc={val_acc:.3f}.pth'))
    
    print(f'Highest Validation Accuracy: {best_val_acc}, Lowest Validation Loss: {min(all_val_loss)}')
    return all_train_loss, all_train_acc, all_val_loss, all_val_acc



In [13]:
def evaluate_model(model, test_data, device):
    test_loader = DataLoader(test_data, batch_size=32, shuffle=False)

    model.eval()
    test_correct = 0
    test_total = 0
    with torch.no_grad():
        for X_batch, y_batch in test_loader:
            X_batch, y_batch = X_batch.to(device), y_batch.to(device)
            outputs = model(X_batch.float())
            predicted = (outputs > 0.5).float() # Threshold at 0.5 for multi-label classification
            # test_correct += (predicted == y_batch).all(dim=1).sum().item()
            matches = (predicted * y_batch).sum(dim=1) > 0
            test_correct += matches.sum().item()
            test_total += y_batch.size(0)

    test_acc = test_correct / test_total
    print('Test Accuracy:', test_acc)
    return test_acc

In [20]:
# Dictionary to store results
results = {}

In [None]:
# Define multiple hyperparameter sets to test
hyperparameter_configs = [
    {"epochs": 500, "patience": 50, "batch": 16, "lr": 0.0001},
    {"epochs": 500, "patience": 50, "batch": 16, "lr": 0.0005},
    {"epochs": 500, "patience": 50, "batch": 16, "lr": 0.001},
    {"epochs": 500, "patience": 50, "batch": 16, "lr": 0.005},
]

for config in hyperparameter_configs:
    config["label"] = f"epochs={config['epochs']}_patience={config['patience']}_batch={config['batch']}_lr={config['lr']}"

# Directory to store model checkpoints
save_dir = "./checkpoints_noSynth"
os.makedirs(save_dir, exist_ok=True)

# torch.seed()
for config in hyperparameter_configs:
    model = Router(input_dim=len(embeddings[0]), output_dim=len(labels[0])).to(DEVICE)
    model_dir = os.path.join(save_dir, config['label'])
    if os.path.isdir(model_dir):
        print(f"Model exists. Skip training with {config}")
        continue
    else:
        os.makedirs(model_dir, exist_ok=True)
    print(f"Training with {config}")

    all_train_loss, all_train_acc, all_val_loss, all_val_acc = train_model(model, 
                                                                        train_data, 
                                                                        val_data, 
                                                                        DEVICE, 
                                                                        patience=config['patience'], 
                                                                        lr=config['lr'], 
                                                                        epochs=config['epochs'], 
                                                                        batch_size=config['batch'], 
                                                                        ckpt_dir=model_dir
                                                                        )
    
    test_acc = evaluate_model(model, test_data, DEVICE)
    
    results[config['label']] = [config['epochs'], config['patience'], config['lr'], config['batch'], all_train_loss, all_train_acc, all_val_loss, all_val_acc, test_acc]
    print('-----'*20)


In [None]:
formatted_results = [(key, *values) for key, values in results.items()]
results_df = pd.DataFrame(formatted_results, columns=['Model_Parameters', 'Epochs', 'Patience', 'Learning_Rate', 'Batch_Size', 'Train_Loss', 'Training_Accuracy', 'Validation_Loss', 'Validation_Accuracy', 'Test_Accuracy'])
results_df.to_csv("./checkpoints_noSynth/results_summary_noSynth.csv")