In [1]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, TensorDataset, Dataset
import random, math
from torch.utils.data import random_split
from pytorch_lightning import Trainer

from collections import defaultdict
import warnings
warnings.filterwarnings('ignore')

import os
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

In [2]:
seed = 42

random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

In [3]:
import torch
import torch.nn.functional as F
from torch_geometric_temporal.nn.recurrent import TGCN

class TGCNRecommender(torch.nn.Module):
    def __init__(self, node_features, hidden_dim):
        super(TGCNRecommender, self).__init__()
        # Initialize TGCN layer
        self.tgcn = TGCN(node_features, hidden_dim)
        # Final prediction layer (e.g., to predict a value for each node)
        self.linear = torch.nn.Linear(hidden_dim, 1)

    def forward(self, x, edge_index, edge_weight, h):
        """
        x: (nodes, features)
        h: (nodes, hidden_dim) - previous hidden state
        """
        # TGCN returns the NEW hidden state
        h_new = self.tgcn(x, edge_index, edge_weight, h)
        
        # Pass the hidden state through a linear layer for final output
        out = self.linear(h_new)
        
        return out, h_new

# Example Usage
num_nodes = 100
node_features = 8
hidden_dim = 32

model = TGCNRecommender(node_features, hidden_dim)

# Mock data for 2 time steps
x_t0 = torch.randn(num_nodes, node_features)
x_t1 = torch.randn(num_nodes, node_features)
edge_index = torch.randint(0, num_nodes, (2, 200))

# Step 0: Initial hidden state is None (will be zeros internally)
out0, h0 = model(x_t0, edge_index, None, None)

# Step 1: Pass the hidden state from t0 to t1
out1, h1 = model(x_t1, edge_index, None, h0)

print(f"Output shape: {out1.shape}") # Expect (num_nodes, 1)

ModuleNotFoundError: No module named 'urllib3'