In [1]:
import os
import gc
import math
import random
from collections import defaultdict

import numpy as np
import pandas as pd
import polars as pl
import scipy.sparse as sp

from tqdm.auto import tqdm

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

In [4]:
import pandas as pd
import numpy as np
from collections import defaultdict
import warnings
from tqdm import tqdm

warnings.filterwarnings('ignore')

class TShoppingRecommender:
    def __init__(self, recent_days_window=3):
        self.user_recent_clicks = defaultdict(list)
        self.global_popularity = defaultdict(int)
        self.last_day = None
        self.recent_days_window = recent_days_window
        
    def fit(self, train_data):
        """Обучение модели на исторических данных"""
        print("Обработка тренировочных данных...")
        
        # Находим последний день в данных
        self.last_day = train_data['date'].max()
        print(f"Последний день в данных: {self.last_day}")
        print(f"Используется окно из {self.recent_days_window} последних дней")
        
        # Собираем клики пользователей за последние N дней
        recent_days = [self.last_day - i for i in range(self.recent_days_window)]
        recent_data = train_data[train_data['date'].isin(recent_days)]
        
        print(f"Найдено взаимодействий за последние {self.recent_days_window} дней: {len(recent_data)}")
        
        print("Сбор персональных предпочтений...")
        # Используем tqdm для прогресса
        for _, row in tqdm(recent_data.iterrows(), total=len(recent_data), desc="Обработка кликов"):
            user_id = row['user_id']
            item_id = row['item_id']
            date = row['date']
            
            # Сохраняем клики с учетом даты (более поздние имеют больший вес)
            self.user_recent_clicks[user_id].append((item_id, date))
        
        print("Сортировка кликов по дате...")
        # Сортируем клики каждого пользователя по дате (сначала последние)
        for user_id in tqdm(self.user_recent_clicks.keys(), desc="Сортировка пользователей"):
            self.user_recent_clicks[user_id].sort(key=lambda x: x[1], reverse=True)
        
        # Собираем глобальную популярность товаров за последний день
        last_day_data = train_data[train_data['date'] == self.last_day]
        item_counts = last_day_data['item_id'].value_counts()
        
        print("Расчет глобальной популярности...")
        for item_id, count in tqdm(item_counts.items(), total=len(item_counts), desc="Популярные товары"):
            self.global_popularity[item_id] = count
            
        print(f"Обработано пользователей: {len(self.user_recent_clicks)}")
        print(f"Уникальных популярных товаров: {len(self.global_popularity)}")
    
    def predict(self, user_ids, top_k=20):
        """Предсказание топ-K товаров для каждого пользователя"""
        print("Генерация рекомендаций...")
        
        # Создаем глобальный рейтинг популярности
        global_ranking = list(self.global_popularity.keys())
        
        submissions = []
        
        # Используем tqdm для прогресса предсказания
        for user_id in tqdm(user_ids, desc="Генерация рекомендаций"):
            user_recommendations = []
            
            # Берем уникальные товары из последних кликов пользователя
            if user_id in self.user_recent_clicks:
                recent_items = []
                for item_id, date in self.user_recent_clicks[user_id]:
                    if item_id not in recent_items:
                        recent_items.append(item_id)
                    if len(recent_items) >= top_k:
                        break
                
                user_recommendations.extend(recent_items)
            
            # Если недостаточно персональных рекомендаций, добавляем глобально популярные
            if len(user_recommendations) < top_k:
                # Берем только те товары, которых еще нет в рекомендациях
                additional_items = [item for item in global_ranking 
                                  if item not in user_recommendations]
                
                # Добавляем столько, сколько нужно до top_k
                needed = top_k - len(user_recommendations)
                user_recommendations.extend(additional_items[:needed])
            
            # Обеспечиваем, что рекомендаций ровно top_k
            user_recommendations = user_recommendations[:top_k]
            
            # Добавляем в submission
            for item_id in user_recommendations:
                submissions.append({
                    'user_id': user_id,
                    'item_id': item_id
                })
        
        return pd.DataFrame(submissions)

def main():
    # Загрузка данных
    print("Загрузка данных...")
    train_data = pd.read_parquet('/kaggle/input/stupidshit777/train_data.pq')
    
    # Загрузка sample submission для получения списка пользователей
    sample_submission = pd.read_csv('/kaggle/input/stupidshit777/sample_submission (11).csv')
    test_users = sample_submission['user_id'].unique()
    
    print(f"Всего пользователей в тесте: {len(test_users)}")
    print(f"Всего взаимодействий в тренировочных данных: {len(train_data)}")
    print(f"Уникальных пользователей: {train_data['user_id'].nunique()}")
    print(f"Уникальных товаров: {train_data['item_id'].nunique()}")
    
    # Настройка параметров
    RECENT_DAYS_WINDOW = 15  # Можно изменить на 3, 5, 10 и т.д.
    
    # Обучение модели
    model = TShoppingRecommender(recent_days_window=RECENT_DAYS_WINDOW)
    model.fit(train_data)
    
    # Предсказание
    submission_df = model.predict(test_users, top_k=20)
    
    # Проверка формата
    print(f"\nПроверка формата submission:")
    print(f"Всего строк в submission: {len(submission_df)}")
    print(f"Ожидалось: {len(test_users) * 20}")
    
    # Сохранение результатов
    submission_df.to_csv('submission.csv', index=False)
    print("\nSubmission сохранен в submission.csv")
    
    # Пример рекомендаций для первых 5 пользователей
    print("\nПример рекомендаций для первых 5 пользователей:")
    for i, user_id in enumerate(test_users[:5]):
        user_recs = submission_df[submission_df['user_id'] == user_id]['item_id'].tolist()
        print(f"Пользователь {user_id}: {user_recs[:5]}... (всего {len(user_recs)} рекомендаций)")
    
    # Статистика по рекомендациям
    print(f"\nСтатистика:")
    print(f"Использовано дней для персональных предпочтений: {RECENT_DAYS_WINDOW}")
    print(f"Размер глобального пула популярных товаров: {len(model.global_popularity)}")

if __name__ == "__main__":
    main()

Загрузка данных...
Всего пользователей в тесте: 293230
Всего взаимодействий в тренировочных данных: 8777975
Уникальных пользователей: 2682603
Уникальных товаров: 740651
Обработка тренировочных данных...
Последний день в данных: 46
Используется окно из 15 последних дней
Найдено взаимодействий за последние 15 дней: 3030646
Сбор персональных предпочтений...


Обработка кликов: 100%|██████████| 3030646/3030646 [02:14<00:00, 22549.26it/s]


Сортировка кликов по дате...


Сортировка пользователей: 100%|██████████| 1157076/1157076 [00:01<00:00, 973740.50it/s] 


Расчет глобальной популярности...


Популярные товары: 100%|██████████| 56659/56659 [00:00<00:00, 1708275.73it/s]


Обработано пользователей: 1157076
Уникальных популярных товаров: 56659
Генерация рекомендаций...


Генерация рекомендаций: 100%|██████████| 293230/293230 [27:46<00:00, 175.94it/s]



Проверка формата submission:
Всего строк в submission: 5864600
Ожидалось: 5864600

Submission сохранен в submission.csv

Пример рекомендаций для первых 5 пользователей:
Пользователь 247446: [1030, 658, 114, 302, 20]... (всего 20 рекомендаций)
Пользователь 352619: [658, 114, 302, 20, 34]... (всего 20 рекомендаций)
Пользователь 352620: [15093, 658, 114, 302, 20]... (всего 20 рекомендаций)
Пользователь 352622: [154849, 884, 106647, 437758, 68094]... (всего 20 рекомендаций)
Пользователь 3257: [171, 5049, 658, 57227, 2144]... (всего 20 рекомендаций)

Статистика:
Использовано дней для персональных предпочтений: 15
Размер глобального пула популярных товаров: 56659


In [None]:
import pandas as pd
import numpy as np
from collections import defaultdict
import warnings
from tqdm import tqdm

warnings.filterwarnings('ignore')

class TShoppingRecommender:
    def __init__(self, recent_days_window=3):
        self.user_recent_clicks = defaultdict(list)
        self.global_popularity = defaultdict(int)
        self.last_day = None
        self.recent_days_window = recent_days_window
        
    def fit(self, train_data):
        """Обучение модели на исторических данных"""
        print("Обработка тренировочных данных...")
        
        # Находим последний день в данных
        self.last_day = train_data['date'].max()
        print(f"Последний день в данных: {self.last_day}")
        print(f"Используется окно из {self.recent_days_window} последних дней")
        
        # Собираем клики пользователей за последние N дней
        recent_days = [self.last_day - i for i in range(self.recent_days_window)]
        recent_data = train_data[train_data['date'].isin(recent_days)]
        
        print(f"Найдено взаимодействий за последние {self.recent_days_window} дней: {len(recent_data)}")
        
        print("Сбор персональных предпочтений...")
        # Используем tqdm для прогресса
        for _, row in tqdm(recent_data.iterrows(), total=len(recent_data), desc="Обработка кликов"):
            user_id = row['user_id']
            item_id = row['item_id']
            date = row['date']
            
            # Сохраняем клики с учетом даты (более поздние имеют больший вес)
            self.user_recent_clicks[user_id].append((item_id, date))
        
        print("Сортировка кликов по дате...")
        # Сортируем клики каждого пользователя по дате (сначала последние)
        for user_id in tqdm(self.user_recent_clicks.keys(), desc="Сортировка пользователей"):
            self.user_recent_clicks[user_id].sort(key=lambda x: x[1], reverse=True)
        
        # Собираем глобальную популярность товаров за последний день
        last_day_data = train_data[train_data['date'] == self.last_day]
        item_counts = last_day_data['item_id'].value_counts()
        
        print("Расчет глобальной популярности...")
        for item_id, count in tqdm(item_counts.items(), total=len(item_counts), desc="Популярные товары"):
            self.global_popularity[item_id] = count
            
        print(f"Обработано пользователей: {len(self.user_recent_clicks)}")
        print(f"Уникальных популярных товаров: {len(self.global_popularity)}")
    
    def predict(self, user_ids, top_k=20):
        """Предсказание топ-K товаров для каждого пользователя"""
        print("Генерация рекомендаций...")
        
        # Создаем глобальный рейтинг популярности
        global_ranking = list(self.global_popularity.keys())
        
        submissions = []
        
        # Используем tqdm для прогресса предсказания
        for user_id in tqdm(user_ids, desc="Генерация рекомендаций"):
            user_recommendations = []
            
            # Берем уникальные товары из последних кликов пользователя
            if user_id in self.user_recent_clicks:
                recent_items = []
                for item_id, date in self.user_recent_clicks[user_id]:
                    if item_id not in recent_items:
                        recent_items.append(item_id)
                    if len(recent_items) >= top_k:
                        break
                
                user_recommendations.extend(recent_items)
            
            # Если недостаточно персональных рекомендаций, добавляем глобально популярные
            if len(user_recommendations) < top_k:
                # Берем только те товары, которых еще нет в рекомендациях
                additional_items = [item for item in global_ranking 
                                  if item not in user_recommendations]
                
                # Добавляем столько, сколько нужно до top_k
                needed = top_k - len(user_recommendations)
                user_recommendations.extend(additional_items[:needed])
            
            # Обеспечиваем, что рекомендаций ровно top_k
            user_recommendations = user_recommendations[:top_k]
            
            # Добавляем в submission
            for item_id in user_recommendations:
                submissions.append({
                    'user_id': user_id,
                    'item_id': item_id
                })
        
        return pd.DataFrame(submissions)

def main():
    # Загрузка данных
    print("Загрузка данных...")
    train_data = pd.read_parquet('/kaggle/input/stupidshit777/train_data.pq')
    
    # Загрузка sample submission для получения списка пользователей
    sample_submission = pd.read_csv('/kaggle/input/stupidshit777/sample_submission (11).csv')
    test_users = sample_submission['user_id'].unique()
    
    print(f"Всего пользователей в тесте: {len(test_users)}")
    print(f"Всего взаимодействий в тренировочных данных: {len(train_data)}")
    print(f"Уникальных пользователей: {train_data['user_id'].nunique()}")
    print(f"Уникальных товаров: {train_data['item_id'].nunique()}")
    
    # Настройка параметров
    RECENT_DAYS_WINDOW = 15  # Можно изменить на 3, 5, 10 и т.д.
    
    # Обучение модели
    model = TShoppingRecommender(recent_days_window=RECENT_DAYS_WINDOW)
    model.fit(train_data)
    
    # Предсказание
    submission_df = model.predict(test_users, top_k=20)
    
    # Проверка формата
    print(f"\nПроверка формата submission:")
    print(f"Всего строк в submission: {len(submission_df)}")
    print(f"Ожидалось: {len(test_users) * 20}")
    
    # Сохранение результатов
    submission_df.to_csv('submission.csv', index=False)
    print("\nSubmission сохранен в submission.csv")
    
    # Пример рекомендаций для первых 5 пользователей
    print("\nПример рекомендаций для первых 5 пользователей:")
    for i, user_id in enumerate(test_users[:5]):
        user_recs = submission_df[submission_df['user_id'] == user_id]['item_id'].tolist()
        print(f"Пользователь {user_id}: {user_recs[:5]}... (всего {len(user_recs)} рекомендаций)")
    
    # Статистика по рекомендациям
    print(f"\nСтатистика:")
    print(f"Использовано дней для персональных предпочтений: {RECENT_DAYS_WINDOW}")
    print(f"Размер глобального пула популярных товаров: {len(model.global_popularity)}")

if __name__ == "__main__":
    main()

Загрузка данных...
Всего пользователей в тесте: 293230
Всего взаимодействий в тренировочных данных: 8777975
Уникальных пользователей: 2682603
Уникальных товаров: 740651
Обработка тренировочных данных...
Последний день в данных: 46
Используется окно из 15 последних дней
Найдено взаимодействий за последние 15 дней: 3030646
Сбор персональных предпочтений...


Обработка кликов:  64%|██████▍   | 1938742/3030646 [01:26<00:47, 23029.43it/s]

In [4]:
!pip install polars pyarrow numpy scipy pandas tqdm torch scikit-learn
!pip install implicit


Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-curand-cu12==10.3.5.147 (from torch)
  Downloading nvidia_curand_cu12-10.3.5

In [2]:
pip install faiss-gpu-cu12

Collecting faiss-gpu-cu12
  Downloading faiss_gpu_cu12-1.12.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Downloading faiss_gpu_cu12-1.12.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (48.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m48.1/48.1 MB[0m [31m41.2 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[?25hInstalling collected packages: faiss-gpu-cu12
Successfully installed faiss-gpu-cu12-1.12.0
Note: you may need to restart the kernel to use updated packages.


In [3]:
pip install faiss-gpu

[31mERROR: Could not find a version that satisfies the requirement faiss-gpu (from versions: none)[0m[31m
[0m[31mERROR: No matching distribution found for faiss-gpu[0m[31m
[0mNote: you may need to restart the kernel to use updated packages.


In [5]:
from implicit.gpu.als import AlternatingLeastSquares as GPU_ALS

In [None]:
GPU ALS not available (No CUDA extension has been built, can't train on GPU.), falling back to CPU.
Training ALS (CPU)..

In [7]:
from implicit.gpu.als import AlternatingLeastSquares as GPU_ALS
als = GPU_ALS(factors=factors, regularization=reg, iterations=iters, random_state=42)
print("Training ALS (GPU)...")
als.fit(ui.T, show_progress=True)

NameError: name 'factors' is not defined

In [2]:
# -----------------------------
# Utils
# -----------------------------
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

set_seed(42)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def average_precision_at_k(true_items_set, ranked_items, k=20):
    if not true_items_set:
        return 0.0
    score = 0.0
    hits = 0
    for i, item in enumerate(ranked_items[:k]):
        if item in true_items_set:
            hits += 1
            score += hits / (i + 1)
    return score / min(len(true_items_set), k)

def map_at_k(gt_by_user, preds_by_user, k=20):
    aps = []
    for u, true_set in gt_by_user.items():
        preds = preds_by_user.get(u, [])
        aps.append(average_precision_at_k(true_set, preds, k=k))
    return float(np.mean(aps)) if aps else 0.0

# -----------------------------
# Load data
# -----------------------------
def load_data(train_path="train_data.pq", sample_path="sample_submission.csv"):
    print("Reading parquet with Polars...")
    dfpl = pl.read_parquet(train_path)
    # Ensure dtypes
    dfpl = dfpl.with_columns([
        pl.col("user_id").cast(pl.Int64),
        pl.col("item_id").cast(pl.Int64),
        pl.col("date").cast(pl.Int32)
    ])
    df = dfpl.to_pandas()
    del dfpl
    gc.collect()

    print("Mapping ids to contiguous indices...")
    df["uid"], user_uniques = pd.factorize(df["user_id"], sort=True)
    df["iid"], item_uniques = pd.factorize(df["item_id"], sort=True)

    uid2orig = np.array(user_uniques)
    iid2orig = np.array(item_uniques)

    # Sort by user, then by date for sequence models
    df.sort_values(["uid", "date"], inplace=True)
    df.reset_index(drop=True, inplace=True)

    print("Loading sample users...")
    sample = pd.read_csv(sample_path)
    # Sample repeats user_id 20 times; we only need unique users
    sample_users = sample["user_id"].unique()
    # Map sample users to internal uids
    # Build map from original user_id to uid
    user_orig2uid = {orig: idx for idx, orig in enumerate(uid2orig)}
    sample_uids = []
    for u in sample_users:
        if u in user_orig2uid:
            sample_uids.append(user_orig2uid[u])
        else:
            # shouldn't happen; but keep -1 to handle later
            sample_uids.append(-1)
    sample_uids = np.array(sample_uids, dtype=np.int64)

    return df, uid2orig, iid2orig, sample_users, sample_uids

# -----------------------------
# Split: last 7 days as validation holdout
# -----------------------------
def time_split(df, holdout_days=7):
    max_day = int(df["date"].max())
    val_start = max_day - (holdout_days - 1)
    train_mask = df["date"] < val_start
    val_mask = df["date"] >= val_start

    train_df = df[train_mask].copy()
    val_df = df[val_mask].copy()

    # Ground truth for validation: set of items per user in holdout period
    gt_val = {}
    for uid, grp in tqdm(val_df.groupby("uid", sort=False), desc="Build val GT"):
        gt_val[uid] = set(grp["iid"].tolist())

    return train_df, val_df, gt_val

# -----------------------------
# Build CSR matrix for implicit ALS
# -----------------------------
def build_user_item_csr(df, n_users, n_items):
    rows = df["uid"].values.astype(np.int32)
    cols = df["iid"].values.astype(np.int32)
    data = np.ones_like(rows, dtype=np.float32)
    ui = sp.coo_matrix((data, (rows, cols)), shape=(n_users, n_items), dtype=np.float32).tocsr()
    return ui

# -----------------------------
# ALS Retrieval
# -----------------------------
def train_als(user_item_csr, use_gpu=True, factors=128, reg=1e-4, iters=20):
    try:
        import implicit
        from implicit.nearest_neighbours import bm25_weight
        ui = bm25_weight(user_item_csr, K1=1.2, B=0.75).tocsr()
        gc.collect()

        if use_gpu:
            #try:
                from implicit.gpu.als import AlternatingLeastSquares as GPU_ALS
                als = GPU_ALS(factors=factors, regularization=reg, iterations=iters, random_state=42)
                print("Training ALS (GPU)...")
                als.fit(ui.T, show_progress=True)
                return als, ui
            #except Exception as e:
                print(f"GPU ALS not available ({e}), falling back to CPU.")
        from implicit.als import AlternatingLeastSquares as CPU_ALS
        als = CPU_ALS(factors=factors, regularization=reg, iterations=iters, use_cg=True, random_state=42)
        print("Training ALS (CPU)...")
        als.fit(ui.T, show_progress=True)
        return als, ui
    except Exception as e:
        print(f"Implicit not installed or failed: {e}")
        return None, user_item_csr

def als_recommend_batch(model, user_item_csr, user_ids, N=300, filter_seen=False, batch_size=1024):
    recs = {}
    desc = "ALS recommend"
    for i in tqdm(range(0, len(user_ids), batch_size), desc=desc):
        batch = user_ids[i: i+batch_size]
        for u in batch:
            if u < 0:
                recs[u] = (np.empty((0,), dtype=np.int32), np.empty((0,), dtype=np.float32))
                continue
            try:
                ids, scores = model.recommend(
                    userid=int(u),
                    user_items=user_item_csr,
                    N=N,
                    recalculate_user=True,
                    filter_already_liked_items=filter_seen
                )
                recs[u] = (ids.astype(np.int32), scores.astype(np.float32))
            except Exception:
                recs[u] = (np.empty((0,), dtype=np.int32), np.empty((0,), dtype=np.float32))
    return recs

# -----------------------------
# FAISS index for item similarity (ALS item factors)
# -----------------------------
def build_faiss_index(item_factors):
    try:
        import faiss
    except Exception as e:
        print(f"FAISS not available: {e}")
        return None, None

    d = item_factors.shape[1]
    try:
        res = faiss.StandardGpuResources() if torch.cuda.is_available() else None
    except Exception:
        res = None

    index = faiss.IndexFlatIP(d)
    if res is not None:
        index = faiss.index_cpu_to_gpu(res, 0, index)
    # Optionally normalize for cosine similarity; but ALS uses IP
    # faiss.normalize_L2(item_factors)
    index.add(item_factors.astype(np.float32))
    return index, res

def similar_items_for_users(index, item_factors, user_histories, topk_per_item=50, last_k=10):
    """
    user_histories: dict uid -> np.array of recent item ids (0-based)
    returns dict uid -> (item_ids, scores) aggregated over last_k items
    """
    try:
        import faiss
    except Exception:
        index = None

    recs = {}
    if index is None:
        return recs

    for uid, hist in tqdm(user_histories.items(), desc="FAISS item2item"):
        if hist.size == 0:
            recs[uid] = (np.empty((0,), dtype=np.int32), np.empty((0,), dtype=np.float32))
            continue
        queries = hist[-last_k:]
        # Get neighbors for each query
        # Note: include self in results; we will drop identical ids
        distances, neighbors = index.search(item_factors[queries].astype(np.float32), k=topk_per_item+1)
        # Aggregate by sum of scores, exclude original items
        agg = defaultdict(float)
        hist_set = set(queries.tolist())
        for row in range(neighbors.shape[0]):
            for j in range(neighbors.shape[1]):
                it = int(neighbors[row, j])
                if it in hist_set:
                    continue
                agg[it] += float(distances[row, j])
        if not agg:
            recs[uid] = (np.empty((0,), dtype=np.int32), np.empty((0,), dtype=np.float32))
        else:
            items = np.fromiter(agg.keys(), dtype=np.int32)
            scores = np.fromiter(agg.values(), dtype=np.float32)
            order = np.argsort(-scores)
            recs[uid] = (items[order], scores[order])
    return recs

# -----------------------------
# Popularity candidates
# -----------------------------
def popularity_scores(df_train, df_recent_days=7, n_items=None):
    max_day = int(df_train["date"].max())
    recent_start = max_day - (df_recent_days - 1)

    all_cnt = df_train.groupby("iid").size()
    rec_cnt = df_train[df_train["date"] >= recent_start].groupby("iid").size()

    # Align
    if n_items is None:
        n_items = int(df_train["iid"].max()) + 1
    all_vec = np.zeros(n_items, dtype=np.float32)
    rec_vec = np.zeros(n_items, dtype=np.float32)

    all_vec[all_cnt.index.values] = all_cnt.values.astype(np.float32)
    rec_vec[rec_cnt.index.values] = rec_cnt.values.astype(np.float32)

    # Normalize
    all_vec /= (all_vec.max() + 1e-9)
    rec_vec /= (rec_vec.max() + 1e-9)

    # Weighted blend (tuneable)
    pop = 0.3 * all_vec + 0.7 * rec_vec
    return pop

def top_pop_items(pop_scores, topk=500):
    idx = np.argsort(-pop_scores)[:topk]
    return idx, pop_scores[idx]

# -----------------------------
# Build user sequences (offsets) for SASRec
# -----------------------------
def build_user_offsets(df, n_users):
    # df must be sorted by ["uid", "date"]
    uids = df["uid"].values.astype(np.int64)
    iids = df["iid"].values.astype(np.int64)
    # Keep 0-based for ALS; SASRec will use +1 for padding
    counts = np.bincount(uids, minlength=n_users).astype(np.int64)
    offsets = np.zeros(n_users + 1, dtype=np.int64)
    offsets[1:] = np.cumsum(counts)
    # items array already in user-sorted order
    return uids, iids, offsets

def get_user_last_items(iids, offsets, uids):
    user_hist = {}
    for u in tqdm(np.unique(uids), desc="Collect last items per user"):
        s, e = offsets[u], offsets[u+1]
        if e - s > 0:
            user_hist[u] = iids[s:e]
        else:
            user_hist[u] = np.array([], dtype=np.int64)
    return user_hist

# -----------------------------
# SASRec model
# -----------------------------
class SASRec(nn.Module):
    def __init__(self, n_items_padded, max_len=50, d_model=128, n_heads=4, n_layers=2, dropout=0.2):
        super().__init__()
        self.n_items = n_items_padded
        self.max_len = max_len
        self.item_emb = nn.Embedding(n_items_padded, d_model, padding_idx=0)
        self.pos_emb = nn.Embedding(max_len, d_model)
        layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=n_heads, dim_feedforward=d_model*4,
                                           dropout=dropout, batch_first=True)
        self.encoder = nn.TransformerEncoder(layer, num_layers=n_layers, norm=nn.LayerNorm(d_model))
        self.dropout = nn.Dropout(dropout)

        nn.init.normal_(self.item_emb.weight, mean=0.0, std=0.02)
        nn.init.normal_(self.pos_emb.weight, mean=0.0, std=0.02)

    def forward(self, seq_ids):
        # seq_ids: (B, L) with 0 = PAD
        B, L = seq_ids.size()
        positions = torch.arange(L, device=seq_ids.device).unsqueeze(0).expand(B, L)
        x = self.item_emb(seq_ids) + self.pos_emb(positions)
        x = self.dropout(x)

        # Causal mask (prevent attending to future positions)
        causal_mask = torch.triu(torch.ones(L, L, device=seq_ids.device), diagonal=1).bool()
        key_padding_mask = (seq_ids == 0)  # True for PAD

        x = self.encoder(x, mask=causal_mask, src_key_padding_mask=key_padding_mask)
        # Gather last non-pad timestep
        lengths = (~key_padding_mask).sum(dim=1).clamp(min=1)  # at least 1
        idx = (lengths - 1).unsqueeze(1).unsqueeze(2).expand(-1, 1, x.size(-1))  # B x 1 x D
        out = x.gather(1, idx).squeeze(1)  # B x D
        return out  # user representation

    def item_embedding(self, item_ids):
        return self.item_emb(item_ids)

# -----------------------------
# SASRec dataset
# -----------------------------
class SequenceDataset(Dataset):
    def __init__(self, items_by_user, offsets, max_len=50, users=None):
        """
        items_by_user: 1D np.array of item ids (0-based), sorted by user/date
        offsets: np.array of size n_users+1
        """
        self.items = items_by_user.astype(np.int64)  # 0-based
        self.offsets = offsets.astype(np.int64)
        self.max_len = max_len
        self.n_users = offsets.size - 1

        if users is None:
            # Only users with at least 2 interactions
            all_users = np.arange(self.n_users, dtype=np.int64)
            lens = self.offsets[1:] - self.offsets[:-1]
            self.users = all_users[lens >= 2]
        else:
            self.users = np.array(users, dtype=np.int64)

        # Shuffle initial order
        self.shuffle_users()

    def shuffle_users(self):
        np.random.shuffle(self.users)

    def __len__(self):
        return len(self.users)

    def __getitem__(self, idx):
        u = self.users[idx]
        s, e = self.offsets[u], self.offsets[u+1]
        seq = self.items[s:e]  # 0-based
        L = len(seq)
        # Sample a cut point: prefix -> next item
        # t in [1..L-1], target = seq[t]
        t = np.random.randint(1, L)
        prefix = seq[:t]
        target = seq[t]

        # Truncate to last max_len of prefix
        if len(prefix) > self.max_len:
            prefix = prefix[-self.max_len:]
        # Convert to padded space: add 1
        prefix_padded = prefix + 1  # 1..n_items
        # Left-pad with zeros
        pad_len = self.max_len - len(prefix_padded)
        if pad_len > 0:
            prefix_padded = np.pad(prefix_padded, (pad_len, 0), constant_values=0)

        target_padded = int(target + 1)
        return prefix_padded.astype(np.int64), target_padded

def collate_batch(batch):
    seqs_np = np.stack([b[0] for b in batch], axis=0).astype(np.int64, copy=False)
    seqs = torch.from_numpy(seqs_np)
    targets = torch.tensor([b[1] for b in batch], dtype=torch.long)
    return seqs, targets

# -----------------------------
# Train SASRec with in-batch negatives
# -----------------------------
def train_sasrec(model, dataset, epochs=2, batch_size=1024, lr=1e-3, val_eval_fn=None):
    model = model.to(DEVICE)
    optim = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)
    loss_fn = nn.CrossEntropyLoss()

    steps_per_epoch = math.ceil(len(dataset) / batch_size)
    for epoch in range(1, epochs + 1):
        dataset.shuffle_users()
        loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=2,
                            collate_fn=collate_batch, pin_memory=True)
        model.train()
        pbar = tqdm(loader, total=steps_per_epoch, desc=f"SASRec epoch {epoch}")
        losses = []
        for seqs, targets in pbar:
            seqs = seqs.to(DEVICE, non_blocking=True)
            targets = targets.to(DEVICE, non_blocking=True)

            user_vec = model(seqs)                       # (B, D)
            pos_embs = model.item_embedding(targets)     # (B, D)
            logits = user_vec @ pos_embs.T               # (B, B), in-batch negatives
            labels = torch.arange(seqs.size(0), device=DEVICE)
            loss = loss_fn(logits, labels)

            optim.zero_grad(set_to_none=True)
            loss.backward()
            optim.step()

            losses.append(loss.item())
            if len(losses) % 20 == 0:
                pbar.set_postfix(loss=np.mean(losses[-20:]))

        avg_loss = np.mean(losses) if losses else 0.0
        print(f"Epoch {epoch} train loss: {avg_loss:.4f}")

        if val_eval_fn is not None:
            val_map = val_eval_fn(model)
            print(f"Epoch {epoch} VAL mAP@20: {val_map:.5f}")

# -----------------------------
# Candidate merger and reranking
# -----------------------------
def merge_candidates(als_recs, item2item_recs, pop_items, pop_scores, users, topk_merge=500,
                     w_als=1.0, w_i2i=1.0, w_pop=0.3):
    """
    Return dict uid -> np.array of candidate item ids (0-based) with blended scores
    """
    merged = {}
    for u in tqdm(users, desc="Merge candidates"):
        scores = defaultdict(float)
        # ALS
        if als_recs is not None and u in als_recs:
            ids, sc = als_recs[u]
            for iid, s in zip(ids, sc):
                scores[int(iid)] += w_als * float(s)
        # i2i
        if item2item_recs is not None and u in item2item_recs:
            ids, sc = item2item_recs[u]
            for iid, s in zip(ids, sc):
                scores[int(iid)] += w_i2i * float(s)
        # POP
        if pop_items is not None:
            for iid, s in zip(pop_items, pop_scores):
                scores[int(iid)] += w_pop * float(s)

        if not scores:
            merged[u] = np.array([], dtype=np.int32)
        else:
            items = np.fromiter(scores.keys(), dtype=np.int32)
            scs = np.fromiter(scores.values(), dtype=np.float32)
            order = np.argsort(-scs)[:topk_merge]
            merged[u] = items[order]
    return merged

@torch.no_grad()
def rerank_with_sasrec(model, user_candidates, items_by_user, offsets, max_len=50, batch_size=2048):
    """
    user_candidates: dict uid -> np.array of candidates (0-based)
    returns dict uid -> sorted list of item ids (0-based) by SASRec score
    """
    model.eval()
    uids = list(user_candidates.keys())
    results = {}

    # Prebuild user sequences (padded) for batch scoring
    seq_cache = {}
    for u in tqdm(uids, desc="Prepare sequences"):
        s, e = offsets[u], offsets[u+1]
        hist = items_by_user[s:e]  # 0-based
        if hist.size == 0:
            seq_cache[u] = torch.zeros((max_len,), dtype=torch.long)
        else:
            seq = hist[-max_len:] + 1  # to padded space
            pad_len = max_len - len(seq)
            if pad_len > 0:
                seq = np.pad(seq, (pad_len, 0), constant_values=0)
            seq_cache[u] = torch.tensor(seq, dtype=torch.long)

    # Batch over users
    for i in tqdm(range(0, len(uids), batch_size), desc="Re-ranking"):
        batch_u = uids[i: i+batch_size]
        # Build seq batch
        seq_batch = torch.stack([seq_cache[u] for u in batch_u], dim=0).to(DEVICE)
        user_vecs = model(seq_batch)  # (B, D)

        # For each user, score their candidates
        for j, u in enumerate(batch_u):
            cands = user_candidates[u]
            if cands.size == 0:
                results[u] = []
                continue
            cands_pad = torch.tensor(cands + 1, dtype=torch.long, device=DEVICE)  # to padded idx
            item_embs = model.item_embedding(cands_pad)  # (C, D)
            logits = torch.mv(item_embs, user_vecs[j])   # (C,)
            # sort by scores desc
            order = torch.argsort(logits, descending=True).detach().cpu().numpy()
            ranked = cands[order]
            results[u] = ranked.tolist()
    return results

# -----------------------------
# End-to-end
# -----------------------------
def main(train_path="/kaggle/input/stupidshit777/train_data.pq", sample_path="/kaggle/input/stupidshit777/sample_submission (11).csv", submission_path="submission1.csv",
         holdout_days=7, max_seq_len=50):
    # Load
    df, uid2orig, iid2orig, sample_users_orig, sample_uids = load_data(train_path, sample_path)
    n_users = int(df["uid"].max()) + 1
    n_items = int(df["iid"].max()) + 1

    # Split
    train_df, val_df, gt_val = time_split(df, holdout_days=holdout_days)
    print(f"Train interactions: {len(train_df):,}, Val interactions: {len(val_df):,}")
    del val_df
    gc.collect()

    # CSR for ALS
    user_item_csr = build_user_item_csr(train_df, n_users, n_items)

    # Train ALS
    als_model, weighted_ui = train_als(user_item_csr, use_gpu=torch.cuda.is_available(),
                                       factors=128, reg=1e-4, iters=20)
    als_item_factors = None
    if als_model is not None:
        try:
            als_item_factors = als_model.item_factors  # (n_items, d)
        except:
            als_item_factors = None

    # Build user sequences for SASRec
    uids_sorted, iids_sorted, offsets = build_user_offsets(train_df, n_users)
    # For item2item (FAISS), we need last items per user (we'll only for users we care about)
    # We'll restrict to val users and test users to speed up
    users_for_hist = np.unique(np.concatenate([np.array(list(gt_val.keys()), dtype=np.int64),
                                               sample_uids[sample_uids >= 0]]))
    user_histories = {}
    for u in tqdm(users_for_hist, desc="Build histories for target users"):
        s, e = offsets[u], offsets[u+1]
        user_histories[u] = iids_sorted[s:e]

    # Popularity
    pop_scores_vec = popularity_scores(train_df, df_recent_days=7, n_items=n_items)
    pop_items, pop_scores = top_pop_items(pop_scores_vec, topk=300)

    # ALS candidates
    target_users_val = np.array(list(gt_val.keys()), dtype=np.int64)
    als_val_recs = als_recommend_batch(als_model, weighted_ui if als_model is not None else user_item_csr,
                                       target_users_val, N=300, filter_seen=False, batch_size=2048) if als_model else {}

    # FAISS item2item candidates
    faiss_index, _ = build_faiss_index(als_item_factors) if als_item_factors is not None else (None, None)
    i2i_val_recs = similar_items_for_users(faiss_index, als_item_factors, user_histories,
                                           topk_per_item=80, last_k=10) if faiss_index is not None else {}

    # Merge candidates for VAL
    cand_val = merge_candidates(als_val_recs, i2i_val_recs, pop_items, pop_scores, target_users_val,
                                topk_merge=500, w_als=1.0, w_i2i=1.0, w_pop=0.3)

    # SASRec training dataset
    sasrec_items = iids_sorted  # 0-based
    sasrec_dataset = SequenceDataset(sasrec_items, offsets, max_len=max_seq_len)

    # Model
    sasrec_model = SASRec(n_items_padded=n_items + 1, max_len=max_seq_len, d_model=128, n_heads=4, n_layers=2, dropout=0.2)

    # Define validation evaluation function (rerank val candidates -> mAP@20)
    def eval_val(model):
        preds_val = rerank_with_sasrec(model, cand_val, sasrec_items, offsets, max_len=max_seq_len, batch_size=2048)
        # Convert to simple dict uid->list (0-based)
        # Compute mAP@20
        return map_at_k(gt_val, preds_val, k=20)

    # Train SASRec with val monitoring
    train_sasrec(sasrec_model, sasrec_dataset, epochs=2, batch_size=1024, lr=3e-4, val_eval_fn=eval_val)

    # Build candidates for TEST users (from sample)
    test_users = sample_uids[sample_uids >= 0]
    test_users = np.unique(test_users)
    # ALS
    als_test_recs = als_recommend_batch(als_model, weighted_ui if als_model is not None else user_item_csr,
                                        test_users, N=300, filter_seen=False, batch_size=4096) if als_model else {}
    # i2i
    # Ensure we have histories
    for u in test_users:
        if u not in user_histories:
            s, e = offsets[u], offsets[u+1]
            user_histories[u] = iids_sorted[s:e]
    i2i_test_recs = similar_items_for_users(faiss_index, als_item_factors, user_histories,
                                            topk_per_item=80, last_k=10) if faiss_index is not None else {}
    # Merge
    cand_test = merge_candidates(als_test_recs, i2i_test_recs, pop_items, pop_scores, test_users,
                                 topk_merge=500, w_als=1.0, w_i2i=1.0, w_pop=0.3)

    # Rerank for TEST
    preds_test = rerank_with_sasrec(sasrec_model, cand_test, sasrec_items, offsets, max_len=max_seq_len, batch_size=4096)

    # Prepare submission: sample has 20 rows per user
    print("Writing submission...")
    sample = pd.read_csv(sample_path)
    
    # Явно приводим маппинг к int64
    user_orig2uid = {int(orig): int(idx) for idx, orig in enumerate(uid2orig)}
    item_idx2orig = np.asarray(iid2orig, dtype=np.int64)  # только int64
    
    def ensure_int_ids(arr, name=""):
        a = np.asarray(arr)
        if not np.issubdtype(a.dtype, np.integer):
            raise TypeError(f"{name}: ожидаю item_id (int), получил {a.dtype}. Пример: {a[:10]}")
        return a.astype(np.int64, copy=False)
    
    # кеш предсказаний
    pop20 = top_pop_items(pop_scores_vec, topk=20)[0]
    pop20_orig = item_idx2orig[pop20]
    pred_item_cache = {}
    
    for u_orig in tqdm(sample["user_id"].unique(), desc="Cache predictions"):
        u_orig = int(u_orig)
        uid = user_orig2uid.get(u_orig, -1)
        if uid == -1:
            pred_item_cache[u_orig] = pop20_orig
            continue
    
        ranked0 = preds_test.get(uid, None)
        if ranked0 is None or len(ranked0) == 0:
            ranked0 = cand_test.get(uid, np.array([], dtype=np.int32)).tolist()
    
        # дополняем популярными без повторов
        seen = set(ranked0)
        for it in pop_items:
            if len(ranked0) >= 20: break
            it = int(it)
            if it not in seen:
                ranked0.append(it); seen.add(it)
        ranked0 = np.array(ranked0[:20], dtype=np.int64)
    
        # маппим к оригинальным id и валидируем
        pred_item_cache[u_orig] = ensure_int_ids(item_idx2orig[ranked0], name=f"user {u_orig}")
    
    # аккуратно раскладываем по 20 строк каждого юзера
    from collections import defaultdict
    counters = defaultdict(int)
    out_items = np.empty(len(sample), dtype=np.int64)
    
    for idx, u in tqdm(enumerate(sample["user_id"].values), total=len(sample), desc="Fill submission"):
        u = int(u)
        i = counters[u]
        arr = pred_item_cache[u]
        if i >= len(arr):
            # подстраховка
            arr = pop20_orig
            pred_item_cache[u] = arr
            i = 0
        out_items[idx] = int(arr[i])
        counters[u] += 1
    
    sample["item_id"] = out_items.astype(np.int64)
    # финальная гарантия типа
    assert np.issubdtype(sample["item_id"].dtype, np.integer), sample["item_id"].dtype
    sample.to_csv(submission_path, index=False)
    print(f"Saved submission to {submission_path}")

if __name__ == "__main__":
    main()

Reading parquet with Polars...
Mapping ids to contiguous indices...
Loading sample users...


Build val GT:   0%|          | 0/592309 [00:00<?, ?it/s]

Train interactions: 7,326,431, Val interactions: 1,451,544
Implicit not installed or failed: No module named 'implicit'


Build histories for target users:   0%|          | 0/782173 [00:00<?, ?it/s]

Merge candidates:   0%|          | 0/592309 [00:00<?, ?it/s]

SASRec epoch 1:   0%|          | 0/1145 [00:00<?, ?it/s]

Epoch 1 train loss: 6.9032


Prepare sequences:   0%|          | 0/592309 [00:00<?, ?it/s]

Re-ranking:   0%|          | 0/290 [00:00<?, ?it/s]

Epoch 1 VAL mAP@20: 0.04629


SASRec epoch 2:   0%|          | 0/1145 [00:00<?, ?it/s]

Epoch 2 train loss: 6.8856


Prepare sequences:   0%|          | 0/592309 [00:00<?, ?it/s]

Re-ranking:   0%|          | 0/290 [00:00<?, ?it/s]

Epoch 2 VAL mAP@20: 0.04629


Merge candidates:   0%|          | 0/293230 [00:00<?, ?it/s]

Prepare sequences:   0%|          | 0/293230 [00:00<?, ?it/s]

Re-ranking:   0%|          | 0/72 [00:00<?, ?it/s]

Writing submission...


Cache predictions:   0%|          | 0/293230 [00:00<?, ?it/s]

Fill submission:   0%|          | 0/5864600 [00:00<?, ?it/s]

Saved submission to submission1.csv


In [3]:
import os, gc, math, random
import numpy as np
import pandas as pd
import polars as pl
from tqdm import tqdm
from collections import defaultdict, Counter

# -----------------------------
# Utils
# -----------------------------
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)

set_seed(42)

def average_precision_at_k(true_items_set, ranked_items, k=20):
    if not true_items_set:
        return 0.0
    score = 0.0
    hits = 0
    for i, item in enumerate(ranked_items[:k]):
        if item in true_items_set:
            hits += 1
            score += hits / (i + 1)
    return score / min(len(true_items_set), k)

def map_at_k(gt_by_user, preds_by_user, k=20):
    aps = []
    for u, true_set in gt_by_user.items():
        preds = preds_by_user.get(u, [])
        aps.append(average_precision_at_k(true_set, preds, k=k))
    return float(np.mean(aps)) if aps else 0.0

# -----------------------------
# Load & factorize
# -----------------------------
def load_data(train_path="train_data.pq", sample_path="sample_submission.csv"):
    print("Reading parquet with Polars...")
    dfpl = pl.read_parquet(train_path)
    dfpl = dfpl.with_columns([
        pl.col("user_id").cast(pl.Int64),
        pl.col("item_id").cast(pl.Int64),
        pl.col("date").cast(pl.Int32)
    ])
    df = dfpl.to_pandas()
    del dfpl; gc.collect()

    print("Mapping ids to contiguous indices...")
    df["uid"], user_uniques = pd.factorize(df["user_id"], sort=True)
    df["iid"], item_uniques = pd.factorize(df["item_id"], sort=True)

    uid2orig = np.asarray(user_uniques, dtype=np.int64)
    iid2orig = np.asarray(item_uniques, dtype=np.int64)

    # Sort by user/date
    df.sort_values(["uid", "date"], inplace=True)
    df.reset_index(drop=True, inplace=True)

    print("Loading sample users...")
    sample = pd.read_csv(sample_path)
    sample_users = sample["user_id"].unique().astype(np.int64)

    # Map sample users to internal uids
    user_orig2uid = {int(orig): int(i) for i, orig in enumerate(uid2orig)}
    sample_uids = np.array([user_orig2uid.get(int(u), -1) for u in sample_users], dtype=np.int64)

    return df, uid2orig, iid2orig, sample_users, sample_uids

# -----------------------------
# Split: last 7 days as validation holdout (for tuning)
# -----------------------------
def time_split(df, holdout_days=7):
    max_day = int(df["date"].max())
    val_start = max_day - (holdout_days - 1)
    train_mask = df["date"] < val_start
    val_mask = ~train_mask
    train_df = df[train_mask].copy()
    val_df = df[val_mask].copy()

    gt_val = {}
    for uid, grp in tqdm(val_df.groupby("uid", sort=False), desc="Build val GT"):
        gt_val[int(uid)] = set(grp["iid"].astype(int).tolist())
    return train_df, val_df, gt_val

# -----------------------------
# Build offsets (CSR-like)
# -----------------------------
def build_user_offsets(df, n_users):
    uids = df["uid"].values.astype(np.int64)
    iids = df["iid"].values.astype(np.int64)
    dates = df["date"].values.astype(np.int32)

    counts = np.bincount(uids, minlength=n_users).astype(np.int64)
    offsets = np.zeros(n_users + 1, dtype=np.int64)
    offsets[1:] = np.cumsum(counts)
    return uids, iids, dates, offsets

# -----------------------------
# Popularity (global)
# -----------------------------
def popularity_scores(df, recent_days=7, n_items=None):
    max_day = int(df["date"].max())
    recent_start = max_day - (recent_days - 1)

    all_cnt = df.groupby("iid").size()
    rec_cnt = df[df["date"] >= recent_start].groupby("iid").size()

    if n_items is None:
        n_items = int(df["iid"].max()) + 1
    all_vec = np.zeros(n_items, dtype=np.float32)
    rec_vec = np.zeros(n_items, dtype=np.float32)

    all_vec[all_cnt.index.values] = all_cnt.values.astype(np.float32)
    rec_vec[rec_cnt.index.values] = rec_cnt.values.astype(np.float32)

    all_vec /= (all_vec.max() + 1e-9)
    rec_vec /= (rec_vec.max() + 1e-9)

    pop = 0.3 * all_vec + 0.7 * rec_vec
    return pop

def top_pop_items(pop_scores, topk=500):
    idx = np.argsort(-pop_scores)[:topk]
    return idx, pop_scores[idx]

# -----------------------------
# Co-visitation graph (pure counting, no ML)
# -----------------------------
def build_covisit_adjacency(iids, dates, offsets,
                            max_last_per_user=60, window=20,
                            time_decay=0.08,  # per day
                            symmetric=True,
                            cap_neighbors=None):
    """
    Returns: dict[item] -> dict[neighbor] -> score
    Weights:
      w = (1 / gap) * exp(-time_decay * (max_day - date_j))
    Only last `max_last_per_user` interactions per user, pairs within `window`.
    """
    n_users = offsets.size - 1
    max_day = int(dates.max()) if len(dates) > 0 else 0
    adj = defaultdict(lambda: defaultdict(float))

    print(f"Building co-vis adjacency: max_last_per_user={max_last_per_user}, window={window}")
    for u in tqdm(range(n_users), desc="Covis per user"):
        s, e = int(offsets[u]), int(offsets[u+1])
        m = e - s
        if m <= 1:
            continue
        start = max(s, e - max_last_per_user)
        seq = iids[start:e]      # item ids
        dts = dates[start:e]     # days
        L = len(seq)
        # iterate sequential pairs within window
        for i in range(L):
            item_i = int(seq[i])
            di = int(dts[i])
            # j after i
            max_j = min(L, i + 1 + window)
            for j in range(i + 1, max_j):
                item_j = int(seq[j])
                dj = int(dts[j])
                gap = j - i
                w = (1.0 / gap) * math.exp(-time_decay * (max_day - dj))
                adj[item_i][item_j] += w
                if symmetric:
                    adj[item_j][item_i] += w

    # Optional pruning per item: keep only top cap_neighbors
    if cap_neighbors is not None and cap_neighbors > 0:
        print(f"Pruning adjacency to top-{cap_neighbors} neighbors per item...")
        for it in tqdm(list(adj.keys()), desc="Prune neighbors"):
            neigh = adj[it]
            if len(neigh) > cap_neighbors:
                # keep top by score
                top = sorted(neigh.items(), key=lambda kv: kv[1], reverse=True)[:cap_neighbors]
                adj[it] = defaultdict(float, top)

    return adj

# -----------------------------
# Personalized repeat (no ML)
# -----------------------------
def user_repeat_scores(hist_items, hist_dates, max_day, time_decay=0.08):
    """
    Sum exp(-time_decay * age_in_days) per item in user's history.
    Returns dict[item] -> score
    """
    rep = defaultdict(float)
    for it, d in zip(hist_items, hist_dates):
        rep[int(it)] += math.exp(-time_decay * (max_day - int(d)))
    return rep

# -----------------------------
# Recommend for a batch of users (algorithmic)
# -----------------------------
def recommend_batch(users, offsets, iids, dates, adj, pop_items, pop_scores,
                    last_k=10, topk=20,
                    w_covis=1.0, w_seed_recency=0.85,  # seed recency downweight per step back
                    w_repeat=0.3, w_pop=0.2,
                    cap_candidates=600):
    """
    Returns dict uid -> list of item ids (0-based), length topk
    """
    results = {}
    max_day = int(dates.max()) if len(dates) > 0 else 0
    for u in tqdm(users, desc="Recommend users"):
        u = int(u)
        s, e = int(offsets[u]), int(offsets[u+1])
        if e - s == 0:
            # cold-start: top-pop
            cand = list(map(int, pop_items[:topk]))
            results[u] = cand
            continue

        hist_items = iids[s:e]
        hist_dates = dates[s:e]
        # seeds: last_k items (most recent at end)
        seeds = hist_items[-last_k:]
        seed_recency_weights = []
        for idx in range(len(seeds)):
            # last seed gets 1.0, previous gets w_seed_recency, etc.
            power = len(seeds) - 1 - idx
            seed_recency_weights.append(w_seed_recency ** power)

        scores = defaultdict(float)

        # 1) co-vis neighbors aggregated
        for seed, srw in zip(seeds, seed_recency_weights):
            seed = int(seed)
            neigh = adj.get(seed, None)
            if not neigh:
                continue
            for nb, w in neigh.items():
                scores[int(nb)] += w_covis * srw * float(w)

        # 2) repeat booster (time-decayed personal popularity)
        rep = user_repeat_scores(hist_items, hist_dates, max_day, time_decay=0.08)
        for it, w in rep.items():
            scores[int(it)] += w_repeat * float(w)

        # 3) global pop as fallback/regularizer
        for it, w in zip(pop_items, pop_scores):
            scores[int(it)] += w_pop * float(w)

        # collect top candidates
        if not scores:
            cand = list(map(int, pop_items[:topk]))
            results[u] = cand
        else:
            # take top by score
            items = np.fromiter(scores.keys(), dtype=np.int32)
            scs = np.fromiter(scores.values(), dtype=np.float32)
            order = np.argsort(-scs)[:max(cap_candidates, topk)]
            ranked = items[order]
            # Dedup and cut to topk
            unique_ranked = []
            seen = set()
            for it in ranked:
                it = int(it)
                if it in seen:
                    continue
                seen.add(it)
                unique_ranked.append(it)
                if len(unique_ranked) >= topk:
                    break
            # if still short, pad with pop
            if len(unique_ranked) < topk:
                for it in pop_items:
                    it = int(it)
                    if it in seen:
                        continue
                    unique_ranked.append(it)
                    seen.add(it)
                    if len(unique_ranked) >= topk:
                        break
            results[u] = unique_ranked[:topk]
    return results

# -----------------------------
# End-to-end
# -----------------------------
def main(train_path="train_data.pq", sample_path="sample_submission.csv", submission_path="submission.csv",
         holdout_days=7,
         max_last_per_user=60, window=20, covis_cap=300,
         last_k=10, topk_pred=20):

    # Load
    df, uid2orig, iid2orig, sample_users_orig, sample_uids = load_data(train_path, sample_path)
    n_users = int(df["uid"].max()) + 1
    n_items = int(df["iid"].max()) + 1

    # Split for evaluation/tuning
    train_df, val_df, gt_val = time_split(df, holdout_days=holdout_days)
    print(f"Train interactions: {len(train_df):,}, Val interactions: {len(val_df):,}")

    # Build arrays for TRAIN
    uids_tr, iids_tr, dates_tr, offsets_tr = build_user_offsets(train_df, n_users)

    # Popularity on TRAIN
    pop_scores_tr = popularity_scores(train_df, recent_days=7, n_items=n_items)
    pop_items_tr, pop_scores_vals_tr = top_pop_items(pop_scores_tr, topk=500)

    # Co-vis adjacency on TRAIN
    covis_adj_tr = build_covisit_adjacency(
        iids=iids_tr, dates=dates_tr, offsets=offsets_tr,
        max_last_per_user=max_last_per_user, window=window,
        time_decay=0.08, symmetric=True, cap_neighbors=covis_cap
    )

    # Recommend for VAL users (uids present in GT)
    val_users = np.array(list(gt_val.keys()), dtype=np.int64)
    preds_val = recommend_batch(
        users=val_users, offsets=offsets_tr, iids=iids_tr, dates=dates_tr, adj=covis_adj_tr,
        pop_items=pop_items_tr, pop_scores=pop_scores_vals_tr,
        last_k=last_k, topk=topk_pred,
        w_covis=1.0, w_seed_recency=0.85, w_repeat=0.3, w_pop=0.2,
        cap_candidates=600
    )
    val_map = map_at_k(gt_val, preds_val, k=20)
    print(f"Validation mAP@20 (algorithmic): {val_map:.5f}")

    # -----------------------------
    # Refit on FULL (train + last 7 days) for final predictions
    # -----------------------------
    print("Refit on FULL data (including last 7 days)...")
    uids_full, iids_full, dates_full, offsets_full = build_user_offsets(df, n_users)

    pop_scores_full = popularity_scores(df, recent_days=7, n_items=n_items)
    pop_items_full, pop_scores_vals_full = top_pop_items(pop_scores_full, topk=500)

    covis_adj_full = build_covisit_adjacency(
        iids=iids_full, dates=dates_full, offsets=offsets_full,
        max_last_per_user=max_last_per_user, window=window,
        time_decay=0.08, symmetric=True, cap_neighbors=covis_cap
    )

    # Predict for TEST users (from sample)
    test_users = np.unique(sample_uids[sample_uids >= 0])
    preds_test = recommend_batch(
        users=test_users, offsets=offsets_full, iids=iids_full, dates=dates_full, adj=covis_adj_full,
        pop_items=pop_items_full, pop_scores=pop_scores_vals_full,
        last_k=last_k, topk=topk_pred,
        w_covis=1.0, w_seed_recency=0.85, w_repeat=0.3, w_pop=0.2,
        cap_candidates=600
    )

    # -----------------------------
    # Build submission
    # -----------------------------
    print("Writing submission...")
    sample = pd.read_csv(sample_path)
    user_orig2uid = {int(orig): int(i) for i, orig in enumerate(uid2orig)}
    item_idx2orig = np.asarray(iid2orig, dtype=np.int64)

    # Precompute per-user 20 items (orig ids)
    pred_item_cache = {}
    pop20_orig = item_idx2orig[top_pop_items(pop_scores_full, topk=20)[0]]

    for u_orig in tqdm(sample["user_id"].unique(), desc="Cache predictions"):
        u_orig = int(u_orig)
        uid = user_orig2uid.get(u_orig, -1)
        if uid == -1:
            pred_item_cache[u_orig] = pop20_orig
            continue
        ranked0 = preds_test.get(uid, [])
        if not ranked0:
            ranked0 = list(map(int, pop_items_full[:topk_pred]))
        ranked0 = np.asarray(ranked0[:topk_pred], dtype=np.int64)
        pred_item_cache[u_orig] = item_idx2orig[ranked0]

    # Fill submission: 20 rows per user in order
    from collections import defaultdict
    counters = defaultdict(int)
    out_items = np.empty(len(sample), dtype=np.int64)

    for idx, u in tqdm(enumerate(sample["user_id"].values), total=len(sample), desc="Fill submission"):
        u = int(u)
        i = counters[u]
        arr = pred_item_cache[u]
        if i >= len(arr):
            arr = pop20_orig
            pred_item_cache[u] = arr
            i = 0
        out_items[idx] = int(arr[i])
        counters[u] += 1

    sample["item_id"] = out_items
    # sanity: ensure integer ids
    assert np.issubdtype(sample["item_id"].dtype, np.integer)
    sample.to_csv(submission_path, index=False)
    print(f"Saved submission to {submission_path}")

if __name__ == "__main__":
    # Пример запуска: пути подставь свои
    main(
        train_path="/kaggle/input/stupidshit777/train_data.pq",
        sample_path="/kaggle/input/stupidshit777/sample_submission (11).csv",
        submission_path="submission.csv",
        holdout_days=7,
        max_last_per_user=60,
        window=20,
        covis_cap=300,
        last_k=10,
        topk_pred=20
    )

Reading parquet with Polars...
Mapping ids to contiguous indices...
Loading sample users...


Build val GT: 100%|██████████| 592309/592309 [00:44<00:00, 13328.33it/s]


Train interactions: 7,326,431, Val interactions: 1,451,544
Building co-vis adjacency: max_last_per_user=60, window=20


Covis per user: 100%|██████████| 2682603/2682603 [00:48<00:00, 55529.50it/s] 


Pruning adjacency to top-300 neighbors per item...


Prune neighbors: 100%|██████████| 620175/620175 [00:06<00:00, 89229.58it/s] 
Recommend users: 100%|██████████| 592309/592309 [03:13<00:00, 3053.18it/s] 


Validation mAP@20 (algorithmic): 0.05538
Refit on FULL data (including last 7 days)...
Building co-vis adjacency: max_last_per_user=60, window=20


Covis per user: 100%|██████████| 2682603/2682603 [00:59<00:00, 44964.31it/s] 


Pruning adjacency to top-300 neighbors per item...


Prune neighbors: 100%|██████████| 700420/700420 [00:09<00:00, 76351.91it/s] 
Recommend users: 100%|██████████| 293230/293230 [04:07<00:00, 1185.95it/s]


Writing submission...


Cache predictions: 100%|██████████| 293230/293230 [00:01<00:00, 204217.93it/s]
Fill submission: 100%|██████████| 5864600/5864600 [00:04<00:00, 1317464.25it/s]


Saved submission to submission.csv


In [None]:
import pandas as pd
import numpy as np
from collections import defaultdict
import warnings
from tqdm import tqdm

warnings.filterwarnings('ignore')

class TShoppingRecommender:
    def __init__(self, recent_days_window=3):
        self.user_recent_clicks = defaultdict(list)
        self.global_popularity = defaultdict(int)
        self.last_day = None
        self.recent_days_window = recent_days_window
        
    def fit(self, train_data):
        """Обучение модели на исторических данных"""
        print("Обработка тренировочных данных...")
        
        # Находим последний день в данных
        self.last_day = train_data['date'].max()
        print(f"Последний день в данных: {self.last_day}")
        print(f"Используется окно из {self.recent_days_window} последних дней")
        
        # Собираем клики пользователей за последние N дней
        recent_days = [self.last_day - i for i in range(self.recent_days_window)]
        recent_data = train_data[train_data['date'].isin(recent_days)]
        
        print(f"Найдено взаимодействий за последние {self.recent_days_window} дней: {len(recent_data)}")
        
        print("Сбор персональных предпочтений...")
        # Используем tqdm для прогресса
        for _, row in tqdm(recent_data.iterrows(), total=len(recent_data), desc="Обработка кликов"):
            user_id = row['user_id']
            item_id = row['item_id']
            date = row['date']
            
            # Сохраняем клики с учетом даты (более поздние имеют больший вес)
            self.user_recent_clicks[user_id].append((item_id, date))
        
        print("Сортировка кликов по дате...")
        # Сортируем клики каждого пользователя по дате (сначала последние)
        for user_id in tqdm(self.user_recent_clicks.keys(), desc="Сортировка пользователей"):
            self.user_recent_clicks[user_id].sort(key=lambda x: x[1], reverse=True)
        
        # Собираем глобальную популярность товаров за последний день
        last_day_data = train_data[train_data['date'] == self.last_day]
        item_counts = last_day_data['item_id'].value_counts()
        
        print("Расчет глобальной популярности...")
        for item_id, count in tqdm(item_counts.items(), total=len(item_counts), desc="Популярные товары"):
            self.global_popularity[item_id] = count
            
        print(f"Обработано пользователей: {len(self.user_recent_clicks)}")
        print(f"Уникальных популярных товаров: {len(self.global_popularity)}")
    
    def predict(self, user_ids, top_k=20):
        """Предсказание топ-K товаров для каждого пользователя"""
        print("Генерация рекомендаций...")
        
        # Создаем глобальный рейтинг популярности
        global_ranking = list(self.global_popularity.keys())
        
        submissions = []
        
        # Используем tqdm для прогресса предсказания
        for user_id in tqdm(user_ids, desc="Генерация рекомендаций"):
            user_recommendations = []
            
            # Берем уникальные товары из последних кликов пользователя
            if user_id in self.user_recent_clicks:
                recent_items = []
                for item_id, date in self.user_recent_clicks[user_id]:
                    if item_id not in recent_items:
                        recent_items.append(item_id)
                    if len(recent_items) >= top_k:
                        break
                
                user_recommendations.extend(recent_items)
            
            # Если недостаточно персональных рекомендаций, добавляем глобально популярные
            if len(user_recommendations) < top_k:
                # Берем только те товары, которых еще нет в рекомендациях
                additional_items = [item for item in global_ranking 
                                  if item not in user_recommendations]
                
                # Добавляем столько, сколько нужно до top_k
                needed = top_k - len(user_recommendations)
                user_recommendations.extend(additional_items[:needed])
            
            # Обеспечиваем, что рекомендаций ровно top_k
            user_recommendations = user_recommendations[:top_k]
            
            # Добавляем в submission
            for item_id in user_recommendations:
                submissions.append({
                    'user_id': user_id,
                    'item_id': item_id
                })
        
        return pd.DataFrame(submissions)

def main():
    # Загрузка данных
    print("Загрузка данных...")
    train_data = pd.read_parquet('/kaggle/input/stupidshit777/train_data.pq')
    
    # Загрузка sample submission для получения списка пользователей
    sample_submission = pd.read_csv('/kaggle/input/stupidshit777/sample_submission (11).csv')
    test_users = sample_submission['user_id'].unique()
    
    print(f"Всего пользователей в тесте: {len(test_users)}")
    print(f"Всего взаимодействий в тренировочных данных: {len(train_data)}")
    print(f"Уникальных пользователей: {train_data['user_id'].nunique()}")
    print(f"Уникальных товаров: {train_data['item_id'].nunique()}")
    
    # Настройка параметров
    RECENT_DAYS_WINDOW = 17  # Можно изменить на 3, 5, 10 и т.д.
    
    # Обучение модели
    model = TShoppingRecommender(recent_days_window=RECENT_DAYS_WINDOW)
    model.fit(train_data)
    
    # Предсказание
    submission_df = model.predict(test_users, top_k=20)
    
    # Проверка формата
    print(f"\nПроверка формата submission:")
    print(f"Всего строк в submission: {len(submission_df)}")
    print(f"Ожидалось: {len(test_users) * 20}")
    
    # Сохранение результатов
    submission_df.to_csv('submission.csv', index=False)
    print("\nSubmission сохранен в submission.csv")
    
    # Пример рекомендаций для первых 5 пользователей
    print("\nПример рекомендаций для первых 5 пользователей:")
    for i, user_id in enumerate(test_users[:5]):
        user_recs = submission_df[submission_df['user_id'] == user_id]['item_id'].tolist()
        print(f"Пользователь {user_id}: {user_recs[:5]}... (всего {len(user_recs)} рекомендаций)")
    
    # Статистика по рекомендациям
    print(f"\nСтатистика:")
    print(f"Использовано дней для персональных предпочтений: {RECENT_DAYS_WINDOW}")
    print(f"Размер глобального пула популярных товаров: {len(model.global_popularity)}")

if __name__ == "__main__":
    main()

Reading parquet with Polars...
Mapping ids to contiguous indices...
Loading sample users...


Build val GT: 100%|██████████| 592309/592309 [00:45<00:00, 13076.60it/s]


Train interactions: 7,326,431, Val interactions: 1,451,544
Building co-vis graph: last=100, short_w=8, long_w=25


Covis per user: 100%|██████████| 2682603/2682603 [01:27<00:00, 30510.86it/s] 
Prune/sort neighbors: 100%|██████████| 634257/634257 [00:23<00:00, 26516.17it/s] 
Prune/sort neighbors: 100%|██████████| 600257/600257 [00:16<00:00, 35755.27it/s] 
Recommend users: 100%|██████████| 592309/592309 [1:03:47<00:00, 154.74it/s] 


Validation mAP@20 (algorithmic): 0.04337
Refit on FULL data (including last 7 days)...
Building co-vis graph: last=100, short_w=8, long_w=25


Covis per user: 100%|██████████| 2682603/2682603 [01:48<00:00, 24667.76it/s] 
Prune/sort neighbors: 100%|██████████| 717614/717614 [00:32<00:00, 22353.08it/s] 
Prune/sort neighbors: 100%|██████████| 681615/681615 [00:19<00:00, 35870.99it/s] 
Recommend users:   0%|          | 187/293230 [00:05<2:19:51, 34.92it/s]

In [None]:
4

In [None]:
# -----------------------------
# Global accelerators
# -----------------------------
import os
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")

import torch
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
try:
    from torch.backends.cuda import sdp_kernel
    sdp_kernel.enable_flash_sdp(True)
    sdp_kernel.enable_mem_efficient_sdp(True)
    sdp_kernel.enable_math_sdp(False)
except Exception:
    pass
if hasattr(torch, "set_float32_matmul_precision"):
    torch.set_float32_matmul_precision("high")

In [None]:
# -----------------------------
# RP3beta Retrieval (implicit CPU)
# -----------------------------
def train_rp3beta(user_item_csr, K=200, alpha=0.4, beta=0.7):
    try:
        from implicit.nearest_neighbours import RP3betaRecommender
        from implicit.nearest_neighbours import bm25_weight
    except Exception as e:
        print(f"[RP3] implicit недоступен: {e}. Пропускаем RP3beta.")
        return None, user_item_csr

    try:
        ui = user_item_csr.tocsr().astype(np.float32, copy=False)
        ui_w = bm25_weight(ui, K1=1.2, B=0.75).tocsr()
        model = RP3betaRecommender(K=K, alpha=alpha, beta=beta)
        print("Training RP3beta (CPU)...")
        model.fit(ui_w.T, show_progress=True)
        return model, ui_w
    except Exception as e:
        print(f"[RP3] ошибка обучения: {e}. Пропускаем RP3beta.")
        return None, user_item_csr

def rp3_recommend_batch(model, user_item_csr, user_ids, N=300, filter_seen=False, batch_size=4096):
    recs = {}
    if model is None:
        return recs
    for i in tqdm(range(0, len(user_ids), batch_size), desc="RP3 recommend"):
        batch = user_ids[i: i+batch_size]
        for u in batch:
            if u < 0:
                recs[u] = (np.empty((0,), dtype=np.int32), np.empty((0,), dtype=np.float32))
                continue
            try:
                ids, scores = model.recommend(
                    userid=int(u),
                    user_items=user_item_csr,
                    N=N,
                    recalculate_user=True,
                    filter_already_liked_items=filter_seen
                )
                recs[u] = (ids.astype(np.int32), scores.astype(np.float32))
            except Exception:
                recs[u] = (np.empty((0,), dtype=np.int32), np.empty((0,), dtype=np.float32))
    return recs

In [None]:
# -----------------------------
# Utils
# -----------------------------
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

set_seed(42)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def average_precision_at_k(true_items_set, ranked_items, k=20):
    if not true_items_set:
        return 0.0
    score = 0.0
    hits = 0
    for i, item in enumerate(ranked_items[:k]):
        if item in true_items_set:
            hits += 1
            score += hits / (i + 1)
    return score / min(len(true_items_set), k)

def map_at_k(gt_by_user, preds_by_user, k=20):
    aps = []
    for u, true_set in gt_by_user.items():
        preds = preds_by_user.get(u, [])
        aps.append(average_precision_at_k(true_set, preds, k=k))
    return float(np.mean(aps)) if aps else 0.0
    

In [None]:
# -----------------------------
# Load data
# -----------------------------
def load_data(train_path="train_data.pq", sample_path="sample_submission.csv"):
    print("Reading parquet with Polars...")
    dfpl = pl.read_parquet(train_path)
    # Ensure dtypes
    dfpl = dfpl.with_columns([
        pl.col("user_id").cast(pl.Int64),
        pl.col("item_id").cast(pl.Int64),
        pl.col("date").cast(pl.Int32)
    ])
    df = dfpl.to_pandas()
    del dfpl
    gc.collect()

    print("Mapping ids to contiguous indices...")
    df["uid"], user_uniques = pd.factorize(df["user_id"], sort=True)
    df["iid"], item_uniques = pd.factorize(df["item_id"], sort=True)

    uid2orig = np.array(user_uniques)
    iid2orig = np.array(item_uniques)

    # Sort by user, then by date for sequence models
    df.sort_values(["uid", "date"], inplace=True)
    df.reset_index(drop=True, inplace=True)

    print("Loading sample users...")
    sample = pd.read_csv(sample_path)
    # Sample repeats user_id 20 times; we only need unique users
    sample_users = sample["user_id"].unique()
    # Map sample users to internal uids
    # Build map from original user_id to uid
    user_orig2uid = {orig: idx for idx, orig in enumerate(uid2orig)}
    sample_uids = []
    for u in sample_users:
        if u in user_orig2uid:
            sample_uids.append(user_orig2uid[u])
        else:
            # shouldn't happen; but keep -1 to handle later
            sample_uids.append(-1)
    sample_uids = np.array(sample_uids, dtype=np.int64)

    return df, uid2orig, iid2orig, sample_users, sample_uids

# -----------------------------
# Split: last 7 days as validation holdout
# -----------------------------
def time_split(df, holdout_days=7):
    max_day = int(df["date"].max())
    val_start = max_day - (holdout_days - 1)
    train_mask = df["date"] < val_start
    val_mask = df["date"] >= val_start

    train_df = df[train_mask].copy()
    val_df = df[val_mask].copy()

    # Ground truth for validation: set of items per user in holdout period
    gt_val = {}
    for uid, grp in tqdm(val_df.groupby("uid", sort=False), desc="Build val GT"):
        gt_val[uid] = set(grp["iid"].tolist())

    return train_df, val_df, gt_val

# -----------------------------
# Build CSR matrix for implicit ALS
# -----------------------------
def build_user_item_csr(df, n_users, n_items):
    rows = df["uid"].values.astype(np.int32)
    cols = df["iid"].values.astype(np.int32)
    data = np.ones_like(rows, dtype=np.float32)
    ui = sp.coo_matrix((data, (rows, cols)), shape=(n_users, n_items), dtype=np.float32).tocsr()
    return ui

# -----------------------------
# ALS Retrieval
# -----------------------------
def train_als(user_item_csr, use_gpu=True, factors=128, reg=1e-4, iters=20):
    try:
        import implicit
        from implicit.nearest_neighbours import bm25_weight
        ui = bm25_weight(user_item_csr, K1=1.2, B=0.75).tocsr()
        gc.collect()

        if use_gpu:
            try:
                from implicit.gpu.als import AlternatingLeastSquares as GPU_ALS
                als = GPU_ALS(factors=factors, regularization=reg, iterations=iters, random_state=42)
                print("Training ALS (GPU)...")
                als.fit(ui.T, show_progress=True)
                return als, ui
            except Exception as e:
                print(f"GPU ALS not available ({e}), falling back to CPU.")
        from implicit.als import AlternatingLeastSquares as CPU_ALS
        als = CPU_ALS(factors=factors, regularization=reg, iterations=iters, use_cg=True, random_state=42)
        print("Training ALS (CPU)...")
        als.fit(ui.T, show_progress=True)
        return als, ui
    except Exception as e:
        print(f"Implicit not installed or failed: {e}")
        return None, user_item_csr

def als_recommend_batch(model, user_item_csr, user_ids, N=300, filter_seen=False, batch_size=1024):
    recs = {}
    desc = "ALS recommend"
    for i in tqdm(range(0, len(user_ids), batch_size), desc=desc):
        batch = user_ids[i: i+batch_size]
        for u in batch:
            if u < 0:
                recs[u] = (np.empty((0,), dtype=np.int32), np.empty((0,), dtype=np.float32))
                continue
            try:
                ids, scores = model.recommend(
                    userid=int(u),
                    user_items=user_item_csr,
                    N=N,
                    recalculate_user=True,
                    filter_already_liked_items=filter_seen
                )
                recs[u] = (ids.astype(np.int32), scores.astype(np.float32))
            except Exception:
                recs[u] = (np.empty((0,), dtype=np.int32), np.empty((0,), dtype=np.float32))
    return recs

# -----------------------------
# FAISS index for item similarity (ALS item factors)
# -----------------------------
def build_faiss_index(item_factors):
    try:
        import faiss
    except Exception as e:
        print(f"FAISS not available: {e}")
        return None, None

    d = item_factors.shape[1]
    try:
        res = faiss.StandardGpuResources() if torch.cuda.is_available() else None
    except Exception:
        res = None

    index = faiss.IndexFlatIP(d)
    if res is not None:
        index = faiss.index_cpu_to_gpu(res, 0, index)
    # Optionally normalize for cosine similarity; but ALS uses IP
    # faiss.normalize_L2(item_factors)
    index.add(item_factors.astype(np.float32))
    return index, res

def similar_items_for_users(index, item_factors, user_histories, topk_per_item=50, last_k=10):
    """
    user_histories: dict uid -> np.array of recent item ids (0-based)
    returns dict uid -> (item_ids, scores) aggregated over last_k items
    """
    try:
        import faiss
    except Exception:
        index = None

    recs = {}
    if index is None:
        return recs

    for uid, hist in tqdm(user_histories.items(), desc="FAISS item2item"):
        if hist.size == 0:
            recs[uid] = (np.empty((0,), dtype=np.int32), np.empty((0,), dtype=np.float32))
            continue
        queries = hist[-last_k:]
        # Get neighbors for each query
        # Note: include self in results; we will drop identical ids
        distances, neighbors = index.search(item_factors[queries].astype(np.float32), k=topk_per_item+1)
        # Aggregate by sum of scores, exclude original items
        agg = defaultdict(float)
        hist_set = set(queries.tolist())
        for row in range(neighbors.shape[0]):
            for j in range(neighbors.shape[1]):
                it = int(neighbors[row, j])
                if it in hist_set:
                    continue
                agg[it] += float(distances[row, j])
        if not agg:
            recs[uid] = (np.empty((0,), dtype=np.int32), np.empty((0,), dtype=np.float32))
        else:
            items = np.fromiter(agg.keys(), dtype=np.int32)
            scores = np.fromiter(agg.values(), dtype=np.float32)
            order = np.argsort(-scores)
            recs[uid] = (items[order], scores[order])
    return recs

# -----------------------------
# Popularity candidates
# -----------------------------
def popularity_scores(df_train, df_recent_days=7, n_items=None):
    max_day = int(df_train["date"].max())
    recent_start = max_day - (df_recent_days - 1)

    all_cnt = df_train.groupby("iid").size()
    rec_cnt = df_train[df_train["date"] >= recent_start].groupby("iid").size()

    # Align
    if n_items is None:
        n_items = int(df_train["iid"].max()) + 1
    all_vec = np.zeros(n_items, dtype=np.float32)
    rec_vec = np.zeros(n_items, dtype=np.float32)

    all_vec[all_cnt.index.values] = all_cnt.values.astype(np.float32)
    rec_vec[rec_cnt.index.values] = rec_cnt.values.astype(np.float32)

    # Normalize
    all_vec /= (all_vec.max() + 1e-9)
    rec_vec /= (rec_vec.max() + 1e-9)

    # Weighted blend (tuneable)
    pop = 0.3 * all_vec + 0.7 * rec_vec
    return pop

def top_pop_items(pop_scores, topk=500):
    idx = np.argsort(-pop_scores)[:topk]
    return idx, pop_scores[idx]

# -----------------------------
# Build user sequences (offsets) for SASRec
# -----------------------------
def build_user_offsets(df, n_users):
    # df must be sorted by ["uid", "date"]
    uids = df["uid"].values.astype(np.int64)
    iids = df["iid"].values.astype(np.int64)
    # Keep 0-based for ALS; SASRec will use +1 for padding
    counts = np.bincount(uids, minlength=n_users).astype(np.int64)
    offsets = np.zeros(n_users + 1, dtype=np.int64)
    offsets[1:] = np.cumsum(counts)
    # items array already in user-sorted order
    return uids, iids, offsets

def get_user_last_items(iids, offsets, uids):
    user_hist = {}
    for u in tqdm(np.unique(uids), desc="Collect last items per user"):
        s, e = offsets[u], offsets[u+1]
        if e - s > 0:
            user_hist[u] = iids[s:e]
        else:
            user_hist[u] = np.array([], dtype=np.int64)
    return user_hist

# -----------------------------
# SASRec model
# -----------------------------
# -----------------------------
# SASRec (norm_first + SDPA-friendly)
# -----------------------------
class SASRec(nn.Module):
    def __init__(self, n_items_padded, max_len=50, d_model=192, n_heads=4, n_layers=3, dropout=0.2):
        super().__init__()
        self.n_items = n_items_padded
        self.max_len = max_len
        self.item_emb = nn.Embedding(n_items_padded, d_model, padding_idx=0)
        self.pos_emb = nn.Embedding(max_len, d_model)
        layer = nn.TransformerEncoderLayer(
            d_model=d_model, nhead=n_heads, dim_feedforward=d_model*4,
            dropout=dropout, batch_first=True, norm_first=True
        )
        self.encoder = nn.TransformerEncoder(layer, num_layers=n_layers, norm=nn.LayerNorm(d_model))
        self.dropout = nn.Dropout(dropout)
        nn.init.normal_(self.item_emb.weight, mean=0.0, std=0.02)
        nn.init.normal_(self.pos_emb.weight, mean=0.0, std=0.02)

    def forward(self, seq_ids):
        B, L = seq_ids.size()
        positions = torch.arange(L, device=seq_ids.device).unsqueeze(0).expand(B, L)
        x = self.item_emb(seq_ids) + self.pos_emb(positions)
        x = self.dropout(x)
        causal = torch.triu(torch.ones(L, L, device=seq_ids.device), diagonal=1).bool()
        key_pad = (seq_ids == 0)
        x = self.encoder(x, mask=causal, src_key_padding_mask=key_pad)
        lengths = (~key_pad).sum(dim=1).clamp(min=1)
        idx = (lengths - 1).unsqueeze(1).unsqueeze(2).expand(-1, 1, x.size(-1))
        out = x.gather(1, idx).squeeze(1)
        return out

    def item_embedding(self, item_ids):
        return self.item_emb(item_ids)

# -----------------------------
# Train SASRec with AMP + compile + cosine warmup + early stop
# -----------------------------
def train_sasrec(model, dataset, epochs=3, batch_size=1024, lr=3e-4, val_eval_fn=None, amp_dtype=None,
                 grad_clip=1.0, patience=2, save_path="sasrec_best.pt"):
    model = model.to(DEVICE)
    if amp_dtype is None:
        amp_dtype = torch.bfloat16 if (torch.cuda.is_available() and torch.cuda.is_bf16_supported()) else torch.float16
    try:
        model = torch.compile(model, mode="max-autotune")
    except Exception:
        pass

    optim = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)
    loss_fn = nn.CrossEntropyLoss(label_smoothing=0.05)
    scaler = torch.cuda.amp.GradScaler(enabled=(DEVICE.type == "cuda"))

    steps_per_epoch = math.ceil(len(dataset) / batch_size)
    total_steps = steps_per_epoch * epochs
    warmup = max(1, int(0.1 * total_steps))

    def lr_lambda(step):
        if step < warmup:
            return float(step + 1) / float(warmup)
        progress = (step - warmup) / max(1, (total_steps - warmup))
        return 0.5 * (1.0 + math.cos(math.pi * progress))

    scheduler = torch.optim.lr_scheduler.LambdaLR(optim, lr_lambda=lr_lambda)
    best_map, bad_epochs = -1.0, 0

    loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=2,
                        collate_fn=collate_batch, pin_memory=True, persistent_workers=True)

    for epoch in range(1, epochs + 1):
        dataset.shuffle_users()
        model.train()
        pbar = tqdm(loader, total=steps_per_epoch, desc=f"SASRec epoch {epoch}")
        losses = []
        for step, (seqs, targets) in enumerate(pbar, start=1):
            seqs = seqs.to(DEVICE, non_blocking=True)
            targets = targets.to(DEVICE, non_blocking=True)
            with torch.cuda.amp.autocast(enabled=(DEVICE.type=="cuda"), dtype=amp_dtype):
                user_vec = model(seqs)
                pos_embs = model.item_embedding(targets)
                logits = user_vec @ pos_embs.T
                # явная диагональная маска (чистые in-batch negatives)
                logits = logits - torch.eye(logits.size(0), device=logits.device) * 1e9
                labels = torch.arange(seqs.size(0), device=DEVICE)
                loss = loss_fn(logits, labels)
            optim.zero_grad(set_to_none=True)
            scaler.scale(loss).backward()
            if grad_clip and grad_clip > 0:
                scaler.unscale_(optim)
                torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
            scaler.step(optim)
            scaler.update()
            scheduler.step()
            losses.append(loss.item())
            if len(losses) % 20 == 0:
                pbar.set_postfix(loss=np.mean(losses[-20:]), lr=optim.param_groups[0]["lr"])
        print(f"Epoch {epoch} train loss: {np.mean(losses):.4f}")

        if val_eval_fn is not None:
            with torch.no_grad(), torch.cuda.amp.autocast(enabled=(DEVICE.type=="cuda"), dtype=amp_dtype):
                val_map = val_eval_fn(model)
            print(f"Epoch {epoch} VAL mAP@20: {val_map:.5f}")
            if val_map > best_map + 1e-5:
                best_map, bad_epochs = val_map, 0
                try:
                    torch.save(model.state_dict(), save_path)
                except Exception:
                    pass
            else:
                bad_epochs += 1
                if bad_epochs >= patience:
                    print(f"Early stop on epoch {epoch}, best mAP@20 = {best_map:.5f}")
                    break

@torch.no_grad()
def rerank_with_sasrec(model, user_candidates, items_by_user, offsets, max_len=50, batch_size=2048, amp_dtype=None):
    model.eval()
    if amp_dtype is None:
        amp_dtype = torch.bfloat16 if (torch.cuda.is_available() and torch.cuda.is_bf16_supported()) else torch.float16

    uids = list(user_candidates.keys())
    results = {}

    # кэш последовательностей
    seq_cache = {}
    for u in tqdm(uids, desc="Prepare sequences"):
        s, e = offsets[u], offsets[u+1]
        hist = items_by_user[s:e]
        if hist.size == 0:
            seq_cache[u] = torch.zeros((max_len,), dtype=torch.long)
        else:
            seq = hist[-max_len:] + 1
            pad = max_len - len(seq)
            if pad > 0:
                seq = np.pad(seq, (pad, 0), constant_values=0)
            seq_cache[u] = torch.from_numpy(seq.astype(np.int64, copy=False))

    for i in tqdm(range(0, len(uids), batch_size), desc="Re-ranking"):
        batch_u = uids[i: i+batch_size]
        seq_batch = torch.stack([seq_cache[u] for u in batch_u], dim=0).to(DEVICE, non_blocking=True)
        cmax = max((len(user_candidates[u]) for u in batch_u), default=0)
        if cmax == 0:
            for u in batch_u:
                results[u] = []
            continue

        cand_mat = np.zeros((len(batch_u), cmax), dtype=np.int64)
        cand_lens = np.zeros((len(batch_u),), dtype=np.int32)
        for bi, u in enumerate(batch_u):
            arr = user_candidates[u]
            cand_lens[bi] = len(arr)
            cand_mat[bi, :len(arr)] = arr + 1  # к padded space

        cand_t = torch.from_numpy(cand_mat).to(DEVICE, non_blocking=True)

        with torch.cuda.amp.autocast(enabled=(DEVICE.type=="cuda"), dtype=amp_dtype):
            user_vecs = model(seq_batch)                  # (B, D)
            item_embs = model.item_embedding(cand_t)      # (B, Cmax, D)
            logits = (item_embs * user_vecs.unsqueeze(1)).sum(dim=-1)  # (B, Cmax)
            logits = logits.masked_fill(cand_t.eq(0), float("-inf"))
            topk = min(500, cmax)
            _, idxs = torch.topk(logits, k=topk, dim=1)

        for bi, u in enumerate(batch_u):
            valid = cand_lens[bi]
            if valid == 0:
                results[u] = []
                continue
            order = idxs[bi, :min(topk, valid)].detach().cpu().numpy()
            ranked = cand_mat[bi][order] - 1
            results[u] = ranked.tolist()
    return results

# -----------------------------
# SASRec dataset
# -----------------------------
class SequenceDataset(Dataset):
    def __init__(self, items_by_user, offsets, max_len=50, users=None):
        """
        items_by_user: 1D np.array of item ids (0-based), sorted by user/date
        offsets: np.array of size n_users+1
        """
        self.items = items_by_user.astype(np.int64)  # 0-based
        self.offsets = offsets.astype(np.int64)
        self.max_len = max_len
        self.n_users = offsets.size - 1

        if users is None:
            # Only users with at least 2 interactions
            all_users = np.arange(self.n_users, dtype=np.int64)
            lens = self.offsets[1:] - self.offsets[:-1]
            self.users = all_users[lens >= 2]
        else:
            self.users = np.array(users, dtype=np.int64)

        # Shuffle initial order
        self.shuffle_users()

    def shuffle_users(self):
        np.random.shuffle(self.users)

    def __len__(self):
        return len(self.users)

    def __getitem__(self, idx):
        u = self.users[idx]
        s, e = self.offsets[u], self.offsets[u+1]
        seq = self.items[s:e]  # 0-based
        L = len(seq)
        # Sample a cut point: prefix -> next item
        # t in [1..L-1], target = seq[t]
        t = np.random.randint(1, L)
        prefix = seq[:t]
        target = seq[t]

        # Truncate to last max_len of prefix
        if len(prefix) > self.max_len:
            prefix = prefix[-self.max_len:]
        # Convert to padded space: add 1
        prefix_padded = prefix + 1  # 1..n_items
        # Left-pad with zeros
        pad_len = self.max_len - len(prefix_padded)
        if pad_len > 0:
            prefix_padded = np.pad(prefix_padded, (pad_len, 0), constant_values=0)

        target_padded = int(target + 1)
        return prefix_padded.astype(np.int64), target_padded

def collate_batch(batch):
    seqs_np = np.stack([b[0] for b in batch], axis=0).astype(np.int64, copy=False)
    seqs = torch.from_numpy(seqs_np)
    targets = torch.tensor([b[1] for b in batch], dtype=torch.long)
    return seqs, targets

# -----------------------------
# Train SASRec with in-batch negatives
# -----------------------------


# -----------------------------
# Candidate merger and reranking
# -----------------------------
def merge_candidates(als_recs, i2i_recs, pop_items, pop_scores, users, topk_merge=500,
                     w_als=1.0, w_i2i=1.0, w_pop=0.3, rp3_recs=None, w_rp3=1.0):
    merged = {}
    for u in tqdm(users, desc="Merge candidates"):
        scores = defaultdict(float)
        if als_recs is not None and u in als_recs:
            ids, sc = als_recs[u]
            for iid, s in zip(ids, sc):
                scores[int(iid)] += w_als * float(s)
        if rp3_recs is not None and u in rp3_recs:
            ids, sc = rp3_recs[u]
            for iid, s in zip(ids, sc):
                scores[int(iid)] += w_rp3 * float(s)
        if i2i_recs is not None and u in i2i_recs:
            ids, sc = i2i_recs[u]
            for iid, s in zip(ids, sc):
                scores[int(iid)] += w_i2i * float(s)
        if pop_items is not None:
            for iid, s in zip(pop_items, pop_scores):
                scores[int(iid)] += w_pop * float(s)

        if not scores:
            merged[u] = np.array([], dtype=np.int32)
        else:
            items = np.fromiter(scores.keys(), dtype=np.int32)
            scs = np.fromiter(scores.values(), dtype=np.float32)
            order = np.argsort(-scs)[:topk_merge]
            merged[u] = items[order]
    return merged


# -----------------------------
# End-to-end
# -----------------------------
def main(train_path="/kaggle/input/stupidshit777/train_data.pq", sample_path="/kaggle/input/stupidshit777/sample_submission (11).csv", submission_path="submission1.csv",
         holdout_days=7, max_seq_len=50):
    # Load
    df, uid2orig, iid2orig, sample_users_orig, sample_uids = load_data(train_path, sample_path)
    n_users = int(df["uid"].max()) + 1
    n_items = int(df["iid"].max()) + 1

    # Split (для тюнинга)
    train_df, val_df, gt_val = time_split(df, holdout_days=holdout_days)
    print(f"Train interactions: {len(train_df):,}, Val interactions: {len(val_df):,}")

    # Stage A: Обучение на первых 47 днях + валидация на последних 7
    # CSR/ALS для train
    user_item_csr = build_user_item_csr(train_df, n_users, n_items)
    als_model, weighted_ui = train_als(user_item_csr, use_gpu=False, factors=128, reg=1e-4, iters=20)

    als_item_factors, als_user_factors = None, None
    if als_model is not None:
        als_item_factors = getattr(als_model, "item_factors", getattr(als_model, "item_factors_", None))
        als_user_factors = getattr(als_model, "user_factors", getattr(als_model, "user_factors_", None))

    faiss_index, _ = build_faiss_index(als_item_factors) if als_item_factors is not None else (None, None)

    # RP3beta (train)
    rp3_model, rp3_ui = train_rp3beta(user_item_csr, K=200, alpha=0.4, beta=0.7)

    # Sequences (train)
    uids_sorted, iids_sorted, offsets = build_user_offsets(train_df, n_users)

    # Истории целевых юзеров (val + test) на train-отрезке
    users_for_hist = np.unique(np.concatenate([np.array(list(gt_val.keys()), dtype=np.int64),
                                               sample_uids[sample_uids >= 0]]))
    user_histories = {}
    for u in tqdm(users_for_hist, desc="Build histories (train)"):
        s, e = offsets[u], offsets[u+1]
        user_histories[u] = iids_sorted[s:e]

    # Популярность (train)
    pop_scores_vec = popularity_scores(train_df, df_recent_days=7, n_items=n_items)
    pop_items, pop_scores = top_pop_items(pop_scores_vec, topk=300)

    # Кандидаты на валидацию
    target_users_val = np.array(list(gt_val.keys()), dtype=np.int64)
    als_val_recs = als_recommend_batch(als_model, weighted_ui if als_model is not None else user_item_csr,
                                       target_users_val, N=300, filter_seen=False, batch_size=2048) if als_model else {}
    rp3_val_recs = rp3_recommend_batch(rp3_model, rp3_ui if rp3_model is not None else user_item_csr,
                                       target_users_val, N=300, filter_seen=False, batch_size=4096) if rp3_model else {}
    i2i_val_recs = similar_items_for_users(faiss_index, als_item_factors, user_histories,
                                           topk_per_item=80, last_k=10) if faiss_index is not None else {}

    cand_val = merge_candidates(als_val_recs, i2i_val_recs, pop_items, pop_scores, target_users_val,
                                topk_merge=500, w_als=1.0, w_i2i=1.0, w_pop=0.2, rp3_recs=rp3_val_recs, w_rp3=1.0)

    # SASRec train (на train-части)
    sasrec_items = iids_sorted
    sasrec_dataset = SequenceDataset(sasrec_items, offsets, max_len=max_seq_len)
    sasrec_model = SASRec(n_items_padded=n_items + 1, max_len=max_seq_len, d_model=192, n_heads=4, n_layers=3, dropout=0.2)

    def eval_val(model):
        preds_val = rerank_with_sasrec(model, cand_val, sasrec_items, offsets, max_len=max_seq_len, batch_size=2048)
        return map_at_k(gt_val, preds_val, k=20)

    train_sasrec(sasrec_model, sasrec_dataset, epochs=3, batch_size=1024, lr=3e-4, val_eval_fn=eval_val,
                 patience=2, save_path="sasrec_best.pt")

    # Stage B: Финальный рефит на всех 47+7 днях
    print("Refit ALL modules on full data (train+val)...")
    full_df = df  # все дни
    # CSR/ALS/RP3 на FULL
    user_item_csr_full = build_user_item_csr(full_df, n_users, n_items)
    als_model_full, weighted_ui_full = train_als(user_item_csr_full, use_gpu=False, factors=128, reg=1e-4, iters=20)
    als_item_factors_full, als_user_factors_full = None, None
    if als_model_full is not None:
        als_item_factors_full = getattr(als_model_full, "item_factors", getattr(als_model_full, "item_factors_", None))
        als_user_factors_full = getattr(als_model_full, "user_factors", getattr(als_model_full, "user_factors_", None))
    faiss_index_full, _ = build_faiss_index(als_item_factors_full) if als_item_factors_full is not None else (None, None)

    rp3_model_full, rp3_ui_full = train_rp3beta(user_item_csr_full, K=200, alpha=0.4, beta=0.7)

    # Sequences FULL
    uids_sorted_full, iids_sorted_full, offsets_full = build_user_offsets(full_df, n_users)

    # Истории для тестовых пользователей — по FULL
    test_users = np.unique(sample_uids[sample_uids >= 0])
    user_histories_full = {}
    for u in tqdm(test_users, desc="Build histories (full)"):
        s, e = offsets_full[u], offsets_full[u+1]
        user_histories_full[u] = iids_sorted_full[s:e]

    # Популярность FULL
    pop_scores_vec_full = popularity_scores(full_df, df_recent_days=7, n_items=n_items)
    pop_items_full, pop_scores_full = top_pop_items(pop_scores_vec_full, topk=300)

    # Кандидаты для TEST (на FULL)
    als_test_recs = als_recommend_batch(als_model_full, weighted_ui_full if als_model_full is not None else user_item_csr_full,
                                        test_users, N=300, filter_seen=False, batch_size=4096) if als_model_full else {}
    rp3_test_recs = rp3_recommend_batch(rp3_model_full, rp3_ui_full if rp3_model_full is not None else user_item_csr_full,
                                        test_users, N=300, filter_seen=False, batch_size=16384) if rp3_model_full else {}
    i2i_test_recs = similar_items_for_users(faiss_index_full, als_item_factors_full, user_histories_full,
                                            topk_per_item=80, last_k=10) if faiss_index_full is not None else {}
    cand_test = merge_candidates(als_test_recs, i2i_test_recs, pop_items_full, pop_scores_full, test_users,
                                 topk_merge=500, w_als=1.0, w_i2i=1.0, w_pop=0.2, rp3_recs=rp3_test_recs, w_rp3=1.0)

    # Реранкер FULL: подгружаем лучший чекпоинт и дообучаем на FULL 1-2 эпохи маленьким LR (быстрый fine-tune)
    sasrec_model_full = SASRec(n_items_padded=n_items + 1, max_len=max_seq_len, d_model=192, n_heads=4, n_layers=3, dropout=0.2)
    try:
        sasrec_model_full.load_state_dict(torch.load("sasrec_best.pt", map_location="cpu"), strict=False)
        print("Loaded best SASRec weights from train-part.")
    except Exception as e:
        print(f"Failed to load SASRec best weights: {e}. Using current weights.")

    full_dataset = SequenceDataset(iids_sorted_full, offsets_full, max_len=max_seq_len)
    # Короткий фт на FULL с меньшим LR
    train_sasrec(sasrec_model_full, full_dataset, epochs=2, batch_size=1024, lr=1.5e-4, val_eval_fn=None, patience=1)

    # Реранк по FULL секвенциям
    preds_test = rerank_with_sasrec(sasrec_model_full, cand_test, iids_sorted_full, offsets_full,
                                    max_len=max_seq_len, batch_size=2048)

    # Сабмишен (на FULL поп/обратные маппинги)
    print("Writing submission...")
    sample = pd.read_csv(sample_path)
    user_orig2uid = {orig: idx for idx, orig in enumerate(uid2orig)}
    item_idx2orig = iid2orig

    pop_scores_vec = pop_scores_vec_full
    pop_items, pop_scores = pop_items_full, pop_scores_full

    pred_item_cache = {}
    pop20 = top_pop_items(pop_scores_vec, topk=20)[0]
    for u_orig in tqdm(sample["user_id"].unique(), desc="Cache predictions"):
        uid = user_orig2uid.get(int(u_orig), -1)
        if uid == -1:
            pred_item_cache[int(u_orig)] = item_idx2orig[pop20]
            continue
        ranked0 = preds_test.get(uid)
        if not ranked0:
            ranked0 = cand_test.get(uid, np.array([], dtype=np.int32)).tolist()
        seen_set = set(ranked0)
        for i in pop_items:
            if len(ranked0) >= 20: break
            if int(i) not in seen_set:
                ranked0.append(int(i)); seen_set.add(int(i))
        ranked0 = ranked0[:20]
        pred_item_cache[int(u_orig)] = item_idx2orig[np.array(ranked0, dtype=np.int32)]

    from collections import defaultdict
    counters = defaultdict(int)
    out_items = np.empty(len(sample), dtype=item_idx2orig.dtype)
    for idx, u in tqdm(enumerate(sample["user_id"].values), total=len(sample), desc="Fill submission"):
        i = counters[int(u)]
        out_items[idx] = pred_item_cache[int(u)][i]
        counters[int(u)] += 1

    sample["item_id"] = out_items
    sample.to_csv(submission_path, index=False)
    print(f"Saved submission to {submission_path}")bmission to {submission_path}")

if __name__ == "__main__":
    main()