In [1]:
# changing core directory
import os, sys
dir2 = os.path.abspath('')
dir1 = os.path.dirname(dir2)
if not dir1 in sys.path:
    sys.path.append(dir1)
os.chdir('..')

In [2]:
from collections import defaultdict

import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from sklearn.metrics import ndcg_score
import lightgbm as lgb
from tqdm import tqdm

from src.data import get_sequences, load_data, download_dataset, ValSASRecDataset, TrainSASRecDataset
from src.utils import fix_seed
from src.model import SASRec

%load_ext autoreload
%autoreload

In [3]:
dataset = 'beauty'
data_path = f"data/{dataset}"

download_dataset(data_path, dataset)
load_data(data_path, dataset)

Dataset already exists.


Unnamed: 0,user_id,item_id,rating,timestamp
0,AGKASBHYZPGTEPO6LWZPVJWB2BVA,B00V6R3R3S,5.0,1452647102000
1,AGKASBHYZPGTEPO6LWZPVJWB2BVA,B00PA7VMD2,3.0,1452648690000
2,AGKASBHYZPGTEPO6LWZPVJWB2BVA,B00JIIUJ5Q,4.0,1454675735000
3,AGKASBHYZPGTEPO6LWZPVJWB2BVA,B007Z2R15I,3.0,1458094710000
4,AGKASBHYZPGTEPO6LWZPVJWB2BVA,B00YAZBWZI,5.0,1458095420000
...,...,...,...,...
6624436,AHK7KOTTU4XURTJ76PW5KDF7S7MQ,B000PYAC86,5.0,1691973175744
6624437,AHK7KOTTU4XURTJ76PW5KDF7S7MQ,B07SPY1GJG,5.0,1691973217288
6624438,AHK7KOTTU4XURTJ76PW5KDF7S7MQ,B0B9HV7S8B,5.0,1691973274033
6624439,AHK7KOTTU4XURTJ76PW5KDF7S7MQ,B07L8QNGZF,5.0,1691973478471


In [4]:
raw_file = "data/beauty/beauty.csv.gz"

df = pd.read_csv(raw_file, compression="gzip")
df = df.rename(columns={"parent_asin": "item_id"})
df.head()

Unnamed: 0,user_id,item_id,rating,timestamp
0,AGKASBHYZPGTEPO6LWZPVJWB2BVA,B00V6R3R3S,5.0,1452647102000
1,AGKASBHYZPGTEPO6LWZPVJWB2BVA,B00PA7VMD2,3.0,1452648690000
2,AGKASBHYZPGTEPO6LWZPVJWB2BVA,B00JIIUJ5Q,4.0,1454675735000
3,AGKASBHYZPGTEPO6LWZPVJWB2BVA,B007Z2R15I,3.0,1458094710000
4,AGKASBHYZPGTEPO6LWZPVJWB2BVA,B00YAZBWZI,5.0,1458095420000


In [5]:
ratings = df[df["rating"] > 3.5]
min_interactions = 5

num_users = ratings['user_id'].nunique()
num_items = ratings['item_id'].nunique()

print(num_users, num_items)

prev_len = 0
while len(ratings) != prev_len:
    prev_len = len(ratings)
    
    # Фильтрация пользователей
    user_counts = ratings["user_id"].value_counts()
    valid_users = user_counts[user_counts >= min_interactions].index
    ratings = ratings[ratings["user_id"].isin(valid_users)]
    
    # Фильтрация айтемов
    item_counts = ratings["item_id"].value_counts()
    valid_items = item_counts[item_counts >= min_interactions].index
    ratings = ratings[ratings["item_id"].isin(valid_items)]

# Подготовка данных
num_users = ratings['user_id'].nunique()
num_items = ratings['item_id'].nunique()


print(num_users, num_items)

user2id = {val:i for i, val in enumerate(ratings['user_id'].unique())}
item2id = {val:i+1 for i, val in enumerate(ratings['item_id'].unique())}

ratings['user_id'] = ratings['user_id'].map(user2id)
ratings['item_id'] = ratings['item_id'].map(item2id)

725156 207369
475139 158366


In [6]:
ratings

Unnamed: 0,user_id,item_id,rating,timestamp
10,0,1,5.0,1361240647000
12,0,2,5.0,1481817544000
13,0,3,5.0,1481817706000
14,0,4,5.0,1549383515613
15,0,5,5.0,1597804385775
...,...,...,...,...
6624436,475138,67619,5.0,1691973175744
6624437,475138,12,5.0,1691973217288
6624438,475138,13146,5.0,1691973274033
6624439,475138,42546,5.0,1691973478471
