In [1]:
%load_ext autoreload
%autoreload 2

import torch as t
import numpy as np

from memory_module import TGNPLMemory, StaticMemory
from msg_func import TGNPLMessage
from msg_agg import *

from neighbor_loader import LastNeighborLoader, LastNeighborLoaderTGNPL

# Test TGNPLMemory

In [2]:
num_nodes = 10
num_prods = 2
raw_msg_dim = 1
state_dim = 5
time_dim = 1
message_module = TGNPLMessage(raw_msg_dim, state_dim+num_prods, time_dim)
aggregator_module = MeanAggregator()

In [27]:
# test initialization
mem = TGNPLMemory(num_nodes,
        num_prods,
        raw_msg_dim,
        state_dim,
        time_dim,
        message_module,
        aggregator_module,
        state_updater_cell="gru",
        use_inventory=True,
        debt_penalty=10,
        consumption_reward=5,
        debug=True)

23 5


In [28]:
for name, param in mem.named_parameters():
    if param.requires_grad:
        print(name, param.data.shape)

prod_bilinear torch.Size([5, 5])
msg_s_module.lin.weight torch.Size([23, 23])
msg_s_module.lin.bias torch.Size([23])
msg_s_module.layer_norm.weight torch.Size([23])
msg_s_module.layer_norm.bias torch.Size([23])
msg_d_module.lin.weight torch.Size([23, 23])
msg_d_module.lin.bias torch.Size([23])
msg_d_module.layer_norm.weight torch.Size([23])
msg_d_module.layer_norm.bias torch.Size([23])
msg_p_module.lin.weight torch.Size([23, 23])
msg_p_module.lin.bias torch.Size([23])
msg_p_module.layer_norm.weight torch.Size([23])
msg_p_module.layer_norm.bias torch.Size([23])
time_enc.lin.weight torch.Size([1, 1])
time_enc.lin.bias torch.Size([1])
state_updater.weight_ih torch.Size([15, 23])
state_updater.weight_hh torch.Size([15, 5])
state_updater.bias_ih torch.Size([15])
state_updater.bias_hh torch.Size([15])
init_memory.weight torch.Size([10, 5])


In [29]:
# before any interactions have been added
# init memory should match memory
n_id = t.arange(0, num_nodes, dtype=t.long)
print(n_id)
print('init memory', mem.init_memory(n_id))
memory, last_update, loss = mem(n_id)

tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
init memory tensor([[ 1.2116,  0.3773,  1.0512, -0.9904,  0.2710],
        [-0.5007,  0.8516,  0.4899, -0.1371, -0.6343],
        [-0.2270, -2.2572, -0.2835, -0.5244,  0.9657],
        [ 0.0474, -0.0585, -1.1629,  0.9649, -0.0190],
        [-1.0272,  2.2202, -0.3713,  1.7193,  0.4966],
        [ 0.8658, -2.5166, -2.2544,  0.2747, -0.1048],
        [ 1.4423, -0.5555,  0.1374,  0.0451,  0.7957],
        [-0.4425, -0.6752,  0.0149, -0.8949, -0.1105],
        [ 0.3926, -0.4009,  1.2658,  2.0982, -0.7427],
        [ 0.6506,  0.9090,  0.2477,  1.0854,  0.1493]],
       grad_fn=<EmbeddingBackward0>)
msg_s tensor([], size=(0, 23), grad_fn=<ReluBackward0>)
msg_d tensor([], size=(0, 23), grad_fn=<ReluBackward0>)
msg_p tensor([], size=(0, 23), grad_fn=<ReluBackward0>)
aggr tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0

In [30]:
# try adding interactions
# 1) get_updated_memory will print 6 nodes [0, 1, 2, 3, 8, 9], but memories shouldn't be updated yet
# 2) _update_msg_store should update all three msg stores
src = t.Tensor([0, 0, 1, 2]).long()
dst = t.Tensor([3, 3, 3, 0]).long()
prod = t.Tensor([8, 8, 8, 9]).long()
time = t.Tensor(np.ones(4)).long()
raw_msg = t.Tensor([10, 20, 5, 13]).reshape(-1, 1)
mem.update_state(src, dst, prod, time, raw_msg)

msg_s tensor([], size=(0, 23), grad_fn=<ReluBackward0>)
msg_d tensor([], size=(0, 23), grad_fn=<ReluBackward0>)
msg_p tensor([], size=(0, 23), grad_fn=<ReluBackward0>)
aggr tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],
       grad_fn=<DivBackward0>)
Total consumed: tensor([[0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.]], grad_fn=<IndexBackward0>)
Total loss: tensor(0., grad_fn=<

In [31]:
mem.msg_s_store

{0: (tensor([0, 0]),
  tensor([3, 3]),
  tensor([8, 8]),
  tensor([1, 1]),
  tensor([[10.],
          [20.]])),
 1: (tensor([1]), tensor([3]), tensor([8]), tensor([1]), tensor([[5.]])),
 2: (tensor([2]), tensor([0]), tensor([9]), tensor([1]), tensor([[13.]])),
 3: (tensor([], dtype=torch.int64),
  tensor([], dtype=torch.int64),
  tensor([], dtype=torch.int64),
  tensor([], dtype=torch.int64),
  tensor([], size=(0, 1))),
 4: (tensor([], dtype=torch.int64),
  tensor([], dtype=torch.int64),
  tensor([], dtype=torch.int64),
  tensor([], dtype=torch.int64),
  tensor([], size=(0, 1))),
 5: (tensor([], dtype=torch.int64),
  tensor([], dtype=torch.int64),
  tensor([], dtype=torch.int64),
  tensor([], dtype=torch.int64),
  tensor([], size=(0, 1))),
 6: (tensor([], dtype=torch.int64),
  tensor([], dtype=torch.int64),
  tensor([], dtype=torch.int64),
  tensor([], dtype=torch.int64),
  tensor([], size=(0, 1))),
 7: (tensor([], dtype=torch.int64),
  tensor([], dtype=torch.int64),
  tensor([], dtype

In [32]:
mem.msg_d_store

{0: (tensor([2]), tensor([0]), tensor([9]), tensor([1]), tensor([[13.]])),
 1: (tensor([], dtype=torch.int64),
  tensor([], dtype=torch.int64),
  tensor([], dtype=torch.int64),
  tensor([], dtype=torch.int64),
  tensor([], size=(0, 1))),
 2: (tensor([], dtype=torch.int64),
  tensor([], dtype=torch.int64),
  tensor([], dtype=torch.int64),
  tensor([], dtype=torch.int64),
  tensor([], size=(0, 1))),
 3: (tensor([0, 0, 1]),
  tensor([3, 3, 3]),
  tensor([8, 8, 8]),
  tensor([1, 1, 1]),
  tensor([[10.],
          [20.],
          [ 5.]])),
 4: (tensor([], dtype=torch.int64),
  tensor([], dtype=torch.int64),
  tensor([], dtype=torch.int64),
  tensor([], dtype=torch.int64),
  tensor([], size=(0, 1))),
 5: (tensor([], dtype=torch.int64),
  tensor([], dtype=torch.int64),
  tensor([], dtype=torch.int64),
  tensor([], dtype=torch.int64),
  tensor([], size=(0, 1))),
 6: (tensor([], dtype=torch.int64),
  tensor([], dtype=torch.int64),
  tensor([], dtype=torch.int64),
  tensor([], dtype=torch.int64

In [33]:
mem.msg_p_store

{0: (tensor([], dtype=torch.int64),
  tensor([], dtype=torch.int64),
  tensor([], dtype=torch.int64),
  tensor([], dtype=torch.int64),
  tensor([], size=(0, 1))),
 1: (tensor([], dtype=torch.int64),
  tensor([], dtype=torch.int64),
  tensor([], dtype=torch.int64),
  tensor([], dtype=torch.int64),
  tensor([], size=(0, 1))),
 2: (tensor([], dtype=torch.int64),
  tensor([], dtype=torch.int64),
  tensor([], dtype=torch.int64),
  tensor([], dtype=torch.int64),
  tensor([], size=(0, 1))),
 3: (tensor([], dtype=torch.int64),
  tensor([], dtype=torch.int64),
  tensor([], dtype=torch.int64),
  tensor([], dtype=torch.int64),
  tensor([], size=(0, 1))),
 4: (tensor([], dtype=torch.int64),
  tensor([], dtype=torch.int64),
  tensor([], dtype=torch.int64),
  tensor([], dtype=torch.int64),
  tensor([], size=(0, 1))),
 5: (tensor([], dtype=torch.int64),
  tensor([], dtype=torch.int64),
  tensor([], dtype=torch.int64),
  tensor([], dtype=torch.int64),
  tensor([], size=(0, 1))),
 6: (tensor([], dtype=

In [34]:
# now get memory again - only nodes with interactions should've changed
memory, last_update, loss = mem(n_id)  # test .forward()

msg_s tensor([[1.4807, 0.0000, 0.5545, 1.3220, 0.0000, 1.8267, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.7232, 0.0000, 0.0000, 1.0196, 1.5632, 0.0000, 1.1032,
         0.0000, 0.0000, 0.2157, 0.0000, 0.0000],
        [1.3461, 0.0000, 0.6229, 1.4381, 0.0000, 1.7520, 0.0000, 0.0000, 0.0000,
         0.0430, 0.0292, 0.5518, 0.0000, 0.0000, 0.8611, 1.6331, 0.0000, 1.2646,
         0.0000, 0.0000, 0.3054, 0.0000, 0.0000],
        [1.4023, 0.0000, 0.1903, 0.5603, 0.0000, 2.2669, 0.0000, 0.0000, 0.0000,
         0.1327, 0.0000, 0.5829, 0.0000, 0.4303, 0.9183, 1.5209, 0.0818, 1.2602,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [1.5099, 0.0000, 0.9403, 1.6272, 0.0000, 1.3254, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.4027, 0.0000, 0.0000, 0.7330, 1.5194, 0.0000, 1.1682,
         0.3515, 0.0000, 0.2631, 0.0000, 0.0000]], grad_fn=<ReluBackward0>)
msg_d tensor([[1.5099, 0.0000, 0.9403, 1.6272, 0.0000, 1.3254, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.4027, 

In [35]:
# 2 supplied product 9
# product 9 has no inventory
memory[[2,9]]

tensor([[ 3.6012e-01, -1.3115e+00, -3.8810e-01, -2.2612e-01,  9.3875e-01,
         -1.0334e+02, -1.2031e+02],
        [ 5.9465e-01,  6.6270e-01,  1.1274e-01,  9.3827e-01,  7.8881e-01,
          0.0000e+00,  0.0000e+00]], grad_fn=<IndexBackward0>)

In [36]:
# 3 received exactly 35 of product 8
# product 8 has no inventory
memory[[3,8]]

tensor([[ 0.2005,  0.1731, -1.1316,  0.7166,  0.7719, 35.0000,  0.0000],
        [ 0.3877, -0.2104,  0.9158,  1.9632,  0.4339,  0.0000,  0.0000]],
       grad_fn=<IndexBackward0>)

In [37]:
# should be unaffected
memory[[4,5,6,7]]

tensor([[-1.0272,  2.2202, -0.3713,  1.7193,  0.4966,  0.0000,  0.0000],
        [ 0.8658, -2.5166, -2.2544,  0.2747, -0.1048,  0.0000,  0.0000],
        [ 1.4423, -0.5555,  0.1374,  0.0451,  0.7957,  0.0000,  0.0000],
        [-0.4425, -0.6752,  0.0149, -0.8949, -0.1105,  0.0000,  0.0000]],
       grad_fn=<IndexBackward0>)

In [38]:
# should only be updated for nodes in transactions
last_update

tensor([ 1,  1,  1,  1, -1, -1, -1, -1,  1,  1])

In [39]:
# store latest memory, reset message store
mem._update_memory(n_id)
mem._reset_message_store()

msg_s tensor([[1.4807, 0.0000, 0.5545, 1.3220, 0.0000, 1.8267, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.7232, 0.0000, 0.0000, 1.0196, 1.5632, 0.0000, 1.1032,
         0.0000, 0.0000, 0.2157, 0.0000, 0.0000],
        [1.3461, 0.0000, 0.6229, 1.4381, 0.0000, 1.7520, 0.0000, 0.0000, 0.0000,
         0.0430, 0.0292, 0.5518, 0.0000, 0.0000, 0.8611, 1.6331, 0.0000, 1.2646,
         0.0000, 0.0000, 0.3054, 0.0000, 0.0000],
        [1.4023, 0.0000, 0.1903, 0.5603, 0.0000, 2.2669, 0.0000, 0.0000, 0.0000,
         0.1327, 0.0000, 0.5829, 0.0000, 0.4303, 0.9183, 1.5209, 0.0818, 1.2602,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [1.5099, 0.0000, 0.9403, 1.6272, 0.0000, 1.3254, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.4027, 0.0000, 0.0000, 0.7330, 1.5194, 0.0000, 1.1682,
         0.3515, 0.0000, 0.2631, 0.0000, 0.0000]], grad_fn=<ReluBackward0>)
msg_d tensor([[1.5099, 0.0000, 0.9403, 1.6272, 0.0000, 1.3254, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.4027, 

In [40]:
mem.memory

tensor([[ 9.6774e-01,  3.5577e-01,  7.3739e-01, -9.0616e-01,  8.4783e-01,
         -2.0483e+02, -2.2548e+02],
        [-8.9614e-02,  7.6891e-01,  3.3754e-01, -1.8818e-01,  4.9022e-01,
         -3.4139e+01, -3.9746e+01],
        [ 3.6012e-01, -1.3115e+00, -3.8810e-01, -2.2612e-01,  9.3875e-01,
         -1.0334e+02, -1.2031e+02],
        [ 2.0049e-01,  1.7311e-01, -1.1316e+00,  7.1657e-01,  7.7188e-01,
          3.5000e+01,  0.0000e+00],
        [-1.0272e+00,  2.2202e+00, -3.7127e-01,  1.7193e+00,  4.9665e-01,
          0.0000e+00,  0.0000e+00],
        [ 8.6581e-01, -2.5166e+00, -2.2544e+00,  2.7472e-01, -1.0477e-01,
          0.0000e+00,  0.0000e+00],
        [ 1.4423e+00, -5.5553e-01,  1.3739e-01,  4.5061e-02,  7.9573e-01,
          0.0000e+00,  0.0000e+00],
        [-4.4254e-01, -6.7520e-01,  1.4906e-02, -8.9491e-01, -1.1055e-01,
          0.0000e+00,  0.0000e+00],
        [ 3.8768e-01, -2.1044e-01,  9.1579e-01,  1.9632e+00,  4.3385e-01,
          0.0000e+00,  0.0000e+00],
        [ 

In [47]:
# should be equal when node hasn't had interactions
mem.memory[:, :5] == mem.init_memory(n_id)

tensor([[False, False, False, False, False],
        [False, False, False, False, False],
        [False, False, False, False, False],
        [False, False, False, False, False],
        [ True,  True,  True,  True,  True],
        [ True,  True,  True,  True,  True],
        [ True,  True,  True,  True,  True],
        [ True,  True,  True,  True,  True],
        [False, False, False, False, False],
        [False, False, False, False, False]])

In [48]:
# should be empty
mem.msg_s_store

{0: (tensor([], dtype=torch.int64),
  tensor([], dtype=torch.int64),
  tensor([], dtype=torch.int64),
  tensor([], dtype=torch.int64),
  tensor([], size=(0, 1))),
 1: (tensor([], dtype=torch.int64),
  tensor([], dtype=torch.int64),
  tensor([], dtype=torch.int64),
  tensor([], dtype=torch.int64),
  tensor([], size=(0, 1))),
 2: (tensor([], dtype=torch.int64),
  tensor([], dtype=torch.int64),
  tensor([], dtype=torch.int64),
  tensor([], dtype=torch.int64),
  tensor([], size=(0, 1))),
 3: (tensor([], dtype=torch.int64),
  tensor([], dtype=torch.int64),
  tensor([], dtype=torch.int64),
  tensor([], dtype=torch.int64),
  tensor([], size=(0, 1))),
 4: (tensor([], dtype=torch.int64),
  tensor([], dtype=torch.int64),
  tensor([], dtype=torch.int64),
  tensor([], dtype=torch.int64),
  tensor([], size=(0, 1))),
 5: (tensor([], dtype=torch.int64),
  tensor([], dtype=torch.int64),
  tensor([], dtype=torch.int64),
  tensor([], dtype=torch.int64),
  tensor([], size=(0, 1))),
 6: (tensor([], dtype=

In [49]:
mem.last_update

tensor([ 1,  1,  1,  1, -1, -1, -1, -1,  1,  1])

## Test attention weight learning

In [15]:
num_nodes = 6
num_prods = 3
raw_msg_dim = 1
state_dim = 2
time_dim = 1
message_module = TGNPLMessage(raw_msg_dim, state_dim+num_prods, time_dim)
aggregator_module = MeanAggregator()

In [16]:
# raw_msg_dim + (3 * memory_dim) + time_dim
message_module.out_channels

17

In [17]:
n_id = t.arange(0, num_nodes).long()
print(n_id)

tensor([0, 1, 2, 3, 4, 5])


In [23]:
# test directly learning attention weights
mem = TGNPLMemory(num_nodes,
        num_prods,
        raw_msg_dim,
        state_dim,
        time_dim,
        message_module,
        aggregator_module,
        state_updater_cell="gru",
        use_inventory=True,
        learn_att_direct=True,
        debt_penalty=10,
        consumption_reward=1,
        debug=False)
opt = t.optim.Adam(mem.parameters(), lr=1e-2)
for name, param in mem.named_parameters():
    if param.requires_grad:
        print(name, param.data.shape)

17 2
att_weights torch.Size([3, 3])
msg_s_module.lin.weight torch.Size([17, 17])
msg_s_module.lin.bias torch.Size([17])
msg_s_module.layer_norm.weight torch.Size([17])
msg_s_module.layer_norm.bias torch.Size([17])
msg_d_module.lin.weight torch.Size([17, 17])
msg_d_module.lin.bias torch.Size([17])
msg_d_module.layer_norm.weight torch.Size([17])
msg_d_module.layer_norm.bias torch.Size([17])
msg_p_module.lin.weight torch.Size([17, 17])
msg_p_module.lin.bias torch.Size([17])
msg_p_module.layer_norm.weight torch.Size([17])
msg_p_module.layer_norm.bias torch.Size([17])
time_enc.lin.weight torch.Size([1, 1])
time_enc.lin.bias torch.Size([1])
state_updater.weight_ih torch.Size([6, 17])
state_updater.weight_hh torch.Size([6, 2])
state_updater.bias_ih torch.Size([6])
state_updater.bias_hh torch.Size([6])


In [24]:
# need products 3 and 4 to make 5
# att weights should have nonzero in 3,1 and/or 3,2 - this works
t.autograd.set_detect_anomaly(True)
for i in range(1, 31):
    opt.zero_grad()
    if (i % 2) == 0:
        # 1 sells 5 to 2
        src = t.Tensor([1]).long()
        dst = t.Tensor([2]).long()
        prod = t.Tensor([5]).long()
        time = t.Tensor([i]).long()
        raw_msg = t.Tensor([1]).reshape(-1, 1)
    else:
        # 1 buys 3 and 4 from 0
        src = t.Tensor([0, 0]).long()
        dst = t.Tensor([1, 1]).long()
        prod = t.Tensor([3, 4]).long()
        time = t.Tensor([i, i]).long()
        raw_msg = t.Tensor([2, 4]).reshape(-1, 1)

    print('iter', i)
    mem.update_state(src, dst, prod, time, raw_msg)
    memory, last_update, loss = mem(n_id)
    print('loss', loss)
    att_weights = mem.get_prod_attention()
    print('att weights', att_weights)
    loss.backward()
    opt.step()
    mem.detach()

iter 1
loss tensor(162., grad_fn=<SumBackward0>)
att weights tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]], grad_fn=<ReluBackward0>)
iter 2
loss tensor(167.3800, grad_fn=<SumBackward0>)
att weights tensor([[0.9900, 0.9900, 0.9900],
        [0.9900, 0.9900, 0.9900],
        [1.0000, 1.0000, 1.0000]], grad_fn=<ReluBackward0>)
iter 3
loss tensor(352.0037, grad_fn=<SumBackward0>)
att weights tensor([[0.9800, 0.9800, 0.9800],
        [0.9800, 0.9800, 0.9800],
        [1.0074, 1.0074, 0.9926]], grad_fn=<ReluBackward0>)
iter 4
loss tensor(360.1991, grad_fn=<SumBackward0>)
att weights tensor([[0.9704, 0.9704, 0.9704],
        [0.9704, 0.9704, 0.9704],
        [1.0160, 1.0160, 0.9843]], grad_fn=<ReluBackward0>)
iter 5
loss tensor(541.2927, grad_fn=<SumBackward0>)
att weights tensor([[0.9609, 0.9609, 0.9609],
        [0.9609, 0.9609, 0.9609],
        [1.0251, 1.0251, 0.9754]], grad_fn=<ReluBackward0>)
iter 6
loss tensor(549.2872, grad_fn=<SumBackward0>)
att weights tensor([[0

In [25]:
# test learning attention weights with product memories
mem = TGNPLMemory(num_nodes,
        num_prods,
        raw_msg_dim,
        state_dim,
        time_dim,
        message_module,
        aggregator_module,
        state_updater_cell="gru",
        use_inventory=True,
        learn_att_direct=False,
        debt_penalty=10,
        consumption_reward=1,
        debug=False)
opt = t.optim.Adam(mem.parameters(), lr=1e-2)
for name, param in mem.named_parameters():
    if param.requires_grad:
        print(name, param.data.shape)

17 2
prod_bilinear torch.Size([2, 2])
msg_s_module.lin.weight torch.Size([17, 17])
msg_s_module.lin.bias torch.Size([17])
msg_s_module.layer_norm.weight torch.Size([17])
msg_s_module.layer_norm.bias torch.Size([17])
msg_d_module.lin.weight torch.Size([17, 17])
msg_d_module.lin.bias torch.Size([17])
msg_d_module.layer_norm.weight torch.Size([17])
msg_d_module.layer_norm.bias torch.Size([17])
msg_p_module.lin.weight torch.Size([17, 17])
msg_p_module.lin.bias torch.Size([17])
msg_p_module.layer_norm.weight torch.Size([17])
msg_p_module.layer_norm.bias torch.Size([17])
time_enc.lin.weight torch.Size([1, 1])
time_enc.lin.bias torch.Size([1])
state_updater.weight_ih torch.Size([6, 17])
state_updater.weight_hh torch.Size([6, 2])
state_updater.bias_ih torch.Size([6])
state_updater.bias_hh torch.Size([6])


In [26]:
# need products 3 and 4 to make 5
# att weights should have nonzero in 3,1 and/or 3,2 - this doesn't work
t.autograd.set_detect_anomaly(True)
for i in range(1, 31):
    opt.zero_grad()
    if (i % 2) == 0:
        # 1 sells 5 to 2
        src = t.Tensor([1]).long()
        dst = t.Tensor([2]).long()
        prod = t.Tensor([5]).long()
        time = t.Tensor([i]).long()
        raw_msg = t.Tensor([1]).reshape(-1, 1)
    else:
        # 1 buys 3 and 4 from 0
        src = t.Tensor([0, 0]).long()
        dst = t.Tensor([1, 1]).long()
        prod = t.Tensor([3, 4]).long()
        time = t.Tensor([i, i]).long()
        raw_msg = t.Tensor([2, 4]).reshape(-1, 1)

    print('iter', i)
    mem.update_state(src, dst, prod, time, raw_msg)
    memory, last_update, loss = mem(n_id)
    print('loss', loss)
    att_weights = mem.get_prod_attention()
    print('att weights', att_weights)
    loss.backward()
    opt.step()
    mem.detach()

iter 1
loss tensor(0.4768, grad_fn=<SumBackward0>)
att weights tensor([[0.0044, 0.0044, 0.0000],
        [0.0044, 0.0044, 0.0000],
        [0.0000, 0.0000, 0.0000]], grad_fn=<ReluBackward0>)
iter 2
loss tensor(0.4148, grad_fn=<SumBackward0>)
att weights tensor([[0.0031, 0.0031, 0.0015],
        [0.0031, 0.0031, 0.0015],
        [0.0015, 0.0015, 0.0005]], grad_fn=<ReluBackward0>)
iter 3
loss tensor(3.2606, grad_fn=<SumBackward0>)
att weights tensor([[0.1207, 0.0161, 0.0000],
        [0.0161, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000]], grad_fn=<ReluBackward0>)
iter 4
loss tensor(3.2100, grad_fn=<SumBackward0>)
att weights tensor([[0.1169, 0.0131, 0.0000],
        [0.0131, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0202]], grad_fn=<ReluBackward0>)
iter 5
loss tensor(15.2349, grad_fn=<SumBackward0>)
att weights tensor([[0.1803, 0.1137, 0.0000],
        [0.1137, 0.0697, 0.0000],
        [0.0000, 0.0000, 0.0142]], grad_fn=<ReluBackward0>)
iter 6
loss tensor(14.6644, grad_fn=<SumBack

# Test Neighbor Loader

In [24]:
# Test init
num_nodes = 9
num_neighbors = 2
all_nid = torch.arange(9).long()
fid = torch.arange(6).long()  # firm ids
pid = torch.arange(6,9).long()  # prod ids
neighbor_loader = LastNeighborLoaderTGNPL(num_nodes, size=num_neighbors)
print(neighbor_loader.neighbors.shape)
print(neighbor_loader.e_id.shape)
print(neighbor_loader._assoc.shape)
self = neighbor_loader

torch.Size([9, 2])
torch.Size([9, 2])
torch.Size([9])


In [25]:
# Test insert
src = torch.Tensor([0, 1, 2]).to(torch.long)
dst = torch.Tensor([3, 4, 5]).to(torch.long)
prod = torch.Tensor([6, 7, 8]).to(torch.long)
neighbor_loader.insert(src, dst, prod)

# Ground truth neighbors
# 0: 6; 1: 7; 2: 8; 3: 6; 4: 7; 5: 8; 6: 3,0; 7: 4,1; 8: 5,2
for i in all_nid:
    print(i, self.neighbors[i], self.e_id[i])

tensor(0) tensor([6, 0]) tensor([ 0, -1])
tensor(1) tensor([7, 1]) tensor([ 2, -1])
tensor(2) tensor([8, 0]) tensor([ 4, -1])
tensor(3) tensor([              6, 140133897100152]) tensor([ 1, -1])
tensor(4) tensor([  7, 129]) tensor([ 3, -1])
tensor(5) tensor([              8, 140133897100152]) tensor([ 5, -1])
tensor(6) tensor([3, 0]) tensor([1, 0])
tensor(7) tensor([4, 1]) tensor([3, 2])
tensor(8) tensor([5, 2]) tensor([5, 4])


In [26]:
# Test _call_
n_id, edge_index, e_id = neighbor_loader(fid, pid)
# Ground truth edges (x2, for all)
# 0-6, 3-6, 1-7, 4-7, 2-8, 5-8
print(n_id, n_id.shape)
print(edge_index.shape, e_id.shape)
for i in range(edge_index.size(1)):
    print(f'{edge_index[0,i]} -> {edge_index[1,i]}; e_id = {e_id[i]}')

tensor([0, 1, 2, 3, 4, 5, 6, 7, 8]) torch.Size([9])
torch.Size([2, 12]) torch.Size([12])
6 -> 0; e_id = 0
7 -> 1; e_id = 2
8 -> 2; e_id = 4
6 -> 3; e_id = 1
7 -> 4; e_id = 3
8 -> 5; e_id = 5
3 -> 6; e_id = 1
0 -> 6; e_id = 0
4 -> 7; e_id = 3
1 -> 7; e_id = 2
5 -> 8; e_id = 5
2 -> 8; e_id = 4


In [27]:
# Test insert again
src = torch.Tensor([0]).to(torch.long)
dst = torch.Tensor([1]).to(torch.long)
prod = torch.Tensor([6]).to(torch.long)
neighbor_loader.insert(src, dst, prod)

# Ground truth neighbors - 0 and 1 have extra neighbor, 6 replaced old neighbors with new neighbors
# 0: 6,6; 1: 6,7; 2: 8; 3: 6; 4: 7; 5: 8; 6: 1,0; 7: 4,1; 8: 5,2
for i in all_nid:
    print(i, self.neighbors[i], self.e_id[i])  # This is correct

tensor(0) tensor([6, 6]) tensor([6, 0])
tensor(1) tensor([6, 7]) tensor([7, 2])
tensor(2) tensor([8, 0]) tensor([ 4, -1])
tensor(3) tensor([              6, 140133897100152]) tensor([ 1, -1])
tensor(4) tensor([  7, 129]) tensor([ 3, -1])
tensor(5) tensor([              8, 140133897100152]) tensor([ 5, -1])
tensor(6) tensor([1, 0]) tensor([7, 6])
tensor(7) tensor([4, 1]) tensor([3, 2])
tensor(8) tensor([5, 2]) tensor([5, 4])


In [28]:
# Test _call_
n_id, edge_index, e_id = neighbor_loader(fid, pid)
# Ground truth edges (x2, for all except first 0-6 and 3-6)
# 0-6, 3-6, 1-7, 4-7, 2-8, 5-8, 0-6, 1-6
print(n_id, n_id.shape)
print(edge_index.shape, e_id.shape)
for i in range(edge_index.size(1)):
    print(f'{edge_index[0,i]} -> {edge_index[1,i]}; e_id = {e_id[i]}')  # This is correct

tensor([0, 1, 2, 3, 4, 5, 6, 7, 8]) torch.Size([9])
torch.Size([2, 14]) torch.Size([14])
6 -> 0; e_id = 6
6 -> 0; e_id = 0
6 -> 1; e_id = 7
7 -> 1; e_id = 2
8 -> 2; e_id = 4
6 -> 3; e_id = 1
7 -> 4; e_id = 3
8 -> 5; e_id = 5
1 -> 6; e_id = 7
0 -> 6; e_id = 6
4 -> 7; e_id = 3
1 -> 7; e_id = 2
5 -> 8; e_id = 5
2 -> 8; e_id = 4


In [31]:
# Test calling a subset of nodes
n_id, edge_index, e_id = neighbor_loader(torch.Tensor([]).long(), torch.Tensor([6]).long())
print(n_id)
for i in range(edge_index.size(1)):
    print(f'{edge_index[0,i]} -> {edge_index[1,i]}; e_id = {e_id[i]}')  # This is correct - reindexing matches up

tensor([0, 1, 6])
1 -> 2; e_id = 7
0 -> 2; e_id = 6


# Test data

In [2]:
import wandb
import math
import timeit
from tqdm import tqdm
import json

import os
import os.path as osp
from pathlib import Path
import numpy as np

import torch
from sklearn.metrics import average_precision_score, roc_auc_score
from torch.nn import Linear

from torch_geometric.datasets import JODIEDataset
from torch_geometric.loader import TemporalDataLoader

from torch_geometric.nn import TransformerConv

# internal imports
from tgb.utils.utils import *
from tgb.linkproppred.evaluate import Evaluator
from modules.decoder import LinkPredictorTGNPL
from modules.emb_module import GraphAttentionEmbedding
from modules.msg_func import TGNPLMessage
from modules.msg_agg import MeanAggregator
from modules.neighbor_loader import LastNeighborLoaderTGNPL
from modules.memory_module import TGNPLMemory, StaticMemory
from modules.early_stopping import  EarlyStopMonitor
from modules.hyper_edgebank import HyperEdgeBankPredictor, test_edgebank
from tgb.linkproppred.dataset_pyg import PyGLinkPropPredDataset, PyGLinkPropPredDatasetHyper

In [3]:
DATA = 'tgbl-hypergraph'
with open(f"/lfs/turing1/0/{os.getlogin()}/supply-chains/TGB/tgb/datasets/{DATA.replace('-', '_')}/{DATA}_meta.json","r") as file:
    METADATA = json.load(file)
    NUM_NODES = len(METADATA["id2entity"])
METADATA.keys()

dict_keys(['product_threshold', 'id2entity', 'train_max_ts', 'val_max_ts', 'test_max_ts'])

In [4]:
METADATA['train_max_ts']

272

In [5]:
METADATA['test_max_ts']

364

In [6]:
NUM_FIRMS = METADATA["product_threshold"]
NUM_PRODUCTS = NUM_NODES - NUM_FIRMS
NUM_FIRMS, NUM_PRODUCTS

(7257, 2852)

In [7]:
# load dataset
device = "cpu"
dataset = PyGLinkPropPredDatasetHyper(name=DATA, root="datasets")
train_mask = dataset.train_mask
val_mask = dataset.val_mask
test_mask = dataset.test_mask
data = dataset.get_TemporalData()
data = data.to(device)

# for evaluation
neg_sampler = dataset.negative_sampler
dataset.load_val_ns()  # load validation negative samples
metric = dataset.eval_metric
evaluator = Evaluator(name=DATA)

Dataset tgbl-hypergraph url not found, download not supported yet.
file found, skipping download
Dataset directory is  /lfs/turing1/0/serinac/supply-chains/TGB/tgb/datasets/tgbl_hypergraph
loading processed file


In [8]:
data

TemporalData(src=[305277], dst=[305277], t=[305277], msg=[305277, 1], prod=[305277], y=[305277])

In [9]:
data.src[:10]

tensor([2246, 6686, 3109, 3791, 2124, 4738, 5023, 6768, 1081,  708])

In [10]:
data.dst[:10]

tensor([1727, 4512, 5068, 6941, 2609, 4240, 1410, 4512, 3544, 5851])

In [11]:
data.prod[:10]

tensor([8448, 8055, 9212, 7338, 8490, 9519, 9689, 8055, 7342, 8440])

In [12]:
data.t[:10]

tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0])

In [13]:
data.t[train_mask]

tensor([  0,   0,   0,  ..., 272, 272, 272])

In [14]:
data.t[test_mask]

tensor([318, 318, 318,  ..., 364, 364, 364])

In [15]:
BATCH_SIZE = 200
train_data = data[train_mask]
val_data = data[val_mask]
test_data = data[test_mask]

train_loader = TemporalDataLoader(train_data, batch_size=BATCH_SIZE)
val_loader = TemporalDataLoader(val_data, batch_size=BATCH_SIZE)
test_loader = TemporalDataLoader(test_data, batch_size=BATCH_SIZE)
print(len(train_loader))

1074


In [16]:
# Ensure to only sample actual source, product, or destination nodes as negatives.
min_src_idx, max_src_idx = int(data.src.min()), int(data.src.max())
min_prod_idx, max_prod_idx = int(data.prod.min()), int(data.prod.max())
min_dst_idx, max_dst_idx = int(data.dst.min()), int(data.dst.max())
print(min_src_idx, max_src_idx)
print(min_prod_idx, max_prod_idx)
print(min_dst_idx, max_dst_idx)

0 7256
7257 10108
3 7253


## Test edgebank

In [17]:
edgebank = HyperEdgeBankPredictor(NUM_FIRMS, NUM_PRODUCTS, consecutive=True)

In [18]:
idx = edgebank.convert_triplet_to_index(train_data.src, train_data.dst, train_data.prod)
print(idx[:10])
src, dst, prod = edgebank.convert_index_to_triplet(idx)
assert (train_data.src == src).all()
assert (train_data.dst == dst).all()
assert (train_data.prod == prod).all()

tensor([ 46490307739, 138392770326,  64361316967,  78481986337,  43967793637,
         98074310174, 103964873924, 140089921374,  22383525657,  14670138747])


In [19]:
edgebank.fit(train_data.src, train_data.dst, train_data.prod)

Fit on 214674 edges; found 38756 unique


In [20]:
edgebank.predict(val_data.src[:10], val_data.dst[:10], val_data.prod[:10])

tensor([ 11.,  61.,   1., 213.,   0., 172.,   3.,   1., 131.,   7.])

In [21]:
y_pred = edgebank.predict(val_data.src, val_data.dst, val_data.prod, use_counts=True)
torch.sum(y_pred > 0), len(y_pred)

(tensor(38147), 44991)

In [22]:
test_edgebank(val_loader, neg_sampler, "val", evaluator, metric, edgebank, use_counts=False)

100%|███████████████████████████████████████████████████████████████████████| 225/225 [00:02<00:00, 92.14it/s]


0.3399311304092407

In [23]:
test_edgebank(val_loader, neg_sampler, "val", evaluator, metric, edgebank, use_counts=True)

100%|███████████████████████████████████████████████████████████████████████| 225/225 [00:04<00:00, 51.55it/s]


0.6819248199462891