# PyShopper example
---
- This notebook contains a quick example of PyShopper that includes:
1. Loading data
2. Instantiating and fitting the Shopper model via MCMC sampling
3. Inference diagnostics
4. Prediction on unseen test data

In [1]:
# Imports

import numpy as np
import pandas as pd
import warnings

from pyshopper import shopper

# Ignore FutureWarning and UserWarning
warnings.simplefilter(action="ignore", category=FutureWarning)
warnings.simplefilter(action="ignore", category=UserWarning)



## 1. Load data

In [2]:
# Load data

X_train = shopper.load_data('data/train.tsv',
                            'data/prices.tsv')
X_train

Unnamed: 0,user_id,item_id,session_id,quantity,price
0,1,100,1,1,1.0
1,2,100,1,1,1.0
2,4,100,1,1,1.0
3,5,100,1,1,1.0
4,6,100,1,1,1.0
...,...,...,...,...,...
306042,208,200,123,1,5.0
306043,209,200,123,1,5.0
306044,227,200,123,1,5.0
306045,238,200,123,1,5.0


In [3]:
# Limited data to 100 random samples of baskets

sample_size = 100

groupby_baskets = X_train.groupby(['user_id', 'session_id'])
baskets_idx = np.arange(groupby_baskets.ngroups)
np.random.shuffle(baskets_idx)

X_train_limited = X_train[groupby_baskets.ngroup()\
                                         .isin(baskets_idx[:sample_size])]
X_train_limited = X_train_limited.sort_values(['session_id', 'user_id'])\
                                 .reset_index(drop=True)
X_train_limited

Unnamed: 0,user_id,item_id,session_id,quantity,price
0,214,300,4,1,1.0
1,214,301,4,1,1.0
2,214,200,4,1,1.0
3,214,201,4,1,1.0
4,218,400,5,1,1.0
...,...,...,...,...,...
302,74,300,394,1,1.0
303,74,301,394,1,1.0
304,93,400,399,1,1.0
305,93,401,399,1,1.0


## 2. Instantiate and fit model

In [4]:
# Create Shopper instance

model = shopper.Shopper(X_train_limited)

INFO:root:Building the Shopper model...
INFO:root:Done building the Shopper model.


In [None]:
# Fit model

res = model.fit(draws=1000, random_seed=42)

Auto-assigning NUTS sampler...
INFO:pymc3:Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
INFO:pymc3:Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
INFO:pymc3:Multiprocess sampling (4 chains in 4 jobs)
NUTS: [beta_c, gamma_u, lambda_c, theta_u, alpha_c, rho_c]
INFO:pymc3:NUTS: [beta_c, gamma_u, lambda_c, theta_u, alpha_c, rho_c]


## 3. Diagnostics

In [None]:
# Trace plot

res.trace_plot()