In [27]:
import dgl
from dgl.data import citation_graph as citegrh
import torch
import numpy as np
from typing import List

In [41]:
import importlib
import packages.transformer.data as transformer_data
importlib.reload(transformer_data)

from packages.transformer.data import construct_batch, TransformerGraphBundleInput

In [4]:
data = citegrh.load_cora()
graph = data[0]
adj = graph.adj(scipy_fmt='coo')
graph = dgl.graph((adj.row, adj.col)).to('cuda')

  NumNodes: 2708
  NumEdges: 10556
  NumFeats: 1433
  NumClasses: 7
  NumTrainingSamples: 140
  NumValidationSamples: 500
  NumTestSamples: 1000
Done loading data from cached files.


In [5]:
adj = (graph.adj(scipy_fmt='coo'))

In [10]:
features = torch.tensor(data.features, device='cuda')
labels = torch.tensor(data.labels, device='cuda')

  features = torch.tensor(data.features, device='cuda')


In [7]:
train_mask = torch.BoolTensor(data.train_mask)
sampler = dgl.dataloading.MultiLayerNeighborSampler([3, 3])
train_nids = (torch.arange(0, graph.number_of_nodes())[train_mask]).to('cuda')
dataloader = dgl.dataloading.DataLoader(
    graph, train_nids, sampler,
    batch_size=64,
    shuffle=True,
    drop_last=False,
    num_workers=0)



In [8]:
loader_iter = iter(dataloader)

In [9]:
input_nodes_fb, output_nodes_fb, mfgs_fb = next(loader_iter)
input_nodes_sb, output_nodes_sb, mfgs_sb = next(loader_iter)

In [12]:
input_graph_bundle_fb = construct_batch(output_nodes_fb, input_nodes_fb, mfgs_fb, features, labels, 'cpu')
input_graph_bundle_sb = construct_batch(output_nodes_sb, input_nodes_sb, mfgs_sb, features, labels, 'cpu')

In [22]:
def pad_graph_bundle(graph_bundle: TransformerGraphBundleInput) -> None: # WARNING: mutates graph bundle object
    src_mask = graph_bundle.src_mask.squeeze(0) 
    size_subgraph = src_mask.shape[1]
    padded_src_mask = torch.zeros((src_mask.shape[0], 512, 512))
    padded_src_mask[:, : size_subgraph, : size_subgraph] = src_mask

    src_feats = graph_bundle.src_feats.squeeze(0)
    padded_src_feats = torch.zeros((512, src_feats.shape[-1]))
    padded_src_feats[: size_subgraph, :src_feats.shape[-1]] = src_feats
    graph_bundle.src_feats = padded_src_feats.unsqueeze(0)
    graph_bundle.src_mask = padded_src_mask.unsqueeze(0)


In [42]:
def stack_graph_bundles(graph_bundles: List[TransformerGraphBundleInput]) -> TransformerGraphBundleInput:
    src_masks = torch.cat([graph_bundle.src_mask for graph_bundle in graph_bundles])
    src_feats = torch.cat([graph_bundle.src_feats for graph_bundle in graph_bundles])
    trg_labels = torch.cat([graph_bundle.trg_labels for graph_bundle in graph_bundles])
    train_inds = torch.cat([graph_bundle.train_inds for graph_bundle in graph_bundles])
    return TransformerGraphBundleInput(src_feats, trg_labels, src_masks, train_inds, 'cpu')

In [23]:
pad_graph_bundle(input_graph_bundle_fb)

In [29]:
pad_graph_bundle(input_graph_bundle_sb)


In [43]:
stacked_graph_bundle = stack_graph_bundles([input_graph_bundle_fb, input_graph_bundle_sb])

In [45]:
stacked_graph_bundle.ntokens
print(stacked_graph_bundle.ntokens)
print(stacked_graph_bundle.train_inds.shape)
print(stacked_graph_bundle.src_feats.shape)
print(stacked_graph_bundle.src_mask.shape)

128
torch.Size([2, 64])
torch.Size([2, 512, 1433])
torch.Size([2, 2, 512, 512])
