In [2]:
import os
import gzip
import json
import numpy as np
import pandas as pd
from datetime import datetime
from sklearn.model_selection import train_test_split
from scipy.sparse import coo_matrix

import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader

try:
    import implicit
except ImportError:
    implicit = None

In [3]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [5]:
# ============================================================
# SASRec: Self-Attentive Sequential Recommendation
# Paper: https://arxiv.org/pdf/1808.09781
# ============================================================

import math
import random
import pickle
import os
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

# ============================================================
# CONFIGURATION
# ============================================================

DATA_DIR = "/content/drive/MyDrive/data/"
TRAIN_PKL = DATA_DIR + "train_data.pkl"
VAL_PKL   = DATA_DIR + "val_data.pkl"
TEST_PKL  = DATA_DIR + "test_data.pkl"

# Model hyperparameters
MAX_LEN = 50
HIDDEN_UNITS = 50
HEADS = 1
LAYERS = 2
DROPOUT = 0.2

# Training settings
BATCH_SIZE = 128
NEG_SAMPLES = 100
EPOCHS = 5
LR = 0.001
TOPK = 10

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("[INFO] Device:", device)

# Set seed for reproducibility
SEED = 42
random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)


# ============================================================
# DATA LOADING
# ============================================================

def load_pickle_data(pkl_path):
    """Load pickle file and convert to (sequences, targets, lengths) format."""
    if not os.path.exists(pkl_path):
        raise FileNotFoundError(f"Pickle file not found: {pkl_path}")

    with open(pkl_path, "rb") as f:
        data = pickle.load(f)

    sequences = []
    targets = []
    lengths = []

    for user_id, user_data in data.items():
        if 'sequence' in user_data and len(user_data['sequence']) > 0:
            seq = [item['item_id'] for item in user_data['sequence']]

            # Handle validation/test data with separate target field
            if 'target' in user_data:
                target = user_data['target']['item_id'] if isinstance(user_data['target'], dict) else user_data['target']
                input_seq = [item for item in seq if item != target]
                sequences.append(input_seq)
                targets.append(target)
                lengths.append(len(input_seq))
            else:
                # Training data: target is last item in sequence
                if len(seq) > 1:
                    sequences.append(seq[:-1])
                    targets.append(seq[-1])
                    lengths.append(len(seq) - 1)
                elif len(seq) == 1:
                    sequences.append([])
                    targets.append(seq[0])
                    lengths.append(0)

    return sequences, targets, lengths


# ============================================================
# DATASET
# ============================================================

class PickleDataset(Dataset):
    def __init__(self, pkl_path, max_len):
        self.seqs, self.targets, self.lengths = load_pickle_data(pkl_path)
        self.max_len = max_len

    def __len__(self):
        return len(self.seqs)

    def __getitem__(self, idx):
        seq = self.seqs[idx]
        tgt = self.targets[idx]

        # Truncate to max_len and left-pad with zeros
        seq = seq[-self.max_len:]
        seq = [0] * (self.max_len - len(seq)) + seq

        return torch.tensor(seq, dtype=torch.long), torch.tensor(tgt)


# ============================================================
# INITIALIZE DATA LOADERS
# ============================================================

# Calculate vocab size from all splits to ensure coverage
tr_seqs, tr_targets, _ = load_pickle_data(TRAIN_PKL)
val_seqs, val_targets, _ = load_pickle_data(VAL_PKL)
test_seqs, test_targets, _ = load_pickle_data(TEST_PKL)

all_seqs = tr_seqs + val_seqs + test_seqs
all_targets = tr_targets + val_targets + test_targets

max_seq_item = max([max(seq) for seq in all_seqs if len(seq) > 0]) if all_seqs else 0
max_target = max(all_targets) if all_targets else 0
VOCAB_SIZE = max(max_seq_item, max_target) + 1
print("[INFO] Vocab size:", VOCAB_SIZE)

train_loader = DataLoader(PickleDataset(TRAIN_PKL, MAX_LEN), batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
val_loader = DataLoader(PickleDataset(VAL_PKL, MAX_LEN), batch_size=1)
test_loader = DataLoader(PickleDataset(TEST_PKL, MAX_LEN), batch_size=1)

print("[INFO] Train samples:", len(train_loader.dataset))
print("[INFO] Val samples:", len(val_loader.dataset))
print("[INFO] Test samples:", len(test_loader.dataset))


# ============================================================
# SASRec MODEL
# ============================================================

class SASRec(nn.Module):
    def __init__(self, item_num):
        super().__init__()
        self.item_emb = nn.Embedding(item_num, HIDDEN_UNITS, padding_idx=0)
        self.pos_emb = nn.Embedding(MAX_LEN, HIDDEN_UNITS)

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=HIDDEN_UNITS,
            nhead=HEADS,
            dim_feedforward=HIDDEN_UNITS * 4,
            dropout=DROPOUT,
            batch_first=True
        )
        self.encoder = nn.TransformerEncoder(encoder_layer, LAYERS)

        # Initialize item embeddings
        nn.init.normal_(self.item_emb.weight, mean=0.0, std=HIDDEN_UNITS ** -0.5)

    def forward(self, seq):
        B, T = seq.shape
        seq = torch.clamp(seq, 0, self.item_emb.num_embeddings - 1)
        pos_ids = torch.arange(T, device=seq.device).unsqueeze(0)
        x = self.item_emb(seq) + self.pos_emb(pos_ids)
        mask = (seq == 0)
        h = self.encoder(x, src_key_padding_mask=mask)
        return h[:, -1]


model = SASRec(VOCAB_SIZE).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=LR)


# ============================================================
# BPR LOSS
# ============================================================

def bpr_loss(h, pos, neg):
    """Bayesian Personalized Ranking loss."""
    pos_emb = model.item_emb(pos)
    neg_emb = model.item_emb(neg)
    pos_score = torch.sum(h * pos_emb, dim=-1)
    neg_score = torch.sum(h.unsqueeze(1) * neg_emb, dim=-1)
    return -torch.log(torch.sigmoid(pos_score.unsqueeze(1) - neg_score) + 1e-8).mean()


# ============================================================
# EVALUATION
# ============================================================

def ndcg(rank):
    """Normalized Discounted Cumulative Gain."""
    return 1 / math.log2(rank + 1)


def evaluate(model, loader):
    """Evaluate model on validation/test set."""
    model.eval()
    hits = 0
    mrr = 0
    ndcg_sum = 0
    count = 0

    with torch.no_grad():
        for seq, tgt in loader:
            seq = seq.to(device)
            tgt = tgt.item()

            # Skip if target is out of vocabulary range
            if tgt >= VOCAB_SIZE or tgt < 1:
                count += 1
                continue

            seq_list = seq.squeeze(0).tolist()
            user_items = set([item for item in seq_list if item != 0 and item < VOCAB_SIZE])
            if tgt in user_items:
                user_items.discard(tgt)

            # Sample 100 negatives not in user's history
            negatives = set()
            attempts = 0
            while len(negatives) < 100 and attempts < 10000:
                neg = random.randint(1, VOCAB_SIZE - 1)
                if neg != tgt and neg not in user_items:
                    negatives.add(neg)
                attempts += 1

            # Rank target among candidates (all should be valid now)
            candidates = [tgt] + list(negatives)
            random.shuffle(candidates)

            h = model(seq)
            cand_tensor = torch.tensor(candidates, device=device, dtype=torch.long)

            # Validate all candidates are in range before embedding
            if cand_tensor.max() >= VOCAB_SIZE or cand_tensor.min() < 0:
                count += 1
                continue

            cand_emb = model.item_emb(cand_tensor)
            scores = torch.matmul(h, cand_emb.T).squeeze(0)
            rankings = torch.argsort(scores, descending=True)
            ranked_items = [candidates[i] for i in rankings.tolist()]
            rank = ranked_items.index(tgt) + 1

            if rank <= TOPK:
                hits += 1
                mrr += 1 / rank
                ndcg_sum += ndcg(rank)
            count += 1

    hr = hits / count if count > 0 else 0.0
    mrr_score = mrr / count if count > 0 else 0.0
    ndcg_score = ndcg_sum / count if count > 0 else 0.0
    return hr, mrr_score, ndcg_score


# ============================================================
# TRAINING
# ============================================================

print("\n[INFO] Training started\n")

for epoch in range(EPOCHS):
    model.train()
    total_loss = 0
    steps = 0

    for seq, tgt in train_loader:
        seq = seq.to(device)
        tgt = tgt.to(device)
        B = seq.shape[0]

        # Sample random negatives
        neg = torch.randint(1, VOCAB_SIZE, (B, NEG_SAMPLES), device=device)

        # Forward pass
        h = model(seq)
        loss = bpr_loss(h, tgt, neg)

        # Backward pass
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        total_loss += loss.item()
        steps += 1

    print(f"[TRAIN] Epoch {epoch+1}  Avg Loss = {total_loss/steps:.4f}")

    # Validation evaluation
    hr, mrr_score, ndcg_score = evaluate(model, val_loader)
    print(f"[VAL]   Epoch {epoch+1}  HR@10={hr:.4f}  MRR@10={mrr_score:.4f}  NDCG@10={ndcg_score:.4f}\n")


# ============================================================
# FINAL TEST RESULTS
# ============================================================

print("\n[INFO] Final Test Results")
hr, mrr_score, ndcg_score = evaluate(model, test_loader)
print(f"[TEST]  HR@10={hr:.4f}  MRR@10={mrr_score:.4f}  NDCG@10={ndcg_score:.4f}")

print("\n========== DONE ==========")


[INFO] Device: cuda
[INFO] Vocab size: 389162
[INFO] Train samples: 100000
[INFO] Val samples: 100000
[INFO] Test samples: 100000





[INFO] Training started

[TRAIN] Epoch 1  Avg Loss = 0.6690
[VAL]   Epoch 1  HR@10=0.3872  MRR@10=0.2104  NDCG@10=0.2520

[TRAIN] Epoch 2  Avg Loss = 0.2121
[VAL]   Epoch 2  HR@10=0.4052  MRR@10=0.2377  NDCG@10=0.2773

[TRAIN] Epoch 3  Avg Loss = 0.1155
[VAL]   Epoch 3  HR@10=0.4064  MRR@10=0.2392  NDCG@10=0.2787

[TRAIN] Epoch 4  Avg Loss = 0.0623
[VAL]   Epoch 4  HR@10=0.4044  MRR@10=0.2311  NDCG@10=0.2719

[TRAIN] Epoch 5  Avg Loss = 0.0281
[VAL]   Epoch 5  HR@10=0.3970  MRR@10=0.2197  NDCG@10=0.2614


[INFO] Final Test Results
[TEST]  HR@10=0.3803  MRR@10=0.2030  NDCG@10=0.2446

