# Session-Based Recommendation with Transformers4Rec

This notebook implements a session-based recommender system using NVIDIA's [Transformers4Rec](https://github.com/NVIDIA-Merlin/Transformers4Rec) library.

We will:
1.  **Setup**: Install necessary libraries.
2.  **Preprocess**: Use NVTabular to create session sequences from our rental data.
3.  **Model**: Define a Transformer-based model (e.g., XLNet).
4.  **Train**: Train the model to predict the next item in a session.
5.  **Evaluate**: Check performance metrics.

In [1]:
import os
import glob
import pandas as pd
import numpy as np
import torch

import nvtabular as nvt
from nvtabular.ops import *
from merlin.schema.tags import Tags

import transformers4rec.torch as tr

# Check for GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

  warn(f"Tensorflow dtype mappings did not load successfully due to an error: {exc.msg}")
  warn(f"Triton dtype mappings did not load successfully due to an error: {exc.msg}")


Using device: cpu


## 2. Preprocessing with NVTabular

We need to transform our raw interaction data into a format suitable for sequential models.
This involves:
1.  Loading the raw data (Hits and Visits).
2.  Merging them to associate products with sessions.
3.  Using NVTabular to group interactions by session (`visit_id`) and create sequences.

In [None]:
# 1. Load Enriched Data
# We use the feature engineering notebook to prepare the data
# This file contains: visit_id, item_id, context features, item metadata, and counter features.

print("Loading enriched interactions...")
interactions = pd.read_parquet('data/enriched_interactions.parquet')

print(f"Loaded {len(interactions)} interactions.")
print("Columns:", interactions.columns.tolist())
interactions.head()

Loading raw data...
Processed interactions: 408562
Processed interactions: 408562


Unnamed: 0,visit_id,item_id,date_time,traffic_source,region_city
122887,463311640199432,avtokreslo-chicco-synthesis-xt-plus,2022-01-20 03:29:26,ad,Moscow
122896,640428033179772,manezh-krovat-capella-best-friends,2022-01-20 03:40:42,ad,Moscow
122902,714740689010850,piratskiy-korabl-elc,2022-01-20 03:45:26,direct,Moscow
122903,714740689010850,piratskiy-korabl-elc,2022-01-20 03:45:26,direct,Moscow
122906,714740689010850,piratskiy-korabl-elc,2022-01-20 03:45:26,direct,Moscow


In [None]:
# 2. Define NVTabular Workflow
# We will create a workflow to:
# - Categorify categorical features
# - Normalize continuous features
# - Group by visit_id to create sequences

# Define Feature Columns
# Categorical Features
item_id = ['item_id'] >> Categorify(dtype="int64") >> TagAsItemID()
traffic_source = ['traffic_source'] >> Categorify(dtype="int64")
region_city = ['region_city'] >> Categorify(dtype="int64")
brand = ['brand'] >> Categorify(dtype="int64")
main_category = ['main_category'] >> Categorify(dtype="int64")
price_bucket = ['price_bucket'] >> Categorify(dtype="int64")
hour = ['hour'] >> Categorify(dtype="int64")
day_of_week = ['day_of_week'] >> Categorify(dtype="int64")
is_weekend = ['is_weekend'] >> Categorify(dtype="int64")

# Continuous Features (Counters)
# We LogOp then Normalize to handle skewed distributions typical of popularity
item_popularity = ['item_popularity'] >> LogOp() >> Normalize()
category_popularity = ['category_popularity'] >> LogOp() >> Normalize()

session_id = ['visit_id'] >> Categorify(dtype="int64") >> TagAsUserID()
time_col = ['date_time']

# Grouping to create sequences
# We group by 'visit_id' and aggregate other columns into lists
# Note: Context features (city, source, time) are constant per session in our logic, 
# but T4Rec expects sequences. We'll just list them and T4Rec can handle them.
groupby_features = (
    session_id + item_id + traffic_source + region_city + 
    brand + main_category + price_bucket + 
    hour + day_of_week + is_weekend +
    item_popularity + category_popularity +
    time_col
) >> Groupby(
    groupby_cols=['visit_id'],
    sort_cols=['date_time'],
    aggs={
        'item_id': 'list',
        'traffic_source': 'list',
        'region_city': 'list',
        'brand': 'list',
        'main_category': 'list',
        'price_bucket': 'list',
        'hour': 'list',
        'day_of_week': 'list',
        'is_weekend': 'list',
        'item_popularity': 'list',
        'category_popularity': 'list',
        'date_time': 'first'
    },
    name_sep="-"
)

workflow = nvt.Workflow(groupby_features)

# Create a dataset from the pandas dataframe
interactions = interactions.reset_index(drop=True)
dataset = nvt.Dataset(interactions)

# Fit and Transform
print("Fitting and transforming with NVTabular...")
workflow.fit(dataset)
workflow.transform(dataset).to_parquet("data/processed_sessions")

print("NVTabular processing complete.")



Fitting and transforming with NVTabular...




NVTabular processing complete.


## 3. Dataset Creation

We load the processed Parquet files into a Merlin Dataset, which T4Rec uses.
We also define the schema, which tells the model which features are categorical, which is the item ID, etc.

In [10]:
# Load the processed data
# In a real scenario, we would split by time before this step.
# For this example, we'll just load the single file and split it manually or use a subset.

import os
import glob

processed_path = "data/processed_sessions"
schema = workflow.output_schema

# Check the schema
print("Schema:", schema)

# Create a Dataset
# We can use the Merlin Dataset API
import merlin.io
ds = merlin.io.Dataset(processed_path, engine="parquet")

# Simple Time-based Split (approximate for this example)
# We'll just take the last 20% of rows as validation since we sorted by time implicitly? 
# Actually, we should sort the parquet file or split it properly.
# Let's assume random split for the mechanics of this demo if time split is hard to do on the fly.
# But for session rec, time split is crucial.

# Let's reload as DataFrame to split, then save back to parquet for T4Rec (easiest for small data)
# We use glob to explicitly select .parquet files and avoid reading metadata files (like schema.pbtxt)
parquet_files = glob.glob(os.path.join(processed_path, "*.parquet"))
df = pd.read_parquet(parquet_files)
df = df.sort_values('date_time-first')

split_index = int(len(df) * 0.8)
train_df = df.iloc[:split_index]
valid_df = df.iloc[split_index:]

train_df.to_parquet("data/train.parquet")
valid_df.to_parquet("data/valid.parquet")

print(f"Train sessions: {len(train_df)}")
print(f"Valid sessions: {len(valid_df)}")

Schema: [{'name': 'visit_id', 'tags': {<Tags.CATEGORICAL: 'categorical'>, <Tags.USER: 'user'>, <Tags.ID: 'id'>}, 'properties': {'num_buckets': None, 'freq_threshold': 0, 'max_size': 0, 'cat_path': './/categories/unique.visit_id.parquet', 'domain': {'min': 0, 'max': 139158, 'name': 'visit_id'}, 'embedding_sizes': {'cardinality': 139159, 'dimension': 512}}, 'dtype': DType(name='uint64', element_type=<ElementType.UInt: 'uint'>, element_size=64, element_unit=None, signed=None, shape=Shape(dims=(Dimension(min=0, max=None),))), 'is_list': False, 'is_ragged': False}, {'name': 'item_id-list', 'tags': {<Tags.CATEGORICAL: 'categorical'>, <Tags.ITEM: 'item'>, <Tags.ID: 'id'>}, 'properties': {'num_buckets': None, 'freq_threshold': 0, 'max_size': 0, 'cat_path': './/categories/unique.item_id.parquet', 'domain': {'min': 0, 'max': 1199, 'name': 'item_id'}, 'embedding_sizes': {'cardinality': 1200, 'dimension': 85}, 'value_count': {'min': 0, 'max': None}}, 'dtype': DType(name='int64', element_type=<Elem



Train sessions: 111324
Valid sessions: 27832


## 4. Model Configuration

We define the Transformer model.
We use `SequentialBlock` to combine:
1.  **Embeddings**: For items and side info (city, source).
2.  **Transformer Body**: XLNet or similar.
3.  **Prediction Head**: To predict the next item.

In [None]:
# Define the Schema for the model
from merlin.schema import Schema, Tags
import merlin.io

# Load schema from processed data
train_schema = merlin.io.Dataset("data/processed_sessions", engine="parquet").schema

# Select features
# We include all the new features we engineered
selected_features = [
    'item_id-list', 
    'traffic_source-list', 'region_city-list', # Original Context
    'brand-list', 'main_category-list', 'price_bucket-list', # Item Metadata
    'hour-list', 'day_of_week-list', 'is_weekend-list', # Temporal
    'item_popularity-list', 'category_popularity-list' # Continuous Counters
]

input_schema = train_schema.select_by_name(selected_features)

# WORKAROUND: Fix schema value_counts
new_cols = []
for col in input_schema:
    props = col.properties.copy()
    if 'value_count' not in props:
        props['value_count'] = {}
    props['value_count']['max'] = 20
    new_col = col.with_properties(props)
    new_cols.append(new_col)

input_schema = Schema(new_cols)
print("Input Schema (Fixed):", input_schema)

# Model parameters
d_model = 64
max_seq_length = 20

# Define the Input Block
# T4Rec automatically handles:
# - Embedding tables for categorical features (brand, city, etc.)
# - Projection for continuous features (popularity)
# - Concatenation of all features
input_module = tr.TabularSequenceFeatures.from_schema(
    input_schema,
    max_sequence_length=max_seq_length,
    aggregation="concat",
    d_output=d_model,
    masking="causal",
)

# Define the Transformer Body - XLNet
transformer_config = tr.XLNetConfig.build(
    d_model=d_model,
    n_head=4,
    n_layer=2,
    total_seq_length=max_seq_length
)

# Define body
body = tr.SequentialBlock(
    input_module,
    tr.TransformerBlock(transformer_config, masking=input_module.masking)
)

# Define ranking metrics
from transformers4rec.torch.ranking_metric import NDCGAt, RecallAt
metrics = [
    RecallAt(top_ks=[6, 10], labels_onehot=True),
    NDCGAt(top_ks=[6, 10], labels_onehot=True)
]

# Define the Head
head = tr.Head(
    body,
    tr.NextItemPredictionTask(weight_tying=True, metrics=metrics),
    inputs=input_module,
)

# Get the end-to-end Model
model = tr.Model(head)

print("Model built successfully!")
print(model)

Input Schema (Fixed): [{'name': 'item_id-list', 'tags': {<Tags.CATEGORICAL: 'categorical'>, <Tags.ITEM: 'item'>, <Tags.ID: 'id'>}, 'properties': {'freq_threshold': 0.0, 'num_buckets': None, 'cat_path': './/categories/unique.item_id.parquet', 'max_size': 0.0, 'embedding_sizes': {'dimension': 85.0, 'cardinality': 1200.0}, 'domain': {'min': 0, 'max': 1199, 'name': 'item_id'}, 'value_count': {'min': 0, 'max': 20}}, 'dtype': DType(name='int64', element_type=<ElementType.Int: 'int'>, element_size=64, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None), Dimension(min=0, max=20)))), 'is_list': True, 'is_ragged': True}, {'name': 'traffic_source-list', 'tags': {<Tags.CATEGORICAL: 'categorical'>}, 'properties': {'freq_threshold': 0.0, 'num_buckets': None, 'cat_path': './/categories/unique.traffic_source.parquet', 'max_size': 0.0, 'embedding_sizes': {'dimension': 16.0, 'cardinality': 14.0}, 'domain': {'min': 0, 'max': 13, 'name': 'traffic_source'}, 'value_count': {'min': 0



## 5. Training

We use the `Trainer` class (based on HuggingFace Trainer) to train the model.

In [34]:
from transformers4rec.torch.trainer import Trainer
from transformers4rec.torch.utils.data_utils import MerlinDataLoader

# Training Arguments
# Note: T4RecTrainingArguments is the correct class name in newer versions
training_args = tr.T4RecTrainingArguments(
    output_dir="./t4r_output",
    max_steps=500,
    learning_rate=0.001,
    per_device_train_batch_size=64,
    per_device_eval_batch_size=64,
    logging_steps=50,
    eval_steps=100,
    save_steps=100,
    evaluation_strategy="steps",
    report_to=[], # Disable wandb/mlflow for now
    dataloader_drop_last=False,
    compute_metrics_each_n_steps=1,
    use_mps_device=False, # Force CPU to avoid MPS errors on macOS
    no_cuda=True # Ensure we use CPU
)

# Initialize Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    schema=train_schema,
    train_dataset_or_path="data/train.parquet",
    eval_dataset_or_path="data/valid.parquet",
)

# Train
print("Starting training...")
trainer.train()

Starting training...




  0%|          | 0/500 [00:00<?, ?it/s]

{'loss': 6.5266, 'learning_rate': 0.0009000000000000001, 'epoch': 0.03}
{'loss': 5.9934, 'learning_rate': 0.0008, 'epoch': 0.06}
{'loss': 5.9934, 'learning_rate': 0.0008, 'epoch': 0.06}




  0%|          | 0/435 [00:00<?, ?it/s]

{'eval_/loss': 7.164107322692871, 'eval_runtime': 3.5477, 'eval_samples_per_second': 7847.357, 'eval_steps_per_second': 122.615, 'epoch': 0.06}
{'loss': 5.7634, 'learning_rate': 0.0007, 'epoch': 0.09}
{'loss': 5.7634, 'learning_rate': 0.0007, 'epoch': 0.09}
{'loss': 5.5622, 'learning_rate': 0.0006, 'epoch': 0.11}
{'loss': 5.5622, 'learning_rate': 0.0006, 'epoch': 0.11}




  0%|          | 0/435 [00:00<?, ?it/s]

{'eval_/loss': 6.12920618057251, 'eval_runtime': 3.3524, 'eval_samples_per_second': 8304.605, 'eval_steps_per_second': 129.759, 'epoch': 0.11}
{'loss': 5.4245, 'learning_rate': 0.0005, 'epoch': 0.14}
{'loss': 5.4245, 'learning_rate': 0.0005, 'epoch': 0.14}
{'loss': 5.2823, 'learning_rate': 0.0004, 'epoch': 0.17}
{'loss': 5.2823, 'learning_rate': 0.0004, 'epoch': 0.17}




  0%|          | 0/435 [00:00<?, ?it/s]

{'eval_/loss': 5.769600868225098, 'eval_runtime': 3.3671, 'eval_samples_per_second': 8268.13, 'eval_steps_per_second': 129.19, 'epoch': 0.17}
{'loss': 5.138, 'learning_rate': 0.0003, 'epoch': 0.2}
{'loss': 5.138, 'learning_rate': 0.0003, 'epoch': 0.2}
{'loss': 5.1025, 'learning_rate': 0.0002, 'epoch': 0.23}
{'loss': 5.1025, 'learning_rate': 0.0002, 'epoch': 0.23}




  0%|          | 0/435 [00:00<?, ?it/s]

{'eval_/loss': 5.535675525665283, 'eval_runtime': 3.4422, 'eval_samples_per_second': 8087.771, 'eval_steps_per_second': 126.371, 'epoch': 0.23}
{'loss': 5.0881, 'learning_rate': 0.0001, 'epoch': 0.26}
{'loss': 5.0881, 'learning_rate': 0.0001, 'epoch': 0.26}
{'loss': 4.988, 'learning_rate': 0.0, 'epoch': 0.29}
{'loss': 4.988, 'learning_rate': 0.0, 'epoch': 0.29}




  0%|          | 0/435 [00:00<?, ?it/s]

{'eval_/loss': 5.49398946762085, 'eval_runtime': 3.5249, 'eval_samples_per_second': 7898.023, 'eval_steps_per_second': 123.407, 'epoch': 0.29}
{'train_runtime': 33.3216, 'train_samples_per_second': 960.338, 'train_steps_per_second': 15.005, 'train_loss': 5.486896209716797, 'epoch': 0.29}


TrainOutput(global_step=500, training_loss=5.486896209716797, metrics={'train_runtime': 33.3216, 'train_samples_per_second': 960.338, 'train_steps_per_second': 15.005, 'total_flos': 0.0, 'train_loss': 5.486896209716797})

## 6. Evaluation

We evaluate the model using ranking metrics.

In [35]:
# Evaluate
# Enable metric computation (it defaults to None in Trainer)
trainer.compute_metrics = True

# Note: The trainer already has the eval_dataset_or_path from initialization
eval_metrics = trainer.evaluate()

print("\nAll Evaluation Metrics:")
for key, value in eval_metrics.items():
    print(f"{key}: {value}")

print("\nFiltered Results:")
for key, value in eval_metrics.items():
    if "recall" in key or "ndcg" in key:
        print(f"{key}: {value:.4f}")



  0%|          | 0/435 [00:00<?, ?it/s]


All Evaluation Metrics:
eval_/next-item/recall_at_6: 0.28340524435043335
eval_/next-item/recall_at_10: 0.325452983379364
eval_/next-item/ndcg_at_6: 0.20350845158100128
eval_/next-item/ndcg_at_10: 0.21658070385456085
eval_/loss: 5.49398946762085
eval_runtime: 4.7652
eval_samples_per_second: 5842.301
eval_steps_per_second: 91.286

Filtered Results:
eval_/next-item/recall_at_6: 0.2834
eval_/next-item/recall_at_10: 0.3255
eval_/next-item/ndcg_at_6: 0.2035
eval_/next-item/ndcg_at_10: 0.2166


# Generate Predictions

In [None]:
# 1. Load and Preprocess Test Data (Pandas)
print("Loading test data...")
test_hits_df = pd.read_csv('data/metrika_hits_test.csv', low_memory=False)
test_visits_df = pd.read_csv('data/metrika_visits_test.csv', low_memory=False)

# Parse watch_ids
test_visits_df['watch_ids_list'] = test_visits_df['watch_ids'].apply(parse_watch_ids)
test_session_hits = test_visits_df.explode('watch_ids_list').rename(columns={'watch_ids_list': 'watch_id'})

# Ensure types match
test_session_hits['watch_id'] = test_session_hits['watch_id'].astype(str)
test_hits_df['watch_id'] = test_hits_df['watch_id'].astype(str)

# Merge
test_full_data = test_session_hits.merge(test_hits_df, on='watch_id', how='inner')

# Filter PRODUCT
test_interactions = test_full_data[test_full_data['page_type'] == 'PRODUCT'].copy()

# Sort
test_interactions['date_time'] = pd.to_datetime(test_interactions['date_time_x'])
test_interactions = test_interactions.sort_values(['visit_id', 'date_time'])

# Rename slug -> item_id
test_interactions = test_interactions.rename(columns={'slug': 'item_id'})
test_interactions = test_interactions.dropna(subset=['item_id'])

# Handle duplicate columns
for col in test_interactions.columns:
    if col.endswith('_x'):
        base_name = col[:-2]
        if base_name not in test_interactions.columns:
            test_interactions = test_interactions.rename(columns={col: base_name})

# Merge Product Metadata (Same as training)
# We assume 'products_meta' is available or we reload it
# Let's reload to be safe and self-contained
print("Loading product metadata for test set...")
new_products = pd.read_csv('data/new_site_products.csv')
old_products = pd.read_csv('data/old_site_products.csv')
cols_new = ['slug', 'brand', 'main_category', 'price_per_period_week']
cols_old = ['slug', 'brand', 'main_category', 'price_per_period_week']
products_combined = pd.concat([new_products[cols_new], old_products[cols_old]])
products_meta = products_combined.drop_duplicates(subset=['slug']).copy()
products_meta['brand'] = products_meta['brand'].fillna('Unknown')
products_meta['main_category'] = products_meta['main_category'].fillna('Unknown')
products_meta['price_per_period_week'] = products_meta['price_per_period_week'].fillna(0)

test_interactions = test_interactions.merge(products_meta, left_on='item_id', right_on='slug', how='left')
test_interactions['brand'] = test_interactions['brand'].fillna('Unknown')
test_interactions['main_category'] = test_interactions['main_category'].fillna('Unknown')
test_interactions['price_per_period_week'] = test_interactions['price_per_period_week'].fillna(0)

# Feature Engineering (Same as training)
test_interactions['hour'] = test_interactions['date_time'].dt.hour
test_interactions['day_of_week'] = test_interactions['date_time'].dt.dayofweek
test_interactions['is_weekend'] = test_interactions['day_of_week'].isin([5, 6]).astype(int)
test_interactions['price_bucket'] = pd.qcut(test_interactions['price_per_period_week'], q=10, labels=False, duplicates='drop').fillna(0).astype(int)

# Select columns
cols_to_keep = [
    'visit_id', 'item_id', 'date_time', 
    'traffic_source', 'region_city',
    'brand', 'main_category', 'price_bucket',
    'hour', 'day_of_week', 'is_weekend'
]
test_interactions = test_interactions[cols_to_keep]

print(f"Processed test interactions: {len(test_interactions)}")

Loading test data...
Processed test interactions: 6434
Processed test interactions: 6434


In [None]:
# 2. Manual Transformation (Pandas)
# We map the categorical features using the dictionaries created during training

def load_map(col_name):
    path = f"categories/unique.{col_name}.parquet"
    # Check if file exists (some might not be created if cardinality is low or handled differently)
    if not os.path.exists(path):
        print(f"Warning: Map for {col_name} not found at {path}")
        return {}
        
    df = pd.read_parquet(path)
    df = df.reset_index(drop=True)
    return {val: i + 1 for i, val in enumerate(df[col_name])}

print("Loading category maps...")
# Load maps for ALL categorical features
cat_feats = ['item_id', 'traffic_source', 'region_city', 'brand', 'main_category', 
             'price_bucket', 'hour', 'day_of_week', 'is_weekend']
maps = {col: load_map(col) for col in cat_feats}

# Apply mappings
print("Mapping categorical columns...")
for col in cat_feats:
    # Map and fillna(0) for unknown/padding
    test_interactions[f'{col}_mapped'] = test_interactions[col].map(maps[col]).fillna(0).astype(int)

# Handle Continuous Features (Popularity)
# We need to use the TRAINING data statistics to ensure consistency
print("Calculating continuous features based on training stats...")
train_data = pd.read_parquet('data/enriched_interactions.parquet')

# 1. Compute Counts (Popularity)
item_counts = train_data['item_id'].value_counts()
category_counts = train_data['main_category'].value_counts()

test_interactions['item_popularity'] = test_interactions['item_id'].map(item_counts).fillna(0)
test_interactions['category_popularity'] = test_interactions['main_category'].map(category_counts).fillna(0)

# 2. LogOp (np.log(x + 1))
test_interactions['item_popularity'] = np.log(test_interactions['item_popularity'] + 1)
test_interactions['category_popularity'] = np.log(test_interactions['category_popularity'] + 1)

# 3. Normalize ((x - mean) / std)
# Compute stats from TRAINING data (after LogOp)
train_item_pop = np.log(train_data['item_popularity'] + 1)
train_cat_pop = np.log(train_data['category_popularity'] + 1)

test_interactions['item_popularity'] = (test_interactions['item_popularity'] - train_item_pop.mean()) / train_item_pop.std()
test_interactions['category_popularity'] = (test_interactions['category_popularity'] - train_cat_pop.mean()) / train_cat_pop.std()

# Groupby to create lists
print("Grouping by session...")
agg_dict = {f'{col}_mapped': list for col in cat_feats}
agg_dict['item_popularity'] = list
agg_dict['category_popularity'] = list
agg_dict['date_time'] = 'first'

test_grouped = test_interactions.groupby('visit_id').agg(agg_dict).reset_index()

# Rename columns to match schema expected by T4Rec
rename_dict = {f'{col}_mapped': f'{col}-list' for col in cat_feats}
rename_dict['item_popularity'] = 'item_popularity-list'
rename_dict['category_popularity'] = 'category_popularity-list'

test_grouped = test_grouped.rename(columns=rename_dict)

# Save to parquet
test_grouped.to_parquet("data/processed_test_sessions_pandas.parquet")
print(f"Processed {len(test_grouped)} test sessions.")
test_grouped.head()

Loading category maps...
Mapping columns...
Grouping by session...
Processed 2062 test sessions.


Unnamed: 0,visit_id,item_id-list,traffic_source-list,region_city-list,date_time
0,3705189088312688779,[601],[4],[1],2025-07-02 17:09:34
1,3705579516103753793,"[294, 497, 519, 519]","[1, 1, 1, 1]","[0, 0, 0, 0]",2025-07-02 17:34:23
2,3706260855703732415,"[775, 775]","[4, 4]","[1, 1]",2025-07-02 18:17:43
3,3706380385759789156,"[373, 373]","[5, 5]","[1, 1]",2025-07-02 18:25:19
4,3706985600653983919,"[705, 705]","[4, 4]","[1, 1]",2025-07-02 19:03:47


In [39]:
# 3. Generate Predictions

# Update args to return top-6 predictions
training_args.predict_top_k = 6
trainer.args = training_args

# Predict
print("Generating predictions...")
# Note: Trainer.predict() in T4Rec expects a dataset object or uses the one from init
# We need to manually create a dataloader or dataset for the test set
import merlin.io
test_dataset = merlin.io.Dataset("data/processed_test_sessions_pandas.parquet", engine="parquet")

# We can pass the dataset directly to predict
test_predictions = trainer.predict(test_dataset)

# The predictions object contains a tuple (item_ids, scores) because we set predict_top_k
# We extract the item IDs (first element of the tuple)
# Note: Depending on T4Rec version, it might be in predictions.predictions or just predictions
if isinstance(test_predictions.predictions, tuple):
    top_k_ids = test_predictions.predictions[0]
else:
    top_k_ids = test_predictions.predictions

print("Predictions shape:", top_k_ids.shape)

Generating predictions...




  0%|          | 0/33 [00:00<?, ?it/s]

Predictions shape: (2062, 6)


In [42]:
# 4. Create Submission File

import pandas as pd
import numpy as np

# 1. Load Product Mappings (Slug -> Numeric ID)
print("Loading product catalogs...")
try:
    new_products = pd.read_csv('data/new_site_products.csv', usecols=['id', 'slug'])
    old_products = pd.read_csv('data/old_site_products.csv', usecols=['id', 'slug'])
    
    # Combine and create map
    products_df = pd.concat([new_products, old_products]).drop_duplicates(subset=['slug'])
    # Ensure IDs are strings
    products_df['id'] = products_df['id'].astype(str)
    slug_to_id = dict(zip(products_df['slug'], products_df['id']))
    
    # Create a list of all available IDs for absolute fallback
    all_product_ids = products_df['id'].tolist()
    
    print(f"Loaded {len(slug_to_id)} product mappings.")
except Exception as e:
    print(f"Error loading product catalogs: {e}")
    slug_to_id = {}
    all_product_ids = []

# 2. Determine Popular Items (Fallback for cold users)
print("Calculating popular items for fallback...")
try:
    # IMPORTANT: Filter for page_type='PRODUCT' to avoid counting 'MAIN', 'CATALOG' etc.
    hits_df = pd.read_csv('data/metrika_hits.csv', usecols=['slug', 'page_type'])
    product_hits = hits_df[hits_df['page_type'] == 'PRODUCT']
    
    # Get top 50 to be safe
    top_slugs = product_hits['slug'].value_counts().head(50).index.tolist()
    
    # Map to IDs
    top_ids = []
    for s in top_slugs:
        if s in slug_to_id:
            tid = str(slug_to_id[s])
            if tid not in top_ids:
                top_ids.append(tid)
    
    print(f"Found {len(top_ids)} popular product IDs.")
    
    # Ensure we have at least 6
    if len(top_ids) < 6:
        print("Warning: Not enough popular items found. Padding with random products.")
        # Pad with any available products
        for pid in all_product_ids:
            if pid not in top_ids:
                top_ids.append(pid)
            if len(top_ids) >= 6:
                break
        
    default_pred_str = " ".join(top_ids[:6])
    print(f"Default prediction: {default_pred_str}")
except Exception as e:
    print(f"Error calculating popular items: {e}")
    # Absolute fallback
    if len(all_product_ids) >= 6:
        default_pred_str = " ".join(all_product_ids[:6])
    else:
        default_pred_str = "0 0 0 0 0 0" # Should not happen if data is loaded

# 3. Map Model Predictions to Product IDs
# reverse_item_map: Model Int -> Slug
reverse_item_map = {i + 1: val for i, val in enumerate(pd.read_parquet("categories/unique.item_id.parquet")['item_id'])}
reverse_item_map[0] = "unknown"

# Get visit_ids from the processed test set
processed_test = pd.read_parquet("data/processed_test_sessions_pandas.parquet")
predicted_visit_ids = processed_test['visit_id'].astype(str).values

prediction_map = {}
print("Mapping predictions...")
for i, vid in enumerate(predicted_visit_ids):
    model_ids = top_k_ids[i] # Array of model integers
    real_ids = []
    for mid in model_ids:
        slug = reverse_item_map.get(mid, "unknown")
        if slug in slug_to_id:
            real_ids.append(str(slug_to_id[slug]))
    
    # If we didn't find 6 valid IDs, pad with popular items
    if len(real_ids) < 6:
        # Only add unique popular items that aren't already predicted
        for pid in top_ids:
            if pid not in real_ids:
                real_ids.append(pid)
            if len(real_ids) >= 6:
                break
        
    prediction_map[vid] = " ".join(real_ids[:6])

# 4. Generate Final Submission for ALL Test Visits
print("Generating final submission file...")
test_visits = pd.read_csv('data/metrika_visits_test.csv', usecols=['visit_id'])
test_visits['visit_id'] = test_visits['visit_id'].astype(str)

def get_pred(vid):
    return prediction_map.get(vid, default_pred_str)

test_visits['product_ids'] = test_visits['visit_id'].apply(get_pred)

# Save
test_visits.to_csv("submission.csv", index=False)
print(f"Submission saved to submission.csv with {len(test_visits)} rows.")
print(test_visits.head())

Loading product catalogs...
Loaded 1400 product mappings.
Calculating popular items for fallback...
Found 50 popular product IDs.
Default prediction: 3714 3282 5602 3582 3365 3746
Mapping predictions...
Generating final submission file...
Submission saved to submission.csv with 3891 rows.
              visit_id                                   product_ids
0  3705073560174199024                 3714 3282 5602 3582 3365 3746
1  3705189088312688779  3942 463480486 463480210 3631 3653 495264803
2  3705549051029618879                 3714 3282 5602 3582 3365 3746
3  3705579516103753793            3393 3625 463480227 3942 3405 3301
4  3705717843210797336                 3714 3282 5602 3582 3365 3746
Found 50 popular product IDs.
Default prediction: 3714 3282 5602 3582 3365 3746
Mapping predictions...
Generating final submission file...
Submission saved to submission.csv with 3891 rows.
              visit_id                                   product_ids
0  3705073560174199024               