### EXPT 2: Model V2 (MultiHeadAttentionPooling and Dropout Layer) with hyperparameter optimization and early stopping

In [1]:
import pandas as pd
import numpy as np

import torch
import torch.optim as optim

from sklearn.metrics import roc_auc_score, average_precision_score, accuracy_score, confusion_matrix

from mil_model.dataloader import create_dataloaders
from mil_model.preprocessing import load_labels, load_dataset_json, build_bags
from mil_model.train_utils import split_data, get_model_weights
from mil_model.loss import FocalLoss
from mil_model.model_v2 import MILModel

In [2]:
# process data
labels, gene_id, transcript_id, transcript_pos = load_labels('data/data.info.txt')
data_list = load_dataset_json('data/dataset0.json')
bags = build_bags(data_list)

df = pd.DataFrame({
    'gene_id': gene_id,
    'transcript_id': transcript_id,
    'transcript_position': transcript_pos,
    'bags': bags,
    'label': labels
})

In [3]:
# set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [4]:
# split data
bags_train, labels_train, bags_val, labels_val, bags_test, labels_test = split_data(bags, labels)

In [5]:
# build dataloaders
batch_size = 1 # because we want the output tensor to only process one bag at a time 

train_loader, val_loader, test_loader = create_dataloaders(bags_train, labels_train, bags_val, labels_val, bags_test, labels_test, batch_size = batch_size)

In [6]:
# compute weights
class_weights = get_model_weights(labels_train, device)

In [8]:
# parameter for training
input_dim = 9
hidden_dim = 128 # og was 64
learning_rate = 1e-4

num_epochs = 1 # change epochs
threshold = 0.5

weight_decay = 1e-5
alpha = 0.25
gamma = 2

# can experiment with this too
num_heads = 4 
dropout_rate = 0.2

criterion = FocalLoss(alpha=alpha, gamma=gamma)
model = MILModel(input_dim, hidden_dim, num_heads, dropout_rate).to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate, betas=(0.9, 0.999), weight_decay=weight_decay)

In [9]:
# training 
# training 

best_loss = np.inf
epochs_without_improvement = 0
patience = 10 # stop after 10 epochs of no improvement

for epoch in range(num_epochs):
    model.train()
    epoch_loss = 0

    for bags, labels, masks in train_loader:
        bags, labels, masks = bags.to(device), labels.to(device), masks.to(device)  # Move to device
        # reset gradients
        optimizer.zero_grad()
        outputs = model(bags, masks)
        # calculate loss and metrics
        loss = criterion(outputs, labels.float())
        # backward pass
        loss.backward()
        # step
        optimizer.step()
        epoch_loss += loss.item()

    avg_epoch_loss = epoch_loss / len(train_loader)

    # Early stopping
    if avg_epoch_loss < best_loss:
        best_loss = avg_epoch_loss
        epochs_without_improvement = 0
        # Save your model here
    else:
        epochs_without_improvement += 1
    
    if epochs_without_improvement >= patience:
        print(f"Early stopping at epoch {epoch}")
        break

    print(f'Epoch [{epoch}/{num_epochs}], Loss: {avg_epoch_loss:.4f}')
    
    # Validation
    model.eval()
    with torch.no_grad():
        all_labels = []
        all_outputs = []
        val_losses = []
        for bags, labels, masks in val_loader:
            bags, labels, masks = bags.to(device), labels.to(device), masks.to(device) 
            outputs = model(bags, masks)
            outputs = torch.sigmoid(outputs)

            loss = criterion(outputs, labels.float())
            val_losses.append(loss.item())

            all_labels.extend(labels.cpu().numpy())
            all_outputs.extend(outputs.cpu().numpy())
        
        print(all_labels)
        print(all_outputs)

        roc_auc = roc_auc_score(all_labels, all_outputs)
        pr_auc = average_precision_score(all_labels, all_outputs)
        acc = accuracy_score(all_labels, (np.array(all_outputs) > threshold).astype(int))
        cm = confusion_matrix(all_labels, (np.array(all_outputs) > threshold).astype(int))
        print(f"Epoch {epoch+1}/{num_epochs}, "
              f"Train Loss: {avg_epoch_loss:.4f}, "
              f"Val Loss: {np.mean(val_losses):.4f}, "
              f"Val ROC-AUC: {roc_auc:.4f}, "
              f"Val PR-AUC: {pr_auc:.4f}, "
              f"Val Accuracy: {acc:.4f}")

  bag = torch.tensor(bag, dtype=torch.float32)


KeyboardInterrupt: 

In [None]:
torch.save(model.state_dict(), 'weights/expt2.pth')