# Demo: ML-based Collaborative Filtering on Utility Matrix with Reduced Dimensions Applied to Clustered Data

In [1]:
import pandas as pd
import numpy as np
import sys
sys.path.insert(1, '../resype')
%load_ext autoreload
%autoreload 2 

## Prepare data

In [2]:
np.random.seed(202109)
rating_vals = np.hstack([np.arange(1,6), [np.nan]])
rating_vals

array([ 1.,  2.,  3.,  4.,  5., nan])

In [3]:
userids = np.arange(1000)
itemids = np.arange(1000)
random_ratings = np.random.choice(rating_vals, size=len(userids)*len(itemids))

In [4]:
transactions = pd.DataFrame(
    {'user_id': userids.repeat(len(itemids)),
     'item_id': itemids.reshape((-1, 1)).repeat(len(userids), axis=1).T.flatten(),
     'rating': random_ratings}).drop_duplicates()

In [5]:
transactions

Unnamed: 0,user_id,item_id,rating
0,0,0,2.0
1,0,1,
2,0,2,
3,0,3,5.0
4,0,4,4.0
...,...,...,...
999995,999,995,1.0
999996,999,996,3.0
999997,999,997,
999998,999,998,2.0


## Load resype

In [6]:
from resype import Resype

In [7]:
re = Resype(transactions)

In [8]:
utility_matrix = re.construct_utility_matrix()
utility_matrix

item_id,0,1,2,3,4,5,6,7,8,9,...,990,991,992,993,994,995,996,997,998,999
user_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
0,2.0,,,5.0,4.0,4.0,3.0,4.0,4.0,3.0,...,5.0,1.0,4.0,2.0,3.0,,2.0,,5.0,2.0
1,1.0,,4.0,5.0,3.0,2.0,1.0,3.0,1.0,,...,3.0,3.0,2.0,4.0,4.0,3.0,4.0,4.0,3.0,4.0
2,3.0,4.0,4.0,4.0,2.0,4.0,2.0,4.0,1.0,4.0,...,5.0,4.0,3.0,1.0,,5.0,2.0,2.0,,5.0
3,5.0,2.0,1.0,,2.0,4.0,3.0,3.0,,1.0,...,2.0,,2.0,3.0,5.0,2.0,,5.0,,1.0
4,2.0,1.0,3.0,1.0,2.0,2.0,3.0,1.0,3.0,5.0,...,,3.0,1.0,4.0,4.0,1.0,2.0,1.0,2.0,3.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
995,5.0,,,2.0,1.0,5.0,1.0,2.0,2.0,3.0,...,,1.0,5.0,1.0,5.0,2.0,2.0,,3.0,5.0
996,1.0,5.0,5.0,2.0,5.0,4.0,1.0,,1.0,5.0,...,,,1.0,,,,4.0,5.0,1.0,2.0
997,2.0,5.0,,4.0,5.0,4.0,,3.0,,3.0,...,1.0,4.0,1.0,5.0,4.0,3.0,,5.0,2.0,2.0
998,3.0,2.0,2.0,1.0,1.0,4.0,2.0,1.0,5.0,1.0,...,3.0,4.0,4.0,5.0,5.0,,4.0,5.0,4.0,4.0


## Cluster data 

In [9]:
from sklearn.cluster import (KMeans, SpectralClustering,
                             AgglomerativeClustering, DBSCAN, OPTICS,
                             cluster_optics_dbscan, Birch)

model1 = KMeans(n_clusters = 10)
model2 = KMeans(n_clusters = 15)

In [10]:
x_u,y_u, df_u  = re.cluster_users(model1)
x_i,y_i, df_i  = re.cluster_items(model2)

## Generate new utility matrix based on clusters

In [11]:
# Running this overwrites the original utility matrix
Uc_df = re.utility_matrix_agg()
Uc_df

i_cluster,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14
u_cluster,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1
0,2.454815,2.509972,2.45679,2.464988,2.527778,2.555556,2.657239,2.391714,2.55211,2.533981,2.670139,2.528024,2.520129,2.505612,2.452222
1,2.472771,2.563485,2.458278,2.470162,2.46988,2.48494,2.533625,2.517868,2.47016,2.508714,2.546875,2.482567,2.475118,2.524827,2.55241
2,2.535014,2.498061,2.358543,2.48989,2.520435,2.485294,2.5233,2.542088,2.557944,2.478176,2.488445,2.521232,2.547193,2.483066,2.521765
3,2.543915,2.577534,2.542622,2.47433,2.480159,2.476757,2.492352,2.500673,2.538206,2.480814,2.512153,2.498244,2.492063,2.515873,2.527778
4,2.48338,2.573402,2.41758,2.522557,2.535531,2.533702,2.445839,2.519694,2.532918,2.481061,2.545445,2.52362,2.482445,2.494131,2.489789
5,2.466241,2.5,2.449173,2.543218,2.400629,2.483663,2.560348,2.480707,2.445571,2.45879,2.516124,2.508944,2.47086,2.567537,2.47266
6,2.537361,2.521234,2.626929,2.514404,2.494555,2.506324,2.468371,2.50459,2.463905,2.4732,2.55013,2.486173,2.502566,2.476484,2.487083
7,2.5369,2.495885,2.495199,2.500354,2.544426,2.470018,2.465918,2.504708,2.507704,2.54317,2.479102,2.48465,2.499911,2.461279,2.487613
8,2.459506,2.510684,2.47668,2.492332,2.492424,2.494048,2.556566,2.44005,2.562446,2.520316,2.499421,2.520977,2.535158,2.519921,2.522222
9,2.491814,2.52629,2.441163,2.492484,2.4229,2.496835,2.603452,2.502896,2.514572,2.500061,2.480024,2.513946,2.48138,2.527043,2.545316


## Train iterative model using `train_model_svd_cluster`

#### Create model object (load from sklearn)

In [12]:
from sklearn.ensemble import RandomForestRegressor
rs_model1 = RandomForestRegressor(random_state=202109)

#### Train model

In [13]:
re.utility_matrix

i_cluster,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14
u_cluster,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1
0,2.454815,2.509972,2.45679,2.464988,2.527778,2.555556,2.657239,2.391714,2.55211,2.533981,2.670139,2.528024,2.520129,2.505612,2.452222
1,2.472771,2.563485,2.458278,2.470162,2.46988,2.48494,2.533625,2.517868,2.47016,2.508714,2.546875,2.482567,2.475118,2.524827,2.55241
2,2.535014,2.498061,2.358543,2.48989,2.520435,2.485294,2.5233,2.542088,2.557944,2.478176,2.488445,2.521232,2.547193,2.483066,2.521765
3,2.543915,2.577534,2.542622,2.47433,2.480159,2.476757,2.492352,2.500673,2.538206,2.480814,2.512153,2.498244,2.492063,2.515873,2.527778
4,2.48338,2.573402,2.41758,2.522557,2.535531,2.533702,2.445839,2.519694,2.532918,2.481061,2.545445,2.52362,2.482445,2.494131,2.489789
5,2.466241,2.5,2.449173,2.543218,2.400629,2.483663,2.560348,2.480707,2.445571,2.45879,2.516124,2.508944,2.47086,2.567537,2.47266
6,2.537361,2.521234,2.626929,2.514404,2.494555,2.506324,2.468371,2.50459,2.463905,2.4732,2.55013,2.486173,2.502566,2.476484,2.487083
7,2.5369,2.495885,2.495199,2.500354,2.544426,2.470018,2.465918,2.504708,2.507704,2.54317,2.479102,2.48465,2.499911,2.461279,2.487613
8,2.459506,2.510684,2.47668,2.492332,2.492424,2.494048,2.556566,2.44005,2.562446,2.520316,2.499421,2.520977,2.535158,2.519921,2.522222
9,2.491814,2.52629,2.441163,2.492484,2.4229,2.496835,2.603452,2.502896,2.514572,2.500061,2.480024,2.513946,2.48138,2.527043,2.545316


In [14]:
%%time
utility_matrix_imputed = re.train_model_svd_cluster(
    re.utility_matrix, rs_model1, d=10)

CPU times: user 4min 2s, sys: 23.3 s, total: 4min 25s
Wall time: 2min 18s


#### Prediction

In [15]:
utility_matrix_imputed

i_cluster,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14
u_cluster,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1
0,2.466272,2.513166,2.459775,2.474797,2.517915,2.541614,2.62595,2.404765,2.543694,2.525905,2.622318,2.52381,2.516525,2.505338,2.462357
1,2.478062,2.558491,2.464877,2.478315,2.472403,2.487963,2.534694,2.510531,2.482757,2.505758,2.541863,2.489099,2.480062,2.520197,2.540449
2,2.529058,2.505441,2.379741,2.491896,2.515267,2.488869,2.524186,2.53265,2.544379,2.483282,2.496231,2.51874,2.537056,2.487435,2.517189
3,2.535307,2.568619,2.528143,2.477714,2.481867,2.483589,2.498421,2.499637,2.531398,2.483808,2.515864,2.499489,2.493944,2.513694,2.52513
4,2.486414,2.564253,2.427597,2.516909,2.522703,2.52607,2.466264,2.514719,2.526553,2.484249,2.541904,2.517948,2.487961,2.497894,2.493111
5,2.4729,2.504416,2.45618,2.534716,2.419165,2.486917,2.555516,2.48218,2.457763,2.469241,2.517845,2.508839,2.479368,2.5563,2.478499
6,2.530166,2.521985,2.599523,2.510559,2.494761,2.503811,2.47931,2.499716,2.471584,2.481983,2.545528,2.49088,2.502183,2.482279,2.489302
7,2.52846,2.502133,2.492569,2.499804,2.531517,2.475873,2.479517,2.501107,2.508983,2.528828,2.488871,2.490359,2.500488,2.4731,2.490236
8,2.467683,2.514963,2.475501,2.493566,2.490198,2.495348,2.552074,2.451479,2.549598,2.514788,2.505543,2.518837,2.525264,2.517475,2.518055
9,2.492034,2.527224,2.449305,2.492726,2.438136,2.497574,2.590144,2.499573,2.514165,2.499667,2.491165,2.51221,2.487065,2.522942,2.538749


## Train iterative model using `fit`

#### Train model

In [16]:
re.fit(rs_model1, method='svd', d=10, return_models=False)

#### Prediction

In [17]:
re.utility_matrix_preds

i_cluster,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14
u_cluster,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1
0,2.465671,2.509972,2.45679,2.475876,2.519187,2.555556,2.627252,2.391714,2.55211,2.533981,2.670139,2.517909,2.520129,2.505612,2.46491
1,2.472771,2.554303,2.463032,2.470162,2.46988,2.487777,2.543422,2.517868,2.47016,2.508714,2.537948,2.482567,2.478335,2.522495,2.541438
2,2.535014,2.503709,2.358543,2.48989,2.520435,2.485294,2.522476,2.527101,2.557944,2.486198,2.502189,2.518047,2.537405,2.492063,2.521765
3,2.543915,2.577534,2.537896,2.475875,2.482764,2.483126,2.492352,2.500673,2.53266,2.488145,2.512153,2.498244,2.492063,2.515873,2.524786
4,2.491203,2.573402,2.41758,2.522557,2.535531,2.527069,2.445839,2.519694,2.527769,2.486712,2.545445,2.52362,2.482445,2.498762,2.494333
5,2.472693,2.508708,2.455878,2.543218,2.400629,2.487341,2.560348,2.482107,2.445571,2.45879,2.523053,2.508944,2.47086,2.555348,2.47266
6,2.533839,2.521234,2.592658,2.508309,2.494555,2.506324,2.482867,2.50459,2.463905,2.480569,2.547765,2.490496,2.502495,2.483248,2.487083
7,2.5369,2.495885,2.495199,2.500354,2.532125,2.470018,2.465918,2.504708,2.507704,2.54317,2.499179,2.48465,2.499911,2.461279,2.490387
8,2.459506,2.510684,2.47668,2.492332,2.490115,2.494048,2.556566,2.452792,2.553516,2.515019,2.499421,2.520977,2.528187,2.519921,2.51778
9,2.492417,2.524758,2.441163,2.492484,2.4229,2.495658,2.590032,2.502896,2.514572,2.500061,2.485965,2.513946,2.48138,2.527043,2.545316
