# HSTrans - Drug-Side Effect Prediction Training

This notebook contains the complete training pipeline for the HSTrans model, a Transformer-based approach for predicting drug-side effect interactions.

## üéØ Project Overview
- **Model**: HSTrans (Hierarchical Transformer)
- **Task**: Drug-Side Effect Prediction
- **Architecture**: Cross-attention Transformer with substructure encoding
- **Training**: 10-fold cross-validation

## üìã Table of Contents
1. [Setup & Dependencies](#setup)
2. [Data Loading & Preparation](#data)
3. [Model Architecture](#model)
4. [Training Functions](#training)
5. [Cross-Validation Training](#cv-training)
6. [Results Analysis](#results)

## üõ†Ô∏è 1. Setup & Dependencies

First, let's install all required dependencies and set up the environment.

In [None]:
# Install required packages
!pip install torch>=1.9.0 numpy>=1.19.0 pandas>=1.2.0 scipy>=1.6.0 
!pip install scikit-learn>=0.24.0 matplotlib>=3.3.0 rdkit-pypi>=2021.9.1 
!pip install subword-nmt>=0.3.7 networkx>=2.5 tqdm

print("‚úÖ All dependencies installed successfully!")

In [None]:
# Import essential libraries
import os
import json
import pickle
import random
import argparse
import warnings
from pathlib import Path
from tqdm.auto import tqdm

# Scientific computing
import numpy as np
import pandas as pd
import scipy
from scipy import io
from math import sqrt

# Machine Learning
from sklearn.metrics import (roc_auc_score, average_precision_score, 
                             precision_score, recall_score, accuracy_score)
from sklearn.model_selection import StratifiedKFold
from scipy import stats

# Deep Learning
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
from torch import optim

# Visualization
import matplotlib.pyplot as plt
import seaborn as sns

# RDKit for chemistry
from rdkit import Chem
import networkx as nx
import codecs
from subword_nmt.apply_bpe import BPE

# Settings
warnings.filterwarnings('ignore')
plt.style.use('seaborn-v0_8')
sns.set_palette("husl")

# Check GPU availability
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"üñ•Ô∏è  Using device: {device}")
if torch.cuda.is_available():
    print(f"üìä GPU: {torch.cuda.get_device_name(0)}")
    print(f"üíæ GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

## üìÅ 2. Data Loading & Preparation

Let's download the data and set up the file structure. We'll use Google Drive for persistent storage.

In [None]:
# Mount Google Drive for persistent storage
from google.colab import drive
drive.mount('/content/drive')

# Create working directory
WORK_DIR = '/content/HSTrans'
DATA_DIR = f'{WORK_DIR}/data'
SUB_DIR = f'{DATA_DIR}/sub'
RESULTS_DIR = f'{WORK_DIR}/results'
CHECKPOINTS_DIR = f'{WORK_DIR}/checkpoints'
PREDICT_DIR = f'{WORK_DIR}/predictResult'

# Create directories
for path in [WORK_DIR, DATA_DIR, SUB_DIR, RESULTS_DIR, CHECKPOINTS_DIR, PREDICT_DIR]:
    Path(path).mkdir(parents=True, exist_ok=True)

print(f"üìÇ Working directory: {WORK_DIR}")
print("‚úÖ Directory structure created!")

In [None]:
# Download data (replace with your actual data source)
# For this example, we'll create a function to download from GitHub or other sources
def download_data():
    """Download required data files"""
    import gdown
    
    # Replace these with your actual file IDs or URLs
    data_files = {
        'raw_frequency_750.mat': 'YOUR_FILE_ID',  # Replace with actual ID
        'drug_SMILES_750.csv': 'YOUR_FILE_ID',   # Replace with actual ID  
        'mask_mat_750.mat': 'YOUR_FILE_ID',      # Replace with actual ID
        'side_effect_label_750.mat': 'YOUR_FILE_ID', # Replace with actual ID
        'drug_side.pkl': 'YOUR_FILE_ID',         # Replace with actual ID
        'drug_codes_chembl_freq_1500.txt': 'YOUR_FILE_ID', # Replace with actual ID
        'subword_units_map_chembl_freq_1500.csv': 'YOUR_FILE_ID' # Replace with actual ID
    }
    
    print("üì• Downloading data files...")
    for filename, file_id in data_files.items():
        if not os.path.exists(f'{DATA_DIR}/{filename}'):
            try:
                url = f'https://drive.google.com/uc?id={file_id}'
                gdown.download(url, f'{DATA_DIR}/{filename}', quiet=False)
                print(f"‚úÖ Downloaded {filename}")
            except Exception as e:
                print(f"‚ö†Ô∏è  Could not download {filename}: {e}")
                print(f"Please manually upload {filename} to {DATA_DIR}")
        else:
            print(f"‚úÖ {filename} already exists")

# Uncomment to download data
# download_data()

print("\nüìã Please ensure the following files are in your data directory:")
print("   - raw_frequency_750.mat")
print("   - drug_SMILES_750.csv")
print("   - mask_mat_750.mat")
print("   - side_effect_label_750.mat")
print("   - drug_side.pkl")
print("   - drug_codes_chembl_freq_1500.txt")
print("   - subword_units_map_chembl_freq_1500.csv")
print("\nüí° You can upload these files manually to Google Drive or use the download_data() function")

In [None]:
# Check if data files exist
required_files = [
    'raw_frequency_750.mat',
    'drug_SMILES_750.csv', 
    'mask_mat_750.mat',
    'side_effect_label_750.mat',
    'drug_side.pkl',
    'drug_codes_chembl_freq_1500.txt',
    'subword_units_map_chembl_freq_1500.csv'
]

missing_files = []
for file in required_files:
    if os.path.exists(f'{DATA_DIR}/{file}'):
        print(f"‚úÖ {file}")
    else:
        print(f"‚ùå {file} - MISSING")
        missing_files.append(file)

if missing_files:
    print(f"\n‚ö†Ô∏è  {len(missing_files)} files are missing. Please upload them to continue.")
else:
    print("\nüéâ All data files are ready!")

## üèóÔ∏è 3. Model Architecture

Now let's define the HSTrans model architecture and utility functions.

In [None]:
# Utility functions for metrics
def rmse(y, f):
    return sqrt(((y - f) ** 2).mean())

def mse(y, f):
    return ((y - f) ** 2).mean()

def pearson(y, f):
    return np.corrcoef(y, f)[0, 1]

def spearman(y, f):
    return stats.spearmanr(y, f)[0]

def MAE(y, f):
    import sklearn
    return sklearn.metrics.mean_absolute_error(y, f)

print("üîß Utility functions loaded!")

In [None]:
# SMILES processing functions
def atom_features(atom):
    HYB_list = [Chem.rdchem.HybridizationType.S, Chem.rdchem.HybridizationType.SP,
                Chem.rdchem.HybridizationType.SP2, Chem.rdchem.HybridizationType.SP3,
                Chem.rdchem.HybridizationType.SP3D, Chem.rdchem.HybridizationType.SP3D2,
                Chem.rdchem.HybridizationType.UNSPECIFIED, Chem.rdchem.HybridizationType.OTHER]
    return np.array(one_of_k_encoding_unk(atom.GetSymbol(),
                                          ['C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br', 'Mg', 'Na', 'Ca', 'Fe', 'As',
                                           'Al', 'I', 'B', 'V', 'K', 'Tl', 'Yb', 'Sb', 'Sn', 'Ag', 'Pd', 'Co', 'Se',
                                           'Ti', 'Zn', 'H', 'Li', 'Ge', 'Cu', 'Au', 'Ni', 'Cd', 'In', 'Mn', 'Zr', 'Cr',
                                           'Pt', 'Hg', 'Pb', 'Sm', 'Tc', 'Gd', 'Unknown']) +
                    one_of_k_encoding(atom.GetDegree(), [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) +
                    one_of_k_encoding_unk(atom.GetTotalNumHs(), [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) +
                    one_of_k_encoding_unk(atom.GetExplicitValence(), [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) +
                    one_of_k_encoding_unk(atom.GetImplicitValence(), [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) +
                    one_of_k_encoding(atom.GetFormalCharge(), [-4, -3, -2, -1, 0, 1, 2, 3, 4]) +
                    one_of_k_encoding(atom.GetHybridization(), HYB_list) +
                    [atom.GetIsAromatic()])

def one_of_k_encoding(x, allowable_set):
    if x not in allowable_set:
        raise Exception("input {0} not in allowable set{1}:".format(x, allowable_set))
    return list(map(lambda s: x == s, allowable_set))

def one_of_k_encoding_unk(x, allowable_set):
    if x not in allowable_set:
        x = allowable_set[-1]
    return list(map(lambda s: x == s, allowable_set))

def load_drug_smile(file):
    reader = pd.read_csv(file)
    drug_dict = {}
    drug_smile = []
    
    for idx, row in reader.iterrows():
        name = str(row.iloc[0])
        smile = str(row.iloc[1])
        if name not in drug_dict:
            pos = len(drug_dict)
            drug_dict[name] = pos
        drug_smile.append(smile)
    
    return drug_dict, drug_smile

def drug2emb_encoder(smile):
    vocab_path = f'{DATA_DIR}/drug_codes_chembl_freq_1500.txt'
    sub_csv = pd.read_csv(f'{DATA_DIR}/subword_units_map_chembl_freq_1500.csv')

    bpe_codes_drug = codecs.open(vocab_path)
    dbpe = BPE(bpe_codes_drug, merges=-1, separator='')
    idx2word_d = sub_csv['index'].values
    words2idx_d = dict(zip(idx2word_d, range(0, len(idx2word_d))))

    max_d = 50
    t1 = dbpe.process_line(smile).split()
    try:
        i1 = np.asarray([words2idx_d[i] for i in t1])
    except:
        i1 = np.array([0])

    l = len(i1)
    if l < max_d:
        i = np.pad(i1, (0, max_d - l), 'constant', constant_values=0)
        input_mask = ([1] * l) + ([0] * (max_d - l))
    else:
        i = i1[:max_d]
        input_mask = [1] * max_d

    return i, np.asarray(input_mask)

print("üß™ SMILES processing functions loaded!")

In [None]:
# Transformer architecture components
class LayerNorm(nn.Module):
    def __init__(self, hidden_size, variance_epsilon=1e-12):
        super(LayerNorm, self).__init__()
        self.gamma = nn.Parameter(torch.ones(hidden_size))
        self.beta = nn.Parameter(torch.zeros(hidden_size))
        self.variance_epsilon = variance_epsilon

    def forward(self, x):
        u = x.mean(-1, keepdim=True)
        s = (x - u).pow(2).mean(-1, keepdim=True)
        x = (x - u) / torch.sqrt(s + self.variance_epsilon)
        return self.gamma * x + self.beta

class Embeddings(nn.Module):
    def __init__(self, vocab_size, hidden_size, max_position_size, dropout_rate):
        super(Embeddings, self).__init__()
        self.word_embeddings = nn.Embedding(vocab_size, hidden_size)
        self.position_embeddings = nn.Embedding(max_position_size, hidden_size)
        self.LayerNorm = LayerNorm(hidden_size)
        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, input_ids):
        seq_length = input_ids.size(1)
        position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
        position_ids = position_ids.unsqueeze(0).expand_as(input_ids)

        words_embeddings = self.word_embeddings(input_ids)
        position_embeddings = self.position_embeddings(position_ids)

        embeddings = words_embeddings + position_embeddings
        embeddings = self.LayerNorm(embeddings)
        embeddings = self.dropout(embeddings)
        return embeddings

print("üèóÔ∏è Transformer components loaded!")

In [None]:
# Attention mechanisms
class SelfAttention(nn.Module):
    def __init__(self, hidden_size, num_attention_heads, attention_probs_dropout_prob):
        super(SelfAttention, self).__init__()
        self.num_attention_heads = num_attention_heads
        self.attention_head_size = int(hidden_size / num_attention_heads)
        self.all_head_size = self.num_attention_heads * self.attention_head_size

        self.query = nn.Linear(hidden_size, self.all_head_size)
        self.key = nn.Linear(hidden_size, self.all_head_size)
        self.value = nn.Linear(hidden_size, self.all_head_size)

        self.query2 = nn.Linear(hidden_size, self.all_head_size)
        self.key2 = nn.Linear(hidden_size, self.all_head_size)
        self.value2 = nn.Linear(hidden_size, self.all_head_size)

        self.dropout = nn.Dropout(attention_probs_dropout_prob)

    def transpose_for_scores(self, x):
        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
        x = x.view(*new_x_shape)
        return x.permute(0, 2, 1, 3)

    def forward(self, hidden_states, attention_mask, fusion):
        if fusion:
            mixed_query_layer = self.query(hidden_states[0])
            mixed_key_layer = self.key(hidden_states[0])
            mixed_value_layer = self.value(hidden_states[0])

            mixed_query_layer1 = self.query2(hidden_states[1])
            mixed_key_layer1 = self.key2(hidden_states[1])
            mixed_value_layer1 = self.value2(hidden_states[1])

            query_layer = self.transpose_for_scores(mixed_query_layer)
            key_layer = self.transpose_for_scores(mixed_key_layer)
            value_layer = self.transpose_for_scores(mixed_value_layer)

            query_layer1 = self.transpose_for_scores(mixed_query_layer1)
            key_layer1 = self.transpose_for_scores(mixed_key_layer1)
            value_layer1 = self.transpose_for_scores(mixed_value_layer1)

            attention_scores = torch.matmul(query_layer, key_layer1.transpose(-1, -2))
            attention_scores = attention_scores / math.sqrt(self.attention_head_size)
            attention_scores = attention_scores + attention_mask

            attention_scores1 = torch.matmul(query_layer1, key_layer.transpose(-1, -2))
            attention_scores1 = attention_scores1 / math.sqrt(self.attention_head_size)
            attention_scores1 = attention_scores1 + attention_mask
            
            attention_probs = nn.Softmax(dim=-1)(attention_scores)
            attention_probs1 = nn.Softmax(dim=-1)(attention_scores1)

            attention_probs = self.dropout(attention_probs)
            attention_probs1 = self.dropout(attention_probs1)

            context_layer = torch.matmul(attention_probs1, value_layer)
            context_layer1 = torch.matmul(attention_probs, value_layer1)
            context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
            context_layer1 = context_layer1.permute(0, 2, 1, 3).contiguous()

            new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
            new_context_layer_shape1 = context_layer1.size()[:-2] + (self.all_head_size,)

            context_layer = context_layer.view(*new_context_layer_shape)
            context_layer1 = context_layer1.view(*new_context_layer_shape1)

            context_layer = torch.cat((context_layer.unsqueeze(0), context_layer1.unsqueeze(0)), 0)
        else:
            mixed_query_layer = self.query(hidden_states)
            mixed_key_layer = self.key(hidden_states)
            mixed_value_layer = self.value(hidden_states)

            query_layer = self.transpose_for_scores(mixed_query_layer)
            key_layer = self.transpose_for_scores(mixed_key_layer)
            value_layer = self.transpose_for_scores(mixed_value_layer)

            attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
            attention_scores = attention_scores / math.sqrt(self.attention_head_size)
            attention_scores = attention_scores + attention_mask

            attention_probs = nn.Softmax(dim=-1)(attention_scores)
            attention_probs = self.dropout(attention_probs)

            context_layer = torch.matmul(attention_probs, value_layer)
            context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
            new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
            context_layer = context_layer.view(*new_context_layer_shape)

        return context_layer

print("üëÅÔ∏è Attention mechanisms loaded!")

In [None]:
# Encoder layers
import copy
import math

class SelfOutput(nn.Module):
    def __init__(self, hidden_size, hidden_dropout_prob):
        super(SelfOutput, self).__init__()
        self.dense = nn.Linear(hidden_size, hidden_size)
        self.LayerNorm = LayerNorm(hidden_size)
        self.dropout = nn.Dropout(hidden_dropout_prob)

    def forward(self, hidden_states, input_tensor):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.LayerNorm(hidden_states + input_tensor)
        return hidden_states

class Attention(nn.Module):
    def __init__(self, hidden_size, num_attention_heads, attention_probs_dropout_prob, hidden_dropout_prob):
        super(Attention, self).__init__()
        self.self = SelfAttention(hidden_size, num_attention_heads, attention_probs_dropout_prob)
        self.output = SelfOutput(hidden_size, hidden_dropout_prob)

    def forward(self, input_tensor, attention_mask, fusion):
        self_output = self.self(input_tensor, attention_mask, fusion)
        if fusion:
            input_tensor = torch.cat((input_tensor[0].unsqueeze(0), input_tensor[1].unsqueeze(0)), 0)
        attention_output = self.output(self_output, input_tensor)
        return attention_output

class Intermediate(nn.Module):
    def __init__(self, hidden_size, intermediate_size):
        super(Intermediate, self).__init__()
        self.dense = nn.Linear(hidden_size, intermediate_size)

    def forward(self, hidden_states):
        hidden_states = self.dense(hidden_states)
        hidden_states = F.relu(hidden_states)
        return hidden_states

class Output(nn.Module):
    def __init__(self, intermediate_size, hidden_size, hidden_dropout_prob):
        super(Output, self).__init__()
        self.dense = nn.Linear(intermediate_size, hidden_size)
        self.LayerNorm = LayerNorm(hidden_size)
        self.dropout = nn.Dropout(hidden_dropout_prob)

    def forward(self, hidden_states, input_tensor):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.LayerNorm(hidden_states + input_tensor)
        return hidden_states

class Encoder(nn.Module):
    def __init__(self, hidden_size, intermediate_size, num_attention_heads, attention_probs_dropout_prob,
                 hidden_dropout_prob):
        super(Encoder, self).__init__()
        self.attention = Attention(hidden_size, num_attention_heads,
                                   attention_probs_dropout_prob, hidden_dropout_prob)
        self.intermediate = Intermediate(hidden_size, intermediate_size)
        self.output = Output(intermediate_size, hidden_size, hidden_dropout_prob)

    def forward(self, hidden_states, attention_mask, fusion):
        attention_output = self.attention(hidden_states, attention_mask, fusion)
        intermediate_output = self.intermediate(attention_output)
        layer_output = self.output(intermediate_output, attention_output)
        return layer_output

class Encoder_MultipleLayers(nn.Module):
    def __init__(self, n_layer, hidden_size, intermediate_size,
                 num_attention_heads, attention_probs_dropout_prob, hidden_dropout_prob):
        super(Encoder_MultipleLayers, self).__init__()
        layer = Encoder(hidden_size, intermediate_size, num_attention_heads,
                        attention_probs_dropout_prob, hidden_dropout_prob)
        self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(n_layer)])

    def forward(self, hidden_states, attention_mask, fusion, output_all_encoded_layers=True):
        for layer_module in self.layer:
            hidden_states = layer_module(hidden_states, attention_mask, fusion)
        return hidden_states

print("üîÑ Encoder layers loaded!")

In [None]:
# HSTrans Model
class Trans(nn.Module):
    def __init__(self):
        super(Trans, self).__init__()
        
        self.relu = nn.ReLU()

        input_dim_drug = 2586
        transformer_emb_size_drug = 200
        transformer_dropout_rate = 0.1
        transformer_n_layer_drug = 8
        transformer_intermediate_size_drug = 512
        transformer_num_attention_heads_drug = 8
        transformer_attention_probs_dropout = 0.1
        transformer_hidden_dropout_rate = 0.1

        # Embedding layers
        self.embDrug = Embeddings(input_dim_drug,
                              transformer_emb_size_drug,
                              50,
                              transformer_dropout_rate)

        self.embSide = Embeddings(input_dim_drug,
                              transformer_emb_size_drug,
                              50,
                              transformer_dropout_rate)

        # Transformer encoders
        self.encoderDrug = Encoder_MultipleLayers(transformer_n_layer_drug,
                                              transformer_emb_size_drug,
                                              transformer_intermediate_size_drug,
                                              transformer_num_attention_heads_drug,
                                              transformer_attention_probs_dropout,
                                              transformer_hidden_dropout_rate)

        self.encoderSide = Encoder_MultipleLayers(transformer_n_layer_drug,
                                              transformer_emb_size_drug,
                                              transformer_intermediate_size_drug,
                                              transformer_num_attention_heads_drug,
                                              transformer_attention_probs_dropout,
                                              transformer_hidden_dropout_rate)

        # Cross Attention Encoder
        cross_attention_n_layer = 2
        self.crossAttentionencoder = Encoder_MultipleLayers(cross_attention_n_layer,
                                                             transformer_emb_size_drug,
                                                             transformer_intermediate_size_drug,
                                                             transformer_num_attention_heads_drug,
                                                             transformer_attention_probs_dropout,
                                                             transformer_hidden_dropout_rate)

        # Residual Fusion Layers
        self.residual_fusion_drug = nn.Sequential(
            nn.Linear(transformer_emb_size_drug * 2, transformer_emb_size_drug),
            nn.LayerNorm(transformer_emb_size_drug),
            nn.ReLU(),
            nn.Dropout(transformer_dropout_rate),
            nn.Linear(transformer_emb_size_drug, transformer_emb_size_drug)
        )
        
        self.residual_fusion_side = nn.Sequential(
            nn.Linear(transformer_emb_size_drug * 2, transformer_emb_size_drug),
            nn.LayerNorm(transformer_emb_size_drug),
            nn.ReLU(),
            nn.Dropout(transformer_dropout_rate),
            nn.Linear(transformer_emb_size_drug, transformer_emb_size_drug)
        )
        
        # Gating mechanisms
        self.gate_drug = nn.Sequential(
            nn.Linear(transformer_emb_size_drug * 2, transformer_emb_size_drug),
            nn.Sigmoid()
        )
        
        self.gate_side = nn.Sequential(
            nn.Linear(transformer_emb_size_drug * 2, transformer_emb_size_drug),
            nn.Sigmoid()
        )

        self.position_embeddings = nn.Embedding(500, 200)
        self.dropout = 0.3

        # Decoder
        self.decoder = nn.Sequential(
            nn.Linear(6912, 512),
            nn.ReLU(True),
            nn.BatchNorm1d(512),
            nn.Linear(512, 64),
            nn.ReLU(True),
            nn.BatchNorm1d(64),
            nn.Linear(64, 32),
            nn.ReLU(True),
            nn.Linear(32, 1)
        )

        self.icnn = nn.Conv2d(1, 3, 3, padding=0)
        self.CrossAttention = True

    def forward(self, Drug, SE, DrugMask, SEMsak):
        batch = Drug.size(0)
        device = next(self.parameters()).device

        # Drug encoding
        Drug = Drug.long().to(device)
        DrugMask = DrugMask.long().to(device)
        DrugMask = DrugMask.unsqueeze(1).unsqueeze(2)
        DrugMask = (1.0 - DrugMask) * -10000.0
        emb = self.embDrug(Drug)
        encoded_layers = self.encoderDrug(emb.float(), DrugMask.float(), False)
        x_d = encoded_layers

        # Side effect encoding
        SE = SE.long().to(device)
        SEMsak = SEMsak.long().to(device)
        SEMsak = SEMsak.unsqueeze(1).unsqueeze(2)
        SEMsak = (1.0 - SEMsak) * -10000.0
        embE = self.embSide(SE)
        encoded_layers = self.encoderSide(embE.float(), SEMsak.float(), False)
        x_e = encoded_layers

        if self.CrossAttention:
            x_d_original = x_d.clone()
            x_e_original = x_e.clone()
            combined_mask = DrugMask.float()
            
            cross_output = self.crossAttentionencoder([x_d.float(), x_e.float()], combined_mask, True)
            x_d_cross = cross_output[0]
            x_e_cross = cross_output[1]
            
            batch_size, seq_len, hidden_size = x_d_original.shape
            
            x_d_flat = x_d_original.view(-1, hidden_size)
            x_d_cross_flat = x_d_cross.view(-1, hidden_size)
            x_e_flat = x_e_original.view(-1, hidden_size)
            x_e_cross_flat = x_e_cross.view(-1, hidden_size)
            
            x_d_concat = torch.cat([x_d_flat, x_d_cross_flat], dim=-1)
            x_e_concat = torch.cat([x_e_flat, x_e_cross_flat], dim=-1)
            
            gate_d = self.gate_drug(x_d_concat)
            gate_e = self.gate_side(x_e_concat)
            
            x_d_fused = self.residual_fusion_drug(x_d_concat)
            x_e_fused = self.residual_fusion_side(x_e_concat)
            
            x_d = (gate_d * x_d_fused + (1 - gate_d) * x_d_flat).view(batch_size, seq_len, hidden_size)
            x_e = (gate_e * x_e_fused + (1 - gate_e) * x_e_flat).view(batch_size, seq_len, hidden_size)

        # Interaction
        d_aug = torch.unsqueeze(x_d, 2).repeat(1, 1, 50, 1)
        e_aug = torch.unsqueeze(x_e, 1).repeat(1, 50, 1, 1)

        i = d_aug * e_aug
        i_v = i.permute(0, 3, 1, 2)
        i_v = torch.sum(i_v, dim=1)
        i_v = torch.unsqueeze(i_v, 1)
        i_v = F.dropout(i_v, p=self.dropout)

        f = self.icnn(i_v)
        f = f.view(int(batch), -1)

        score = self.decoder(f)

        return score, Drug, SE

print("üöÄ HSTrans model loaded!")
model = Trans()
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"üìä Total trainable parameters: {total_params:,}")

## üéØ 4. Training Functions

Let's define the training, evaluation, and data processing functions.

In [None]:
# Data processing functions
def Extract_positive_negative_samples(DAL, addition_negative_number=''):
    k = 0
    interaction_target = np.zeros((DAL.shape[0] * DAL.shape[1], 3)).astype(int)
    for i in range(DAL.shape[0]):
        for j in range(DAL.shape[1]):
            interaction_target[k, 0] = i
            interaction_target[k, 1] = j
            interaction_target[k, 2] = DAL[i, j]
            k = k + 1
    data_shuffle = interaction_target[interaction_target[:, 2].argsort()]
    number_positive = len(np.nonzero(data_shuffle[:, 2])[0])
    final_positive_sample = data_shuffle[interaction_target.shape[0] - number_positive::]
    negative_sample = data_shuffle[0:interaction_target.shape[0] - number_positive]
    a = np.arange(interaction_target.shape[0] - number_positive)
    a = list(a)
    if addition_negative_number == 'all':
        b = random.sample(a, (interaction_target.shape[0] - number_positive))
    else:
        b = random.sample(a, (1 + addition_negative_number) * number_positive)
    final_negtive_sample = negative_sample[b[0:number_positive], :]
    addition_negative_sample = negative_sample[b[number_positive::], :]
    final_positive_sample = np.concatenate((final_positive_sample, final_negtive_sample), axis=0)
    return addition_negative_sample, final_positive_sample, final_negtive_sample

def identify_sub(data, k):
    print(f'üîç Extracting effective substructures for fold {k}...')
    drug_smile = [item[1] for item in data]
    side_id = [item[0] for item in data]
    labels = [item[2] for item in data]

    # Get SMILE-sub index
    sub_dict = {}
    for i in tqdm(range(len(drug_smile)), desc="Processing drugs"):
        drug_sub, mask = drug2emb_encoder(drug_smile[i])
        drug_sub = drug_sub.tolist()
        sub_dict[i] = drug_sub

    # Save temporary file
    with open(f'{SUB_DIR}/my_dict_{k}.pkl', 'wb') as f:
        pickle.dump(sub_dict, f)
    
    with open(f'{SUB_DIR}/my_dict_{k}.pkl', 'rb') as f:
        sub_dict = pickle.load(f)

    SE_sub = np.zeros((994, 2686))
    for j in tqdm(range(len(drug_smile)), desc="Building substructure matrix"):
        sideID = side_id[j]
        label = float(labels[j])
        for sub_id in sub_dict[j]:
            if sub_id == 0:
                continue
            SE_sub[int(sideID)][int(sub_id)] += label

    np.save(f"{SUB_DIR}/SE_sub_{k}.npy", SE_sub)
    SE_sub = np.load(f"{SUB_DIR}/SE_sub_{k}.npy", allow_pickle=True)

    n = np.sum(SE_sub)
    SE_sum = np.sum(SE_sub, axis=1)
    SE_p = SE_sum / n
    Sub_sum = np.sum(SE_sub, axis=0)
    Sub_p = Sub_sum / n
    SE_sub_p = SE_sub / n

    freq = np.zeros((994, 2686))
    for i in tqdm(range(994), desc="Calculating frequencies"):
        for j in range(2686):
            freq[i][j] = ((SE_sub_p[i][j] - SE_p[i] * Sub_p[j]) / (sqrt((SE_p[i] * Sub_p[j] / n)
                                                                        * (1 - SE_p[i]) *
                                                                        (1 - Sub_p[j])))) + 1e-5
    np.save(f"{SUB_DIR}/freq_{k}.npy", freq)
    freq = np.load(f"{SUB_DIR}/freq_{k}.npy", allow_pickle=True)
    non_nan_values = freq[~np.isnan(freq)]
    percentile_95 = np.percentile(non_nan_values, 95)
    print(f"üìä 95th percentile: {percentile_95:.4f}")

    l = []
    SE_sub_index = np.zeros((994, 50))
    for i in tqdm(range(994), desc="Extracting top substructures"):
        k_count = 0
        sorted_indices = np.argsort(freq[i])[::-1]
        filtered_indices = sorted_indices[freq[i][sorted_indices] > percentile_95]
        l.append(len(filtered_indices))
        for j in filtered_indices:
            if k_count < 50:
                SE_sub_index[i][k_count] = j
                k_count = k_count + 1
            else:
                continue

    np.save(f"{SUB_DIR}/SE_sub_index_50_{k}.npy", SE_sub_index)
    np.save(f"{SUB_DIR}/SE_sub_mask_50_{k}.npy", (SE_sub_index > 0).astype(int))
    np.save(f"{WORK_DIR}/len_sub.npy", l)
    print(f"‚úÖ Substructure extraction completed for fold {k}")

print("üìä Data processing functions loaded!")

In [None]:
# Dataset class
class Data_Encoder(data.Dataset):
    def __init__(self, list_IDs, labels, df_dti, k):
        self.labels = labels
        self.list_IDs = list_IDs
        self.df = df_dti
        self.k = k

    def __len__(self):
        return len(self.list_IDs)

    def __getitem__(self, index):
        index = self.list_IDs[index]
        d = self.df.iloc[index]['Drug_smile']
        s = int(self.df.iloc[index]['SE_id'])

        d_v, input_mask_d = drug2emb_encoder(d)

        # Load pre-computed side effect substructures
        SE_index = np.load(f"{SUB_DIR}/SE_sub_index_50_32.npy").astype(int)
        SE_mask = np.load(f"{SUB_DIR}/SE_sub_mask_50_32.npy")
        s_v = SE_index[s, :]
        input_mask_s = SE_mask[s, :]
        y = self.labels[index]
        
        return d_v, s_v, input_mask_d, input_mask_s, y

print("üì¶ Dataset class loaded!")

In [None]:
# Training and evaluation functions
def loss_fun(output, label):
    loss = torch.sum((output - label) ** 2)
    return loss

def trainfun(model, device, train_loader, optimizer, epoch, log_interval):
    model.train()
    avg_loss = []
    pbar = tqdm(train_loader, desc=f"Epoch {epoch} Training")

    for batch_idx, (Drug, SE, DrugMask, SEMsak, Label) in enumerate(pbar):
        Drug = Drug.to(device)
        SE = SE.to(device)
        DrugMask = DrugMask.to(device)
        SEMsak = SEMsak.to(device)
        Label = torch.FloatTensor([int(item) for item in Label]).to(device)

        optimizer.zero_grad()
        out, _, _ = model(Drug, SE, DrugMask, SEMsak)
        pred = out.to(device)

        loss = loss_fun(pred.flatten(), Label).to('cpu')
        loss.backward()
        optimizer.step()
        avg_loss.append(loss.item())
        
        pbar.set_postfix({'loss': f'{loss.item():.6f}'})

    return sum(avg_loss) / len(avg_loss)

def predict(model, device, test_loader):
    total_preds = torch.Tensor()
    total_labels = torch.Tensor()

    model.eval()
    if torch.cuda.is_available():
        torch.cuda.manual_seed(42)

    with torch.no_grad():
        for batch_idx, (Drug, SE, DrugMask, SEMsak, Label) in enumerate(tqdm(test_loader, desc="Predicting")):
            Drug = Drug.to(device)
            SE = SE.to(device)
            DrugMask = DrugMask.to(device)
            SEMsak = SEMsak.to(device)
            Label = torch.FloatTensor([int(item) for item in Label]).to(device)
            out, _, _ = model(Drug, SE, DrugMask, SEMsak)

            location = torch.where(Label != 0)
            pred = out[location]
            label = Label[location]

            total_preds = torch.cat((total_preds, pred.detach().cpu()), 0)
            total_labels = torch.cat((total_labels, label.detach().cpu()), 0)

    return total_labels.numpy().flatten(), total_preds.numpy().flatten()

def evaluate(model, device, test_loader):
    total_preds = torch.Tensor()
    total_label = torch.Tensor()
    singleDrug_auc = []
    singleDrug_aupr = []
    model.eval()
    if torch.cuda.is_available():
        torch.cuda.manual_seed(42)

    with torch.no_grad():
        for batch_idx, (Drug, SE, DrugMask, SEMsak, Label) in enumerate(tqdm(test_loader, desc="Evaluating")):
            Drug = Drug.to(device)
            SE = SE.to(device)
            DrugMask = DrugMask.to(device)
            SEMsak = SEMsak.to(device)
            Label = torch.FloatTensor([int(item) for item in Label]).to(device)
            output, _, _ = model(Drug, SE, DrugMask, SEMsak)
            pred = output.detach().cpu()
            pred = torch.Tensor(pred)

            total_preds = torch.cat((total_preds, pred), 0)
            total_label = torch.cat((total_label, Label), 0)

            pred = pred.numpy().flatten()
            pred = np.where(pred > 0.5, 1, 0)
            label = (Label.numpy().flatten() != 0).astype(int)
            label = np.where(label != 0, 1, label)

            singleDrug_auc.append(roc_auc_score(label, pred))
            singleDrug_aupr.append(average_precision_score(label, pred))

        drugAUC = sum(singleDrug_auc) / len(singleDrug_auc)
        drugAUPR = sum(singleDrug_aupr) / len(singleDrug_aupr)
        total_preds = total_preds.numpy()
        total_label = total_label.numpy()

        total_pre_binary = np.where(total_preds > 0.5, 1, 0)
        label01 = np.where(total_label != 0, 1, total_label)

        pre_list = total_pre_binary.tolist()
        label_list = label01.tolist()

        precision = precision_score(pre_list, label_list)
        recall = recall_score(pre_list, label_list)
        accuracy = accuracy_score(pre_list, label_list)

        total_preds = np.where(total_preds > 0.5, 1, 0)
        total_label = np.where(total_label != 0, 1, total_label)

        pos = np.squeeze(total_preds[np.where(total_label)])
        pos_label = np.ones(len(pos))

        neg = np.squeeze(total_preds[np.where(total_label == 0)])
        neg_label = np.zeros(len(neg))

        y = np.hstack((pos, neg))
        y_true = np.hstack((pos_label, neg_label))
        auc_all = roc_auc_score(y_true, y)
        aupr_all = average_precision_score(y_true, y)

    return auc_all, aupr_all, drugAUC, drugAUPR, precision, recall, accuracy

print("üèãÔ∏è Training functions loaded!")

In [None]:
# Main training function
def main(training_generator, testing_generator, modeling, lr, num_epoch, weight_decay, 
         log_interval, cuda_name, save_model, k, save_every=5, resume_path=None):
    
    print('\n' + '='*80)
    print(f'üöÄ Starting training for fold {k}')
    print(f'Model: {modeling.__name__}')
    print(f'Learning rate: {lr}')
    print(f'Epochs: {num_epoch}')
    print(f'Weight decay: {weight_decay}')
    print('='*80)

    # Device setup
    device = torch.device(cuda_name if torch.cuda.is_available() else 'cpu')
    print(f'üñ•Ô∏è  Using device: {device}')

    # Model initialization
    model = modeling().to(device)
    total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f'üìä Total trainable parameters: {total_params:,}')

    # Optimizer
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)

    # Resume from checkpoint if specified
    start_epoch = 0
    if resume_path is not None and os.path.exists(resume_path):
        try:
            ckpt = torch.load(resume_path, map_location=device, weights_only=False)
            if isinstance(ckpt, dict) and 'model_state_dict' in ckpt:
                model.load_state_dict(ckpt['model_state_dict'])
                if 'optimizer_state_dict' in ckpt:
                    optimizer.load_state_dict(ckpt['optimizer_state_dict'])
                start_epoch = int(ckpt.get('epoch', 0))
                if 'random_state' in ckpt:
                    torch.set_rng_state(ckpt['random_state'])
                if 'numpy_random_state' in ckpt:
                    np.random.set_state(ckpt['numpy_random_state'])
                print(f"üîÑ Resuming from {resume_path} at epoch {start_epoch}")
            else:
                model.load_state_dict(ckpt)
                print(f"üîÑ Loaded model weights from legacy checkpoint {resume_path}")
        except Exception as e:
            print(f"‚ö†Ô∏è  Could not resume from {resume_path}: {e}")

    history = []
    train_losses = []

    # Load existing history if resuming
    metrics_file = os.path.join(RESULTS_DIR, 'train_metrics_per_epoch.json')
    if start_epoch > 0 and os.path.exists(metrics_file):
        try:
            with open(metrics_file, 'r', encoding='utf-8') as f:
                history = json.load(f)
        except Exception:
            history = []

    # Training loop
    for epoch in range(start_epoch, num_epoch):
        train_loss = trainfun(model=model, device=device,
                              train_loader=training_generator,
                              optimizer=optimizer, epoch=epoch + 1, 
                              log_interval=log_interval)
        train_losses.append(train_loss)

        # Save checkpoint
        if ((epoch + 1) % save_every == 0) or (epoch == num_epoch - 1):
            ckpt_obj = {
                'epoch': epoch + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'random_state': torch.get_rng_state(),
                'numpy_random_state': np.random.get_state(),
            }
            ckpt_path = os.path.join(CHECKPOINTS_DIR, f'{k}_{epoch + 1}.pth')
            torch.save(ckpt_obj, ckpt_path)
            torch.save(ckpt_obj, os.path.join(CHECKPOINTS_DIR, f'latest_{k}.pth'))

        # Evaluate on test set
        model.eval()
        with torch.no_grad():
            y_true, y_pred = predict(model=model, device=device, test_loader=testing_generator)
        ep_mse = mse(y_true, y_pred)
        ep_rmse = rmse(y_true, y_pred)
        ep_scc = spearman(y_true, y_pred)
        
        print(f"Epoch {epoch+1}/{num_epoch} - TrainLoss: {train_loss:.6f} - MSE: {ep_mse:.6f} - RMSE: {ep_rmse:.6f} - SCC: {ep_scc:.6f}")
        
        history.append({
            'epoch': int(epoch+1),
            'train_loss': float(train_loss),
            'MSE': float(ep_mse),
            'RMSE': float(ep_rmse),
            'SCC': float(ep_scc)
        })
        
        # Save history after each epoch
        with open(metrics_file, 'w', encoding='utf-8') as f:
            json.dump(history, f, ensure_ascii=False, indent=2)

    print("\nüîÆ Making predictions...")
    test_labels, test_preds = predict(model=model, device=device, test_loader=testing_generator)

    # Save predictions
    os.makedirs(PREDICT_DIR, exist_ok=True)
    np.save(f'{PREDICT_DIR}/total_labels_{k}.npy', test_labels)
    np.save(f'{PREDICT_DIR}/total_preds_{k}.npy', test_preds)

    # Calculate metrics
    test_MSE = mse(test_labels, test_preds)
    test_RMSE = rmse(test_labels, test_preds)
    test_SCC = spearman(test_labels, test_preds)

    print("\nüìä Evaluating performance...")
    auc_all, aupr_all, drugAUC, drugAUPR, precision, recall, accuracy = evaluate(
        model=model, device=device, test_loader=testing_generator)

    print(f'\nüéØ Test Results (Regression): MSE: {test_MSE:.5f}\tRMSE: {test_RMSE:.5f}\tSCC: {test_SCC:.5f}')
    print(f'üìà Classification Metrics: AUC: {auc_all:.5f}\tAUPR: {aupr_all:.5f}\tDrug AUC: {drugAUC:.5f}\tDrug AUPR: {drugAUPR:.5f}')
    print(f'üéØ Precision: {precision:.5f}\tRecall: {recall:.5f}\tAccuracy: {accuracy:.5f}')

    # Save final metrics
    os.makedirs(RESULTS_DIR, exist_ok=True)
    metrics_path = os.path.join(RESULTS_DIR, f'metrics_fold_{k}.json')
    with open(metrics_path, 'w', encoding='utf-8') as f:
        json.dump({
            'fold': k,
            'MSE': float(test_MSE),
            'RMSE': float(test_RMSE),
            'SCC': float(test_SCC),
            'AUC_all': float(auc_all),
            'AUPR_all': float(aupr_all),
            'AUC_drug': float(drugAUC),
            'AUPR_drug': float(drugAUPR),
            'Precision': float(precision),
            'Recall': float(recall),
            'Accuracy': float(accuracy)
        }, f, ensure_ascii=False, indent=2)
    
    print(f'‚úÖ Metrics saved to {metrics_path}')
    
    return {
        'MSE': test_MSE, 'RMSE': test_RMSE, 'SCC': test_SCC,
        'AUC_all': auc_all, 'AUPR_all': aupr_all, 
        'AUC_drug': drugAUC, 'AUPR_drug': drugAUPR,
        'Precision': precision, 'Recall': recall, 'Accuracy': accuracy
    }

print("üéØ Main training function loaded!")

## üîÑ 5. Cross-Validation Training

Now let's set up and run the 10-fold cross-validation training.

In [None]:
# Training configuration
class Config:
    # Model parameters
    model_type = 0  # Trans model
    lr = 1e-4
    weight_decay = 0.01
    num_epoch = 100  # Reduced for Colab demo
    log_interval = 40
    cuda_name = 'cuda' if torch.cuda.is_available() else 'cpu'
    save_model = True
    batch_size = 128
    save_every = 10
    resume_path = None
    
    # Data files
    raw_file = f'{DATA_DIR}/raw_frequency_750.mat'
    SMILES_file = f'{DATA_DIR}/drug_SMILES_750.csv'
    mask_mat_file = f'{DATA_DIR}/mask_mat_750.mat'
    side_effect_label = f'{DATA_DIR}/side_effect_label_750.mat'
    drug_side_file = f'{DATA_DIR}/drug_side.pkl'

config = Config()

print("‚öôÔ∏è Configuration:")
print(f"  Model: Trans")
print(f"  Learning rate: {config.lr}")
print(f"  Epochs: {config.num_epoch}")
print(f"  Batch size: {config.batch_size}")
print(f"  Device: {config.cuda_name}")
print(f"  Weight decay: {config.weight_decay}")

In [None]:
# Load data for cross-validation
print("üìÇ Loading data...")

# Load drug-side effect interactions
if os.path.exists(config.drug_side_file):
    with open(config.drug_side_file, 'rb') as f:
        drug_side = pickle.load(f)
    print(f"‚úÖ Loaded drug-side interactions: {drug_side.shape}")
else:
    print(f"‚ùå File not found: {config.drug_side_file}")
    print("Please ensure all data files are uploaded to Google Drive")

# Load drug SMILES
if os.path.exists(config.SMILES_file):
    drug_dict, drug_smile = load_drug_smile(config.SMILES_file)
    print(f"‚úÖ Loaded {len(drug_smile)} drug SMILES")
else:
    print(f"‚ùå File not found: {config.SMILES_file}")

# Extract positive and negative samples
print("\nüîç Extracting positive and negative samples...")
addition_negative_sample, final_positive_sample, final_negative_sample = Extract_positive_negative_samples(
    drug_side, addition_negative_number='all')

addition_negative_sample = np.vstack((addition_negative_sample, final_negative_sample))
final_sample = final_positive_sample
X = final_sample[:, 0::]
final_target = final_sample[:, final_sample.shape[1] - 1]
y = final_target

print(f"‚úÖ Extracted {len(final_positive_sample)} positive samples")
print(f"‚úÖ Extracted {len(addition_negative_sample)} negative samples")

# Prepare data for cross-validation
data = []
data_x = []
data_y = []

for i in range(X.shape[0]):
    data_x.append((X[i, 1], X[i, 0]))
    data_y.append((int(float(X[i, 2]))))
    data.append((X[i, 1], drug_smile[X[i, 0]], X[i, 2]))

print(f"‚úÖ Prepared {len(data)} samples for cross-validation")

In [None]:
# Precompute substructures (run once)
print("üß© Computing substructure indices...")
identify_sub(data, 0)

# Copy the computed indices for all folds
import shutil
for fold in range(1, 10):
    shutil.copy(f'{SUB_DIR}/SE_sub_index_50_0.npy', f'{SUB_DIR}/SE_sub_index_50_{fold}.npy')
    shutil.copy(f'{SUB_DIR}/SE_sub_mask_50_0.npy', f'{SUB_DIR}/SE_sub_mask_50_{fold}.npy')

# Also create the 32 index files that the code expects
shutil.copy(f'{SUB_DIR}/SE_sub_index_50_0.npy', f'{SUB_DIR}/SE_sub_index_50_32.npy')
shutil.copy(f'{SUB_DIR}/SE_sub_mask_50_0.npy', f'{SUB_DIR}/SE_sub_mask_50_32.npy')

print("‚úÖ Substructure indices prepared for all folds")

In [None]:
# Run cross-validation training
fold_results = []
modeling = Trans

# Set up cross-validation
kfold = StratifiedKFold(n_splits=10, random_state=1, shuffle=True)
params = {
    'batch_size': config.batch_size,
    'shuffle': True
}

print("\nüöÄ Starting 10-fold cross-validation training...")
print("="*80)

for fold, (train_idx, test_idx) in enumerate(kfold.split(data_x, data_y)):
    print(f"\nüìÅ Fold {fold + 1}/10")
    print("-" * 40)
    
    # Split data
    data_train = np.array(data)[train_idx]
    data_test = np.array(data)[test_idx]

    # Create DataFrames
    df_train = pd.DataFrame(data=data_train.tolist(), columns=['SE_id', 'Drug_smile', 'Label'])
    df_test = pd.DataFrame(data=data_test.tolist(), columns=['SE_id', 'Drug_smile', 'Label'])

    print(f"üìä Train samples: {len(df_train)}, Test samples: {len(df_test)}")
    print(f"üìà Train positives: {df_train['Label'].sum()}, Test positives: {df_test['Label'].sum()}")

    # Create datasets and dataloaders
    training_set = Data_Encoder(df_train.index.values, df_train.Label.values, df_train, fold)
    testing_set = Data_Encoder(df_test.index.values, df_test.Label.values, df_test, fold)

    training_generator = torch.utils.data.DataLoader(training_set, **params)
    testing_generator = torch.utils.data.DataLoader(testing_set, **params)

    # Train model
    try:
        result = main(
            training_generator=training_generator,
            testing_generator=testing_generator,
            modeling=modeling,
            lr=config.lr,
            num_epoch=config.num_epoch,
            weight_decay=config.weight_decay,
            log_interval=config.log_interval,
            cuda_name=config.cuda_name,
            save_model=config.save_model,
            k=fold,
            save_every=config.save_every,
            resume_path=config.resume_path
        )
        
        result['fold'] = fold
        fold_results.append(result)
        print(f"‚úÖ Fold {fold + 1} completed successfully!")
        
    except Exception as e:
        print(f"‚ùå Fold {fold + 1} failed: {e}")
        import traceback
        traceback.print_exc()
        continue

print("\nüéâ Cross-validation training completed!")

## üìä 6. Results Analysis

Let's analyze and visualize the results from all folds.

In [None]:
# Aggregate results across folds
if fold_results:
    print("üìä Aggregating results across all folds...")
    
    # Calculate mean and std for each metric
    metrics = ['MSE', 'RMSE', 'SCC', 'AUC_all', 'AUPR_all', 'AUC_drug', 'AUPR_drug', 'Precision', 'Recall', 'Accuracy']
    summary = {}
    
    for metric in metrics:
        values = [result[metric] for result in fold_results if metric in result]
        if values:
            summary[metric] = {
                'mean': np.mean(values),
                'std': np.std(values),
                'values': values
            }
    
    # Create summary table
    print("\nüéØ Final Results Summary:")
    print("=" * 60)
    for metric, stats in summary.items():
        print(f"{metric:15s}: {stats['mean']:.4f} ¬± {stats['std']:.4f}")
    
    # Save summary
    summary_path = os.path.join(RESULTS_DIR, 'cv_summary.json')
    with open(summary_path, 'w', encoding='utf-8') as f:
        json.dump(summary, f, ensure_ascii=False, indent=2)
    print(f"\nüíæ Summary saved to {summary_path}")
    
else:
    print("‚ùå No results to analyze. Please run the training first.")

In [None]:
# Visualize results
if fold_results:
    # Create visualization plots
    fig, axes = plt.subplots(2, 3, figsize=(18, 12))
    fig.suptitle('HSTrans Cross-Validation Results', fontsize=16, fontweight='bold')
    
    # Regression metrics
    reg_metrics = ['MSE', 'RMSE', 'SCC']
    for i, metric in enumerate(reg_metrics):
        values = [result[metric] for result in fold_results if metric in result]
        axes[0, i].boxplot(values, labels=[metric])
        axes[0, i].set_title(f'{metric} Across Folds')
        axes[0, i].set_ylabel(metric)
        axes[0, i].grid(True, alpha=0.3)
    
    # Classification metrics
    cls_metrics = ['AUC_all', 'AUPR_all', 'Accuracy']
    for i, metric in enumerate(cls_metrics):
        values = [result[metric] for result in fold_results if metric in result]
        axes[1, i].boxplot(values, labels=[metric])
        axes[1, i].set_title(f'{metric} Across Folds')
        axes[1, i].set_ylabel(metric)
        axes[1, i].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(os.path.join(RESULTS_DIR, 'cv_results_boxplot.png'), dpi=300, bbox_inches='tight')
    plt.show()
    
    # Training curves (if available)
    metrics_file = os.path.join(RESULTS_DIR, 'train_metrics_per_epoch.json')
    if os.path.exists(metrics_file):
        with open(metrics_file, 'r', encoding='utf-8') as f:
            history = json.load(f)
        
        if history:
            fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
            
            epochs = [h['epoch'] for h in history]
            train_losses = [h['train_loss'] for h in history]
            mse_values = [h['MSE'] for h in history]
            scc_values = [h['SCC'] for h in history]
            
            ax1.plot(epochs, train_losses, 'b-', label='Train Loss')
            ax1.set_xlabel('Epoch')
            ax1.set_ylabel('Loss')
            ax1.set_title('Training Loss')
            ax1.legend()
            ax1.grid(True, alpha=0.3)
            
            ax2.plot(epochs, mse_values, 'r-', label='MSE')
            ax2.plot(epochs, scc_values, 'g-', label='SCC')
            ax2.set_xlabel('Epoch')
            ax2.set_ylabel('Metric Value')
            ax2.set_title('Validation Metrics')
            ax2.legend()
            ax2.grid(True, alpha=0.3)
            
            plt.tight_layout()
            plt.savefig(os.path.join(RESULTS_DIR, 'training_curves.png'), dpi=300, bbox_inches='tight')
            plt.show()
    
else:
    print("‚ùå No results to visualize. Please run the training first.")

In [None]:
# Download results to Google Drive
import shutil

# Create results archive
archive_path = '/content/HSTrans_Results.zip'
shutil.make_archive('/content/HSTrans_Results', 'zip', WORK_DIR)

# Copy to Google Drive
drive_results_path = '/content/drive/MyDrive/HSTrans_Results.zip'
shutil.copy(archive_path, drive_results_path)

print(f"‚úÖ Results archived to: {archive_path}")
print(f"‚úÖ Results copied to Google Drive: {drive_results_path}")
print("\nüìÅ You can now download the results from your Google Drive!")

# List all saved files
print("\nüìã Saved files:")
for root, dirs, files in os.walk(RESULTS_DIR):
    for file in files:
        if file.endswith(('.json', '.png', '.npy')):
            print(f"  üìÑ {os.path.join(root, file)}")

## üéâ Training Complete!

### üìã What We Accomplished

1. **‚úÖ Setup & Dependencies**: Installed all required packages and set up environment
2. **‚úÖ Data Preparation**: Loaded and processed drug-side effect interaction data
3. **‚úÖ Model Architecture**: Implemented HSTrans Transformer model with cross-attention
4. **‚úÖ Training Pipeline**: Created comprehensive training functions with checkpointing
5. **‚úÖ Cross-Validation**: Ran 10-fold cross-validation training
6. **‚úÖ Results Analysis**: Visualized and aggregated performance metrics

### üìä Key Features of This Colab Notebook

- **üñ•Ô∏è GPU Support**: Automatic GPU detection and utilization
- **üíæ Checkpointing**: Save/resume training progress
- **üìà Progress Tracking**: Real-time training progress with tqdm
- **üîÑ Cross-Validation**: Robust 10-fold CV for reliable evaluation
- **üìä Visualization**: Comprehensive result visualization
- **‚òÅÔ∏è Cloud Storage**: Automatic backup to Google Drive

### üéØ Model Performance

The HSTrans model uses:
- **Cross-attention mechanism** for drug-side effect interaction
- **Substructure encoding** for molecular representation
- **Residual fusion** with gated mechanisms
- **Multi-layer Transformer** architecture

### üîß How to Use This Notebook

1. **Upload your data files** to the data directory or update the `download_data()` function
2. **Configure training parameters** in the `Config` class
3. **Run all cells sequentially** to complete training
4. **Check results** in the generated plots and saved files
5. **Download results** from Google Drive

### üìû Need Help?

- Check **GPU availability** at the beginning
- Ensure all **data files are uploaded** correctly
- Monitor **memory usage** during training
- Use **checkpoint resuming** if training gets interrupted

**Happy Training! üöÄ**