# Molecular Toxicity Prediction (Tox21) using GCN

In this notebook, we implement and train a Graph Convolutional Network for predicting molecular toxicity based on the Tox21 dataset.

## 1. Installing Required Libraries

Run the following code to install PyTorch Geometric and other dependencies:

In [None]:
# Check if notebook is running in Google Colab
import os
IN_COLAB = 'COLAB_GPU' in os.environ

if IN_COLAB:
    # Install PyTorch Geometric and dependencies in Colab
    !pip install torch-geometric
    !pip install rdkit-pypi
    !pip install scikit-learn matplotlib seaborn

## 2. Import Libraries

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data, Dataset
from torch_geometric.datasets import TUDataset
from torch_geometric.nn import GCNConv, global_mean_pool
from torch_geometric.loader import DataLoader

from rdkit import Chem
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import roc_auc_score, precision_recall_curve, auc
from sklearn.model_selection import train_test_split

# Set visualization style
plt.style.use('seaborn-whitegrid')
sns.set_palette('muted')

# Check CUDA availability
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## 3. Loading and Preparing Tox21 Data

In [None]:
from torch_geometric.datasets import Tox21
from torch_geometric.transforms import ToUndirected

# Load the Tox21 dataset (downloaded to the data/Tox21 folder)
dataset = Tox21(root='data/Tox21', transform=ToUndirected())

print(f"Total graphs: {len(dataset)}")
print(f"Number of node features: {dataset.num_node_features}")
print(f"Number of tasks (toxicity endpoints): {dataset.num_classes}")

# View an example graph
data_example = dataset[0]
print(data_example)
print(f"Shape of edge_index: {data_example.edge_index.shape}")
print(f"Number of nodes: {data_example.num_nodes}")
print(f"Labels (y): {data_example.y}")

# Split into train/test (80/20)
torch.manual_seed(42)
train_len = int(0.8 * len(dataset))
train_dataset, test_dataset = torch.utils.data.random_split(
    dataset, [train_len, len(dataset) - train_len]
)

# Create DataLoader
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader  = DataLoader(test_dataset, batch_size=32, shuffle=False)

print(f"Batch size: {train_loader.batch_size}, number of batches in train: {len(train_loader)}")



## 4. Define GCN Model

In [None]:
class GCN(nn.Module):
    def __init__(self, num_node_features, num_classes):
        super(GCN, self).__init__()
        # TODO: Define model architecture
        
    def forward(self, data):
        # TODO: Implement forward pass
        pass

## 5. Training and Evaluation

In [None]:
# TODO: Implement model training and evaluation functions

## 6. Results Visualization

In [None]:
# TODO: Visualization of results (ROC curve, PR curve, etc.)

## 7. Conclusion and Findings

# TODO: Add analysis of results and conclusions