In [6]:
%load_ext autoreload
%autoreload 2

import torch as t
import numpy as np

from memory_module import TGNPLMemory
from msg_func import TGNPLMessage
from msg_agg import *

from neighbor_loader import LastNeighborLoader, LastNeighborLoaderTGNPL

# Test TGNPLMemory

In [None]:
num_nodes = 10
num_prods = 2
raw_msg_dim = 1
state_dim = 10
time_dim = 1
message_module = TGNPLMessage(raw_msg_dim, state_dim+num_prods, time_dim)
aggregator_module = MeanAggregator()

In [None]:
# test initialization
mem = TGNPLMemory(num_nodes,
        num_prods,
        raw_msg_dim,
        state_dim,
        time_dim,
        message_module,
        aggregator_module,
        state_updater_cell="gru",
        use_inventory=True,
        debt_penalty=10,
        consumption_reward=5,
        debug=True)

In [None]:
# before any interactions have been added
np.random.seed(0)
n_id = np.random.choice(num_nodes, size=5, replace=False)
n_id = t.from_numpy(n_id)
print(n_id)
memory, last_update, loss = mem(n_id)

In [None]:
last_update

In [None]:
# try adding interactions
# expectation: 
# 1) get_updated_memory will print 6 nodes, but memories shouldn't be updated yet
# 2) _update_msg_store should update all three msg stores
src = t.Tensor([0, 0, 1, 2]).long()
dst = t.Tensor([3, 3, 3, 0]).long()
prod = t.Tensor([8, 8, 8, 9]).long()
time = t.Tensor(np.ones(4)).long()
raw_msg = t.Tensor([10, 20, 5, 13]).reshape(-1, 1)
mem.update_state(src, dst, prod, time, raw_msg)

In [None]:
mem.msg_s_store

In [None]:
mem.msg_d_store

In [None]:
mem.msg_p_store

In [None]:
# now get memory again - only nodes with interactions should've changed
n_id = t.from_numpy(np.arange(num_nodes))
memory, last_update, loss = mem(n_id)  # test .forward()

In [None]:
# 2 and 9, same state, different inventory
# 2 supplied product 9
# product 9 has no inventory
memory[[2,9]]

In [None]:
# 3 and 8, same state, different inventory
# 3 received exactly 35 of product 8
# product 8 has no inventory
memory[[3,8]]

In [None]:
# should be unaffected
memory[[4,5,6,7]]

In [None]:
# should only be updated for nodes in transactions
last_update

## Test attention weight learning

In [None]:
num_nodes = 6
num_prods = 3
raw_msg_dim = 1
state_dim = 2
time_dim = 1
message_module = TGNPLMessage(raw_msg_dim, state_dim+num_prods, time_dim)
aggregator_module = MeanAggregator()

In [None]:
# raw_msg_dim + (3 * memory_dim) + time_dim
message_module.out_channels

In [None]:
n_id = t.arange(0, num_nodes).long()
print(n_id)

In [None]:
mem = TGNPLMemory(num_nodes,
        num_prods,
        raw_msg_dim,
        state_dim,
        time_dim,
        message_module,
        aggregator_module,
        state_updater_cell="gru",
        use_inventory=True,
        debt_penalty=10,
        consumption_reward=5,
        debug=False)
opt = t.optim.Adam(mem.parameters())
for name, param in mem.named_parameters():
    if param.requires_grad:
        print(name, param.data.shape)

In [None]:
t.autograd.set_detect_anomaly(True)
for i in range(1, 31):
#     opt.zero_grad()
    if (i % 2) == 0:
        # 1 sells 5 to 2
        src = t.Tensor([1]).long()
        dst = t.Tensor([2]).long()
        prod = t.Tensor([5]).long()
        time = t.Tensor([i]).long()
        raw_msg = t.Tensor([1]).reshape(-1, 1)
    else:
        # 1 buys 3 and 4 from 0
        src = t.Tensor([0, 0]).long()
        dst = t.Tensor([1, 1]).long()
        prod = t.Tensor([3, 4]).long()
        time = t.Tensor([i, i]).long()
        raw_msg = t.Tensor([2, 4]).reshape(-1, 1)

    print('iter', i)
    mem.update_state(src, dst, prod, time, raw_msg)
    memory, last_update, loss = mem(n_id)
    print('loss', loss)
    prod_emb = mem.memory[mem.num_firms:, :mem.state_dim]
    # prod_emb = t.ones(mem.num_prods, mem.state_dim)
    output_emb = mem.output_l(prod_emb)  # num_products x emb_dim
    input_emb = mem.input_l(prod_emb)  # num_products x emb_dim
    att_weights = output_emb @ input_emb.T  # num_products x num_products
    att_weights = t.nn.ReLU(inplace=False)(att_weights)
    print('att weights', att_weights)
    loss.backward()
    opt.step()

## Test Neighbor Loader

In [2]:
neighbor_loader = LastNeighborLoaderTGNPL(9, size=2)

In [3]:
# Test init
print(neighbor_loader.neighbors.shape)
print(neighbor_loader.e_id.shape)
print(neighbor_loader._assoc.shape)
self = neighbor_loader

torch.Size([9, 2])
torch.Size([9, 2])
torch.Size([9])


In [7]:
# Test insert
src = torch.Tensor([0, 1, 2]).to(torch.long)
dst = torch.Tensor([3, 4, 5]).to(torch.long)
prod = torch.Tensor([6, 7, 8]).to(torch.long)

nodes = torch.cat([prod, src, prod, dst], dim=0)
n_id = nodes.unique()
neighbor_loader.insert(src, dst, prod)

print(n_id, self.neighbors[n_id]) # This is correct

tensor([0, 1, 2, 3, 4, 5, 6, 7, 8]) tensor([[              6, 159321811882214],
        [              7, 158772056065462],
        [              8,  88214333292640],
        [              6,       110377216],
        [              7,               0],
        [              8,       112948880],
        [              3,               0],
        [              4,               1],
        [              5,               2]])


In [8]:
# Test _call_
f_id = torch.cat([src, dst]).unique()
p_id = torch.cat([prod]).unique()

n_id, edge_index, e_id = neighbor_loader(f_id, p_id)

# Ground truth: 6 edges 0-6, 3-6, 1-7, 4-7, 2-8, 5-8
print(n_id, n_id.shape)
print(edge_index, edge_index.shape)
print(e_id, e_id.shape)

tensor([0, 1, 2, 3, 4, 5, 6, 7, 8]) torch.Size([9])
tensor([[6, 7, 8, 6, 7, 8, 3, 0, 4, 1, 5, 2],
        [0, 1, 2, 3, 4, 5, 6, 6, 7, 7, 8, 8]]) torch.Size([2, 12])
tensor([0, 1, 2, 3, 4, 5, 3, 0, 4, 1, 5, 2]) torch.Size([12])


In [9]:
# Test insert
src = torch.Tensor([0]).to(torch.long)
dst = torch.Tensor([1]).to(torch.long)
prod = torch.Tensor([6]).to(torch.long)

nodes = torch.cat([prod, src, prod, dst], dim=0)
n_id = nodes.unique()
neighbor_loader.insert(src, dst, prod)

print(n_id, self.neighbors[n_id]) # This is correct

tensor([0, 1, 6]) tensor([[6, 6],
        [6, 7],
        [1, 0]])


In [10]:
# Test _call_
f_id = torch.cat([src, dst]).unique()
p_id = torch.cat([prod]).unique()

n_id, edge_index, e_id = neighbor_loader(f_id, p_id)

# Ground truth: 6 edges 0-6, 1-6, 1-7
print(n_id, n_id.shape)
print(edge_index, edge_index.shape)
print(e_id, e_id.shape)

tensor([0, 1, 6, 7]) torch.Size([4])
tensor([[2, 2, 2, 3, 1, 0],
        [0, 0, 1, 1, 2, 2]]) torch.Size([2, 6])
tensor([6, 0, 7, 1, 7, 6]) torch.Size([6])
