<a href="https://colab.research.google.com/github/wogsim/two_tower_recmodel/blob/main/notebooks/two_tower_rank_model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>



# Рекомендации на основе Трансформера

Этот проект — это система рекомендаций, использующая архитектуру **Трансформер** для анализа истории позитивных действий пользователя.

**Цель:** Научить модель эффективно ранжировать товары, которые пользователь, скорее всего, добавит в корзину.

---

## 1. Входные данные (История взаимодействий)

Мы будем использовать историю **положительных взаимодействий** пользователя (сессии):
* **Клики**
* **Добавления в корзину**
* **Покупки**

(Примечание: Просмотры временно **игнорируются** для упрощения, но могут быть добавлены позже.)

---

## 2. Двухэтапное Обучение

Обучение модели будет состоять из двух фаз: **Pretrain** (Предварительное обучение) и **Finetune** (Тонкая настройка).

### А. Pretrain (Предварительное обучение)

* **Задача:** Обучить модель решать **общую задачу**, максимально утилизируя **все имеющиеся данные** (клики, корзины, покупки).
* **Результат:** Получение сильных общих представлений (эмбеддингов) для пользователей и товаров.
* **Свойство:** Хорошо масштабируется — добавление данных повышает качество.

### Б. Finetune (Тонкая настройка)

* **Задача:** **Адаптировать** модель под **конкретную целевую задачу ранжирования** на тесте.
* **Целевая Метрика:** **$ndcg@10$** (Normalized Discounted Cumulative Gain на 10) по группам (`request_id`).
* **Метод:** Обучение на **попарную функцию потерь** (Pairwise Loss) — будем использовать **Calibrated Pairwise Logistic** (подробности ниже).

---

## 3. Целевая Задача Finetune (Ранжирование)

На этапе Finetune мы фокусируемся на ранжировании внутри групп (`request_id`).

* **Метки:**
    * **1 (Позитивный):** Товар был **добавлен в корзину**.
    * **0 (Негативный):** Товар был **просмотрен** (но не добавлен в корзину).
* **Исключения:** Клики и покупки **не** учитываются, так как их нет в итоговом тестовом наборе.
* **Цель:** **Отранжировать** единички (корзины) как можно **выше** ноликов (просмотров)

### Pretrain:

<div align="center">
  <img src="https://i.ibb.co/8GksnWD/Screenshot-2025-05-04-at-10-15-36.png" width="500" alt="pretrain">
</div>

Предварительное обучение состоит из двух взаимодополняющих задач: **Next-Positive Prediction (NPP)** и **Feedback Prediction (FP)**.

### Общая структура токена

Каждое положительное взаимодействие ($t$) кодируется как **токен**, который является суммой векторов трёх характеристик:
$$\text{Токен}_t = \mathbf{c}_t + \mathbf{i}_t + \mathbf{f}_t$$
* $\mathbf{i}_t$ — **Товар** (Item) взаимодействия.
* $\mathbf{c}_t$ — **Контекст** (Context) взаимодействия.
* $\mathbf{f}_t$ — **Фидбек** (Feedback): клик, корзина или покупка.

### Next-Positive Prediction ($\mathcal{L}_{\mathrm{NPP}}$)
* **Задача:** Предсказать вероятность следующего товара $i_t$, используя Softmax и косинусное сходство:
$$
P(\text{item}=i_t\mid \text{history}=S_{t-1},\;\text{context}=c_t)\,.
$$
* **Позднее связывание (Next-Positive):** Трансформер выдает $h_{t-1}$. Для предсказания используем:
$$
\hat h^c_t = \mathrm{MLP}\bigl(\mathrm{Concat}(h_{t-1},\mathbf{c}_t)\bigr).
$$

### Feedback Prediction ($\mathcal{L}_{\mathrm{FP}}$)
* **Задача:** Предсказать вероятность типа фидбека $f_t$ (классификация на три класса: клик, корзина, покупка):
$$
P(\text{feedback}=f_t\mid \text{history}=S_{t-1},\;\text{context}=c_t,\;\text{item}=i_t).
$$
* **Позднее связывание (Feedback):** Используем $h_{t-1}$, $\mathbf{c}_t$ и $\mathbf{i}_t$:
$$
\hat h^i_t = \mathrm{MLP}\bigl(\mathrm{Concat}(h_{t-1},\mathbf{c}_t,\mathbf{i}_t)\bigr).
$$

### Итоговый Pretrain Loss

$$\mathcal{L}_{\rm pre\text{-}train} = \mathcal{L}_{\mathrm{NPP}} + \mathcal{L}_{\mathrm{FP}}.$$

### Finetune

<p align="center">
  <img src="https://i.ibb.co/kgwt7pRb/Screenshot-2025-05-04-at-10-15-48.png" width="500" alt="finetune">
</p>


#### Постановка задачи

Пусть для пользователя с идентификатором `user_id` есть сформированная история взаимодействий и набор групп товаров, каждая из которых имеет свой `request_id`.

* **Скрытые состояния:**
    $$h_0, h_1, \dots, h_t$$
* **Группа (`request_id`):** Множество товаров, каждый из которых имеет бинарную метку:
    * **1:** Товар был добавлен в корзину (позитив).
    * **0:** Товар был просмотрен (негатив).

**Цель:** Для каждого товара внутри группы оценить его **релевантность** пользователю, используя актуальное скрытое состояние пользователя.

---

#### Учет временной задержки (Consistency with Validation)

На валидации и тесте между последним известным состоянием истории и моментом показа новой группы может быть значительный разрыв (от 2 дней до 1 месяца). Чтобы имитировать это и обеспечить консистентность, мы добавляем задержку в процесс обучения:

1.  **Выбор задержки:** Для каждой группы случайно выбирается задержка $\Delta$:
$$\Delta \sim \mathrm{Uniform}(2\ \text{дн.},\ 32 \text{дн.}).$$
2.  **Выбор состояния:** Находим самое *позднее* скрытое состояние $h_k$, которое предшествует текущей группе *не менее* чем на $\Delta$.
3.  **Вектор пользователя:** Используем это состояние **$h_k$** как итоговый вектор пользователя для расчета потерь в данной группе.

---

#### Попарная ранжирующая функция потерь

Мы обучаем модель ранжировать позитивные товары выше негативных в рамках одной группы.

1.  **Формирование пар:** Внутри каждой группы (`request_id`) генерируем все возможные пары товаров $(i, j)$, где:
    * Товар $i$ имеет метку **1** (корзина, позитив).
    * Товар $j$ имеет метку **0** (просмотр, негатив).
2.  **Расчет потерь:** Для каждой такой пары рассчитываем ранжирующую функцию потерь (например, **Calibrated Pairwise Logistic**), используя предсказанные релевантности для товаров $i$ и $j$.

## 2. Предобработка данных

In [3]:
!git clone https://github.com/wogsim/two_tower_recmodel

fatal: destination path 'two_tower_recmodel' already exists and is not an empty directory.


In [1]:
!pip install -r /content/two_tower_recmodel/notebooks/requirements.txt
!pip install -e two_tower_recmodel/grocery

Obtaining file:///content/two_tower_recmodel/grocery
  Installing build dependencies ... [?25l[?25hdone
  Checking if build backend supports build_editable ... [?25l[?25hdone
  Getting requirements to build editable ... [?25l[?25hdone
  Installing backend dependencies ... [?25l[?25hdone
  Preparing editable metadata (pyproject.toml) ... [?25l[?25hdone
Building wheels for collected packages: grocery
  Building editable for grocery (pyproject.toml) ... [?25l[?25hdone
  Created wheel for grocery: filename=grocery-0.1.0-py3-none-any.whl size=1515 sha256=bfd4ef5aaa1d56a73fce1201fff8e4074d57d2fc74b1828b92a6b4c946e736f7
  Stored in directory: /tmp/pip-ephem-wheel-cache-no_yu498/wheels/1e/69/79/9e835b8b9571913ea43b36767a1872057c821869f21e022ee7
Successfully built grocery
Installing collected packages: grocery
  Attempting uninstall: grocery
    Found existing installation: grocery 0.1.0
    Uninstalling grocery-0.1.0:
      Successfully uninstalled grocery-0.1.0
Successfully instal

In [None]:
DATA_DIR = "/content/recsys_course/data/lavka"

from grocery.utils.dataset import download_and_extract
download_and_extract(
     url="https://www.kaggle.com/api/v1/datasets/download/thekabeton/ysda-recsys-2025-lavka-dataset",
     filename="lavka.zip",
    dest_dir=DATA_DIR
)

lavka.zip: 100%|██████████████████████████████████████████████████| 447M/447M [00:05<00:00, 75.5MB/s]


Unpacking lavka.zip...
Files from lavka.zip successfully unpacked



In [None]:
DATA_DIR = "/content/recsys_course/data/lavka"

In [None]:
import polars as pl
from collections import deque
from typing import Dict, Any, Generator, Iterable, Optional
from abc import ABC, abstractmethod
from itertools import chain

 Оставляем только небольшой поднабор всех признаков. Остальные признаки буду учтены в будущем.

In [None]:
train_df = pl.read_parquet(DATA_DIR + '/train.parquet').select(['action_type', 'product_id', 'source_type', 'timestamp', 'user_id', 'request_id'])

### Разделение на трейн и валидация:

<div align="center">
  <img src="https://i.ibb.co/yBPn87t7/IMG000-19.jpg" width="500" alt="split">
</div>

Для оценки модели будем использовать train данные. Из них сформируем валидационную и обучающую части. Для того, чтобы получить корректные метрики на валидации, важно повторить все особенности тестовых данных:
- 2 дня разница между train и valid
- 1 месяц на valid
- оставляем только просмотр и корзину
- группы с >= 10 товарами
- группы с хотя бы одной корзиной и хотя бы одним просмотром

In [None]:
class ActionType:
    VIEW = 'AT_View'
    CLICK = 'AT_Click'
    CART_UPDATE = 'AT_CartUpdate'
    PURCHASE = 'AT_Purchase'


class Preprocessor:
    mapping_action_types = {
        ActionType.VIEW: 0,
        ActionType.CART_UPDATE: 1,
        ActionType.CLICK: 2,
        ActionType.PURCHASE: 3
    }
    def __init__(
        self,
        train_df: pl.DataFrame,
    ):
        self.train_df = train_df

    def _map_col(self, column: str, cast: pl.DataType = None) -> dict:
        uniques = sorted(self.train_df.select(pl.col(column)).unique().to_series().to_list())
        mapping = {val: idx for idx, val in enumerate(uniques)}

        for attr in ("train_df",):
            df = getattr(self, attr)
            df = df.with_columns(
                pl.col(column)
                .replace(mapping)
                .alias(column)
            )
            if cast is not None:
                df = df.with_columns(pl.col(column).cast(cast))
            setattr(self, attr, df)

        return mapping

    def run(self):
        self.train_df = self.train_df.with_columns(
            pl.col("source_type").fill_null("").alias("source_type")
        )

        self.mapping_product_ids = self._map_col("product_id")
        self.mapping_user_ids = self._map_col("user_id")
        self.mapping_source_types = self._map_col("source_type", cast=pl.Int8)

        self.train_df = self.train_df.with_columns(
            pl.col("action_type")
            .replace(self.mapping_action_types)
            .cast(pl.Int8)
            .alias("action_type")
        )

        self.targets = (
            self.train_df
            .filter(
                pl.col("request_id").is_not_null() &
                pl.col("action_type").is_in([0, 1]) &
                (pl.col("source_type") != self.mapping_source_types["ST_Catalog"])
            )
            .group_by([
                "user_id",
                "request_id",
                "product_id",
            ])
            .agg([
                pl.col("action_type").max(),
                pl.col("timestamp").min(),
                pl.col("source_type").mode().first()
            ])
        )

        requests_with_cartupdate_and_view = (
            self.targets
            .select(["request_id", "action_type", "timestamp"])
            .group_by("request_id")
            .agg([
                pl.col("action_type").max().alias("max_t"),
                pl.col("action_type").min().alias("min_t"),
                pl.len(),
                pl.col("timestamp").min().alias("req_ts")
            ])
            .with_columns(sum_targets=pl.col('max_t').add(pl.col('min_t')))
            .filter(pl.col('sum_targets') == 1)
            .filter(pl.col('len') >= 10)
            .select(["request_id", "req_ts"])
        )
        self.targets = (
            self.targets
            .drop("timestamp")
            .join(requests_with_cartupdate_and_view, on="request_id", how="inner")
            .with_columns(pl.col("req_ts").alias("timestamp"))
            .drop("req_ts")
        )
        self.targets = (
            self.targets
            .group_by(['user_id', 'request_id', 'timestamp', 'source_type'])
            .agg([
                pl.col('product_id'),
                pl.col('action_type'),
            ])
        )

        self.timesplit_valid_end = self.train_df["timestamp"].max()
        self.timesplit_valid_start = self.timesplit_valid_end - 30 * 24 * 60 * 60
        self.timesplit_train_end = self.timesplit_valid_start - 2 * 24 * 60 * 60
        self.timesplit_train_start = self.train_df["timestamp"].min()

        self.train_df = (
            self.train_df
            .filter(pl.col("action_type") != 0)
            .drop("request_id")
        )

        self.train_targets = self.targets.filter(
            pl.col("timestamp") <= self.timesplit_train_end
        )
        self.valid_targets = self.targets.filter(
            (pl.col("timestamp") > self.timesplit_valid_start) &
            (pl.col("timestamp") <= self.timesplit_valid_end)
        )
        self.train_history = self.train_df.filter(pl.col('timestamp') <= self.timesplit_train_end)
        self.valid_history = self.train_df.filter(pl.col('timestamp') > self.timesplit_train_end)

        return (
            self.train_history,
            self.valid_history,
            self.train_targets,
            self.valid_targets
        )

In [None]:
preprocessor = Preprocessor(train_df)
train_history, valid_history, train_targets, valid_targets = preprocessor.run()

- train_history/valid_history - позитивные взаимодействия пользователей
- train_targets/valid_targets - группы для finetune

## 3. Подготовка данных для pretrain

Для pretrain и finetune работа происходит с двумя последовательностями: последовательностью позитивных взаимодействий и последовательностью request-ов пользователя. Далее будем называть их history и candidates. На pretrain нам будет нужна только history.

#### Обрезание историй

У пользователей может быть разное количество позитивных событий в истории. Для простоты будет рассматривать последние 512 событий. Если вдруг их будет больше, то будем обрезать.

#### Схемы таблиц и пример

Схема для history имеет вид:
```python
HISTORY_SCHEMA = pl.Struct([{
    'source_type': pl.List(pl.Int64),
    'action_type': pl.List(pl.Int64),
    'product_id': pl.List(pl.Int64),
    'position': pl.List(pl.Int64),
    'targets_inds': pl.List(pl.Int64),
    'targets_lengths': pl.List(pl.Int64), # количество таргет событий в истории
    'lengths': pl.List(pl.Int64), # длина всей истории
}]).
```
`position` это индексы событий в истории (нужно будут далее для позиционных эмбеддингов), `targets_inds` - индексы тех позиций, которые будут участвовать в подсчете функции потерь. Они нужны, чтобы разделять потерю по событиям из обучающий и валидационной частей. `targets_lengths` - количество таких событий в истории.
Пример: Пусть есть некоторый пользователь с историей из позитивных взаимодействий длины 5. Пусть первые 3 события попадают в обучение, а последние 2 в валидацию. Тогда:
```python
history_train_sample = pl.DataFrame([{
    'source_type': [1, 1, 2, 3, 4],
    'action_type': [1, 0, 1, 0, 1],
    'product_id': [1, 2, 3, 4, 5],
    'position': [0, 1, 2, 3, 4],
    'targets_inds': [0, 1, 2],
    'targets_lengths': [3],
    'lengths': [5]
}])
history_valid_sample = pl.DataFrame([{
    'source_type': [1, 1, 2, 3, 4],
    'action_type': [1, 0, 1, 0, 1],
    'product_id': [1, 2, 3, 4, 5],
    'position': [0, 1, 2, 3, 4],
    'targets_inds': [3, 4],
    'targets_lengths': [2],
    'lengths': [5]
}])
```

In [None]:
def ensure_sorted_by_timestamp(group: Iterable[Dict[str, Any]]) -> Generator[Dict[str, Any], None, None]:
    """
    Ensures that the given iterable of events is sorted by the 'timestamp' field.

    This function iterates over each event in the provided iterable and checks if the
    'timestamp' of the current event is greater than or equal to the 'timestamp' of the
    previous event. If any event has a 'timestamp' that is less than the previous event's
    'timestamp', an AssertionError is raised.

    @param group: An iterable of dictionaries, where each dictionary represents an event with at least a 'timestamp' key.
    @return: A generator yielding each event from the input iterable in order, ensuring they are sorted by 'timestamp'.
    @raises AssertionError: If the events are not sorted by 'timestamp'.
    """

    events = chain(group)

    prev_timestamp = 0
    for event in events:
        if event["timestamp"] >= prev_timestamp:
            prev_timestamp = event["timestamp"]
            yield event
        else:
            raise AssertionError("Events are not sorted by timestamp")

In [None]:
class Mapper(ABC):
    HISTORY_SCHEMA = pl.Struct({
        'source_type': pl.List(pl.Int64),
        'action_type': pl.List(pl.Int64),
        'product_id': pl.List(pl.Int64),
        'position': pl.List(pl.Int64),
        'targets_inds': pl.List(pl.Int64),
        'targets_lengths': pl.List(pl.Int64), # количество таргет событий в истории
        'lengths': pl.List(pl.Int64), # длина всей истории
    })
    CANDIDATES_SCHEMA = pl.Struct({
        'source_type': pl.List(pl.Int64),
        'action_type': pl.List(pl.Int64),
        'product_id': pl.List(pl.Int64),
        'lengths': pl.List(pl.Int64), # длина каждого реквеста
        'num_requests': pl.List(pl.Int64) # общее количество реквестов у этого пользователя
    })

    def __init__(self, min_length: int, max_length: int):
        self._min_length: int = min_length
        self._max_length: int = max_length

    @abstractmethod
    def __call__(self, group: pl.DataFrame) -> pl.DataFrame:
        pass

    def get_empty_frame(self, candidates=False):
        return pl.DataFrame(schema=pl.Schema({
            'history': self.HISTORY_SCHEMA,
            **({'candidates': self.CANDIDATES_SCHEMA} if candidates else {})
        }))

Структура, которая:
- накапливает history для пользователя и оставлять последние `max_length`
- умеет обращаться по индексу к событию истории
- имеет метод `get(self, targets_ids)`, который превращает `self._data` в dict в соотвествии со схемой `HISTORY_SCHEMA`.

In [None]:
class HistoryDeque:
    def __init__(self, max_length=512):
        self._data = deque([], maxlen=max_length)

    def append(self, x):
        self._data.append(x)

    def __len__(self):
        return (len(self._data))

    def __getitem__(self, idx):
        return self._data[idx]

    def get(self, targets_inds=None):
        """
        Retrieves a dictionary containing various attributes of the dataset samples.

        If `targets_inds` is not provided, it automatically identifies indices of samples where the `target` is 1.

        @param targets_inds: List of indices of target samples. If None, it will be determined based on samples with target value 1.
        @return: Dictionary with keys ['source_type', 'action_type', 'product_id', 'position', 'targets_inds', 'targets_lengths', 'lengths']
                Each key maps to a list or value representing the respective attribute of the dataset samples.
        """
        if targets_inds is None:
            targets_inds = [i for i, value in enumerate(self._data)
                                        if value['target'] == 1]

        history = {'source_type': [stori['source_type'] for stori in self._data],
                    'action_type': [stori['action_type'] for stori in self._data],
                    'product_id': [stori['product_id'] for stori in self._data],
                    'position': list(range(len(self._data))),
                    'targets_inds': targets_inds,
                    'targets_lengths': [len(targets_inds)],
                    'lengths': [len(self._data)]}


        return history

На основе функции `get_pretrain_data` реализуем `PretrainMapper`, который будет по данном пользователю выдавать обучающий пример в нужном формате.

In [None]:
class PretrainMapper(Mapper):
    def __call__(self, group: pl.DataFrame) -> pl.DataFrame:
        """
        Processes a group of data by maintaining a history of rows up to a specified maximum length.
        If the history meets the minimum length requirement and contains at least one target, it returns
        a DataFrame with the history. Otherwise, it returns an empty DataFrame.

        @param group: A Polars DataFrame containing the group of data to process.
        @return: A Polars DataFrame containing the history if conditions are met; otherwise, an empty DataFrame.
        """
        deque = HistoryDeque(self._max_length)
        events_generator = ensure_sorted_by_timestamp(group.to_struct())

        for event in events_generator:
            deque.append(event)

        if len(deque) > self._min_length:
            return pl.DataFrame([{'history': deque.get()}], schema=pl.Schema({'history': Mapper.HISTORY_SCHEMA}))

        else:
            return self.get_empty_frame()


In [None]:
def get_pretrain_data(train_history: pl.DataFrame,
                      valid_history: pl.DataFrame,
                      min_length: int = 5,
                      max_length: int = 4096) -> pl.DataFrame:
    mapper = PretrainMapper(
        min_length=min_length,
        max_length=max_length,
    )

    train_data = (
        train_history.with_columns(target=pl.lit(1))
        .sort(['user_id', 'timestamp'])
        .group_by('user_id')
        .map_groups(mapper)
    )

    valid_data = (
        pl.concat([
            train_history.with_columns(target=pl.lit(0)),
            valid_history.with_columns(target=pl.lit(1))
        ], how='diagonal')
        .sort(['user_id', 'timestamp'])
        .group_by('user_id')
        .map_groups(mapper)
    )

    return train_data, valid_data

In [None]:
pretrain_train_data, pretrain_valid_data = get_pretrain_data(train_history, valid_history, min_length=5, max_length=512)

In [None]:
pretrain_valid_data

history
struct[7]
"{[1, 4, … 0],[1, 2, … 3],[2594, 14908, … 3651],[0, 1, … 182],[180, 181, 182],[3],[183]}"
"{[1, 1, … 1],[2, 2, … 1],[19210, 8368, … 10986],[0, 1, … 6],[5, 6],[2],[7]}"
"{[1, 1, … 1],[1, 1, … 2],[248, 24235, … 4297],[0, 1, … 39],[],[0],[40]}"
"{[3, 3, … 12],[1, 1, … 1],[11151, 6673, … 15581],[0, 1, … 39],[],[0],[40]}"
"{[1, 1, … 0],[1, 1, … 3],[11905, 8728, … 5651],[0, 1, … 102],[68, 69, … 102],[35],[103]}"
…
"{[8, 8, … 0],[1, 1, … 3],[11122, 15697, … 22773],[0, 1, … 15],[],[0],[16]}"
"{[8, 8, … 8],[2, 2, … 1],[9902, 9902, … 10053],[0, 1, … 78],[78],[1],[79]}"
"{[1, 0, … 0],[1, 3, … 3],[10964, 10964, … 24734],[0, 1, … 31],[],[0],[32]}"
"{[1, 1, … 0],[1, 1, … 3],[4612, 14635, … 14635],[0, 1, … 5],[],[0],[6]}"


## 4. Реализуем свой torch.nn.utils.data.Dataset

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import polars as pl
from tqdm import tqdm

In [None]:
def convert_dict_to_tensor(data_dict):
    """
    Recursively converts lists within a dictionary to PyTorch tensors with dtype=torch.int64.

    @param data_dict: A dictionary potentially containing nested dictionaries and lists.
    @return: A new dictionary with the same structure as `data_dict`, but with lists converted to PyTorch tensors.
    """
    if not isinstance(data_dict, dict):
        if isinstance(data_dict, str):
            return data_dict
        return torch.tensor(data_dict, dtype=torch.int64)
    else:
        new_dict = {}
        for key in data_dict:
            new_dict[key] = convert_dict_to_tensor(data_dict[key])
    return new_dict


In [None]:
def test_convert_dict_to_tensor_basic():
    input_data = {
        'a': [1, 2, 3],
        'b': {
            'c': [4, 5],
            'd': 6
        },
        'e': 'text'
    }

    result = convert_dict_to_tensor(input_data)
    print(result)

    assert isinstance(result['a'], torch.Tensor)
    assert result['a'].dtype == torch.int64
    assert result['a'].tolist() == [1, 2, 3]

    assert isinstance(result['b'], dict)
    assert isinstance(result['b']['c'], torch.Tensor)
    assert result['b']['c'].dtype == torch.int64
    assert result['b']['c'].tolist() == [4, 5]

    assert result['b']['d'] == 6
    assert result['e'] == 'text'

test_convert_dict_to_tensor_basic()

{'a': tensor([1, 2, 3]), 'b': {'c': tensor([4, 5]), 'd': tensor(6)}, 'e': 'text'}


In [None]:
class LavkaDataset(Dataset):
    def __init__(self, data):
        self.data = data

    @classmethod
    def from_dataframe(cls, df: pl.DataFrame) -> 'LavkaDataset':
        converted_data = [convert_dict_to_tensor(group) for group in df.to_struct()]

        return cls(converted_data)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

In [None]:
print("Creating train dataset ...")
train_ds = LavkaDataset.from_dataframe(pretrain_train_data)
print("Creating valid dataset ...")
valid_ds = LavkaDataset.from_dataframe(pretrain_valid_data)

Creating train dataset ...
Creating valid dataset ...


## 5. Реализуем основной backbone модели

Реализуем класс ResNet-блока согласно следующим формулам:

Для входа $x\in\mathbb{R}^{\text{batch}\times d}$ вычислить:  
$$
    \begin{aligned}
    z &= xW + b,\\
    a &= \mathrm{ReLU}(z),\\
    d'&= \mathrm{Dropout}(a),\\
    y &= \mathrm{LayerNorm}\bigl(x + d'\bigr).
    \end{aligned}
$$  
В компактном виде:  
$$
    y = \mathrm{LayerNorm}\Bigl(x + \mathrm{Dropout}\bigl(\mathrm{ReLU}(xW + b)\bigr)\Bigr)\,.
$$


In [None]:
class ResNet(nn.Module):
    def __init__(self, embedding_dim, dropout=0.):
        super().__init__()

        self.linear = nn.Linear(embedding_dim, embedding_dim)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(p=dropout)
        self.layer_norm = nn.LayerNorm(embedding_dim)

    def forward(self, x):

        identity = x
        out = self.linear(x)
        out = self.relu(out)
        out = self.dropout(out)
        out = self.layer_norm(identity + out)

        return out

In [None]:
def test_resnet_output_shape():
    embedding_dim = 32
    batch_size = 8
    model = ResNet(embedding_dim)
    x = torch.randn(batch_size, embedding_dim)
    out = model(x)
    assert out.shape == (batch_size, embedding_dim)

test_resnet_output_shape()

Реализуем `ContextEncoder`, `ItemEncoder` и `ActionEncoder`. На вход они принимают torch.tensor с индексами размера (seq_len,), а на выходе получают torch.tensor с соотвествующими векторами размера (seq_len, embedding_dim).

In [None]:
print(f'длинна action {len(preprocessor.mapping_action_types)}')
print(f'длинна item {len(preprocessor.mapping_product_ids)}')
print(f'длинна source {len(preprocessor.mapping_source_types)}')

длинна action 4
длинна item 26522
длинна source 13


In [None]:
class ContextEncoder(nn.Module):
    def __init__(self, embedding_dim=64):
        super().__init__()
        self.embeddings = nn.Embedding(13, embedding_dim)

    def forward(self, inputs):
        return self.embeddings(inputs)


class ItemEncoder(nn.Module):
    def __init__(self, embedding_dim=64):
        super().__init__()
        self.embeddings = nn.Embedding(26522, embedding_dim)

    def forward(self, inputs):
        return self.embeddings(inputs)


class ActionEncoder(nn.Module):
    def __init__(self, embedding_dim=64):
        super().__init__()
        self.embeddings = nn.Embedding(4, embedding_dim)

    def forward(self, inputs):
        return self.embeddings(inputs)

In [None]:
import math
class PositionalEncoding(nn.Module):
    def __init__(self, max_seq_len: int = 512, embedding_dim: int= 64):
        super().__init__()

        pe = torch.zeros(max_seq_len, embedding_dim)
        position = torch.arange(0, max_seq_len, dtype=torch.float).unsqueeze(1)

        div_term = torch.exp(
            torch.arange(0, embedding_dim, 2).float() * (-math.log(10000.0) / embedding_dim)
        )

        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)

        pe = pe.unsqueeze(0)

        self.register_buffer('pe', pe)

    def forward(self, pos: torch.Tensor) -> torch.Tensor:
        return self.pe[0].index_select(0, pos.flatten())

In [None]:
def test_context_encoder_shape():
    embedding_dim = 16
    batch_size = 5
    seq_len = 7
    model = ContextEncoder(embedding_dim=embedding_dim)
    x = torch.randint(0, 4, (batch_size, seq_len))
    out = model(x)
    assert out.shape == (batch_size, seq_len, embedding_dim)
test_context_encoder_shape()

def test_item_encoder_shape():
    embedding_dim = 20
    batch_size = 6
    seq_len = 10
    model = ItemEncoder(embedding_dim=embedding_dim)
    x = torch.randint(0, 20000, (batch_size, seq_len))
    out = model(x)
    assert out.shape == (batch_size, seq_len, embedding_dim)
test_item_encoder_shape()

def test_action_encoder_shape():
    embedding_dim = 12
    batch_size = 4
    seq_len = 3
    model = ActionEncoder(embedding_dim=embedding_dim)
    x = torch.randint(0, 4, (batch_size, seq_len))
    out = model(x)
    assert out.shape == (batch_size, seq_len, embedding_dim)
test_action_encoder_shape()

In [None]:
def get_mask(lengths):
    """
    Generates a mask tensor based on the given sequence lengths.

    The mask is a boolean tensor where each row corresponds to a sequence and contains
    True values up to the length of the sequence and False values thereafter.

    @param lengths: A 1D tensor containing the lengths of sequences.
    @return: A 2D boolean tensor where each row has True up to the corresponding sequence length.
    """
    max_length = max(lengths)
    arange_tensor = torch.arange(max_length, device=lengths.device)

    return (arange_tensor < lengths.unsqueeze(1))

In [None]:
def test_get_mask():
    lengths = torch.tensor([2, 3, 1])
    expected_mask = torch.tensor([
        [True, True, False],
        [True, True, True],
        [True, False, False]
    ])
    assert torch.equal(get_mask(lengths), expected_mask)

test_get_mask()

Реализуем `ModelBackbone`. Эта часть модели является общей для pretrain и finetune. Она кодируется входные события, преобразует их в нужный для трансформера формат `(batch_size, seq_len, embedding_dim)`, прогоняет через них трансформер и возвращает три поля: выходы трансформера, вектора товаров и вектора feedback-ов.

In [None]:
class ModelBackbone(nn.Module):
    def __init__(self,
                 embedding_dim=64,
                 num_heads=2,
                 max_seq_len=512,
                 dropout_rate=0.2,
                 num_transformer_layers=2):
        super().__init__()
        self.context_encoder = ContextEncoder(embedding_dim)
        self.item_encoder = ItemEncoder(embedding_dim)
        self.action_encoder = ActionEncoder(embedding_dim)
        self.position_embeddings = PositionalEncoding(max_seq_len, embedding_dim)
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embedding_dim,
            nhead=num_heads,
            dim_feedforward=embedding_dim * 4,
            dropout=dropout_rate,
            activation='gelu',
            batch_first=True
        )
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_transformer_layers)
        self._embedding_dim = embedding_dim

    @property
    def embedding_dim(self):
        return self._embedding_dim

    def forward(self, inputs):
        context_embeddings = self.context_encoder(inputs['history']['source_type'])
        item_embeddings = self.item_encoder(inputs['history']['product_id'])
        action_embeddings = self.action_encoder(inputs['history']['action_type'])
        position_embedding = self.position_embeddings(inputs['history']['position'])

        padding_mask = get_mask(inputs['history']['lengths'])
        batch_size, seq_len = padding_mask.shape

        token_embeddings = item_embeddings.new_zeros(
            batch_size, seq_len, self.embedding_dim, device=context_embeddings.device)

        summed_embs = (context_embeddings + item_embeddings
                       + action_embeddings + position_embedding)

        token_embeddings[padding_mask] = summed_embs

        source_embeddings = self.transformer_encoder(
            token_embeddings,
            mask=torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool().to(device=context_embeddings.device),
            src_key_padding_mask=~padding_mask)

        return {
            'source_embeddings': source_embeddings[padding_mask],
            'item_embeddings': item_embeddings.squeeze(),
            'context_embeddings': context_embeddings.squeeze()}

In [None]:
def test_model_backbone():
    sample = {
        'history': {
            'source_type': torch.tensor([[1, 1, 7, 1, 1,]]),
            'action_type': torch.tensor([[2, 2, 2, 1, 2]]),
            'product_id': torch.tensor([[19210,  8368,  5165,  5326, 12476]]),
            'position': torch.tensor([[0, 1, 2, 3, 4]]),
            'targets_inds': torch.tensor([[0, 1, 2]]),
            'targets_lengths': torch.tensor([3]),
            'lengths': torch.tensor([5])
        }
    }


    backbone = ModelBackbone()
    output = backbone(sample)
    print(output['source_embeddings'].shape)
    print(output['item_embeddings'].shape)
    print(output['context_embeddings'].shape)
    assert output['source_embeddings'].shape == (5, 64)
    assert output['item_embeddings'].shape == (5, 64)
    assert output['context_embeddings'].shape == (5, 64)

test_model_backbone()

torch.Size([5, 64])
torch.Size([5, 64])
torch.Size([5, 64])


## 6. Реализуем pretrain модель

#### Головы (heads)

- **User–context fusion**  
      Последовательность ResNet-блоков, свёртка размерности `2D → D`. На входе конкатенация `source_embeddings ∥ context_embeddings`.  
- **Candidate projector**  
      Три ResNet-блока, преобразующие эмбеддинг товара из `D → D`.  
- **Classifier**  
      Три ResNet-блока, затем линейный слой `3D → 3`, для предсказания типа следующего действия из трёх возможных (cart, click, purchase).  
- **Параметр τ**  
      Скаляp-коэффициент `τ = clip(exp(τ_raw), τ_min, τ_max)` для масштабирования скалярных произведений – температура в contrastive-лоссе.


#### Формулы лоссов

Обозначения:  
- $u_i \in \mathbb{R}^D$ – нормализованный вектор пользователя для i-го примера.  
- $c_i \in \mathbb{R}^D$ – нормализованный вектор кандидата (позитивного товара) для i-го примера.  
- $\{n_{ij}\}_{j=1}^M\subset\mathbb{R}^D$ – нормализованные векторы $M$ негативных товаров (весь каталог товаров минус один товар, позитивный).  
- $\text{temp}>0$ – «температура» (скаляр).  
- $K= M+1$ – общее число кандидатов (1 позитивный + M негативных). Количество товаров во всем каталоге будет $1 + M$.

1. **Retrieval loss** (контрастивный softmax-лосс)  
   Для каждого примера $i$ вычисляем скалярные логиты:  
   $$
     \ell_i = [\,\underbrace{u_i^\top c_i}_{\text{позитивный логит}} \,\big|\,
               \underbrace{u_i^\top n_{i1},\,u_i^\top n_{i2},\,\dots,\,u_i^\top n_{iM}}_{\text{негативные логиты}}] \;\times\;\\text{temp}
   $$
   Затем лосс  
   $$
     \mathcal{L}_{\mathrm{retr}}
     = -\frac1N \sum_{i=1}^N \log\frac{\exp\bigl(u_i^\top c_i \,\text{temp}\bigr)}
                                      {\exp\bigl(u_i^\top c_i \,\text{temp}\bigr)
                                     + \sum_{j=1}^M \exp\bigl(u_i^\top n_{ij}\,\text{temp}\bigr)}.
   $$

2. **Action loss** (кросс-энтропия)  
   Для каждого положительного шага $i$ модель выдаёт логиты $\mathbf{z}_i \in \mathbb{R}^3$ по трём классам действий, а истинная метка $y_i\in\{0,1,2\}$.  
   $$
     \mathcal{L}_{\mathrm{action}}
     = -\frac1N \sum_{i=1}^N \sum_{k=0}^2 \delta_{y_i=k}\,\log\bigl(\mathrm{softmax}(\mathbf{z}_i)_k\bigr),
   $$
   где $\mathrm{softmax}(\mathbf{z})_k = \frac{e^{z_k}}{\sum_{m=0}^2 e^{z_m}}$.

3. **Итоговый лосс**  
   $$
     \mathcal{L}
     = \mathcal{L}_{\mathrm{retr}}
       \;+\; 10 \times \mathcal{L}_{\mathrm{action}}.
   $$
   Перевзвешиваем action часть т.к. у нее сильно меньше масштаб.

In [None]:
class PretrainModel(nn.Module):
    MIN_TEMPERATURE = 0.01
    MAX_TEMPERATURE = 100

    def __init__(self,
                 backbone,
                 embedding_dim=64):
        super().__init__()
        self.backbone = backbone
        self.user_context_fusion = nn.Sequential(
            ResNet(2 * embedding_dim),
            ResNet(2 * embedding_dim),
            ResNet(2 * embedding_dim),
            nn.Linear(2 * embedding_dim, embedding_dim),
        )
        self.candidate_projector = nn.Sequential(
            ResNet(embedding_dim),
            ResNet(embedding_dim),
            ResNet(embedding_dim),
        )
        self.classifier = nn.Sequential(
            ResNet(3 * embedding_dim),
            ResNet(3 * embedding_dim),
            ResNet(3 * embedding_dim),
            nn.Linear(3 * embedding_dim, 3),
        )
        self._embedding_dim = embedding_dim
        self.tau = torch.nn.Parameter(torch.zeros(1, dtype=torch.float32))

        self.cross_entr_loss = nn.CrossEntropyLoss()

    @property
    def embedding_dim(self):
        return self._embedding_dim

    @property
    def temperature(self):
        return torch.clip(torch.exp(self.tau), min=self.MIN_TEMPERATURE, max=self.MAX_TEMPERATURE)

    def forward(self, inputs):
        backbone_outputs = self.backbone(inputs)
        source_embeddings = backbone_outputs['source_embeddings']
        item_embeddings = backbone_outputs['item_embeddings']
        context_embeddings = backbone_outputs['context_embeddings']

        lengths = inputs['history']['lengths']
        offsets = torch.cumsum(lengths, dim=0) - lengths
        target_mask = torch.full((sum(lengths).item(),), False, dtype=torch.bool, device=lengths.device)
        target_inds = torch.repeat_interleave(offsets, inputs['history']['targets_lengths']) + inputs['history']['targets_inds']
        target_mask[target_inds] = True
        non_first_element = torch.full((sum(lengths).item(),), True, dtype=torch.bool, device=lengths.device)
        non_first_element[offsets] = False
        source_mask = torch.roll(non_first_element & target_mask, -1)

        source_embeddings = source_embeddings[source_mask]
        context_embeddings = context_embeddings[non_first_element & target_mask]
        item_embeddings = item_embeddings[non_first_element & target_mask]


        # calc retrieval loss
        user_embeddings = torch.nn.functional.normalize(self.user_context_fusion(torch.cat([source_embeddings, context_embeddings], dim=-1)))
        candidate_embeddings = torch.nn.functional.normalize(
                                self.candidate_projector(
                                self.backbone.item_encoder.embeddings(inputs['history']['product_id'][target_mask & non_first_element])))
        negative_embeddings = torch.nn.functional.normalize(self.candidate_projector(self.backbone.item_encoder.embeddings.weight))
        pos_logits = torch.sum(user_embeddings * candidate_embeddings, dim=-1) * self.temperature
        neg_logits = user_embeddings @ negative_embeddings.T * self.temperature
        next_positive_prediction_loss = - torch.mean(
                                          (pos_logits) - (torch.logsumexp(neg_logits, dim=-1))
                                                  )

        # calc action loss
        logits = self.classifier(torch.cat([source_embeddings, context_embeddings, item_embeddings], dim=-1))
        targets = inputs['history']['action_type'][non_first_element & target_mask] - 1
        feedback_prediction_loss = self.cross_entr_loss(logits, targets)

        return {
            'next_positive_prediction_loss': next_positive_prediction_loss,
            'feedback_prediction_loss': feedback_prediction_loss,
            'loss': next_positive_prediction_loss + feedback_prediction_loss * 10
        }

## 7. (0.5 балл) Обучим pretrain модель

In [None]:
def collate_fn(batch):
    """
    Collates a batch of samples from a dataset.

    This function is designed to handle batches where each sample is a dictionary.
    It recursively collates values associated with the same keys across all samples in the batch.
    For tensor values, it concatenates them along the first dimension.
    For dictionary values, it applies the same collation logic recursively.
    For other types of values, it simply aggregates them into a list.

    @param batch: A list of samples, where each sample is a dictionary.
    @return: A dictionary with the same keys as the samples, where each value is either
             a concatenated tensor, a recursively collated dictionary, or a list of values.
    """
    if isinstance(batch, list) and isinstance(batch[0], dict):
        batched_dict = {}
        for key in batch[0].keys():
            list_of_values = [item[key] for item in batch]
            batched_dict[key] = collate_fn(list_of_values)
        return batched_dict
    elif isinstance(batch, list) and isinstance(batch[0], torch.Tensor):
        return torch.cat(batch, dim=0)
    elif isinstance(batch, list):
        return batch
    else:
        return batch

def move_to_device(batch, device):
    """
    Moves a batch of data to a specified device (e.g., CPU or GPU).

    Args:
        batch (torch.Tensor or dict): The batch of data to move. Can be a single tensor or a dictionary of tensors.
        device (torch.device): The target device to which the batch should be moved.

    Returns:
        torch.Tensor or dict: The batch of data moved to the specified device.
                             If the input is a dictionary, the returned value will be a dictionary with the same keys
                             and values moved to the specified device.
    """
    if isinstance(batch, torch.Tensor):
        return batch.to(device)
    elif isinstance(batch, dict):
        return {key: move_to_device(value, device) for key, value in batch.items()}
    elif isinstance(batch, list):
        return [move_to_device(item, device) for item in batch]
    elif isinstance(batch, tuple):
        return tuple(move_to_device(item, device) for item in batch)

In [None]:
def test_collate_fn_basic():
    batch = [
        {
            'x': torch.tensor([1, 2]),
            'y': {
                'z': torch.tensor([[10], [20]]),
                'w': 5
            },
            's': 'foo'
        },
        {
            'x': torch.tensor([3, 4]),
            'y': {
                'z': torch.tensor([[30], [40]]),
                'w': 6
            },
            's': 'bar'
        }
    ]

    result = collate_fn(batch)

    assert isinstance(result['x'], torch.Tensor)
    assert result['x'].tolist() == [1, 2, 3, 4]

    assert isinstance(result['y'], dict)
    assert isinstance(result['y']['z'], torch.Tensor)
    assert result['y']['z'].tolist() == [[10], [20], [30], [40]]
    assert result['y']['w'] == [5, 6]

    assert result['s'] == ['foo', 'bar']
test_collate_fn_basic()

In [None]:
from statistics import mean
from sklearn.metrics import ndcg_score


def train_pretrain_model(model, train_loader, valid_loader, optimizer, scheduler, num_epochs, device):
    global_cnt = 0
    prev_valid_loss = None
    for epoch in range(num_epochs):
        model.train()
        train_losses = []
        action_losses = []
        retrieval_losses = []
        for batch in tqdm(train_loader):
            batch = move_to_device(batch, device)
            optimizer.zero_grad()
            output = model(batch)
            loss = output['loss']
            loss.backward()
            optimizer.step()
            train_losses.append(loss.item())
            action_losses.append(output['feedback_prediction_loss'].item())
            retrieval_losses.append(output['next_positive_prediction_loss'].item())
        scheduler.step()
        print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {mean(train_losses):.6f}, Train Feedback Loss: {mean(action_losses):.6f}, Train NPP Loss: {mean(retrieval_losses):.6f}")


        model.eval()
        valid_losses = []
        action_losses = []
        retrieval_losses = []
        with torch.inference_mode():
            for batch in tqdm(valid_loader):
                if len(batch['history']['targets_inds']) == 0:
                    continue
                batch = move_to_device(batch, device)
                output = model(batch)
                loss = output['loss']
                valid_losses.append(loss.item())
                action_losses.append(output['feedback_prediction_loss'].item())
                retrieval_losses.append(output['next_positive_prediction_loss'].item())

        avg_valid_loss = mean(valid_losses)
        print(f"Epoch {epoch+1}/{num_epochs}, Valid Loss: {avg_valid_loss:.6f}, Valid Feedback Loss: {mean(action_losses):.6f}, Valid NPP Loss: {mean(retrieval_losses):.6f}")

        if prev_valid_loss is None or prev_valid_loss > avg_valid_loss:
            global_cnt = 0
            prev_valid_loss = avg_valid_loss
            with torch.no_grad():
                torch.save(model, './pretrain.pt')
        else:
            global_cnt += 1
            if global_cnt == 10:
                break

In [None]:
lr = 0.001
batch_size = 8
warmup_epochs = 4
start_factor = 0.1
num_epochs = 100

embedding_dim = 64
num_heads = 2
max_seq_len = 512
dropout_rate = 0.2
num_transformer_layers = 2

In [None]:
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
valid_loader = DataLoader(valid_ds, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)

In [None]:
device = torch.device('cuda')

backbone = ModelBackbone(embedding_dim=embedding_dim,
                        num_heads=num_heads,
                        max_seq_len=max_seq_len,
                        dropout_rate=dropout_rate,
                        num_transformer_layers=num_transformer_layers).to(device)
model_pretrain = PretrainModel(backbone=backbone,
                             embedding_dim=embedding_dim).to(device)
optimizer = optim.AdamW(model_pretrain.parameters(), lr=lr, weight_decay=0.01)
scheduler = optim.lr_scheduler.LinearLR(
    optimizer,
    start_factor=start_factor,
    total_iters=warmup_epochs
)

In [None]:
train_pretrain_model(model_pretrain, train_loader, valid_loader, optimizer, scheduler, num_epochs, device)

100%|██████████| 245/245 [00:07<00:00, 31.42it/s]


Epoch 1/100, Train Loss: 15.076300, Train Feedback Loss: 0.491430, Train NPP Loss: 10.162004


100%|██████████| 263/263 [00:01<00:00, 139.36it/s]


Epoch 1/100, Valid Loss: 14.385434, Valid Feedback Loss: 0.424745, Valid NPP Loss: 10.137979


100%|██████████| 245/245 [00:07<00:00, 31.44it/s]


Epoch 2/100, Train Loss: 14.032211, Train Feedback Loss: 0.400547, Train NPP Loss: 10.026740


100%|██████████| 263/263 [00:01<00:00, 150.71it/s]


Epoch 2/100, Valid Loss: 13.988038, Valid Feedback Loss: 0.407550, Valid NPP Loss: 9.912540


100%|██████████| 245/245 [00:07<00:00, 31.33it/s]


Epoch 3/100, Train Loss: 13.579624, Train Feedback Loss: 0.390384, Train NPP Loss: 9.675787


100%|██████████| 263/263 [00:01<00:00, 151.00it/s]


Epoch 3/100, Valid Loss: 13.541099, Valid Feedback Loss: 0.401443, Valid NPP Loss: 9.526667


100%|██████████| 245/245 [00:07<00:00, 31.25it/s]


Epoch 4/100, Train Loss: 13.109988, Train Feedback Loss: 0.380026, Train NPP Loss: 9.309725


100%|██████████| 263/263 [00:01<00:00, 149.76it/s]


Epoch 4/100, Valid Loss: 13.084324, Valid Feedback Loss: 0.387779, Valid NPP Loss: 9.206531


100%|██████████| 245/245 [00:07<00:00, 31.26it/s]


Epoch 5/100, Train Loss: 12.581448, Train Feedback Loss: 0.364485, Train NPP Loss: 8.936603


100%|██████████| 263/263 [00:01<00:00, 150.33it/s]


Epoch 5/100, Valid Loss: 12.654987, Valid Feedback Loss: 0.383135, Valid NPP Loss: 8.823638


100%|██████████| 245/245 [00:07<00:00, 31.58it/s]


Epoch 6/100, Train Loss: 12.023960, Train Feedback Loss: 0.349631, Train NPP Loss: 8.527654


100%|██████████| 263/263 [00:01<00:00, 150.61it/s]


Epoch 6/100, Valid Loss: 12.249575, Valid Feedback Loss: 0.375640, Valid NPP Loss: 8.493171


100%|██████████| 245/245 [00:07<00:00, 31.35it/s]


Epoch 7/100, Train Loss: 11.568967, Train Feedback Loss: 0.338255, Train NPP Loss: 8.186420


100%|██████████| 263/263 [00:01<00:00, 145.70it/s]


Epoch 7/100, Valid Loss: 11.966042, Valid Feedback Loss: 0.378144, Valid NPP Loss: 8.184600


100%|██████████| 245/245 [00:07<00:00, 31.19it/s]


Epoch 8/100, Train Loss: 11.096124, Train Feedback Loss: 0.325050, Train NPP Loss: 7.845627


100%|██████████| 263/263 [00:01<00:00, 148.52it/s]


Epoch 8/100, Valid Loss: 11.621952, Valid Feedback Loss: 0.375110, Valid NPP Loss: 7.870847


100%|██████████| 245/245 [00:07<00:00, 31.29it/s]


Epoch 9/100, Train Loss: 10.697235, Train Feedback Loss: 0.314754, Train NPP Loss: 7.549692


100%|██████████| 263/263 [00:01<00:00, 151.54it/s]


Epoch 9/100, Valid Loss: 11.464827, Valid Feedback Loss: 0.381268, Valid NPP Loss: 7.652151


100%|██████████| 245/245 [00:07<00:00, 31.31it/s]


Epoch 10/100, Train Loss: 10.411818, Train Feedback Loss: 0.306068, Train NPP Loss: 7.351135


100%|██████████| 263/263 [00:01<00:00, 149.93it/s]


Epoch 10/100, Valid Loss: 11.413508, Valid Feedback Loss: 0.387704, Valid NPP Loss: 7.536468


100%|██████████| 245/245 [00:07<00:00, 31.44it/s]


Epoch 11/100, Train Loss: 10.231369, Train Feedback Loss: 0.296865, Train NPP Loss: 7.262723


100%|██████████| 263/263 [00:01<00:00, 150.35it/s]


Epoch 11/100, Valid Loss: 11.553783, Valid Feedback Loss: 0.402013, Valid NPP Loss: 7.533655


100%|██████████| 245/245 [00:07<00:00, 31.44it/s]


Epoch 12/100, Train Loss: 10.075129, Train Feedback Loss: 0.285695, Train NPP Loss: 7.218182


100%|██████████| 263/263 [00:01<00:00, 150.37it/s]


Epoch 12/100, Valid Loss: 11.473783, Valid Feedback Loss: 0.397993, Valid NPP Loss: 7.493850


100%|██████████| 245/245 [00:07<00:00, 31.41it/s]


Epoch 13/100, Train Loss: 9.935277, Train Feedback Loss: 0.276890, Train NPP Loss: 7.166374


100%|██████████| 263/263 [00:01<00:00, 147.76it/s]


Epoch 13/100, Valid Loss: 11.553828, Valid Feedback Loss: 0.405777, Valid NPP Loss: 7.496060


100%|██████████| 245/245 [00:07<00:00, 31.15it/s]


Epoch 14/100, Train Loss: 9.794553, Train Feedback Loss: 0.266706, Train NPP Loss: 7.127493


100%|██████████| 263/263 [00:01<00:00, 146.65it/s]


Epoch 14/100, Valid Loss: 11.590894, Valid Feedback Loss: 0.411010, Valid NPP Loss: 7.480794


100%|██████████| 245/245 [00:07<00:00, 31.14it/s]


Epoch 15/100, Train Loss: 9.652846, Train Feedback Loss: 0.254320, Train NPP Loss: 7.109643


100%|██████████| 263/263 [00:01<00:00, 148.73it/s]


Epoch 15/100, Valid Loss: 11.692090, Valid Feedback Loss: 0.421313, Valid NPP Loss: 7.478960


100%|██████████| 245/245 [00:07<00:00, 31.28it/s]


Epoch 16/100, Train Loss: 9.539109, Train Feedback Loss: 0.247122, Train NPP Loss: 7.067885


100%|██████████| 263/263 [00:01<00:00, 150.51it/s]


Epoch 16/100, Valid Loss: 11.932073, Valid Feedback Loss: 0.444021, Valid NPP Loss: 7.491863


100%|██████████| 245/245 [00:07<00:00, 31.40it/s]


Epoch 17/100, Train Loss: 9.405526, Train Feedback Loss: 0.236054, Train NPP Loss: 7.044984


100%|██████████| 263/263 [00:01<00:00, 149.44it/s]


Epoch 17/100, Valid Loss: 12.032263, Valid Feedback Loss: 0.456069, Valid NPP Loss: 7.471575


100%|██████████| 245/245 [00:07<00:00, 31.19it/s]


Epoch 18/100, Train Loss: 9.294823, Train Feedback Loss: 0.225077, Train NPP Loss: 7.044048


100%|██████████| 263/263 [00:01<00:00, 148.57it/s]


Epoch 18/100, Valid Loss: 12.270269, Valid Feedback Loss: 0.479740, Valid NPP Loss: 7.472866


100%|██████████| 245/245 [00:07<00:00, 31.39it/s]


Epoch 19/100, Train Loss: 9.170962, Train Feedback Loss: 0.213693, Train NPP Loss: 7.034028


100%|██████████| 263/263 [00:01<00:00, 146.15it/s]


Epoch 19/100, Valid Loss: 12.528954, Valid Feedback Loss: 0.504380, Valid NPP Loss: 7.485155


100%|██████████| 245/245 [00:07<00:00, 31.39it/s]


Epoch 20/100, Train Loss: 9.066911, Train Feedback Loss: 0.203776, Train NPP Loss: 7.029149


100%|██████████| 263/263 [00:01<00:00, 148.21it/s]

Epoch 20/100, Valid Loss: 12.576748, Valid Feedback Loss: 0.509613, Valid NPP Loss: 7.480619





In [None]:
def best_metric_model(valid_loader):
    test_model = torch.load('./pretrain.pt', weights_only=False)
    valid_losses = []
    with torch.inference_mode():
        for batch in tqdm(valid_loader):
            if len(batch['history']['targets_inds']) == 0:
                    continue
            batch = move_to_device(batch, device)
            output = test_model(batch)
            loss = output['loss']
            valid_losses.append(loss.item())
    print(mean(valid_losses))

best_metric_model(valid_loader)

100%|██████████| 263/263 [00:01<00:00, 146.41it/s]

11.413507751318125





## 8. Подготовка данных для finetune

Схемы для candidates:
```python
CANDIDATES_SCHEMA = pl.Struct({
    'source_type': pl.List(pl.Int64),
    'action_type': pl.List(pl.Int64),
    'product_id': pl.List(pl.Int64),
    'lengths': pl.List(pl.Int64), # длина каждого реквеста
    'num_requests': pl.List(pl.Int64) # общее количество реквестов у этого пользователя
})
```

Пример семпла:

```python
finetune_train_sample = {
    'history': {...},
    'candidates': {
        'source_type': [1, 2, 3],
        'action_type': [1, 0, 1, 0, 1, 0, 1, 1, 1],
        'product_id': [10, 20, 30, 40, 50, 60, 70, 80, 90],
        'lengths': [3, 3, 3],
        'num_requests': 3
    }
}
```

In [None]:
class Candidates:
    def __init__(self, max_requests_size):
        self._data = deque([], maxlen=max_requests_size)

    def append(self, x):
        if x['request_id']:
            self._data.append(x)

    def popleft(self):
        return self._data.popleft()

    def __getitem__(self, idx):
        return self._data[idx]

    def __len__(self):
        return len(self._data)

    def get(self):
        """
        Aggregates data from the internal _data attribute into a structured dictionary format.

        This method constructs a dictionary with keys 'source_type', 'action_type', 'product_id', 'lengths', and 'num_requests'.
        - 'source_type' contains the source types from each sample.
        - 'action_type' contains all action types from each sample's action_type_list flattened into a single list.
        - 'product_id' contains all product IDs from each sample's product_id_list flattened into a single list.
        - 'lengths' contains the length of the product_id_list for each sample.
        - 'num_requests' contains the total number of samples.

        Returns:
            Dict[str, Any]: A dictionary with aggregated data.
        """

        candidate_deque = {'source_type':[],
                           'action_type':[],
                           'product_id':[],
                           'lengths':[],
                           'num_requests':[len(self)]}

        for x in self:
            candidate_deque['source_type'].append(x['source_type'])
            candidate_deque['action_type'].extend(x['action_type_list'])
            candidate_deque['product_id'].extend(x['product_id_list'])
            candidate_deque['lengths'].append(len(x['action_type_list']))

        return candidate_deque

In [None]:
import random
import bisect


class FinetuneTrainMapper(Mapper):
    def __call__(self, group: pl.DataFrame) -> pl.DataFrame:
        """
        Processes a group of interactions to generate history and candidate sets for recommendation.

        This method processes a DataFrame containing interaction data, separating actions into history and candidates based on the presence of 'action_type_list'.
        It ensures the data is sorted by timestamp, filters candidates based on time constraints, and selects historical interactions within a specified lag range for each candidate.
        If there are no valid candidates or insufficient history, it returns an empty DataFrame.

        @param group: A Polars DataFrame containing interaction data with at least 'timestamp' and 'action_type_list' columns.
        @return: A Polars DataFrame with 'history' and 'candidates' columns, or an empty DataFrame if no valid candidates are found.
        """
        history_deque = HistoryDeque()
        candidate_deque = Candidates(self._max_length)

        history_generator = ensure_sorted_by_timestamp(group.to_struct())
        for event in history_generator:
            if event['action_type_list'] is None:
                history_deque.append(event)

        candidate_generator = ensure_sorted_by_timestamp(group.to_struct())
        for event in candidate_generator:
            if event['action_type_list']:
                max_time = event['timestamp'] - random.randrange(2, 32) * 86400
                target_ind = bisect.bisect_right(history_deque, max_time, key=lambda x: x['timestamp'])
                if target_ind == 0:
                    continue
                event['targets_inds'] = target_ind - 1
                candidate_deque.append(event)

        targets_inds = [candidate['targets_inds'] for candidate in candidate_deque]


        if len(candidate_deque) > self._min_length and len(history_deque) > self._min_length:
            return pl.DataFrame([{'history': history_deque.get(targets_inds),
                                  'candidates': candidate_deque.get()}],
                                 schema=pl.Schema({'history': Mapper.HISTORY_SCHEMA,
                                                  'candidates': Mapper.CANDIDATES_SCHEMA}))
        else:
            return self.get_empty_frame(candidates=True)

class FinetuneValidMapper(Mapper):
    def __call__(self, group: pl.DataFrame) -> pl.DataFrame:
        """
        Differs only in the formation of target_inds
        """
        history_deque = HistoryDeque()
        candidate_deque = Candidates(self._max_length)

        history_generator = ensure_sorted_by_timestamp(group.to_struct())
        for event in history_generator:
            if event['action_type_list'] is None:
                history_deque.append(event)

        candidate_generator = ensure_sorted_by_timestamp(group.to_struct())
        for event in candidate_generator:
            if event['action_type_list']:
                max_time = event['timestamp']
                target_ind = bisect.bisect_right(history_deque, max_time, key=lambda x: x['timestamp'])
                if target_ind == 0:
                    continue
                event['targets_inds'] = target_ind - 1
                candidate_deque.append(event)

        targets_inds = [candidate['targets_inds'] for candidate in candidate_deque]


        if len(candidate_deque) > self._min_length and len(history_deque) > self._min_length:
            return pl.DataFrame([{'history': history_deque.get(targets_inds),
                                  'candidates': candidate_deque.get()}],
                                 schema=pl.Schema({'history': Mapper.HISTORY_SCHEMA,
                                                  'candidates': Mapper.CANDIDATES_SCHEMA}))
        else:
            return self.get_empty_frame(candidates=True)

In [None]:
def get_finetune_data(train_history: pl.DataFrame,
                      train_targets: pl.DataFrame,
                      valid_targets: pl.DataFrame,
                      min_length: int = 5,
                      max_length: int = 4096) -> pl.DataFrame:
    mapper = FinetuneTrainMapper(
        min_length=min_length,
        max_length=max_length,
    )

    train_data = (
        pl.concat([
            train_history,
            train_targets.with_columns([
                pl.col('product_id').alias('product_id_list'),
                pl.col('action_type').alias('action_type_list')
            ]).drop(['product_id', 'action_type'])
        ], how='diagonal')
        .sort(['user_id', 'timestamp'])
        .group_by('user_id')
        .map_groups(mapper)
    )

    mapper = FinetuneValidMapper(
        min_length=min_length,
        max_length=max_length,
    )

    valid_data = (
        pl.concat([
            train_history,
            valid_targets.with_columns([
                pl.col('product_id').alias('product_id_list'),
                pl.col('action_type').alias('action_type_list')
            ]).drop(['product_id', 'action_type'])
        ], how='diagonal')
        .sort(['user_id', 'timestamp'])
        .group_by('user_id')
        .map_groups(mapper)
    )

    return train_data, valid_data

In [None]:
finetune_train_data, finetune_valid_data = get_finetune_data(train_history, train_targets, valid_targets, min_length=5, max_length=512)

In [None]:
print("Creating train dataset ...")
train_ds = LavkaDataset.from_dataframe(finetune_train_data)
print("Creating valid dataset ...")
valid_ds = LavkaDataset.from_dataframe(finetune_valid_data)

Creating train dataset ...
Creating valid dataset ...


## 9. Реализуем finetune модель

#### Функция make_groups: разметка элементов по «группам»

Дано: вектор длин последовательностей  
$$
\mathbf{l} = [\,l_1, l_2, \dots, l_B\,],\quad l_i\in\mathbb{N},\;
B=\text{batch size}.
$$  
Нужно получить вектор «номеров групп» длиной  
$$
N = \sum_{i=1}^B l_i
$$
так, чтобы первые $l_1$ элементов имели номер группы 0, следующие $l_2$ - номер 1 и т.д.  
Результат:  
$$
\mathrm{groups} = [\,\underbrace{0,\dots,0}_{l_1},\;
\underbrace{1,\dots,1}_{l_2},\;\dots\;,\underbrace{B-1,\dots,B-1}_{l_B}\,]\,.
$$


In [None]:
def make_groups(lengths: torch.Tensor) -> torch.Tensor:
    range_tensor = torch.arange(0, len(lengths), device=lengths.device)
    return torch.repeat_interleave(range_tensor, lengths)

In [None]:
def test_make_groups_basic():
    lengths = torch.tensor([2, 3, 1])
    expected = torch.tensor([0, 0, 1, 1, 1, 2])
    result = make_groups(lengths)
    assert torch.equal(result, expected)

test_make_groups_basic()

#### Функция make_pairs: построение всех упорядоченных пар внутри групп

Цель: для каждого «блока» длины $l_i$ сгенерировать всех $l_i\times l_i$ упорядоченных пар индексов  
$$
( p, q ),\quad p,q\in\{0,\dots,l_i-1\},
$$
а затем «развернуть» их по всему батчу. Результат - двумерный тензор shape $(2,\,\sum_i l_i^2)$, где

- первая строка `pairs` - индексы «первого» элемента пары в пределах своего блока,  
- вторая строка `pairs` - индексы «второго».

Математически пары нумеруются так:
$$
\{\, (p,q)\;\big|\;p=0..l_i-1,\;q=0..l_i-1\;\}\quad\forall i=1..B.
$$

In [None]:
def make_pairs(lengths: torch.Tensor) -> torch.Tensor:
    num_pairs_per_group = lengths**2
    total_pairs = torch.sum(num_pairs_per_group)

    group_idx = make_groups(num_pairs_per_group)

    pair_offsets = torch.cumsum(num_pairs_per_group, dim=0) - num_pairs_per_group
    local_pair_idx = torch.arange(total_pairs, device=lengths.device) - pair_offsets.repeat_interleave(num_pairs_per_group)

    local_p = local_pair_idx // lengths[group_idx]
    local_q = local_pair_idx % lengths[group_idx]

    offsets = torch.cumsum(lengths, dim=0) - lengths
    global_offsets = offsets[group_idx]

    pairs_first = local_p + global_offsets
    pairs_second = local_q + global_offsets

    return torch.stack([pairs_first, pairs_second], dim=0)

In [None]:
def test_make_pairs_simple():
    lengths = torch.tensor([1, 2], dtype=torch.long)
    expected = torch.tensor([
        [0, 1, 1, 2, 2],
        [0, 1, 2, 1, 2]
    ], dtype=torch.long)

    pairs = make_pairs(lengths)
    assert pairs.shape == (2, 5)
    assert torch.equal(pairs, expected)

test_make_pairs_simple()

#### Класс CalibratedPairwiseLogistic: попарная калиброванная логистическая функция потерь

Идея была предложена здесь: [Calibrated Pairwise Logistic](https://arxiv.org/pdf/2211.01494). Пусть у нас есть:

- логиты всех элементов: $\mathbf{c} \in \mathbb{R}^N$,
- таргеты $\mathbf{t}\in\mathbb{R}^N$,

Шаги:

1. Генерируем все упорядоченные пары индексов внутри групп:  
   $$
   \mathrm{pairs} = \bigl[\;I_0,\;I_1\bigr],\quad
   I_0,I_1\in\{0,\dots,N-1\}
   $$
2. Для каждой пары извлекаем  
   $$
   c_i = c_{I_0},\quad c_j = c_{I_1},\quad
   t_i = t_{I_0},\quad t_j = t_{I_1}.
   $$
3. Отбираем только «положительные» пары, где $t_i > t_j$. Вводим индикатор  
   $$
   w_{ij} =
     \begin{cases}
       1,&t_i > t_j,\\
       0,&\text{иначе}.
     \end{cases}
   $$
   И считаем $W=\sum w_{ij}$.
4. Если $W>0$, вычисляем попарный loss для каждой положительной пары:
   
   а) сначала вычисляем «калиброванную вероятность» того, что $i$ лучше $j$:
   $$
     p_{ij}
     = \frac{\sigma(c_i)}{\sigma(c_i)+\sigma(c_j)},
     \quad
     \sigma(x)=\frac1{1+e^{-x}}.
   $$
   б) берём отрицательный логарифм правдоподобия:
   $$
     \ell_{ij}
     = -\log p_{ij}
     = -\log\frac{\sigma(c_i)}{\sigma(c_i)+\sigma(c_j)}.
   $$
   
5. Итоговая loss - усреднённая:
   $$
     \mathcal{L}
     = \frac{1}{W}\sum_{i,j} w_{ij}\;\ell_{ij}.
   $$
6. Если $W=0$ (нет ни одной пары с $t_i>t_j$), возвращаем нуль.

Таким образом, CalibratedPairwiseLogistic минимизирует  
$$
-\frac{1}{W}\sum_{t_i>t_j}\log\frac{\sigma(c_i)}{\sigma(c_i)+\sigma(c_j)},
$$
то есть учит давать более высокие оценки $c_i$ элементам с большим таргетом $t_i$.

In [None]:
import torch.nn.functional as F

class CalibratedPairwiseLogistic(nn.Module):
    def forward(self, logits, targets, lengths):
        pairs = make_pairs(lengths)
        targets_pairs = targets[pairs]
        logits_pairs = logits[pairs]

        w = targets_pairs[0] > targets_pairs[1]
        ci = logits_pairs[0][w]
        cj = logits_pairs[1][w]

        if ci.numel() == 0:
            return logits.new_tensor(0.0)

        term1 = F.softplus(-ci)

        log_sig_ci = -F.softplus(-ci)
        log_sig_cj = -F.softplus(-cj)
        term2 = torch.logaddexp(log_sig_ci, log_sig_cj)

        loss = term1 + term2

        #loss = F.softplus(-(ci - cj))

        return torch.mean(loss)

В FinetuneModel к сырым логитам  
$$\ell_i = \langle u_i, v_i\rangle$$  
применяется калибровка:

1. Параметр «scale» (обозначим $s$) хранится в виде логарифма, то есть в модели он задан как $\text{scale}$, а реальный множитель берётся как  
   $$
     \alpha = \exp(\text{scale}).
   $$

2. Параметр «bias» (обозначим $b$) - это свободный смещающий коэффициент.

Калиброванный логит получается по формуле  
$$
  \hat\ell_i \;=\; \frac{\ell_i}{\alpha} \;+\; b
  \;=\;
  \frac{\langle u_i, v_i\rangle}{\exp(\text{s})} \;+\; b.
$$

Благодаря этому механизмy модель может автоматически подстраивать и жёсткость (разброс) логитов (через $\alpha$), и их среднее значение (через $b$), что важно для оптимальной работы попарной логистической функции потерь.

In [None]:
class FinetuneModel(nn.Module):
    def __init__(self,
                 backbone,
                 embedding_dim=64):
        super().__init__()
        self.backbone = backbone
        self.user_context_fusion = nn.Sequential(
            ResNet(2 * embedding_dim),
            ResNet(2 * embedding_dim),
            ResNet(2 * embedding_dim),
            nn.Linear(2 * embedding_dim, embedding_dim),
        )
        self.candidate_projector = nn.Sequential(
            ResNet(embedding_dim),
            ResNet(embedding_dim),
            ResNet(embedding_dim),
        )
        self._embedding_dim = embedding_dim
        self.scale = torch.nn.Parameter(torch.zeros(1, dtype=torch.float32))
        self.bias = torch.nn.Parameter(torch.zeros(1, dtype=torch.float32))
        self.pairwise_loss = CalibratedPairwiseLogistic()

    @property
    def embedding_dim(self):
        return self._embedding_dim

    def forward(self, inputs):
        backbone_outputs = self.backbone(inputs)
        source_embeddings = backbone_outputs['source_embeddings']

        lengths = inputs['history']['lengths']
        offsets = torch.cumsum(lengths, dim=0) - lengths

        target_inds = torch.repeat_interleave(offsets, inputs['history']['targets_lengths']) + inputs['history']['targets_inds']
        source_embeddings = source_embeddings[target_inds]
        context_embeddings = self.backbone.context_encoder(inputs['candidates']['source_type'])
        candidate_embeddings = self.backbone.item_encoder(inputs['candidates']['product_id'])

        source_embeddings = torch.nn.functional.normalize(
                                self.user_context_fusion(
                                torch.cat([source_embeddings, context_embeddings], dim=-1)))

        candidate_embeddings = torch.nn.functional.normalize(self.candidate_projector(candidate_embeddings))
        source_embeddings = torch.repeat_interleave(source_embeddings, inputs['candidates']['lengths'], dim=0)
        output_logits = torch.sum((candidate_embeddings * source_embeddings), dim=-1) / torch.exp(self.scale) + self.bias

        return {
            'logits': output_logits,
            'loss': self.pairwise_loss(output_logits,
                                       inputs['candidates']['action_type'],
                                       inputs['candidates']['lengths'])
        }

In [None]:
def test_finetune_model():
    sample = {
        'history': {
            'source_type': torch.tensor([8, 8, 8, 8, 8]),
            'action_type': torch.tensor([1, 1, 2, 2, 1]),
            'product_id': torch.tensor([ 3551, 17044, 10396, 10396, 10396]),
            'position': torch.tensor([0, 1, 2, 3, 4]),
            'targets_inds': torch.tensor([1]),
            'targets_lengths': torch.tensor([1]),
            'lengths': torch.tensor([5])
        },
        'candidates': {
            'source_type': torch.tensor([8]),
            'action_type': torch.tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]),
            'product_id': torch.tensor([18391,  6750, 21647,  5339,  3171,  6150,  3454, 20012, 19954, 10690, 24020,  5551,  5699, 17388, 10396]),
        'lengths': torch.tensor([15]),
        'num_requests': torch.tensor([1])
        }
    }
    backbone = ModelBackbone()
    model_finetune = FinetuneModel(backbone)
    output = model_finetune(sample)

    assert output['logits'].shape == (15,)
    assert output['loss'].shape == ()

test_finetune_model()

## 10. Обучаем finetune модель

In [None]:
def train_finetune_model(model, train_loader, valid_loader, optimizer, scheduler, num_epochs, device):
    prev_valid_ndcg = None
    global_cnt = 0
    for epoch in range(num_epochs):
        model.train()
        train_losses = []
        for batch in tqdm(train_loader):
            batch = move_to_device(batch, device)
            optimizer.zero_grad()
            loss = model(batch)['loss']
            loss.backward()
            optimizer.step()
            train_losses.append(loss.item())
        scheduler.step()
        print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {mean(train_losses):.6f}")

        model.eval()
        valid_losses = []
        valid_logits = []
        valid_targets = []
        with torch.inference_mode():
            for batch in tqdm(valid_loader):
                batch = move_to_device(batch, device)
                output = model(batch)
                loss = output['loss']
                valid_losses.append(loss.item())
                logits = output['logits']
                targets = batch['candidates']['action_type']
                lengths = batch['candidates']['lengths']
                i = 0
                for length in lengths:
                    if length > 1:
                        valid_logits.append(logits[i:i + length].cpu().numpy())
                        valid_targets.append(targets[i:i + length].cpu().numpy())
                    i += length

        avg_valid_ndcg = 0
        for logits, targets in zip(valid_logits, valid_targets):
            avg_valid_ndcg += ndcg_score(targets[None,], logits[None,], k=10, ignore_ties=True)
        avg_valid_ndcg /= len(valid_logits)

        print(f"Epoch {epoch+1}/{num_epochs}, Valid Loss: {mean(valid_losses):.6f}")
        print(f"Epoch {epoch+1}/{num_epochs}, Valid NDCG@10: {avg_valid_ndcg}")

        if prev_valid_ndcg is None or prev_valid_ndcg < avg_valid_ndcg:
            global_cnt = 0
            prev_valid_ndcg = avg_valid_ndcg
            with torch.no_grad():
                torch.save(model, './finetune.pt')
        else:
            global_cnt += 1
            if global_cnt == 10:
                break

In [None]:
lr = 0.001
batch_size = 4
warmup_epochs = 6
start_factor = 0.1
num_epochs = 100

embedding_dim = 64
num_heads = 2
max_seq_len = 512
dropout_rate = 0.1
num_transformer_layers = 2

In [None]:
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
valid_loader = DataLoader(valid_ds, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)

In [None]:
backbone = ModelBackbone(embedding_dim=embedding_dim,
                        num_heads=num_heads,
                        max_seq_len=max_seq_len,
                        dropout_rate=dropout_rate,
                        num_transformer_layers=num_transformer_layers).to(device)

model_finetune = FinetuneModel(backbone=backbone,
                             embedding_dim=embedding_dim).to(device).to(device)
optimizer = optim.AdamW(model_finetune.parameters(), lr=lr, weight_decay=0.01)
scheduler = optim.lr_scheduler.LinearLR(
    optimizer,
    start_factor=start_factor,
    total_iters=warmup_epochs
)

In [None]:
train_finetune_model(model_finetune, train_loader, valid_loader, optimizer, scheduler, num_epochs, device)

100%|██████████| 103/103 [00:01<00:00, 59.03it/s]


Epoch 1/100, Train Loss: 0.694108


100%|██████████| 44/44 [00:01<00:00, 37.28it/s]


Epoch 1/100, Valid Loss: 0.693395
Epoch 1/100, Valid NDCG@10: 0.2546861600107447


100%|██████████| 103/103 [00:01<00:00, 68.56it/s]


Epoch 2/100, Train Loss: 0.692111


100%|██████████| 44/44 [00:01<00:00, 37.36it/s]


Epoch 2/100, Valid Loss: 0.693322
Epoch 2/100, Valid NDCG@10: 0.2570427670724341


100%|██████████| 103/103 [00:01<00:00, 65.38it/s]


Epoch 3/100, Train Loss: 0.688936


100%|██████████| 44/44 [00:01<00:00, 37.85it/s]


Epoch 3/100, Valid Loss: 0.693499
Epoch 3/100, Valid NDCG@10: 0.25713644874995256


100%|██████████| 103/103 [00:01<00:00, 67.18it/s]


Epoch 4/100, Train Loss: 0.684719


100%|██████████| 44/44 [00:01<00:00, 37.60it/s]


Epoch 4/100, Valid Loss: 0.693407
Epoch 4/100, Valid NDCG@10: 0.25852186271025285


100%|██████████| 103/103 [00:01<00:00, 68.16it/s]


Epoch 5/100, Train Loss: 0.676360


100%|██████████| 44/44 [00:01<00:00, 37.81it/s]


Epoch 5/100, Valid Loss: 0.694664
Epoch 5/100, Valid NDCG@10: 0.2616602168785861


100%|██████████| 103/103 [00:01<00:00, 67.35it/s]


Epoch 6/100, Train Loss: 0.662996


100%|██████████| 44/44 [00:01<00:00, 37.90it/s]


Epoch 6/100, Valid Loss: 0.695768
Epoch 6/100, Valid NDCG@10: 0.2617906381541963


100%|██████████| 103/103 [00:01<00:00, 66.86it/s]


Epoch 7/100, Train Loss: 0.642791


100%|██████████| 44/44 [00:01<00:00, 37.83it/s]


Epoch 7/100, Valid Loss: 0.696026
Epoch 7/100, Valid NDCG@10: 0.26647013288501536


100%|██████████| 103/103 [00:01<00:00, 67.80it/s]


Epoch 8/100, Train Loss: 0.615586


100%|██████████| 44/44 [00:01<00:00, 37.70it/s]


Epoch 8/100, Valid Loss: 0.698590
Epoch 8/100, Valid NDCG@10: 0.2713029140449932


100%|██████████| 103/103 [00:01<00:00, 65.19it/s]


Epoch 9/100, Train Loss: 0.579750


100%|██████████| 44/44 [00:01<00:00, 38.04it/s]


Epoch 9/100, Valid Loss: 0.706566
Epoch 9/100, Valid NDCG@10: 0.268429292915627


100%|██████████| 103/103 [00:01<00:00, 68.95it/s]


Epoch 10/100, Train Loss: 0.539574


100%|██████████| 44/44 [00:01<00:00, 38.34it/s]


Epoch 10/100, Valid Loss: 0.713357
Epoch 10/100, Valid NDCG@10: 0.276569378244181


100%|██████████| 103/103 [00:01<00:00, 67.60it/s]


Epoch 11/100, Train Loss: 0.493453


100%|██████████| 44/44 [00:01<00:00, 37.43it/s]


Epoch 11/100, Valid Loss: 0.733010
Epoch 11/100, Valid NDCG@10: 0.2744708748629545


100%|██████████| 103/103 [00:01<00:00, 66.82it/s]


Epoch 12/100, Train Loss: 0.447443


100%|██████████| 44/44 [00:01<00:00, 38.00it/s]


Epoch 12/100, Valid Loss: 0.753066
Epoch 12/100, Valid NDCG@10: 0.27648931996836923


100%|██████████| 103/103 [00:01<00:00, 67.07it/s]


Epoch 13/100, Train Loss: 0.403074


100%|██████████| 44/44 [00:01<00:00, 37.00it/s]


Epoch 13/100, Valid Loss: 0.783191
Epoch 13/100, Valid NDCG@10: 0.27505940661850004


100%|██████████| 103/103 [00:01<00:00, 67.42it/s]


Epoch 14/100, Train Loss: 0.361130


100%|██████████| 44/44 [00:01<00:00, 37.74it/s]


Epoch 14/100, Valid Loss: 0.814066
Epoch 14/100, Valid NDCG@10: 0.27694571772869075


100%|██████████| 103/103 [00:01<00:00, 68.30it/s]


Epoch 15/100, Train Loss: 0.326589


100%|██████████| 44/44 [00:01<00:00, 32.61it/s]


Epoch 15/100, Valid Loss: 0.851196
Epoch 15/100, Valid NDCG@10: 0.2765302838861126


100%|██████████| 103/103 [00:01<00:00, 68.73it/s]


Epoch 16/100, Train Loss: 0.292372


100%|██████████| 44/44 [00:01<00:00, 37.88it/s]


Epoch 16/100, Valid Loss: 0.891140
Epoch 16/100, Valid NDCG@10: 0.2770081934039496


100%|██████████| 103/103 [00:01<00:00, 66.47it/s]


Epoch 17/100, Train Loss: 0.263815


100%|██████████| 44/44 [00:01<00:00, 37.65it/s]


Epoch 17/100, Valid Loss: 0.937039
Epoch 17/100, Valid NDCG@10: 0.27672211828648846


100%|██████████| 103/103 [00:01<00:00, 65.01it/s]


Epoch 18/100, Train Loss: 0.239859


100%|██████████| 44/44 [00:01<00:00, 38.02it/s]


Epoch 18/100, Valid Loss: 0.976588
Epoch 18/100, Valid NDCG@10: 0.2759399086248945


100%|██████████| 103/103 [00:01<00:00, 67.95it/s]


Epoch 19/100, Train Loss: 0.217944


100%|██████████| 44/44 [00:01<00:00, 38.07it/s]


Epoch 19/100, Valid Loss: 1.020218
Epoch 19/100, Valid NDCG@10: 0.27831837665400627


100%|██████████| 103/103 [00:01<00:00, 66.12it/s]


Epoch 20/100, Train Loss: 0.198325


100%|██████████| 44/44 [00:01<00:00, 37.31it/s]


Epoch 20/100, Valid Loss: 1.065931
Epoch 20/100, Valid NDCG@10: 0.2771304165878406


100%|██████████| 103/103 [00:01<00:00, 69.48it/s]


Epoch 21/100, Train Loss: 0.183882


100%|██████████| 44/44 [00:01<00:00, 38.05it/s]


Epoch 21/100, Valid Loss: 1.105885
Epoch 21/100, Valid NDCG@10: 0.2777282734656781


100%|██████████| 103/103 [00:01<00:00, 68.57it/s]


Epoch 22/100, Train Loss: 0.168667


100%|██████████| 44/44 [00:01<00:00, 37.98it/s]


Epoch 22/100, Valid Loss: 1.146273
Epoch 22/100, Valid NDCG@10: 0.27696814350310045


100%|██████████| 103/103 [00:01<00:00, 68.27it/s]


Epoch 23/100, Train Loss: 0.157582


100%|██████████| 44/44 [00:01<00:00, 38.32it/s]


Epoch 23/100, Valid Loss: 1.180521
Epoch 23/100, Valid NDCG@10: 0.27702362427335275


100%|██████████| 103/103 [00:01<00:00, 68.38it/s]


Epoch 24/100, Train Loss: 0.145772


100%|██████████| 44/44 [00:01<00:00, 37.82it/s]


Epoch 24/100, Valid Loss: 1.217736
Epoch 24/100, Valid NDCG@10: 0.277624958646671


100%|██████████| 103/103 [00:01<00:00, 67.39it/s]


Epoch 25/100, Train Loss: 0.136098


100%|██████████| 44/44 [00:01<00:00, 37.80it/s]


Epoch 25/100, Valid Loss: 1.262788
Epoch 25/100, Valid NDCG@10: 0.2796917890621476


100%|██████████| 103/103 [00:01<00:00, 67.23it/s]


Epoch 26/100, Train Loss: 0.128355


100%|██████████| 44/44 [00:01<00:00, 37.63it/s]


Epoch 26/100, Valid Loss: 1.290439
Epoch 26/100, Valid NDCG@10: 0.2801923030409099


100%|██████████| 103/103 [00:01<00:00, 65.99it/s]


Epoch 27/100, Train Loss: 0.118789


100%|██████████| 44/44 [00:01<00:00, 38.44it/s]


Epoch 27/100, Valid Loss: 1.322291
Epoch 27/100, Valid NDCG@10: 0.281143559199074


100%|██████████| 103/103 [00:01<00:00, 68.60it/s]


Epoch 28/100, Train Loss: 0.114026


100%|██████████| 44/44 [00:01<00:00, 37.39it/s]


Epoch 28/100, Valid Loss: 1.359577
Epoch 28/100, Valid NDCG@10: 0.27960267665894967


100%|██████████| 103/103 [00:01<00:00, 67.32it/s]


Epoch 29/100, Train Loss: 0.108824


100%|██████████| 44/44 [00:01<00:00, 32.34it/s]


Epoch 29/100, Valid Loss: 1.394563
Epoch 29/100, Valid NDCG@10: 0.28024568728227045


100%|██████████| 103/103 [00:01<00:00, 67.98it/s]


Epoch 30/100, Train Loss: 0.101837


100%|██████████| 44/44 [00:01<00:00, 38.00it/s]


Epoch 30/100, Valid Loss: 1.412267
Epoch 30/100, Valid NDCG@10: 0.2787179308377197


100%|██████████| 103/103 [00:01<00:00, 67.66it/s]


Epoch 31/100, Train Loss: 0.095012


100%|██████████| 44/44 [00:01<00:00, 37.47it/s]


Epoch 31/100, Valid Loss: 1.435590
Epoch 31/100, Valid NDCG@10: 0.2831360686156984


100%|██████████| 103/103 [00:01<00:00, 68.82it/s]


Epoch 32/100, Train Loss: 0.091292


100%|██████████| 44/44 [00:01<00:00, 37.79it/s]


Epoch 32/100, Valid Loss: 1.465331
Epoch 32/100, Valid NDCG@10: 0.2798770118282553


100%|██████████| 103/103 [00:01<00:00, 68.39it/s]


Epoch 33/100, Train Loss: 0.086790


100%|██████████| 44/44 [00:01<00:00, 37.83it/s]


Epoch 33/100, Valid Loss: 1.475408
Epoch 33/100, Valid NDCG@10: 0.2817704699945992


100%|██████████| 103/103 [00:01<00:00, 68.45it/s]


Epoch 34/100, Train Loss: 0.081552


100%|██████████| 44/44 [00:01<00:00, 38.14it/s]


Epoch 34/100, Valid Loss: 1.506873
Epoch 34/100, Valid NDCG@10: 0.28079910201217434


100%|██████████| 103/103 [00:01<00:00, 68.25it/s]


Epoch 35/100, Train Loss: 0.079628


100%|██████████| 44/44 [00:01<00:00, 37.91it/s]


Epoch 35/100, Valid Loss: 1.530875
Epoch 35/100, Valid NDCG@10: 0.2819948348623824


100%|██████████| 103/103 [00:01<00:00, 65.56it/s]


Epoch 36/100, Train Loss: 0.075387


100%|██████████| 44/44 [00:01<00:00, 37.71it/s]


Epoch 36/100, Valid Loss: 1.534424
Epoch 36/100, Valid NDCG@10: 0.2851178730571711


100%|██████████| 103/103 [00:01<00:00, 67.34it/s]


Epoch 37/100, Train Loss: 0.072174


100%|██████████| 44/44 [00:01<00:00, 37.84it/s]


Epoch 37/100, Valid Loss: 1.576321
Epoch 37/100, Valid NDCG@10: 0.2834175607235177


100%|██████████| 103/103 [00:01<00:00, 66.11it/s]


Epoch 38/100, Train Loss: 0.069576


100%|██████████| 44/44 [00:01<00:00, 37.26it/s]


Epoch 38/100, Valid Loss: 1.586873
Epoch 38/100, Valid NDCG@10: 0.28079375741346896


100%|██████████| 103/103 [00:01<00:00, 68.01it/s]


Epoch 39/100, Train Loss: 0.066024


100%|██████████| 44/44 [00:01<00:00, 37.73it/s]


Epoch 39/100, Valid Loss: 1.608439
Epoch 39/100, Valid NDCG@10: 0.2813023848924957


100%|██████████| 103/103 [00:01<00:00, 66.91it/s]


Epoch 40/100, Train Loss: 0.063204


100%|██████████| 44/44 [00:01<00:00, 37.29it/s]


Epoch 40/100, Valid Loss: 1.632899
Epoch 40/100, Valid NDCG@10: 0.28328605454411476


100%|██████████| 103/103 [00:01<00:00, 68.37it/s]


Epoch 41/100, Train Loss: 0.060043


100%|██████████| 44/44 [00:01<00:00, 37.68it/s]


Epoch 41/100, Valid Loss: 1.642190
Epoch 41/100, Valid NDCG@10: 0.2836053274862247


100%|██████████| 103/103 [00:01<00:00, 68.25it/s]


Epoch 42/100, Train Loss: 0.056804


100%|██████████| 44/44 [00:01<00:00, 38.13it/s]


Epoch 42/100, Valid Loss: 1.669807
Epoch 42/100, Valid NDCG@10: 0.28271136990400586


100%|██████████| 103/103 [00:01<00:00, 67.82it/s]


Epoch 43/100, Train Loss: 0.056408


100%|██████████| 44/44 [00:01<00:00, 32.35it/s]


Epoch 43/100, Valid Loss: 1.700580
Epoch 43/100, Valid NDCG@10: 0.2823041095260559


100%|██████████| 103/103 [00:01<00:00, 67.92it/s]


Epoch 44/100, Train Loss: 0.053669


100%|██████████| 44/44 [00:01<00:00, 37.68it/s]


Epoch 44/100, Valid Loss: 1.715311
Epoch 44/100, Valid NDCG@10: 0.2797988911833067


100%|██████████| 103/103 [00:01<00:00, 66.18it/s]


Epoch 45/100, Train Loss: 0.050857


100%|██████████| 44/44 [00:01<00:00, 37.84it/s]


Epoch 45/100, Valid Loss: 1.738091
Epoch 45/100, Valid NDCG@10: 0.2788766263616418


100%|██████████| 103/103 [00:01<00:00, 67.55it/s]


Epoch 46/100, Train Loss: 0.048640


100%|██████████| 44/44 [00:01<00:00, 37.93it/s]


Epoch 46/100, Valid Loss: 1.753403
Epoch 46/100, Valid NDCG@10: 0.2813816070701094


Пробуем инициализироваться предобученной моделью, только backbone:

In [None]:
model_pretrain = torch.load('./pretrain.pt', weights_only=False)
model_finetune = FinetuneModel(model_pretrain.backbone, embedding_dim).to(device)
optimizer = optim.AdamW(model_finetune.parameters(), lr=lr, weight_decay=0.01)
scheduler = optim.lr_scheduler.LinearLR(
    optimizer,
    start_factor=start_factor,
    total_iters=warmup_epochs
)

In [None]:
train_finetune_model(model_finetune, train_loader, valid_loader, optimizer, scheduler, num_epochs, device)

100%|██████████| 103/103 [00:01<00:00, 65.05it/s]


Epoch 1/100, Train Loss: 0.691471


100%|██████████| 44/44 [00:01<00:00, 37.32it/s]


Epoch 1/100, Valid Loss: 0.690597
Epoch 1/100, Valid NDCG@10: 0.2715581525414558


100%|██████████| 103/103 [00:01<00:00, 66.37it/s]


Epoch 2/100, Train Loss: 0.686504


100%|██████████| 44/44 [00:01<00:00, 38.20it/s]


Epoch 2/100, Valid Loss: 0.688594
Epoch 2/100, Valid NDCG@10: 0.2853143746444059


100%|██████████| 103/103 [00:01<00:00, 66.62it/s]


Epoch 3/100, Train Loss: 0.682444


100%|██████████| 44/44 [00:01<00:00, 37.38it/s]


Epoch 3/100, Valid Loss: 0.687203
Epoch 3/100, Valid NDCG@10: 0.29198543922684267


100%|██████████| 103/103 [00:01<00:00, 66.47it/s]


Epoch 4/100, Train Loss: 0.675731


100%|██████████| 44/44 [00:01<00:00, 37.53it/s]


Epoch 4/100, Valid Loss: 0.685760
Epoch 4/100, Valid NDCG@10: 0.2946631830565442


100%|██████████| 103/103 [00:01<00:00, 67.23it/s]


Epoch 5/100, Train Loss: 0.663543


100%|██████████| 44/44 [00:01<00:00, 37.61it/s]


Epoch 5/100, Valid Loss: 0.682968
Epoch 5/100, Valid NDCG@10: 0.2980635697373388


100%|██████████| 103/103 [00:01<00:00, 67.62it/s]


Epoch 6/100, Train Loss: 0.646557


100%|██████████| 44/44 [00:01<00:00, 37.71it/s]


Epoch 6/100, Valid Loss: 0.681894
Epoch 6/100, Valid NDCG@10: 0.3027882754625199


100%|██████████| 103/103 [00:01<00:00, 67.04it/s]


Epoch 7/100, Train Loss: 0.622757


100%|██████████| 44/44 [00:01<00:00, 37.93it/s]


Epoch 7/100, Valid Loss: 0.681384
Epoch 7/100, Valid NDCG@10: 0.2993919543717133


100%|██████████| 103/103 [00:01<00:00, 67.60it/s]


Epoch 8/100, Train Loss: 0.592705


100%|██████████| 44/44 [00:01<00:00, 38.07it/s]


Epoch 8/100, Valid Loss: 0.680859
Epoch 8/100, Valid NDCG@10: 0.30762836971960333


100%|██████████| 103/103 [00:01<00:00, 65.93it/s]


Epoch 9/100, Train Loss: 0.557855


100%|██████████| 44/44 [00:01<00:00, 38.06it/s]


Epoch 9/100, Valid Loss: 0.688612
Epoch 9/100, Valid NDCG@10: 0.30201546988229494


100%|██████████| 103/103 [00:01<00:00, 64.96it/s]


Epoch 10/100, Train Loss: 0.518535


100%|██████████| 44/44 [00:01<00:00, 37.57it/s]


Epoch 10/100, Valid Loss: 0.698482
Epoch 10/100, Valid NDCG@10: 0.2975723309701491


100%|██████████| 103/103 [00:01<00:00, 67.11it/s]


Epoch 11/100, Train Loss: 0.479594


100%|██████████| 44/44 [00:01<00:00, 32.70it/s]


Epoch 11/100, Valid Loss: 0.710182
Epoch 11/100, Valid NDCG@10: 0.3031960916175086


100%|██████████| 103/103 [00:01<00:00, 67.86it/s]


Epoch 12/100, Train Loss: 0.441269


100%|██████████| 44/44 [00:01<00:00, 36.34it/s]


Epoch 12/100, Valid Loss: 0.733430
Epoch 12/100, Valid NDCG@10: 0.30074281482762627


100%|██████████| 103/103 [00:01<00:00, 69.49it/s]


Epoch 13/100, Train Loss: 0.402006


100%|██████████| 44/44 [00:01<00:00, 37.93it/s]


Epoch 13/100, Valid Loss: 0.758549
Epoch 13/100, Valid NDCG@10: 0.30010136594632275


100%|██████████| 103/103 [00:01<00:00, 66.92it/s]


Epoch 14/100, Train Loss: 0.368085


100%|██████████| 44/44 [00:01<00:00, 37.65it/s]


Epoch 14/100, Valid Loss: 0.788902
Epoch 14/100, Valid NDCG@10: 0.29746587528192325


100%|██████████| 103/103 [00:01<00:00, 67.88it/s]


Epoch 15/100, Train Loss: 0.333925


100%|██████████| 44/44 [00:01<00:00, 38.21it/s]


Epoch 15/100, Valid Loss: 0.823396
Epoch 15/100, Valid NDCG@10: 0.297276212739027


100%|██████████| 103/103 [00:01<00:00, 66.90it/s]


Epoch 16/100, Train Loss: 0.306275


100%|██████████| 44/44 [00:01<00:00, 37.63it/s]


Epoch 16/100, Valid Loss: 0.849111
Epoch 16/100, Valid NDCG@10: 0.29814354090040174


100%|██████████| 103/103 [00:01<00:00, 67.19it/s]


Epoch 17/100, Train Loss: 0.280859


100%|██████████| 44/44 [00:01<00:00, 37.31it/s]


Epoch 17/100, Valid Loss: 0.894657
Epoch 17/100, Valid NDCG@10: 0.301113224742707


100%|██████████| 103/103 [00:01<00:00, 67.58it/s]


Epoch 18/100, Train Loss: 0.258980


100%|██████████| 44/44 [00:01<00:00, 37.96it/s]


Epoch 18/100, Valid Loss: 0.932736
Epoch 18/100, Valid NDCG@10: 0.2962556867154115


Попробуем еще дополнительно иницилизировать user_context_fusion и candidate_projector:

In [None]:
model_pretrain = torch.load('./pretrain.pt', weights_only=False)
model_finetune = FinetuneModel(model_pretrain.backbone, embedding_dim).to(device)
model_finetune.user_context_fusion = model_pretrain.user_context_fusion
model_finetune.candidate_projector = model_pretrain.candidate_projector

optimizer = optim.AdamW(model_finetune.parameters(), lr=lr, weight_decay=0.01)
scheduler = optim.lr_scheduler.LinearLR(
    optimizer,
    start_factor=start_factor,
    total_iters=warmup_epochs
)

In [None]:
train_finetune_model(model_finetune, train_loader, valid_loader, optimizer, scheduler, num_epochs, device)

100%|██████████| 208/208 [00:03<00:00, 59.95it/s]


Epoch 1/100, Train Loss: 0.679132


100%|██████████| 79/79 [00:01<00:00, 59.53it/s]


Epoch 1/100, Valid Loss: 0.682856
Epoch 1/100, Valid NDCG@10: 0.30732482430123936


100%|██████████| 208/208 [00:03<00:00, 59.70it/s]


Epoch 2/100, Train Loss: 0.669400


100%|██████████| 79/79 [00:01<00:00, 60.29it/s]


Epoch 2/100, Valid Loss: 0.679741
Epoch 2/100, Valid NDCG@10: 0.30849815430583744


100%|██████████| 208/208 [00:03<00:00, 60.62it/s]


Epoch 3/100, Train Loss: 0.659003


100%|██████████| 79/79 [00:01<00:00, 61.26it/s]


Epoch 3/100, Valid Loss: 0.678653
Epoch 3/100, Valid NDCG@10: 0.30931249287854684


100%|██████████| 208/208 [00:03<00:00, 61.85it/s]


Epoch 4/100, Train Loss: 0.644037


100%|██████████| 79/79 [00:01<00:00, 62.01it/s]


Epoch 4/100, Valid Loss: 0.678889
Epoch 4/100, Valid NDCG@10: 0.31046323207923315


100%|██████████| 208/208 [00:03<00:00, 61.43it/s]


Epoch 5/100, Train Loss: 0.625504


100%|██████████| 79/79 [00:01<00:00, 60.87it/s]


Epoch 5/100, Valid Loss: 0.683930
Epoch 5/100, Valid NDCG@10: 0.31124725443319823


100%|██████████| 208/208 [00:03<00:00, 60.70it/s]


Epoch 6/100, Train Loss: 0.601912


100%|██████████| 79/79 [00:01<00:00, 60.44it/s]


Epoch 6/100, Valid Loss: 0.694785
Epoch 6/100, Valid NDCG@10: 0.3086770732148856


100%|██████████| 208/208 [00:03<00:00, 61.33it/s]


Epoch 7/100, Train Loss: 0.580264


100%|██████████| 79/79 [00:01<00:00, 61.04it/s]


Epoch 7/100, Valid Loss: 0.705092
Epoch 7/100, Valid NDCG@10: 0.3069931432152151


100%|██████████| 208/208 [00:03<00:00, 59.95it/s]


Epoch 8/100, Train Loss: 0.557128


100%|██████████| 79/79 [00:01<00:00, 60.53it/s]


Epoch 8/100, Valid Loss: 0.730049
Epoch 8/100, Valid NDCG@10: 0.30959840999771815


100%|██████████| 208/208 [00:03<00:00, 61.33it/s]


Epoch 9/100, Train Loss: 0.532404


100%|██████████| 79/79 [00:01<00:00, 60.96it/s]


Epoch 9/100, Valid Loss: 0.759088
Epoch 9/100, Valid NDCG@10: 0.3093430684839358


100%|██████████| 208/208 [00:03<00:00, 61.65it/s]


Epoch 10/100, Train Loss: 0.505664


100%|██████████| 79/79 [00:01<00:00, 61.38it/s]


Epoch 10/100, Valid Loss: 0.792779
Epoch 10/100, Valid NDCG@10: 0.3062908252445432


100%|██████████| 208/208 [00:03<00:00, 59.96it/s]


Epoch 11/100, Train Loss: 0.476132


100%|██████████| 79/79 [00:01<00:00, 59.85it/s]


Epoch 11/100, Valid Loss: 0.806685
Epoch 11/100, Valid NDCG@10: 0.3092291720752101


100%|██████████| 208/208 [00:03<00:00, 61.59it/s]


Epoch 12/100, Train Loss: 0.440795


100%|██████████| 79/79 [00:01<00:00, 61.34it/s]


Epoch 12/100, Valid Loss: 0.855497
Epoch 12/100, Valid NDCG@10: 0.30261216950507674


100%|██████████| 208/208 [00:03<00:00, 59.41it/s]


Epoch 13/100, Train Loss: 0.405742


100%|██████████| 79/79 [00:01<00:00, 60.11it/s]


Epoch 13/100, Valid Loss: 0.878947
Epoch 13/100, Valid NDCG@10: 0.3052794059313995


100%|██████████| 208/208 [00:03<00:00, 60.80it/s]


Epoch 14/100, Train Loss: 0.372725


100%|██████████| 79/79 [00:01<00:00, 60.12it/s]


Epoch 14/100, Valid Loss: 0.933721
Epoch 14/100, Valid NDCG@10: 0.30291167875036706


100%|██████████| 208/208 [00:03<00:00, 59.90it/s]


Epoch 15/100, Train Loss: 0.343517


100%|██████████| 79/79 [00:01<00:00, 50.97it/s]


Epoch 15/100, Valid Loss: 0.995619
Epoch 15/100, Valid NDCG@10: 0.30357367883426856


In [None]:
def best_metrics_finetune_model(valid_loader):
    valid_logits = []
    valid_targets = []
    with torch.inference_mode():
        test_model = torch.load('./finetune.pt', weights_only=False)
        for batch in tqdm(valid_loader):
            batch = move_to_device(batch, device)
            output = test_model(batch)
            logits = output['logits']
            targets = batch['candidates']['action_type']
            lengths = batch['candidates']['lengths']
            i = 0
            for length in lengths:
                if length > 1:
                    valid_logits.append(logits[i:i + length].cpu().numpy())
                    valid_targets.append(targets[i:i + length].cpu().numpy())
                i += length

    avg_valid_ndcg = 0
    for logits, targets in zip(valid_logits, valid_targets):
        avg_valid_ndcg += ndcg_score(targets[None,], logits[None,], k=10, ignore_ties=True)
    avg_valid_ndcg /= len(valid_logits)
    print(avg_valid_ndcg)

test_finetune_model(valid_loader)

100%|██████████| 79/79 [00:01<00:00, 59.99it/s]


0.31124725443319823


AssertionError: 