In [2]:
import pytorch_lightning as pl
from torch import optim
import wandb
import torch
import os

from torch.nn.functional import binary_cross_entropy

import torch_geometric as tg
import torchmetrics
from pytorch_lightning.loggers.wandb import WandbLogger

from GraphCoAttention.datasets.HeterogenousDDI import HeteroDrugDrugInteractionData
from GraphCoAttention.nn.models.HeterogenousCoAttention import HeteroGNN

In [3]:
class Learner(pl.LightningModule):
    def __init__(self, root_dir, hidden_dim=25, n_cycles=16, n_head=1, dropout=0.1, lr=0.001, bs=2):
        super().__init__()
        self.root_dir = root_dir

        self.dataset = HeteroDrugDrugInteractionData(root=self.root_dir)
        self.dataset = self.dataset.shuffle()

        self.dataset = self.dataset[:10]

        self.num_node_types = len(self.dataset[0].x_dict)
        self.num_workers = 32
        self.n_cycles = n_cycles
        self.n_head = n_head
        self.dropout = dropout
        self.batch_size = bs
        self.lr = lr

        # self.num_features = self.dataset.num_features
        self.hidden_dim = hidden_dim

        wandb.config.hidden_dim = self.hidden_dim
        wandb.config.n_layers = self.n_cycles
        wandb.config.n_head = self.n_head
        wandb.config.dropout = self.dropout

        self.HeterogenousCoAttention = HeteroGNN(hidden_channels=self.hidden_dim, out_channels=1, num_layers=self.n_cycles,
                                                 batch_size=self.batch_size, num_node_types=self.num_node_types,
                                                 num_heads=self.n_head)

        # self.CoAttention = CoAttention(hidden_channels=self.hidden_dim, encoder=self.encoder,
        #                                outer=self.outer, inner=self.inner,
        #                                update=self.update, readout=self.readout,
        #                                n_cycles=self.n_cycles, batch_size=self.batch_size, n_head=self.n_head)

        self.bce_loss = torch.nn.BCEWithLogitsLoss()

    def forward(self, batch, *args, **kwargs):

        logits = self.HeterogenousCoAttention(batch.x_dict, batch.edge_index_dict, batch)

        # print(logits)
        # exit()
        # logits = self.CoAttention(data)
        # logits = torch.sigmoid(torch.mean(logits))
        return logits

    def training_step(self, data, batch_idx):
        logits = self(data)
        y_pred = logits.squeeze()
        y_true = data.binary_y.float()

        bce = self.bce_loss(input=y_pred, target=y_true)
        # self.log('train_loss', bce)
        wandb.log({"train/loss": bce})
        wandb.log({'train/y_pred': y_pred})
        wandb.log({'train/y_true': y_true})
        return {'loss': bce}  # , 'train_accuracy': acc, 'train_f1': f1}

    def validation_step(self, val_batch, batch_idx):

        # print(val_batch.binary_y.float())

        logits = self(val_batch)
        y_pred = logits.squeeze()
        y_true = val_batch.binary_y.float()

        bce_loss = self.bce_loss(input=y_pred, target=y_true)
        # self.log('validation_loss', bce_loss)
        # self.log('Predicted', y_pred)
        # self.log('Actual', y_true)
        wandb.log({"val/loss": bce_loss})
        return {'loss': bce_loss}

    def configure_optimizers(self):
        optimizer = optim.AdamW(self.parameters(), lr=self.lr, betas=(0.28, 0.93), weight_decay=0.01)
        scheduler = optim.lr_scheduler.MultiStepLR(optimizer, '25,35', gamma=0.1)
        return [optimizer], [scheduler]

    def train_dataloader(self):
        return tg.loader.DataLoader(list(self.dataset), batch_size=self.batch_size,
                                    num_workers=self.num_workers, pin_memory=False, shuffle=True)

    def val_dataloader(self):
        return tg.loader.DataLoader(list(self.dataset), batch_size=self.batch_size,
                                    num_workers=self.num_workers, pin_memory=False, shuffle=True)


In [4]:

# os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
data_dir = os.path.join('GraphCoAttention', 'data')
wandb_logger = WandbLogger(project='flux', log_model='all')
trainer = pl.Trainer(gpus=[0], max_epochs=2000, check_val_every_n_epoch=500, accumulate_grad_batches=1)
trainer.fit(Learner(data_dir, bs=2, lr=0.0005, n_cycles=30, hidden_dim=10, n_head=4))


GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs


100% [....................................................] 35688667 / 35688667

Processing...


        ./._bio-decagon-combo.csv      STITCH 2 Polypharmacy Side Effect  \
0                    CID000002173  CID000003345                 C0151714   
1                    CID000002173  CID000003345                 C0035344   
2                    CID000002173  CID000003345                 C0004144   
3                    CID000002173  CID000003345                 C0002063   
4                    CID000002173  CID000003345                 C0004604   
...                           ...           ...                      ...   
4649437              CID000003461  CID000003954                 C0035410   
4649438              CID000003461  CID000003954                 C0043096   
4649439              CID000003461  CID000003954                 C0003962   
4649440              CID000003461  CID000003954                 C0038999   
4649441                       NaN           NaN                      NaN   

                   Side Effect Name  
0                   hypermagnesemia  
1        re

100%|█████████████████████████████| 4649442/4649442 [00:26<00:00, 176720.40it/s]
100%|███████████████████████████████████████| 1000/1000 [14:36<00:00,  1.14it/s]
100%|█████████████████████████████████████| 1000/1000 [00:00<00:00, 2997.92it/s]
100%|█████████████████████████████████████| 1000/1000 [00:00<00:00, 2790.91it/s]


Saving...


Done!
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33msyntensor[0m (use `wandb login --relogin` to force relogin)


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name                    | Type              | Params
--------------------------------------------------------------
0 | HeterogenousCoAttention | HeteroGNN         | 19.2 K
1 | bce_loss                | BCEWithLogitsLoss | 0     
--------------------------------------------------------------
19.2 K    Trainable params
0         Non-trainable params
19.2 K    Total params
0.077     Total estimated model params size (MB)


Epoch 499:  50%|██████      | 5/10 [00:03<00:03,  1.29it/s, loss=0.611, v_num=0]
Validating: 0it [00:00, ?it/s][A
Validating:   0%|                                         | 0/5 [00:00<?, ?it/s][A
Epoch 499:  70%|████████▍   | 7/10 [00:06<00:02,  1.02it/s, loss=0.611, v_num=0][A
Validating:  40%|█████████████▏                   | 2/5 [00:04<00:06,  2.03s/it][A
Epoch 499:  90%|██████████▊ | 9/10 [00:09<00:01,  1.07s/it, loss=0.611, v_num=0][A
Validating:  80%|██████████████████████████▍      | 4/5 [00:07<00:01,  1.57s/it][A
Epoch 499: 100%|███████████| 10/10 [00:11<00:00,  1.18s/it, loss=0.611, v_num=0][A
Epoch 999:  60%|███████▏    | 6/10 [00:03<00:02,  1.54it/s, loss=0.611, v_num=0][A
Validating: 0it [00:00, ?it/s][A
Validating:   0%|                                         | 0/5 [00:00<?, ?it/s][A
Validating:  20%|██████▌                          | 1/5 [00:01<00:06,  1.71s/it][A
Epoch 999:  80%|█████████▌  | 8/10 [00:05<00:01,  1.39it/s, loss=0.611, v_num=0][A
Validating: