### EXPT 2: Model V2 with balancing data (experimenting with different ratios)

In [2]:
import pandas as pd
import numpy as np

import torch
import torch.optim as optim

from sklearn.metrics import roc_auc_score, average_precision_score, accuracy_score, confusion_matrix

from mil_model.dataloader import create_dataloaders
from mil_model.preprocessing import load_labels, load_dataset_json, build_bags
from mil_model.train_utils import split_data, get_model_weights, get_pos_neg, balance_data
from mil_model.loss import FocalLoss
from mil_model.model import MILModel

In [3]:
# process data
labels, gene_id, transcript_id, transcript_pos = load_labels('data/data.info.txt')
data_list = load_dataset_json('data/dataset0.json')
bags = build_bags(data_list)

df = pd.DataFrame({
    'gene_id': gene_id,
    'transcript_id': transcript_id,
    'transcript_position': transcript_pos,
    'bags': bags,
    'label': labels
})

In [4]:
# set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [5]:
# balance data 
bags_pos, labels_pos, bags_neg, labels_neg = get_pos_neg(bags, labels)

# expt with diff desired ratios and see training metrics 
desired_ratio = [1,4]

bags_resampled, labels_resampled = balance_data(bags_pos, labels_pos, bags_neg, labels_neg, desired_ratio)

In [6]:
# split data
bags_train, labels_train, bags_val, labels_val, bags_test, labels_test = split_data(bags, labels)

In [7]:
# build dataloaders
batch_size = 16 # because we want the output tensor to only process one bag at a time 

train_loader, val_loader, test_loader = create_dataloaders(bags_train, labels_train, bags_val, labels_val, bags_test, labels_test, batch_size = batch_size)

In [8]:
# compute weights
class_weights = get_model_weights(labels_train, device)

In [9]:
# parameter for training
input_dim = 9
hidden_dim = 128 # og was 64
learning_rate = 1e-4

num_epochs = 10 # change epochs 
threshold = 0.5

weight_decay = 1e-5
alpha = 0.25
gamma = 4

criterion = FocalLoss(alpha=alpha, gamma=gamma)
model = MILModel(input_dim, hidden_dim).to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate, betas=(0.9, 0.999), weight_decay=weight_decay)

In [11]:
# training 

best_loss = np.inf
epochs_without_improvement = 0
patience = 10 # stop after 10 epochs of no improvement

for epoch in range(num_epochs):
    model.train()
    epoch_loss = 0

    for bags, labels, masks in train_loader:
        bags, labels, masks = bags.to(device), labels.to(device), masks.to(device)  # Move to device
        # reset gradients
        optimizer.zero_grad()
        outputs = model(bags, masks)
        # calculate loss and metrics
        loss = criterion(outputs, labels.float())
        # backward pass
        loss.backward()
        # step
        optimizer.step()
        epoch_loss += loss.item()

    avg_epoch_loss = epoch_loss / len(train_loader)

    # Early stopping
    if avg_epoch_loss < best_loss:
        best_loss = avg_epoch_loss
        epochs_without_improvement = 0
        # Save your model here
    else:
        epochs_without_improvement += 1
    
    if epochs_without_improvement >= patience:
        print(f"Early stopping at epoch {epoch}")
        break

    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {avg_epoch_loss:.4f}')
    
    # Validation
    model.eval()
    with torch.no_grad():
        all_labels = []
        all_outputs = []
        val_losses = []
        for bags, labels, masks in val_loader:
            bags, labels, masks = bags.to(device), labels.to(device), masks.to(device) 
            outputs = model(bags, masks)
            outputs = torch.sigmoid(outputs)

            loss = criterion(outputs, labels.float())
            val_losses.append(loss.item())

            all_labels.extend(labels.cpu().numpy())
            all_outputs.extend(outputs.cpu().numpy())

        roc_auc = roc_auc_score(all_labels, all_outputs)
        pr_auc = average_precision_score(all_labels, all_outputs)
        acc = accuracy_score(all_labels, (np.array(all_outputs) > threshold).astype(int))
        cm = confusion_matrix(all_labels, (np.array(all_outputs) > threshold).astype(int))
        print(f"Epoch {epoch+1}/{num_epochs}, "
              f"Train Loss: {avg_epoch_loss:.4f}, "
              f"Val Loss: {np.mean(val_losses):.4f}, "
              f"Val ROC-AUC: {roc_auc:.4f}, "
              f"Val PR-AUC: {pr_auc:.4f}, "
              f"Val Accuracy: {acc:.4f}")

Epoch [1/10], Loss: 0.0041
Epoch 1/10, Train Loss: 0.0041, Val Loss: 0.0220, Val ROC-AUC: 0.7880, Val PR-AUC: 0.2749, Val Accuracy: 0.9560
Epoch [2/10], Loss: 0.0035
Epoch 2/10, Train Loss: 0.0035, Val Loss: 0.0201, Val ROC-AUC: 0.7843, Val PR-AUC: 0.2492, Val Accuracy: 0.9560
Epoch [3/10], Loss: 0.0033
Epoch 3/10, Train Loss: 0.0033, Val Loss: 0.0239, Val ROC-AUC: 0.8203, Val PR-AUC: 0.2837, Val Accuracy: 0.9564
Epoch [4/10], Loss: 0.0033
Epoch 4/10, Train Loss: 0.0033, Val Loss: 0.0215, Val ROC-AUC: 0.8238, Val PR-AUC: 0.3049, Val Accuracy: 0.9560
Epoch [5/10], Loss: 0.0031
Epoch 5/10, Train Loss: 0.0031, Val Loss: 0.0229, Val ROC-AUC: 0.8140, Val PR-AUC: 0.2998, Val Accuracy: 0.9564
Epoch [6/10], Loss: 0.0031
Epoch 6/10, Train Loss: 0.0031, Val Loss: 0.0206, Val ROC-AUC: 0.8285, Val PR-AUC: 0.3027, Val Accuracy: 0.9560
Epoch [7/10], Loss: 0.0030
Epoch 7/10, Train Loss: 0.0030, Val Loss: 0.0227, Val ROC-AUC: 0.8259, Val PR-AUC: 0.3112, Val Accuracy: 0.9567
Epoch [8/10], Loss: 0.0030


In [None]:
torch.save(model.state_dict(), 'weights/expt3.pth')