# Molecular Toxicity Prediction (Tox21) using GCN

In this notebook, we implement and train a Graph Convolutional Network for predicting molecular toxicity based on the Tox21 dataset.

## 1. Environment Setup in Colab

Run the following code to install PyTorch Geometric and other dependencies:

In [2]:
import os

# This code should run only in Colab
if 'COLAB_GPU' in os.environ:
    # 1) Remove any previous copy to avoid nested folders
    if os.path.exists('gnn-molecule-prediction'):
        !rm -rf gnn-molecule-prediction
    # 2) Clone your repository into /content
    !git clone https://github.com/sth-s/gnn-molecule-prediction.git
    # 3) Navigate to the project root
    %cd gnn-molecule-prediction
    # 4) Install Conda support
    !pip install -q condacolab
    import condacolab
    condacolab.install()  # this will restart the runtime
else:
    print("This cell is intended to be run only in Colab")

Cloning into 'gnn-molecule-prediction'...
remote: Enumerating objects: 18, done.[K
remote: Counting objects:   5% (1/18)[Kremote: Counting objects:  11% (2/18)[Kremote: Counting objects:  16% (3/18)[Kremote: Counting objects:  22% (4/18)[Kremote: Counting objects:  27% (5/18)[Kremote: Counting objects:  33% (6/18)[Kremote: Counting objects:  38% (7/18)[Kremote: Counting objects:  44% (8/18)[Kremote: Counting objects:  50% (9/18)[Kremote: Counting objects:  55% (10/18)[Kremote: Counting objects:  61% (11/18)[Kremote: Counting objects:  66% (12/18)[Kremote: Counting objects:  72% (13/18)[Kremote: Counting objects:  77% (14/18)[Kremote: Counting objects:  83% (15/18)[Kremote: Counting objects:  88% (16/18)[Kremote: Counting objects:  94% (17/18)[Kremote: Counting objects: 100% (18/18)[Kremote: Counting objects: 100% (18/18), done.[K
remote: Compressing objects:   7% (1/13)[Kremote: Compressing objects:  15% (2/13)[Kremote: Compressing objects:  23

## 2. Import Libraries

In [4]:
# PyTorch and PyG
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data, Dataset
from torch_geometric.datasets import TUDataset
from torch_geometric.nn import GCNConv, global_mean_pool
from torch_geometric.loader import DataLoader

# Chemistry and data processing
from rdkit import Chem
import numpy as np

# Visualization
import matplotlib.pyplot as plt
import seaborn as sns

# Evaluation and splitting
from sklearn.metrics import roc_auc_score, precision_recall_curve, auc
from sklearn.model_selection import train_test_split

# Set visualization style
sns.set_theme(style="whitegrid")
sns.set_palette('muted')

# Check CUDA availability
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")


Using device: cpu


## 3. Loading and Preparing Tox21 Data

In [6]:
from torch_geometric.datasets import MoleculeNet
from torch_geometric.transforms import ToUndirected

# Load the Tox21 dataset (downloaded to the data/Tox21 folder)
dataset = MoleculeNet(root='data/Tox21', name='Tox21', transform=ToUndirected())

print(f"Total graphs: {len(dataset)}")
print(f"Number of node features: {dataset.num_node_features}")
print(f"Number of tasks (toxicity endpoints): {dataset.num_classes}")

# View an example graph
data_example = dataset[0]
print(data_example)
print(f"Shape of edge_index: {data_example.edge_index.shape}")
print(f"Number of nodes: {data_example.num_nodes}")
print(f"Labels (y): {data_example.y}")

# Split into train/test (80/20)
torch.manual_seed(42)
train_len = int(0.8 * len(dataset))
train_dataset, test_dataset = torch.utils.data.random_split(
    dataset, [train_len, len(dataset) - train_len]
)

# Create DataLoader
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader  = DataLoader(test_dataset, batch_size=32, shuffle=False)

print(f"Batch size: {train_loader.batch_size}, number of batches in train: {len(train_loader)}")



Downloading https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/tox21.csv.gz
Extracting data/Tox21/tox21/raw/tox21.csv.gz
Processing...
Done!


Total graphs: 7831
Number of node features: 9
Number of tasks (toxicity endpoints): 12
Data(x=[16, 9], edge_index=[2, 34], edge_attr=[34, 3], smiles='CCOc1ccc2nc(S(N)(=O)=O)sc2c1', y=[1, 12])
Shape of edge_index: torch.Size([2, 34])
Number of nodes: 16
Labels (y): tensor([[0., 0., 1., nan, nan, 0., 0., 1., 0., 0., 0., 0.]])
Batch size: 32, number of batches in train: 196


## 4. Define GCN Model

In [None]:
class GCN(nn.Module):
    def __init__(self, num_node_features, num_classes):
        super(GCN, self).__init__()
        # TODO: Define model architecture

    def forward(self, data):
        # TODO: Implement forward pass
        pass

## 5. Training and Evaluation

In [None]:
# TODO: Implement model training and evaluation functions

## 6. Results Visualization

In [None]:
# TODO: Visualization of results (ROC curve, PR curve, etc.)

## 7. Conclusion and Findings

# TODO: Add analysis of results and conclusions