In [1]:
import pytorch_lightning as pl
from torch import optim
import wandb
import torch
import os

from torch.nn.functional import binary_cross_entropy

import torch_geometric as tg
import torchmetrics
from pytorch_lightning.loggers.wandb import WandbLogger

from GraphCoAttention.datasets.HeterogenousDDI import HeteroDrugDrugInteractionData, HeteroQM9
# from GraphCoAttention.nn.models.CoAttention import CoAttention
from GraphCoAttention.nn.models.HeterogenousCoAttention import HeteroGNN
# from GraphCoAttention.nn.conv.GATConv import GATConv


In [2]:
import torch
from torch import nn
from torch.nn import Parameter
from torch.nn import functional as F
import torch_geometric as tg

from torch_geometric.nn import GATConv, HeteroConv, Linear, GATv2Conv
from torch_geometric.nn.glob import global_mean_pool, global_add_pool
from torch.nn import LeakyReLU

In [3]:
class HeteroGNN(torch.nn.Module):
    def __init__(self, hidden_channels, outer_out_channels, inner_out_channels,
                 num_layers, batch_size, num_node_types, num_heads):
        super().__init__()

        self.batch_size = batch_size
        self.hidden_channels = hidden_channels
        self.heads = num_heads

        self.convs = torch.nn.ModuleList()
        for _ in range(num_layers):
            conv = HeteroConv({
                ('x_i', 'inner_edge_i', 'x_i'): GATv2Conv(-1, self.hidden_channels, heads=num_heads),
                ('x_j', 'inner_edge_j', 'x_j'): GATv2Conv(-1, self.hidden_channels, heads=num_heads),
                ('x_i', 'outer_edge_ij', 'x_j'): GATv2Conv(-1, self.hidden_channels, heads=num_heads),
                ('x_j', 'outer_edge_ji', 'x_i'): GATv2Conv(-1, self.hidden_channels, heads=num_heads),
                ('x_i', 'inner_edge_i', 'x_i'): GATv2Conv(-1, self.hidden_channels, heads=num_heads),
                ('x_j', 'inner_edge_j', 'x_j'): GATv2Conv(-1, self.hidden_channels, heads=num_heads),
            }, aggr='sum')
            self.convs.append(conv)

        self.lin = Linear(self.hidden_channels, outer_out_channels)

        self.lin_i = Linear(self.hidden_channels, inner_out_channels)
        self.lin_j = Linear(self.hidden_channels, inner_out_channels)
        # self.hlin = tg.nn.HeteroLinear(hidden_channels, out_channels, num_node_types=num_node_types)

    def forward(self, x_dict, edge_index_dict, d):

        # x_dict, edge_index_dict = x_dict, edge_index_dict

        for conv in self.convs:
            x_dict = conv(x_dict, edge_index_dict)
            x_dict = {key: torch.tanh(torch.sum(x.view(-1, self.heads, self.hidden_channels), dim=1))
                      for key, x in x_dict.items()}

            # [print(key, x.shape) for key, x in x_dict.items()]
            # [print(key, x.view(-1, self.heads, self.hidden_channels).shape) for key, x in x_dict.items()]
            # [print(key, torch.mean(x.view(-1, self.heads, self.hidden_channels), dim=1).shape) for key, x in x_dict.items()]

        # p_i = F.leaky_relu(global_add_pool(x_dict['x_i'], batch=d['x_i'].batch, size=self.batch_size).unsqueeze(1))
        # p_j = F.leaky_relu(global_add_pool(x_dict['x_j'], batch=d['x_j'].batch, size=self.batch_size).unsqueeze(1))

        # p_i = global_add_pool(x_dict['x_i'], batch=d['x_i'].batch, size=self.batch_size).unsqueeze(1).sigmoid()
        # p_j = global_add_pool(x_dict['x_j'], batch=d['x_j'].batch, size=self.batch_size).unsqueeze(1).sigmoid()

        p_i = global_add_pool(x_dict['x_i'], batch=d['x_i'].batch, size=self.batch_size).unsqueeze(1).tanh()
        p_j = global_add_pool(x_dict['x_j'], batch=d['x_j'].batch, size=self.batch_size).unsqueeze(1).tanh()

        y_i_ = self.lin_i(p_i)
        y_j_ = self.lin_j(p_j)

        x = torch.cat([p_i, p_j], dim=1)
        x = torch.sum(x, dim=1)

        logits = self.lin(x).sigmoid()
        return logits, y_i_, y_j_

In [4]:
class Learner(pl.LightningModule):
    def __init__(self, root_dir, hidden_dim=25, n_cycles=16, n_head=1, dropout=0.1, lr=0.001, bs=2):
        super().__init__()

        print(root_dir)
        exit()
        self.root_dir = root_dir

        self.ddi_dataset = HeteroDrugDrugInteractionData(root=self.root_dir).shuffle()[:10]  # .shuffle()
        self.qm9_dataset = HeteroQM9(root=self.root_dir).shuffle()[:10]  # .shuffle()

        # self.dataset = self.dataset[:10]

        self.num_node_types = len(self.qm9_dataset[0].x_dict)
        self.num_workers = 32
        self.n_cycles = n_cycles
        self.n_head = n_head
        self.dropout = dropout
        self.batch_size = bs
        self.lr = lr

        # self.num_features = self.dataset.num_features
        self.hidden_dim = hidden_dim

        wandb.config.hidden_dim = self.hidden_dim
        wandb.config.n_layers = self.n_cycles
        wandb.config.n_head = self.n_head
        wandb.config.dropout = self.dropout

        # self.encoder = GATConv(self.num_features, self.hidden_dim, heads=self.n_head, dropout=self.dropout)
        #
        # self.inner = GATConv(self.hidden_dim * self.n_head, self.hidden_dim, heads=self.n_head,
        #                      add_self_loops=True, bipartite=False, dropout=self.dropout)
        # self.outer = GATConv(self.hidden_dim * self.n_head, self.hidden_dim, heads=self.n_head, add_self_loops=True,
        #                      concat=False, bipartite=True, dropout=self.dropout)
        #
        # self.update = tg.nn.dense.Linear(self.hidden_dim * self.n_head + self.hidden_dim, self.hidden_dim * self.n_head)
        # # self.update = GATConv(self.hidden_dim*self.n_head+self.hidden_dim, self.hidden_dim, heads=self.n_head,
        # #                       add_self_loops=True, bipartite=False, dropout=self.dropout)
        #
        # self.readout = tg.nn.dense.Linear(in_channels=2 * self.hidden_dim, out_channels=1)
        # # self.readout = GATConv(self.hidden_dim*self.n_head, self.hidden_dim, heads=1,
        # #                        add_self_loops=True, bipartite=False, dropout=self.dropout)

        self.HeterogenousCoAttention = HeteroGNN(hidden_channels=self.hidden_dim, outer_out_channels=1,
                                                 inner_out_channels=15, num_layers=self.n_cycles,
                                                 batch_size=self.batch_size, num_node_types=self.num_node_types,
                                                 num_heads=self.n_head)

        # self.CoAttention = CoAttention(hidden_channels=self.hidden_dim, encoder=self.encoder,
        #                                outer=self.outer, inner=self.inner,
        #                                update=self.update, readout=self.readout,
        #                                n_cycles=self.n_cycles, batch_size=self.batch_size, n_head=self.n_head)

        self.bce_loss = torch.nn.BCEWithLogitsLoss()
        self.mse_loss = torch.nn.L1Loss()

    def forward(self, batch, *args, **kwargs):

        y_ij, y_i_, y_j_ = self.HeterogenousCoAttention(batch.x_dict, batch.edge_index_dict, batch)

        # logits = self.CoAttention(data)
        # logits = torch.sigmoid(torch.mean(logits))
        return y_ij, y_i_, y_j_

    def training_step(self, data, batch_idx):

        _, y_i_, y_j_ = self(data['QM9'])
        mse1 = self.mse_loss(input=y_i_.flatten(), target=data['QM9']['y_i'].y_norm)
        mse2 = self.mse_loss(input=y_j_.flatten(), target=data['QM9']['y_j'].y_norm)
        mse = mse1 + mse2

        y_ij, _, _ = self(data['DDI'])
        y_pred = y_ij.squeeze()
        y_true = data['DDI'].binary_y.float()
        bce = self.bce_loss(input=y_pred, target=y_true)

        loss = mse + bce
        
        wandb.log({"train/mse1_loss": mse1.cpu().detach()})
        wandb.log({"train/mse2_loss": mse2.cpu().detach()})
        wandb.log({"train/bce_loss": bce.cpu().detach()})

        wandb.log({"train/loss": loss.cpu().detach()})

        wandb.log({'train/y_i_pred': y_i_.flatten().cpu().detach()})
        wandb.log({'train/y_i_true': data['QM9']['y_i'].y_norm.cpu().detach()})
        wandb.log({'train/y_j_pred': y_j_.flatten().cpu().detach()})
        wandb.log({'train/y_j_true': data['QM9']['y_j'].y_norm.cpu().detach()})

        wandb.log({"train/y_pred": wandb.Histogram(y_pred.cpu().detach())})
        wandb.log({"train/y_true": wandb.Histogram(y_true.cpu().detach())})

        return {'loss': loss}  # , 'train_accuracy': acc, 'train_f1': f1}

    def validation_step(self, val_batch, batch_idx, loader_idx):

        y_ij, y_i_, y_j_ = self(val_batch)

        if loader_idx == 1:
            y_pred = y_ij.squeeze()
            y_true = val_batch.binary_y.float()
            bce = self.bce_loss(input=y_pred, target=y_true)
            wandb.log({"val/loss": bce})
            loss = bce

        if loader_idx == 0:
            mse1 = self.mse_loss(input=y_i_.flatten(), target=val_batch['y_i'].y_norm)
            mse2 = self.mse_loss(input=y_j_.flatten(), target=val_batch['y_j'].y_norm)
            mse = mse1 + mse2
            wandb.log({"val/loss": mse})
            loss = mse

        # print(val_batch, loader_idx, batch_idx)
        # print(ddi_batch, type(ddi_batch))

        # y_ij, y_i_, y_j_ = self(val_batch)
        # y_pred = y_ij.squeeze()
        # y_true = val_batch.binary_y.float()
        #
        # mse1 = self.mse_loss(input=y_i_.flatten(), target=val_batch['y_i'].y_norm)
        # mse2 = self.mse_loss(input=y_j_.flatten(), target=val_batch['y_j'].y_norm)
        # mse = mse1 + mse2
        # bce = self.bce_loss(input=y_pred, target=y_true)
        # loss = mse
        # # self.log('validation_loss', bce_loss)
        # # self.log('Predicted', y_pred)
        # # self.log('Actual', y_true)
        # wandb.log({"val/loss": loss})

        return {'loss': loss}

    def configure_optimizers(self):
        optimizer = optim.AdamW(self.parameters(), lr=self.lr, betas=(0.28, 0.93), weight_decay=0.01)
        scheduler = optim.lr_scheduler.MultiStepLR(optimizer, '25,35', gamma=0.1)
        return [optimizer], [scheduler]

    def train_dataloader(self):
        qm9_dataloader = tg.loader.DataLoader(list(self.qm9_dataset), batch_size=self.batch_size,
                                              num_workers=self.num_workers, pin_memory=False, shuffle=True)

        ddi_dataloader = tg.loader.DataLoader(list(self.ddi_dataset), batch_size=self.batch_size,
                                              num_workers=self.num_workers, pin_memory=False, shuffle=True)

        loaders = {"QM9": qm9_dataloader, 'DDI': ddi_dataloader}
        return loaders

    def val_dataloader(self):
        qm9_dataloader = tg.loader.DataLoader(list(self.qm9_dataset), batch_size=self.batch_size,
                                              num_workers=self.num_workers, pin_memory=False, shuffle=True)

        ddi_dataloader = tg.loader.DataLoader(list(self.ddi_dataset), batch_size=self.batch_size,
                                              num_workers=self.num_workers, pin_memory=False, shuffle=True)
        # loaders = {"QM9": qm9_dataloader, 'DDI': ddi_dataloader}
        loaders = [qm9_dataloader, ddi_dataloader]
        return loaders





In [5]:
os.getcwd()
os.listdir(os.path.join('..', '' 'GraphCoAttention', 'data', 'processed'))

['pre_transform.pt',
 'heterogenous_decagon_ps_ns_V4.pt',
 'data_v3.pt',
 'heterogenous_qm9_norm.pt',
 'pre_filter.pt',
 'pyg_molecules.pt']

In [6]:

# os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
data_dir = os.path.join('..', '' 'GraphCoAttention', 'datasets', 'data')
wandb.init()
wandb_logger = WandbLogger(project='flux', log_model='all')
trainer = pl.Trainer(gpus=[0], max_epochs=2000, check_val_every_n_epoch=500, accumulate_grad_batches=1)
learner = Learner(data_dir, bs=10, lr=0.001, n_cycles=40, hidden_dim=225, n_head=5)
trainer.fit(Learner(data_dir, bs=4, lr=0.001, n_cycles=3, hidden_dim=4, n_head=2))

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mkatharina_z[0m (use `wandb login --relogin` to force relogin)


GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs


../GraphCoAttention/datasets/data
100% [....................................................] 35688667 / 35688667

Processing...


        ./._bio-decagon-combo.csv      STITCH 2 Polypharmacy Side Effect  \
0                    CID000002173  CID000003345                 C0151714   
1                    CID000002173  CID000003345                 C0035344   
2                    CID000002173  CID000003345                 C0004144   
3                    CID000002173  CID000003345                 C0002063   
4                    CID000002173  CID000003345                 C0004604   
...                           ...           ...                      ...   
4649437              CID000003461  CID000003954                 C0035410   
4649438              CID000003461  CID000003954                 C0043096   
4649439              CID000003461  CID000003954                 C0003962   
4649440              CID000003461  CID000003954                 C0038999   
4649441                       NaN           NaN                      NaN   

                   Side Effect Name  
0                   hypermagnesemia  
1        re

100%|█████████████████████████████| 4649442/4649442 [00:31<00:00, 149210.67it/s]
100%|███████████████████████████████████████| 1000/1000 [13:44<00:00,  1.21it/s]
100%|█████████████████████████████████████| 1000/1000 [00:00<00:00, 2438.37it/s]
100%|█████████████████████████████████████| 1000/1000 [00:00<00:00, 2317.20it/s]


Saving...


Done!


100% [....................................................] 86144227 / 86144227

Processing...
Downloading https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/molnet_publish/qm9.zip
Extracting ../GraphCoAttention/datasets/data/raw/qm9.zip
Downloading https://ndownloader.figshare.com/files/3195404
Processing...
100%|██████████████████████████████████| 133885/133885 [02:35<00:00, 861.24it/s]
Done!


tensor([[    0.0000,    13.2100,   -10.5499,     3.1865,    13.7363,    35.3641,
             1.2177, -1101.4878, -1101.4098, -1101.3840, -1102.0229,     6.4690,
           -17.1722,   -17.2868,   -17.3897,   -16.1519,   157.7118,   157.7100,
           157.7070]])


133885it [02:57, 753.53it/s]
100%|█████████████████████████████████| 133139/133139 [01:19<00:00, 1674.21it/s]


Saving...


Done!


../GraphCoAttention/datasets/data


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name                    | Type              | Params
--------------------------------------------------------------
0 | HeterogenousCoAttention | HeteroGNN         | 539   
1 | bce_loss                | BCEWithLogitsLoss | 0     
2 | mse_loss                | L1Loss            | 0     
--------------------------------------------------------------
539       Trainable params
0         Non-trainable params
539       Total params
0.002     Total estimated model params size (MB)


Validation sanity check:   0%|                            | 0/4 [00:00<?, ?it/s]

RuntimeError: CUDA error: CUBLAS_STATUS_NOT_INITIALIZED when calling `cublasCreate(handle)`

In [34]:
qm9_dataloader = tg.loader.DataLoader(list(learner.qm9_dataset), batch_size=100,
                                      num_workers=learner.num_workers, pin_memory=False, shuffle=True)

ddi_dataloader = tg.loader.DataLoader(list(learner.ddi_dataset), batch_size=100,
                                      num_workers=learner.num_workers, pin_memory=False, shuffle=True)

loaders = {"QM9": qm9_dataloader, 'DDI': ddi_dataloader}

In [35]:
d = next(iter(loaders['DDI']))

In [36]:
d

Batch(
  binary_y=[100],
  [1mx_i[0m={
    x=[2499, 9],
    batch=[2499],
    ptr=[101]
  },
  [1mx_j[0m={
    x=[2694, 9],
    batch=[2694],
    ptr=[101]
  },
  [1m(x_i, inner_edge_i, x_i)[0m={
    edge_index=[2, 5348],
    edge_attr=[5348, 3]
  },
  [1m(x_j, inner_edge_j, x_j)[0m={
    edge_index=[2, 5720],
    edge_attr=[3281, 3]
  },
  [1m(x_i, outer_edge_ij, x_j)[0m={
    edge_index=[2, 63630],
    edge_attr=[3281, 3]
  },
  [1m(x_j, outer_edge_ji, x_i)[0m={ edge_index=[2, 63630] }
)

In [37]:
d['binary_y']

tensor([1, 0, 0, 1, 0, 0, 1, 1, 0, 1, 0, 0, 0, 1, 0, 0, 1, 1, 0, 1, 1, 0, 0, 1,
        0, 0, 0, 1, 1, 0, 1, 0, 1, 1, 0, 1, 0, 0, 1, 1, 0, 0, 1, 1, 1, 0, 1, 0,
        1, 0, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 0,
        1, 0, 1, 0, 1, 1, 0, 1, 1, 0, 0, 1, 1, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1,
        0, 1, 1, 0])