In [1]:
import copy
from typing import Callable, Dict, Tuple

import torch
from torch import Tensor
from torch.nn import GRUCell, Linear

from torch_geometric.nn.inits import zeros
from torch_geometric.utils import scatter
from torch_geometric.utils._scatter import scatter_argmax

TGNMessageStoreType = Dict[int, Tuple[Tensor, Tensor, Tensor, Tensor]]


class TGNMemory(torch.nn.Module):
    r"""The Temporal Graph Network (TGN) memory model from the
    `"Temporal Graph Networks for Deep Learning on Dynamic Graphs"
    <https://arxiv.org/abs/2006.10637>`_ paper.

    .. note::

        For an example of using TGN, see `examples/tgn.py
        <https://github.com/pyg-team/pytorch_geometric/blob/master/examples/
        tgn.py>`_.

    Args:
        num_nodes (int): The number of nodes to save memories for.
        raw_msg_dim (int): The raw message dimensionality.
        memory_dim (int): The hidden memory dimensionality.
        time_dim (int): The time encoding dimensionality.
        message_module (torch.nn.Module): The message function which
            combines source and destination node memory embeddings, the raw
            message and the time encoding.
        aggregator_module (torch.nn.Module): The message aggregator function
            which aggregates messages to the same destination into a single
            representation.
    """
    def __init__(self, num_nodes: int, raw_msg_dim: int, memory_dim: int,
                 time_dim: int, message_module: Callable,
                 aggregator_module: Callable):
        super().__init__()

        self.num_nodes = num_nodes
        self.raw_msg_dim = raw_msg_dim
        self.memory_dim = memory_dim
        self.time_dim = time_dim

        self.msg_s_module = message_module
        self.msg_d_module = copy.deepcopy(message_module)
        self.aggr_module = aggregator_module
        self.time_enc = TimeEncoder(time_dim)
        self.gru = GRUCell(message_module.out_channels, memory_dim)

        self.register_buffer('memory', torch.empty(num_nodes, memory_dim))
        last_update = torch.empty(self.num_nodes, dtype=torch.long)
        self.register_buffer('last_update', last_update)
        self.register_buffer('_assoc', torch.empty(num_nodes,
                                                   dtype=torch.long))

        self.msg_s_store = {}
        self.msg_d_store = {}

        self.reset_parameters()

    @property
    def device(self) -> torch.device:
        return self.time_enc.lin.weight.device

    def reset_parameters(self):
        r"""Resets all learnable parameters of the module."""
        if hasattr(self.msg_s_module, 'reset_parameters'):
            self.msg_s_module.reset_parameters()
        if hasattr(self.msg_d_module, 'reset_parameters'):
            self.msg_d_module.reset_parameters()
        if hasattr(self.aggr_module, 'reset_parameters'):
            self.aggr_module.reset_parameters()
        self.time_enc.reset_parameters()
        self.gru.reset_parameters()
        self.reset_state()

    def reset_state(self):
        """Resets the memory to its initial state."""
        zeros(self.memory)
        zeros(self.last_update)
        self._reset_message_store()

    def detach(self):
        """Detaches the memory from gradient computation."""
        self.memory.detach_()

    def forward(self, n_id: Tensor) -> Tuple[Tensor, Tensor]:
        """Returns, for all nodes :obj:`n_id`, their current memory and their
        last updated timestamp.
        """
        if self.training:
            memory, last_update = self._get_updated_memory(n_id)
        else:
            memory, last_update = self.memory[n_id], self.last_update[n_id]

        return memory, last_update

    def update_state(self, src: Tensor, dst: Tensor, t: Tensor,
                     raw_msg: Tensor):
        """Updates the memory with newly encountered interactions
        :obj:`(src, dst, t, raw_msg)`.
        """
        n_id = torch.cat([src, dst]).unique()

        if self.training:
            self._update_memory(n_id)
            self._update_msg_store(src, dst, t, raw_msg, self.msg_s_store)
            self._update_msg_store(dst, src, t, raw_msg, self.msg_d_store)
        else:
            self._update_msg_store(src, dst, t, raw_msg, self.msg_s_store)
            self._update_msg_store(dst, src, t, raw_msg, self.msg_d_store)
            self._update_memory(n_id)

    def _reset_message_store(self):
        i = self.memory.new_empty((0, ), device=self.device, dtype=torch.long)
        msg = self.memory.new_empty((0, self.raw_msg_dim), device=self.device)
        # Message store format: (src, dst, t, msg)
        self.msg_s_store = {j: (i, i, i, msg) for j in range(self.num_nodes)}
        self.msg_d_store = {j: (i, i, i, msg) for j in range(self.num_nodes)}

    # def _update_memory(self, n_id: Tensor):
    #     memory, last_update = self._get_updated_memory(n_id)
    #     self.memory[n_id] = memory
    #     self.last_update[n_id] = last_update

    def _update_memory(self, n_id):
        memory, last_update = self._get_updated_memory(n_id)
        # self.memory[n_id] = memory
        # self.last_update[n_id] = last_update

        self.memory[n_id.long()] = memory
        self.last_update[n_id.float().long()] = last_update.long()


    def _get_updated_memory(self, n_id: Tensor) -> Tuple[Tensor, Tensor]:
        self._assoc[n_id] = torch.arange(n_id.size(0), device=n_id.device)

        # Compute messages (src -> dst).
        msg_s, t_s, src_s, dst_s = self._compute_msg(n_id, self.msg_s_store,
                                                     self.msg_s_module)

        # Compute messages (dst -> src).
        msg_d, t_d, src_d, dst_d = self._compute_msg(n_id, self.msg_d_store,
                                                     self.msg_d_module)

        # Aggregate messages.
        idx = torch.cat([src_s, src_d], dim=0)
        msg = torch.cat([msg_s, msg_d], dim=0)
        t = torch.cat([t_s, t_d], dim=0)
        aggr = self.aggr_module(msg, self._assoc[idx], t, n_id.size(0))

        # Get local copy of updated memory.
        memory = self.gru(aggr, self.memory[n_id])

        # Get local copy of updated `last_update`.
        dim_size = self.last_update.size(0)
        last_update = scatter(t, idx, 0, dim_size, reduce='max')[n_id]
        return memory, last_update


    def _update_msg_store(self, src: Tensor, dst: Tensor, t: Tensor,
                          raw_msg: Tensor, msg_store: TGNMessageStoreType):
        n_id, perm = src.sort()
        n_id, count = n_id.unique_consecutive(return_counts=True)
        for i, idx in zip(n_id.tolist(), perm.split(count.tolist())):
            msg_store[i] = (src[idx], dst[idx], t[idx], raw_msg[idx])

    def _compute_msg(self, n_id: Tensor, msg_store: TGNMessageStoreType,
                     msg_module: Callable):
        data = [msg_store[i] for i in n_id.tolist()]
        src, dst, t, raw_msg = list(zip(*data))
        src = torch.cat(src, dim=0)
        dst = torch.cat(dst, dim=0)
        t = torch.cat(t, dim=0)
        # Filter out empty tensors to avoid `invalid configuration argument`.
        # TODO Investigate why this is needed.
        raw_msg = [m for i, m in enumerate(raw_msg) if m.numel() > 0 or i == 0]
        raw_msg = torch.cat(raw_msg, dim=0)

        #--------------------------adding here----------------------------------#
        t = t - self.last_update[src]
        t = t.float()
        t_enc = self.time_enc(t)
        #_______________________________________________________________________#

        # t_enc = self.time_enc(t - self.last_update[src])

        msg = msg_module(self.memory[src], self.memory[dst], raw_msg, t_enc)

        return msg, t, src, dst

    def train(self, mode: bool = True):
        """Sets the module in training mode."""
        if self.training and not mode:
            # Flush message store to memory in case we just entered eval mode.
            self._update_memory(
                torch.arange(self.num_nodes, device=self.memory.device))
            self._reset_message_store()
        super().train(mode)


class IdentityMessage(torch.nn.Module):
    def __init__(self, raw_msg_dim: int, memory_dim: int, time_dim: int):
        super().__init__()
        self.out_channels = raw_msg_dim + 2 * memory_dim + time_dim

    def forward(self, z_src: Tensor, z_dst: Tensor, raw_msg: Tensor,
                t_enc: Tensor):
        return torch.cat([z_src, z_dst, raw_msg, t_enc], dim=-1)


class LastAggregator(torch.nn.Module):
    def forward(self, msg: Tensor, index: Tensor, t: Tensor, dim_size: int):
        argmax = scatter_argmax(t, index, dim=0, dim_size=dim_size)
        out = msg.new_zeros((dim_size, msg.size(-1)))
        mask = argmax < msg.size(0)  # Filter items with at least one entry.
        out[mask] = msg[argmax[mask]]
        return out


class MeanAggregator(torch.nn.Module):
    def forward(self, msg: Tensor, index: Tensor, t: Tensor, dim_size: int):
        return scatter(msg, index, dim=0, dim_size=dim_size, reduce='mean')


class TimeEncoder(torch.nn.Module):
    def __init__(self, out_channels: int):
        super().__init__()
        self.out_channels = out_channels
        self.lin = Linear(1, out_channels)

    def reset_parameters(self):
        self.lin.reset_parameters()

    def forward(self, t: Tensor) -> Tensor:
        return self.lin(t.view(-1, 1)).cos()


class LastNeighborLoader:
    def __init__(self, num_nodes: int, size: int, device=None):
        self.size = size

        self.neighbors = torch.empty((num_nodes, size), dtype=torch.long,
                                     device=device)
        self.e_id = torch.empty((num_nodes, size), dtype=torch.long,
                                device=device)
        self._assoc = torch.empty(num_nodes, dtype=torch.long, device=device)

        self.reset_state()

    def __call__(self, n_id: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
        neighbors = self.neighbors[n_id]
        nodes = n_id.view(-1, 1).repeat(1, self.size)
        e_id = self.e_id[n_id]

        # Filter invalid neighbors (identified by `e_id < 0`).
        mask = e_id >= 0
        neighbors, nodes, e_id = neighbors[mask], nodes[mask], e_id[mask]

        # Relabel node indices.
        n_id = torch.cat([n_id, neighbors]).unique()
        self._assoc[n_id] = torch.arange(n_id.size(0), device=n_id.device)
        neighbors, nodes = self._assoc[neighbors], self._assoc[nodes]

        return n_id, torch.stack([neighbors, nodes]), e_id

    def insert(self, src: Tensor, dst: Tensor):
        # Inserts newly encountered interactions into an ever-growing
        # (undirected) temporal graph.

        # Collect central nodes, their neighbors and the current event ids.
        neighbors = torch.cat([src, dst], dim=0)
        nodes = torch.cat([dst, src], dim=0)
        e_id = torch.arange(self.cur_e_id, self.cur_e_id + src.size(0),
                            device=src.device).repeat(2)
        self.cur_e_id += src.numel()

        # Convert newly encountered interaction ids so that they point to
        # locations of a "dense" format of shape [num_nodes, size].
        nodes, perm = nodes.sort()
        neighbors, e_id = neighbors[perm], e_id[perm]

        n_id = nodes.unique()
        self._assoc[n_id] = torch.arange(n_id.numel(), device=n_id.device)

        dense_id = torch.arange(nodes.size(0), device=nodes.device) % self.size
        dense_id += self._assoc[nodes].mul_(self.size)

        dense_e_id = e_id.new_full((n_id.numel() * self.size, ), -1)
        dense_e_id[dense_id] = e_id
        dense_e_id = dense_e_id.view(-1, self.size)

        dense_neighbors = e_id.new_empty(n_id.numel() * self.size)
        dense_neighbors[dense_id] = neighbors
        dense_neighbors = dense_neighbors.view(-1, self.size)

        # Collect new and old interactions...
        e_id = torch.cat([self.e_id[n_id, :self.size], dense_e_id], dim=-1)
        neighbors = torch.cat(
            [self.neighbors[n_id, :self.size], dense_neighbors], dim=-1)

        # And sort them based on `e_id`.
        e_id, perm = e_id.topk(self.size, dim=-1)
        self.e_id[n_id] = e_id
        self.neighbors[n_id] = torch.gather(neighbors, 1, perm)

    def reset_state(self):
        self.cur_e_id = 0
        self.e_id.fill_(-1)


In [2]:
import pandas as pd
import numpy as np

import torch
from torch.nn import Linear
from sklearn.metrics import average_precision_score, roc_auc_score
from torch_geometric.datasets import JODIEDataset
from torch_geometric.loader import TemporalDataLoader
# from torch_geometric.nn import TGNMemory, TransformerConv
from torch_geometric.nn import TransformerConv
from torch_geometric.nn.models.tgn import (
    IdentityMessage,
    LastAggregator,
    LastNeighborLoader,
)

import os
from torch.nn import BCEWithLogitsLoss, MSELoss

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [4]:
raw_data = pd.read_csv('weeklyAggregatedCitibike_for_train.csv')
raw_data["target"] = 0
raw_data

Unnamed: 0,start_zone_encoded,end_zone_encoded,weight,time,target
0,4,12,2,0,0
1,4,13,15,0,0
2,4,17,1,0,0
3,4,25,3,0,0
4,4,33,13,0,0
...,...,...,...,...,...
3134446,263,247,1,551,0
3134447,263,249,6,551,0
3134448,263,256,3,551,0
3134449,263,260,2,551,0


In [5]:
from torch import Tensor
from torch_geometric.data import TemporalData

data_temporal = TemporalData(
    src=Tensor(raw_data["start_zone_encoded"]).long(),
    dst=Tensor(raw_data["end_zone_encoded"]).long(),
    t=Tensor(raw_data["time"]),
    msg=Tensor(raw_data["weight"]).unsqueeze(-1)
)

data_temporal.y = Tensor(raw_data["target"])

In [6]:
data_temporal

TemporalData(src=[3134451], dst=[3134451], t=[3134451], msg=[3134451, 1], y=[3134451])

In [7]:
print("Source nodes:", data_temporal.src)
print("Destination nodes:", data_temporal.dst)
print("Timestamps:", data_temporal.t)
print("Messages or edge features:", data_temporal.msg)
print("Target values:", data_temporal.y)

Source nodes: tensor([  4,   4,   4,  ..., 263, 263, 263])
Destination nodes: tensor([ 12,  13,  17,  ..., 256, 260, 262])
Timestamps: tensor([  0.,   0.,   0.,  ..., 551., 551., 551.])
Messages or edge features: tensor([[  2.],
        [ 15.],
        [  1.],
        ...,
        [  3.],
        [  2.],
        [212.]])
Target values: tensor([0., 0., 0.,  ..., 0., 0., 0.])


In [8]:
print("Source nodes:", len(data_temporal.src))
print("Destination nodes:", len(data_temporal.dst))
print("Timestamps:", len(data_temporal.t))
print("Messages or edge features:", len(data_temporal.msg))
print("Target values:", len(data_temporal.y))

Source nodes: 3134451
Destination nodes: 3134451
Timestamps: 3134451
Messages or edge features: 3134451
Target values: 3134451


In [9]:
print("Source nodes:", data_temporal.src.unique())
print("Destination nodes:", data_temporal.dst.unique())
print("Timestamps:", data_temporal.t.unique())
print("Messages or edge features:", data_temporal.msg.unique())
print("Target values:", data_temporal.y.unique())

# source_unique = [x for x in data_temporal.src.unique() if x not in data_temporal.dst.unique()]
# destination_unique = [x for x in data_temporal.dst.unique() if x not in data_temporal.src.unique()]

# print("Source not in destination:",source_unique)
# print("Destination not in source:",destination_unique)

Source nodes: tensor([  2,   4,   7,   8,  12,  13,  14,  17,  18,  20,  21,  22,  24,  25,
         26,  28,  31,  32,  33,  34,  35,  36,  37,  39,  40,  41,  42,  43,
         45,  47,  48,  49,  50,  52,  54,  55,  57,  59,  60,  61,  62,  63,
         65,  66,  67,  68,  69,  70,  71,  72,  74,  75,  76,  77,  78,  79,
         80,  81,  82,  83,  85,  87,  88,  89,  90,  91,  92,  93,  94,  96,
         97, 100, 102, 105, 106, 107, 111, 112, 113, 114, 116, 117, 119, 120,
        123, 125, 126, 127, 128, 129, 133, 134, 136, 137, 138, 140, 141, 142,
        143, 144, 145, 146, 147, 148, 149, 151, 152, 153, 157, 158, 159, 160,
        161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 173, 174, 177, 179,
        181, 182, 185, 186, 188, 189, 190, 193, 194, 195, 196, 198, 202, 206,
        207, 209, 211, 212, 213, 217, 220, 222, 223, 224, 225, 226, 227, 228,
        229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242,
        243, 244, 246, 247, 248, 249, 254, 255, 25

In [10]:
print(data_temporal.num_nodes)

264


In [11]:
data_temporal = data_temporal.to(device)

In [12]:
data_temporal

TemporalData(src=[3134451], dst=[3134451], t=[3134451], msg=[3134451, 1], y=[3134451])

In [13]:
train_data, val_data, test_data = data_temporal.train_val_test_split(val_ratio=0.15, test_ratio=0.15)

In [14]:
train_loader = TemporalDataLoader(
    train_data,
    batch_size=200,
    neg_sampling_ratio=1.0,
)
val_loader = TemporalDataLoader(
    val_data,
    batch_size=200,
    neg_sampling_ratio=1.0,
)
test_loader = TemporalDataLoader(
    test_data,
    batch_size=200,
    neg_sampling_ratio=1.0,
)

In [15]:
train_data.src

tensor([  4,   4,   4,  ..., 263, 263, 263])

In [16]:
neighbor_loader = LastNeighborLoader(data_temporal.num_nodes, size=10, device=device)

In [17]:
class GraphAttentionEmbedding(torch.nn.Module):
    def __init__(self, in_channels, out_channels, msg_dim, time_enc):
        super().__init__()
        self.time_enc = time_enc
        edge_dim = msg_dim + time_enc.out_channels
        self.conv = TransformerConv(in_channels, out_channels // 2, heads=2,
                                    dropout=0.1, edge_dim=edge_dim)

    def forward(self, x, last_update, edge_index, t, msg):
        rel_t = last_update[edge_index[0]] - t
        rel_t_enc = self.time_enc(rel_t.to(x.dtype))
        edge_attr = torch.cat([rel_t_enc, msg], dim=-1)
        return self.conv(x, edge_index, edge_attr)


In [18]:
class EdgeWeightPredictor(torch.nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.lin_src = Linear(in_channels, in_channels)
        self.lin_dst = Linear(in_channels, in_channels)
        self.lin_final = Linear(in_channels, 1)  # Outputs a single value for edge weight

    def forward(self, z_src, z_dst):
        h = self.lin_src(z_src) + self.lin_dst(z_dst)
        h = h.relu()
        return self.lin_final(h)

In [19]:
memory_dim = time_dim = embedding_dim = 100

In [20]:
data_temporal.num_nodes

264

In [21]:
memory = TGNMemory(
    data_temporal.num_nodes,
    data_temporal.msg.size(-1),
    memory_dim,
    time_dim,
    message_module=IdentityMessage(data_temporal.msg.size(-1), memory_dim, time_dim),
    aggregator_module=LastAggregator(),
).to(device)

gnn = GraphAttentionEmbedding(
    in_channels=memory_dim,
    out_channels=embedding_dim,
    msg_dim=data_temporal.msg.size(-1),
    time_enc=memory.time_enc,
).to(device)

In [22]:
# link_pred = LinkPredictor(in_channels=embedding_dim).to(device)
edge_weight_pred = EdgeWeightPredictor(in_channels=embedding_dim).to(device)

optimizer = torch.optim.Adam(
    set(memory.parameters()) | set(gnn.parameters())
    | set(edge_weight_pred.parameters()), lr=0.0001)

# Assuming the existence of the link is still a binary classification problem
bce_loss = BCEWithLogitsLoss()
mse_loss = MSELoss()  # Assuming weights are a regression problem

criterion = torch.nn.BCEWithLogitsLoss()
weight_criterion = torch.nn.MSELoss()

# Helper vector to map global node indices to local ones.
assoc = torch.empty(data_temporal.num_nodes, dtype=torch.long, device=device)

In [23]:
def train():
    memory.train()
    gnn.train()
    # link_pred.train()
    edge_weight_pred.train()

    memory.reset_state()  # Start with a fresh memory.
    neighbor_loader.reset_state()  # Start with an empty graph.

    total_loss = 0
    # total_weight_loss = 0
    for batch in train_loader:
        optimizer.zero_grad()
        batch = batch.to(device)

        n_id, edge_index, e_id = neighbor_loader(batch.n_id.long())
        assoc[n_id] = torch.arange(n_id.size(0), device=device)
        # n_id = n_id.long()

        # Get updated memory of all nodes involved in the computation.
        z, last_update = memory(n_id)
        z = gnn(z, last_update, edge_index, data_temporal.t[e_id].to(device),
                data_temporal.msg[e_id].to(device))

        # pos_out = link_pred(z[assoc[batch.src]], z[assoc[batch.dst]])
        # neg_out = link_pred(z[assoc[batch.src]], z[assoc[batch.neg_dst]])

        weight_pred = edge_weight_pred(z[assoc[batch.src]], z[assoc[batch.dst]])
        weight_loss = weight_criterion(weight_pred, batch.msg.view(-1, 1))

        # loss = criterion(pos_out, torch.ones_like(pos_out))
        # loss += criterion(neg_out, torch.zeros_like(neg_out))
        # loss += weight_loss  # Include edge weight loss in the total loss

        memory.update_state(batch.src, batch.dst, batch.t, batch.msg)
        neighbor_loader.insert(batch.src, batch.dst)

        weight_loss.backward()
        optimizer.step()
        memory.detach()
        total_loss += float(weight_loss) * batch.num_events

    return total_loss / train_data.num_events


In [24]:
@torch.no_grad()
def test(loader):
    memory.eval()
    gnn.eval()
    # link_pred.eval()
    edge_weight_pred.eval()

    torch.manual_seed(12345)  # Ensure deterministic sampling across epochs.

    # aps = []
    # weight_aps, weight_aucs = [], []  # To store APS and AUC for weights
    all_predictions = []    # To store prediction results and node info

    for batch in loader:
        batch = batch.to(device)

        n_id, edge_index, e_id = neighbor_loader(batch.n_id.long())
        assoc[n_id] = torch.arange(n_id.size(0), device=device)
        # n_id = n_id.long()

        z, last_update = memory(n_id)
        z = gnn(z, last_update, edge_index, data_temporal.t[e_id].to(device),
                data_temporal.msg[e_id].to(device))
        # outputs = link_pred(z[assoc[batch.src.long()]], z[assoc[batch.dst.long()]])
        # pos_out, pos_weights = outputs[:, 0], outputs[:, 1]
        # neg_out, neg_weights = link_pred(z[assoc[batch.src.long()]], z[assoc[batch.neg_dst.long()]]).split(1, dim=1)

        # y_pred = torch.cat([pos_out.sigmoid().unsqueeze(1), pos_weights.unsqueeze(1)], dim=1).cpu()

        # pos_out = link_pred(z[assoc[batch.src]], z[assoc[batch.dst]])
        # neg_out = link_pred(z[assoc[batch.src]], z[assoc[batch.neg_dst]])

        weight_pred = edge_weight_pred(z[assoc[batch.src]], z[assoc[batch.dst]])

        # if torch.cuda.is_available():
        #   y_true = torch.cat(
        #     [torch.ones(pos_out.size(0), 1).to(device),
        #      batch.msg.to(device)], dim=0).cpu()
            # batch.msg.to(device).unsqueeze(1)], dim=0).cpu()
        # else:
        #   y_true = torch.cat(
        #     [torch.ones(pos_out.size(0), 1),
        #      batch.msg], dim=0).cpu()
            # batch.msg.unsqueeze(1)], dim=0).cpu()

        # if torch.cuda.is_available():
        #   y_true = torch.cat(
        #       [torch.ones(pos_out.size(0), 1).to(device),
        #        batch.msg.to(device)], dim=0).cpu()
        # else:
        #   y_true = torch.cat(
        #     [torch.ones(pos_out.size(0), 1),
        #      batch.msg], dim=0).cpu()  # Assumes batch.msg is properly scaled

        # y_true = torch.cat(
        #     [torch.ones(pos_out.size(0), 1),
        #      batch.msg.unsqueeze(1)], dim=0).cpu()  # Assumes batch.msg is properly scaled

        # src_nodes = batch.src.cpu().numpy()
        # dst_nodes = batch.dst.cpu().numpy()
        src_nodes = batch.src.cpu().numpy()
        dst_nodes = batch.dst.cpu().numpy()
        all_predictions.append((weight_pred, src_nodes, dst_nodes))

        # Calculate APS and AUC for link existence
        # aps.append(average_precision_score(y_true[:, 0], weight_pred[:, 0]))
        # aucs.append(roc_auc_score(y_true[:, 0], weight_pred[:, 0]))

        # Calculate APS and AUC for weights
        # weight_aps.append(average_precision_score(y_true[:, 1], y_pred[:, 1]))  # Added for weight APS
        # weight_aucs.append(roc_auc_score(y_true[:, 1], y_pred[:, 1]))  # Added for weight AUC

        memory.update_state(batch.src, batch.dst, batch.t, batch.msg)
        neighbor_loader.insert(batch.src, batch.dst)

    return all_predictions


In [25]:
for epoch in range(1, 101):
    loss = train()
    print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}')
    val_preds = test(val_loader)

Epoch: 01, Loss: 10169.4605
Epoch: 02, Loss: 8861.9021
Epoch: 03, Loss: 8358.6989
Epoch: 04, Loss: 8165.7161
Epoch: 05, Loss: 7934.1791
Epoch: 06, Loss: 8326.5592
Epoch: 07, Loss: 8052.3375
Epoch: 08, Loss: 7508.4522
Epoch: 09, Loss: 7368.9030
Epoch: 10, Loss: 7407.0726
Epoch: 11, Loss: 7014.4421
Epoch: 12, Loss: 7123.7813
Epoch: 13, Loss: 6582.5375
Epoch: 14, Loss: 6514.0757
Epoch: 15, Loss: 7060.3285
Epoch: 16, Loss: 7934.2216
Epoch: 17, Loss: 6658.9212
Epoch: 18, Loss: 7664.7158
Epoch: 19, Loss: 6683.9084
Epoch: 20, Loss: 7187.7544
Epoch: 21, Loss: 5789.3716
Epoch: 22, Loss: 6010.8711
Epoch: 23, Loss: 5724.8119
Epoch: 24, Loss: 7477.5374
Epoch: 25, Loss: 5685.0806
Epoch: 26, Loss: 6033.4723
Epoch: 27, Loss: 6832.1084
Epoch: 28, Loss: 5632.5736
Epoch: 29, Loss: 5529.2994
Epoch: 30, Loss: 5194.0642
Epoch: 31, Loss: 5179.9064
Epoch: 32, Loss: 4850.7872
Epoch: 33, Loss: 6428.4975
Epoch: 34, Loss: 6698.3564
Epoch: 35, Loss: 6380.6649
Epoch: 36, Loss: 5714.1553
Epoch: 37, Loss: 5505.9633


In [26]:
test_preds = test(test_loader)

In [27]:
len(test_preds)

2329

In [28]:
# import torch

# # Assume memory, gnn, and edge_weight_pred are your trained model instances
# torch.save(memory.state_dict(), 'memory.pth')
# torch.save(gnn.state_dict(), 'gnn.pth')
# torch.save(edge_weight_pred.state_dict(), 'edge_weight_pred.pth')

# # Assuming the model classes are available as MemoryModel, GNNModel, EdgeWeightModel

# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# memory.load_state_dict(torch.load('memory.pth', map_location=device))
# memory.eval()  # Set the model to evaluation mode

# gnn.load_state_dict(torch.load('gnn.pth', map_location=device))
# gnn.eval()  # Set the model to evaluation mode

# edge_weight_pred.load_state_dict(torch.load('edge_weight_pred.pth', map_location=device))
# edge_weight_pred.eval()  # Set the model to evaluation mode


EdgeWeightPredictor(
  (lin_src): Linear(in_features=100, out_features=100, bias=True)
  (lin_dst): Linear(in_features=100, out_features=100, bias=True)
  (lin_final): Linear(in_features=100, out_features=1, bias=True)
)