In [1]:
%load_ext autoreload
%autoreload 2

import torch as t
import numpy as np

from memory_module import TGNPLMemory, StaticMemory
from inventory_module import TGNPLInventory
from msg_func import TGNPLMessage
from msg_agg import *

from neighbor_loader import LastNeighborLoader, LastNeighborLoaderTGNPL

import matplotlib.pyplot as plt

# Test TGNPLMemory

In [2]:
num_nodes = 10
num_prods = 2
raw_msg_dim = 1
memory_dim = 5
time_dim = 1
message_module = TGNPLMessage(raw_msg_dim, memory_dim, time_dim)
aggregator_module = MeanAggregator()

In [3]:
# test initialization
mem = TGNPLMemory(num_nodes,
        num_prods,
        raw_msg_dim,
        memory_dim,
        time_dim,
        message_module,
        aggregator_module,
        memory_updater_cell="gru",
        debug=True)

In [4]:
for name, param in mem.named_parameters():
    if param.requires_grad:
        print(name, param.data.shape)

msg_s_module.lin.weight torch.Size([17, 17])
msg_s_module.lin.bias torch.Size([17])
msg_s_module.layer_norm.weight torch.Size([17])
msg_s_module.layer_norm.bias torch.Size([17])
msg_d_module.lin.weight torch.Size([17, 17])
msg_d_module.lin.bias torch.Size([17])
msg_d_module.layer_norm.weight torch.Size([17])
msg_d_module.layer_norm.bias torch.Size([17])
msg_p_module.lin.weight torch.Size([17, 17])
msg_p_module.lin.bias torch.Size([17])
msg_p_module.layer_norm.weight torch.Size([17])
msg_p_module.layer_norm.bias torch.Size([17])
time_enc.lin.weight torch.Size([1, 1])
time_enc.lin.bias torch.Size([1])
memory_updater.weight_ih torch.Size([15, 17])
memory_updater.weight_hh torch.Size([15, 5])
memory_updater.bias_ih torch.Size([15])
memory_updater.bias_hh torch.Size([15])
init_memory.weight torch.Size([10, 5])


In [5]:
# before any interactions have been added
# init memory should match memory
n_id = t.arange(0, num_nodes, dtype=t.long)
print(n_id)
print('init memory', mem.init_memory(n_id))
memory, last_update, loss = mem(n_id)

tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
init memory tensor([[ 0.4914,  0.7131, -2.8262, -2.5365,  0.3427],
        [-0.0779,  0.6779,  1.2663, -0.5405, -1.4563],
        [-2.1379,  0.2348, -1.1929,  0.5336, -0.2746],
        [ 0.4992,  0.8994,  0.8674,  0.6404,  2.3216],
        [-0.0807, -0.0544, -1.4065, -0.1904,  0.6187],
        [ 0.8137,  1.2687,  0.9926,  0.7229, -0.0804],
        [-1.1660, -0.5671,  0.2012, -1.1957,  0.2869],
        [-1.3550,  0.0471,  0.1788, -0.9293,  0.6383],
        [-0.2679,  0.8475, -0.9793,  1.2555, -1.0411],
        [-0.5698, -0.0139,  0.1173, -0.6805,  1.5255]],
       grad_fn=<EmbeddingBackward0>)
msg_s tensor([], size=(0, 17), grad_fn=<ReluBackward0>)
msg_d tensor([], size=(0, 17), grad_fn=<ReluBackward0>)
msg_p tensor([], size=(0, 17), grad_fn=<ReluBackward0>)
aggr tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.,

In [6]:
# try adding interactions
# 1) get_updated_memory will print 6 nodes [0, 1, 2, 3, 8, 9], but memories shouldn't be updated yet
# 2) _update_msg_store should update all three msg stores
src = t.Tensor([0, 0, 1, 2]).long()
dst = t.Tensor([3, 3, 3, 0]).long()
prod = t.Tensor([8, 8, 8, 9]).long()
time = t.Tensor(np.ones(4)).long()
raw_msg = t.Tensor([10, 20, 5, 13]).reshape(-1, 1)
mem.update_state(src, dst, prod, time, raw_msg)

msg_s tensor([], size=(0, 17), grad_fn=<ReluBackward0>)
msg_d tensor([], size=(0, 17), grad_fn=<ReluBackward0>)
msg_p tensor([], size=(0, 17), grad_fn=<ReluBackward0>)
aggr tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],
       grad_fn=<DivBackward0>)
memory tensor([[ 0.4914,  0.7131, -2.8262, -2.5365,  0.3427],
        [-0.0779,  0.6779,  1.2663, -0.5405, -1.4563],
        [-2.1379,  0.2348, -1.1929,  0.5336, -0.2746],
        [ 0.4992,  0.8994,  0.8674,  0.6404,  2.3216],
        [-0.2679,  0.8475, -0.9793,  1.2555, -1.0411],
        [-0.5698, -0.0139,  0.1173, -0.6805, 

In [7]:
mem.msg_s_store

{0: (tensor([0, 0]),
  tensor([3, 3]),
  tensor([8, 8]),
  tensor([1, 1]),
  tensor([[10.],
          [20.]])),
 1: (tensor([1]), tensor([3]), tensor([8]), tensor([1]), tensor([[5.]])),
 2: (tensor([2]), tensor([0]), tensor([9]), tensor([1]), tensor([[13.]])),
 3: (tensor([], dtype=torch.int64),
  tensor([], dtype=torch.int64),
  tensor([], dtype=torch.int64),
  tensor([], dtype=torch.int64),
  tensor([], size=(0, 1))),
 4: (tensor([], dtype=torch.int64),
  tensor([], dtype=torch.int64),
  tensor([], dtype=torch.int64),
  tensor([], dtype=torch.int64),
  tensor([], size=(0, 1))),
 5: (tensor([], dtype=torch.int64),
  tensor([], dtype=torch.int64),
  tensor([], dtype=torch.int64),
  tensor([], dtype=torch.int64),
  tensor([], size=(0, 1))),
 6: (tensor([], dtype=torch.int64),
  tensor([], dtype=torch.int64),
  tensor([], dtype=torch.int64),
  tensor([], dtype=torch.int64),
  tensor([], size=(0, 1))),
 7: (tensor([], dtype=torch.int64),
  tensor([], dtype=torch.int64),
  tensor([], dtype

In [8]:
mem.msg_d_store

{0: (tensor([2]), tensor([0]), tensor([9]), tensor([1]), tensor([[13.]])),
 1: (tensor([], dtype=torch.int64),
  tensor([], dtype=torch.int64),
  tensor([], dtype=torch.int64),
  tensor([], dtype=torch.int64),
  tensor([], size=(0, 1))),
 2: (tensor([], dtype=torch.int64),
  tensor([], dtype=torch.int64),
  tensor([], dtype=torch.int64),
  tensor([], dtype=torch.int64),
  tensor([], size=(0, 1))),
 3: (tensor([0, 0, 1]),
  tensor([3, 3, 3]),
  tensor([8, 8, 8]),
  tensor([1, 1, 1]),
  tensor([[10.],
          [20.],
          [ 5.]])),
 4: (tensor([], dtype=torch.int64),
  tensor([], dtype=torch.int64),
  tensor([], dtype=torch.int64),
  tensor([], dtype=torch.int64),
  tensor([], size=(0, 1))),
 5: (tensor([], dtype=torch.int64),
  tensor([], dtype=torch.int64),
  tensor([], dtype=torch.int64),
  tensor([], dtype=torch.int64),
  tensor([], size=(0, 1))),
 6: (tensor([], dtype=torch.int64),
  tensor([], dtype=torch.int64),
  tensor([], dtype=torch.int64),
  tensor([], dtype=torch.int64

In [9]:
mem.msg_p_store

{0: (tensor([], dtype=torch.int64),
  tensor([], dtype=torch.int64),
  tensor([], dtype=torch.int64),
  tensor([], dtype=torch.int64),
  tensor([], size=(0, 1))),
 1: (tensor([], dtype=torch.int64),
  tensor([], dtype=torch.int64),
  tensor([], dtype=torch.int64),
  tensor([], dtype=torch.int64),
  tensor([], size=(0, 1))),
 2: (tensor([], dtype=torch.int64),
  tensor([], dtype=torch.int64),
  tensor([], dtype=torch.int64),
  tensor([], dtype=torch.int64),
  tensor([], size=(0, 1))),
 3: (tensor([], dtype=torch.int64),
  tensor([], dtype=torch.int64),
  tensor([], dtype=torch.int64),
  tensor([], dtype=torch.int64),
  tensor([], size=(0, 1))),
 4: (tensor([], dtype=torch.int64),
  tensor([], dtype=torch.int64),
  tensor([], dtype=torch.int64),
  tensor([], dtype=torch.int64),
  tensor([], size=(0, 1))),
 5: (tensor([], dtype=torch.int64),
  tensor([], dtype=torch.int64),
  tensor([], dtype=torch.int64),
  tensor([], dtype=torch.int64),
  tensor([], size=(0, 1))),
 6: (tensor([], dtype=

In [10]:
# now get memory again - only nodes with interactions should've changed
memory, last_update, loss = mem(n_id)  # test .forward()

msg_s tensor([[0.0000, 0.9801, 0.0000, 0.0000, 0.0000, 0.1751, 0.3486, 1.1994, 1.4253,
         0.3512, 1.6287, 0.2486, 0.5403, 0.0000, 0.0036, 0.0000, 0.0000],
        [0.0000, 1.1626, 0.0000, 0.0000, 0.0000, 0.1152, 0.4266, 1.1952, 0.9019,
         0.2167, 1.7127, 0.2165, 1.1999, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.6544, 1.2596, 0.4593, 1.9518, 0.8186,
         0.0000, 1.2689, 0.1537, 0.0000, 0.0000, 0.2257, 0.0000, 0.0000],
        [0.0000, 1.1877, 0.0000, 0.0000, 0.0000, 0.0000, 0.1473, 0.4406, 0.0000,
         0.0000, 1.8169, 0.0000, 2.3194, 0.0000, 0.1493, 0.0070, 0.0000]],
       grad_fn=<ReluBackward0>)
msg_d tensor([[0.0000, 1.1877, 0.0000, 0.0000, 0.0000, 0.0000, 0.1473, 0.4406, 0.0000,
         0.0000, 1.8169, 0.0000, 2.3194, 0.0000, 0.1493, 0.0070, 0.0000],
        [0.0000, 0.9801, 0.0000, 0.0000, 0.0000, 0.1751, 0.3486, 1.1994, 1.4253,
         0.3512, 1.6287, 0.2486, 0.5403, 0.0000, 0.0036, 0.0000, 0.0000],
        [0.0000, 1.1626, 

In [12]:
# had interactions
memory[[0,1,2,3,8,9]]

tensor([[ 0.8160,  0.6607, -2.3491, -1.7818,  0.2681],
        [ 0.4392,  0.3641,  0.8449, -0.3775,  0.3455],
        [-0.1819,  0.3413, -0.9443,  0.3417,  0.2412],
        [ 0.8868,  0.5851,  0.3247,  0.6831,  2.0582],
        [ 0.2796,  0.4626, -0.7522,  0.7771, -0.1053],
        [ 0.7553,  0.2005,  0.1919, -0.5854,  1.1414]],
       grad_fn=<IndexBackward0>)

In [13]:
# didn't have interactions - should still be using init memory
memory[[4,5,6,7]]

tensor([[-0.0807, -0.0544, -1.4065, -0.1904,  0.6187],
        [ 0.8137,  1.2687,  0.9926,  0.7229, -0.0804],
        [-1.1660, -0.5671,  0.2012, -1.1957,  0.2869],
        [-1.3550,  0.0471,  0.1788, -0.9293,  0.6383]],
       grad_fn=<IndexBackward0>)

In [14]:
# should only be updated for nodes in transactions
last_update

tensor([ 1,  1,  1,  1, -1, -1, -1, -1,  1,  1])

In [15]:
# store latest memory, reset message store
mem._update_memory(n_id)
mem._reset_message_store()

msg_s tensor([[0.0000, 0.9801, 0.0000, 0.0000, 0.0000, 0.1751, 0.3486, 1.1994, 1.4253,
         0.3512, 1.6287, 0.2486, 0.5403, 0.0000, 0.0036, 0.0000, 0.0000],
        [0.0000, 1.1626, 0.0000, 0.0000, 0.0000, 0.1152, 0.4266, 1.1952, 0.9019,
         0.2167, 1.7127, 0.2165, 1.1999, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.6544, 1.2596, 0.4593, 1.9518, 0.8186,
         0.0000, 1.2689, 0.1537, 0.0000, 0.0000, 0.2257, 0.0000, 0.0000],
        [0.0000, 1.1877, 0.0000, 0.0000, 0.0000, 0.0000, 0.1473, 0.4406, 0.0000,
         0.0000, 1.8169, 0.0000, 2.3194, 0.0000, 0.1493, 0.0070, 0.0000]],
       grad_fn=<ReluBackward0>)
msg_d tensor([[0.0000, 1.1877, 0.0000, 0.0000, 0.0000, 0.0000, 0.1473, 0.4406, 0.0000,
         0.0000, 1.8169, 0.0000, 2.3194, 0.0000, 0.1493, 0.0070, 0.0000],
        [0.0000, 0.9801, 0.0000, 0.0000, 0.0000, 0.1751, 0.3486, 1.1994, 1.4253,
         0.3512, 1.6287, 0.2486, 0.5403, 0.0000, 0.0036, 0.0000, 0.0000],
        [0.0000, 1.1626, 

In [16]:
# should be equal when node hasn't had interactions
mem.memory[:, :5] == mem.init_memory(n_id)

tensor([[False, False, False, False, False],
        [False, False, False, False, False],
        [False, False, False, False, False],
        [False, False, False, False, False],
        [ True,  True,  True,  True,  True],
        [ True,  True,  True,  True,  True],
        [ True,  True,  True,  True,  True],
        [ True,  True,  True,  True,  True],
        [False, False, False, False, False],
        [False, False, False, False, False]])

In [17]:
# should be empty
mem.msg_s_store

{0: (tensor([], dtype=torch.int64),
  tensor([], dtype=torch.int64),
  tensor([], dtype=torch.int64),
  tensor([], dtype=torch.int64),
  tensor([], size=(0, 1))),
 1: (tensor([], dtype=torch.int64),
  tensor([], dtype=torch.int64),
  tensor([], dtype=torch.int64),
  tensor([], dtype=torch.int64),
  tensor([], size=(0, 1))),
 2: (tensor([], dtype=torch.int64),
  tensor([], dtype=torch.int64),
  tensor([], dtype=torch.int64),
  tensor([], dtype=torch.int64),
  tensor([], size=(0, 1))),
 3: (tensor([], dtype=torch.int64),
  tensor([], dtype=torch.int64),
  tensor([], dtype=torch.int64),
  tensor([], dtype=torch.int64),
  tensor([], size=(0, 1))),
 4: (tensor([], dtype=torch.int64),
  tensor([], dtype=torch.int64),
  tensor([], dtype=torch.int64),
  tensor([], dtype=torch.int64),
  tensor([], size=(0, 1))),
 5: (tensor([], dtype=torch.int64),
  tensor([], dtype=torch.int64),
  tensor([], dtype=torch.int64),
  tensor([], dtype=torch.int64),
  tensor([], size=(0, 1))),
 6: (tensor([], dtype=

In [18]:
# shouldn't have changed
mem.last_update

tensor([ 1,  1,  1,  1, -1, -1, -1, -1,  1,  1])

# Test Neighbor Loader

In [19]:
# Test init
num_nodes = 9
num_neighbors = 2
all_nid = torch.arange(9).long()
fid = torch.arange(6).long()  # firm ids
pid = torch.arange(6,9).long()  # prod ids
neighbor_loader = LastNeighborLoaderTGNPL(num_nodes, size=num_neighbors)
print(neighbor_loader.neighbors.shape)
print(neighbor_loader.e_id.shape)
print(neighbor_loader._assoc.shape)
self = neighbor_loader

torch.Size([9, 2])
torch.Size([9, 2])
torch.Size([9])


In [20]:
# Test insert
src = torch.Tensor([0, 1, 2]).to(torch.long)
dst = torch.Tensor([3, 4, 5]).to(torch.long)
prod = torch.Tensor([6, 7, 8]).to(torch.long)
neighbor_loader.insert(src, dst, prod)

# Ground truth neighbors
# 0: 6; 1: 7; 2: 8; 3: 6; 4: 7; 5: 8; 6: 3,0; 7: 4,1; 8: 5,2
for i in all_nid:
    print(i, self.neighbors[i], self.e_id[i])

tensor(0) tensor([        6, 139976432]) tensor([ 0, -1])
tensor(1) tensor([7, 0]) tensor([ 2, -1])
tensor(2) tensor([              8, 139644797086728]) tensor([ 4, -1])
tensor(3) tensor([              6, 139644808583664]) tensor([ 1, -1])
tensor(4) tensor([        7, 140002544]) tensor([ 3, -1])
tensor(5) tensor([        8, 139966464]) tensor([ 5, -1])
tensor(6) tensor([3, 0]) tensor([1, 0])
tensor(7) tensor([4, 1]) tensor([3, 2])
tensor(8) tensor([5, 2]) tensor([5, 4])


In [21]:
# Test _call_
n_id, edge_index, e_id = neighbor_loader(fid, pid)
# Ground truth edges (x2, for all)
# 0-6, 3-6, 1-7, 4-7, 2-8, 5-8
print(n_id, n_id.shape)
print(edge_index.shape, e_id.shape)
for i in range(edge_index.size(1)):
    print(f'{edge_index[0,i]} -> {edge_index[1,i]}; e_id = {e_id[i]}')

tensor([0, 1, 2, 3, 4, 5, 6, 7, 8]) torch.Size([9])
torch.Size([2, 12]) torch.Size([12])
6 -> 0; e_id = 0
7 -> 1; e_id = 2
8 -> 2; e_id = 4
6 -> 3; e_id = 1
7 -> 4; e_id = 3
8 -> 5; e_id = 5
3 -> 6; e_id = 1
0 -> 6; e_id = 0
4 -> 7; e_id = 3
1 -> 7; e_id = 2
5 -> 8; e_id = 5
2 -> 8; e_id = 4


In [22]:
# Test insert again
src = torch.Tensor([0]).to(torch.long)
dst = torch.Tensor([1]).to(torch.long)
prod = torch.Tensor([6]).to(torch.long)
neighbor_loader.insert(src, dst, prod)

# Ground truth neighbors - 0 and 1 have extra neighbor, 6 replaced old neighbors with new neighbors
# 0: 6,6; 1: 6,7; 2: 8; 3: 6; 4: 7; 5: 8; 6: 1,0; 7: 4,1; 8: 5,2
for i in all_nid:
    print(i, self.neighbors[i], self.e_id[i])  # This is correct

tensor(0) tensor([6, 6]) tensor([6, 0])
tensor(1) tensor([6, 7]) tensor([7, 2])
tensor(2) tensor([              8, 139644797086728]) tensor([ 4, -1])
tensor(3) tensor([              6, 139644808583664]) tensor([ 1, -1])
tensor(4) tensor([        7, 140002544]) tensor([ 3, -1])
tensor(5) tensor([        8, 139966464]) tensor([ 5, -1])
tensor(6) tensor([1, 0]) tensor([7, 6])
tensor(7) tensor([4, 1]) tensor([3, 2])
tensor(8) tensor([5, 2]) tensor([5, 4])


In [23]:
# Test _call_
n_id, edge_index, e_id = neighbor_loader(fid, pid)
# Ground truth edges (x2, for all except first 0-6 and 3-6)
# 0-6, 3-6, 1-7, 4-7, 2-8, 5-8, 0-6, 1-6
print(n_id, n_id.shape)
print(edge_index.shape, e_id.shape)
for i in range(edge_index.size(1)):
    print(f'{edge_index[0,i]} -> {edge_index[1,i]}; e_id = {e_id[i]}')  # This is correct

tensor([0, 1, 2, 3, 4, 5, 6, 7, 8]) torch.Size([9])
torch.Size([2, 14]) torch.Size([14])
6 -> 0; e_id = 6
6 -> 0; e_id = 0
6 -> 1; e_id = 7
7 -> 1; e_id = 2
8 -> 2; e_id = 4
6 -> 3; e_id = 1
7 -> 4; e_id = 3
8 -> 5; e_id = 5
1 -> 6; e_id = 7
0 -> 6; e_id = 6
4 -> 7; e_id = 3
1 -> 7; e_id = 2
5 -> 8; e_id = 5
2 -> 8; e_id = 4


In [24]:
# Test calling a subset of nodes
n_id, edge_index, e_id = neighbor_loader(torch.Tensor([]).long(), torch.Tensor([6]).long())
print(n_id)
for i in range(edge_index.size(1)):
    print(f'{edge_index[0,i]} -> {edge_index[1,i]}; e_id = {e_id[i]}')  # This is correct - reindexing matches up

tensor([0, 1, 6])
1 -> 2; e_id = 7
0 -> 2; e_id = 6


# Test inventory module

In [12]:
num_firms = 3
num_prods = 3
module = TGNPLInventory(num_firms, num_prods, learn_att_direct=True)
print(module.inventory)
print(module.att_weights)

tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]])
Parameter containing:
tensor([[0.5877, 0.1735, 0.7540],
        [0.1679, 0.9280, 0.9383],
        [0.4877, 0.2799, 0.7242]], requires_grad=True)


In [3]:
src = t.Tensor([1]).long()
dst = t.Tensor([2]).long()
prod = t.Tensor([5]).long()
raw_msg = t.Tensor([1]).reshape(-1, 1)
module._compute_totals_per_firm_and_product(src, prod, raw_msg)

tensor([[0., 0., 0.],
        [0., 0., 1.],
        [0., 0., 0.]])

In [4]:
module._compute_totals_per_firm_and_product(dst, prod, raw_msg)

tensor([[0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 1.]])

In [5]:
loss = module(src, dst, prod, raw_msg)
print(module.inventory)

tensor([[1.0000, 1.0000, 1.0000],
        [0.0773, 0.8896, 0.8109],
        [1.0000, 1.0000, 2.0000]], grad_fn=<ClampBackward1>)


In [6]:
loss

tensor(-1.2222, grad_fn=<DivBackward0>)

In [13]:
# need products 3 and 4 to make 5
# att weights should have nonzero in 3,1 and/or 3,2 - this works
opt = torch.optim.Adam(module.parameters())
losses = []
for i in range(0, 5001):
    opt.zero_grad()
    if (i % 2) == 0:
        # 1 sells 5 to 2
        src = t.Tensor([1]).long()
        dst = t.Tensor([2]).long()
        prod = t.Tensor([5]).long()
        raw_msg = t.Tensor([1]).reshape(-1, 1)
    else:
        # 1 buys 3 and 4 from 0
        src = t.Tensor([0, 0]).long()
        dst = t.Tensor([1, 1]).long()
        prod = t.Tensor([3, 4]).long()
        raw_msg = t.Tensor([2, 4]).reshape(-1, 1)
    if (i%100) == 0:
        print('i=%d -> loss = %.3f' % (i, float(loss)))
        print(module.att_weights.data)
        print()
    loss = module(src, dst, prod, raw_msg)
    loss.backward(retain_graph=False)
    opt.step()
    module.detach()
    losses.append(float(loss))

i=0 -> loss = -1.259
tensor([[0.5877, 0.1735, 0.7540],
        [0.1679, 0.9280, 0.9383],
        [0.4877, 0.2799, 0.7242]])

i=100 -> loss = 44.634
tensor([[0.5176, 0.1034, 0.6839],
        [0.0977, 0.8579, 0.8682],
        [0.5585, 0.3507, 0.6566]])

i=200 -> loss = 38.898
tensor([[0.4467, 0.0326, 0.6130],
        [0.0269, 0.7871, 0.7974],
        [0.6291, 0.4213, 0.5856]])

i=300 -> loss = 34.278
tensor([[ 0.3760, -0.0068,  0.5423],
        [-0.0068,  0.7163,  0.7266],
        [ 0.6998,  0.4919,  0.5147]])

i=400 -> loss = 30.457
tensor([[ 0.3052, -0.0068,  0.4715],
        [-0.0068,  0.6456,  0.6559],
        [ 0.7704,  0.5626,  0.4439]])

i=500 -> loss = 26.637
tensor([[ 0.2345, -0.0068,  0.4008],
        [-0.0068,  0.5748,  0.5851],
        [ 0.8411,  0.6333,  0.3731]])

i=600 -> loss = 22.818
tensor([[ 0.1637, -0.0068,  0.3301],
        [-0.0068,  0.5041,  0.5144],
        [ 0.9118,  0.7040,  0.3023]])

i=700 -> loss = 18.998
tensor([[ 0.0930, -0.0068,  0.2593],
        [-0.0068,

In [18]:
# test learning attention weights with product embeddings
# use one-hot, should produce the same results
module = TGNPLInventory(num_firms, num_prods, learn_att_direct=False, emb_dim=3)
print(module.prod_bilinear)
embs = torch.eye(3)
embs

Parameter containing:
tensor([[0.3364, 0.5904, 0.8218],
        [0.6400, 0.1977, 0.7878],
        [0.6339, 0.0069, 0.1946]], requires_grad=True)


tensor([[1., 0., 0.],
        [0., 1., 0.],
        [0., 0., 1.]])

In [21]:
# need products 3 and 4 to make 5
# att weights should have nonzero in 3,1 and/or 3,2 - this works
opt = torch.optim.Adam(module.parameters())
losses = []
for i in range(0, 5001):
    opt.zero_grad()
    if (i % 2) == 0:
        # 1 sells 5 to 2
        src = t.Tensor([1]).long()
        dst = t.Tensor([2]).long()
        prod = t.Tensor([5]).long()
        raw_msg = t.Tensor([1]).reshape(-1, 1)
    else:
        # 1 buys 3 and 4 from 0
        src = t.Tensor([0, 0]).long()
        dst = t.Tensor([1, 1]).long()
        prod = t.Tensor([3, 4]).long()
        raw_msg = t.Tensor([2, 4]).reshape(-1, 1)
    if (i%100) == 0:
        print('i=%d -> loss = %.3f' % (i, float(loss)))
        print(module.prod_bilinear.data)
        print()
    loss = module(src, dst, prod, raw_msg, embs)
    loss.backward(retain_graph=False)
    opt.step()
    module.detach()
    losses.append(float(loss))

i=0 -> loss = -1.856
tensor([[0.3364, 0.5904, 0.8218],
        [0.6400, 0.1977, 0.7878],
        [0.6339, 0.0069, 0.1946]])

i=100 -> loss = 39.378
tensor([[0.2663, 0.5203, 0.7517],
        [0.5699, 0.1276, 0.7177],
        [0.7047, 0.0777, 0.1356]])

i=200 -> loss = 33.642
tensor([[0.1955, 0.4495, 0.6809],
        [0.4991, 0.0568, 0.6468],
        [0.7753, 0.1483, 0.0627]])

i=300 -> loss = 28.148
tensor([[ 0.1247,  0.3787,  0.6101],
        [ 0.4283, -0.0058,  0.5761],
        [ 0.8460,  0.2190, -0.0051]])

i=400 -> loss = 23.690
tensor([[ 0.0540,  0.3080,  0.5393],
        [ 0.3575, -0.0067,  0.5053],
        [ 0.9166,  0.2896, -0.0069]])

i=500 -> loss = 19.378
tensor([[-0.0060,  0.2372,  0.4686],
        [ 0.2868, -0.0067,  0.4346],
        [ 0.9873,  0.3603, -0.0069]])

i=600 -> loss = 15.558
tensor([[-0.0065,  0.1665,  0.3979],
        [ 0.2161, -0.0067,  0.3639],
        [ 1.0580,  0.4310, -0.0069]])

i=700 -> loss = 11.739
tensor([[-0.0065,  0.0958,  0.3271],
        [ 0.1453,

# Test data + HyperEdgebank

In [4]:
import wandb
import math
import timeit
from tqdm import tqdm
import json

import os
import os.path as osp
from pathlib import Path
import numpy as np

import torch
from sklearn.metrics import average_precision_score, roc_auc_score
from torch.nn import Linear

from torch_geometric.datasets import JODIEDataset
from torch_geometric.loader import TemporalDataLoader

from torch_geometric.nn import TransformerConv

# internal imports
from tgb.utils.utils import *
from tgb.linkproppred.evaluate import Evaluator
from modules.decoder import LinkPredictorTGNPL
from modules.emb_module import GraphAttentionEmbedding
from modules.msg_func import TGNPLMessage
from modules.msg_agg import MeanAggregator
from modules.neighbor_loader import LastNeighborLoaderTGNPL
from modules.memory_module import TGNPLMemory, StaticMemory
from modules.early_stopping import  EarlyStopMonitor
from modules.hyper_edgebank import HyperEdgeBankPredictor, test_edgebank
from tgb.linkproppred.dataset_pyg import PyGLinkPropPredDataset, PyGLinkPropPredDatasetHyper

## Logistic data

In [75]:
DATA = 'tgbl-hypergraph'
with open(f"/lfs/turing1/0/{os.getlogin()}/supply-chains/TGB/tgb/datasets/{DATA.replace('-', '_')}/{DATA}_meta.json","r") as file:
    METADATA = json.load(file)
    NUM_NODES = len(METADATA["id2entity"])
METADATA.keys()

dict_keys(['product_threshold', 'id2entity', 'train_max_ts', 'val_max_ts', 'test_max_ts'])

In [76]:
METADATA['train_max_ts']

272

In [77]:
METADATA['test_max_ts']

364

In [78]:
NUM_FIRMS = METADATA["product_threshold"]
NUM_PRODUCTS = NUM_NODES - NUM_FIRMS
NUM_FIRMS, NUM_PRODUCTS

(7257, 2852)

In [79]:
# load dataset
device = "cpu"
dataset = PyGLinkPropPredDatasetHyper(name=DATA, root="datasets")
train_mask = dataset.train_mask
val_mask = dataset.val_mask
test_mask = dataset.test_mask
data = dataset.get_TemporalData()
data = data.to(device)

# for evaluation
neg_sampler = dataset.negative_sampler
dataset.load_val_ns()  # load validation negative samples
metric = dataset.eval_metric
evaluator = Evaluator(name=DATA)

Dataset tgbl-hypergraph url not found, download not supported yet.
file found, skipping download
Dataset directory is  /lfs/turing1/0/serinac/supply-chains/TGB/tgb/datasets/tgbl_hypergraph
loading processed file


In [80]:
data

TemporalData(src=[305277], dst=[305277], t=[305277], msg=[305277, 1], prod=[305277], y=[305277])

In [81]:
data.src[:10]

tensor([2246, 6686, 3109, 3791, 2124, 4738, 5023, 6768, 1081,  708])

In [82]:
data.dst[:10]

tensor([1727, 4512, 5068, 6941, 2609, 4240, 1410, 4512, 3544, 5851])

In [83]:
data.prod[:10]

tensor([8448, 8055, 9212, 7338, 8490, 9519, 9689, 8055, 7342, 8440])

In [84]:
data.t[:10]

tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0])

In [85]:
data.t[train_mask]

tensor([  0,   0,   0,  ..., 272, 272, 272])

In [86]:
data.t[test_mask]

tensor([318, 318, 318,  ..., 364, 364, 364])

In [87]:
BATCH_SIZE = 200
train_data = data[train_mask]
val_data = data[val_mask]
test_data = data[test_mask]

train_loader = TemporalDataLoader(train_data, batch_size=BATCH_SIZE)
val_loader = TemporalDataLoader(val_data, batch_size=BATCH_SIZE)
test_loader = TemporalDataLoader(test_data, batch_size=BATCH_SIZE)
print(len(train_loader))

1074


In [88]:
# Ensure to only sample actual source, product, or destination nodes as negatives.
min_src_idx, max_src_idx = int(data.src.min()), int(data.src.max())
min_prod_idx, max_prod_idx = int(data.prod.min()), int(data.prod.max())
min_dst_idx, max_dst_idx = int(data.dst.min()), int(data.dst.max())
print(min_src_idx, max_src_idx)
print(min_prod_idx, max_prod_idx)
print(min_dst_idx, max_dst_idx)

0 7256
7257 10108
3 7253


## Test edgebank

In [89]:
edgebank = HyperEdgeBankPredictor(NUM_FIRMS, NUM_PRODUCTS, consecutive=True)

In [90]:
idx = edgebank.convert_triplet_to_index(train_data.src, train_data.dst, train_data.prod)
print(idx[:10])
src, dst, prod = edgebank.convert_index_to_triplet(idx)
assert (train_data.src == src).all()
assert (train_data.dst == dst).all()
assert (train_data.prod == prod).all()

tensor([ 46490307739, 138392770326,  64361316967,  78481986337,  43967793637,
         98074310174, 103964873924, 140089921374,  22383525657,  14670138747])


In [91]:
edgebank.fit(train_data.src, train_data.dst, train_data.prod)

Fit on 214674 edges; found 38756 unique


In [92]:
edgebank.predict(val_data.src[:10], val_data.dst[:10], val_data.prod[:10])

tensor([ 11.,  61.,   1., 213.,   0., 172.,   3.,   1., 131.,   7.])

In [93]:
y_pred = edgebank.predict(val_data.src, val_data.dst, val_data.prod, use_counts=True)
torch.sum(y_pred > 0), len(y_pred)

(tensor(38147), 44991)

In [94]:
test_edgebank(val_loader, neg_sampler, "val", evaluator, metric, edgebank, use_counts=False,
              use_prev_sampling=True)

100%|██████████████████████████████████████████████████████████████████████████████████| 225/225 [01:35<00:00,  2.36it/s]


0.3399311304092407

In [95]:
test_edgebank(val_loader, neg_sampler, "val", evaluator, metric, edgebank, use_counts=True,
              use_prev_sampling=True)

100%|██████████████████████████████████████████████████████████████████████████████████| 225/225 [00:13<00:00, 16.61it/s]


0.6819248199462891

## Synthetic data

In [5]:
DATA = 'tgbl-hypergraph_synthetic'
with open(f"/lfs/turing1/0/{os.getlogin()}/supply-chains/TGB/tgb/datasets/{DATA.replace('-', '_')}/{DATA}_meta.json","r") as file:
    METADATA = json.load(file)
    NUM_NODES = len(METADATA["id2entity"])
METADATA.keys()

dict_keys(['product_threshold', 'id2entity', 'train_max_ts', 'val_max_ts', 'test_max_ts'])

In [6]:
METADATA['train_max_ts']

105

In [7]:
METADATA['test_max_ts']

149

In [8]:
NUM_FIRMS = METADATA["product_threshold"]
NUM_PRODUCTS = NUM_NODES - NUM_FIRMS
NUM_FIRMS, NUM_PRODUCTS

(21, 97)

In [9]:
# load dataset
device = "cpu"
dataset = PyGLinkPropPredDatasetHyper(name=DATA, root="datasets", use_prev_sampling=False)
train_mask = dataset.train_mask
val_mask = dataset.val_mask
test_mask = dataset.test_mask
data = dataset.get_TemporalData()
data = data.to(device)

# for evaluation
neg_sampler = dataset.negative_sampler
dataset.load_val_ns()  # load validation negative samples
metric = dataset.eval_metric
evaluator = Evaluator(name=DATA)

Dataset tgbl-hypergraph_synthetic url not found, download not supported yet.
file found, skipping download
Dataset directory is  /lfs/turing1/0/serinac/supply-chains/TGB/tgb/datasets/tgbl_hypergraph_synthetic
loading processed file


In [10]:
data

TemporalData(src=[74887], dst=[74887], t=[74887], msg=[74887, 1], prod=[74887], y=[74887])

In [11]:
data.src[:10]

tensor([7, 7, 7, 3, 3, 3, 3, 3, 3, 3])

In [12]:
data.dst[:10]

tensor([ 9, 12,  9, 12,  7, 16, 12, 18, 13,  2])

In [13]:
data.prod[:10]

tensor([ 66,  97,  66, 104, 104, 104, 104, 104, 104, 104])

In [14]:
data.t[:10]

tensor([1, 2, 3, 3, 3, 3, 3, 3, 3, 3])

In [15]:
data.t[train_mask]

tensor([  1,   2,   3,  ..., 105, 105, 105])

In [16]:
data.t[test_mask]

tensor([128, 128, 128,  ..., 149, 149, 149])

In [17]:
BATCH_SIZE = 10000
train_data = data[train_mask]
val_data = data[val_mask]
test_data = data[test_mask]

train_loader = TemporalDataLoader(train_data, batch_size=BATCH_SIZE)
val_loader = TemporalDataLoader(val_data, batch_size=BATCH_SIZE)
test_loader = TemporalDataLoader(test_data, batch_size=BATCH_SIZE)
print(len(train_loader))

6


In [18]:
# check that we can use neg_sampler
batch = next(iter(val_loader))
pos_src, pos_prod, pos_dst, pos_t, pos_msg = batch.src, batch.prod, batch.dst, batch.t, batch.msg
neg_batch_list = neg_sampler.query_batch(pos_src, pos_dst, pos_prod, pos_t, split_mode='val')

In [19]:
# Ensure to only sample actual source, product, or destination nodes as negatives.
min_src_idx, max_src_idx = int(data.src.min()), int(data.src.max())
min_prod_idx, max_prod_idx = int(data.prod.min()), int(data.prod.max())
min_dst_idx, max_dst_idx = int(data.dst.min()), int(data.dst.max())
print(min_src_idx, max_src_idx)
print(min_prod_idx, max_prod_idx)
print(min_dst_idx, max_dst_idx)

0 20
21 117
0 20


## Test edgebank

In [20]:
edgebank = HyperEdgeBankPredictor(NUM_FIRMS, NUM_PRODUCTS, consecutive=True)

In [21]:
idx = edgebank.convert_triplet_to_index(train_data.src, train_data.dst, train_data.prod)
print(idx[:10])
src, dst, prod = edgebank.convert_index_to_triplet(idx)
assert (train_data.src == src).all()
assert (train_data.dst == dst).all()
assert (train_data.prod == prod).all()

tensor([15177, 15499, 15177,  7358,  6873,  7746,  7358,  7940,  7455,  6388])


In [22]:
edgebank.fit(train_data.src, train_data.dst, train_data.prod)

Fit on 52736 edges; found 458 unique


In [23]:
edgebank.predict(val_data.src[:10], val_data.dst[:10], val_data.prod[:10])

tensor([106., 106., 106., 105., 315., 210., 315., 105., 315., 210.])

In [24]:
y_pred = edgebank.predict(val_data.src, val_data.dst, val_data.prod, use_counts=True)
torch.sum(y_pred > 0), len(y_pred)

(tensor(11100), 11100)

In [30]:
test_edgebank(val_loader, neg_sampler, "val", evaluator, metric, edgebank, use_counts=False,
              use_prev_sampling=False)

100%|██████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:01<00:00,  1.14it/s]


0.38609811663627625

In [31]:
test_edgebank(val_loader, neg_sampler, "val", evaluator, metric, edgebank, use_counts=True,
              use_prev_sampling=False)

100%|██████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:05<00:00,  2.51s/it]


0.5654792785644531

In [33]:
dataset.load_test_ns()
test_edgebank(test_loader, neg_sampler, "test", evaluator, metric, edgebank, use_counts=False,
              use_prev_sampling=False)

100%|██████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:01<00:00,  1.07it/s]


0.4747615158557892

In [34]:
dataset.load_test_ns()
test_edgebank(test_loader, neg_sampler, "test", evaluator, metric, edgebank, use_counts=True,
              use_prev_sampling=False)

100%|██████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:04<00:00,  2.10s/it]


0.6792834997177124