# TGN with MA features, NDCGLoss2++, LSTM instead of GRU for memory update

The code was adapted from https://github.com/shenyangHuang/TGB and https://github.com/allegro/allRank

In [None]:
!pip install pandas -q
!pip install torch-geometric -f https://data.pyg.org/whl/torch-2.8.0+cu126.html
!pip install py-tgb -q
!pip install modules


Looking in links: https://data.pyg.org/whl/torch-2.8.0+cu126.html
Collecting torch-geometric
  Downloading torch_geometric-2.7.0-py3-none-any.whl.metadata (63 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.7/63.7 kB[0m [31m7.0 MB/s[0m eta [36m0:00:00[0m
Downloading torch_geometric-2.7.0-py3-none-any.whl (1.3 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m71.6 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: torch-geometric
Successfully installed torch-geometric-2.7.0
  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m91.2/91.2 kB[0m [31m3.3 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m154.4/154.4 kB[0m [31m14.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m12.4/12.4 MB[0m [31m70.3 

In [None]:
import torch
from torch.nn import Linear
import torch.nn.functional as F
from torch.nn.parameter import Parameter
import math
from torch_geometric.nn import TransformerConv
import numpy as np

class NodePredictor(torch.nn.Module):
    def __init__(self, in_dim, out_dim):
        super().__init__()
        self.lin_node = Linear(in_dim, in_dim)
        self.out = Linear(in_dim, out_dim)

    def forward(self, node_embed):
        h = self.lin_node(node_embed)
        h = h.relu()
        h = self.out(h)
        # h = F.log_softmax(h, dim=-1)
        return h

class GraphAttentionEmbedding(torch.nn.Module):
    """
    Reference:
    - https://github.com/pyg-team/pytorch_geometric/blob/master/examples/tgn.py
    """

    def __init__(self, in_channels, out_channels, msg_dim, time_enc):
        super().__init__()
        self.time_enc = time_enc
        edge_dim = msg_dim + time_enc.out_channels
        self.conv = TransformerConv(
            in_channels, out_channels // 2, heads=2, dropout=0.1, edge_dim=edge_dim
        )

    def forward(self, x, last_update, edge_index, t, msg):
        rel_t = last_update[edge_index[0]] - t
        rel_t_enc = self.time_enc(rel_t.to(x.dtype))
        edge_attr = torch.cat([rel_t_enc, msg], dim=-1)
        return self.conv(x, edge_index, edge_attr)


class MAFeatures:
    def __init__(self, num_class, window=7):
        self.num_class = num_class
        self.window = window
        self.dict = {}

    def reset(self):
        self.dict = {}

    def update_dict(self, node_id, label_vec):
        if node_id in self.dict:
            total = self.dict[node_id] * (self.window - 1) + label_vec
            self.dict[node_id] = total / self.window
        else:
            self.dict[node_id] = label_vec

    def query_dict(self, node_id):
        if node_id in self.dict:
            return self.dict[node_id]
        else:
            return np.zeros(self.num_class, dtype=np.float32)

    def batch_query(self, node_ids):
        feats = [self.query_dict(int(n)) for n in node_ids]
        return np.stack(feats, axis=0).astype(np.float32)


In [None]:
import torch

PADDED_Y_VALUE = -1
DEFAULT_EPS = 1e-10


def lambdaLoss(y_pred, y_true, eps=DEFAULT_EPS, padded_value_indicator=PADDED_Y_VALUE, weighing_scheme=None, k=None, sigma=1., mu=5.,
               reduction="sum", reduction_log="binary"):
    """
    LambdaLoss framework for LTR losses implementations, introduced in "The LambdaLoss Framework for Ranking Metric Optimization".
    Contains implementations of different weighing schemes corresponding to e.g. LambdaRank or RankNet.
    :param y_pred: predictions from the model, shape [batch_size, slate_length]
    :param y_true: ground truth labels, shape [batch_size, slate_length]
    :param eps: epsilon value, used for numerical stability
    :param padded_value_indicator: an indicator of the y_true index containing a padded item, e.g. -1
    :param weighing_scheme: a string corresponding to a name of one of the weighing schemes
    :param k: rank at which the loss is truncated
    :param sigma: score difference weight used in the sigmoid function
    :param mu: optional weight used in NDCGLoss2++ weighing scheme
    :param reduction: losses reduction method, could be either a sum or a mean
    :param reduction_log: logarithm variant used prior to masking and loss reduction, either binary or natural
    :return: loss value, a torch.Tensor
    """
    device = y_pred.device
    y_pred = y_pred.clone()
    y_true = y_true.clone()

    padded_mask = y_true == padded_value_indicator
    y_pred[padded_mask] = float("-inf")
    y_true[padded_mask] = float("-inf")

    # Here we sort the true and predicted relevancy scores.
    y_pred_sorted, indices_pred = y_pred.sort(descending=True, dim=-1)
    y_true_sorted, _ = y_true.sort(descending=True, dim=-1)

    # After sorting, we can mask out the pairs of indices (i, j) containing index of a padded element.
    true_sorted_by_preds = torch.gather(y_true, dim=1, index=indices_pred)
    true_diffs = true_sorted_by_preds[:, :, None] - true_sorted_by_preds[:, None, :]
    padded_pairs_mask = torch.isfinite(true_diffs)

    if weighing_scheme != "ndcgLoss1_scheme":
        padded_pairs_mask = padded_pairs_mask & (true_diffs > 0)

    ndcg_at_k_mask = torch.zeros((y_pred.shape[1], y_pred.shape[1]), dtype=torch.bool, device=device)
    ndcg_at_k_mask[:k, :k] = 1

    # Here we clamp the -infs to get correct gains and ideal DCGs (maxDCGs)
    true_sorted_by_preds.clamp_(min=0.)
    y_true_sorted.clamp_(min=0.)

    # Here we find the gains, discounts and ideal DCGs per slate.
    pos_idxs = torch.arange(1, y_pred.shape[1] + 1).to(device)
    D = torch.log2(1. + pos_idxs.float())[None, :]
    maxDCGs = torch.sum(((torch.pow(2, y_true_sorted) - 1) / D)[:, :k], dim=-1).clamp(min=eps)
    G = (torch.pow(2, true_sorted_by_preds) - 1) / maxDCGs[:, None]

    # Here we apply appropriate weighing scheme - ndcgLoss1, ndcgLoss2, ndcgLoss2++ or no weights (=1.0)
    if weighing_scheme is None:
        weights = 1.
    else:
        weights = globals()[weighing_scheme](G, D, mu, true_sorted_by_preds)  # type: ignore

    # We are clamping the array entries to maintain correct backprop (log(0) and division by 0)
    scores_diffs = (y_pred_sorted[:, :, None] - y_pred_sorted[:, None, :]).clamp(min=-1e8, max=1e8)
    scores_diffs.masked_fill(torch.isnan(scores_diffs), 0.)
    weighted_probas = (torch.sigmoid(sigma * scores_diffs).clamp(min=eps) ** weights).clamp(min=eps)
    if reduction_log == "natural":
        losses = torch.log(weighted_probas)
    elif reduction_log == "binary":
        losses = torch.log2(weighted_probas)
    else:
        raise ValueError("Reduction logarithm base can be either natural or binary")

    if reduction == "sum":
        loss = -torch.sum(losses[padded_pairs_mask & ndcg_at_k_mask])
    elif reduction == "mean":
        loss = -torch.mean(losses[padded_pairs_mask & ndcg_at_k_mask])
    else:
        raise ValueError("Reduction method can be either sum or mean")

    return loss


def ndcgLoss1_scheme(G, D, *args):
    return (G / D)[:, :, None]


def ndcgLoss2_scheme(G, D, *args):
    pos_idxs = torch.arange(1, G.shape[1] + 1, device=G.device)
    delta_idxs = torch.abs(pos_idxs[:, None] - pos_idxs[None, :])
    deltas = torch.abs(torch.pow(torch.abs(D[0, delta_idxs - 1]), -1.) - torch.pow(torch.abs(D[0, delta_idxs]), -1.))
    deltas.diagonal().zero_()

    return deltas[None, :, :] * torch.abs(G[:, :, None] - G[:, None, :])


def lambdaRank_scheme(G, D, *args):
    return torch.abs(torch.pow(D[:, :, None], -1.) - torch.pow(D[:, None, :], -1.)) * torch.abs(G[:, :, None] - G[:, None, :])


def ndcgLoss2PP_scheme(G, D, *args):
    return args[0] * ndcgLoss2_scheme(G, D) + lambdaRank_scheme(G, D)


def rankNet_scheme(G, D, *args):
    return 1.


def rankNetWeightedByGTDiff_scheme(G, D, *args):
    return torch.abs(args[1][:, :, None] - args[1][:, None, :])


def rankNetWeightedByGTDiffPowed_scheme(G, D, *args):
    return torch.abs(torch.pow(args[1][:, :, None], 2) - torch.pow(args[1][:, None, :], 2))

In [None]:
# Choose hyperparameters
lr = 0.0001
batch_size = 200
global_hidden_dim = 100
nb_neighbors = 10
window_ma = 7

epochs = 50

In [None]:
import torch
from torch_geometric.nn import TGNMemory
from torch_geometric.utils import scatter

class LSTMTGNMemory(TGNMemory):
    def __init__(self, num_nodes, raw_msg_dim, memory_dim, time_dim, message_module, aggregator_module):
        super().__init__(num_nodes, raw_msg_dim, memory_dim, time_dim, message_module, aggregator_module)

        # 1. Replace GRU with LSTMCell
        self.lstm = torch.nn.LSTMCell(input_size=message_module.out_channels, hidden_size=memory_dim)

        # 2. Register buffer for Cell State
        self.register_buffer("cell_memory", torch.empty(num_nodes, memory_dim))

        # Initialize
        self.reset_state()

    def reset_state(self):
        super().reset_state()
        if hasattr(self, 'cell_memory'):
            self.cell_memory.fill_(0)

    def detach(self):
        super().detach()
        self.cell_memory.detach_()

    def _update_memory(self, n_id):
        # Message Aggregation
        msg_s, t_s, src_s, dst_s = self._compute_msg(n_id, self.msg_s_store, self.msg_s_module)
        msg_d, t_d, src_d, dst_d = self._compute_msg(n_id, self.msg_d_store, self.msg_d_module)

        idx = torch.cat([src_s, src_d], dim=0)
        msg = torch.cat([msg_s, msg_d], dim=0)
        t = torch.cat([t_s, t_d], dim=0)

        assoc = torch.full((self.num_nodes,), -1, dtype=torch.long, device=n_id.device)
        assoc[n_id] = torch.arange(n_id.size(0), device=n_id.device)
        local_idx = assoc[idx]

        aggr_msg = self.aggr_module(msg, local_idx, t, n_id.size(0))


        # LSTM Update
        h_state = self.memory[n_id]
        c_state = self.cell_memory[n_id]

        h_new, c_new = self.lstm(aggr_msg, (h_state, c_state))

        self.memory[n_id] = h_new
        self.cell_memory[n_id] = c_new

        # Time Update
        latest_time = scatter(t, local_idx, dim=0, dim_size=n_id.size(0), reduce='max')

        node_has_msg = torch.zeros(n_id.size(0), dtype=torch.bool, device=n_id.device)
        node_has_msg[local_idx] = True

        self.last_update[n_id[node_has_msg]] = latest_time[node_has_msg]

In [None]:
from tqdm import tqdm
import torch
import timeit
import matplotlib.pyplot as plt

from torch_geometric.loader import TemporalDataLoader
from torch_geometric.nn import TGNMemory
from torch_geometric.nn.models.tgn import (
    IdentityMessage,
    LastAggregator,
    LastNeighborLoader,
)

from tgb.nodeproppred.dataset_pyg import PyGNodePropPredDataset
from tgb.nodeproppred.evaluate import Evaluator
from tgb.utils.utils import set_random_seed

seed = 1
print ("setting random seed to be", seed)
torch.manual_seed(seed)
set_random_seed(seed)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
name = "tgbn-genre"
dataset = PyGNodePropPredDataset(name=name, root="datasets")
train_mask = dataset.train_mask.to(device)
val_mask = dataset.val_mask.to(device)
test_mask = dataset.test_mask.to(device)

eval_metric = dataset.eval_metric
num_classes = dataset.num_classes
data = dataset.get_TemporalData()
data = data.to(device)

train_data = data[train_mask].to(device)
val_data = data[val_mask].to(device)
test_data = data[test_mask].to(device)

evaluator = Evaluator(name=name)
ma_tracker = MAFeatures(num_class=num_classes, window=window_ma)

# Ensure to only sample actual destination nodes as negatives.
min_dst_idx, max_dst_idx = int(data.dst.min()), int(data.dst.max())


train_loader = TemporalDataLoader(train_data, batch_size=batch_size)
val_loader = TemporalDataLoader(val_data, batch_size=batch_size)
test_loader = TemporalDataLoader(test_data, batch_size=batch_size)

neighbor_loader = LastNeighborLoader(data.num_nodes, size=nb_neighbors, device=device)

memory_dim = time_dim = embedding_dim = global_hidden_dim

#  memory = TGNMemory( # OLD CODE: GRU memory update
#      data.num_nodes,
#      data.msg.size(-1),
#      memory_dim,
#      time_dim,
#      message_module=IdentityMessage(data.msg.size(-1), memory_dim, time_dim),
#      aggregator_module=LastAggregator(),
#  ).to(device)

memory = LSTMTGNMemory(
    data.num_nodes,
    data.msg.size(-1),
    memory_dim,
    time_dim,
    message_module=IdentityMessage(data.msg.size(-1), memory_dim, time_dim),
    aggregator_module=LastAggregator(),
).to(device)


gnn = (
    GraphAttentionEmbedding(
        in_channels=memory_dim,
        out_channels=embedding_dim,
        msg_dim=data.msg.size(-1),
        time_enc=memory.time_enc,
    )
    .to(device)
    .float()
)

node_pred = NodePredictor(in_dim=embedding_dim + num_classes, out_dim=num_classes).to(device)

optimizer = torch.optim.Adam(
    set(memory.parameters()) | set(gnn.parameters()) | set(node_pred.parameters()),
    lr=lr,
)

# Helper vector to map global node indices to local ones.
assoc = torch.empty(data.num_nodes, dtype=torch.long, device=device)


def plot_curve(scores, out_name):
    plt.plot(scores, color="#e34a33")
    plt.ylabel("score")
    plt.savefig(out_name + ".pdf")
    plt.close()


def process_edges(src, dst, t, msg):
    if src.nelement() > 0:
        # msg = msg.to(torch.float32)
        memory.update_state(src, dst, t, msg)
        neighbor_loader.insert(src, dst)


def train():
    memory.train()
    gnn.train()
    node_pred.train()

    memory.reset_state()  # Start with a fresh memory.
    neighbor_loader.reset_state()  # Start with an empty graph.
    ma_tracker.reset()

    total_loss = 0
    label_t = dataset.get_label_time()  # check when does the first label start
    total_score = 0
    num_label_ts = 0

    for batch in tqdm(train_loader):
        batch = batch.to(device)
        optimizer.zero_grad()
        src, dst, t, msg = batch.src, batch.dst, batch.t, batch.msg

        query_t = batch.t[-1]
        # check if this batch moves to the next day
        if query_t > label_t:
            # find the node labels from the past day
            label_tuple = dataset.get_node_label(query_t)
            label_ts, label_srcs, labels = (
                label_tuple[0],
                label_tuple[1],
                label_tuple[2],
            )
            label_t = dataset.get_label_time()
            label_srcs = label_srcs.to(device)

            # Process all edges that are still in the past day
            previous_day_mask = batch.t < label_t
            process_edges(
                src[previous_day_mask],
                dst[previous_day_mask],
                t[previous_day_mask],
                msg[previous_day_mask],
            )
            # Reset edges to be the edges from tomorrow so they can be used later
            src, dst, t, msg = (
                src[~previous_day_mask],
                dst[~previous_day_mask],
                t[~previous_day_mask],
                msg[~previous_day_mask],
            )

            """
            modified for node property prediction
            1. sample neighbors from the neighbor loader for all nodes to be predicted
            2. extract memory from the sampled neighbors and the nodes
            3. run gnn with the extracted memory embeddings and the corresponding time and message
            """
            n_id = label_srcs
            n_id_neighbors, mem_edge_index, e_id = neighbor_loader(n_id)
            assoc[n_id_neighbors] = torch.arange(n_id_neighbors.size(0), device=device)

            z, last_update = memory(n_id_neighbors)

            z = gnn(
                z,
                last_update,
                mem_edge_index,
                data.t[e_id].to(device),
                data.msg[e_id].to(device),
            )
            z = z[assoc[n_id]]

            # moving-average features (computed BEFORE updating with current labels)
            with torch.no_grad():
                ma_feats_np = ma_tracker.batch_query(label_srcs.detach().cpu().numpy())
            ma_feats = torch.from_numpy(ma_feats_np).to(device)
            z = torch.cat([z, ma_feats], dim=1)

            # loss and metric computation
            pred = node_pred(z)

            loss = lambdaLoss(
                y_pred=pred,
                y_true=labels.to(device),
                weighing_scheme='ndcgLoss2PP_scheme'
            )

            np_pred = pred.cpu().detach().numpy()
            np_true = labels.cpu().detach().numpy()

            input_dict = {
                "y_true": np_true,
                "y_pred": np_pred,
                "eval_metric": [eval_metric],
            }
            result_dict = evaluator.eval(input_dict)
            score = result_dict[eval_metric]
            total_score += score
            num_label_ts += 1

            loss.backward()
            optimizer.step()
            total_loss += float(loss.detach())

            # AFTER using current labels, update moving-average tracker
            with torch.no_grad():
                lbl_vec = labels.to(device)
                for nid, vec in zip(label_srcs, lbl_vec):
                    ma_tracker.update_dict(int(nid.item()), vec.detach().cpu().numpy())

        # Update memory and neighbor loader with ground-truth state.
        process_edges(src, dst, t, msg)
        memory.detach()

    metric_dict = {
        "ce": total_loss / num_label_ts,
    }
    metric_dict[eval_metric] = total_score / num_label_ts
    return metric_dict


@torch.no_grad()
def test(loader):
    memory.eval()
    gnn.eval()
    node_pred.eval()

    label_t = dataset.get_label_time()  # check when does the first label start
    num_label_ts = 0
    total_score = 0

    for batch in tqdm(loader):
        batch = batch.to(device)
        src, dst, t, msg = batch.src, batch.dst, batch.t, batch.msg

        query_t = batch.t[-1]
        if query_t > label_t:
            label_tuple = dataset.get_node_label(query_t)
            if label_tuple is None:
                break
            label_ts, label_srcs, labels = (
                label_tuple[0],
                label_tuple[1],
                label_tuple[2],
            )
            label_t = dataset.get_label_time()
            label_srcs = label_srcs.to(device)

            # Process all edges that are still in the past day
            previous_day_mask = batch.t < label_t
            process_edges(
                src[previous_day_mask],
                dst[previous_day_mask],
                t[previous_day_mask],
                msg[previous_day_mask],
            )
            # Reset edges to be the edges from tomorrow so they can be used later
            src, dst, t, msg = (
                src[~previous_day_mask],
                dst[~previous_day_mask],
                t[~previous_day_mask],
                msg[~previous_day_mask],
            )

            """
            modified for node property prediction
            1. sample neighbors from the neighbor loader for all nodes to be predicted
            2. extract memory from the sampled neighbors and the nodes
            3. run gnn with the extracted memory embeddings and the corresponding time and message
            """
            n_id = label_srcs
            n_id_neighbors, mem_edge_index, e_id = neighbor_loader(n_id)
            assoc[n_id_neighbors] = torch.arange(n_id_neighbors.size(0), device=device)

            z, last_update = memory(n_id_neighbors)
            z = gnn(
                z,
                last_update,
                mem_edge_index,
                data.t[e_id].to(device),
                data.msg[e_id].to(device),
            )
            z = z[assoc[n_id]]

            with torch.no_grad():
                ma_feats_np = ma_tracker.batch_query(label_srcs.detach().cpu().numpy())
            ma_feats = torch.from_numpy(ma_feats_np).to(device)
            z = torch.cat([z, ma_feats], dim=1)

            # loss and metric computation
            pred = node_pred(z)
            np_pred = pred.cpu().detach().numpy()
            np_true = labels.cpu().detach().numpy()

            input_dict = {
                "y_true": np_true,
                "y_pred": np_pred,
                "eval_metric": [eval_metric],
            }
            result_dict = evaluator.eval(input_dict)
            score = result_dict[eval_metric]
            total_score += score
            num_label_ts += 1

            # Update MA tracker with current labels for subsequent timesteps
            with torch.no_grad():
                lbl_vec = labels.to(device)
                for nid, vec in zip(label_srcs, lbl_vec):
                    ma_tracker.update_dict(int(nid.item()), vec.detach().cpu().numpy())

        process_edges(src, dst, t, msg)

    metric_dict = {}
    metric_dict[eval_metric] = total_score / num_label_ts
    return metric_dict


train_curve = []
val_curve = []
test_curve = []
max_val_score = 0  #find the best test score based on validation score
best_test_idx = 0
for epoch in range(1, epochs + 1):
    start_time = timeit.default_timer()
    train_dict = train()
    print("------------------------------------")
    print(f"training Epoch: {epoch:02d}")
    print(train_dict)
    train_curve.append(train_dict[eval_metric])
    print("Training takes--- %s seconds ---" % (timeit.default_timer() - start_time))

    start_time = timeit.default_timer()
    val_dict = test(val_loader)
    print(val_dict)
    val_curve.append(val_dict[eval_metric])
    if (val_dict[eval_metric] > max_val_score):
        max_val_score = val_dict[eval_metric]
        best_test_idx = epoch - 1
    print("Validation takes--- %s seconds ---" % (timeit.default_timer() - start_time))

    start_time = timeit.default_timer()
    test_dict = test(test_loader)
    print(test_dict)
    test_curve.append(test_dict[eval_metric])
    print("Test takes--- %s seconds ---" % (timeit.default_timer() - start_time))
    print("------------------------------------")
    dataset.reset_label_time()


# # code for plotting
# plot_curve(train_curve, "train_curve")
# plot_curve(val_curve, "val_curve")
# plot_curve(test_curve, "test_curve")

max_test_score = test_curve[best_test_idx]
print("------------------------------------")
print("------------------------------------")
print ("best val score: ", max_val_score)
print ("best validation epoch   : ", best_test_idx + 1)
print ("best test score: ", max_test_score)

setting random seed to be 1


17858396it [01:04, 276410.02it/s]
2741936it [00:07, 385468.34it/s]
100%|██████████| 62505/62505 [10:39<00:00, 97.78it/s] 


------------------------------------
training Epoch: 01
{'ce': 4181.9982040405275, 'ndcg': np.float64(0.36792309855575844)}
Training takes--- 639.262616316 seconds ---


100%|██████████| 13394/13394 [01:59<00:00, 111.75it/s]


{'ndcg': np.float64(0.38730832112337005)}
Validation takes--- 119.88037349199999 seconds ---


 99%|█████████▉| 13299/13394 [01:57<00:00, 113.64it/s]


{'ndcg': np.float64(0.3853531596391973)}
Test takes--- 117.03509413300003 seconds ---
------------------------------------


100%|██████████| 62505/62505 [10:44<00:00, 97.01it/s] 


------------------------------------
training Epoch: 02
{'ce': 3689.0365585693357, 'ndcg': np.float64(0.4297447891890143)}
Training takes--- 644.336956587 seconds ---


100%|██████████| 13394/13394 [02:01<00:00, 110.69it/s]


{'ndcg': np.float64(0.4352275719932199)}
Validation takes--- 121.03272802899983 seconds ---


 99%|█████████▉| 13299/13394 [01:58<00:00, 112.57it/s]


{'ndcg': np.float64(0.43768577179659424)}
Test takes--- 118.13994267899989 seconds ---
------------------------------------


100%|██████████| 62505/62505 [10:43<00:00, 97.11it/s] 


------------------------------------
training Epoch: 03
{'ce': 3390.636919842529, 'ndcg': np.float64(0.4718997734738662)}
Training takes--- 643.64905178 seconds ---


100%|██████████| 13394/13394 [02:00<00:00, 110.87it/s]


{'ndcg': np.float64(0.4572788138065822)}
Validation takes--- 120.82607129300004 seconds ---


 99%|█████████▉| 13299/13394 [01:57<00:00, 113.43it/s]


{'ndcg': np.float64(0.45991097178726126)}
Test takes--- 117.24707520500033 seconds ---
------------------------------------


100%|██████████| 62505/62505 [10:38<00:00, 97.90it/s] 


------------------------------------
training Epoch: 04
{'ce': 3213.5372873596193, 'ndcg': np.float64(0.493506276072485)}
Training takes--- 638.5058164449997 seconds ---


100%|██████████| 13394/13394 [01:57<00:00, 113.75it/s]


{'ndcg': np.float64(0.4725041720951601)}
Validation takes--- 117.77163852500007 seconds ---


 99%|█████████▉| 13299/13394 [01:54<00:00, 115.73it/s]


{'ndcg': np.float64(0.47645608670939943)}
Test takes--- 114.91734353900029 seconds ---
------------------------------------


100%|██████████| 62505/62505 [10:35<00:00, 98.33it/s] 


------------------------------------
training Epoch: 05
{'ce': 3091.1833522766115, 'ndcg': np.float64(0.5063896928391988)}
Training takes--- 635.7042688420001 seconds ---


100%|██████████| 13394/13394 [01:58<00:00, 112.86it/s]


{'ndcg': np.float64(0.48272652630593543)}
Validation takes--- 118.69626102999973 seconds ---


 99%|█████████▉| 13299/13394 [01:57<00:00, 113.63it/s]


{'ndcg': np.float64(0.486111918850336)}
Test takes--- 117.04519488699952 seconds ---
------------------------------------


100%|██████████| 62505/62505 [10:38<00:00, 97.94it/s] 


------------------------------------
training Epoch: 06
{'ce': 3009.147756665039, 'ndcg': np.float64(0.5138874030931987)}
Training takes--- 638.2446699860002 seconds ---


100%|██████████| 13394/13394 [02:00<00:00, 111.13it/s]


{'ndcg': np.float64(0.49093616400839407)}
Validation takes--- 120.54360248900048 seconds ---


 99%|█████████▉| 13299/13394 [01:55<00:00, 115.32it/s]


{'ndcg': np.float64(0.49439760854051623)}
Test takes--- 115.32621344499967 seconds ---
------------------------------------


100%|██████████| 62505/62505 [10:37<00:00, 98.03it/s] 


------------------------------------
training Epoch: 07
{'ce': 2952.557901672363, 'ndcg': np.float64(0.5189483517672706)}
Training takes--- 637.6137227939998 seconds ---


100%|██████████| 13394/13394 [02:00<00:00, 111.43it/s]


{'ndcg': np.float64(0.4957084656873508)}
Validation takes--- 120.22491074499976 seconds ---


 99%|█████████▉| 13299/13394 [01:58<00:00, 111.82it/s]


{'ndcg': np.float64(0.4997944217593414)}
Test takes--- 118.93278959300005 seconds ---
------------------------------------


100%|██████████| 62505/62505 [10:42<00:00, 97.26it/s] 


------------------------------------
training Epoch: 08
{'ce': 2912.240256866455, 'ndcg': np.float64(0.5219953143132667)}
Training takes--- 642.6847132550001 seconds ---


100%|██████████| 13394/13394 [01:57<00:00, 113.71it/s]


{'ndcg': np.float64(0.4990266190480628)}
Validation takes--- 117.81192817200008 seconds ---


 99%|█████████▉| 13299/13394 [01:55<00:00, 114.96it/s]


{'ndcg': np.float64(0.5035942546411201)}
Test takes--- 115.68357447800008 seconds ---
------------------------------------


100%|██████████| 62505/62505 [10:35<00:00, 98.30it/s] 


------------------------------------
training Epoch: 09
{'ce': 2881.6684244628905, 'ndcg': np.float64(0.5245042458663688)}
Training takes--- 635.8505678620004 seconds ---


100%|██████████| 13394/13394 [01:57<00:00, 113.63it/s]


{'ndcg': np.float64(0.5004974166376698)}
Validation takes--- 117.89928232700004 seconds ---


 99%|█████████▉| 13299/13394 [01:55<00:00, 115.18it/s]


{'ndcg': np.float64(0.5051165117101236)}
Test takes--- 115.46467185300025 seconds ---
------------------------------------


100%|██████████| 62505/62505 [10:34<00:00, 98.55it/s] 


------------------------------------
training Epoch: 10
{'ce': 2857.733162414551, 'ndcg': np.float64(0.5266310254560755)}
Training takes--- 634.2645298339994 seconds ---


100%|██████████| 13394/13394 [01:57<00:00, 113.72it/s]


{'ndcg': np.float64(0.5039339813106105)}
Validation takes--- 117.80697301100008 seconds ---


 99%|█████████▉| 13299/13394 [01:54<00:00, 115.69it/s]


{'ndcg': np.float64(0.507720923767984)}
Test takes--- 114.95489909199932 seconds ---
------------------------------------


100%|██████████| 62505/62505 [10:35<00:00, 98.36it/s] 


------------------------------------
training Epoch: 11
{'ce': 2837.9741048706055, 'ndcg': np.float64(0.5283178703343598)}
Training takes--- 635.4649954709985 seconds ---


100%|██████████| 13394/13394 [01:57<00:00, 113.97it/s]


{'ndcg': np.float64(0.5046980548498298)}
Validation takes--- 117.5497097110001 seconds ---


 99%|█████████▉| 13299/13394 [01:55<00:00, 115.40it/s]


{'ndcg': np.float64(0.5092697962839398)}
Test takes--- 115.2511762769991 seconds ---
------------------------------------


100%|██████████| 62505/62505 [10:33<00:00, 98.60it/s] 


------------------------------------
training Epoch: 12
{'ce': 2820.653523272705, 'ndcg': np.float64(0.5298576001240635)}
Training takes--- 633.967595483 seconds ---


100%|██████████| 13394/13394 [01:55<00:00, 115.98it/s]


{'ndcg': np.float64(0.5064736179768045)}
Validation takes--- 115.5084806649993 seconds ---


 99%|█████████▉| 13299/13394 [01:54<00:00, 116.16it/s]


{'ndcg': np.float64(0.5105261834080582)}
Test takes--- 114.48865454299994 seconds ---
------------------------------------


100%|██████████| 62505/62505 [10:32<00:00, 98.89it/s]


------------------------------------
training Epoch: 13
{'ce': 2806.1514885498045, 'ndcg': np.float64(0.5312574764374752)}
Training takes--- 632.0557527370001 seconds ---


100%|██████████| 13394/13394 [01:56<00:00, 114.78it/s]


{'ndcg': np.float64(0.5067490419683807)}
Validation takes--- 116.71816541699991 seconds ---


 99%|█████████▉| 13299/13394 [01:54<00:00, 116.17it/s]


{'ndcg': np.float64(0.5114305567220054)}
Test takes--- 114.48529612300081 seconds ---
------------------------------------


100%|██████████| 62505/62505 [10:35<00:00, 98.42it/s] 


------------------------------------
training Epoch: 14
{'ce': 2793.4049685546875, 'ndcg': np.float64(0.5324505402835353)}
Training takes--- 635.0941465530013 seconds ---


100%|██████████| 13394/13394 [01:57<00:00, 113.53it/s]


{'ndcg': np.float64(0.5082415489853085)}
Validation takes--- 117.99918986100056 seconds ---


 99%|█████████▉| 13299/13394 [01:55<00:00, 114.72it/s]


{'ndcg': np.float64(0.5130291183359812)}
Test takes--- 115.93494573799944 seconds ---
------------------------------------


100%|██████████| 62505/62505 [10:34<00:00, 98.57it/s]


------------------------------------
training Epoch: 15
{'ce': 2781.6632088928222, 'ndcg': np.float64(0.5335533500390768)}
Training takes--- 634.1145635180001 seconds ---


100%|██████████| 13394/13394 [01:58<00:00, 113.24it/s]


{'ndcg': np.float64(0.5091518758342348)}
Validation takes--- 118.3134674110006 seconds ---


 99%|█████████▉| 13299/13394 [01:54<00:00, 116.28it/s]


{'ndcg': np.float64(0.5142427318585058)}
Test takes--- 114.37681960400005 seconds ---
------------------------------------


100%|██████████| 62505/62505 [10:30<00:00, 99.16it/s]


------------------------------------
training Epoch: 16
{'ce': 2771.4682409423826, 'ndcg': np.float64(0.5343794241492141)}
Training takes--- 630.3859490390005 seconds ---


100%|██████████| 13394/13394 [01:59<00:00, 111.80it/s]


{'ndcg': np.float64(0.5100790008397854)}
Validation takes--- 119.82548690400108 seconds ---


 99%|█████████▉| 13299/13394 [01:54<00:00, 116.27it/s]


{'ndcg': np.float64(0.5152743837432613)}
Test takes--- 114.38621392900131 seconds ---
------------------------------------


100%|██████████| 62505/62505 [10:36<00:00, 98.14it/s] 


------------------------------------
training Epoch: 17
{'ce': 2761.8299852783202, 'ndcg': np.float64(0.535177703003092)}
Training takes--- 636.8826072200009 seconds ---


100%|██████████| 13394/13394 [01:59<00:00, 111.78it/s]


{'ndcg': np.float64(0.5105635597227994)}
Validation takes--- 119.84706790900054 seconds ---


 99%|█████████▉| 13299/13394 [01:56<00:00, 114.34it/s]


{'ndcg': np.float64(0.5167647652357614)}
Test takes--- 116.31829345700135 seconds ---
------------------------------------


100%|██████████| 62505/62505 [10:41<00:00, 97.48it/s] 


------------------------------------
training Epoch: 18
{'ce': 2752.9850010620116, 'ndcg': np.float64(0.5361049392960142)}
Training takes--- 641.2433639539995 seconds ---


100%|██████████| 13394/13394 [02:02<00:00, 109.40it/s]


{'ndcg': np.float64(0.5107956336248418)}
Validation takes--- 122.45956697800102 seconds ---


 99%|█████████▉| 13299/13394 [01:59<00:00, 111.09it/s]


{'ndcg': np.float64(0.5171348179557709)}
Test takes--- 119.72164374100066 seconds ---
------------------------------------


100%|██████████| 62505/62505 [10:43<00:00, 97.17it/s] 


------------------------------------
training Epoch: 19
{'ce': 2744.6880583007814, 'ndcg': np.float64(0.5367188915777575)}
Training takes--- 643.2595675730008 seconds ---


100%|██████████| 13394/13394 [01:59<00:00, 111.64it/s]


{'ndcg': np.float64(0.5111332656580793)}
Validation takes--- 120.00185527999929 seconds ---


 99%|█████████▉| 13299/13394 [01:56<00:00, 114.07it/s]


{'ndcg': np.float64(0.5174888788843548)}
Test takes--- 116.58666482899935 seconds ---
------------------------------------


100%|██████████| 62505/62505 [10:36<00:00, 98.15it/s] 


------------------------------------
training Epoch: 20
{'ce': 2736.950809893799, 'ndcg': np.float64(0.5374331912564428)}
Training takes--- 636.8361457780011 seconds ---


100%|██████████| 13394/13394 [01:57<00:00, 113.96it/s]


{'ndcg': np.float64(0.5117549887241618)}
Validation takes--- 117.55412279599841 seconds ---


 99%|█████████▉| 13299/13394 [01:56<00:00, 114.38it/s]


{'ndcg': np.float64(0.5176728529740038)}
Test takes--- 116.27255796200188 seconds ---
------------------------------------


100%|██████████| 62505/62505 [10:38<00:00, 97.84it/s] 


------------------------------------
training Epoch: 21
{'ce': 2729.9544872314455, 'ndcg': np.float64(0.5380272089108847)}
Training takes--- 638.8642282180008 seconds ---


100%|██████████| 13394/13394 [01:58<00:00, 112.63it/s]


{'ndcg': np.float64(0.5121612605411706)}
Validation takes--- 118.9384333909984 seconds ---


 99%|█████████▉| 13299/13394 [01:56<00:00, 114.18it/s]


{'ndcg': np.float64(0.5186278878341108)}
Test takes--- 116.47552764599823 seconds ---
------------------------------------


100%|██████████| 62505/62505 [10:31<00:00, 99.03it/s] 


------------------------------------
training Epoch: 22
{'ce': 2723.2361229370117, 'ndcg': np.float64(0.5385485702108765)}
Training takes--- 631.1712380330027 seconds ---


100%|██████████| 13394/13394 [01:57<00:00, 113.91it/s]


{'ndcg': np.float64(0.512724134895731)}
Validation takes--- 117.60576091400071 seconds ---


 99%|█████████▉| 13299/13394 [01:55<00:00, 114.98it/s]


{'ndcg': np.float64(0.5187848892478127)}
Test takes--- 115.66738347499995 seconds ---
------------------------------------


100%|██████████| 62505/62505 [10:36<00:00, 98.14it/s]


------------------------------------
training Epoch: 23
{'ce': 2716.6076707519533, 'ndcg': np.float64(0.5390087078197194)}
Training takes--- 636.886244389003 seconds ---


100%|██████████| 13394/13394 [01:57<00:00, 114.08it/s]


{'ndcg': np.float64(0.5133089649761386)}
Validation takes--- 117.43340347199774 seconds ---


 99%|█████████▉| 13299/13394 [01:54<00:00, 116.58it/s]


{'ndcg': np.float64(0.5197538315801494)}
Test takes--- 114.07734406500094 seconds ---
------------------------------------


100%|██████████| 62505/62505 [10:29<00:00, 99.31it/s]


------------------------------------
training Epoch: 24
{'ce': 2710.70874576416, 'ndcg': np.float64(0.5396285673685455)}
Training takes--- 629.4180291429984 seconds ---


100%|██████████| 13394/13394 [01:55<00:00, 115.69it/s]


{'ndcg': np.float64(0.513852200836215)}
Validation takes--- 115.80154389100062 seconds ---


 99%|█████████▉| 13299/13394 [01:54<00:00, 116.49it/s]


{'ndcg': np.float64(0.5201259495394316)}
Test takes--- 114.17405227299969 seconds ---
------------------------------------


100%|██████████| 62505/62505 [10:35<00:00, 98.36it/s] 


------------------------------------
training Epoch: 25
{'ce': 2704.949270465088, 'ndcg': np.float64(0.540047095055951)}
Training takes--- 635.4811498359995 seconds ---


100%|██████████| 13394/13394 [01:58<00:00, 113.33it/s]


{'ndcg': np.float64(0.5148192604115452)}
Validation takes--- 118.21175304099961 seconds ---


 99%|█████████▉| 13299/13394 [01:55<00:00, 114.96it/s]


{'ndcg': np.float64(0.5212893570749784)}
Test takes--- 115.68761399199866 seconds ---
------------------------------------


100%|██████████| 62505/62505 [10:36<00:00, 98.17it/s] 


------------------------------------
training Epoch: 26
{'ce': 2699.3254629211424, 'ndcg': np.float64(0.5404316777314729)}
Training takes--- 636.7283008449995 seconds ---


100%|██████████| 13394/13394 [01:58<00:00, 113.00it/s]


{'ndcg': np.float64(0.5148972340751605)}
Validation takes--- 118.5582370009979 seconds ---


 99%|█████████▉| 13299/13394 [01:57<00:00, 113.45it/s]


{'ndcg': np.float64(0.5213466676342978)}
Test takes--- 117.23257242300315 seconds ---
------------------------------------


100%|██████████| 62505/62505 [10:35<00:00, 98.35it/s] 


------------------------------------
training Epoch: 27
{'ce': 2693.971320300293, 'ndcg': np.float64(0.5407717617582766)}
Training takes--- 635.5451224339995 seconds ---


100%|██████████| 13394/13394 [01:56<00:00, 114.75it/s]


{'ndcg': np.float64(0.5153224966968198)}
Validation takes--- 116.74847504200079 seconds ---


 99%|█████████▉| 13299/13394 [01:54<00:00, 116.04it/s]


{'ndcg': np.float64(0.5213148819258088)}
Test takes--- 114.60831826900176 seconds ---
------------------------------------


100%|██████████| 62505/62505 [10:31<00:00, 99.04it/s] 


------------------------------------
training Epoch: 28
{'ce': 2688.9624670959474, 'ndcg': np.float64(0.5412665366351385)}
Training takes--- 631.1004437580013 seconds ---


100%|██████████| 13394/13394 [01:56<00:00, 114.56it/s]


{'ndcg': np.float64(0.5148575333055518)}
Validation takes--- 116.94373507499768 seconds ---


 99%|█████████▉| 13299/13394 [01:54<00:00, 116.19it/s]


{'ndcg': np.float64(0.5203804938450627)}
Test takes--- 114.46574515699831 seconds ---
------------------------------------


100%|██████████| 62505/62505 [10:30<00:00, 99.15it/s] 


------------------------------------
training Epoch: 29
{'ce': 2684.1020813873292, 'ndcg': np.float64(0.5415842396169068)}
Training takes--- 630.4566267530026 seconds ---


100%|██████████| 13394/13394 [01:56<00:00, 114.74it/s]


{'ndcg': np.float64(0.5137052487542878)}
Validation takes--- 116.75899786000082 seconds ---


 99%|█████████▉| 13299/13394 [01:54<00:00, 115.78it/s]


{'ndcg': np.float64(0.5188955057801796)}
Test takes--- 114.86909603800086 seconds ---
------------------------------------


100%|██████████| 62505/62505 [10:33<00:00, 98.72it/s] 


------------------------------------
training Epoch: 30
{'ce': 2679.3804178771975, 'ndcg': np.float64(0.5418968736686053)}
Training takes--- 633.1938943470013 seconds ---


100%|██████████| 13394/13394 [01:57<00:00, 114.19it/s]


{'ndcg': np.float64(0.5140303216959891)}
Validation takes--- 117.31896754500121 seconds ---


 99%|█████████▉| 13299/13394 [01:55<00:00, 115.62it/s]


{'ndcg': np.float64(0.5198093951078029)}
Test takes--- 115.02679349400205 seconds ---
------------------------------------


100%|██████████| 62505/62505 [10:34<00:00, 98.57it/s] 


------------------------------------
training Epoch: 31
{'ce': 2674.7121172424318, 'ndcg': np.float64(0.5422311515460622)}
Training takes--- 634.1299607590008 seconds ---


100%|██████████| 13394/13394 [01:57<00:00, 113.84it/s]


{'ndcg': np.float64(0.5142569031559475)}
Validation takes--- 117.67825017400173 seconds ---


 99%|█████████▉| 13299/13394 [01:55<00:00, 115.28it/s]


{'ndcg': np.float64(0.5200921407451831)}
Test takes--- 115.36882311399677 seconds ---
------------------------------------


100%|██████████| 62505/62505 [10:31<00:00, 98.93it/s] 


------------------------------------
training Epoch: 32
{'ce': 2670.2920359802247, 'ndcg': np.float64(0.5424883470072115)}
Training takes--- 631.8118551740008 seconds ---


100%|██████████| 13394/13394 [01:56<00:00, 114.69it/s]


{'ndcg': np.float64(0.5147403223918194)}
Validation takes--- 116.81138614700103 seconds ---


 99%|█████████▉| 13299/13394 [01:57<00:00, 112.94it/s]


{'ndcg': np.float64(0.5211597173043312)}
Test takes--- 117.7617507670002 seconds ---
------------------------------------


100%|██████████| 62505/62505 [10:38<00:00, 97.82it/s] 


------------------------------------
training Epoch: 33
{'ce': 2665.9775551696775, 'ndcg': np.float64(0.5428811482205922)}
Training takes--- 638.9666152080026 seconds ---


100%|██████████| 13394/13394 [01:57<00:00, 114.21it/s]


{'ndcg': np.float64(0.5152727083626968)}
Validation takes--- 117.29887236200011 seconds ---


 99%|█████████▉| 13299/13394 [01:54<00:00, 116.27it/s]


{'ndcg': np.float64(0.5219159512252439)}
Test takes--- 114.3830412870011 seconds ---
------------------------------------


100%|██████████| 62505/62505 [10:31<00:00, 99.01it/s] 


------------------------------------
training Epoch: 34
{'ce': 2661.6463364868164, 'ndcg': np.float64(0.5430660568888986)}
Training takes--- 631.3466250380006 seconds ---


100%|██████████| 13394/13394 [01:57<00:00, 113.80it/s]


{'ndcg': np.float64(0.5157868739045854)}
Validation takes--- 117.725389301002 seconds ---


 99%|█████████▉| 13299/13394 [01:54<00:00, 115.80it/s]


{'ndcg': np.float64(0.5219231891404964)}
Test takes--- 114.84853885399934 seconds ---
------------------------------------


100%|██████████| 62505/62505 [10:31<00:00, 98.97it/s] 


------------------------------------
training Epoch: 35
{'ce': 2657.681897338867, 'ndcg': np.float64(0.5434092213711027)}
Training takes--- 631.5771259329995 seconds ---


100%|██████████| 13394/13394 [01:56<00:00, 115.06it/s]


{'ndcg': np.float64(0.5162159252440771)}
Validation takes--- 116.43052641099712 seconds ---


 99%|█████████▉| 13299/13394 [01:54<00:00, 116.36it/s]


{'ndcg': np.float64(0.5225935826333986)}
Test takes--- 114.30295687800026 seconds ---
------------------------------------


100%|██████████| 62505/62505 [10:29<00:00, 99.31it/s] 


------------------------------------
training Epoch: 36
{'ce': 2653.680057781982, 'ndcg': np.float64(0.5436900508895915)}
Training takes--- 629.4191908950015 seconds ---


100%|██████████| 13394/13394 [01:58<00:00, 112.84it/s]


{'ndcg': np.float64(0.5168128381539457)}
Validation takes--- 118.71815244199752 seconds ---


 99%|█████████▉| 13299/13394 [01:56<00:00, 114.63it/s]


{'ndcg': np.float64(0.5228046837994133)}
Test takes--- 116.02465050699902 seconds ---
------------------------------------


100%|██████████| 62505/62505 [10:32<00:00, 98.75it/s] 


------------------------------------
training Epoch: 37
{'ce': 2649.502627209473, 'ndcg': np.float64(0.5438211136272044)}
Training takes--- 632.9519296149992 seconds ---


100%|██████████| 13394/13394 [01:58<00:00, 113.37it/s]


{'ndcg': np.float64(0.5168263092748245)}
Validation takes--- 118.16839936299948 seconds ---


 99%|█████████▉| 13299/13394 [01:55<00:00, 115.54it/s]


{'ndcg': np.float64(0.5231245498450122)}
Test takes--- 115.10929650600156 seconds ---
------------------------------------


100%|██████████| 62505/62505 [10:36<00:00, 98.17it/s] 


------------------------------------
training Epoch: 38
{'ce': 2645.717973150635, 'ndcg': np.float64(0.5441965357942652)}
Training takes--- 636.6924857400008 seconds ---


100%|██████████| 13394/13394 [02:00<00:00, 111.17it/s]


{'ndcg': np.float64(0.5169948654302293)}
Validation takes--- 120.49940993900236 seconds ---


 99%|█████████▉| 13299/13394 [01:57<00:00, 113.61it/s]


{'ndcg': np.float64(0.5234065493342615)}
Test takes--- 117.06385531400156 seconds ---
------------------------------------


100%|██████████| 62505/62505 [10:33<00:00, 98.74it/s]


------------------------------------
training Epoch: 39
{'ce': 2641.8269653320312, 'ndcg': np.float64(0.5443818551248862)}
Training takes--- 633.0606591239994 seconds ---


100%|██████████| 13394/13394 [01:58<00:00, 113.30it/s]


{'ndcg': np.float64(0.5172621001942437)}
Validation takes--- 118.24881659100356 seconds ---


 99%|█████████▉| 13299/13394 [01:55<00:00, 115.11it/s]


{'ndcg': np.float64(0.5242220617069165)}
Test takes--- 115.54096437900444 seconds ---
------------------------------------


100%|██████████| 62505/62505 [10:38<00:00, 97.89it/s] 


------------------------------------
training Epoch: 40
{'ce': 2638.238241693115, 'ndcg': np.float64(0.5447968331087383)}
Training takes--- 638.5217525420012 seconds ---


100%|██████████| 13394/13394 [01:59<00:00, 111.68it/s]


{'ndcg': np.float64(0.517168948381993)}
Validation takes--- 119.95318279699859 seconds ---


 99%|█████████▉| 13299/13394 [01:56<00:00, 114.05it/s]


{'ndcg': np.float64(0.5243445159160633)}
Test takes--- 116.6075680110007 seconds ---
------------------------------------


100%|██████████| 62505/62505 [10:37<00:00, 98.05it/s] 


------------------------------------
training Epoch: 41
{'ce': 2634.8153390441894, 'ndcg': np.float64(0.5448757128756864)}
Training takes--- 637.4769006709976 seconds ---


100%|██████████| 13394/13394 [01:56<00:00, 114.54it/s]


{'ndcg': np.float64(0.5171634905581493)}
Validation takes--- 116.95694963999995 seconds ---


 99%|█████████▉| 13299/13394 [01:55<00:00, 115.55it/s]


{'ndcg': np.float64(0.5245054080032505)}
Test takes--- 115.09727936799754 seconds ---
------------------------------------


100%|██████████| 62505/62505 [10:39<00:00, 97.75it/s] 


------------------------------------
training Epoch: 42
{'ce': 2631.35123427124, 'ndcg': np.float64(0.5451865207460659)}
Training takes--- 639.4824859879955 seconds ---


100%|██████████| 13394/13394 [01:59<00:00, 111.98it/s]


{'ndcg': np.float64(0.5174786418204096)}
Validation takes--- 119.63130099699629 seconds ---


 99%|█████████▉| 13299/13394 [01:56<00:00, 114.07it/s]


{'ndcg': np.float64(0.524456696806745)}
Test takes--- 116.59135430000606 seconds ---
------------------------------------


100%|██████████| 62505/62505 [10:38<00:00, 97.94it/s] 


------------------------------------
training Epoch: 43
{'ce': 2627.7414395996093, 'ndcg': np.float64(0.5453841122278557)}
Training takes--- 638.1814797940024 seconds ---


100%|██████████| 13394/13394 [01:59<00:00, 112.24it/s]


{'ndcg': np.float64(0.5163976694121625)}
Validation takes--- 119.35880845000065 seconds ---


 99%|█████████▉| 13299/13394 [01:57<00:00, 113.55it/s]


{'ndcg': np.float64(0.5231524402707783)}
Test takes--- 117.11982448300114 seconds ---
------------------------------------


100%|██████████| 62505/62505 [10:35<00:00, 98.35it/s]


------------------------------------
training Epoch: 44
{'ce': 2624.364042666626, 'ndcg': np.float64(0.5456127939556633)}
Training takes--- 635.5530503409973 seconds ---


100%|██████████| 13394/13394 [01:55<00:00, 115.97it/s]


{'ndcg': np.float64(0.5165156465567527)}
Validation takes--- 115.52305165400321 seconds ---


 99%|█████████▉| 13299/13394 [01:54<00:00, 115.91it/s]


{'ndcg': np.float64(0.5232268716588193)}
Test takes--- 114.74527270199906 seconds ---
------------------------------------


100%|██████████| 62505/62505 [10:33<00:00, 98.68it/s] 


------------------------------------
training Epoch: 45
{'ce': 2621.0885615600587, 'ndcg': np.float64(0.5457978219891128)}
Training takes--- 633.4598765990013 seconds ---


100%|██████████| 13394/13394 [01:56<00:00, 115.15it/s]


{'ndcg': np.float64(0.5145322576289088)}
Validation takes--- 116.33700453699566 seconds ---


 99%|█████████▉| 13299/13394 [01:54<00:00, 115.91it/s]


{'ndcg': np.float64(0.5217738576574792)}
Test takes--- 114.74108471699583 seconds ---
------------------------------------


100%|██████████| 62505/62505 [10:31<00:00, 98.96it/s] 


------------------------------------
training Epoch: 46
{'ce': 2618.053621017456, 'ndcg': np.float64(0.5459609326995294)}
Training takes--- 631.6138277330028 seconds ---


100%|██████████| 13394/13394 [01:57<00:00, 114.00it/s]


{'ndcg': np.float64(0.515813045973166)}
Validation takes--- 117.5132233819968 seconds ---


 99%|█████████▉| 13299/13394 [01:54<00:00, 116.43it/s]


{'ndcg': np.float64(0.522251818705113)}
Test takes--- 114.22944870599895 seconds ---
------------------------------------


100%|██████████| 62505/62505 [10:36<00:00, 98.15it/s] 


------------------------------------
training Epoch: 47
{'ce': 2614.683027999878, 'ndcg': np.float64(0.546232120224213)}
Training takes--- 636.8264549129963 seconds ---


100%|██████████| 13394/13394 [01:58<00:00, 113.07it/s]


{'ndcg': np.float64(0.5164363604882882)}
Validation takes--- 118.48528845300461 seconds ---


 99%|█████████▉| 13299/13394 [01:54<00:00, 115.97it/s]


{'ndcg': np.float64(0.5230207146525233)}
Test takes--- 114.67763105400081 seconds ---
------------------------------------


100%|██████████| 62505/62505 [10:36<00:00, 98.14it/s] 


------------------------------------
training Epoch: 48
{'ce': 2611.7155628845217, 'ndcg': np.float64(0.546315441024285)}
Training takes--- 636.8888959249962 seconds ---


100%|██████████| 13394/13394 [02:01<00:00, 110.29it/s]


{'ndcg': np.float64(0.5161599520834871)}
Validation takes--- 121.47198261200538 seconds ---


 99%|█████████▉| 13299/13394 [01:59<00:00, 111.74it/s]


{'ndcg': np.float64(0.5230298801290433)}
Test takes--- 119.02135442000144 seconds ---
------------------------------------


100%|██████████| 62505/62505 [10:50<00:00, 96.10it/s] 


------------------------------------
training Epoch: 49
{'ce': 2608.804053338623, 'ndcg': np.float64(0.5464860429491316)}
Training takes--- 650.4133615159953 seconds ---


100%|██████████| 13394/13394 [02:00<00:00, 110.72it/s]


{'ndcg': np.float64(0.5163342955882831)}
Validation takes--- 120.99924449399987 seconds ---


 99%|█████████▉| 13299/13394 [01:59<00:00, 111.59it/s]


{'ndcg': np.float64(0.5233691438164435)}
Test takes--- 119.18294119799975 seconds ---
------------------------------------


100%|██████████| 62505/62505 [10:51<00:00, 95.93it/s] 


------------------------------------
training Epoch: 50
{'ce': 2605.7523179870605, 'ndcg': np.float64(0.5467671327750111)}
Training takes--- 651.613068106999 seconds ---


100%|██████████| 13394/13394 [02:00<00:00, 111.33it/s]


{'ndcg': np.float64(0.5155810154144402)}
Validation takes--- 120.33526552099647 seconds ---


 99%|█████████▉| 13299/13394 [01:59<00:00, 111.37it/s]

{'ndcg': np.float64(0.5222871567394357)}
Test takes--- 119.42224629900011 seconds ---
------------------------------------
------------------------------------
------------------------------------
best val score:  0.5174786418204096
best validation epoch   :  42
best test score:  0.524456696806745





(NDCG-Loss2++ was used for training, not Cross Entropy, we hadn't updated the string in metric_dict when we ran the training, so 'ce' should be 'ndcgloss2++')