In [47]:
import pytorch_lightning as pl
from torch import optim
import wandb
import torch
import os

from torch.nn.functional import binary_cross_entropy

import torch_geometric as tg
import torchmetrics
from pytorch_lightning.loggers.wandb import WandbLogger

from GraphCoAttention.datasets.HeterogenousDDI import HeteroDrugDrugInteractionData, HeteroQM9
# from GraphCoAttention.nn.models.CoAttention import CoAttention
from GraphCoAttention.nn.models.HeterogenousCoAttention import HeteroGNN
# from GraphCoAttention.nn.conv.GATConv import GATConv


In [48]:
import torch
from torch import nn
from torch.nn import Parameter
from torch.nn import functional as F
import torch_geometric as tg

from torch_geometric.nn import GATConv, HeteroConv, Linear, GATv2Conv
from torch_geometric.nn.glob import global_mean_pool, global_add_pool
from torch.nn import LeakyReLU

In [49]:
class HeteroGNN(torch.nn.Module):
    def __init__(self, hidden_channels, outer_out_channels, inner_out_channels,
                 num_layers, batch_size, num_node_types, num_heads):
        super().__init__()

        self.batch_size = batch_size
        self.hidden_channels = hidden_channels
        self.heads = num_heads

        self.convs = torch.nn.ModuleList()
        for _ in range(num_layers):
            conv = HeteroConv({
                ('x_i', 'inner_edge_i', 'x_i'): GATv2Conv(-1, self.hidden_channels, heads=num_heads),
                ('x_j', 'inner_edge_j', 'x_j'): GATv2Conv(-1, self.hidden_channels, heads=num_heads),
                ('x_i', 'outer_edge_ij', 'x_j'): GATv2Conv(-1, self.hidden_channels, heads=num_heads),
                ('x_j', 'outer_edge_ji', 'x_i'): GATv2Conv(-1, self.hidden_channels, heads=num_heads),
                ('x_i', 'inner_edge_i', 'x_i'): GATv2Conv(-1, self.hidden_channels, heads=num_heads),
                ('x_j', 'inner_edge_j', 'x_j'): GATv2Conv(-1, self.hidden_channels, heads=num_heads),
            }, aggr='sum')
            self.convs.append(conv)

        self.lin = Linear(self.hidden_channels, outer_out_channels)

        self.lin_i = Linear(self.hidden_channels, inner_out_channels)
        self.lin_j = Linear(self.hidden_channels, inner_out_channels)
        # self.hlin = tg.nn.HeteroLinear(hidden_channels, out_channels, num_node_types=num_node_types)

    def forward(self, x_dict, edge_index_dict, d):

        # x_dict, edge_index_dict = x_dict, edge_index_dict

        for conv in self.convs:
            x_dict = conv(x_dict, edge_index_dict)
            x_dict = {key: torch.tanh(torch.sum(x.view(-1, self.heads, self.hidden_channels), dim=1))
                      for key, x in x_dict.items()}

            # [print(key, x.shape) for key, x in x_dict.items()]
            # [print(key, x.view(-1, self.heads, self.hidden_channels).shape) for key, x in x_dict.items()]
            # [print(key, torch.mean(x.view(-1, self.heads, self.hidden_channels), dim=1).shape) for key, x in x_dict.items()]

        # p_i = F.leaky_relu(global_add_pool(x_dict['x_i'], batch=d['x_i'].batch, size=self.batch_size).unsqueeze(1))
        # p_j = F.leaky_relu(global_add_pool(x_dict['x_j'], batch=d['x_j'].batch, size=self.batch_size).unsqueeze(1))

        # p_i = global_add_pool(x_dict['x_i'], batch=d['x_i'].batch, size=self.batch_size).unsqueeze(1).sigmoid()
        # p_j = global_add_pool(x_dict['x_j'], batch=d['x_j'].batch, size=self.batch_size).unsqueeze(1).sigmoid()

        p_i = global_add_pool(x_dict['x_i'], batch=d['x_i'].batch, size=self.batch_size).unsqueeze(1).tanh()
        p_j = global_add_pool(x_dict['x_j'], batch=d['x_j'].batch, size=self.batch_size).unsqueeze(1).tanh()

        y_i_ = self.lin_i(p_i)
        y_j_ = self.lin_j(p_j)

        x = torch.cat([p_i, p_j], dim=1)
        x = torch.sum(x, dim=1)

        logits = self.lin(x).sigmoid()
        return logits, y_i_, y_j_

In [50]:
class Learner(pl.LightningModule):
    def __init__(self, root_dir, hidden_dim=25, n_cycles=16, n_head=1, dropout=0.1, lr=0.001, bs=2):
        super().__init__()
        self.root_dir = root_dir

        self.ddi_dataset = HeteroDrugDrugInteractionData(root=self.root_dir).shuffle()[:100]  # .shuffle()
        self.qm9_dataset = HeteroQM9(root=self.root_dir).shuffle()[:100]  # .shuffle()

        # self.dataset = self.dataset[:10]

        self.num_node_types = len(self.qm9_dataset[0].x_dict)
        self.num_workers = 32
        self.n_cycles = n_cycles
        self.n_head = n_head
        self.dropout = dropout
        self.batch_size = bs
        self.lr = lr

        # self.num_features = self.dataset.num_features
        self.hidden_dim = hidden_dim

        wandb.config.hidden_dim = self.hidden_dim
        wandb.config.n_layers = self.n_cycles
        wandb.config.n_head = self.n_head
        wandb.config.dropout = self.dropout

        # self.encoder = GATConv(self.num_features, self.hidden_dim, heads=self.n_head, dropout=self.dropout)
        #
        # self.inner = GATConv(self.hidden_dim * self.n_head, self.hidden_dim, heads=self.n_head,
        #                      add_self_loops=True, bipartite=False, dropout=self.dropout)
        # self.outer = GATConv(self.hidden_dim * self.n_head, self.hidden_dim, heads=self.n_head, add_self_loops=True,
        #                      concat=False, bipartite=True, dropout=self.dropout)
        #
        # self.update = tg.nn.dense.Linear(self.hidden_dim * self.n_head + self.hidden_dim, self.hidden_dim * self.n_head)
        # # self.update = GATConv(self.hidden_dim*self.n_head+self.hidden_dim, self.hidden_dim, heads=self.n_head,
        # #                       add_self_loops=True, bipartite=False, dropout=self.dropout)
        #
        # self.readout = tg.nn.dense.Linear(in_channels=2 * self.hidden_dim, out_channels=1)
        # # self.readout = GATConv(self.hidden_dim*self.n_head, self.hidden_dim, heads=1,
        # #                        add_self_loops=True, bipartite=False, dropout=self.dropout)

        self.HeterogenousCoAttention = HeteroGNN(hidden_channels=self.hidden_dim, outer_out_channels=1,
                                                 inner_out_channels=15, num_layers=self.n_cycles,
                                                 batch_size=self.batch_size, num_node_types=self.num_node_types,
                                                 num_heads=self.n_head)

        # self.CoAttention = CoAttention(hidden_channels=self.hidden_dim, encoder=self.encoder,
        #                                outer=self.outer, inner=self.inner,
        #                                update=self.update, readout=self.readout,
        #                                n_cycles=self.n_cycles, batch_size=self.batch_size, n_head=self.n_head)

        self.bce_loss = torch.nn.BCEWithLogitsLoss()
        self.mse_loss = torch.nn.L1Loss()

    def forward(self, batch, *args, **kwargs):

        y_ij, y_i_, y_j_ = self.HeterogenousCoAttention(batch.x_dict, batch.edge_index_dict, batch)

        # logits = self.CoAttention(data)
        # logits = torch.sigmoid(torch.mean(logits))
        return y_ij, y_i_, y_j_

    def training_step(self, data, batch_idx):

        _, y_i_, y_j_ = self(data['QM9'])
        mse1 = self.mse_loss(input=y_i_.flatten(), target=data['QM9']['y_i'].y_norm)
        mse2 = self.mse_loss(input=y_j_.flatten(), target=data['QM9']['y_j'].y_norm)
        mse = mse1 + mse2

        y_ij, _, _ = self(data['DDI'])
        y_pred = y_ij.squeeze()
        y_true = data['DDI'].binary_y.float()
        bce = self.bce_loss(input=y_pred, target=y_true)

        loss = mse + bce
        
        wandb.log({"train/mse1_loss": mse1.cpu().detach()})
        wandb.log({"train/mse2_loss": mse2.cpu().detach()})
        wandb.log({"train/bce_loss": bce.cpu().detach()})

        wandb.log({"train/loss": loss.cpu().detach()})

        wandb.log({'train/y_i_pred': y_i_.flatten().cpu().detach()})
        wandb.log({'train/y_i_true': data['QM9']['y_i'].y_norm.cpu().detach()})
        wandb.log({'train/y_j_pred': y_j_.flatten().cpu().detach()})
        wandb.log({'train/y_j_true': data['QM9']['y_j'].y_norm.cpu().detach()})

        wandb.log({"train/y_pred": wandb.Histogram(y_pred.cpu().detach())})
        wandb.log({"train/y_true": wandb.Histogram(y_true.cpu().detach())})

        return {'loss': loss}  # , 'train_accuracy': acc, 'train_f1': f1}

    def validation_step(self, val_batch, batch_idx, loader_idx):

        y_ij, y_i_, y_j_ = self(val_batch)

        if loader_idx == 1:
            y_pred = y_ij.squeeze()
            y_true = val_batch.binary_y.float()
            bce = self.bce_loss(input=y_pred, target=y_true)
            wandb.log({"val/loss": bce})
            loss = bce

        if loader_idx == 0:
            mse1 = self.mse_loss(input=y_i_.flatten(), target=val_batch['y_i'].y_norm)
            mse2 = self.mse_loss(input=y_j_.flatten(), target=val_batch['y_j'].y_norm)
            mse = mse1 + mse2
            wandb.log({"val/loss": mse})
            loss = mse

        # print(val_batch, loader_idx, batch_idx)
        # print(ddi_batch, type(ddi_batch))

        # y_ij, y_i_, y_j_ = self(val_batch)
        # y_pred = y_ij.squeeze()
        # y_true = val_batch.binary_y.float()
        #
        # mse1 = self.mse_loss(input=y_i_.flatten(), target=val_batch['y_i'].y_norm)
        # mse2 = self.mse_loss(input=y_j_.flatten(), target=val_batch['y_j'].y_norm)
        # mse = mse1 + mse2
        # bce = self.bce_loss(input=y_pred, target=y_true)
        # loss = mse
        # # self.log('validation_loss', bce_loss)
        # # self.log('Predicted', y_pred)
        # # self.log('Actual', y_true)
        # wandb.log({"val/loss": loss})

        return {'loss': loss}

    def configure_optimizers(self):
        optimizer = optim.AdamW(self.parameters(), lr=self.lr, betas=(0.28, 0.93), weight_decay=0.01)
        scheduler = optim.lr_scheduler.MultiStepLR(optimizer, '25,35', gamma=0.1)
        return [optimizer], [scheduler]

    def train_dataloader(self):
        qm9_dataloader = tg.loader.DataLoader(list(self.qm9_dataset), batch_size=self.batch_size,
                                              num_workers=self.num_workers, pin_memory=False, shuffle=True)

        ddi_dataloader = tg.loader.DataLoader(list(self.ddi_dataset), batch_size=self.batch_size,
                                              num_workers=self.num_workers, pin_memory=False, shuffle=True)

        loaders = {"QM9": qm9_dataloader, 'DDI': ddi_dataloader}
        return loaders

    def val_dataloader(self):
        qm9_dataloader = tg.loader.DataLoader(list(self.qm9_dataset), batch_size=self.batch_size,
                                              num_workers=self.num_workers, pin_memory=False, shuffle=True)

        ddi_dataloader = tg.loader.DataLoader(list(self.ddi_dataset), batch_size=self.batch_size,
                                              num_workers=self.num_workers, pin_memory=False, shuffle=True)
        # loaders = {"QM9": qm9_dataloader, 'DDI': ddi_dataloader}
        loaders = [qm9_dataloader, ddi_dataloader]
        return loaders

In [9]:
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
data_dir = os.path.join('GraphCoAttention', 'data')
trainer = pl.Trainer(gpus=[0], max_epochs=2000, check_val_every_n_epoch=500, accumulate_grad_batches=1)
learner = Learner(data_dir, bs=20, lr=0.001, n_cycles=40, hidden_dim=225, n_head=5)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs


In [51]:
qm9_dataloader = tg.loader.DataLoader(list(learner.qm9_dataset), batch_size=learner.batch_size,
                                      num_workers=learner.num_workers, pin_memory=False, shuffle=True)

ddi_dataloader = tg.loader.DataLoader(list(learner.ddi_dataset), batch_size=learner.batch_size,
                                      num_workers=learner.num_workers, pin_memory=False, shuffle=True)

loaders = {"QM9": qm9_dataloader, 'DDI': ddi_dataloader}

In [52]:
d = next(iter(loaders['DDI']))

In [53]:
d

Batch(
  binary_y=[5],
  [1mx_i[0m={
    x=[122, 9],
    batch=[122],
    ptr=[6]
  },
  [1mx_j[0m={
    x=[100, 9],
    batch=[100],
    ptr=[6]
  },
  [1m(x_i, inner_edge_i, x_i)[0m={
    edge_index=[2, 260],
    edge_attr=[260, 3]
  },
  [1m(x_j, inner_edge_j, x_j)[0m={
    edge_index=[2, 214],
    edge_attr=[130, 3]
  },
  [1m(x_i, outer_edge_ij, x_j)[0m={
    edge_index=[2, 2449],
    edge_attr=[130, 3]
  },
  [1m(x_j, outer_edge_ji, x_i)[0m={ edge_index=[2, 2449] }
)

In [54]:
d['x_i']

{'x': tensor([[5., 0., 4.,  ..., 2., 0., 0.],
        [7., 0., 2.,  ..., 1., 0., 0.],
        [5., 0., 3.,  ..., 1., 1., 1.],
        ...,
        [7., 0., 2.,  ..., 1., 0., 0.],
        [7., 0., 2.,  ..., 1., 0., 0.],
        [7., 0., 2.,  ..., 2., 0., 0.]]), 'batch': tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
        3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
        4, 4]), 'ptr': tensor([  0,  33,  55,  82, 109, 122])}

In [55]:
d['binary_y']

tensor([1, 1, 1, 1, 1])

In [None]:
_, y_i_, y_j_ = self(learner.qm9_dataset.data)
mse1 = self.mse_loss(input=y_i_.flatten(), target=data['QM9']['y_i'].y_norm)
mse2 = self.mse_loss(input=y_j_.flatten(), target=data['QM9']['y_j'].y_norm)
mse = mse1 + mse2

In [None]:


        y_ij, _, _ = self(data['DDI'])
        y_pred = y_ij.squeeze()
        y_true = data['DDI'].binary_y.float()
        bce = self.bce_loss(input=y_pred, target=y_true)

In [None]:

os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
data_dir = os.path.join('GraphCoAttention', 'data')
    wandb.init()
    wandb_logger = WandbLogger(project='flux', log_model='all')
    trainer = pl.Trainer(gpus=[0], max_epochs=2000, check_val_every_n_epoch=500, accumulate_grad_batches=1)
    trainer.fit(Learner(data_dir, bs=5, lr=0.001, n_cycles=40, hidden_dim=225, n_head=5))