## Loading and Configs

In [None]:
import pandas as pd
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingLR
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split

import matplotlib.pyplot as plt

from tqdm import tqdm

import sys
import os
import gc

# import ray
# from ray import tune
# from ray.tune.search.hyperopt import HyperOptSearch

# ray.shutdown()  # engage new ray session
# ray.init()

pd.set_option('display.max_rows', 100)

In [None]:
!mkdir lstm_starter_v1

In [None]:
# load data
genome_scores = pd.read_csv("/kaggle/input/movielens-20m-dataset/genome_scores.csv")
genome_tags = pd.read_csv("/kaggle/input/movielens-20m-dataset/genome_tags.csv")
links = pd.read_csv("/kaggle/input/movielens-20m-dataset/link.csv")
movies = pd.read_csv("/kaggle/input/movielens-20m-dataset/movie.csv")
rating = pd.read_csv("/kaggle/input/movielens-20m-dataset/rating.csv")
tags = pd.read_csv("/kaggle/input/movielens-20m-dataset/tag.csv")

In [None]:
IS_DEVELOP = True

In [None]:
class CFG:
    # base settings
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    seed = 42
    hyper_trials = 0

    # train settings
    seq_len = 10
    sample_size = 100000 if IS_DEVELOP else 1000000
    bs = 128
    emb_dim = 128
    hidden_dim = 256
    lstm_layers = 2
    dropout = 0.2
    epochs = 3 if IS_DEVELOP else 5
    lr = 1e-3
    wd = 1e-2
    top_k = 10
    temperature = 0.85
    candidate_pool_size = 100

    #saving & inference
    patience = 4
    save_model_path = "/kaggle/working/lstm_starter_v1/state.pth"


# seed everything
def set_seed(seed=73):
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    
set_seed(CFG.seed)

In [None]:
# basic preprocess
df = genome_scores.copy()
df = df.merge(genome_tags, on = 'tagId', how = 'inner')
df = df.merge(links, on = 'movieId', how = 'inner')
df = df.merge(movies, on = 'movieId', how = 'inner')
unique_movie_ids = df.movieId.unique()

del genome_scores, genome_tags, links, movies
gc.collect()

print("PREPROCESS READY!!!")

In [None]:
# make a table
df['movieId'] = df['movieId'].astype(str)
movie_tag_df = df.pivot_table(index='movieId', columns='tag', values='relevance', fill_value=0)

# only contain those that are present in movie data
rating['movieId'] = rating['movieId'].astype(str)
movie_encoder = LabelEncoder()
rating['movieId_enc'] = movie_encoder.fit_transform(rating['movieId'])
movie_tag_df = movie_tag_df.loc[movie_tag_df.index.intersection(movie_encoder.classes_)]

num_movies = len(movie_encoder.classes_)
# print(num_movies) # 26744

movie_tag_tensor = torch.tensor(movie_tag_df.values, dtype=torch.float).to(CFG.device)  # (num_movies, tag_dim)
raw2idx = {raw_id: i for i, raw_id in enumerate(movie_tag_df.index)}

In [None]:
def batch_retrieve_candidate_pool(seq_raw_batch, rating_seq_batch, movie_tag_tensor, raw2idx, top_n):
    """
    Parameters:
      seq_raw_batch: list of lists of raw movie IDs, shape (B, seq_len)
      rating_seq_batch: torch.Tensor of shape (B, seq_len)
      movie_tag_tensor: torch.Tensor of shape (num_movies, tag_dim)
      raw2idx: dictionary mapping raw movie id to row index in movie_tag_tensor
      top_n: int, number of candidates to retrieve
      
    Returns:
      candidate_indices: torch.Tensor of shape (B, top_n)
    """
    device = rating_seq_batch.device
    B = len(seq_raw_batch)
    
    # Convert raw movie ids to indices for each sample
    indices = torch.tensor([[raw2idx[movie_id] for movie_id in seq] for seq in seq_raw_batch], device=device)
    
    # Gather tag vectors for each movie in each sequence: (B, seq_len, tag_dim)
    seq_tags = movie_tag_tensor[indices]  # (B, seq_len, tag_dim)
    
    # Expand ratings to shape (B, seq_len, 1)
    ratings = rating_seq_batch.unsqueeze(2)  # (B, seq_len, 1)
    
    # Compute weighted tag vectors: (B, seq_len, tag_dim)
    weighted_tags = seq_tags * ratings
    
    # Sum over the sequence to form the user profile: (B, tag_dim)
    user_profiles = weighted_tags.sum(dim=1)
    
    # Normalize user profiles: (B, tag_dim)
    user_profiles = F.normalize(user_profiles, p=2, dim=1, eps=1e-8)
    
    # Normalize movie tag tensor (pre-compute norms): (num_movies, tag_dim)
    movie_tag_norm = F.normalize(movie_tag_tensor, p=2, dim=1, eps=1e-8)
    
    # Compute cosine similarities: (B, num_movies)
    sims = torch.matmul(user_profiles, movie_tag_norm.t())
    
    # Retrieve the top_n candidate movie indices for each sample.
    _, candidate_indices = torch.topk(sims, k=top_n, dim=1)
    
    return candidate_indices  # (B, top_n)


def custom_collate(batch):
    seq_movies = torch.stack([item[0] for item in batch])
    seq_ratings = torch.stack([item[1] for item in batch])
    tag_seq = torch.stack([item[2] for item in batch])
    target = torch.stack([item[3] for item in batch])
    seq_raw = [item[4] for item in batch]
    return seq_movies, seq_ratings, tag_seq, target, seq_raw

In [None]:
class MovieRatingTagDataset(Dataset):
    def __init__(self, rating_df, movie_tag_features, seq_len=5):
        self.samples = []
        # For each user, precompute candidate pools for each sequence.
        for user_id, group in rating_df.groupby('userId'):
            group = group.sort_values('timestamp')
            movies_enc = group['movieId_enc'].tolist()
            ratings_list = group['rating'].tolist()
            movies_raw = group['movieId'].tolist() 
            for i in range(len(movies_enc) - seq_len):
                seq_movies = movies_enc[i:i+seq_len]
                seq_ratings = ratings_list[i:i+seq_len]
                seq_raw = movies_raw[i:i+seq_len]
                target = movies_enc[i+seq_len]
                # check if all ids are in movie_tag table
                if not all(r in movie_tag_features.index for r in seq_raw):
                    continue
                tag_seq = [movie_tag_features.loc[r].values for r in seq_raw]

                self.samples.append((seq_movies, seq_ratings, tag_seq, target, seq_raw))
                
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        seq_movies, seq_ratings, tag_seq, target, seq_raw = self.samples[idx]
        tag_seq_array = np.array(tag_seq)
        return (
            torch.tensor(seq_movies, dtype=torch.long),   
            torch.tensor(seq_ratings, dtype=torch.float),   
            torch.tensor(tag_seq_array, dtype=torch.float), 
            torch.tensor(target, dtype=torch.long),         
            seq_raw # raw movieId sequence (list of str)
        )

In [None]:
# model
class MovieRatingTagLSTM(nn.Module):
    def __init__(self, num_movies, tag_dim, emb_dim=64, hidden_dim=128, lstm_layers=4, dropout=0.3, proj_dim=128):
        super().__init__()
        self.movie_emb = nn.Embedding(num_movies, emb_dim)
        self.rating_fc = nn.Linear(1, emb_dim)
        self.tag_encoder = nn.Sequential(
            nn.Linear(tag_dim, emb_dim),
            nn.ReLU(),
            nn.Linear(emb_dim, emb_dim)
        )
        input_dim = emb_dim * 3  # concatenated features: movie, rating, tag
        self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers=lstm_layers, batch_first=True, dropout=dropout, bidirectional=True)
        self.multihead_attn = nn.MultiheadAttention(embed_dim=hidden_dim*2, num_heads=8, dropout=dropout)
        self.fc1 = nn.Linear(hidden_dim * 2, 1024)
        self.fc2 = nn.Linear(1024, 512)
        # Project sequence representation into a space where we can compare via cosine similarity.
        self.seq_projection = nn.Linear(512, proj_dim)
        # We’ll use normalized movie embeddings (projected from self.movie_emb)
        self.movie_projection = nn.Linear(emb_dim, proj_dim)
        
    def forward_embedding(self, movie_seq, rating_seq, tag_seq):
        movie_vec = self.movie_emb(movie_seq)                  # (B, L, emb_dim)
        rating_vec = self.rating_fc(rating_seq.unsqueeze(-1))    # (B, L, emb_dim)
        tag_vec = self.tag_encoder(tag_seq)                      # (B, L, emb_dim)
        x = torch.cat([movie_vec, rating_vec, tag_vec], dim=-1)    # (B, L, emb_dim*3)
        lstm_out, _ = self.lstm(x)                               # (B, L, hidden_dim*2)
        attn_output, _ = self.multihead_attn(lstm_out, lstm_out, lstm_out)
        attn_output = torch.mean(attn_output, dim=1)             # (B, hidden_dim*2)
        x = F.relu(self.fc1(attn_output))
        x = F.relu(self.fc2(x))
        seq_embed = self.seq_projection(x)                       # (B, proj_dim)
        seq_embed = F.normalize(seq_embed, p=2, dim=1)
        return seq_embed
    
    def get_movie_embeddings(self):
        with torch.no_grad():
            movie_embeds = self.movie_emb.weight              
            proj_movie_embeds = self.movie_projection(movie_embeds)  # (num_movies, proj_dim)
            proj_movie_embeds = F.normalize(proj_movie_embeds, p=2, dim=1)
        return proj_movie_embeds

In [None]:
# train functions
def train_one_epoch_vectorized(model, optimizer, criterion, train_loader, device, top_n, movie_tag_tensor, raw2idx):
    model.train()
    total_loss = 0

    # movie embeddings, computed once per training epoch
    all_movie_embeds = model.get_movie_embeddings()  
    
    for movie_seq, rating_seq, tag_seq, target, seq_raw in tqdm(train_loader, desc=f"Train Epoch"):
        movie_seq = movie_seq.to(device)
        rating_seq = rating_seq.to(device) 
        tag_seq = tag_seq.to(device)
        target = target.to(device)
        
        # retrieve batch-wise candidate pool.
        candidate_pools = batch_retrieve_candidate_pool(seq_raw, rating_seq, movie_tag_tensor, raw2idx, top_n).to(device)

        # get embeddings
        optimizer.zero_grad()
        seq_embed = model.forward_embedding(movie_seq, rating_seq, tag_seq)  # (B, proj_dim)
        candidate_embeds = all_movie_embeds[candidate_pools]  # (B, top_n, proj_dim)
        
        # compute similarity
        seq_embed_expanded = seq_embed.unsqueeze(1)  # (B, 1, proj_dim)
        batch_logits = torch.bmm(seq_embed_expanded, candidate_embeds.transpose(1, 2)).squeeze(1)  # (B, top_n)
        
        # check if target is in candidate pool
        eq = (candidate_pools == target.unsqueeze(1))  # (B, top_n)
        target_mask = eq.any(dim=1)  # (B,)

        # if not skip
        if target_mask.sum() == 0:
            continue 

        # Filter valid samples
        train_logits = batch_logits[target_mask]
        train_eq = eq[target_mask]
        train_targets = train_eq.float().argmax(dim=1)

        # apply temperature scaling
        temperature = CFG.temperature
        train_logits = train_logits / temperature
        
        # compute loss and accuracy on valid samples only
        loss = criterion(train_logits, train_targets)
        total_loss += loss.item() * train_logits.size(0)
        loss.backward(retain_graph=True)
        optimizer.step()
    
    avg_loss = total_loss / len(train_loader)
    print(f"Train Loss: {avg_loss:.4f}")
    return avg_loss


def validate_vectorized(model, criterion, valid_loader, device, top_n, movie_tag_tensor, raw2idx, top_k=5):
    model.eval()
    total_loss = 0
    total = 0
    correct_topk = 0
    skipped = 0 
    
    # movie embeddings, computed once per validation epoch
    all_movie_embeds = model.get_movie_embeddings()  # (num_movies, proj_dim)
    
    with torch.no_grad():
        for movie_seq, rating_seq, tag_seq, target, seq_raw in tqdm(valid_loader, desc="Valid Epoch"):
            movie_seq = movie_seq.to(device)
            rating_seq = rating_seq.to(device)
            tag_seq = tag_seq.to(device)
            target = target.to(device)
            
            # compute candidate pools
            candidate_pools = batch_retrieve_candidate_pool(seq_raw, rating_seq, movie_tag_tensor, raw2idx, top_n).to(device)  # (B, top_n)

            # Embeddings
            seq_embed = model.forward_embedding(movie_seq, rating_seq, tag_seq)  # (B, proj_dim)
            candidate_embeds = all_movie_embeds[candidate_pools]  # (B, top_n, proj_dim)
            batch_logits = torch.bmm(seq_embed.unsqueeze(1), candidate_embeds.transpose(1, 2)).squeeze(1)  # (B, top_n)
            
            # Mask: check if target is in candidate pool
            eq = (candidate_pools == target.unsqueeze(1))  # (B, top_n)
            target_mask = eq.any(dim=1)  # (B,)
            if target_mask.sum() == 0:
                skipped += target.size(0)
                continue  # Skip this batch completely if no valid target
            
            # filter
            valid_logits = batch_logits[target_mask]  
            valid_eq = eq[target_mask]                
            # this gives the index in the candidate pool where the true target is located
            valid_targets = valid_eq.float().argmax(dim=1)  

            # loss
            loss = criterion(valid_logits, valid_targets)
            total_loss += loss.item() * valid_logits.size(0)
            
            # top-k predictions
            topk_indices = valid_logits.topk(k=top_k, dim=1, largest=True, sorted=True)[1]
            correct_topk += (topk_indices == valid_targets.unsqueeze(1)).any(dim=1).float().sum().item()
            
            total += valid_logits.size(0)
    
    avg_loss = total_loss / total if total > 0 else 0.0
    topk_accuracy = correct_topk / total if total > 0 else 0.0
    coverage = total / (total + skipped) if (total + skipped) > 0 else 0.0
    
    print(f"Valid Loss: {avg_loss:.4f} | Top-{top_k} Accuracy: {topk_accuracy * 100:.2f}% | Coverage: {coverage:.4f}")
    return avg_loss, topk_accuracy

In [None]:
# data preparation

# sample out part of the dataset
# rating_df = rating.sample(n=CFG.sample_size, random_state=CFG.seed)

import random
start_idx = random.randint(0, len(rating) - CFG.sample_size)
rating_df = rating.iloc[start_idx:start_idx + CFG.sample_size]


# stratified split to ensure the actual ratio for train, valid, test is 0.8 : 0.1 : 0.1
user_activity = rating_df['userId'].value_counts().rename('count').reset_index()
user_activity.columns = ['userId', 'count']
user_activity['activity_bin'] = pd.qcut(user_activity['count'], q=4, labels=False, duplicates="drop")
train_users, test_users = train_test_split(
    user_activity['userId'],
    test_size=0.2,
    stratify=user_activity['activity_bin'],
    random_state=42
)

test_user_bins = user_activity.set_index('userId').loc[test_users]['activity_bin']
valid_users, test_users = train_test_split(
    test_users,
    test_size=0.5,
    stratify=test_user_bins,
    random_state=42
)

# indexing
train_rating = rating_df[rating_df['userId'].isin(train_users)].copy()
valid_rating = rating_df[rating_df['userId'].isin(valid_users)].copy()
test_rating = rating_df[rating_df['userId'].isin(test_users)].copy()

# datasets
train_dataset = MovieRatingTagDataset(train_rating, movie_tag_df, seq_len=CFG.seq_len)
valid_dataset = MovieRatingTagDataset(valid_rating, movie_tag_df, seq_len=CFG.seq_len)
test_dataset = MovieRatingTagDataset(test_rating, movie_tag_df, seq_len=CFG.seq_len)

# loaders
train_loader = DataLoader(train_dataset, batch_size=CFG.bs, shuffle=True, collate_fn=custom_collate)
valid_loader = DataLoader(valid_dataset, batch_size=CFG.bs, shuffle=False, collate_fn=custom_collate)
test_loader = DataLoader(test_dataset, batch_size=CFG.bs, shuffle=False, collate_fn=custom_collate)

print("DATALOADER READY!!!")

In [None]:
# search_space = {
#     "temperature": tune.uniform(0.5, 1.0),
#     "lr": tune.loguniform(1e-5, 1e-3),
#     "emb_dim": tune.choice([64, 128]),
#     "hidden_dim": tune.choice([128, 256]),
#     "lstm_layers": tune.choice([2, 4]),  # If you increase seq_len later, you might try more layers.
#     "dropout": tune.uniform(0.2, 0.5),
#     "wd": tune.loguniform(1e-5, 1e-3),
#     "epochs": tune.choice([10, 15, 20]),
#     "candidate_pool_size": tune.choice([300, 500, 1000]),
#     "label_smoothing": tune.uniform(0, 0.2),
# }

# movie_tag_tensor_ref = ray.put(movie_tag_tensor)
# raw2idx_ref = ray.put(raw2idx)
# tuned_train_model = tune.with_parameters(
#     train_model,  # your training function defined earlier
#     movie_tag_tensor=movie_tag_tensor_ref,
#     raw2idx=raw2idx_ref,
#     train_loader=train_loader,  # assuming these are defined globally or similarly passed
#     valid_loader=valid_loader
# )

# def train_model(config):
#     CFG.temperature = config["temperature"]
    
#     # Create the model with hyperparameters from config.
#     model = MovieRatingTagLSTM(
#         num_movies=26744,
#         tag_dim=movie_tag_tensor.shape[1],
#         emb_dim=config["emb_dim"],
#         hidden_dim=config["hidden_dim"],
#         lstm_layers=config["lstm_layers"],
#         dropout=config["dropout"],
#         proj_dim=128  # You can also tune this if needed.
#     ).to(CFG.device)
    
#     optimizer = optim.Adam(model.parameters(), lr=config["lr"], weight_decay=config["wd"])
#     criterion = nn.CrossEntropyLoss(label_smoothing=config["label_smoothing"])
    
#     # For simplicity, we run for a fixed number of epochs.
#     for epoch in range(config["epochs"]):
#         loss = train_one_epoch_vectorized(
#             model, optimizer, criterion, train_loader,
#             device, config["candidate_pool_size"], movie_tag_tensor, raw2idx, config
#         )
#         # Here you could add a validation loop and report validation metrics.
#         # For now, we report training loss.
#         tune.report(loss=loss)

In [None]:
# training block
tag_dim = movie_tag_df.shape[1]

# num_movies, tag_dim, num_users=None, emb_dim=128, hidden_dim=512, lstm_layers=3, dropout=0.3
model = MovieRatingTagLSTM(
    num_movies=num_movies, 
    tag_dim=tag_dim, emb_dim=CFG.emb_dim, 
    hidden_dim=CFG.hidden_dim, 
    lstm_layers=CFG.lstm_layers,
    dropout=CFG.dropout,
).to(CFG.device)
print("MODEL READY!!!")

# training utils
criterion = nn.CrossEntropyLoss(label_smoothing = 0.1)
optimizer = optim.AdamW(model.parameters(), lr=CFG.lr, weight_decay=CFG.wd)
scheduler = CosineAnnealingLR(optimizer, T_max=CFG.epochs, eta_min=1e-5)

# record training results
train_losses, valid_losses, valid_accs = [], [], []
best_acc = 0
patience_count = 0
best_model_state = model.state_dict()

if not CFG.hyper_trials:
    for epoch in range(CFG.epochs):
        tloss = train_one_epoch_vectorized(model, optimizer, criterion, train_loader, CFG.device, CFG.candidate_pool_size, movie_tag_tensor, raw2idx)
        vloss, vacc = validate_vectorized(model, criterion, valid_loader, CFG.device, CFG.candidate_pool_size, movie_tag_tensor, raw2idx, CFG.top_k)
        train_losses.append(tloss)
        valid_losses.append(vloss)
        valid_accs.append(vacc)
        
        if vacc > best_acc:
            best_acc = vacc
            best_model_state = model.state_dict()
            patience_count = 0
            torch.save(best_model_state, CFG.save_model_path)
            print("Best model state updated.")
        else:
            patience_count += 1
            if patience_count >= CFG.patience:
                print("Early stopping triggered")
                break
        scheduler.step()
    
    model.load_state_dict(best_model_state)
    torch.save(best_model_state, CFG.save_model_path)
    
else:
    analysis = tune.run(
        tuned_train_model,
        config=search_space,
        num_samples=CFG.hyper_trials,
        resources_per_trial={"cpu": 4, "gpu": 1}, 
        metric="loss",
        mode="min",
        progress_reporter=tune.CLIReporter(metric_columns=["loss", "accuracy"])
    )
    
    print("Best hyperparameters found were: ", analysis.best_config)
    print("Best result:", analysis.best_result)

In [None]:
# visualize valid accuracy
if not CFG.hyper_trials:
    plt.figure(figsize=(6, 4))
    plt.plot(valid_accs, marker='o', linestyle='-', color='b', label='Eval Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.title('Evaluation Accuracy over Epochs')
    plt.grid(True)
    plt.legend()
    plt.tight_layout()
    plt.show()

In [None]:
# visualize train loss
if not CFG.hyper_trials:
    plt.figure(figsize=(6, 4))
    plt.plot(train_losses, marker='o', linestyle='-', color='b', label='Train Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Train Loss')
    plt.title('Train Loss over Epochs')
    plt.grid(True)
    plt.legend()
    plt.tight_layout()
    plt.show()

In [None]:
# visualize valid loss
if not CFG.hyper_trials:
    plt.figure(figsize=(6, 4))
    plt.plot(valid_losses, marker='o', linestyle='-', color='b', label='Train Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Valid Loss')
    plt.title('Valid Loss over Epochs')
    plt.grid(True)
    plt.legend()
    plt.tight_layout()
    plt.show()

### Inference

In [None]:
def compute_top_k_accuracy(predictions, targets, k=5):
    correct = 0
    total = 0

    for pred, target in zip(predictions, targets):
        # Check if the target is in the top-k predictions
        # Flatten the arrays in case they are not 1D
        pred = pred.flatten()
        target = target.flatten()

        # Check if the target is in the top-K predictions
        correct += (pred[:k] == target).sum().item()
        total += 1

    return correct / total if total > 0 else 0.0


def infer_with_test_loader(model, test_loader, device, top_n, movie_tag_tensor, raw2idx, return_top_k=5):
    model.eval()

    all_predictions = []
    all_targets = []
    skipped = 0  # Track skipped batches where no valid target is found

    with torch.no_grad():
        for movie_seq, rating_seq, tag_seq, target, seq_raw in tqdm(test_loader, desc="Inference"):
            # Move inputs to device
            movie_seq = movie_seq.to(device)
            rating_seq = rating_seq.to(device)
            tag_seq = tag_seq.to(device)
            target = target.to(device)

            # Precompute all movie embeddings once
            all_movie_embeds = model.get_movie_embeddings()  # (num_movies, proj_dim)
            
            # Compute candidate pools for the batch
            candidate_pools = batch_retrieve_candidate_pool(seq_raw, rating_seq, movie_tag_tensor, raw2idx, top_n).to(device)

            # Embeddings
            seq_embed = model.forward_embedding(movie_seq, rating_seq, tag_seq)  # (B, proj_dim)
            candidate_embeds = all_movie_embeds[candidate_pools]  # (B, top_n, proj_dim)
            batch_logits = torch.bmm(seq_embed.unsqueeze(1), candidate_embeds.transpose(1, 2)).squeeze(1)  # (B, top_n)
            
            # Mask: check if target is in candidate pool
            eq = (candidate_pools == target.unsqueeze(1))  # (B, top_n)
            target_mask = eq.any(dim=1)  # (B,)

            if target_mask.sum() == 0:
                skipped += target.size(0)
                continue  # Skip this batch completely if no valid target

            # Filter valid samples
            valid_logits = batch_logits[target_mask]
            valid_eq = eq[target_mask]
            valid_targets = valid_eq.float().argmax(dim=1)

            # Get the top-K predicted movie IDs
            topk_scores, topk_movies = valid_logits.topk(return_top_k, dim=1, largest=True, sorted=True)

            # Store predictions and targets
            all_predictions.append(topk_movies.cpu().numpy())  # List of top-k movie indices (for each user)
            all_targets.append(valid_targets.cpu().numpy())  # Correct target for each user

    # Compute top-K accuracy
    accuracy = compute_top_k_accuracy(np.concatenate(all_predictions), np.concatenate(all_targets), k=return_top_k)
    
    return all_predictions, all_targets, accuracy

In [None]:
tag_dim = movie_tag_df.shape[1]
model = MovieRatingTagLSTM(
    num_movies=num_movies, 
    tag_dim=tag_dim, emb_dim=CFG.emb_dim, 
    hidden_dim=CFG.hidden_dim, 
    lstm_layers=CFG.lstm_layers,
    dropout=CFG.dropout,
).to(CFG.device)

model.load_state_dict(torch.load("/kaggle/input/movie-rec-lstm-v1/lstm_starter_v1/state.pth", map_location = CFG.device))

In [None]:
if not CFG.hyper_trials:
    top_n = 100
    raw2idx = raw2idx  
    return_top_k = 1
    
    # Run inference with the test loader and compute accuracy
    predictions, targets, accuracy = infer_with_test_loader(
        model, test_loader, CFG.device, top_n, movie_tag_tensor, raw2idx, return_top_k
    )
    
    print(f"Top-{return_top_k} Accuracy: {accuracy * 100:.2f}%")

also need to build the inference cycle for raw inputs