In [7]:
import pytorch_lightning as pl
from torch import optim
import wandb
import torch
import os

from torch.nn.functional import binary_cross_entropy

import torch_geometric as tg
import torchmetrics
from pytorch_lightning.loggers.wandb import WandbLogger

import torch
from torch import nn
from torch.nn import Parameter, Sequential, ReLU, GRU
from torch.nn import functional as F
import torch_geometric as tg

from GraphCoAttention.datasets.HeterogenousDDI import HeteroDrugDrugInteractionData, HeteroQM9


from torch_geometric.nn import GATConv, HeteroConv, Linear, GATv2Conv, NNConv, Set2Set
from torch_geometric.nn.glob import global_mean_pool, global_add_pool
from torch.nn import LeakyReLU

# from GraphCoAttention.data.MultipartiteData import BipartitePairData

In [8]:
class HeteroGNN(torch.nn.Module):
    def __init__(self, hidden_channels, outer_out_channels, inner_out_channels,
                 num_layers, batch_size, num_node_types, num_heads):
        super().__init__()

        self.batch_size = batch_size
        self.hidden_channels = hidden_channels
        self.heads = num_heads

        self.convs = torch.nn.ModuleList()
        for _ in range(num_layers):
            conv = HeteroConv({
                ('x_i', 'inner_edge_i', 'x_i'): GATv2Conv(-1, self.hidden_channels, heads=num_heads),
                ('x_j', 'inner_edge_j', 'x_j'): GATv2Conv(-1, self.hidden_channels, heads=num_heads),
                # ('x_i', 'outer_edge_ij', 'x_j'): GATv2Conv(-1, self.hidden_channels, heads=num_heads),
                # ('x_j', 'outer_edge_ji', 'x_i'): GATv2Conv(-1, self.hidden_channels, heads=num_heads),
                # ('x_i', 'inner_edge_i', 'x_i'): GATv2Conv(-1, self.hidden_channels, heads=num_heads),
                # ('x_j', 'inner_edge_j', 'x_j'): GATv2Conv(-1, self.hidden_channels, heads=num_heads),
            }, aggr='sum')
            self.convs.append(conv)

        self.lin = Linear(self.hidden_channels, outer_out_channels)

        self.lin_i = Linear(self.hidden_channels, inner_out_channels)
        self.lin_j = Linear(self.hidden_channels, inner_out_channels)
        # self.hlin = tg.nn.HeteroLinear(hidden_channels, out_channels, num_node_types=num_node_types)

    def forward(self, x_dict, edge_index_dict, d):

        # x_dict, edge_index_dict = x_dict, edge_index_dict

        for conv in self.convs:
            x_dict = conv(x_dict, edge_index_dict)
            x_dict = {key: torch.tanh(torch.sum(x.view(-1, self.heads, self.hidden_channels), dim=1))
                      for key, x in x_dict.items()}

            # [print(key, x.shape) for key, x in x_dict.items()]
            # [print(key, x.view(-1, self.heads, self.hidden_channels).shape) for key, x in x_dict.items()]
            # [print(key, torch.mean(x.view(-1, self.heads, self.hidden_channels), dim=1).shape) for key, x in x_dict.items()]

        # p_i = F.leaky_relu(global_add_pool(x_dict['x_i'], batch=d['x_i'].batch, size=self.batch_size).unsqueeze(1))
        # p_j = F.leaky_relu(global_add_pool(x_dict['x_j'], batch=d['x_j'].batch, size=self.batch_size).unsqueeze(1))

        # p_i = global_add_pool(x_dict['x_i'], batch=d['x_i'].batch, size=self.batch_size).unsqueeze(1).sigmoid()
        # p_j = global_add_pool(x_dict['x_j'], batch=d['x_j'].batch, size=self.batch_size).unsqueeze(1).sigmoid()

        p_i = global_add_pool(x_dict['x_i'], batch=d['x_i'].batch, size=self.batch_size).unsqueeze(1).tanh()
        p_j = global_add_pool(x_dict['x_j'], batch=d['x_j'].batch, size=self.batch_size).unsqueeze(1).tanh()

        y_i_ = self.lin_i(p_i)
        y_j_ = self.lin_j(p_j)

        x = torch.cat([p_i, p_j], dim=1)
        x = torch.sum(x, dim=1)

        logits = self.lin(x).sigmoid()
        return logits, y_i_, y_j_

In [9]:
class Learner(pl.LightningModule):
    def __init__(self, root_dir, lr=0.001):
        super().__init__()
        self.root_dir = root_dir

        # self.dataset = HeteroDrugDrugInteractionData(root=self.root_dir)
        self.dataset = HeteroQM9(root=self.root_dir)
        self.dataset = self.dataset.shuffle()
        
        self.dataset = self.dataset[:10]

        self.num_workers = 32
        self.lr = lr
        # self.num_node_types = len(self.dataset[0].x_dict)
        self.n_cycles = 16
        self.dropout = 0.1
        self.batch_size = 2
        self.lr = 0.001
        self.hidden_dim = 25
        
        self.HeterogenousCoAttention = HeteroGNN(hidden_channels=self.hidden_dim, outer_out_channels=1,
                                         inner_out_channels=15, num_layers=self.n_cycles,
                                         batch_size=self.batch_size, num_node_types=self.num_node_types,
                                         num_heads=self.n_head)

        self.bce_loss = torch.nn.BCEWithLogitsLoss()
        self.mse_loss = torch.nn.MSELoss()
        
        
    def forward(self, batch, *args, **kwargs):

        y_ij, y_i_, y_j_ = self.Net(batch.x_dict, batch.edge_index_dict, batch)

        # logits = self.CoAttention(data)
        # logits = torch.sigmoid(torch.mean(logits))
        return y_ij, y_i_, y_j_

    def training_step(self, data, batch_idx):
        y_ij, y_i_, y_j_ = self(data)
        y_pred = y_ij.squeeze()
        y_true = data.binary_y.float()

        mse1 = self.mse_loss(input=y_i_.flatten(), target=data['y_i'].y)
        mse2 = self.mse_loss(input=y_j_.flatten(), target=data['y_j'].y)
        mse = mse1 + mse2
        bce = self.bce_loss(input=y_pred, target=y_true)
        loss = mse

        # self.log('train_loss', bce)
        wandb.log({"train/loss": loss})
        wandb.log({'train/y_pred': y_pred})
        wandb.log({'train/y_true': y_true})
        return {'loss': loss}  # , 'train_accuracy': acc, 'train_f1': f1}

    def validation_step(self, val_batch, batch_idx):

        # print(val_batch.binary_y.float())

        y_ij, y_i_, y_j_ = self(val_batch)
        y_pred = y_ij.squeeze()
        y_true = val_batch.binary_y.float()

        mse1 = self.mse_loss(input=y_i_.flatten(), target=val_batch['y_i'].y)
        mse2 = self.mse_loss(input=y_j_.flatten(), target=val_batch['y_j'].y)
        mse = mse1 + mse2
        bce = self.bce_loss(input=y_pred, target=y_true)
        loss = bce + mse
        # self.log('validation_loss', bce_loss)
        # self.log('Predicted', y_pred)
        # self.log('Actual', y_true)
        wandb.log({"val/loss": loss})
        return {'loss': loss}

    def configure_optimizers(self):
        optimizer = optim.AdamW(self.parameters(), lr=self.lr, betas=(0.28, 0.93), weight_decay=0.01)
        scheduler = optim.lr_scheduler.MultiStepLR(optimizer, '25,35', gamma=0.1)
        return [optimizer], [scheduler]

    def train_dataloader(self):
        return tg.loader.DataLoader(list(self.dataset),
                                    num_workers=self.num_workers, pin_memory=False, shuffle=True)

    def val_dataloader(self):
        return tg.loader.DataLoader(list(self.dataset), 
                                    num_workers=self.num_workers, pin_memory=False, shuffle=True)


if __name__ == '__main__':
    os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
    data_dir = os.path.join('GraphCoAttention', 'data')
    wandb.init()
    wandb_logger = WandbLogger(project='flux', log_model='all')
    trainer = pl.Trainer(gpus=[0], max_epochs=2000, check_val_every_n_epoch=500, accumulate_grad_batches=1)
    trainer.fit(Learner(data_dir))

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs


AttributeError: 'Learner' object has no attribute 'num_node_types'

Exception in thread ChkStopThr:
Traceback (most recent call last):
  File "/home/ray/anaconda3/envs/st/lib/python3.8/threading.py", line 932, in _bootstrap_inner
Exception in thread NetStatThr:
Traceback (most recent call last):
  File "/home/ray/anaconda3/envs/st/lib/python3.8/threading.py", line 932, in _bootstrap_inner
    self.run()
    self.run()
  File "/home/ray/anaconda3/envs/st/lib/python3.8/threading.py", line 870, in run
  File "/home/ray/anaconda3/envs/st/lib/python3.8/threading.py", line 870, in run
    self._target(*self._args, **self._kwargs)
  File "/home/ray/anaconda3/envs/st/lib/python3.8/site-packages/wandb/sdk/wandb_run.py", line 152, in check_network_status
    status_response = self._interface.communicate_network_status()
  File "/home/ray/anaconda3/envs/st/lib/python3.8/site-packages/wandb/sdk/interface/interface.py", line 125, in communicate_network_status
    resp = self._communicate_network_status(status)
  File "/home/ray/anaconda3/envs/st/lib/python3.8/site-