# Demo: ML-based Collaborative Filtering on Utility Matrix with Reduced Dimensions Applied to Clustered Data

In [1]:
import pandas as pd
import numpy as np
import sys
sys.path.insert(1, '../resype')
%load_ext autoreload
%autoreload 2 

## Prepare data

In [2]:
np.random.seed(202109)
rating_vals = np.hstack([np.arange(1,6), [np.nan]])
rating_vals

array([ 1.,  2.,  3.,  4.,  5., nan])

In [3]:
userids = np.arange(1000)
itemids = np.arange(1000)
random_ratings = np.random.choice(rating_vals, size=len(userids)*len(itemids))

In [4]:
transactions = pd.DataFrame(
    {'user_id': userids.repeat(len(itemids)),
     'item_id': itemids.reshape((-1, 1)).repeat(len(userids), axis=1).T.flatten(),
     'rating': random_ratings}).drop_duplicates()

In [5]:
transactions

Unnamed: 0,user_id,item_id,rating
0,0,0,2.0
1,0,1,
2,0,2,
3,0,3,5.0
4,0,4,4.0
...,...,...,...
999995,999,995,1.0
999996,999,996,3.0
999997,999,997,
999998,999,998,2.0


## Load resype

In [6]:
from collab_filtering import CollabFilteringModel

In [7]:
re = CollabFilteringModel(transactions)

In [8]:
utility_matrix = re.construct_utility_matrix()
utility_matrix

item_id,0,1,2,3,4,5,6,7,8,9,...,990,991,992,993,994,995,996,997,998,999
user_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
0,2.0,,,5.0,4.0,4.0,3.0,4.0,4.0,3.0,...,5.0,1.0,4.0,2.0,3.0,,2.0,,5.0,2.0
1,1.0,,4.0,5.0,3.0,2.0,1.0,3.0,1.0,,...,3.0,3.0,2.0,4.0,4.0,3.0,4.0,4.0,3.0,4.0
2,3.0,4.0,4.0,4.0,2.0,4.0,2.0,4.0,1.0,4.0,...,5.0,4.0,3.0,1.0,,5.0,2.0,2.0,,5.0
3,5.0,2.0,1.0,,2.0,4.0,3.0,3.0,,1.0,...,2.0,,2.0,3.0,5.0,2.0,,5.0,,1.0
4,2.0,1.0,3.0,1.0,2.0,2.0,3.0,1.0,3.0,5.0,...,,3.0,1.0,4.0,4.0,1.0,2.0,1.0,2.0,3.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
995,5.0,,,2.0,1.0,5.0,1.0,2.0,2.0,3.0,...,,1.0,5.0,1.0,5.0,2.0,2.0,,3.0,5.0
996,1.0,5.0,5.0,2.0,5.0,4.0,1.0,,1.0,5.0,...,,,1.0,,,,4.0,5.0,1.0,2.0
997,2.0,5.0,,4.0,5.0,4.0,,3.0,,3.0,...,1.0,4.0,1.0,5.0,4.0,3.0,,5.0,2.0,2.0
998,3.0,2.0,2.0,1.0,1.0,4.0,2.0,1.0,5.0,1.0,...,3.0,4.0,4.0,5.0,5.0,,4.0,5.0,4.0,4.0


## Cluster data 

In [9]:
from sklearn.cluster import (KMeans, SpectralClustering,
                             AgglomerativeClustering, DBSCAN, OPTICS,
                             cluster_optics_dbscan, Birch)

model1 = KMeans(n_clusters = 10)
model2 = KMeans(n_clusters = 15)

In [10]:
x_u,y_u, df_u  = re.cluster_users(model1)
x_i,y_i, df_i  = re.cluster_items(model2)

## Generate new utility matrix based on clusters

In [11]:
# Running this overwrites the original utility matrix
Uc_df = re.utility_matrix_agg()
Uc_df

i_cluster,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14
u_cluster,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1
0,2.534848,2.536878,2.490112,2.494817,2.556541,2.468984,2.51384,2.549661,2.511124,2.568413,2.541558,2.513287,2.42803,2.486118,2.481016
1,2.464226,2.526253,2.514443,2.511989,2.521064,2.46628,2.517262,2.501281,2.495215,2.493323,2.493987,2.483294,2.499158,2.533306,2.501188
2,2.493827,2.471349,2.505523,2.498962,2.461156,2.503813,2.510227,2.489359,2.495005,2.488858,2.464727,2.515907,2.290895,2.507758,2.529366
3,2.520833,2.576782,2.494152,2.512721,2.441057,2.464869,2.538143,2.46476,2.531067,2.504237,2.406085,2.49359,2.446759,2.537913,2.460103
4,2.544702,2.479883,2.462414,2.54447,2.515749,2.528438,2.478007,2.485322,2.512199,2.468627,2.483759,2.554933,2.624724,2.495794,2.493572
5,2.495726,2.506047,2.541835,2.497364,2.448405,2.477376,2.452928,2.445082,2.493927,2.456758,2.462759,2.458251,2.563034,2.483888,2.523756
6,2.504902,2.521365,2.54386,2.577653,2.493185,2.444204,2.521949,2.420983,2.449884,2.555583,2.467787,2.540724,2.512255,2.507154,2.536044
7,2.525808,2.508297,2.52333,2.481209,2.518249,2.519608,2.475918,2.517307,2.459313,2.50006,2.457278,2.493362,2.547872,2.509201,2.494333
8,2.516247,2.570844,2.541212,2.549286,2.526001,2.493896,2.377077,2.573641,2.443893,2.465942,2.479784,2.473149,2.66195,2.509179,2.496115
9,2.48538,2.461106,2.476147,2.44745,2.575524,2.495872,2.609322,2.510081,2.458218,2.389236,2.3868,2.526316,2.489766,2.478189,2.56063


## Train iterative model using `train_model_svd_cluster`

#### Create model object (load from sklearn)

In [12]:
from sklearn.ensemble import RandomForestRegressor
rs_model1 = RandomForestRegressor(random_state=202109)

#### Train model

In [13]:
re.utility_matrix

i_cluster,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14
u_cluster,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1
0,2.534848,2.536878,2.490112,2.494817,2.556541,2.468984,2.51384,2.549661,2.511124,2.568413,2.541558,2.513287,2.42803,2.486118,2.481016
1,2.464226,2.526253,2.514443,2.511989,2.521064,2.46628,2.517262,2.501281,2.495215,2.493323,2.493987,2.483294,2.499158,2.533306,2.501188
2,2.493827,2.471349,2.505523,2.498962,2.461156,2.503813,2.510227,2.489359,2.495005,2.488858,2.464727,2.515907,2.290895,2.507758,2.529366
3,2.520833,2.576782,2.494152,2.512721,2.441057,2.464869,2.538143,2.46476,2.531067,2.504237,2.406085,2.49359,2.446759,2.537913,2.460103
4,2.544702,2.479883,2.462414,2.54447,2.515749,2.528438,2.478007,2.485322,2.512199,2.468627,2.483759,2.554933,2.624724,2.495794,2.493572
5,2.495726,2.506047,2.541835,2.497364,2.448405,2.477376,2.452928,2.445082,2.493927,2.456758,2.462759,2.458251,2.563034,2.483888,2.523756
6,2.504902,2.521365,2.54386,2.577653,2.493185,2.444204,2.521949,2.420983,2.449884,2.555583,2.467787,2.540724,2.512255,2.507154,2.536044
7,2.525808,2.508297,2.52333,2.481209,2.518249,2.519608,2.475918,2.517307,2.459313,2.50006,2.457278,2.493362,2.547872,2.509201,2.494333
8,2.516247,2.570844,2.541212,2.549286,2.526001,2.493896,2.377077,2.573641,2.443893,2.465942,2.479784,2.473149,2.66195,2.509179,2.496115
9,2.48538,2.461106,2.476147,2.44745,2.575524,2.495872,2.609322,2.510081,2.458218,2.389236,2.3868,2.526316,2.489766,2.478189,2.56063


In [14]:
%%time
utility_matrix_imputed = re.train_model_svd_cluster(
    re.utility_matrix, rs_model1, d=10)

CPU times: user 2min 42s, sys: 1.06 s, total: 2min 43s
Wall time: 2min 43s


#### Prediction

In [15]:
utility_matrix_imputed

i_cluster,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14
u_cluster,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1
0,2.527135,2.530404,2.494497,2.500898,2.545132,2.47473,2.512684,2.542626,2.505138,2.555773,2.513247,2.511958,2.441204,2.490325,2.485559
1,2.47148,2.524512,2.513555,2.511858,2.515284,2.470282,2.515367,2.501548,2.490921,2.493068,2.486219,2.489422,2.497437,2.523104,2.503156
2,2.495332,2.481868,2.50629,2.500949,2.471254,2.499296,2.509652,2.491319,2.491049,2.48692,2.462862,2.51414,2.344308,2.506885,2.525012
3,2.517452,2.56415,2.496791,2.510972,2.456073,2.470756,2.532615,2.473109,2.521366,2.500002,2.417311,2.496404,2.455558,2.530546,2.470624
4,2.534622,2.486986,2.470068,2.537159,2.512227,2.523658,2.480095,2.488967,2.503454,2.472577,2.480827,2.540861,2.589293,2.498252,2.496154
5,2.4967,2.507373,2.534156,2.499104,2.460571,2.480676,2.465947,2.45408,2.491818,2.46337,2.461438,2.472282,2.543399,2.488401,2.521855
6,2.504087,2.520935,2.539329,2.561391,2.496364,2.454589,2.519811,2.446874,2.454908,2.534621,2.466017,2.532129,2.510269,2.506554,2.529583
7,2.520606,2.510202,2.519464,2.488109,2.514124,2.513474,2.481534,2.513205,2.463679,2.495893,2.459923,2.495149,2.536652,2.508262,2.495236
8,2.512952,2.560143,2.535756,2.537782,2.523454,2.492725,2.407321,2.560209,2.451281,2.4709,2.475169,2.480005,2.616424,2.50783,2.499289
9,2.488932,2.473964,2.483276,2.460616,2.556494,2.49436,2.585699,2.506738,2.46431,2.413821,2.405149,2.521331,2.493886,2.483878,2.550244


## Train iterative model using `fit`

#### Train model

In [16]:
re.fit(rs_model1, method='svd', d=10, return_models=False)

#### Prediction

In [17]:
re.utility_matrix_preds

i_cluster,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14
u_cluster,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1
0,2.534848,2.536878,2.490112,2.494817,2.556541,2.470855,2.509634,2.549661,2.502158,2.554855,2.541558,2.509401,2.457254,2.492716,2.481016
1,2.464226,2.524068,2.514721,2.515714,2.51721,2.470587,2.517262,2.499543,2.49425,2.491623,2.487342,2.483294,2.499158,2.533306,2.500489
2,2.493827,2.495688,2.505523,2.498962,2.466994,2.500675,2.510227,2.490792,2.495005,2.491005,2.464727,2.515907,2.334135,2.507758,2.529366
3,2.515038,2.561679,2.494152,2.512721,2.441057,2.464869,2.538143,2.470961,2.531067,2.504237,2.406085,2.49359,2.446759,2.537913,2.460103
4,2.536135,2.494503,2.462414,2.54447,2.515749,2.528438,2.479796,2.500635,2.509105,2.468627,2.482434,2.554933,2.598433,2.495794,2.493572
5,2.499266,2.507535,2.541835,2.502915,2.448405,2.477376,2.452928,2.445082,2.493927,2.460965,2.458942,2.458251,2.563034,2.490432,2.523821
6,2.504902,2.521365,2.54386,2.577653,2.493185,2.444204,2.521949,2.420983,2.449884,2.542458,2.469074,2.525754,2.512255,2.509503,2.521444
7,2.519766,2.516884,2.52333,2.490627,2.514766,2.519608,2.479377,2.517307,2.459313,2.495075,2.457278,2.493362,2.547872,2.51183,2.494333
8,2.516247,2.561399,2.541212,2.549286,2.526001,2.489591,2.404344,2.573641,2.450559,2.465942,2.477215,2.473149,2.634606,2.50987,2.496115
9,2.48538,2.487639,2.476147,2.477527,2.575524,2.495872,2.588401,2.510081,2.458218,2.389236,2.3868,2.526316,2.489766,2.486706,2.56063
