In [1]:
# !pip install torch_geometric==2.3.1
!pip install mamba-ssm

import torch

!pip uninstall torch-scatter torch-sparse torch-geometric torch-cluster  --y
!pip install torch-scatter -f https://data.pyg.org/whl/torch-{torch.__version__}.html
!pip install torch-sparse -f https://data.pyg.org/whl/torch-{torch.__version__}.html
!pip install torch-cluster -f https://data.pyg.org/whl/torch-{torch.__version__}.html
!pip install git+https://github.com/pyg-team/pytorch_geometric.git

Found existing installation: torch_scatter 2.1.2+pt22cu121
Uninstalling torch_scatter-2.1.2+pt22cu121:
  Successfully uninstalled torch_scatter-2.1.2+pt22cu121
Found existing installation: torch_sparse 0.6.18+pt22cu121
Uninstalling torch_sparse-0.6.18+pt22cu121:
  Successfully uninstalled torch_sparse-0.6.18+pt22cu121
Found existing installation: torch_geometric 2.6.0
Uninstalling torch_geometric-2.6.0:
  Successfully uninstalled torch_geometric-2.6.0
Found existing installation: torch_cluster 1.6.3+pt22cu121
Uninstalling torch_cluster-1.6.3+pt22cu121:
  Successfully uninstalled torch_cluster-1.6.3+pt22cu121
Looking in links: https://data.pyg.org/whl/torch-2.2.1+cu121.html
Collecting torch-scatter
  Using cached https://data.pyg.org/whl/torch-2.2.0%2Bcu121/torch_scatter-2.1.2%2Bpt22cu121-cp310-cp310-linux_x86_64.whl (10.9 MB)
Installing collected packages: torch-scatter
Successfully installed torch-scatter-2.1.2+pt22cu121
Looking in links: https://data.pyg.org/whl/torch-2.2.1+cu121.htm

In [2]:
import argparse
import os.path as osp
from typing import Any, Dict, Optional

import torch
from torch.nn import (
    BatchNorm1d,
    Embedding,
    Linear,
    ModuleList,
    ReLU,
    Sequential,
)
from torch.optim.lr_scheduler import ReduceLROnPlateau

import torch_geometric.transforms as T
from torch_geometric.datasets import ZINC
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GINEConv, global_add_pool
import inspect
from typing import Any, Dict, Optional

import torch.nn.functional as F
from torch import Tensor
from torch.nn import Dropout, Linear, Sequential

from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.inits import reset
from torch_geometric.nn.resolver import (
    activation_resolver,
    normalization_resolver,
)
from torch_geometric.typing import Adj
from torch_geometric.utils import to_dense_batch

from mamba_ssm import Mamba
from torch_geometric.utils import degree, sort_edge_index
import torch.nn as nn

from torch_geometric.loader import DataLoader
from torch_geometric.utils import to_dense_adj
from scipy.sparse.csgraph import floyd_warshall
import scipy.sparse.csgraph as csg

import numpy as np

In [3]:
class MLP(nn.Module):
    def __init__(self, dim_h, drop_rate=0.):
        super(MLP, self).__init__()
        self.dim_h = dim_h
        self.drop_rate = drop_rate
        self.layer_norm = nn.LayerNorm(dim_h)
        self.dense1 = nn.Linear(dim_h, dim_h)
        self.dropout1 = nn.Dropout(drop_rate)
        self.dense2 = nn.Linear(dim_h, dim_h)
        self.dropout2 = nn.Dropout(drop_rate)

    def forward(self, inputs, training=False):
        x = self.layer_norm(inputs)
        x = self.dense1(x)
        x = F.gelu(x)
        x = self.dropout1(x) if training else x
        x = self.dense2(x)
        x = self.dropout2(x) if training else x
        return x + inputs

In [6]:
class GREDMamba(torch.nn.Module):
    def __init__(self, dim_h):
        super().__init__()

        self.mlp = MLP(dim_h)
        self.self_attn = Mamba(
              d_model=dim_h,
              d_state=16,
              d_conv=4,
              expand=1
          )

    def forward(self, inputs, dist_masks):
        # Shape of inputs: (batch_size, num_nodes, dim_h)
        # Shape of dist_masks: (batch_size, K+1, num_nodes, num_nodes)
        x = torch.swapaxes(dist_masks, 0, 1) @ inputs

        # Shape of x: (K+1, batch_size, num_nodes, dim_h)
        x = self.mlp(x)

        # TODO
        # x = torch.permute(x, (1, 2, 0, 3)) # Shape of x: (batch_size, num_nodes, K+1, dim_h)

        # Shape of x: (K+1 * batch_size, num_nodes, dim_h)
        # x = self.self_attn(x)
        return x

In [None]:
# path, subset = '/scratch/ssd004/scratch/tsepaole/ZINC_full/', False
path, subset = '', True

train_dataset = ZINC(path, subset=subset, split='train')
val_dataset = ZINC(path, subset=subset, split='val')
test_dataset = ZINC(path, subset=subset, split='test')

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64)
test_loader = DataLoader(test_dataset, batch_size=64)

K = 4

model = GREDMamba(1).to('cuda')

for data in train_loader:
    inputs, mask = to_dense_batch(data.x, data.batch) # Shape of inputs: (batch_size, num_nodes, dim_h)
    inputs = inputs.float().to('cuda')
    mask = mask.float().to('cuda')

    # Create distance matrix using floyd_warshall
    adj_matrix = to_dense_adj(data.edge_index, batch=data.batch, max_num_nodes=inputs.shape[1])

    # Compute the shortest paths using Floyd-Warshall
    shortest_paths = [csg.floyd_warshall(i, directed=False) for i in adj_matrix.numpy()]
    shortest_paths = [np.transpose((np.arange(K+1)[::-1] == i[...,None]).astype(int), (2, 0, 1)) for i in shortest_paths]

    # Shape of dist_masks: (batch_size, K+1, num_nodes, num_nodes)
    dist_mask = torch.tensor(np.array(shortest_paths), dtype=torch.float32).to('cuda')

    out = model(inputs, dist_mask)