In [1]:
%load_ext autoreload
%autoreload 2

import torch
import numpy as np

from modules.graphmixer import GraphMixer
from neighbor_loader import LastNeighborLoaderGraphmixer

import matplotlib.pyplot as plt

# Test Neighbor Loader

In [2]:
# Test init
num_nodes = 9
num_neighbors = 2
edge_feat_dim = 2
all_nid = torch.arange(9).long()
fid = torch.arange(6).long()  # firm ids
pid = torch.arange(6,9).long()  # prod ids
neighbor_loader = LastNeighborLoaderGraphmixer(num_nodes, num_neighbors, edge_feat_dim)
print(neighbor_loader.neighbors.shape)
print(neighbor_loader.e_id.shape)
print(neighbor_loader.msg.shape)
print(neighbor_loader._assoc.shape)
print(neighbor_loader.edge_feat_dim)

torch.Size([9, 2])
torch.Size([9, 2])
torch.Size([9, 2, 2])
torch.Size([9])
2


In [3]:
# Test insert
src = torch.Tensor([0, 1, 2]).to(torch.long)
dst = torch.Tensor([3, 4, 5]).to(torch.long)
prod = torch.Tensor([6, 7, 8]).to(torch.long)
t = torch.Tensor([11, 22, 33]).to(torch.long)
msg = torch.Tensor([[74, 75], [234, 235], [13, 14]]).to(torch.float)
neighbor_loader.insert(src, dst, prod, t, msg)
self = neighbor_loader 

# Ground truth neighbors
# 0: 6 (11); 1: 7 (22); 2: 8 (33); 3: 6 (11); 4: 7 (22); 5: 8 (33); 6: 3,0 (11, 11); 7: 4,1 (22, 22); 8: 5,2 (33, 33)
for i in all_nid:
    print(i, self.neighbors[i], self.e_id[i], self.t_id[i], self.msg[i])

tensor(0) tensor([6, 0]) tensor([ 0, -1]) tensor([11, -1]) tensor([[74., 75.],
        [-1., -1.]])
tensor(1) tensor([              7, 139771620186544]) tensor([ 2, -1]) tensor([22, -1]) tensor([[234., 235.],
        [ -1.,  -1.]])
tensor(2) tensor([              8, 139774613494808]) tensor([ 4, -1]) tensor([33, -1]) tensor([[13., 14.],
        [-1., -1.]])
tensor(3) tensor([        6, 104574592]) tensor([ 1, -1]) tensor([11, -1]) tensor([[74., 75.],
        [-1., -1.]])
tensor(4) tensor([              7, 139774482081392]) tensor([ 3, -1]) tensor([22, -1]) tensor([[234., 235.],
        [ -1.,  -1.]])
tensor(5) tensor([              8, 139774613494648]) tensor([ 5, -1]) tensor([33, -1]) tensor([[13., 14.],
        [-1., -1.]])
tensor(6) tensor([3, 0]) tensor([1, 0]) tensor([11, 11]) tensor([[74., 75.],
        [74., 75.]])
tensor(7) tensor([4, 1]) tensor([3, 2]) tensor([22, 22]) tensor([[234., 235.],
        [234., 235.]])
tensor(8) tensor([5, 2]) tensor([5, 4]) tensor([33, 33]) tensor(

In [4]:
# Test _call_
neighbors, _, e_id, t_id, msg = neighbor_loader(all_nid)
# Ground truth edges (x2, for all)
# 0-6, 3-6, 1-7, 4-7, 2-8, 5-8
print(neighbors, neighbors.shape)
print(e_id, e_id.shape)
print(t_id, t_id.shape)
print(msg, msg.shape)

tensor([[ 6, -1],
        [ 7, -1],
        [ 8, -1],
        [ 6, -1],
        [ 7, -1],
        [ 8, -1],
        [ 3,  0],
        [ 4,  1],
        [ 5,  2]]) torch.Size([9, 2])
tensor([[ 0, -1],
        [ 2, -1],
        [ 4, -1],
        [ 1, -1],
        [ 3, -1],
        [ 5, -1],
        [ 1,  0],
        [ 3,  2],
        [ 5,  4]]) torch.Size([9, 2])
tensor([[11, -1],
        [22, -1],
        [33, -1],
        [11, -1],
        [22, -1],
        [33, -1],
        [11, 11],
        [22, 22],
        [33, 33]]) torch.Size([9, 2])
tensor([[[ 74.,  75.],
         [ -1.,  -1.]],

        [[234., 235.],
         [ -1.,  -1.]],

        [[ 13.,  14.],
         [ -1.,  -1.]],

        [[ 74.,  75.],
         [ -1.,  -1.]],

        [[234., 235.],
         [ -1.,  -1.]],

        [[ 13.,  14.],
         [ -1.,  -1.]],

        [[ 74.,  75.],
         [ 74.,  75.]],

        [[234., 235.],
         [234., 235.]],

        [[ 13.,  14.],
         [ 13.,  14.]]]) torch.Size([9, 2, 2])


In [5]:
# Test insert again
src = torch.Tensor([0]).to(torch.long)
dst = torch.Tensor([1]).to(torch.long)
prod = torch.Tensor([6]).to(torch.long)
t = torch.Tensor([44]).to(torch.long)
msg = torch.Tensor([[665, 666]]).to(torch.float)
neighbor_loader.insert(src, dst, prod, t, msg)

# Ground truth neighbors - 0 and 1 have extra neighbor, 6 replaced old neighbors with new neighbors
# 0: 6,6 (44, 11); 1: 6,7 (44, 22); 2: 8 (33); 3: 6 (11); 4: 7 (22); 5: 8 (33); 6: 1,0 (44, 44); 7: 4,1 (22, 22); 8: 5,2 (33, 33)
for i in all_nid:
    print(i, self.neighbors[i], self.e_id[i], self.t_id[i], self.msg[i])  # This is correct   

tensor(0) tensor([6, 6]) tensor([6, 0]) tensor([44, 11]) tensor([[665., 666.],
        [ 74.,  75.]])
tensor(1) tensor([6, 7]) tensor([7, 2]) tensor([44, 22]) tensor([[665., 666.],
        [234., 235.]])
tensor(2) tensor([              8, 139774613494808]) tensor([ 4, -1]) tensor([33, -1]) tensor([[13., 14.],
        [-1., -1.]])
tensor(3) tensor([        6, 104574592]) tensor([ 1, -1]) tensor([11, -1]) tensor([[74., 75.],
        [-1., -1.]])
tensor(4) tensor([              7, 139774482081392]) tensor([ 3, -1]) tensor([22, -1]) tensor([[234., 235.],
        [ -1.,  -1.]])
tensor(5) tensor([              8, 139774613494648]) tensor([ 5, -1]) tensor([33, -1]) tensor([[13., 14.],
        [-1., -1.]])
tensor(6) tensor([1, 0]) tensor([7, 6]) tensor([44, 44]) tensor([[665., 666.],
        [665., 666.]])
tensor(7) tensor([4, 1]) tensor([3, 2]) tensor([22, 22]) tensor([[234., 235.],
        [234., 235.]])
tensor(8) tensor([5, 2]) tensor([5, 4]) tensor([33, 33]) tensor([[13., 14.],
        [13

In [6]:
# Test _call_
neighbors, _, e_id, t_id, msg = neighbor_loader(all_nid)
# Ground truth edges (x2, for all except first 0-6 and 3-6)
# 0-6, 3-6, 1-7, 4-7, 2-8, 5-8, 0-6, 1-6
print(neighbors, neighbors.shape)
print(e_id, e_id.shape)
print(t_id, t_id.shape)
print(msg, msg.shape)

tensor([[ 6,  6],
        [ 6,  7],
        [ 8, -1],
        [ 6, -1],
        [ 7, -1],
        [ 8, -1],
        [ 1,  0],
        [ 4,  1],
        [ 5,  2]]) torch.Size([9, 2])
tensor([[ 6,  0],
        [ 7,  2],
        [ 4, -1],
        [ 1, -1],
        [ 3, -1],
        [ 5, -1],
        [ 7,  6],
        [ 3,  2],
        [ 5,  4]]) torch.Size([9, 2])
tensor([[44, 11],
        [44, 22],
        [33, -1],
        [11, -1],
        [22, -1],
        [33, -1],
        [44, 44],
        [22, 22],
        [33, 33]]) torch.Size([9, 2])
tensor([[[665., 666.],
         [ 74.,  75.]],

        [[665., 666.],
         [234., 235.]],

        [[ 13.,  14.],
         [ -1.,  -1.]],

        [[ 74.,  75.],
         [ -1.,  -1.]],

        [[234., 235.],
         [ -1.,  -1.]],

        [[ 13.,  14.],
         [ -1.,  -1.]],

        [[665., 666.],
         [665., 666.]],

        [[234., 235.],
         [234., 235.]],

        [[ 13.,  14.],
         [ 13.,  14.]]]) torch.Size([9, 2, 2])


In [7]:
# Test calling a subset of nodes
neighbors, _, e_id, t_id, msg = neighbor_loader(torch.Tensor([6]).long())
print(neighbors, neighbors.shape) # these are node 6's neighbors - correct!
print(e_id, e_id.shape)
print(t_id, t_id.shape)
print(msg, msg.shape)

tensor([[1, 0]]) torch.Size([1, 2])
tensor([[7, 6]]) torch.Size([1, 2])
tensor([[44, 44]]) torch.Size([1, 2])
tensor([[[665., 666.],
         [665., 666.]]]) torch.Size([1, 2, 2])


In [8]:
# Test reset_state and _call_
neighbor_loader.reset_state()
neighbors, _, e_id, t_id, msg = neighbor_loader(all_nid)

# All reset to -1
print(neighbors, neighbors.shape)
print(e_id, e_id.shape)
print(t_id, t_id.shape)
print(msg, msg.shape)

tensor([[-1, -1],
        [-1, -1],
        [-1, -1],
        [-1, -1],
        [-1, -1],
        [-1, -1],
        [-1, -1],
        [-1, -1],
        [-1, -1]]) torch.Size([9, 2])
tensor([[-1, -1],
        [-1, -1],
        [-1, -1],
        [-1, -1],
        [-1, -1],
        [-1, -1],
        [-1, -1],
        [-1, -1],
        [-1, -1]]) torch.Size([9, 2])
tensor([[-1, -1],
        [-1, -1],
        [-1, -1],
        [-1, -1],
        [-1, -1],
        [-1, -1],
        [-1, -1],
        [-1, -1],
        [-1, -1]]) torch.Size([9, 2])
tensor([[[-1., -1.],
         [-1., -1.]],

        [[-1., -1.],
         [-1., -1.]],

        [[-1., -1.],
         [-1., -1.]],

        [[-1., -1.],
         [-1., -1.]],

        [[-1., -1.],
         [-1., -1.]],

        [[-1., -1.],
         [-1., -1.]],

        [[-1., -1.],
         [-1., -1.]],

        [[-1., -1.],
         [-1., -1.]],

        [[-1., -1.],
         [-1., -1.]]]) torch.Size([9, 2, 2])


# Test GraphMixer Model
Since most work are handled by neighbor loader, we only perform simple sanity checks here. 

In [3]:
num_nodes = 10
num_prods = 3

node_raw_features = torch.eye(num_nodes)
edge_feat_dim = 1 # in our case, it is weight
time_dim = 1
num_neighbors = 2 

neighbor_loader = LastNeighborLoaderGraphmixer(num_nodes, num_neighbors, edge_feat_dim)
all_nid = torch.arange(num_nodes).long() 

In [4]:
# test initialization
gm = GraphMixer(node_raw_features,
                edge_feat_dim,
                time_feat_dim=time_dim,
                num_tokens=num_neighbors,
                num_layers=2,
                token_dim_expansion_factor=0.5,
                channel_dim_expansion_factor=4.0,
                dropout=0.1,
                time_gap=2000, # NOT USED 
                debug=True) 
gm.neighbor_sampler = neighbor_loader

In [5]:
# Check model parameters
for name, param in gm.named_parameters():
    if param.requires_grad:
        print(name, param.data.shape)

projection_layer.weight torch.Size([1, 2])
projection_layer.bias torch.Size([1])
mlp_mixers.0.token_norm.weight torch.Size([2])
mlp_mixers.0.token_norm.bias torch.Size([2])
mlp_mixers.0.token_feedforward.ffn.0.weight torch.Size([1, 2])
mlp_mixers.0.token_feedforward.ffn.0.bias torch.Size([1])
mlp_mixers.0.token_feedforward.ffn.3.weight torch.Size([2, 1])
mlp_mixers.0.token_feedforward.ffn.3.bias torch.Size([2])
mlp_mixers.0.channel_norm.weight torch.Size([1])
mlp_mixers.0.channel_norm.bias torch.Size([1])
mlp_mixers.0.channel_feedforward.ffn.0.weight torch.Size([4, 1])
mlp_mixers.0.channel_feedforward.ffn.0.bias torch.Size([4])
mlp_mixers.0.channel_feedforward.ffn.3.weight torch.Size([1, 4])
mlp_mixers.0.channel_feedforward.ffn.3.bias torch.Size([1])
mlp_mixers.1.token_norm.weight torch.Size([2])
mlp_mixers.1.token_norm.bias torch.Size([2])
mlp_mixers.1.token_feedforward.ffn.0.weight torch.Size([1, 2])
mlp_mixers.1.token_feedforward.ffn.0.bias torch.Size([1])
mlp_mixers.1.token_feedfor

In [6]:
# query an empty graph: expect input of zeros to the model
node_interact_times = torch.ones_like(all_nid)
print("The input shapes to compute_node_temporal_embeddings are:", all_nid.shape, node_interact_times.shape)
node_embeddings = gm.compute_node_temporal_embeddings(all_nid, node_interact_times)


The input shapes to compute_node_temporal_embeddings are: torch.Size([10]) torch.Size([10])
node_interact_times (shape, value) torch.Size([10, 1]) tensor([[1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1]])
neighbor_times (shape, value) torch.Size([10, 2]) tensor([[-1, -1],
        [-1, -1],
        [-1, -1],
        [-1, -1],
        [-1, -1],
        [-1, -1],
        [-1, -1],
        [-1, -1],
        [-1, -1],
        [-1, -1]])
input to time encoder tensor([[2, 2],
        [2, 2],
        [2, 2],
        [2, 2],
        [2, 2],
        [2, 2],
        [2, 2],
        [2, 2],
        [2, 2],
        [2, 2]])
output of time encoder (zeroed out if applicable) torch.Size([10, 2, 1]) tensor([[[0.],
         [0.]],

        [[0.],
         [0.]],

        [[0.],
         [0.]],

        [[0.],
         [0.]],

        [[0.],
         [0.]],

        [[0.],
         [0.]],

        [[0.],
         [0.]],

        [[0

In [7]:
# insert some transactions
src = torch.Tensor([0, 1, 2]).to(torch.long)
dst = torch.Tensor([3, 4, 5]).to(torch.long)
prod = torch.Tensor([6, 7, 8]).to(torch.long)
t = torch.Tensor([11, 22, 33]).to(torch.long)
msg = torch.Tensor([[74], [234], [13]]).to(torch.float)
neighbor_loader.insert(src, dst, prod, t, msg)

In [8]:
# query the new graph
node_interact_times = torch.full(all_nid.shape, 2)
node_embeddings = gm.compute_node_temporal_embeddings(all_nid, node_interact_times)


node_interact_times (shape, value) torch.Size([10, 1]) tensor([[2],
        [2],
        [2],
        [2],
        [2],
        [2],
        [2],
        [2],
        [2],
        [2]])
neighbor_times (shape, value) torch.Size([10, 2]) tensor([[11, -1],
        [22, -1],
        [33, -1],
        [11, -1],
        [22, -1],
        [33, -1],
        [11, 11],
        [22, 22],
        [33, 33],
        [-1, -1]])
input to time encoder tensor([[ -9,   3],
        [-20,   3],
        [-31,   3],
        [ -9,   3],
        [-20,   3],
        [-31,   3],
        [ -9,  -9],
        [-20, -20],
        [-31, -31],
        [  3,   3]])
output of time encoder (zeroed out if applicable) torch.Size([10, 2, 1]) tensor([[[-0.9111],
         [ 0.0000]],

        [[ 0.4081],
         [ 0.0000]],

        [[ 0.9147],
         [ 0.0000]],

        [[-0.9111],
         [ 0.0000]],

        [[ 0.4081],
         [ 0.0000]],

        [[ 0.9147],
         [ 0.0000]],

        [[-0.9111],
         [-0.91