In [1]:
import dgl
import numpy as np
import torch
import scipy as sp

In [29]:
import importlib
import packages.utils.sp_utils as sp_utils
importlib.reload(sp_utils)


<module 'packages.utils.sp_utils' from '/home/fsamir/gnn/packages/utils/sp_utils.py'>

In [14]:
# Local imports 
from packages.data_management.pkl_io import save_pkl, load_pkl_from_path
from packages.utils.sp_utils import select_submatrix, convert_scipy_sparse_to_torch
from packages.transformer.data import retrieve_features_for_minibatch, retrieve_labels_for_minibatch, TransformerGraphBundleInput

In [2]:
def load_reddit_adj(path="data/reddit_cpr", fname="reddit_adj_coo.npz"):
    adj = sp.sparse.load_npz(f'{path}/{fname}')
    return adj

def load_reddit_feats(path="data/reddit_cpr", fname="reddit_feats.npy"):
    with open(f"{path}/{fname}", "rb") as f:
        feats = np.load(f)
        return feats

def load_reddit_labels(path="data/reddit_cpr", fname="reddit_labels.npy"):
    with open(f"{path}/{fname}", "rb") as f:
        labels = np.load(f)
        return labels 

def load_reddit_masks(path="data/reddit_cpr", fname="reddit_masks.npy"):
    with open(f"{path}/{fname}", "rb") as f:
        masks = np.load(f)
        return masks 

In [3]:
adj_sparse = load_reddit_adj()
print(adj_sparse.shape)

(232965, 232965)


In [4]:
feats = load_reddit_feats()
print(feats.shape)

(232965, 602)


In [5]:
masks = load_reddit_masks()
print(masks.shape)

(3, 232965)


In [8]:
labels = load_reddit_labels()
print(labels.shape)

(232965,)


In [6]:
all_ids = np.arange(masks.shape[1])
train_ids = all_ids[masks[0,:]]

In [7]:
graph = dgl.graph((adj_sparse.row, adj_sparse.col))


In [8]:
sampler = dgl.dataloading.MultiLayerNeighborSampler([5, 5])
dataloader = dgl.dataloading.DataLoader(
    graph, train_ids, sampler,
    batch_size=64,
    shuffle=True,
    drop_last=False,
    num_workers=0)

In [27]:
def construct_batch(target_nodes, subgraph_nodes, mfg, sparse_adj, features, device):
    all_parallel_indices = torch.arange(subgraph_nodes.shape[0], device=device)

    src_inds_first_layer = (mfg[0].srcdata[dgl.NID])
    dst_inds_first_layer = (mfg[0].dstdata[dgl.NID])
    two_hop_neighbour_inds_argsort_inds = all_parallel_indices[dst_inds_first_layer.shape[0]:]
    output_node_argsort_inds = all_parallel_indices[: target_nodes.shape[0]] # NOTE: is this an invariant form for DGL? It's probably not guaranteed 

    first_layer_adj_submatrix = select_submatrix(sparse_adj, src_inds_first_layer, all_parallel_indices, device) # TODO: does this work?
    first_layer_adj_submatrix = first_layer_adj_submatrix + torch.eye(first_layer_adj_submatrix.shape[0], device=device) # NOTE: adding self-connections.

    second_layer_adj_submatrix = first_layer_adj_submatrix.detach().clone()
    second_layer_adj_submatrix[:, two_hop_neighbour_inds_argsort_inds] = 0 
    second_layer_adj_submatrix = second_layer_adj_submatrix + torch.eye(second_layer_adj_submatrix.shape[0], device=device) # NOTE: adding self-connections.
    
    minibatch_adjacencies = torch.stack((first_layer_adj_submatrix, second_layer_adj_submatrix))
    all_minibatch_feats = retrieve_features_for_minibatch(src_inds_first_layer, features)

    all_minibatch_feats = all_minibatch_feats.unsqueeze(0)
    minibatch_adjacencies = minibatch_adjacencies.unsqueeze(0)
    minibatch_labels = retrieve_labels_for_minibatch(target_nodes, labels).unsqueeze(0)
    output_node_inds = output_node_argsort_inds.unsqueeze(0)

    minibatch = TransformerGraphBundleInput(all_minibatch_feats, minibatch_labels, minibatch_adjacencies, output_node_inds)
    return minibatch

In [9]:
dataloader_iter = iter(dataloader)
input_nodes, output_nodes, mfgs = next(dataloader_iter) # input nodes gives us the requisite features. The mfgs gives us the requisite attention mask

In [16]:
mfg = load_pkl_from_path("data/reddit_cpr/mfg")
print(mfg)

In [13]:
with open("data/reddit_cpr/subgraph_nodes.npy", "wb") as f:
    f.write(input_nodes.detach().numpy())
    
with open("data/reddit_cpr/target_nodes.npy", "wb") as f:
    f.write(output_nodes.detach().numpy())

save_pkl("mfg", mfgs, "data/reddit_cpr/")


# print(input_nodes)
# print(mfgs)

In [32]:
adj_sparse = convert_scipy_sparse_to_torch(adj_sparse, 'cpu')

In [33]:
minibatch = construct_batch(output_nodes, input_nodes, mfgs, adj_sparse, feats, 'cpu')

: 

: 