In [21]:
import dgl
from dgl.data import citation_graph as citegrh
import torch
import numpy as np
from typing import List

In [41]:
import importlib
import packages.transformer.data as transformer_data
import packages.transformer.encoder_decoder as enc_dec
import packages.transformer.attention as attention
import packages.transformer.utils as utils
importlib.reload(transformer_data)
importlib.reload(utils)
importlib.reload(attention)
importlib.reload(enc_dec)

from packages.transformer.data import construct_batch, TransformerGraphBundleInput
from packages.transformer.encoder_decoder import make_model

In [6]:
data = citegrh.load_cora()
graph = data[0]
adj = graph.adj(scipy_fmt='coo')
graph = dgl.graph((adj.row, adj.col)).to('cuda')

  NumNodes: 2708
  NumEdges: 10556
  NumFeats: 1433
  NumClasses: 7
  NumTrainingSamples: 140
  NumValidationSamples: 500
  NumTestSamples: 1000
Done loading data from cached files.


In [7]:
adj = (graph.adj(scipy_fmt='coo'))

In [8]:
features = torch.tensor(data.features, device='cuda')
labels = torch.tensor(data.labels, device='cuda')

  features = torch.tensor(data.features, device='cuda')


In [9]:
train_mask = torch.BoolTensor(data.train_mask)
sampler = dgl.dataloading.MultiLayerNeighborSampler([3, 3])
train_nids = (torch.arange(0, graph.number_of_nodes())[train_mask]).to('cuda')
dataloader = dgl.dataloading.DataLoader(
    graph, train_nids, sampler,
    batch_size=64,
    shuffle=True,
    drop_last=False,
    num_workers=0)



In [10]:
loader_iter = iter(dataloader)

In [11]:
input_nodes_fb, output_nodes_fb, mfgs_fb = next(loader_iter)
input_nodes_sb, output_nodes_sb, mfgs_sb = next(loader_iter)

In [12]:
input_graph_bundle_fb = construct_batch(output_nodes_fb, input_nodes_fb, mfgs_fb, features, labels, 'cpu')
input_graph_bundle_sb = construct_batch(output_nodes_sb, input_nodes_sb, mfgs_sb, features, labels, 'cpu')

In [13]:
def pad_graph_bundle(graph_bundle: TransformerGraphBundleInput) -> None: # WARNING: mutates graph bundle object
    src_mask = graph_bundle.src_mask.squeeze(0) 
    size_subgraph = src_mask.shape[1]
    padded_src_mask = torch.zeros((src_mask.shape[0], 512, 512))
    padded_src_mask[:, : size_subgraph, : size_subgraph] = src_mask

    src_feats = graph_bundle.src_feats.squeeze(0)
    padded_src_feats = torch.zeros((512, src_feats.shape[-1]))
    padded_src_feats[: size_subgraph, :src_feats.shape[-1]] = src_feats
    graph_bundle.src_feats = padded_src_feats.unsqueeze(0)
    graph_bundle.src_mask = padded_src_mask.unsqueeze(0)


In [14]:
def stack_graph_bundles(graph_bundles: List[TransformerGraphBundleInput]) -> TransformerGraphBundleInput:
    src_masks = torch.cat([graph_bundle.src_mask for graph_bundle in graph_bundles])
    src_feats = torch.cat([graph_bundle.src_feats for graph_bundle in graph_bundles])
    trg_labels = torch.cat([graph_bundle.trg_labels for graph_bundle in graph_bundles])
    train_inds = torch.cat([graph_bundle.train_inds for graph_bundle in graph_bundles])
    return TransformerGraphBundleInput(src_feats, trg_labels, src_masks, train_inds, 'cpu')

In [15]:
pad_graph_bundle(input_graph_bundle_fb)

In [16]:
pad_graph_bundle(input_graph_bundle_sb)


In [17]:
stacked_graph_bundle = stack_graph_bundles([input_graph_bundle_fb, input_graph_bundle_sb])

In [18]:
stacked_graph_bundle.ntokens
print(stacked_graph_bundle.ntokens)
print(stacked_graph_bundle.train_inds.shape)
print(stacked_graph_bundle.src_feats.shape)
print(stacked_graph_bundle.src_mask.shape)

128
torch.Size([2, 64])
torch.Size([2, 512, 1433])
torch.Size([2, 2, 512, 512])


In [42]:
model = make_model(features.shape[1], len(labels.unique()) + 1, N=2) # +1 for the padding index, though I don't think it's necessary.

  nn.init.xavier_uniform(p)


In [44]:
# model.forward(stacked_graph_bundle.src_feats, stacked_graph_bundle.src_mask, stacked_graph_bundle.train_inds)
val = model.forward(stacked_graph_bundle.src_feats, stacked_graph_bundle.src_mask, stacked_graph_bundle.train_inds)
print(val.shape)

torch.Size([2, 64, 512])


In [4]:
B = 2
H = 8
X = torch.randn((B, H, 512, 512))
M = torch.zeros((B, 512, 512)) # an arbitrary mask

X.masked_fill(M.unsqueeze(1) == 0, -1e9)
# for batch_i in range(X.shape[0]): # looping over [1...B]
# 	batch_mask = M[batch_i]
# 	for j in range(X.shape[1]): # looping over [1...H]
# 		X[batch_i, j] = X[batch_i, j].masked_fill(batch_mask == 0, -1e9) 

# X.masked_fill(M ==0, -1e9)

tensor([[[[-1.0000e+09, -1.0000e+09, -1.0000e+09,  ..., -1.0000e+09,
           -1.0000e+09, -1.0000e+09],
          [-1.0000e+09, -1.0000e+09, -1.0000e+09,  ..., -1.0000e+09,
           -1.0000e+09, -1.0000e+09],
          [-1.0000e+09, -1.0000e+09, -1.0000e+09,  ..., -1.0000e+09,
           -1.0000e+09, -1.0000e+09],
          ...,
          [-1.0000e+09, -1.0000e+09, -1.0000e+09,  ..., -1.0000e+09,
           -1.0000e+09, -1.0000e+09],
          [-1.0000e+09, -1.0000e+09, -1.0000e+09,  ..., -1.0000e+09,
           -1.0000e+09, -1.0000e+09],
          [-1.0000e+09, -1.0000e+09, -1.0000e+09,  ..., -1.0000e+09,
           -1.0000e+09, -1.0000e+09]],

         [[-1.0000e+09, -1.0000e+09, -1.0000e+09,  ..., -1.0000e+09,
           -1.0000e+09, -1.0000e+09],
          [-1.0000e+09, -1.0000e+09, -1.0000e+09,  ..., -1.0000e+09,
           -1.0000e+09, -1.0000e+09],
          [-1.0000e+09, -1.0000e+09, -1.0000e+09,  ..., -1.0000e+09,
           -1.0000e+09, -1.0000e+09],
          ...,
     

In [33]:
def batched_index_select(input, dim, index):
	views = [input.shape[0]] + \
		[1 if i != dim else -1 for i in range(1, len(input.shape))]
	expanse = list(input.shape)
	expanse[0] = -1
	expanse[dim] = -1
	index = index.view(views).expand(expanse)
	return torch.gather(input, dim, index)

In [36]:
# B = 2
# SG_SIZE = 512
# BS = 64
# node_embeds = torch.rand((B, SG_SIZE, 300))
# index = torch.randint(0, SG_SIZE, (2,BS))

embeds = batched_index_select(node_embeds, 1, index)
print(embeds.shape)
# torch.gather(node_embeds, 0, index)
# node_embeds[:,index].shape
# print(index)

torch.Size([2, 64, 300])
