# Null experiments

This notebook includes 2 other experiments conducted on the KuaiRec dataset. They include:
- [Matapath2Vec for Non-personalized Recommendation (Autoplay)](#Non-personalized-Recommendation)  
    This experiment checks if adjusting the W2V video embeddings with Metapath2Vec can improve non-personalized recommendation performance. Long story short, it does not. Stand alone W2V embeddings are much better for autoplay
- [Frozen Metapath2Vec for Personalized Recommendation](#Frozen-Metapath2Vec)  
    This experiment checks if it is possible to learn the user embeddings with Metapath2Vec while keeping the video embeddings frozen (same as original W2V video embeddings). This would be attractive because it could allow us to have learned user embeddings while only storing 1 set of video/track embeddings. Long story short, it does not work at all. Unsurprisingly, it is necessary to adjust both video/track and user embeddings jointly to improve personalized recommendation performance.

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import multiprocessing
from tqdm import tqdm
from typing import Dict, List, Optional, Tuple
from collections import defaultdict
import pickle
import random
import itertools
import copy
from sklearn.neighbors import NearestNeighbors
from sklearn.manifold import TSNE
from gensim.models.word2vec import Word2Vec
from scipy.stats import wilcoxon

import torch
from torch_geometric.nn.models import MetaPath2Vec
from torch_geometric.nn import MetaPath2Vec
from torch_geometric.data import HeteroData
from torch import nn, Tensor
from torch.utils.data import DataLoader
from torch_geometric.index import index2ptr
from torch_geometric.typing import EdgeType, NodeType, OptTensor
from torch_geometric.utils import sort_edge_index

## Non-personalized-Recommendation

In [None]:
# Load KuaiRec user watch histories, filter for only positive interactions (watch_ratio >= 2, as mentioned by the authors),
# and create chronological watch sequences for each user (for W2V)
train = pd.read_csv('KuaiRec/big_matrix.csv')
train = train[train['watch_ratio'] >= 2]
train_sequences = train.sort_values(['user_id', 'time'], ascending=[True, True]).groupby('user_id')['video_id'].apply(list).reset_index()
train_sequences.to_parquet('KuaiRec/train_sequences.parquet')

In [None]:
# Take fully observed test set positive interactions & create val and test pairs for non-personalized recommendation
# as done in Deezer's W2V Hyperparams matter paper (https://github.com/deezer/w2v_reco_hyperparameters_matter/blob/master/src/main.py)
test = pd.read_csv('KuaiRec/small_matrix.csv')
test = test[test['watch_ratio'] >= 2].sort_values(['user_id', 'time']).reset_index(drop=True)
pairs = test.groupby('user_id')['video_id'].apply(lambda x: list(zip(x[:-1], x[1:]))).explode().tolist()
np.random.shuffle(pairs)
split_idx = int(len(pairs) * 0.8)
test_pairs = pairs[:split_idx]
val_pairs = pairs[split_idx:]
np.save('KuaiRec/val_pairs.npy', np.array(val_pairs, dtype=int))
np.save('KuaiRec/test_pairs.npy', np.array(test_pairs, dtype=int))

In [None]:
# Evaluation function for val & test based on Deezer's W2V paper, returns hitrate@k & ndcg@k for KNN recommendations
# using the W2V trained embeddings
def evaluate(train, test, embedding_dim, window_size, epochs, sg=1, min_count=1, k=20):
    model = Word2Vec(
            vector_size=embedding_dim,
            window=window_size,
            workers=multiprocessing.cpu_count(),
            sg=sg,
            min_count=min_count,
            compute_loss=True,
        )
    model.build_vocab(train)
    model.train(
            corpus_iterable=train,
            total_examples=len(train),
            epochs=epochs,
        )
    vocab = list(model.wv.index_to_key)
    embedding = [model.wv[elem] for elem in vocab]
    mapping = {elem: i for i, elem in enumerate(vocab)}
    mapping_back = {v: k for k, v in mapping.items()}
    
    neigh = NearestNeighbors()
    neigh.fit(embedding)
    
    hrk_score = 0.0
    ndcg_score = 0.0
    for pair_items in tqdm(test):
        if str(pair_items[0]) not in vocab:
            continue
        emb_0 = embedding[mapping[str(pair_items[0])]].reshape(1, -1)
        # Get neighbors
        emb_neighbors = neigh.kneighbors(emb_0, k+1)[1].flatten()[1:]
        neighbors = [mapping_back[x] for x in emb_neighbors]
        if str(pair_items[1]) in neighbors:
            # HR@k
            hrk_score += 1/k
            # NDCG@k
            # In our case only one item in the retrieved list can be relevant,
            # so in particular the ideal ndcg is 1 and ndcg_at_k = 1/log_2(1+j)
            # where j is the position of the relevant item in the list.
            index_match = (np.where(str(pair_items[1]) == np.array(neighbors)))[0][0]
            ndcg_score += 1/np.log2(np.arange(2, k+2))[index_match]
    hrk_score = hrk_score / len(test)
    ndcg_score = ndcg_score / len(test)

    return {'HR@%i' % k: 1000*hrk_score, 'NDCG@%i' % k: ndcg_score}

In [None]:
train = pd.read_parquet('KuaiRec/train_sequences.parquet')
sequences = train['video_id'].apply(lambda x: list(map(str, x))).tolist()
val = np.load('KuaiRec/val_pairs.npy')

In [None]:
# Best W2V non-personalized recommendation test set performance
test = np.load('KuaiRec/test_pairs.npy')
evaluate(sequences, test, 50, 15, 75)

In [None]:
# Best W2V video/track embeddings config will be used as initialization for Metapath2vec
train = pd.read_parquet('KuaiRec/train_sequences.parquet')
sequences = train['video_id'].apply(lambda x: list(map(str, x))).tolist()
val = np.load('KuaiRec/val_pairs.npy')
model = Word2Vec(
            vector_size=50,
            window=15,
            workers=multiprocessing.cpu_count(),
            sg=1,
            min_count=1,
            compute_loss=True,
        )
model.build_vocab(sequences)
model.train(
            corpus_iterable=sequences,
            total_examples=len(sequences),
            epochs=75,
        )

In [None]:
df_watch = pd.read_csv('KuaiRec/big_matrix.csv')

unique_users = df_watch['user_id'].unique()
unique_videos = df_watch['video_id'].unique()

df_watch = df_watch[df_watch['watch_ratio'] >= 2].sort_values(['user_id', 'time'])

user2idx = {uid: i for i, uid in enumerate(unique_users)}
video2idx = {vid: i for i, vid in enumerate(unique_videos)}
idx2video = {i: vid for vid, i in video2idx.items()}


# Build user->video edge index
user_col = df_watch['user_id'].map(user2idx).values
video_col = df_watch['video_id'].map(video2idx).values
edge_index_uv = np.vstack([user_col, video_col])

df_social = pd.read_csv('KuaiRec/social_network.csv')
df_social['friend_list'] = df_social['friend_list'].apply(lambda x: x.strip('[]').split(','))
df_social = df_social.explode('friend_list').dropna()
df_social['friend_list'] = df_social['friend_list'].astype(int)

# Build user-user edge index
userA = df_social['user_id'].map(user2idx).values
userB = df_social['friend_list'].map(user2idx).values

edge_index_uu = np.vstack([userA, userB])

In [None]:
data = HeteroData()
# Add user, video node counts
data['user'].num_nodes = len(user2idx)
data['video'].num_nodes = len(video2idx)
# user->video edges
data['user', 'watches', 'video'].edge_index = torch.tensor(edge_index_uv, dtype=torch.long)
# Include reverse edges too:
data['video', 'watched_by', 'user'].edge_index = torch.tensor(np.flip(edge_index_uv, axis=0).copy(order='C'), dtype=torch.long)
data['user', 'follows', 'user'].edge_index = torch.tensor(edge_index_uu, dtype=torch.long)
w2v_dim = model.vector_size

In [None]:
# Create a zero tensor for all video embeddings & update them with W2V trained ones
model_vocab = list(model.wv.index_to_key)
video_emb = np.zeros((len(video2idx), w2v_dim), dtype=np.float32)
for vid, idx in video2idx.items():
    if str(vid) in model_vocab:
        video_emb[idx] = model.wv[str(vid)]

# Set data video embeddings to the W2V embeddings
data['video'].x = torch.tensor(video_emb, dtype=torch.float32)

num_users = len(user2idx)

user_to_videos = defaultdict(list)
for _, row in df_watch.iterrows():
    user_id = row['user_id']
    video_id_str = str(row['video_id'])  # Convert to string for Word2Vec keys
    user_to_videos[user_id].append(video_id_str)

# Initialize user embeddings to their average video embeddings from training watch history
user_emb = torch.zeros((num_users, w2v_dim), dtype=torch.float32)
for uid, u_idx in user2idx.items():
    vids_watched = user_to_videos[uid]
    if not vids_watched:
        continue  # user_emb remains zero if no videos

    sum_vec = np.zeros(w2v_dim, dtype=np.float32)
    count = 0
    for vid_str in vids_watched:
        if vid_str in model.wv:  # If the video ID is in the W2V vocab
            sum_vec += model.wv[vid_str]
            count += 1
    if count > 0:
        user_emb[u_idx] = torch.tensor(sum_vec / count, dtype=torch.float32)

# Set data user embeddings to the user's average W2V video embeddings
data['user'].x = user_emb

In [None]:
meta_path_uuv = [
    ('user', 'follows', 'user'),
    ('user', 'watches', 'video'),
    ('video', 'watched_by', 'user')
]
chosen_meta_path = meta_path_uv

In [None]:
def evaluate_video_pairs(video_emb, pairs, video2idx, idx2video, k=20):
    """
    video_emb: np.array of shape [num_videos, embedding_dim].
    pairs: list of (prev_video_id, next_video_id) as integers or strings.
    video2idx: dict mapping video_id -> index in video_emb
    idx2video: dict mapping index -> video_id
    k: number of neighbors to consider.
    """
    neigh = NearestNeighbors(n_neighbors=k+1, metric='euclidean')  # or 'cosine'
    neigh.fit(video_emb)

    hrk_score = 0.0
    ndcg_score = 0.0
    total_eval = 0

    for (prev_vid, next_vid) in tqdm(pairs):
        # Convert to string if needed, or keep as int if consistent in video2idx
        if prev_vid not in video2idx or next_vid not in video2idx:
            continue
        prev_idx = video2idx[prev_vid]
        emb_0 = video_emb[prev_idx].reshape(1, -1)

        # Retrieve top (k+1) neighbors, ignoring the first if it is the same item
        distances, neighbors_idx = neigh.kneighbors(emb_0, n_neighbors=k+1)
        neighbors_idx = neighbors_idx.flatten()[1:]  # skip the "self" neighbor

        # Convert indices back to video_ids
        neighbors_vids = [idx2video[idx] for idx in neighbors_idx]

        # Hit Rate @ k
        if next_vid in neighbors_vids:
            hrk_score += 1.0 / k
            # NDCG @ k
            rank = neighbors_vids.index(next_vid)  # 0-based
            ndcg_score += 1.0 / np.log2(rank + 2)  # rank+1 => position, +1 for 1-based

        total_eval += 1

    if total_eval == 0:
        return {'HR@%i' % k: 0.0, 'NDCG@%i' % k: 0.0}

    hrk_score /= total_eval
    ndcg_score /= total_eval
    return {'HR@%i' % k: 1000*hrk_score, 'NDCG@%i' % k: ndcg_score}


In [None]:
def get_embeddings(model_metapath):
    """Extract user and video embeddings from the trained model."""
    model_metapath.eval()
    with torch.no_grad():
        user_emb = model_metapath('user').cpu().numpy()
        video_emb = model_metapath('video').cpu().numpy()
    return user_emb, video_emb

def run_metapath2vec_training_and_eval(data, meta_path, hp, val_pairs, video2idx, idx2video):
    """
    Trains MetaPath2Vec given hyperparams (hp) and returns best performance on val set.
    hp is a dict with keys:
      {
        'embedding_dim', 'walk_length', 'context_size',
        'walks_per_node', 'num_negative_samples', 'epochs', 'lr'
      }
    Returns: best_metrics, best_epoch
    """
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # 1) Build the model
    model_metapath = MetaPath2Vec(
        data.edge_index_dict,
        embedding_dim=hp['embedding_dim'],
        metapath=meta_path,
        walk_length=hp['walk_length'],
        context_size=hp['context_size'],
        walks_per_node=hp['walks_per_node'],
        num_negative_samples=hp['num_negative_samples'],
        sparse=True
    ).to(device)
    
    # Overwrite "video" embeddings from data['video'].x if you have W2V init
    with torch.no_grad():
        model_metapath('video').data.copy_(data['video'].x.to(device))
        model_metapath('user').data.copy_(data['user'].x.to(device))
    
    loader = model_metapath.loader(batch_size=128, shuffle=True, num_workers=4)
    optimizer = torch.optim.SparseAdam(model_metapath.parameters(), lr=hp['lr'])
    
    # 2) Training with early stopping (two consecutive hits@20 decreases)
    best_metrics = {'HR@20': 0.0, 'NDCG@20': 0.0}
    best_epoch = 0
    consecutive_decreases = 0
    
    def train_one_epoch(epoch):
        model_metapath.train()
        total_loss = 0
        for step, (pos_rw, neg_rw) in enumerate(loader):
            optimizer.zero_grad()
            pos_rw, neg_rw = pos_rw.to(device), neg_rw.to(device)
            loss = model_metapath.loss(pos_rw, neg_rw)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        return total_loss
    
    for epoch in range(1, hp['epochs'] + 1):
        epoch_loss = train_one_epoch(epoch)
        
        # Evaluate on val set
        user_emb, video_emb = get_embeddings(model_metapath)  # updated: pass model
        val_metrics = evaluate_video_pairs(video_emb, val_pairs, video2idx, idx2video, k=20)
        hr20 = val_metrics['HR@20']
                
        # Check if this is best
        if hr20 > best_metrics['HR@20']:
            best_metrics = val_metrics
            best_epoch = epoch
            consecutive_decreases = 0
        else:
            consecutive_decreases += 1
        
        # Stop if 2 consecutive decreases in HR@20
        if consecutive_decreases >= 2:
            break
    
    return best_metrics, best_epoch

In [None]:
def hyperparam_search(data, val_pairs, video2idx, idx2video):
    walk_lengths = [6, 10, 20]         
    context_sizes = [2, 5, 10]            
    walks_per_nodes = [2, 5, 10]          # how many random walks each node starts
    neg_samples = [2, 5, 10]             # negative sampling
    lrs = [0.01]                       # learning rates
    epochs = 15                       # max epochs for each
    meta_paths = [
        [('user', 'follows', 'user'), ('user', 'watches', 'video'), ('video', 'watched_by', 'user')]
    ]
    
    best_config = None
    best_metrics = {'HR@20': 0.0, 'NDCG@20': 0.0}
    
    for meta_path in meta_paths:
        for wlen in walk_lengths:
            for csize in context_sizes:
                if wlen + 1 < csize:
                    continue
                for wpn in walks_per_nodes:
                    for neg in neg_samples:
                        for lr in lrs:
                            hp = {
                                'embedding_dim': w2v_dim,
                                'walk_length': wlen,
                                'context_size': csize,
                                'walks_per_node': wpn,
                                'num_negative_samples': neg,
                                'epochs': epochs,
                                'lr': lr
                            }

                            metrics, best_epoch = run_metapath2vec_training_and_eval(
                                data, meta_path, hp, val_pairs, video2idx, idx2video
                            )
                            # Check if better
                            if metrics['HR@20'] > best_metrics['HR@20']:
                                best_metrics = metrics
                                best_config = (meta_path, hp, best_epoch)
    
    return best_config, best_metrics

In [None]:
best_config, best_val_metrics = hyperparam_search(data, val, video2idx, idx2video)
print("Best config:", best_config)
print("Best val metrics:", best_val_metrics)

In [None]:
# Best metapath2vec embedding performance on test set
hp = {
    'embedding_dim': w2v_dim,
    'walk_length': 6,
    'context_size': 6,
    'walks_per_node': 2,
    'num_negative_samples': 8,
    'epochs': 10,
    'lr': 0.01
}

metrics, best_epoch = run_metapath2vec_training_and_eval(
    data, meta_path_uuv, hp, val, video2idx, idx2video
)
print(best_epoch, metrics)
metrics, best_epoch = run_metapath2vec_training_and_eval(
    data, meta_path_uuv, hp, test, video2idx, idx2video
)
print(metrics)

## Frozen-Metapath2Vec
Here we attempt to see if we can achieve similar performance while freezing the video embeddings (0 gradient). This would allow only the user embeddings to adapt to the already pretrained W2V video embeddings, and would allow use to not have to store two sets of track/video embeddings (for non-personalized and personalized recommendation). But as we can see below, this fails miserably and the performance almost goes to 0, indicating that changes in the video embeddings are necessary.

In [None]:
from typing import Dict, List, Optional, Tuple

import torch
from torch import nn, Tensor
from torch.utils.data import DataLoader

from torch_geometric.index import index2ptr
from torch_geometric.typing import EdgeType, NodeType, OptTensor
from torch_geometric.utils import sort_edge_index

EPS = 1e-15

def sample(
    rowptr: Tensor,
    col: Tensor,
    rowcount: Tensor,
    subset: Tensor,
    num_neighbors: int,
    dummy_idx: int
) -> Tensor:
    r"""Samples a single neighbor for each node in :obj:`subset` based on
    :obj:`rowptr` and :obj:`col`. Returns :obj:`dummy_idx` for nodes
    without neighbors (or out-of-bounds)."""

    # Mask out-of-bounds or dummy nodes:
    mask = subset >= dummy_idx
    subset = subset.clamp(min=0, max=rowptr.size(0) - 2)

    # How many neighbors does each row have?
    count = rowcount[subset]  # [|subset|]

    # Sample indices:
    rand = torch.rand((subset.size(0), num_neighbors), device=subset.device)
    rand *= count.to(rand.dtype).view(-1, 1)  # scale random floats by neighbor count
    rand = rand.to(torch.long) + rowptr[subset].view(-1, 1)
    rand = rand.clamp(max=col.numel() - 1)  # safe clamp in case of tiny col

    out = col[rand] if col.numel() > 0 else rand
    # For isolated nodes (count=0) or originally out-of-bounds => dummy_idx
    out[mask | (count == 0)] = dummy_idx
    return out


class FrozenMetaPath2Vec(nn.Module):
    r"""
    A custom MetaPath2Vec model that learns separate embeddings for "user" and
    "video" node types, following the paper:

    "metapath2vec: Scalable Representation Learning for Heterogeneous Networks"
    (Dong, Chawla, & Swami, KDD'17).

    This version:
      • Uses random walks based on a specified :obj:`metapath`
      • Samples positive and negative context windows
      • Stores separate `Embedding` layers for "user" and "video"
        so you can freeze the video embeddings and train only the user embeddings

    Args:
        edge_index_dict (Dict[EdgeType, Tensor]): Dictionary holding edge
            indices for each :obj:`(src_type, rel_type, dst_type)`.
        user_count (int): Number of user nodes.
        video_count (int): Number of video nodes.
        embedding_dim (int): Embedding dimension size.
        metapath (List[EdgeType]): The metapath described as a list of
            `(src_node_type, rel_type, dst_node_type)`.
        walk_length (int): The length of each random walk.
        context_size (int): The actual context size for positive samples.
        walks_per_node (int, optional): Number of random walks per node.
            (default: :obj:`1`)
        num_negative_samples (int, optional): Number of negative samples for
            each positive example. (default: :obj:`1`)
        sparse (bool, optional): If set to True, use sparse gradients for
            embedding parameters. (default: :obj:`False`)

    **How to freeze video embeddings**:
        1. Copy in your pre-trained vectors:
           ```
           with torch.no_grad():
               model.video_embedding.weight.copy_(pretrained_video_emb)
           ```
        2. Set
           ```
           model.video_embedding.weight.requires_grad = False
           ```
        3. Construct your optimizer only over remaining parameters:
           ```
           optimizer = torch.optim.SparseAdam(
               filter(lambda p: p.requires_grad, model.parameters()),
               lr=0.01
           )
           ```
    """

    def __init__(
        self,
        edge_index_dict: Dict[Tuple[str, str, str], Tensor],
        user_count: int,
        video_count: int,
        embedding_dim: int,
        metapath: List[EdgeType],
        walk_length: int,
        context_size: int,
        walks_per_node: int = 1,
        num_negative_samples: int = 1,
        sparse: bool = False,
    ):
        super().__init__()

        self.user_count = user_count
        self.video_count = video_count
        self.embedding_dim = embedding_dim

        # Two distinct embeddings:
        #   user_embedding: for nodes in [0 .. user_count-1]
        #   video_embedding: for nodes in [0 .. video_count-1]
        self.user_embedding = nn.Embedding(user_count, embedding_dim, sparse=sparse)
        self.video_embedding = nn.Embedding(video_count, embedding_dim, sparse=sparse)

        # Metapath hyperparams:
        self.metapath = metapath
        self.walk_length = walk_length
        self.context_size = context_size
        self.walks_per_node = walks_per_node
        self.num_negative_samples = num_negative_samples

        # Preprocess adjacency for each edge type:
        # rowptr_dict[(src, rel, dst)], col_dict[(src, rel, dst)], etc.
        self.rowptr_dict = {}
        self.col_dict = {}
        self.rowcount_dict = {}

        # We assume exactly two node types: "user" and "video".
        # Each edge_index is for either user->video or video->user or user->user, etc.
        # We'll store them carefully so random walks can jump between them.
        for keys, edge_index in edge_index_dict.items():
            src_type, _, dst_type = keys
            # Figure out how many nodes on source/dest side:
            if src_type == "user":
                src_size = user_count
            elif src_type == "video":
                src_size = video_count
            else:
                raise ValueError(f"Unknown src_type: {src_type}")

            if dst_type == "user":
                dst_size = user_count
            elif dst_type == "video":
                dst_size = video_count
            else:
                raise ValueError(f"Unknown dst_type: {dst_type}")

            # Sort edges and get rowptr, col for sampling:
            row, col = sort_edge_index(edge_index, num_nodes=max(src_size, dst_size))
            rowptr = index2ptr(row, size=src_size)
            self.rowptr_dict[keys] = rowptr
            self.col_dict[keys] = col
            self.rowcount_dict[keys] = rowptr[1:] - rowptr[:-1]

        # Sanity check on metapath continuity:
        for edge_type1, edge_type2 in zip(metapath[:-1], metapath[1:]):
            # E.g., (user, "rel1", video), (video, "rel2", user) => OK
            if edge_type1[-1] != edge_type2[0]:
                raise ValueError(
                    "Invalid metapath: destination node type of one edge "
                    "does not match source node type of next edge."
                )

        # The paper requires: walk_length + 1 >= context_size
        assert walk_length + 1 >= context_size, \
            "'walk_length + 1' must be >= 'context_size'."

        self.reset_parameters()

    def reset_parameters(self) -> None:
        r"""Resets user and video embedding parameters."""
        # If you do NOT want to reset the video embedding (e.g. it's pre-trained),
        # comment out the next line:
        self.user_embedding.reset_parameters()
        self.video_embedding.reset_parameters()

    def forward(self, node_type: str, batch: OptTensor = None) -> Tensor:
        r"""
        Returns the embedding for a given node type (`"user"` or `"video"`).
        If batch is None, returns all embeddings of that type.
        """
        if node_type == "user":
            emb = self.user_embedding.weight
        elif node_type == "video":
            emb = self.video_embedding.weight
        else:
            raise ValueError(f"Unsupported node type: {node_type}")

        if batch is None:
            return emb
        else:
            return emb.index_select(0, batch)

    def loader(self, node_type: str, **kwargs):
        r"""
        Returns a DataLoader over the nodes of a particular type, e.g. "user".
        We'll create random walks starting from each node of that type.
        The DataLoader yields `(pos_rw, neg_rw)` pairs.
        """
        if node_type not in ("user", "video"):
            raise ValueError("'loader(node_type=...)' must be 'user' or 'video'.")

        if node_type == "user":
            num_nodes = self.user_count
        else:  # video
            num_nodes = self.video_count

        return DataLoader(range(num_nodes), collate_fn=self._sample, **kwargs)

    def _sample(self, batch_list: List[int]) -> Tuple[Tensor, Tensor]:
        r"""Given a list of node IDs of a certain type, sample positive and negative
        random walks. This is passed to the DataLoader as `collate_fn`."""
        batch = torch.tensor(batch_list, dtype=torch.long)
        return self._pos_sample(batch), self._neg_sample(batch)

    def _pos_sample(self, batch: Tensor) -> Tensor:
        r"""
        Samples positive random walks along the metapath, starting from `batch`.
        We'll produce multiple walks if `walks_per_node > 1`.
        """
        batch = batch.repeat(self.walks_per_node)
        rws = [batch]  # list of [nodes at step0, nodes at step1, ...]

        # Follow the edges in `self.metapath` repeatedly for `walk_length` steps:
        for i in range(self.walk_length):
            edge_type = self.metapath[i % len(self.metapath)]
            rowptr = self.rowptr_dict[edge_type]
            col = self.col_dict[edge_type]
            rowcount = self.rowcount_dict[edge_type]

            # Sample next node:
            batch = sample(
                rowptr, col, rowcount, batch,
                num_neighbors=1,
                dummy_idx=max(self.user_count, self.video_count)  # safe dummy
            ).view(-1)
            rws.append(batch)

        # Now we have a list of Tensors (walk_length+1 steps). Combine them:
        # For context-windowing, we create overlapping windows of size `context_size`.
        pos_rw = self._build_context_windows(rws)
        return pos_rw

    def _neg_sample(self, batch: Tensor) -> Tensor:
        r"""
        Negative sampling: pick random nodes at each step (matching the
        node type of the current metapath), ignoring adjacency.
        """
        batch = batch.repeat(self.walks_per_node * self.num_negative_samples)
        rws = [batch]

        for i in range(self.walk_length):
            _, _, dst_type = self.metapath[i % len(self.metapath)]
            if dst_type == "user":
                num_nodes = self.user_count
            elif dst_type == "video":
                num_nodes = self.video_count
            else:
                raise ValueError(f"Unsupported dst_type: {dst_type}")

            batch = torch.randint(0, num_nodes, (batch.size(0),), dtype=torch.long)
            rws.append(batch)

        neg_rw = self._build_context_windows(rws)
        return neg_rw

    def _build_context_windows(self, rws: List[Tensor]) -> Tensor:
        r"""
        Takes a list of Tensors, each shape [num_walks], representing the node
        IDs at each step in the walk. Then builds context windows of size
        `self.context_size`.
        """
        # Stack steps into shape [num_walks, walk_length+1]
        rw_stack = torch.stack(rws, dim=-1)
        total_steps = len(rws)  # walk_length + 1

        # We'll make (total_steps - context_size + 1) windows per walk.
        num_walks_per_rw = total_steps - self.context_size + 1
        out = []
        for j in range(num_walks_per_rw):
            # shape: [num_walks, context_size]
            window = rw_stack[:, j : j + self.context_size]
            out.append(window)
        # shape: [num_windows * num_walks, context_size]
        return torch.cat(out, dim=0)

    def loss(self, pos_rw: Tensor, neg_rw: Tensor) -> Tensor:
        r"""Computes the negative-sampling loss for the given positive and negative
        random walks."""
        # Positive samples:
        start, rest = pos_rw[:, 0], pos_rw[:, 1:].contiguous()
        # Embed them:
        h_start = self._embed_nodes(start).view(pos_rw.size(0), 1, self.embedding_dim)
        h_rest = self._embed_nodes(rest.view(-1)).view(pos_rw.size(0), -1, self.embedding_dim)

        out = (h_start * h_rest).sum(dim=-1).view(-1)
        pos_loss = -torch.log(torch.sigmoid(out) + EPS).mean()

        # Negative samples:
        start, rest = neg_rw[:, 0], neg_rw[:, 1:].contiguous()
        h_start = self._embed_nodes(start).view(neg_rw.size(0), 1, self.embedding_dim)
        h_rest = self._embed_nodes(rest.view(-1)).view(neg_rw.size(0), -1, self.embedding_dim)

        out = (h_start * h_rest).sum(dim=-1).view(-1)
        neg_loss = -torch.log(1 - torch.sigmoid(out) + EPS).mean()

        return pos_loss + neg_loss

    def _embed_nodes(self, nodes: Tensor) -> Tensor:
        r"""
        Looks up embeddings for a batch of node indices, which may contain both
        user and video IDs. We must figure out which are users vs. videos.
        For simplicity, we assume node IDs < user_count => "user",
        else => "video" (subtract user_count).
        """
        device = nodes.device
        emb = torch.zeros((nodes.size(0), self.embedding_dim), device=device)

        user_mask = (nodes < self.user_count)
        if user_mask.any():
            user_idx = nodes[user_mask]
            emb[user_mask] = self.user_embedding(user_idx)

        video_mask = ~user_mask
        if video_mask.any():
            video_idx = nodes[video_mask] - self.user_count
            emb[video_mask] = self.video_embedding(video_idx.clamp(min=0))

        return emb

    def test(
        self,
        train_z: Tensor,
        train_y: Tensor,
        test_z: Tensor,
        test_y: Tensor,
        solver: str = "lbfgs",
        *args,
        **kwargs
    ) -> float:
        r"""
        Evaluates latent space quality via a logistic regression downstream task.
        (Kept for parity with PyG’s MetaPath2Vec.)
        """
        from sklearn.linear_model import LogisticRegression

        clf = LogisticRegression(solver=solver, *args, **kwargs).fit(
            train_z.detach().cpu().numpy(),
            train_y.detach().cpu().numpy()
        )
        return clf.score(test_z.detach().cpu().numpy(), test_y.detach().cpu().numpy())

    def __repr__(self) -> str:
        return (f"{self.__class__.__name__}(\n"
                f"  user_count={self.user_count}, video_count={self.video_count}, "
                f"embedding_dim={self.embedding_dim}\n)")


model_metapath=FrozenMetaPath2Vec(
    data.edge_index_dict,
    user_count=data['user']['num_nodes'], video_count=data['video']['num_nodes'],
    embedding_dim=w2v_dim,
    metapath=meta_path_uuv,
    walk_length=6, context_size=3,
    walks_per_node=2, num_negative_samples=5,
    sparse=True
).to('cuda')

with torch.no_grad():
    model_metapath.video_embedding.weight.copy_(data['video'].x.to('cuda'))
    model_metapath.user_embedding.weight.copy_(data['user'].x.to('cuda'))
# Freeze them:
model_metapath.video_embedding.weight.requires_grad = False

original_video_w2v_cpu = data['video'].x.clone().cpu()

assert torch.allclose(
    model_metapath.video_embedding.weight.detach().cpu(),
    original_video_w2v_cpu,
    atol=1e-6
), "Mismatch even before training!"

optimizer = torch.optim.SparseAdam(
    filter(lambda p: p.requires_grad, model_metapath.parameters()), lr=0.01
)
loader = model_metapath.loader(node_type='user', batch_size=128, shuffle=True, num_workers=4)

def train_one_epoch():
    model_metapath.train()
    total_loss = 0
    for pos_rw, neg_rw in loader:
        pos_rw, neg_rw = pos_rw.to('cuda'), neg_rw.to('cuda')
        optimizer.zero_grad()
        loss = model_metapath.loss(pos_rw, neg_rw)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss

for epoch in range(1, 10):
    epoch_loss = train_one_epoch()
    print(f"Epoch {epoch}, Loss: {epoch_loss:.4f}")

    # Evaluate on validation:
    model_metapath.eval()
    user_emb = model_metapath('user').cpu().detach().numpy()
    video_emb = model_metapath('video').cpu().detach().numpy()
    
    val_metrics = evaluate_user_reco(
        user_emb=user_emb,
        video_emb=video_emb,
        val_data=val_data,
        test_data=test_data,
        user2idx=user2idx,
        video2idx=video2idx,
        idx2video=idx2video,
        is_validation=True,   # i.e., we're evaluating on val_data
        top_k=100
    )
    print("Val metrics:", val_metrics)
    
    current_video_emb = model_metapath.video_embedding.weight.detach().cpu()
    if torch.allclose(current_video_emb, original_video_w2v_cpu, atol=1e-6):
        print("Video embeddings are unchanged after epoch", epoch)
    else:
        print("WARNING: Video embeddings changed after epoch", epoch)

final_video_emb = model_metapath.video_embedding.weight.detach().cpu()
assert torch.allclose(final_video_emb, original_video_w2v_cpu, atol=1e-6), (
    "Video embeddings diverged from the original W2V!"
)
print("Training finished; video embeddings remained fixed.")