# Demo: Iterative Approach to ML-based Item-wise Collaborative Filtering Applied to Clustered Data

In [1]:
import pandas as pd
import numpy as np
import sys
sys.path.insert(1, '../resype')
%load_ext autoreload
%autoreload 2 

## Prepare data

In [2]:
np.random.seed(202109)
rating_vals = np.hstack([np.arange(1,6), [np.nan]])
rating_vals

array([ 1.,  2.,  3.,  4.,  5., nan])

In [3]:
userids = np.arange(1000)
itemids = np.arange(1000)
random_ratings = np.random.choice(rating_vals, size=len(userids)*len(itemids))

In [4]:
transactions = pd.DataFrame(
    {'user_id': userids.repeat(len(itemids)),
     'item_id': itemids.reshape((-1, 1)).repeat(len(userids), axis=1).T.flatten(),
     'rating': random_ratings}).drop_duplicates()

In [5]:
transactions

Unnamed: 0,user_id,item_id,rating
0,0,0,2.0
1,0,1,
2,0,2,
3,0,3,5.0
4,0,4,4.0
...,...,...,...
999995,999,995,1.0
999996,999,996,3.0
999997,999,997,
999998,999,998,2.0


## Load resype

In [6]:
from collab_filtering import CollabFilteringModel

In [7]:
re = CollabFilteringModel(transactions)

In [8]:
utility_matrix = re.construct_utility_matrix()
utility_matrix

item_id,0,1,2,3,4,5,6,7,8,9,...,990,991,992,993,994,995,996,997,998,999
user_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
0,2.0,,,5.0,4.0,4.0,3.0,4.0,4.0,3.0,...,5.0,1.0,4.0,2.0,3.0,,2.0,,5.0,2.0
1,1.0,,4.0,5.0,3.0,2.0,1.0,3.0,1.0,,...,3.0,3.0,2.0,4.0,4.0,3.0,4.0,4.0,3.0,4.0
2,3.0,4.0,4.0,4.0,2.0,4.0,2.0,4.0,1.0,4.0,...,5.0,4.0,3.0,1.0,,5.0,2.0,2.0,,5.0
3,5.0,2.0,1.0,,2.0,4.0,3.0,3.0,,1.0,...,2.0,,2.0,3.0,5.0,2.0,,5.0,,1.0
4,2.0,1.0,3.0,1.0,2.0,2.0,3.0,1.0,3.0,5.0,...,,3.0,1.0,4.0,4.0,1.0,2.0,1.0,2.0,3.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
995,5.0,,,2.0,1.0,5.0,1.0,2.0,2.0,3.0,...,,1.0,5.0,1.0,5.0,2.0,2.0,,3.0,5.0
996,1.0,5.0,5.0,2.0,5.0,4.0,1.0,,1.0,5.0,...,,,1.0,,,,4.0,5.0,1.0,2.0
997,2.0,5.0,,4.0,5.0,4.0,,3.0,,3.0,...,1.0,4.0,1.0,5.0,4.0,3.0,,5.0,2.0,2.0
998,3.0,2.0,2.0,1.0,1.0,4.0,2.0,1.0,5.0,1.0,...,3.0,4.0,4.0,5.0,5.0,,4.0,5.0,4.0,4.0


## Cluster data 

In [9]:
from sklearn.cluster import (KMeans, SpectralClustering,
                             AgglomerativeClustering, DBSCAN, OPTICS,
                             cluster_optics_dbscan, Birch)

model1 = KMeans(n_clusters = 15)
model2 = KMeans(n_clusters = 20)

In [10]:
x_u,y_u, df_u  = re.cluster_users(model1)
x_i,y_i, df_i  = re.cluster_items(model2)

## Generate new utility matrix based on clusters

In [11]:
# Running this overwrites the original utility matrix
Uc_df = re.utility_matrix_agg()
Uc_df

i_cluster,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19
u_cluster,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1
0,2.459586,2.481728,2.564732,2.47013,2.536458,2.578073,2.524068,2.572545,2.442857,2.510445,2.545788,2.487155,2.446429,2.475377,2.557738,2.484524,2.554825,2.430019,2.555314,2.518707
1,2.5,2.507558,2.525,2.506061,2.488542,2.527132,2.555072,2.523958,2.441333,2.530818,2.499145,2.557018,2.528169,2.490683,2.646111,2.494815,2.439474,2.545495,2.464634,2.428571
2,2.433293,2.496349,2.452762,2.513531,2.518169,2.505408,2.470172,2.569767,2.39907,2.502852,2.44037,2.495104,2.552571,2.503972,2.544574,2.552972,2.4694,2.496543,2.48979,2.526578
3,2.568279,2.499686,2.589105,2.468796,2.570101,2.454431,2.534665,2.518581,2.505405,2.542325,2.509356,2.512802,2.528169,2.517962,2.538739,2.55015,2.51138,2.466764,2.51269,2.629987
4,2.411483,2.465513,2.500355,2.459711,2.457623,2.450317,2.639328,2.383523,2.505455,2.520583,2.511364,2.522329,2.480154,2.499859,2.484848,2.448232,2.560805,2.512285,2.497506,2.438853
5,2.541963,2.543055,2.515203,2.497297,2.502252,2.468887,2.491187,2.6875,2.551351,2.518103,2.42966,2.491228,2.529882,2.491858,2.447748,2.513514,2.534851,2.523009,2.520435,2.384813
6,2.456672,2.542929,2.479167,2.492195,2.523464,2.436105,2.465086,2.517992,2.533737,2.504955,2.518778,2.503456,2.518993,2.497522,2.476263,2.497755,2.544568,2.507917,2.492547,2.458393
7,2.406534,2.475742,2.579203,2.51442,2.538793,2.508821,2.50075,2.508621,2.34069,2.531555,2.45977,2.457955,2.511413,2.479439,2.364368,2.542146,2.480036,2.439888,2.526493,2.537767
8,2.560272,2.483121,2.601815,2.490909,2.476478,2.405101,2.677419,2.375,2.412903,2.469872,2.391232,2.460668,2.517038,2.544781,2.444086,2.445878,2.579513,2.52136,2.406373,2.50384
9,2.433114,2.475048,2.561849,2.493182,2.433594,2.453973,2.483696,2.516927,2.479167,2.475236,2.60203,2.474415,2.466549,2.566253,2.49375,2.451852,2.461988,2.490991,2.504065,2.64881


## Train iterative model using `train_model_iterative_cluster`

#### Create model object (load from sklearn)

In [12]:
from sklearn.ensemble import RandomForestRegressor
rs_model1 = RandomForestRegressor(random_state=202109)

#### Train model

In [13]:
re.utility_matrix

i_cluster,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19
u_cluster,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1
0,2.459586,2.481728,2.564732,2.47013,2.536458,2.578073,2.524068,2.572545,2.442857,2.510445,2.545788,2.487155,2.446429,2.475377,2.557738,2.484524,2.554825,2.430019,2.555314,2.518707
1,2.5,2.507558,2.525,2.506061,2.488542,2.527132,2.555072,2.523958,2.441333,2.530818,2.499145,2.557018,2.528169,2.490683,2.646111,2.494815,2.439474,2.545495,2.464634,2.428571
2,2.433293,2.496349,2.452762,2.513531,2.518169,2.505408,2.470172,2.569767,2.39907,2.502852,2.44037,2.495104,2.552571,2.503972,2.544574,2.552972,2.4694,2.496543,2.48979,2.526578
3,2.568279,2.499686,2.589105,2.468796,2.570101,2.454431,2.534665,2.518581,2.505405,2.542325,2.509356,2.512802,2.528169,2.517962,2.538739,2.55015,2.51138,2.466764,2.51269,2.629987
4,2.411483,2.465513,2.500355,2.459711,2.457623,2.450317,2.639328,2.383523,2.505455,2.520583,2.511364,2.522329,2.480154,2.499859,2.484848,2.448232,2.560805,2.512285,2.497506,2.438853
5,2.541963,2.543055,2.515203,2.497297,2.502252,2.468887,2.491187,2.6875,2.551351,2.518103,2.42966,2.491228,2.529882,2.491858,2.447748,2.513514,2.534851,2.523009,2.520435,2.384813
6,2.456672,2.542929,2.479167,2.492195,2.523464,2.436105,2.465086,2.517992,2.533737,2.504955,2.518778,2.503456,2.518993,2.497522,2.476263,2.497755,2.544568,2.507917,2.492547,2.458393
7,2.406534,2.475742,2.579203,2.51442,2.538793,2.508821,2.50075,2.508621,2.34069,2.531555,2.45977,2.457955,2.511413,2.479439,2.364368,2.542146,2.480036,2.439888,2.526493,2.537767
8,2.560272,2.483121,2.601815,2.490909,2.476478,2.405101,2.677419,2.375,2.412903,2.469872,2.391232,2.460668,2.517038,2.544781,2.444086,2.445878,2.579513,2.52136,2.406373,2.50384
9,2.433114,2.475048,2.561849,2.493182,2.433594,2.453973,2.483696,2.516927,2.479167,2.475236,2.60203,2.474415,2.466549,2.566253,2.49375,2.451852,2.461988,2.490991,2.504065,2.64881


In [14]:
%%time
utility_matrix_imputed = re.train_model_iterative_cluster(
    re.utility_matrix, rs_model1)

CPU times: user 3min 35s, sys: 1.55 s, total: 3min 36s
Wall time: 3min 36s


#### Prediction

In [15]:
utility_matrix_imputed

i_cluster,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19
u_cluster,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1
0,-0.039038,-0.022122,0.0453,-0.036716,0.01999,0.04436,0.011957,0.042639,-0.052608,0.004387,0.0292,-0.014744,-0.051624,-0.023198,0.040038,-0.016478,0.034931,-0.063235,0.039611,0.01224
1,-0.007077,-0.003058,0.016287,-0.007953,-0.015383,0.010219,0.033615,0.009101,-0.053635,0.017727,-0.010577,0.034276,0.012579,-0.012927,0.117389,-0.010746,-0.055454,0.025589,-0.032478,-0.067923
2,-0.048111,0.000198,-0.027908,0.011536,0.017701,0.004306,-0.019295,0.064164,-0.078891,0.006888,-0.042252,-0.001617,0.042408,0.006379,0.037838,0.042446,-0.01984,-0.003636,-0.006375,0.021758
3,0.026788,-0.022419,0.05344,-0.048951,0.034686,-0.054147,0.006116,-0.002173,-0.022897,0.015019,-0.009705,-0.011122,-0.00028,-0.00546,0.00986,0.016758,-0.010572,-0.051794,-0.008496,0.078387
4,-0.052601,-0.014234,0.012786,-0.024935,-0.024088,-0.035817,0.121849,-0.085713,0.012799,0.027142,0.021823,0.026909,-0.009737,0.010896,-0.0036,-0.033763,0.064394,0.020455,0.006621,-0.041351
5,0.02441,0.028842,0.006173,-0.01292,-0.004719,-0.036184,-0.015279,0.145937,0.031671,0.00826,-0.065516,-0.015894,0.013434,-0.012665,-0.047536,0.003986,0.022734,0.008966,0.006743,-0.100565
6,-0.031623,0.033249,-0.013314,-0.007262,0.019462,-0.05423,-0.024824,0.01785,0.027609,0.007455,0.011738,0.003573,0.010688,0.000969,-0.017861,0.00047,0.035631,0.007821,-0.002139,-0.033589
7,-0.057189,-0.008005,0.076034,0.019468,0.041875,0.012856,0.015624,0.025557,-0.117945,0.036005,-0.020845,-0.020936,0.016381,-0.001789,-0.082199,0.04461,-0.003032,-0.040903,0.032167,0.041692
8,0.060031,-0.002638,0.086681,0.000781,-0.011971,-0.071955,0.151291,-0.091279,-0.05828,-0.011058,-0.072192,-0.017735,0.0222,0.046403,-0.031919,-0.031753,0.070165,0.023076,-0.069526,0.012452
9,-0.053259,-0.017612,0.052396,-0.009449,-0.052178,-0.038003,-0.002272,0.01498,-0.017043,-0.01752,0.085133,-0.014214,-0.026713,0.052898,-0.003127,-0.038636,-0.025355,-0.006227,0.004296,0.1214


## Train iterative model using `fit`

#### Train model

In [16]:
re.fit(rs_model1, method='iterative')

#### Prediction

In [17]:
re.utility_matrix_preds

i_cluster,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19
u_cluster,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1
0,-0.049105,-0.020916,0.05604,-0.038562,0.027767,0.069381,0.015377,0.063853,-0.056062,0.001142,0.033964,-0.015928,-0.062263,-0.021516,0.048859,-0.024168,0.022215,-0.078672,0.028942,0.010016
1,-0.008513,-0.003207,0.014235,-0.004705,-0.022224,0.016366,0.027618,0.013193,-0.069432,0.020052,-0.011072,0.035971,0.017404,-0.009562,0.135346,-0.01595,-0.071292,0.03473,-0.03604,-0.066478
2,-0.049866,0.00465,-0.045232,0.015537,0.020175,0.007415,-0.018152,0.06943,-0.079609,0.008455,-0.057624,-0.002889,0.054578,0.005979,0.04658,0.054978,-0.028593,-0.00145,-0.006199,0.021468
3,0.026739,-0.025368,0.05431,-0.056257,0.045048,-0.070622,0.009612,-0.006472,-0.019648,0.017272,-0.008279,-0.012251,-0.000887,-0.007091,0.002625,0.018397,-0.009745,-0.058289,-0.008653,0.104934
4,-0.026869,-0.022401,0.019945,-0.028203,-0.030291,-0.037597,0.151414,-0.104391,0.011025,0.018639,0.02314,0.028802,-0.008212,0.011735,-0.002689,-0.033195,0.072892,0.024371,0.009592,-0.049061
5,0.007887,0.031554,0.003702,-0.011881,-0.009248,-0.036155,-0.020314,0.175999,0.039851,0.006602,-0.08184,-0.020273,0.018381,-0.010763,-0.051397,0.002013,0.02335,0.011509,0.008934,-0.099627
6,-0.031917,0.034755,-0.019988,-0.00696,0.024309,-0.052428,-0.034069,0.018838,0.034583,0.005801,0.006499,0.004301,0.019838,0.001013,-0.019262,-0.001399,0.045414,0.008762,-0.006607,-0.040761
7,-0.083518,-0.012468,0.089151,0.024368,0.048741,0.018769,0.011622,0.016399,-0.123056,0.032265,-0.030282,-0.032097,0.021361,-0.010613,-0.018579,0.052094,-0.010016,-0.043031,0.036441,0.034407
8,0.069397,-0.007754,0.11094,-0.002517,-0.017619,-0.085774,0.161858,-0.091702,-0.054989,-0.021003,-0.099643,-0.022135,0.026163,0.053906,-0.03303,-0.036014,0.088638,0.018683,-0.046843,0.012965
9,-0.064271,-0.022337,0.064464,-0.015816,-0.063792,-0.041631,-0.01369,0.019542,-0.009724,-0.014233,0.104645,-0.02297,-0.030836,0.056294,-0.006586,-0.045533,-0.035397,-0.006394,0.00668,0.119648
