# Molecular Toxicity Prediction (Tox21) using GIN

In this notebook, we implement and train a Graph Isomorphism Network (GIN) for predicting molecular toxicity based on the Tox21 dataset. GIN is generally more powerful than GCN for graph representation learning.

## 1. Environment Setup in Colab

Run the following code to install PyTorch Geometric and other dependencies:

In [None]:
import os
import shutil

# This cell is intended to be run only in Google Colab
if 'COLAB_GPU' in os.environ:
    print("Running in Colab...")

    # 1. Remove existing folder if it already exists
    repo_name = 'gnn-molecule-prediction'
    repo_path = os.path.join('/content', repo_name)

    if os.path.exists(repo_path):
        shutil.rmtree(repo_path)
        print(f"Removed: {repo_path}")
    else:
        print(f"No existing directory found.")

    # 2. Clone the GitHub repository
    %cd /content
    !git clone https://github.com/sth-s/gnn-molecule-prediction.git

    # 3. Change working directory to the project root
    %cd gnn-molecule-prediction

    # 4. Install dependencies via pip
    !pip install torch torchvision torchaudio
    !pip install torch-geometric scikit-learn matplotlib seaborn deepchem

    print("Dependencies installed.")

else:
    os.chdir('../')
    print(f"Changed working directory to: {os.getcwd()}")


## 2. Import Libraries

In [None]:
# PyTorch and PyG
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data, Dataset
from torch_geometric.nn import GINConv, global_mean_pool
from torch_geometric.loader import DataLoader
from torch.optim.lr_scheduler import OneCycleLR

# Chemistry and data processing
import deepchem as dc
from deepchem.feat.graph_data import GraphData
import numpy as np

# Visualization
import matplotlib.pyplot as plt
import seaborn as sns

# Evaluation and splitting
from sklearn.metrics import roc_auc_score, precision_recall_curve, auc, roc_curve
from sklearn.model_selection import train_test_split

# Our custom data loader
from src.data_utils import load_tox21
# Import our custom model and training functions
from src.model import GIN
from src.train import train_epoch, evaluate

# Set visualization style
sns.set_theme(style="whitegrid")
sns.set_palette('muted')

# Check CUDA availability
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")


## 3. Loading and Preparing Tox21 Data

### 3.1 About the Data Loading Module

We use a custom developed `load_tox21` module from the `src.data_utils` package, which provides the following features:

- **Automatic data downloading**: if the `tox21.csv` file is missing, the module will automatically download it from the official source
- **SMILES to graphs conversion**: each molecule is converted into a graph with node and edge attributes
- **Result caching**: results are saved to a cache to speed up subsequent runs
- **Flexible configuration**: you can specify the data path, file name, target columns, and other parameters

Detailed information about the `load_tox21` function and usage examples are available in the project's README.

In [None]:
# Load the Tox21 dataset using our custom function
# If the tox21.csv file is missing, it will be automatically downloaded
dataset = load_tox21(
    root="data/Tox21",           # root directory for the data
    filename="tox21.csv",        # name of the CSV file
    smiles_col="smiles",         # column containing SMILES strings
    mol_id_col="mol_id",         # column containing molecule IDs
    cache_file="data.pt",        # name of the cache file
    recreate=False,              # Use existing processed file if available
    auto_download=True           # automatically download if missing
)

print(f"Type of dataset object: {type(dataset)}")
try:
    print(f"dataset._slices: {dataset._slices}")
except AttributeError:
    print(f"dataset has no attribute _slices")
print(f"Total graphs (len(dataset)): {len(dataset)}")

# Check if we can access the first item in the dataset
if len(dataset) > 0:
    print("Attempting to access dataset[0]...")
    first_item = None
    try:
        first_item = dataset[0]
        print(f"dataset[0] type: {type(first_item)}")
        if first_item is not None:
            print(f"Number of node features: {first_item.x.shape[1]}")
            if hasattr(first_item, 'mol_id'):
                 print(f"Mol ID for the first molecule: {first_item.mol_id}")
            else:
                 print(f"Mol ID attribute not found in first_item.")
            
            # Check the presence of y and y_mask
            if hasattr(first_item, 'y'):
                print(f"Task labels (y) shape: {first_item.y.shape}")
                print(f"Task labels (y): {first_item.y}")
            if hasattr(first_item, 'y_mask'):
                print(f"Task mask (y_mask) shape: {first_item.y_mask.shape}")
                print(f"Task mask (y_mask): {first_item.y_mask}")
        else:
            print(f"dataset[0] is None.")
    except Exception as e:
        print(f"Error accessing dataset[0] or its attributes: {e}")
else:
    print("Dataset is empty (len(dataset) is 0).")


### 3.3 Verify mol_id and Dataset Properties (Revised)

In [None]:
# Verify that mol_id is present and check dataset properties
print("\nRunning second verification cell:")
if len(dataset) > 0:
    data_example = None
    try:
        data_example = dataset[0] # This might be None if the previous cell showed it was None
    except Exception as e:
        print(f"Error getting dataset[0] in second cell: {e}")
    
    if data_example is not None:
        print("Example graph from the dataset (after potential reprocessing):")
        print(data_example)
        
        if hasattr(data_example, 'mol_id'):
            print(f"\nMol ID for the first molecule: {data_example.mol_id}")
        else:
            print("\nmol_id attribute not found in the first molecule (data_example).")
        
        print(f"Edge index dimensions: {data_example.edge_index.shape}")
        print(f"Number of atoms (nodes): {data_example.num_nodes}")
        print(f"Task labels (y): {data_example.y}")
        print(f"Task mask (y_mask): {data_example.y_mask}")
    elif len(dataset) > 0: # dataset[0] was None but dataset is not empty
        print("data_example (dataset[0]) is None, cannot print details.")
else:
    print("Dataset is empty in second verification cell.")

In [None]:
# Analysis of graph properties in the dataset
if len(dataset) > 0:
    data_example = dataset[0]
    print("Example graph from the dataset:")
    print(data_example)
    print(f"Edge index dimensions: {data_example.edge_index.shape}")
    print(f"Number of atoms (nodes): {data_example.num_nodes}")
    print(f"Task labels (y): {data_example.y}")

    # Dataset statistics
    nodes_count = []
    edges_count = []
    for i in range(min(1000, len(dataset))):
        data = dataset[i]
        nodes_count.append(data.num_nodes)
        edges_count.append(data.edge_index.shape[1])
    
    print(f"\nGraph statistics (based on a sample of {len(nodes_count)} molecules):")
    print(f"Average number of atoms: {np.mean(nodes_count):.2f} ± {np.std(nodes_count):.2f}")
    print(f"Average number of bonds: {np.mean(edges_count)/2:.2f} ± {np.std(edges_count)/2:.2f}")
    print(f"Min/max atoms: {np.min(nodes_count)}/{np.max(nodes_count)}")
    print(f"Min/max bonds: {np.min(edges_count)/2:.0f}/{np.max(edges_count)/2:.0f}")

In [None]:
# Visualization of atom and bond count distributions
if 'nodes_count' in locals():
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))
    
    ax1.hist(nodes_count, bins=30, alpha=0.7, color='skyblue')
    ax1.set_title('Distribution of Atom Counts')
    ax1.set_xlabel('Number of atoms')
    ax1.set_ylabel('Number of molecules')
    
    ax2.hist([e/2 for e in edges_count], bins=30, alpha=0.7, color='salmon')
    ax2.set_title('Distribution of Bond Counts')
    ax2.set_xlabel('Number of bonds')
    ax2.set_ylabel('Number of molecules')
    
    plt.tight_layout()
    plt.show()

In [None]:
# Split into train/test (80/20)
torch.manual_seed(42)  # for reproducibility
train_len = int(0.8 * len(dataset))
train_dataset, test_dataset = torch.utils.data.random_split(
    dataset, [train_len, len(dataset) - train_len]
)

# Create DataLoader
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader  = DataLoader(test_dataset, batch_size=32, shuffle=False)

print(f"Total graphs: {len(dataset)}")
print(f"Training set: {len(train_dataset)} graphs")
print(f"Test set: {len(test_dataset)} graphs")
print(f"Batch size: {train_loader.batch_size}, number of batches in training set: {len(train_loader)}")

## 4. Define GCN Model

In [None]:
# Create a GIN model instance using our implementation from src/model.py
model = GIN(
    in_channels=dataset.num_node_features,
    hidden_channels=64,
    num_classes=dataset.num_classes,
    num_layers=2,
    dropout=0.5
).to(device)

# Base learning rate for the optimizer
base_lr = 1e-3
optimizer = torch.optim.Adam(model.parameters(), lr=base_lr)

# Configure OneCycleLR scheduler
# total_steps - total number of training steps (epochs * batches)
num_epochs = 20
total_steps = num_epochs * len(train_loader)
scheduler = OneCycleLR(
    optimizer,
    max_lr=base_lr * 10,  # maximum learning rate (10x base rate)
    total_steps=total_steps,
    pct_start=0.3,        # percentage of steps for lr increase (30%)
    anneal_strategy='cos',  # learning rate decay strategy (cosine)
    div_factor=10.0,      # initial lr = max_lr / div_factor
    final_div_factor=1000.0  # final lr = max_lr / final_div_factor
)

print(f"Model created with {dataset.num_node_features} input features and {dataset.num_classes} output classes")
print(f"Optimizer: {optimizer.__class__.__name__} with base lr={base_lr}")
print(f"Scheduler: OneCycleLR with max_lr={base_lr * 10}")
print(model)

## 5. Training and Evaluation

In [None]:
# Training and evaluation using our functions from src/train.py
num_epochs = 20
best_val_roc = 0
best_model_state = None
history = {
    'train_loss': [], 
    'train_roc_auc': [], 
    'train_pr_auc': [],
    'val_loss': [],
    'val_roc': [], 
    'val_pr': [], 
    'lr': []
}

print("Starting training with OneCycleLR scheduler...")
for epoch in range(1, num_epochs + 1):
    # Train for one epoch with scheduler
    train_results, lr_history = train_epoch(model, train_loader, optimizer, device, scheduler)
    train_loss = train_results['loss']
    train_roc = train_results['roc_auc']
    train_pr = train_results['pr_auc']
    
    # Evaluate the model
    val_metrics = evaluate(model, test_loader, device)
    val_loss = val_metrics['loss']
    val_roc = val_metrics['roc_auc']
    val_pr = val_metrics['pr_auc']
    
    # Save metrics history
    history['train_loss'].append(train_loss)
    history['train_roc_auc'].append(train_roc)
    history['train_pr_auc'].append(train_pr)
    history['val_loss'].append(val_loss)
    history['val_roc'].append(val_roc)
    history['val_pr'].append(val_pr)
    if lr_history:
        history['lr'].extend(lr_history)
    
    # Print progress with current learning rate
    current_lr = optimizer.param_groups[0]['lr']
    print(f"Epoch {epoch:02d}:")
    print(f"  Train: Loss = {train_loss:.4f}, ROC AUC = {train_roc:.4f}, PR AUC = {train_pr:.4f}")
    print(f"  Valid: Loss = {val_loss:.4f}, ROC AUC = {val_roc:.4f}, PR AUC = {val_pr:.4f}, LR = {current_lr:.6f}")
    
    # Save best model state
    if val_roc > best_val_roc:
        best_val_roc = val_roc
        best_model_state = model.state_dict().copy()
        print(f"  → New best model saved! ROC AUC: {best_val_roc:.4f}")

# Load the best model state for final evaluation
if best_model_state is not None:
    model.load_state_dict(best_model_state)
    print(f"Loaded best model with validation ROC AUC: {best_val_roc:.4f}")

# Perform final evaluation
final_metrics = evaluate(model, test_loader, device)
print("\nFinal Evaluation Results:")
print(f"Test ROC AUC: {final_metrics['roc_auc']:.4f}")
print(f"Test PR AUC: {final_metrics['pr_auc']:.4f}")
print(f"Test Loss: {final_metrics['loss']:.4f}")

# Save the best model
torch.save(best_model_state, 'best_gin_model.pt')

## 6. Results Visualization

In [None]:
# Visualize learning curves
plt.figure(figsize=(12, 4))

# Plot training loss
plt.subplot(1, 2, 1)
plt.plot(range(1, len(history['train_loss'])+1), history['train_loss'], 'o-', label='Training Loss')
plt.title('Training Loss per Epoch')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.grid(True)
plt.legend()

# Plot validation ROC AUC
plt.subplot(1, 2, 2)
plt.plot(range(1, len(history['val_roc'])+1), history['val_roc'], 'o-', label='ROC AUC')
plt.plot(range(1, len(history['val_pr'])+1), history['val_pr'], 'o-', label='PR AUC')
plt.title('Validation Metrics per Epoch')
plt.xlabel('Epoch')
plt.ylabel('AUC')
plt.grid(True)
plt.legend()

plt.tight_layout()
plt.show()

In [None]:
# Enhanced Model Evaluation and Visualization
model.eval()

# Collect predictions for each task
all_true = {}
all_pred = {}
all_batch_losses = []

# Initialize loss function - binary cross entropy with reduction='none' to get per-task losses
loss_fn = torch.nn.BCEWithLogitsLoss(reduction='none')

with torch.no_grad():
    for batch in test_loader:
        batch = batch.to(device)
        out = model(batch)  # Raw logits (before sigmoid)
        out_probs = torch.sigmoid(out).detach().cpu()
        mask = batch.y_mask.cpu()
        y = batch.y.cpu()
        
        # Calculate batch losses per task
        batch_loss = loss_fn(out, batch.y)
        masked_loss = batch_loss * batch.y_mask
        all_batch_losses.append(masked_loss.cpu())
        
        for task in range(out_probs.size(1)):
            # Select only valid entries (where mask is True)
            idx = mask[:, task]
            if idx.sum() == 0:
                continue
                
            true = y[idx, task].numpy()
            pred = out_probs[idx, task].numpy()
            
            all_true.setdefault(task, []).append(true)
            all_pred.setdefault(task, []).append(pred)

# Calculate training loss per task
all_losses = torch.cat(all_batch_losses, dim=0)  # Combine all batch losses
valid_samples_per_task = all_losses.ne(0).sum(dim=0).float()
task_losses = all_losses.sum(dim=0) / valid_samples_per_task

# Calculate AUC for each task
task_metrics = {}
for task in all_true:
    t = np.concatenate(all_true[task])
    p = np.concatenate(all_pred[task])
    
    # Skip if only one class is present
    if len(np.unique(t)) < 2:
        continue
        
    # Calculate ROC AUC
    roc_auc = roc_auc_score(t, p)
    
    # Calculate PR AUC
    prec, rec, _ = precision_recall_curve(t, p)
    pr_auc = auc(rec, prec)
    
    # Store task metrics and raw predictions for plotting
    task_metrics[task] = {
        'roc_auc': roc_auc, 
        'pr_auc': pr_auc,
        'true': t,
        'pred': p,
        'loss': task_losses[task].item() if task < len(task_losses) else float('nan')
    }

In [None]:
# Train vs Val Loss and ROC AUC plots
fig, axes = plt.subplots(2, 2, figsize=(16, 10))

# Train vs Val Loss
axes[0, 0].plot(range(1, len(history['train_loss'])+1), history['train_loss'], 'o-', label='Training Loss', color='blue')
axes[0, 0].plot(range(1, len(history['val_loss'])+1), history['val_loss'], 'o-', label='Validation Loss', color='red')
axes[0, 0].set_title('Train vs Validation Loss', fontsize=14)
axes[0, 0].set_xlabel('Epoch', fontsize=12)
axes[0, 0].set_ylabel('Loss', fontsize=12)
axes[0, 0].grid(True, alpha=0.3)
axes[0, 0].legend(fontsize=11)

# Train vs Val ROC AUC
axes[0, 1].plot(range(1, len(history['train_roc_auc'])+1), history['train_roc_auc'], 'o-', label='Training ROC AUC', color='blue')
axes[0, 1].plot(range(1, len(history['val_roc'])+1), history['val_roc'], 'o-', label='Validation ROC AUC', color='green')
axes[0, 1].set_title('Train vs Validation ROC AUC', fontsize=14)
axes[0, 1].set_xlabel('Epoch', fontsize=12)
axes[0, 1].set_ylabel('ROC AUC', fontsize=12)
axes[0, 1].grid(True, alpha=0.3)
axes[0, 1].set_ylim([0.5, 1.0])  # Set y-axis from 0.5 (random) to 1.0 (perfect)
axes[0, 1].legend(fontsize=11)

# Train vs Val PR AUC
axes[1, 0].plot(range(1, len(history['train_pr_auc'])+1), history['train_pr_auc'], 'o-', label='Training PR AUC', color='blue')
axes[1, 0].plot(range(1, len(history['val_pr'])+1), history['val_pr'], 'o-', label='Validation PR AUC', color='purple')
axes[1, 0].set_title('Train vs Validation PR AUC', fontsize=14)
axes[1, 0].set_xlabel('Epoch', fontsize=12)
axes[1, 0].set_ylabel('PR AUC', fontsize=12)
axes[1, 0].grid(True, alpha=0.3)
axes[1, 0].set_ylim([0.0, 1.0])  # PR AUC ranges from 0 to 1
axes[1, 0].legend(fontsize=11)

# Learning Rate Schedule
axes[1, 1].plot(range(1, len(history['lr'])+1), history['lr'], '-', color='orange')
axes[1, 1].set_title('Learning Rate Schedule', fontsize=14)
axes[1, 1].set_xlabel('Optimization Step', fontsize=12)
axes[1, 1].set_ylabel('Learning Rate', fontsize=12)
axes[1, 1].grid(True, alpha=0.3)
axes[1, 1].set_yscale('log')  # Log scale for better visualization

plt.tight_layout()
plt.show()

## 7. Conclusion and Findings

# TODO: Add analysis of results and conclusions