# TGN with MA features, cross entropy loss

The code was adapted from https://github.com/shenyangHuang/TGB

In [None]:
!pip install pandas -q
!pip install torch-geometric -f https://data.pyg.org/whl/torch-2.8.0+cu126.html
!pip install py-tgb -q
!pip install modules


Looking in links: https://data.pyg.org/whl/torch-2.8.0+cu126.html
Collecting torch-geometric
  Downloading torch_geometric-2.7.0-py3-none-any.whl.metadata (63 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.7/63.7 kB[0m [31m3.3 MB/s[0m eta [36m0:00:00[0m
Downloading torch_geometric-2.7.0-py3-none-any.whl (1.3 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m34.5 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: torch-geometric
Successfully installed torch-geometric-2.7.0
  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m91.2/91.2 kB[0m [31m4.1 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m154.4/154.4 kB[0m [31m9.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m12.4/12.4 MB[0m [31m108.6 

In [None]:
import torch
from torch.nn import Linear
import torch.nn.functional as F
from torch.nn.parameter import Parameter
import math
from torch_geometric.nn import TransformerConv
import numpy as np

class NodePredictor(torch.nn.Module):
    def __init__(self, in_dim, out_dim):
        super().__init__()
        self.lin_node = Linear(in_dim, in_dim)
        self.out = Linear(in_dim, out_dim)

    def forward(self, node_embed):
        h = self.lin_node(node_embed)
        h = h.relu()
        h = self.out(h)
        # h = F.log_softmax(h, dim=-1)
        return h

class GraphAttentionEmbedding(torch.nn.Module):
    """
    Reference:
    - https://github.com/pyg-team/pytorch_geometric/blob/master/examples/tgn.py
    """

    def __init__(self, in_channels, out_channels, msg_dim, time_enc):
        super().__init__()
        self.time_enc = time_enc
        edge_dim = msg_dim + time_enc.out_channels
        self.conv = TransformerConv(
            in_channels, out_channels // 2, heads=2, dropout=0.1, edge_dim=edge_dim
        )

    def forward(self, x, last_update, edge_index, t, msg):
        rel_t = last_update[edge_index[0]] - t
        rel_t_enc = self.time_enc(rel_t.to(x.dtype))
        edge_attr = torch.cat([rel_t_enc, msg], dim=-1)
        return self.conv(x, edge_index, edge_attr)


class MAFeatures:
    def __init__(self, num_class, window=7):
        self.num_class = num_class
        self.window = window
        self.dict = {}

    def reset(self):
        self.dict = {}

    def update_dict(self, node_id, label_vec):
        if node_id in self.dict:
            total = self.dict[node_id] * (self.window - 1) + label_vec
            self.dict[node_id] = total / self.window
        else:
            self.dict[node_id] = label_vec

    def query_dict(self, node_id):
        if node_id in self.dict:
            return self.dict[node_id]
        else:
            return np.zeros(self.num_class, dtype=np.float32)

    def batch_query(self, node_ids):
        feats = [self.query_dict(int(n)) for n in node_ids]
        return np.stack(feats, axis=0).astype(np.float32)


In [None]:
# Choose hyperparameters
lr = 0.0001
batch_size = 200
global_hidden_dim = 784
nb_neighbors = 30
window_ma = 7

epochs = 50

In [None]:
from tqdm import tqdm
import torch
import timeit
import matplotlib.pyplot as plt

from torch_geometric.loader import TemporalDataLoader
from torch_geometric.nn import TGNMemory
from torch_geometric.nn.models.tgn import (
    IdentityMessage,
    LastAggregator,
    LastNeighborLoader,
)

from tgb.nodeproppred.dataset_pyg import PyGNodePropPredDataset
from tgb.nodeproppred.evaluate import Evaluator
from tgb.utils.utils import set_random_seed

seed = 1
print ("setting random seed to be", seed)
torch.manual_seed(seed)
set_random_seed(seed)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
name = "tgbn-genre"
dataset = PyGNodePropPredDataset(name=name, root="datasets")
train_mask = dataset.train_mask.to(device)
val_mask = dataset.val_mask.to(device)
test_mask = dataset.test_mask.to(device)

eval_metric = dataset.eval_metric
num_classes = dataset.num_classes
data = dataset.get_TemporalData()
data = data.to(device)

evaluator = Evaluator(name=name)
ma_tracker = MAFeatures(num_class=num_classes, window=window_ma)

train_data = data[train_mask]
val_data = data[val_mask]
test_data = data[test_mask]

# Ensure to only sample actual destination nodes as negatives.
min_dst_idx, max_dst_idx = int(data.dst.min()), int(data.dst.max())


train_loader = TemporalDataLoader(train_data, batch_size=batch_size)
val_loader = TemporalDataLoader(val_data, batch_size=batch_size)
test_loader = TemporalDataLoader(test_data, batch_size=batch_size)

neighbor_loader = LastNeighborLoader(data.num_nodes, size=nb_neighbors, device=device)

memory_dim = time_dim = embedding_dim = global_hidden_dim

memory = TGNMemory(
    data.num_nodes,
    data.msg.size(-1),
    memory_dim,
    time_dim,
    message_module=IdentityMessage(data.msg.size(-1), memory_dim, time_dim),
    aggregator_module=LastAggregator(),
).to(device)

gnn = (
    GraphAttentionEmbedding(
        in_channels=memory_dim,
        out_channels=embedding_dim,
        msg_dim=data.msg.size(-1),
        time_enc=memory.time_enc,
    )
    .to(device)
    .float()
)

node_pred = NodePredictor(in_dim=embedding_dim + num_classes, out_dim=num_classes).to(device)

optimizer = torch.optim.Adam(
    set(memory.parameters()) | set(gnn.parameters()) | set(node_pred.parameters()),
    lr=lr,
)

criterion = torch.nn.CrossEntropyLoss()
# Helper vector to map global node indices to local ones.
assoc = torch.empty(data.num_nodes, dtype=torch.long, device=device)


def plot_curve(scores, out_name):
    plt.plot(scores, color="#e34a33")
    plt.ylabel("score")
    plt.savefig(out_name + ".pdf")
    plt.close()


def process_edges(src, dst, t, msg):
    if src.nelement() > 0:
        # msg = msg.to(torch.float32)
        memory.update_state(src, dst, t, msg)
        neighbor_loader.insert(src, dst)


def train():
    memory.train()
    gnn.train()
    node_pred.train()

    memory.reset_state()  # Start with a fresh memory.
    neighbor_loader.reset_state()  # Start with an empty graph.
    ma_tracker.reset()

    total_loss = 0
    label_t = dataset.get_label_time()  # check when does the first label start
    total_score = 0
    num_label_ts = 0

    for batch in tqdm(train_loader):
        batch = batch.to(device)
        optimizer.zero_grad()
        src, dst, t, msg = batch.src, batch.dst, batch.t, batch.msg

        query_t = batch.t[-1]
        # check if this batch moves to the next day
        if query_t > label_t:
            # find the node labels from the past day
            label_tuple = dataset.get_node_label(query_t)
            label_ts, label_srcs, labels = (
                label_tuple[0],
                label_tuple[1],
                label_tuple[2],
            )
            label_t = dataset.get_label_time()
            label_srcs = label_srcs.to(device)

            # Process all edges that are still in the past day
            previous_day_mask = batch.t < label_t
            process_edges(
                src[previous_day_mask],
                dst[previous_day_mask],
                t[previous_day_mask],
                msg[previous_day_mask],
            )
            # Reset edges to be the edges from tomorrow so they can be used later
            src, dst, t, msg = (
                src[~previous_day_mask],
                dst[~previous_day_mask],
                t[~previous_day_mask],
                msg[~previous_day_mask],
            )

            """
            modified for node property prediction
            1. sample neighbors from the neighbor loader for all nodes to be predicted
            2. extract memory from the sampled neighbors and the nodes
            3. run gnn with the extracted memory embeddings and the corresponding time and message
            """
            n_id = label_srcs
            n_id_neighbors, mem_edge_index, e_id = neighbor_loader(n_id)
            assoc[n_id_neighbors] = torch.arange(n_id_neighbors.size(0), device=device)

            z, last_update = memory(n_id_neighbors)

            z = gnn(
                z,
                last_update,
                mem_edge_index,
                data.t[e_id].to(device),
                data.msg[e_id].to(device),
            )
            z = z[assoc[n_id]]

            # moving-average features (computed BEFORE updating with current labels)
            with torch.no_grad():
                ma_feats_np = ma_tracker.batch_query(label_srcs.detach().cpu().numpy())
            ma_feats = torch.from_numpy(ma_feats_np).to(device)
            z = torch.cat([z, ma_feats], dim=1)

            # loss and metric computation
            pred = node_pred(z)

            loss = criterion(pred, labels.to(device))
            np_pred = pred.cpu().detach().numpy()
            np_true = labels.cpu().detach().numpy()

            input_dict = {
                "y_true": np_true,
                "y_pred": np_pred,
                "eval_metric": [eval_metric],
            }
            result_dict = evaluator.eval(input_dict)
            score = result_dict[eval_metric]
            total_score += score
            num_label_ts += 1

            loss.backward()
            optimizer.step()
            total_loss += float(loss.detach())

            # AFTER using current labels, update moving-average tracker
            with torch.no_grad():
                lbl_vec = labels.to(device)
                for nid, vec in zip(label_srcs, lbl_vec):
                    ma_tracker.update_dict(int(nid.item()), vec.detach().cpu().numpy())

        # Update memory and neighbor loader with ground-truth state.
        process_edges(src, dst, t, msg)
        memory.detach()

    metric_dict = {
        "ce": total_loss / num_label_ts,
    }
    metric_dict[eval_metric] = total_score / num_label_ts
    return metric_dict


@torch.no_grad()
def test(loader):
    memory.eval()
    gnn.eval()
    node_pred.eval()

    label_t = dataset.get_label_time()  # check when does the first label start
    num_label_ts = 0
    total_score = 0

    for batch in tqdm(loader):
        batch = batch.to(device)
        src, dst, t, msg = batch.src, batch.dst, batch.t, batch.msg

        query_t = batch.t[-1]
        if query_t > label_t:
            label_tuple = dataset.get_node_label(query_t)
            if label_tuple is None:
                break
            label_ts, label_srcs, labels = (
                label_tuple[0],
                label_tuple[1],
                label_tuple[2],
            )
            label_t = dataset.get_label_time()
            label_srcs = label_srcs.to(device)

            # Process all edges that are still in the past day
            previous_day_mask = batch.t < label_t
            process_edges(
                src[previous_day_mask],
                dst[previous_day_mask],
                t[previous_day_mask],
                msg[previous_day_mask],
            )
            # Reset edges to be the edges from tomorrow so they can be used later
            src, dst, t, msg = (
                src[~previous_day_mask],
                dst[~previous_day_mask],
                t[~previous_day_mask],
                msg[~previous_day_mask],
            )

            """
            modified for node property prediction
            1. sample neighbors from the neighbor loader for all nodes to be predicted
            2. extract memory from the sampled neighbors and the nodes
            3. run gnn with the extracted memory embeddings and the corresponding time and message
            """
            n_id = label_srcs
            n_id_neighbors, mem_edge_index, e_id = neighbor_loader(n_id)
            assoc[n_id_neighbors] = torch.arange(n_id_neighbors.size(0), device=device)

            z, last_update = memory(n_id_neighbors)
            z = gnn(
                z,
                last_update,
                mem_edge_index,
                data.t[e_id].to(device),
                data.msg[e_id].to(device),
            )
            z = z[assoc[n_id]]

            with torch.no_grad():
                ma_feats_np = ma_tracker.batch_query(label_srcs.detach().cpu().numpy())
            ma_feats = torch.from_numpy(ma_feats_np).to(device)
            z = torch.cat([z, ma_feats], dim=1)

            # loss and metric computation
            pred = node_pred(z)
            np_pred = pred.cpu().detach().numpy()
            np_true = labels.cpu().detach().numpy()

            input_dict = {
                "y_true": np_true,
                "y_pred": np_pred,
                "eval_metric": [eval_metric],
            }
            result_dict = evaluator.eval(input_dict)
            score = result_dict[eval_metric]
            total_score += score
            num_label_ts += 1

            # Update MA tracker with current labels for subsequent timesteps
            with torch.no_grad():
                lbl_vec = labels.to(device)
                for nid, vec in zip(label_srcs, lbl_vec):
                    ma_tracker.update_dict(int(nid.item()), vec.detach().cpu().numpy())

        process_edges(src, dst, t, msg)

    metric_dict = {}
    metric_dict[eval_metric] = total_score / num_label_ts
    return metric_dict


train_curve = []
val_curve = []
test_curve = []
max_val_score = 0  #find the best test score based on validation score
best_test_idx = 0
for epoch in range(1, epochs + 1):
    start_time = timeit.default_timer()
    train_dict = train()
    print("------------------------------------")
    print(f"training Epoch: {epoch:02d}")
    print(train_dict)
    train_curve.append(train_dict[eval_metric])
    print("Training takes--- %s seconds ---" % (timeit.default_timer() - start_time))

    start_time = timeit.default_timer()
    val_dict = test(val_loader)
    print(val_dict)
    val_curve.append(val_dict[eval_metric])
    if (val_dict[eval_metric] > max_val_score):
        max_val_score = val_dict[eval_metric]
        best_test_idx = epoch - 1
    print("Validation takes--- %s seconds ---" % (timeit.default_timer() - start_time))

    start_time = timeit.default_timer()
    test_dict = test(test_loader)
    print(test_dict)
    test_curve.append(test_dict[eval_metric])
    print("Test takes--- %s seconds ---" % (timeit.default_timer() - start_time))
    print("------------------------------------")
    dataset.reset_label_time()


# # code for plotting
# plot_curve(train_curve, "train_curve")
# plot_curve(val_curve, "val_curve")
# plot_curve(test_curve, "test_curve")

max_test_score = test_curve[best_test_idx]
print("------------------------------------")
print("------------------------------------")
print ("best val score: ", max_val_score)
print ("best validation epoch   : ", best_test_idx + 1)
print ("best test score: ", max_test_score)

setting random seed to be 1


17858396it [01:03, 281917.14it/s]
2741936it [00:07, 383488.80it/s]
100%|██████████| 62505/62505 [09:44<00:00, 106.87it/s]


------------------------------------
training Epoch: 01
{'ce': 4.521856108856201, 'ndcg': np.float64(0.37157434051068366)}
Training takes--- 584.8939420599999 seconds ---


100%|██████████| 13394/13394 [01:57<00:00, 114.03it/s]


{'ndcg': np.float64(0.36719611551810566)}
Validation takes--- 117.48027660900016 seconds ---


 99%|█████████▉| 13299/13394 [01:55<00:00, 115.38it/s]


{'ndcg': np.float64(0.3640999508878806)}
Test takes--- 115.26215200699994 seconds ---
------------------------------------


100%|██████████| 62505/62505 [09:42<00:00, 107.24it/s]


------------------------------------
training Epoch: 02
{'ce': 4.36019150390625, 'ndcg': np.float64(0.40828068355617086)}
Training takes--- 582.895898664 seconds ---


100%|██████████| 13394/13394 [01:57<00:00, 113.89it/s]


{'ndcg': np.float64(0.4261283761527538)}
Validation takes--- 117.62499643599972 seconds ---


 99%|█████████▉| 13299/13394 [01:55<00:00, 114.72it/s]


{'ndcg': np.float64(0.4301098877871103)}
Test takes--- 115.93073998399996 seconds ---
------------------------------------


100%|██████████| 62505/62505 [09:41<00:00, 107.40it/s]


------------------------------------
training Epoch: 03
{'ce': 4.157965354728699, 'ndcg': np.float64(0.4578263350797834)}
Training takes--- 582.0121052459999 seconds ---


100%|██████████| 13394/13394 [01:58<00:00, 113.09it/s]


{'ndcg': np.float64(0.45656977195015025)}
Validation takes--- 118.47198430400022 seconds ---


 99%|█████████▉| 13299/13394 [01:56<00:00, 113.70it/s]


{'ndcg': np.float64(0.46084551112178584)}
Test takes--- 116.96484033800016 seconds ---
------------------------------------


100%|██████████| 62505/62505 [09:44<00:00, 106.87it/s]


------------------------------------
training Epoch: 04
{'ce': 4.02698044052124, 'ndcg': np.float64(0.4844837979557951)}
Training takes--- 584.8896495210001 seconds ---


100%|██████████| 13394/13394 [01:58<00:00, 112.66it/s]


{'ndcg': np.float64(0.47497922255574504)}
Validation takes--- 118.92957269099998 seconds ---


 99%|█████████▉| 13299/13394 [01:57<00:00, 113.51it/s]


{'ndcg': np.float64(0.47928917182266134)}
Test takes--- 117.16413573299997 seconds ---
------------------------------------


100%|██████████| 62505/62505 [09:42<00:00, 107.29it/s]


------------------------------------
training Epoch: 05
{'ce': 3.936997595024109, 'ndcg': np.float64(0.5015335758215188)}
Training takes--- 582.5931881699998 seconds ---


100%|██████████| 13394/13394 [01:57<00:00, 113.95it/s]


{'ndcg': np.float64(0.4857372275373781)}
Validation takes--- 117.56869201300015 seconds ---


 99%|█████████▉| 13299/13394 [01:56<00:00, 114.53it/s]


{'ndcg': np.float64(0.48911939714329644)}
Test takes--- 116.12609269600034 seconds ---
------------------------------------


100%|██████████| 62505/62505 [09:39<00:00, 107.94it/s]


------------------------------------
training Epoch: 06
{'ce': 3.8786503314971923, 'ndcg': np.float64(0.5110321209717532)}
Training takes--- 579.0682790670007 seconds ---


100%|██████████| 13394/13394 [01:59<00:00, 112.50it/s]


{'ndcg': np.float64(0.4916451157529126)}
Validation takes--- 119.08641776600052 seconds ---


 99%|█████████▉| 13299/13394 [01:55<00:00, 114.84it/s]


{'ndcg': np.float64(0.4953544813311159)}
Test takes--- 115.81485873300062 seconds ---
------------------------------------


100%|██████████| 62505/62505 [09:40<00:00, 107.71it/s]


------------------------------------
training Epoch: 07
{'ce': 3.839206276893616, 'ndcg': np.float64(0.5172152162842174)}
Training takes--- 580.3133500909998 seconds ---


100%|██████████| 13394/13394 [01:58<00:00, 112.83it/s]


{'ndcg': np.float64(0.4957971148309815)}
Validation takes--- 118.73753571799989 seconds ---


 99%|█████████▉| 13299/13394 [01:55<00:00, 115.58it/s]


{'ndcg': np.float64(0.4994945725947972)}
Test takes--- 115.07064519999949 seconds ---
------------------------------------


100%|██████████| 62505/62505 [09:41<00:00, 107.53it/s]


------------------------------------
training Epoch: 08
{'ce': 3.81198913936615, 'ndcg': np.float64(0.5210417104455003)}
Training takes--- 581.2943427579994 seconds ---


100%|██████████| 13394/13394 [01:57<00:00, 113.60it/s]


{'ndcg': np.float64(0.498176105034062)}
Validation takes--- 117.9313643659998 seconds ---


 99%|█████████▉| 13299/13394 [01:56<00:00, 114.47it/s]


{'ndcg': np.float64(0.5019473071534402)}
Test takes--- 116.18514924800002 seconds ---
------------------------------------


100%|██████████| 62505/62505 [09:42<00:00, 107.29it/s]


------------------------------------
training Epoch: 09
{'ce': 3.7921169610977175, 'ndcg': np.float64(0.5235177201685296)}
Training takes--- 582.6180201360003 seconds ---


100%|██████████| 13394/13394 [01:57<00:00, 113.71it/s]


{'ndcg': np.float64(0.500457509544526)}
Validation takes--- 117.81689057699987 seconds ---


 99%|█████████▉| 13299/13394 [01:55<00:00, 114.73it/s]


{'ndcg': np.float64(0.5046089880793734)}
Test takes--- 115.91721170999972 seconds ---
------------------------------------


100%|██████████| 62505/62505 [09:42<00:00, 107.31it/s]


------------------------------------
training Epoch: 10
{'ce': 3.777357727432251, 'ndcg': np.float64(0.5256840031710102)}
Training takes--- 582.5130851040003 seconds ---


100%|██████████| 13394/13394 [02:00<00:00, 111.41it/s]


{'ndcg': np.float64(0.502617114365475)}
Validation takes--- 120.24282094500086 seconds ---


 99%|█████████▉| 13299/13394 [01:58<00:00, 112.28it/s]


{'ndcg': np.float64(0.506114535015331)}
Test takes--- 118.45363462400019 seconds ---
------------------------------------


100%|██████████| 62505/62505 [09:51<00:00, 105.73it/s]


------------------------------------
training Epoch: 11
{'ce': 3.7647378747940063, 'ndcg': np.float64(0.5274768022080081)}
Training takes--- 591.1762473519993 seconds ---


100%|██████████| 13394/13394 [02:05<00:00, 107.07it/s]


{'ndcg': np.float64(0.5038359420920824)}
Validation takes--- 125.11463265099883 seconds ---


 99%|█████████▉| 13299/13394 [01:57<00:00, 112.85it/s]


{'ndcg': np.float64(0.507569649378847)}
Test takes--- 117.84412163699926 seconds ---
------------------------------------


100%|██████████| 62505/62505 [09:42<00:00, 107.24it/s]


------------------------------------
training Epoch: 12
{'ce': 3.7538866096496584, 'ndcg': np.float64(0.5288926857048982)}
Training takes--- 582.870528554 seconds ---


100%|██████████| 13394/13394 [02:00<00:00, 111.09it/s]


{'ndcg': np.float64(0.5048301342658675)}
Validation takes--- 120.59744090999993 seconds ---


 99%|█████████▉| 13299/13394 [01:56<00:00, 114.55it/s]


{'ndcg': np.float64(0.5091089377529502)}
Test takes--- 116.10508043699883 seconds ---
------------------------------------


100%|██████████| 62505/62505 [09:48<00:00, 106.23it/s]


------------------------------------
training Epoch: 13
{'ce': 3.7439945640563965, 'ndcg': np.float64(0.530394460387942)}
Training takes--- 588.4046057029991 seconds ---


100%|██████████| 13394/13394 [01:59<00:00, 112.10it/s]


{'ndcg': np.float64(0.5058221094459877)}
Validation takes--- 119.50049280800158 seconds ---


 99%|█████████▉| 13299/13394 [01:56<00:00, 113.90it/s]


{'ndcg': np.float64(0.5098491760267668)}
Test takes--- 116.76484201700077 seconds ---
------------------------------------


100%|██████████| 62505/62505 [09:45<00:00, 106.72it/s]


------------------------------------
training Epoch: 14
{'ce': 3.735355686187744, 'ndcg': np.float64(0.5317158936073585)}
Training takes--- 585.6835901930008 seconds ---


100%|██████████| 13394/13394 [01:58<00:00, 112.78it/s]


{'ndcg': np.float64(0.5066687926210903)}
Validation takes--- 118.80132236200006 seconds ---


 99%|█████████▉| 13299/13394 [01:57<00:00, 113.13it/s]


{'ndcg': np.float64(0.5112932673380404)}
Test takes--- 117.56382733099963 seconds ---
------------------------------------


100%|██████████| 62505/62505 [09:45<00:00, 106.75it/s]


------------------------------------
training Epoch: 15
{'ce': 3.72817746181488, 'ndcg': np.float64(0.5326742172441776)}
Training takes--- 585.5684580020006 seconds ---


100%|██████████| 13394/13394 [01:59<00:00, 111.72it/s]


{'ndcg': np.float64(0.5073420740668543)}
Validation takes--- 119.9081912090005 seconds ---


 99%|█████████▉| 13299/13394 [01:57<00:00, 112.81it/s]


{'ndcg': np.float64(0.5117745428513919)}
Test takes--- 117.89744414500092 seconds ---
------------------------------------


100%|██████████| 62505/62505 [09:47<00:00, 106.47it/s]


------------------------------------
training Epoch: 16
{'ce': 3.7207003215789793, 'ndcg': np.float64(0.5337593286205001)}
Training takes--- 587.0938963240005 seconds ---


100%|██████████| 13394/13394 [02:02<00:00, 109.55it/s]


{'ndcg': np.float64(0.5080270583424209)}
Validation takes--- 122.29077470700031 seconds ---


 99%|█████████▉| 13299/13394 [01:58<00:00, 112.20it/s]


{'ndcg': np.float64(0.512143605638542)}
Test takes--- 118.53292636699916 seconds ---
------------------------------------


100%|██████████| 62505/62505 [09:47<00:00, 106.45it/s]


------------------------------------
training Epoch: 17
{'ce': 3.714026598358154, 'ndcg': np.float64(0.5345667108754465)}
Training takes--- 587.1711019590002 seconds ---


100%|██████████| 13394/13394 [01:59<00:00, 111.68it/s]


{'ndcg': np.float64(0.5083305331046507)}
Validation takes--- 119.9593802219988 seconds ---


 99%|█████████▉| 13299/13394 [01:57<00:00, 113.55it/s]


{'ndcg': np.float64(0.5130004427078141)}
Test takes--- 117.1236576859992 seconds ---
------------------------------------


100%|██████████| 62505/62505 [09:46<00:00, 106.65it/s]


------------------------------------
training Epoch: 18
{'ce': 3.707283125114441, 'ndcg': np.float64(0.5354002036546132)}
Training takes--- 586.1199172069992 seconds ---


100%|██████████| 13394/13394 [01:59<00:00, 112.09it/s]


{'ndcg': np.float64(0.5084650182534296)}
Validation takes--- 119.51651450400095 seconds ---


 99%|█████████▉| 13299/13394 [01:57<00:00, 113.54it/s]


{'ndcg': np.float64(0.5117067346449233)}
Test takes--- 117.13268674700157 seconds ---
------------------------------------


100%|██████████| 62505/62505 [09:43<00:00, 107.09it/s]


------------------------------------
training Epoch: 19
{'ce': 3.702365365219116, 'ndcg': np.float64(0.5362352037327713)}
Training takes--- 583.6615314430001 seconds ---


100%|██████████| 13394/13394 [01:59<00:00, 111.82it/s]


{'ndcg': np.float64(0.5090933640887079)}
Validation takes--- 119.80628196300131 seconds ---


 99%|█████████▉| 13299/13394 [01:56<00:00, 114.25it/s]


{'ndcg': np.float64(0.5133028654104652)}
Test takes--- 116.41014784599975 seconds ---
------------------------------------


100%|██████████| 62505/62505 [09:43<00:00, 107.18it/s]


------------------------------------
training Epoch: 20
{'ce': 3.6962806829452513, 'ndcg': np.float64(0.5371539437029665)}
Training takes--- 583.1690444509986 seconds ---


100%|██████████| 13394/13394 [01:58<00:00, 112.80it/s]


{'ndcg': np.float64(0.50889703563668)}
Validation takes--- 118.76799616000062 seconds ---


 99%|█████████▉| 13299/13394 [01:56<00:00, 114.57it/s]


{'ndcg': np.float64(0.5129575640042783)}
Test takes--- 116.08142068699817 seconds ---
------------------------------------


100%|██████████| 62505/62505 [09:43<00:00, 107.13it/s]


------------------------------------
training Epoch: 21
{'ce': 3.6915340682983397, 'ndcg': np.float64(0.5378295268981788)}
Training takes--- 583.4803151509986 seconds ---


100%|██████████| 13394/13394 [01:58<00:00, 112.90it/s]


{'ndcg': np.float64(0.5098813021911028)}
Validation takes--- 118.66207257000133 seconds ---


 99%|█████████▉| 13299/13394 [01:56<00:00, 113.88it/s]


{'ndcg': np.float64(0.5141751708573208)}
Test takes--- 116.79120076499748 seconds ---
------------------------------------


100%|██████████| 62505/62505 [09:43<00:00, 107.15it/s]


------------------------------------
training Epoch: 22
{'ce': 3.686108701133728, 'ndcg': np.float64(0.5384060225959201)}
Training takes--- 583.3749815850024 seconds ---


100%|██████████| 13394/13394 [02:01<00:00, 110.55it/s]


{'ndcg': np.float64(0.5101989651813674)}
Validation takes--- 121.1795583460007 seconds ---


 99%|█████████▉| 13299/13394 [01:56<00:00, 114.27it/s]


{'ndcg': np.float64(0.5144272688086063)}
Test takes--- 116.39264975599872 seconds ---
------------------------------------


100%|██████████| 62505/62505 [09:44<00:00, 107.00it/s]


------------------------------------
training Epoch: 23
{'ce': 3.681128923225403, 'ndcg': np.float64(0.5390573057205665)}
Training takes--- 584.1821340630013 seconds ---


100%|██████████| 13394/13394 [02:00<00:00, 111.59it/s]


{'ndcg': np.float64(0.5103524371314104)}
Validation takes--- 120.05530675600312 seconds ---


 99%|█████████▉| 13299/13394 [01:56<00:00, 113.93it/s]


{'ndcg': np.float64(0.515238104005058)}
Test takes--- 116.72827805500128 seconds ---
------------------------------------


100%|██████████| 62505/62505 [09:45<00:00, 106.75it/s]


------------------------------------
training Epoch: 24
{'ce': 3.6766774551391603, 'ndcg': np.float64(0.5398403726727711)}
Training takes--- 585.5444334400017 seconds ---


100%|██████████| 13394/13394 [01:59<00:00, 112.46it/s]


{'ndcg': np.float64(0.5106813984477402)}
Validation takes--- 119.11774692999825 seconds ---


 99%|█████████▉| 13299/13394 [01:56<00:00, 114.12it/s]


{'ndcg': np.float64(0.5147655116667402)}
Test takes--- 116.53544766699997 seconds ---
------------------------------------


100%|██████████| 62505/62505 [09:44<00:00, 106.87it/s]


------------------------------------
training Epoch: 25
{'ce': 3.6726275173187255, 'ndcg': np.float64(0.5402243907928971)}
Training takes--- 584.8617464180024 seconds ---


100%|██████████| 13394/13394 [02:00<00:00, 111.49it/s]


{'ndcg': np.float64(0.5108785862279904)}
Validation takes--- 120.15639414600082 seconds ---


 99%|█████████▉| 13299/13394 [01:56<00:00, 114.46it/s]


{'ndcg': np.float64(0.5151896712496061)}
Test takes--- 116.1966249649995 seconds ---
------------------------------------


100%|██████████| 62505/62505 [09:45<00:00, 106.69it/s]


------------------------------------
training Epoch: 26
{'ce': 3.6679245140075682, 'ndcg': np.float64(0.5408378947103631)}
Training takes--- 585.8965328850027 seconds ---


100%|██████████| 13394/13394 [01:59<00:00, 111.90it/s]


{'ndcg': np.float64(0.51011790775451)}
Validation takes--- 119.7167260219976 seconds ---


 99%|█████████▉| 13299/13394 [01:56<00:00, 113.87it/s]


{'ndcg': np.float64(0.5133634869298911)}
Test takes--- 116.7979199950023 seconds ---
------------------------------------


100%|██████████| 62505/62505 [09:45<00:00, 106.77it/s]


------------------------------------
training Epoch: 27
{'ce': 3.6634274269104004, 'ndcg': np.float64(0.5412400793527355)}
Training takes--- 585.443305904002 seconds ---


100%|██████████| 13394/13394 [02:00<00:00, 111.47it/s]


{'ndcg': np.float64(0.5110413007506336)}
Validation takes--- 120.17630724300034 seconds ---


 99%|█████████▉| 13299/13394 [01:58<00:00, 112.04it/s]


{'ndcg': np.float64(0.5157404648229967)}
Test takes--- 118.70167248000143 seconds ---
------------------------------------


100%|██████████| 62505/62505 [09:45<00:00, 106.76it/s]


------------------------------------
training Epoch: 28
{'ce': 3.6591497261047365, 'ndcg': np.float64(0.5421723903478419)}
Training takes--- 585.4646152820023 seconds ---


100%|██████████| 13394/13394 [02:00<00:00, 111.24it/s]


{'ndcg': np.float64(0.5110229352086618)}
Validation takes--- 120.43450930099789 seconds ---


 99%|█████████▉| 13299/13394 [01:57<00:00, 113.26it/s]


{'ndcg': np.float64(0.5146842087748136)}
Test takes--- 117.42652191600064 seconds ---
------------------------------------


100%|██████████| 62505/62505 [09:45<00:00, 106.73it/s]


------------------------------------
training Epoch: 29
{'ce': 3.654435998916626, 'ndcg': np.float64(0.5426226526150997)}
Training takes--- 585.6462285100024 seconds ---


100%|██████████| 13394/13394 [01:59<00:00, 111.97it/s]


{'ndcg': np.float64(0.5107617882989708)}
Validation takes--- 119.6441239429987 seconds ---


 99%|█████████▉| 13299/13394 [01:56<00:00, 114.35it/s]


{'ndcg': np.float64(0.5142548729569816)}
Test takes--- 116.30902565799988 seconds ---
------------------------------------


100%|██████████| 62505/62505 [09:45<00:00, 106.73it/s]


------------------------------------
training Epoch: 30
{'ce': 3.6506032470703125, 'ndcg': np.float64(0.5431029083385468)}
Training takes--- 585.6443214589999 seconds ---


100%|██████████| 13394/13394 [01:59<00:00, 112.46it/s]


{'ndcg': np.float64(0.509643996474926)}
Validation takes--- 119.12558258900026 seconds ---


 99%|█████████▉| 13299/13394 [01:55<00:00, 114.92it/s]


{'ndcg': np.float64(0.5125000410168296)}
Test takes--- 115.72952095400251 seconds ---
------------------------------------


100%|██████████| 62505/62505 [09:44<00:00, 106.89it/s]


------------------------------------
training Epoch: 31
{'ce': 3.6473770292282106, 'ndcg': np.float64(0.5433370276040811)}
Training takes--- 584.8003505710003 seconds ---


100%|██████████| 13394/13394 [01:59<00:00, 112.30it/s]


{'ndcg': np.float64(0.5085218607351726)}
Validation takes--- 119.29088444499939 seconds ---


 99%|█████████▉| 13299/13394 [01:56<00:00, 114.15it/s]


{'ndcg': np.float64(0.512596499884624)}
Test takes--- 116.50754961699931 seconds ---
------------------------------------


100%|██████████| 62505/62505 [09:46<00:00, 106.64it/s]


------------------------------------
training Epoch: 32
{'ce': 3.6419107097625734, 'ndcg': np.float64(0.5443993842510274)}
Training takes--- 586.1736024089987 seconds ---


100%|██████████| 13394/13394 [01:59<00:00, 112.33it/s]


{'ndcg': np.float64(0.5113635911455997)}
Validation takes--- 119.26315868999882 seconds ---


 99%|█████████▉| 13299/13394 [01:55<00:00, 115.11it/s]


{'ndcg': np.float64(0.5157411933204837)}
Test takes--- 115.53957298000023 seconds ---
------------------------------------


100%|██████████| 62505/62505 [09:45<00:00, 106.81it/s]


------------------------------------
training Epoch: 33
{'ce': 3.637394211578369, 'ndcg': np.float64(0.5448747735459832)}
Training takes--- 585.201901856999 seconds ---


100%|██████████| 13394/13394 [01:58<00:00, 113.04it/s]


{'ndcg': np.float64(0.5102627018971826)}
Validation takes--- 118.50963329700244 seconds ---


 99%|█████████▉| 13299/13394 [01:56<00:00, 114.48it/s]


{'ndcg': np.float64(0.5134101861928899)}
Test takes--- 116.16964676499992 seconds ---
------------------------------------


100%|██████████| 62505/62505 [09:44<00:00, 106.89it/s]


------------------------------------
training Epoch: 34
{'ce': 3.6333824098587035, 'ndcg': np.float64(0.5455616972627974)}
Training takes--- 584.7916762759996 seconds ---


100%|██████████| 13394/13394 [01:58<00:00, 113.07it/s]


{'ndcg': np.float64(0.5093951937079708)}
Validation takes--- 118.47848489999888 seconds ---


 99%|█████████▉| 13299/13394 [01:56<00:00, 114.59it/s]


{'ndcg': np.float64(0.5110200346248656)}
Test takes--- 116.0564811860022 seconds ---
------------------------------------


100%|██████████| 62505/62505 [09:43<00:00, 107.20it/s]


------------------------------------
training Epoch: 35
{'ce': 3.6306360551834107, 'ndcg': np.float64(0.546035824013853)}
Training takes--- 583.0854063469997 seconds ---


100%|██████████| 13394/13394 [01:58<00:00, 112.87it/s]


{'ndcg': np.float64(0.5124851764278913)}
Validation takes--- 118.88294475400107 seconds ---


 99%|█████████▉| 13299/13394 [01:55<00:00, 115.11it/s]


{'ndcg': np.float64(0.5170196562587372)}
Test takes--- 115.53568930699839 seconds ---
------------------------------------


100%|██████████| 62505/62505 [09:42<00:00, 107.29it/s]


------------------------------------
training Epoch: 36
{'ce': 3.6259703422546385, 'ndcg': np.float64(0.5464349750033852)}
Training takes--- 582.5913938179983 seconds ---


100%|██████████| 13394/13394 [01:58<00:00, 112.93it/s]


{'ndcg': np.float64(0.5108356377776584)}
Validation takes--- 118.62753602399971 seconds ---


 99%|█████████▉| 13299/13394 [01:55<00:00, 114.78it/s]


{'ndcg': np.float64(0.5145653456131781)}
Test takes--- 115.87031711199961 seconds ---
------------------------------------


100%|██████████| 62505/62505 [09:41<00:00, 107.45it/s]


------------------------------------
training Epoch: 37
{'ce': 3.620151674461365, 'ndcg': np.float64(0.5476918989572873)}
Training takes--- 581.7195241430018 seconds ---


100%|██████████| 13394/13394 [01:58<00:00, 113.20it/s]


{'ndcg': np.float64(0.5122925323934062)}
Validation takes--- 118.34870714700082 seconds ---


 99%|█████████▉| 13299/13394 [01:56<00:00, 114.46it/s]


{'ndcg': np.float64(0.5164492102084571)}
Test takes--- 116.1882211330012 seconds ---
------------------------------------


100%|██████████| 62505/62505 [09:40<00:00, 107.62it/s]


------------------------------------
training Epoch: 38
{'ce': 3.615781950187683, 'ndcg': np.float64(0.5484349983701082)}
Training takes--- 580.8129542150018 seconds ---


100%|██████████| 13394/13394 [01:58<00:00, 112.86it/s]


{'ndcg': np.float64(0.5115752260790327)}
Validation takes--- 118.69635718499921 seconds ---


 99%|█████████▉| 13299/13394 [01:55<00:00, 115.00it/s]


{'ndcg': np.float64(0.5142454672521821)}
Test takes--- 115.6471930780026 seconds ---
------------------------------------


100%|██████████| 62505/62505 [09:41<00:00, 107.44it/s]


------------------------------------
training Epoch: 39
{'ce': 3.612002512168884, 'ndcg': np.float64(0.5488457516043017)}
Training takes--- 581.7804267030006 seconds ---


100%|██████████| 13394/13394 [01:58<00:00, 112.94it/s]


{'ndcg': np.float64(0.5130699405612995)}
Validation takes--- 118.61565094999969 seconds ---


 99%|█████████▉| 13299/13394 [01:55<00:00, 114.68it/s]


{'ndcg': np.float64(0.5175909032773565)}
Test takes--- 115.97222358399813 seconds ---
------------------------------------


100%|██████████| 62505/62505 [09:42<00:00, 107.27it/s]


------------------------------------
training Epoch: 40
{'ce': 3.6052696321487425, 'ndcg': np.float64(0.5501695581406361)}
Training takes--- 582.6998658839984 seconds ---


100%|██████████| 13394/13394 [01:58<00:00, 112.74it/s]


{'ndcg': np.float64(0.5121514136477621)}
Validation takes--- 118.82696029899671 seconds ---


 99%|█████████▉| 13299/13394 [01:56<00:00, 114.01it/s]


{'ndcg': np.float64(0.5160783031457202)}
Test takes--- 116.6566273770004 seconds ---
------------------------------------


100%|██████████| 62505/62505 [09:42<00:00, 107.28it/s]


------------------------------------
training Epoch: 41
{'ce': 3.601056960487366, 'ndcg': np.float64(0.5509718577953866)}
Training takes--- 582.6725948149979 seconds ---


100%|██████████| 13394/13394 [01:58<00:00, 112.61it/s]


{'ndcg': np.float64(0.5114622364974417)}
Validation takes--- 118.9651272839983 seconds ---


 99%|█████████▉| 13299/13394 [01:56<00:00, 114.21it/s]


{'ndcg': np.float64(0.5132986119788612)}
Test takes--- 116.45106222100003 seconds ---
------------------------------------


100%|██████████| 62505/62505 [09:45<00:00, 106.72it/s]


------------------------------------
training Epoch: 42
{'ce': 3.59669154548645, 'ndcg': np.float64(0.5515844901695813)}
Training takes--- 585.6828877490043 seconds ---


100%|██████████| 13394/13394 [01:59<00:00, 112.53it/s]


{'ndcg': np.float64(0.5115231875575365)}
Validation takes--- 119.04920042599406 seconds ---


 99%|█████████▉| 13299/13394 [01:56<00:00, 114.30it/s]


{'ndcg': np.float64(0.5146478376045436)}
Test takes--- 116.35643246999825 seconds ---
------------------------------------


100%|██████████| 62505/62505 [09:51<00:00, 105.61it/s]


------------------------------------
training Epoch: 43
{'ce': 3.5928194278717043, 'ndcg': np.float64(0.5528481761223386)}
Training takes--- 591.8905101940036 seconds ---


100%|██████████| 13394/13394 [01:59<00:00, 112.51it/s]


{'ndcg': np.float64(0.5106913019386538)}
Validation takes--- 119.07423906699842 seconds ---


 99%|█████████▉| 13299/13394 [01:55<00:00, 114.74it/s]


{'ndcg': np.float64(0.5134432681545394)}
Test takes--- 115.90576241300005 seconds ---
------------------------------------


100%|██████████| 62505/62505 [09:47<00:00, 106.38it/s]


------------------------------------
training Epoch: 44
{'ce': 3.5874018119812012, 'ndcg': np.float64(0.5536440331309497)}
Training takes--- 587.6030343050006 seconds ---


100%|██████████| 13394/13394 [01:59<00:00, 111.77it/s]


{'ndcg': np.float64(0.5109381851024017)}
Validation takes--- 119.86236064299737 seconds ---


 99%|█████████▉| 13299/13394 [01:58<00:00, 112.57it/s]


{'ndcg': np.float64(0.5130711854662373)}
Test takes--- 118.14935878899996 seconds ---
------------------------------------


100%|██████████| 62505/62505 [09:52<00:00, 105.56it/s]


------------------------------------
training Epoch: 45
{'ce': 3.580384783363342, 'ndcg': np.float64(0.5550829656027337)}
Training takes--- 592.126381121001 seconds ---


100%|██████████| 13394/13394 [01:59<00:00, 111.90it/s]


{'ndcg': np.float64(0.5111740132312341)}
Validation takes--- 119.72656464299507 seconds ---


 99%|█████████▉| 13299/13394 [01:57<00:00, 112.74it/s]


{'ndcg': np.float64(0.5156358313313684)}
Test takes--- 117.96904716000427 seconds ---
------------------------------------


100%|██████████| 62505/62505 [10:01<00:00, 103.90it/s]


------------------------------------
training Epoch: 46
{'ce': 3.5769771697998047, 'ndcg': np.float64(0.5552062798382349)}
Training takes--- 601.5928304909976 seconds ---


100%|██████████| 13394/13394 [02:04<00:00, 107.97it/s]


{'ndcg': np.float64(0.5116972922191428)}
Validation takes--- 124.07418990999577 seconds ---


 99%|█████████▉| 13299/13394 [01:59<00:00, 111.42it/s]


{'ndcg': np.float64(0.5161843372407373)}
Test takes--- 119.36930477800342 seconds ---
------------------------------------


100%|██████████| 62505/62505 [09:57<00:00, 104.55it/s]


------------------------------------
training Epoch: 47
{'ce': 3.5708628547668457, 'ndcg': np.float64(0.5568250180130417)}
Training takes--- 597.8415470429973 seconds ---


100%|██████████| 13394/13394 [01:59<00:00, 111.81it/s]


{'ndcg': np.float64(0.5113571683550042)}
Validation takes--- 119.81474676200014 seconds ---


 99%|█████████▉| 13299/13394 [01:59<00:00, 111.58it/s]


{'ndcg': np.float64(0.516703527890401)}
Test takes--- 119.1859686490061 seconds ---
------------------------------------


100%|██████████| 62505/62505 [09:59<00:00, 104.31it/s]


------------------------------------
training Epoch: 48
{'ce': 3.5630626585006713, 'ndcg': np.float64(0.5582617469348368)}
Training takes--- 599.216663113999 seconds ---


100%|██████████| 13394/13394 [02:00<00:00, 110.75it/s]


{'ndcg': np.float64(0.5107720780261132)}
Validation takes--- 120.96406836499955 seconds ---


 99%|█████████▉| 13299/13394 [01:58<00:00, 112.62it/s]


{'ndcg': np.float64(0.5154243745791295)}
Test takes--- 118.09424460700393 seconds ---
------------------------------------


100%|██████████| 62505/62505 [09:51<00:00, 105.59it/s]


------------------------------------
training Epoch: 49
{'ce': 3.5585101306915283, 'ndcg': np.float64(0.559283220288495)}
Training takes--- 591.9577943480035 seconds ---


100%|██████████| 13394/13394 [01:58<00:00, 112.56it/s]


{'ndcg': np.float64(0.5101370707772315)}
Validation takes--- 119.01318311999785 seconds ---


 99%|█████████▉| 13299/13394 [01:56<00:00, 113.71it/s]


{'ndcg': np.float64(0.5150775546167746)}
Test takes--- 116.96113008399698 seconds ---
------------------------------------


100%|██████████| 62505/62505 [09:48<00:00, 106.13it/s]


------------------------------------
training Epoch: 50
{'ce': 3.5499275661468506, 'ndcg': np.float64(0.561102052130957)}
Training takes--- 588.9801423999961 seconds ---


100%|██████████| 13394/13394 [01:59<00:00, 111.67it/s]


{'ndcg': np.float64(0.5100366762914512)}
Validation takes--- 119.96814754599473 seconds ---


 99%|█████████▉| 13299/13394 [01:56<00:00, 113.71it/s]

{'ndcg': np.float64(0.5148715528347158)}
Test takes--- 116.95856902100059 seconds ---
------------------------------------
------------------------------------
------------------------------------
best val score:  0.5130699405612995
best validation epoch   :  39
best test score:  0.5175909032773565



