# Molecular Toxicity Prediction (Tox21) using GCN

In this notebook, we implement and train a Graph Convolutional Network for predicting molecular toxicity based on the Tox21 dataset.

## 1. Environment Setup in Colab

Run the following code to install PyTorch Geometric and other dependencies:

In [None]:
import os
import shutil

# This cell is intended to be run only in Google Colab
if 'COLAB_GPU' in os.environ:
    print("Running in Colab...")

    # 1. Remove existing folder if it already exists
    repo_name = 'gnn-molecule-prediction'
    repo_path = os.path.join('/content', repo_name)

    if os.path.exists(repo_path):
        shutil.rmtree(repo_path)
        print(f"Removed existing directory: {repo_path}")
    else:
        print(f"No existing directory found: {repo_path}")

    # 2. Clone the GitHub repository
    %cd /content
    !git clone https://github.com/sth-s/gnn-molecule-prediction.git

    # 3. Change working directory to the project root
    %cd gnn-molecule-prediction

    # 4. Install Conda support in Colab and restart runtime
    !pip install -q condacolab
    import condacolab
    condacolab.install()  # This will automatically restart the runtime
    !conda env update -n base -f environment.yml

else:
    print("This cell is intended to be run only in Google Colab")


## 2. Import Libraries

In [None]:
# PyTorch and PyG
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data, Dataset
from torch_geometric.nn import GCNConv, global_mean_pool
from torch_geometric.loader import DataLoader

# Chemistry and data processing
from rdkit import Chem
import numpy as np

# Visualization
import matplotlib.pyplot as plt
import seaborn as sns

# Evaluation and splitting
from sklearn.metrics import roc_auc_score, precision_recall_curve, auc
from sklearn.model_selection import train_test_split

# Our custom data loader
from src.data_utils import load_tox21

# Set visualization style
sns.set_theme(style="whitegrid")
sns.set_palette('muted')

# Check CUDA availability
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")


## 3. Loading and Preparing Tox21 Data

### 3.1 About the Data Loading Module

We use a custom developed `load_tox21` module from the `src.data_utils` package, which provides the following features:

- **Automatic data downloading**: if the `tox21.csv` file is missing, the module will automatically download it from the official source
- **SMILES to graphs conversion**: each molecule is converted into a graph with node and edge attributes
- **Result caching**: results are saved to a cache to speed up subsequent runs
- **Flexible configuration**: you can specify the data path, file name, target columns, and other parameters

Detailed information about the `load_tox21` function and usage examples are available in the project's README.

In [None]:
# Load the Tox21 dataset using our custom function
# If the tox21.csv file is missing, it will be automatically downloaded
dataset = load_tox21(
    root="data/Tox21",           # root directory for the data
    filename="tox21.csv",        # name of the CSV file
    smiles_col="smiles",         # column containing SMILES strings
    cache_file="data.pt",        # name of the cache file
    recreate=False,              # use cache if it exists
    auto_download=True                # automatically download if missing
)

print(f"Total graphs: {len(dataset)}")
print(f"Number of node features: {dataset[0].x.shape[1] if len(dataset) > 0 else 0}")
print(f"Number of tasks (toxicity endpoints): {dataset.num_classes if hasattr(dataset, 'num_classes') else len(dataset[0].y) if len(dataset) > 0 else 0}")

In [None]:
# Analysis of graph properties in the dataset
if len(dataset) > 0:
    data_example = dataset[0]
    print("Example graph from the dataset:")
    print(data_example)
    print(f"Edge index dimensions: {data_example.edge_index.shape}")
    print(f"Number of atoms (nodes): {data_example.num_nodes}")
    print(f"Task labels (y): {data_example.y}")
    
    # Dataset statistics
    nodes_count = []
    edges_count = []
    for i in range(min(1000, len(dataset))):
        data = dataset[i]
        nodes_count.append(data.num_nodes)
        edges_count.append(data.edge_index.shape[1])
    
    print(f"\nGraph statistics (based on a sample of {len(nodes_count)} molecules):")
    print(f"Average number of atoms: {np.mean(nodes_count):.2f} ± {np.std(nodes_count):.2f}")
    print(f"Average number of bonds: {np.mean(edges_count)/2:.2f} ± {np.std(edges_count)/2:.2f}")
    print(f"Min/max atoms: {np.min(nodes_count)}/{np.max(nodes_count)}")
    print(f"Min/max bonds: {np.min(edges_count)/2:.0f}/{np.max(edges_count)/2:.0f}")

In [None]:
# Visualization of atom and bond count distributions
if 'nodes_count' in locals():
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))
    
    ax1.hist(nodes_count, bins=30, alpha=0.7, color='skyblue')
    ax1.set_title('Distribution of Atom Counts')
    ax1.set_xlabel('Number of atoms')
    ax1.set_ylabel('Number of molecules')
    
    ax2.hist([e/2 for e in edges_count], bins=30, alpha=0.7, color='salmon')
    ax2.set_title('Distribution of Bond Counts')
    ax2.set_xlabel('Number of bonds')
    ax2.set_ylabel('Number of molecules')
    
    plt.tight_layout()
    plt.show()

In [None]:
# Split into train/test (80/20)
torch.manual_seed(42)  # for reproducibility
train_len = int(0.8 * len(dataset))
train_dataset, test_dataset = torch.utils.data.random_split(
    dataset, [train_len, len(dataset) - train_len]
)

# Create DataLoader
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader  = DataLoader(test_dataset, batch_size=32, shuffle=False)

print(f"Total graphs: {len(dataset)}")
print(f"Training set: {len(train_dataset)} graphs")
print(f"Test set: {len(test_dataset)} graphs")
print(f"Batch size: {train_loader.batch_size}, number of batches in training set: {len(train_loader)}")

## 4. Define GCN Model

In [None]:
class GCN(nn.Module):
    def __init__(self, num_node_features, num_classes):
        super(GCN, self).__init__()
        # TODO: Define model architecture

    def forward(self, data):
        # TODO: Implement forward pass
        pass

## 5. Training and Evaluation

In [None]:
# TODO: Implement model training and evaluation functions

## 6. Results Visualization

In [None]:
# TODO: Visualization of results (ROC curve, PR curve, etc.)

## 7. Conclusion and Findings

# TODO: Add analysis of results and conclusions