In [32]:
# !pip install torch_geometric==2.3.1
# !pip install mamba-ssm

# import torch

# !pip uninstall torch-scatter torch-sparse torch-geometric torch-cluster  --y
# !pip install torch-scatter -f https://data.pyg.org/whl/torch-{torch.__version__}.html
# !pip install torch-sparse -f https://data.pyg.org/whl/torch-{torch.__version__}.html
# !pip install torch-cluster -f https://data.pyg.org/whl/torch-{torch.__version__}.html
# !pip install git+https://github.com/pyg-team/pytorch_geometric.git

In [4]:
import argparse
import os.path as osp
from typing import Any, Dict, Optional

import torch
from torch.nn import (
    BatchNorm1d,
    Embedding,
    Linear,
    ModuleList,
    ReLU,
    Sequential,
)
from torch.optim.lr_scheduler import ReduceLROnPlateau

import torch_geometric.transforms as T
from torch_geometric.datasets import ZINC
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GINEConv, global_add_pool
import inspect
from typing import Any, Dict, Optional

import torch.nn.functional as F
from torch import Tensor
from torch.nn import Dropout, Linear, Sequential

from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.inits import reset
from torch_geometric.nn.resolver import (
    activation_resolver,
    normalization_resolver,
)
from torch_geometric.typing import Adj
from torch_geometric.utils import to_dense_batch

from mamba_ssm import Mamba
from torch_geometric.utils import degree, sort_edge_index
import torch.nn as nn

In [33]:
class MLP(nn.Module):
    def __init__(self, dim_h, drop_rate=0.):
        super(MLP, self).__init__()
        self.dim_h = dim_h
        self.drop_rate = drop_rate
        self.layer_norm = nn.LayerNorm(dim_h)
        self.dense1 = nn.Linear(dim_h, dim_h)
        self.dropout1 = nn.Dropout(drop_rate)
        self.dense2 = nn.Linear(dim_h, dim_h)
        self.dropout2 = nn.Dropout(drop_rate)

    def forward(self, inputs, training=False):
        x = self.layer_norm(inputs)
        x = self.dense1(x)
        x = F.gelu(x)
        x = self.dropout1(x) if training else x
        x = self.dense2(x)
        x = self.dropout2(x) if training else x
        return x + inputs

In [34]:
class GREDMamba(torch.nn.Module):
    def __init__(self, dim_h):
        super().__init__()

        self.mlp = MLP(dim_h)
        self.self_attn = Mamba(
              d_model=dim_h,
              d_state=16,
              d_conv=4,
              expand=1
          )

    def forward(self, inputs, dist_masks):
        # Shape of inputs: (batch_size, num_nodes, dim_h)
        # Shape of dist_masks: (batch_size, K+1, num_nodes, num_nodes)
        x = torch.swapaxes(dist_masks, 0, 1) @ inputs
        # Shape of x: (K+1, batch_size, num_nodes, dim_h)
        x = self.mlp(x)
        x = x.reshape(x.shape[0] * x.shape[1], x.shape[2], x.shape[3])
        # Shape of x: (K+1 * batch_size, num_nodes, dim_h)
        x = self.self_attn(x)
        return x

In [35]:
import numpy as np
import networkx as nx

def create_random_graph(batch_size, num_nodes, dim_h, K):
  # Create a random graph
  G = nx.erdos_renyi_graph(num_nodes, p=0.5)

  inputs = np.random.rand(batch_size, num_nodes, dim_h)

  dist_masks = np.full((batch_size, K+1, num_nodes, num_nodes), fill_value=np.inf)
  np.fill_diagonal(dist_masks[0, 0], 0)

  # Replace with floyd_warshall as per GRED implementation
  lengths = dict(nx.all_pairs_shortest_path_length(G))
  for i in range(num_nodes):
      for j in range(num_nodes):
          if i in lengths and j in lengths[i] and lengths[i][j] <= K:
              dist_masks[0, lengths[i][j], i, j] = 1

  dist_masks[dist_masks == np.inf] = 0

  return inputs, dist_masks

In [36]:
inputs, dist_masks = create_random_graph(1, 10, 16, 3)

inputs_tensor = torch.tensor(inputs, dtype=torch.float32).to('cuda')
dist_masks_tensor = torch.tensor(dist_masks, dtype=torch.float32).to('cuda')

model = GREDMamba(dim_h=inputs.shape[-1]).to('cuda')

y = model(inputs_tensor, dist_masks_tensor)
