In [1]:
import argparse
import os.path as osp
from typing import Any, Dict, Optional

import torch
from torch.nn import (
    BatchNorm1d,
    Embedding,
    Linear,
    ModuleList,
    ReLU,
    Sequential,
)
from torch.optim.lr_scheduler import ReduceLROnPlateau

import torch_geometric.transforms as T
from torch_geometric.datasets import ZINC, LRGBDataset
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GINEConv, global_add_pool
import inspect
from typing import Any, Dict, Optional

import torch.nn.functional as F
from torch import Tensor
from torch.nn import Dropout, Linear, Sequential

from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.inits import reset
from torch_geometric.nn.resolver import (
    activation_resolver,
    normalization_resolver,
)
from torch_geometric.typing import Adj
from torch_geometric.utils import to_dense_batch

from mamba_ssm import Mamba
from torch_geometric.utils import degree, sort_edge_index
import torch.nn as nn

from torch_geometric.loader import DataLoader
from torch_geometric.utils import to_dense_adj
from scipy.sparse.csgraph import floyd_warshall
import scipy.sparse.csgraph as csg

from torch.utils.data import Dataset

import numpy as np
import os

  from .autonotebook import tqdm as notebook_tqdm


In [103]:
SPLIT = "train"
N_SAMPLES_MAX = 2500
K = 25
SAVE_DIR = f"../data/precomputed_dist_masks/pept/split={SPLIT}_n={N_SAMPLES_MAX}_k={K}.pt"

In [104]:
class CustomLRGBDataset(Dataset):
    def __init__(
        self,
        path,
        name,
        precomputed_masks_path_train=None,
        precomputed_masks_path_val=None,
        split="train",
        transform=None,
        n_samples_max=N_SAMPLES_MAX,
    ):
        super().__init__()
        self.tg_dataset = LRGBDataset(path, name=name, split=split, transform=transform)
        precomputed_masks_path = precomputed_masks_path_train if split == "train" else precomputed_masks_path_val
        if precomputed_masks_path is not None:
            with open(precomputed_masks_path, "rb") as f:
                self.precomputed_masks = torch.load(f) # (M, 4) == (graph_idx, node1, node2, dist)
        else: 
            self.precomputed_masks = None
        
        self.n_samples_max = n_samples_max

    def __getitem__(self, idx):
        if idx >= self.n_samples_max:
            raise IndexError
        g = self.tg_dataset[idx]
        if self.precomputed_masks is not None:
            g.dist_mask = self.precomputed_masks[self.precomputed_masks[:,0] == idx]
        return g

    def __len__(self):
        return self.n_samples_max


In [105]:
def floyd_warshall(adj_matrix, K):
    shortest_paths = csg.floyd_warshall(adj_matrix, directed=False)
    k_matrix = np.transpose((np.arange(K+1)[::-1] == shortest_paths[...,None]).astype(int), (2, 0, 1))

    return k_matrix 

train_dataset = CustomLRGBDataset(path = "/home/ricardob/data/SeqForGraphs/datasets/lrgb", name="Peptides-struct", split=SPLIT)

max_num_modes = 0
for i, data in enumerate(train_dataset):
    max_num_modes = max(max_num_modes, len(data.x))

# Precompute and save dist_masks
dist_masks = []
for i, data in enumerate(train_dataset):
    adj_matrix = to_dense_adj(data.edge_index, max_num_nodes=max_num_modes).squeeze(0).numpy()
    k_matrix = floyd_warshall(adj_matrix, K)
    dist_masks.append( torch.tensor(k_matrix, dtype=torch.float32) )

# Save each dist_mask to disk
dist_mask = torch.stack(dist_masks, dim = 0) 

In [106]:
dist_mask.shape #(N, K+1, num_nodes, num_nodes) 


torch.Size([2500, 26, 434, 434])

In [107]:
# Compressed represenation
compressed_dist_mask = []
for k in range(K+1):
    k_masks = torch.nonzero(dist_mask[:,k,...] == 1.0) #(n, node_1, node_2)
    k_masks = torch.cat([k_masks, torch.tensor([K-k]).repeat( k_masks.shape[0],1)], dim=1) 
    compressed_dist_mask.append( k_masks  )
compressed_dist_mask = torch.cat(compressed_dist_mask, dim=0)
compressed_dist_mask.shape # (M,4) == (graph_id, node_1, node_2, k)

torch.Size([39934726, 4])

In [108]:
# create dir from fname
os.makedirs(osp.dirname(SAVE_DIR), exist_ok=True)
torch.save(compressed_dist_mask, SAVE_DIR)