In [None]:
import pandas as pd
import numpy as np
from rectools import Columns
from rectools.dataset import Dataset

from service.models.ann import ApproximateNearestNeighbors
from service.utils import load_model

# LOAD DATA

In [None]:
!mkdir ../data
!wget https://storage.yandexcloud.net/itmo-recsys-public-data/kion_train.zip -O ../data/data_original.zip
!unzip ../data/data_original.zip -d ../data

In [2]:
interactions = pd.read_csv('../data/kion_train/interactions.csv')
users = pd.read_csv('../data/kion_train/users.csv')
items = pd.read_csv('../data/kion_train/items.csv')

# Preprocess

In [3]:
Columns.Datetime = 'last_watch_dt'

In [4]:
interactions.drop(interactions[interactions[Columns.Datetime].str.len() != 10].index, inplace=True)

In [5]:
interactions[Columns.Datetime] = pd.to_datetime(interactions[Columns.Datetime], format='%Y-%m-%d')

In [6]:
max_date = interactions[Columns.Datetime].max()

In [7]:
interactions[Columns.Weight] = np.where(interactions['watched_pct'] > 10, 3, 1)

In [8]:
train = interactions[interactions[Columns.Datetime] < max_date - pd.Timedelta(days=7)].copy()
test = interactions[interactions[Columns.Datetime] >= max_date - pd.Timedelta(days=7)].copy()

print(f"train: {train.shape}")
print(f"test: {test.shape}")

train: (4985269, 6)
test: (490982, 6)


In [9]:
train.drop(train.query("total_dur < 300").index, inplace=True)

In [10]:
# отфильтруем холодных пользователей из теста
cold_users = set(test[Columns.User]) - set(train[Columns.User])

In [11]:
test.drop(test[test[Columns.User].isin(cold_users)].index, inplace=True)

# Models

In [20]:
K_RECOS = 10
RANDOM_STATE = 42
NUM_THREADS = 16
N_FACTORS = (32,)

In [26]:
dataset = Dataset.construct(interactions_df=train)

In [27]:
TEST_USERS = test[Columns.User].unique()

# Approximate Nearest Neighbors

In [30]:
model = load_model('../models/warp_12.pickle')

In [1]:
ann = ApproximateNearestNeighbors(model=model, dataset=dataset)

NameError: name 'ApproximateNearestNeighbors' is not defined

In [None]:
ann.fit(k_reco=10)

In [None]:
ann.predict(user_id=1245)
