In [2]:
import os
from typing import List, Tuple

import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence


class NextItemDataset(Dataset):
    def __init__(self, root_dir: str, dataset: str, split: str, min_len: int = 1):
        name_map = {
            "Dunnhumby": ("Dunnhumby_history.csv", "Dunnhumby_future.csv"),
            "Instacart": ("Instacart_history.csv", "Instacart_future.csv"),
            "TaFang": ("TaFang_history_NB.csv", "TaFang_future_NB.csv"),
            "ValuedShopper": ("VS_history_order.csv", "VS_future_order.csv"),
        }
        if dataset not in name_map:
            raise ValueError(f"Unknown dataset: {dataset}")
        if split not in {"train", "val"}:
            raise ValueError("split must be 'train' or 'val'")
        filename = name_map[dataset][0 if split == "train" else 1]
        path = os.path.join(root_dir, filename)

        dtypes = {"CUSTOMER_ID": "int64", "ORDER_NUMBER": "int64", "MATERIAL_NUMBER": "int64"}
        df = pd.read_csv(path, dtype=dtypes)
        df = df.sort_values(["CUSTOMER_ID", "ORDER_NUMBER"])  # stable within order
        grouped = df.groupby("CUSTOMER_ID")["MATERIAL_NUMBER"].apply(list)
        self.sequences: List[torch.Tensor] = [
            torch.tensor(seq, dtype=torch.long) for seq in grouped.tolist() if len(seq) >= min_len
        ]

    def __len__(self) -> int:
        return len(self.sequences)

    def __getitem__(self, index: int) -> torch.Tensor:
        return self.sequences[index]


def collate_fn(batch: List[torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
    lengths = torch.tensor([len(x) for x in batch], dtype=torch.long)
    padded = pad_sequence(batch, batch_first=True, padding_value=0)
    return padded, lengths


In [3]:
DATA_DIR = "../external_repos/TIFUKNN/data"
DATASET_NAME = "Dunnhumby"  # change to one of: Dunnhumby, Instacart, TaFang, ValuedShopper

BATCH_SIZE = 64

train_ds = NextItemDataset(root_dir=DATA_DIR, dataset=DATASET_NAME, split="train")
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=0, collate_fn=collate_fn)


In [4]:
val_ds = NextItemDataset(root_dir=DATA_DIR, dataset=DATASET_NAME, split="val")
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=0, collate_fn=collate_fn)

In [5]:
batch, lengths = next(iter(train_loader))
print(batch.shape, lengths.shape)


torch.Size([64, 73]) torch.Size([64])


In [None]:
import sys, os
sys.path.append(os.path.abspath(".."))  # make project root importable

from models.sasrec import SASRecModel
model = SASRecModel(num_items=5000, max_seq_len=50, hidden_size=64, num_heads=1, num_layers=1, dropout=0.1)




In [10]:
out = model(batch, lengths)
out

tensor([[ 0.2089, -0.2415,  0.2061,  ..., -0.8096, -0.6885, -0.3644],
        [-0.5335,  0.4761,  1.1201,  ...,  0.1036,  0.1481, -0.4338],
        [-0.2751,  1.2457,  0.2522,  ...,  0.3616, -0.0234,  0.2877],
        ...,
        [-0.5629, -0.5497, -1.3004,  ...,  0.4144,  0.1152,  0.8768],
        [-1.0638, -0.7551,  0.1086,  ..., -0.0692, -1.0627, -0.1640],
        [ 0.0054, -0.0542,  1.3416,  ...,  0.8523,  0.3845,  0.1576]],
       grad_fn=<AddmmBackward0>)

In [None]:
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
from metrics import ndcg_at_k

def train(model, train_dataset, val_dataset, batch_size, num_epochs, metric_fn):
    device = next(model.parameters()).device
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0, collate_fn=collate_fn)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=0, collate_fn=collate_fn)

    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters())

    for epoch in range(num_epochs):
        model.train()
        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}")
        train_losses = []
        train_preds = []
        train_targets = []
        for batch, lengths in progress_bar:
            batch = batch.to(device)
            lengths = lengths.to(device)
            bsz = batch.size(0)
            idx = lengths - 1
            targets = batch[torch.arange(bsz, device=device), idx]
            inputs = batch.clone()
            inputs[torch.arange(bsz, device=device), idx] = 0
            seq_lens = torch.clamp(lengths - 1, min=1)

            logits = model(inputs, seq_lens)
            loss = criterion(logits, targets)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            progress_bar.set_postfix(loss=loss.item())
            train_losses.append(loss.item())
            train_preds.append(logits.detach().cpu())
            train_targets.append(targets.detach().cpu())

        train_preds_tensor = torch.cat(train_preds, dim=0)
        train_targets_tensor = torch.cat(train_targets, dim=0)
        train_ndcg = metric_fn(train_preds_tensor, train_targets_tensor)

        # Validation with tqdm progress bar
        model.eval()
        val_preds, val_targets = [], []
        with torch.no_grad():
            val_progress_bar = tqdm(val_loader, desc=f"Validation {epoch+1}/{num_epochs}")
            for batch, lengths in val_progress_bar:
                batch = batch.to(device)
                lengths = lengths.to(device)
                bsz = batch.size(0)
                idx = lengths - 1
                targets = batch[torch.arange(bsz, device=device), idx]
                inputs = batch.clone()
                inputs[torch.arange(bsz, device=device), idx] = 0
                seq_lens = torch.clamp(lengths - 1, min=1)

                logits = model(inputs, seq_lens)
                val_preds.append(logits.cpu())
                val_targets.append(targets.cpu())
        val_preds_tensor = torch.cat(val_preds, dim=0)
        val_targets_tensor = torch.cat(val_targets, dim=0)
        val_ndcg = metric_fn(val_preds_tensor, val_targets_tensor)

        print(f"Epoch {epoch+1}/{num_epochs} - Loss: {sum(train_losses)/len(train_losses):.4f} - Train NDCG@10: {train_ndcg:.4f} - Val NDCG@10: {val_ndcg:.4f}")

train_dataset = NextItemDataset(root_dir=DATA_DIR, dataset=DATASET_NAME, split="train", min_len=2)
val_dataset = NextItemDataset(root_dir=DATA_DIR, dataset=DATASET_NAME, split="val")

train(model, train_dataset, val_dataset, BATCH_SIZE, num_epochs=10, metric_fn=lambda preds, targets: ndcg_at_k(preds, targets, k=10))

Epoch 1/10: 100%|████████████████████████████████████████████████████████████████████████████████████████████| 567/567 [01:02<00:00,  9.06it/s, loss=7.33]


Epoch 1/10 - Loss: 7.2519 - Train NDCG@10: 0.0810 - Val NDCG@10: 0.0720


Epoch 2/10: 100%|████████████████████████████████████████████████████████████████████████████████████████████| 567/567 [01:08<00:00,  8.29it/s, loss=6.54]


Epoch 2/10 - Loss: 6.8846 - Train NDCG@10: 0.0885 - Val NDCG@10: 0.0739


Epoch 3/10: 100%|████████████████████████████████████████████████████████████████████████████████████████████| 567/567 [06:08<00:00,  1.54it/s, loss=6.47]


Epoch 3/10 - Loss: 6.5230 - Train NDCG@10: 0.0975 - Val NDCG@10: 0.0735


Epoch 4/10: 100%|████████████████████████████████████████████████████████████████████████████████████████████| 567/567 [03:42<00:00,  2.55it/s, loss=6.01]


Epoch 4/10 - Loss: 6.1277 - Train NDCG@10: 0.1165 - Val NDCG@10: 0.0693


Epoch 5/10: 100%|████████████████████████████████████████████████████████████████████████████████████████████| 567/567 [03:15<00:00,  2.91it/s, loss=5.92]


Epoch 5/10 - Loss: 5.7687 - Train NDCG@10: 0.1456 - Val NDCG@10: 0.0677


Epoch 6/10: 100%|████████████████████████████████████████████████████████████████████████████████████████████| 567/567 [01:14<00:00,  7.63it/s, loss=5.44]


Epoch 6/10 - Loss: 5.4776 - Train NDCG@10: 0.1751 - Val NDCG@10: 0.0676


Epoch 7/10: 100%|█████████████████████████████████████████████████████████████████████████████████████████████| 567/567 [01:09<00:00,  8.19it/s, loss=5.5]


Epoch 7/10 - Loss: 5.2433 - Train NDCG@10: 0.1976 - Val NDCG@10: 0.0628


Epoch 8/10: 100%|████████████████████████████████████████████████████████████████████████████████████████████| 567/567 [01:20<00:00,  7.06it/s, loss=5.52]


Epoch 8/10 - Loss: 5.0580 - Train NDCG@10: 0.2185 - Val NDCG@10: 0.0602


Epoch 9/10: 100%|████████████████████████████████████████████████████████████████████████████████████████████| 567/567 [01:14<00:00,  7.62it/s, loss=5.77]


Epoch 9/10 - Loss: 4.9040 - Train NDCG@10: 0.2339 - Val NDCG@10: 0.0601


Epoch 10/10: 100%|███████████████████████████████████████████████████████████████████████████████████████████| 567/567 [01:01<00:00,  9.15it/s, loss=5.34]


Epoch 10/10 - Loss: 4.7665 - Train NDCG@10: 0.2494 - Val NDCG@10: 0.0600
