In [1]:
import pandas as pd
import numpy as np
from scipy.sparse import csr_matrix
import implicit
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score, average_precision_score, brier_score_loss
from sklearn.ensemble import HistGradientBoostingClassifier

from dataclasses import dataclass
from sklearn.isotonic import IsotonicRegression

from src.modeling.business import *
from src.modeling.data_prep import *
from src.modeling.inference import *
from src.modeling.reranker_data import *
from src.modeling.reranker_model import *
from src.modeling.retrieval_als import *
from src.modeling.artifacts import *
from src.modeling.pipeline import recommend_for_client


  from .autonotebook import tqdm as notebook_tqdm


In [None]:
clients = pd.read_csv("data/clients.csv")
products = pd.read_csv("data/products.csv")
stocks = pd.read_csv("data/stocks.csv")
stores = pd.read_csv("data/stores.csv")
transactions = pd.read_csv("data/transactions.csv")

In [None]:
clients.ClientGender.unique()

In [None]:
products.Universe.unique()

# Training

In [None]:
transactions = (
    transactions
    .assign(
        SaleTransactionDate = lambda x: pd.to_datetime(x["SaleTransactionDate"])
    )
)
transactions.head()

In [None]:
cutoff_date = transactions["SaleTransactionDate"].quantile(0.8)
print(cutoff_date)


train_tx, test_tx = make_time_split(transactions, cutoff_date)
user2idx, idx2user, item2idx, idx2item = make_id_maps(clients, products)

In [None]:
X = build_interaction_matrix(train_tx, user2idx, item2idx)

In [None]:
model = train_als(X)

In [None]:
training_df, feature_cols = build_reranker_training_set(
    train_tx=train_tx,
    test_tx=test_tx,
    model=model,
    X_user_item=X,
    user2idx=user2idx,
    idx2item=idx2item,
    N_candidates=200,
    n_neg_per_pos=5,
    filtered=True,
    random_state=1
)

reranker_model, metrics = train_binary_reranker(training_df, feature_cols)
metrics


In [None]:
user_feats, item_feats, max_train_date = prepare_reranker_artifacts(train_tx)
stock_ctry = build_stock_country_lookup(stores, stocks)
item_value = compute_item_value(train_tx)

artifacts = RecoArtifacts(
    als_model=model,
    X_user_item=X,
    user2idx=user2idx,
    idx2item=idx2item,
    reranker_model=reranker_model,
    feature_cols=feature_cols,
    user_feats=user_feats,
    item_feats=item_feats,
    max_train_date=max_train_date,
    stock_ctry=stock_ctry,
    item_value=item_value
)

save_artifacts(artifacts, out_dir="model")
print("Saved artifacts to ./model")


# Inference

In [2]:
clients = pd.read_csv("data/clients.csv")
products = pd.read_csv("data/products.csv")
artifacts = load_artifacts("model")

print(f"Some ids to try out: {list(clients['ClientID'].sample(10))}")
client_id = 4508698145640552159

top10 = recommend_for_client(
    artifacts,
    client_id=client_id,
    clients_df=clients,
    products_df=products,
    N_candidates=200,
    top_k=10,
    min_stock=1.0,
    stock_boost=0.02,
    diversity_boost=0.05,
    diversity_level="FamilyLevel1",
    enforce_gender=False 
)

top10


Some ids to try out: [7621309638700040503, 5990057330784701660, 5865661836541892279, 371075579136468736, 4396526906868656580, 8481428462035789713, 3166068070429227201, 4296486502849546093, 2682890189772573988, 6962613514141281811]


Unnamed: 0,ClientID,ProductID,Category,FamilyLevel2,Universe,als_score,p_buy,StockQty,item_value,business_score
0,4508698145640552159,1053601088228117848,Handball,Select Ultimate,Women,0.902335,0.5,4.0,68.108414,34.097305
1,4508698145640552159,8761826855035940162,Handball,Mizuno Wave Mirage,Women,0.802692,0.22831,4.0,70.582756,16.158758
2,4508698145640552159,2893616851514749639,Handball,Mizuno Wave Mirage,Men,0.566767,0.22831,2.0,68.227104,15.610723
3,4508698145640552159,5166899781459858865,Handball,Mizuno Wave Mirage,Women,0.710164,0.184971,1.0,71.437767,13.23957
4,4508698145640552159,6108384229018093630,Handball,Mizuno Wave Mirage,Men,0.679491,0.186709,1.0,66.935577,12.523114
5,4508698145640552159,2995021495769981207,Handball,Asics Gel-Blast,Men,0.590998,0.135853,3.0,67.807419,9.251348
6,4508698145640552159,387526636274140264,Handball,Asics Gel-Blast,Men,0.254994,0.1198,3.0,67.5495,8.131964
7,4508698145640552159,1613365891365522842,Handball,Molten H3X5001,Men,0.240477,0.100524,1.0,73.430626,7.406282
8,4508698145640552159,6340857046265281758,Hockey,A&R Sports Ice Hockey Puck,Women,0.927525,0.52381,12.0,8.834937,4.694935
9,4508698145640552159,6313624011728899683,Hockey,Bauer Nexus 2N Pro,Women,0.971816,0.494505,12.0,8.96075,4.493099
