In [3]:
import argparse
import os.path as osp
from typing import Any, Dict, Optional

import torch
from torch.nn import (
    BatchNorm1d,
    Embedding,
    Linear,
    ModuleList,
    ReLU,
    Sequential,
)
from torch.optim.lr_scheduler import ReduceLROnPlateau

import torch_geometric.transforms as T
from torch_geometric.datasets import ZINC, LRGBDataset
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GINEConv, global_add_pool
import inspect
from typing import Any, Dict, Optional

import torch.nn.functional as F
from torch import Tensor
from torch.nn import Dropout, Linear, Sequential

from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.inits import reset
from torch_geometric.nn.resolver import (
    activation_resolver,
    normalization_resolver,
)
from torch_geometric.typing import Adj
from torch_geometric.utils import to_dense_batch

from mamba_ssm import Mamba
from torch_geometric.utils import degree, sort_edge_index
import torch.nn as nn

from torch_geometric.loader import DataLoader
from torch_geometric.utils import to_dense_adj
from scipy.sparse.csgraph import floyd_warshall
import scipy.sparse.csgraph as csg

from torch.utils.data import Dataset

import numpy as np
import os
from collections import Counter

In [17]:
SPLIT = "train"
N_SAMPLES_MAX = 12000
K = 8
SAVE_DIR = f"./data/precomputed_dist_masks/zinc/split={SPLIT}_n={N_SAMPLES_MAX}_k={K}.pt"

In [18]:
class CustomZINCDataset(Dataset):
    def __init__(
        self,
        path,
        precomputed_masks_path_train=None,
        precomputed_masks_path_val=None,
        split="train",
        transform=None,
        n_samples_max=N_SAMPLES_MAX,
    ):
        super().__init__()
        self.tg_dataset = ZINC(path, split=split, transform=transform)
        precomputed_masks_path = precomputed_masks_path_train if split == "train" else precomputed_masks_path_val
        if precomputed_masks_path is not None:
            with open(precomputed_masks_path, "rb") as f:
                self.precomputed_masks = torch.load(f) # (M, 4) == (graph_idx, node1, node2, dist)
        else: 
            self.precomputed_masks = None
        
        self.n_samples_max = n_samples_max

    def __getitem__(self, idx):
        if idx >= self.n_samples_max:
            raise IndexError
        g = self.tg_dataset[idx]
        if self.precomputed_masks is not None:
            g.dist_mask = self.precomputed_masks[self.precomputed_masks[:,0] == idx]
        return g

    def __len__(self):
        return self.n_samples_max


In [19]:
def floyd_warshall(adj_matrix, K):
    shortest_paths = csg.floyd_warshall(adj_matrix, directed=False)
    k_matrix = np.transpose((np.arange(K+1) == shortest_paths[...,None]).astype(int), (2, 0, 1))

    return k_matrix 

train_dataset = CustomZINCDataset(path = "/usr0/home/manhbaon/data/SeqForGraphs/datasets/zinc",  split=SPLIT)

max_num_modes = 0
feature_values = set()
edge_values = set()
# counter of targets
targets = Counter()
for i, data in enumerate(train_dataset):
    max_num_modes = max(max_num_modes, len(data.x))
    feature_values.update(data.x.numpy().flatten())
    edge_values.update(data.edge_attr.numpy().flatten())
    for class_labels in data.y:
        # Convert the tensor to a list of indices where the class is present
        class_indices = torch.where(class_labels == 1)[0].tolist()
        targets.update(class_indices)
    

    
print(feature_values)
print(edge_values)
# print relative target counts
print(targets)
print({k: v / sum(targets.values()) for k, v in targets.items()})

{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 18, 20}
{1, 2, 3}
Counter()
{}


In [20]:
data.x.shape

torch.Size([20, 1])

In [21]:
# Precompute and save dist_masks
dist_masks = []
for i, data in enumerate(train_dataset):
    adj_matrix = to_dense_adj(data.edge_index, max_num_nodes=max_num_modes).squeeze(0).numpy()
    k_matrix = floyd_warshall(adj_matrix, K)
    dist_masks.append( torch.tensor(k_matrix, dtype=torch.float32) )

# Save each dist_mask to disk
dist_mask = torch.stack(dist_masks, dim = 0) 

In [17]:
dist_mask.shape #(N, K+1, num_nodes, num_nodes) 

torch.Size([2000, 9, 37, 37])

In [22]:
# Compressed represenation
compressed_dist_mask2 = []
for k in range(K+1):
    k_masks = torch.nonzero(dist_mask[:,k,...] == 1.0) #(n, node_1, node_2)
    k_masks = torch.cat([k_masks, torch.tensor([K-k]).repeat( k_masks.shape[0],1)], dim=1) 
    compressed_dist_mask2.append( k_masks  )
compressed_dist_mask2 = torch.cat(compressed_dist_mask2, dim=0)
print(compressed_dist_mask2.shape) # (M,4) == (graph_id, node_1, node_2, k)

compressed_dist_mask = torch.nonzero(dist_mask == 1.0) #(n, node_1, node_2, k)
compressed_dist_mask.shape

torch.Size([5678206, 4])


torch.Size([5678206, 4])

In [23]:
torch.all(torch.unique(compressed_dist_mask2) == torch.unique(compressed_dist_mask))

tensor(True)

In [20]:
# idxs = (compressed_dist_mask[:,:2] == torch.Tensor([0,0])).all(dim=1).nonzero()
# compressed_dist_mask[idxs]

In [21]:
len((compressed_dist_mask[:,:2] == torch.Tensor([3,0])).all(dim=1).nonzero())==max_num_modes

True

In [24]:
# create dir from fname
os.makedirs(osp.dirname(SAVE_DIR), exist_ok=True)
torch.save(compressed_dist_mask, SAVE_DIR)