In [1]:
# Environment
!pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu113
!pip install torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric -f https://data.pyg.org/whl/torch-1.12.0+cu113.html
!pip install pytorch-lightning

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/, https://download.pytorch.org/whl/cu113
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in links: https://data.pyg.org/whl/torch-1.12.0+cu113.html
Collecting torch-scatter
  Downloading https://data.pyg.org/whl/torch-1.12.0%2Bcu113/torch_scatter-2.0.9-cp37-cp37m-linux_x86_64.whl (7.9 MB)
[K     |████████████████████████████████| 7.9 MB 7.2 MB/s 
[?25hCollecting torch-sparse
  Downloading https://data.pyg.org/whl/torch-1.12.0%2Bcu113/torch_sparse-0.6.14-cp37-cp37m-linux_x86_64.whl (3.5 MB)
[K     |████████████████████████████████| 3.5 MB 2.9 MB/s 
[?25hCollecting torch-cluster
  Downloading https://data.pyg.org/whl/torch-1.12.0%2Bcu113/torch_cluster-1.6.0-cp37-cp37m-linux_x86_64.whl (2.4 MB)
[K     |████████████████████████████████| 2.4 MB 44.9 MB/s 
[?25hCollecting torch-spline-conv
  Downloading https://data.pyg.org/whl/tor

In [16]:
# Dataset
from torch_geometric.datasets import TUDataset

training_dataset = TUDataset('./dataset', 'Tox21_AhR_training')
validation_dataset = TUDataset('./dataset', 'Tox21_AhR_testing')
test_dataset = TUDataset('./dataset', 'Tox21_AhR_evaluation')

print(training_dataset)
print(validation_dataset)
print(test_dataset)

print(training_dataset[0])
print(validation_dataset[0])
print(test_dataset[0])

Tox21_AhR_training(8169)
Tox21_AhR_testing(272)
Tox21_AhR_evaluation(607)
Data(edge_index=[2, 52], x=[25, 50], edge_attr=[52, 4], y=[1])
Data(edge_index=[2, 44], x=[20, 51], edge_attr=[44, 4], y=[1])
Data(edge_index=[2, 118], x=[53, 53], edge_attr=[118, 4], y=[1])


In [17]:
# Pytorch-lightning datamodule
from torch.utils.data import DataLoader
from pytorch_lightning import LightningDataModule
from torch_geometric.data import Data, Batch
from torch.nn.functional import pad

class CustomData(LightningDataModule):
  def __init__(self, training_set, validation_set, test_set, batch_size=128, num_workers=1):
    super().__init__()
    self.training_set = training_set
    self.validation_set = validation_set
    self.test_set = test_set
    self.batch_size = batch_size
    self.num_workers = num_workers
  
  def collate_function(self, batch):
    return Batch.from_data_list([Data(edge_index=data.edge_index, 
                                      x=pad(data.x, (0,3), 'constant', 0.)[:,:53], 
                                      edge_attr=data.edge_attr, 
                                      y=data.y.unsqueeze(0).float()) for data in batch])
  
  def train_dataloader(self):
    return DataLoader(self.training_set, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True, collate_fn=self.collate_function)

  def val_dataloader(self):
    return DataLoader(self.validation_set, batch_size=self.batch_size, num_workers=self.num_workers, collate_fn=self.collate_function)

  def test_dataloader(self):
    return DataLoader(self.test_set, batch_size=self.batch_size, num_workers=self.num_workers, collate_fn=self.collate_function)

In [18]:
# Pytorch and pytorch-geomtric module
from torch.nn import Module, Linear
from torch_geometric.nn import GATConv, global_mean_pool

class CustomGAT(Module):
  def __init__(self, input_size, label_size, layer_size=64, dropout=0.1, heads=2):
    super().__init__()
    self.input_size = input_size
    self.layer_size = layer_size
    self.dropout = dropout
    self.heads = heads
    self.label_size = label_size
    self.setup()

  def setup(self):
    self.first_layer = GATConv(self.input_size, self.layer_size, dropout=self.dropout, heads=self.heads, concat=False)
    self.last_layer = GATConv(self.layer_size, self.layer_size, dropout=self.dropout, heads=self.heads, concat=False)
    self.ffnn = Linear(self.layer_size, self.label_size)
  
  def convert_graph_into_single_vector(self, graph_hidden, batch_index):
    return global_mean_pool(graph_hidden, batch_index)

  def forward(self, batch):
    z = self.first_layer(batch.x, batch.edge_index)
    z = self.last_layer(z, batch.edge_index)
    z = self.convert_graph_into_single_vector(z, batch.batch)
    z = self.ffnn(z)
    return z

In [19]:
# Pytorch-lightning module
from pytorch_lightning import LightningModule
from torch.nn import  BCEWithLogitsLoss
from torch.optim import Adam

class CustomModel(LightningModule):
  def __init__(self, model, learning_rate=1e-3):
    super().__init__()
    self.model = model
    self.lr = learning_rate

  def forward(self, batch, mode):
    z = self.model(batch)
    loss = self.loss_function(z, batch.y)
    self.log(f"{mode}_loss", loss, batch_size=batch.y.size(0), prog_bar=True, on_step=False, on_epoch=True)
    return loss, z, batch.y
  
  def training_step(self, batch, batch_idx):
    loss, predict, answer = self(batch, 'train')
    return {'loss':loss, 'predict':predict, 'answer':answer}

  def validation_step(self, batch, batch_idx):
    loss, predict, answer = self(batch, 'val')
    return {'loss':loss, 'predict':predict, 'answer':answer}
  
  def test_step(self, batch, batch_idx):
    loss, predict, answer = self(batch, 'test')
    return {'loss':loss, 'predict':predict, 'answer':answer}

  def predict_step(self, batch, batch_idx):
    predict = self.model(batch)
    return predict

  def loss_function(self, output, target):
    return BCEWithLogitsLoss(reduction='mean')(output, target)

  def configure_optimizers(self):
    optimizer = Adam(self.parameters(), lr=self.lr)
    return optimizer

In [20]:
# Training
from pytorch_lightning import Trainer

import warnings
warnings.filterwarnings(action='ignore')
# warnings.filterwarnings(action='default')

data_module = CustomData(training_dataset, validation_dataset, test_dataset)

gat = CustomGAT(53, 1)
model = CustomModel(gat)

#trainer = Trainer(max_epochs=1, accelerator='gpu', devices=[0])
trainer = Trainer(max_epochs=1, accelerator='cpu')

trainer.fit(model, datamodule=data_module)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Missing logger folder: /content/lightning_logs

  | Name  | Type      | Params
------------------------------------
0 | model | CustomGAT | 15.7 K
------------------------------------
15.7 K    Trainable params
0         Non-trainable params
15.7 K    Total params
0.063     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=1` reached.


In [34]:
# Evaluation
import torch
from torchmetrics.functional import auroc

outputs = trainer.predict(model, dataloaders=data_module.test_dataloader())
y = torch.concat(outputs)
x = torch.concat([batch.y for batch in data_module.test_dataloader()]).int()

evaluation = auroc(y, x)
print(f"auc-roc: {evaluation}")

Predicting: 64it [00:00, ?it/s]

auc-roc: 0.4389864206314087
