In [2]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [3]:
import sys
sys.path.append('../')

In [4]:
from src.utils import loading, Spark
import pyspark.ml as M
import pyspark.sql.functions as F
import pyspark.sql.types as T
from pyspark.ml.recommendation import ALS, ALSModel
from pyspark.ml.evaluation import RegressionEvaluator
from pyspark.ml.tuning import ParamGridBuilder, TrainValidationSplit
from tqdm.notebook import tqdm

In [5]:
from scipy import sparse
import numpy as np
import pandas as pd

In [6]:
spark = Spark()

Spark UI address http://127.0.0.1:4040


In [7]:
datas = loading(spark, '../data/interim')

In [8]:
cast_int = lambda df: df.select([F.col(c).cast('int') for c in ['userId', 'movieId']] + \
                                 [F.col('rating').cast('float')])
training, test = cast_int(datas['train_0.75_0.25']), cast_int(datas['test_0.75_0.25'])

In [9]:
class indexTransformer():
    def __init__(self, usercol='userId', itemcol='movieId'):
        self.usercol = usercol
        self.itemcol = itemcol
        self.u_indxer =  M.feature.StringIndexer(inputCol=usercol, 
                                                outputCol=usercol+'_idx', 
                                                handleInvalid = 'skip')
        self.i_indxer = M.feature.StringIndexer(inputCol=itemcol, 
                                                outputCol=itemcol+'_idx', 
                                                handleInvalid = 'skip')
        self.X = None
    def fit(self, X):
        self.X = X
        self.u_indxer = self.u_indxer.fit(self.X)
        self.i_indxer = self.i_indxer.fit(self.X)
        return
    def transform(self, X):
        X_ = self.u_indxer.transform(X)
        X_ = self.i_indxer.transform(X_)
        return self._cast_int(X_).orderBy([self.usercol+'_idx', self.itemcol+'_idx'])
    
    def fit_transform(self, X):
        self.fit(X)
        return self.transform(X)
    
    def _cast_int(self, X):
        return X.select([F.col(c).cast('int') for c in X.columns])

In [10]:
idxer = indexTransformer()
training = idxer.fit_transform(training)
test = idxer.transform(test)

In [13]:
X_train = training.select('userId_idx', 'movieId_idx', 'rating').toPandas().values

In [14]:
row = X_train[:, 0]
col = X_train[:, 1]
data = X_train[:, 2]

In [15]:
X_train = sparse.csr_matrix((data, (row, col)))

In [46]:
def pearson_corr(A):
    n = A.shape[1]
    
    rowsum = A.sum(1)
    centering = rowsum.dot(rowsum.T) / n
    C = (A.dot(A.T) - centering) / (n - 1)
    
    d = np.diag(C)
    coeffs = C / np.sqrt(np.outer(d, d))
    return np.array(np.nan_to_num(coeffs))

In [47]:
sim = pearson_corr(X_train)

  coeffs = C / np.sqrt(np.outer(d, d))


In [48]:
sim

array([[ 1.00000000e+00, -2.00924617e-02,  4.46350065e-02, ...,
        -4.53758402e-03, -4.53758402e-03, -4.53758402e-03],
       [-2.00924617e-02,  1.00000000e+00, -1.77460937e-02, ...,
        -4.02180630e-03, -4.02180630e-03, -4.02180630e-03],
       [ 4.46350065e-02, -1.77460937e-02,  1.00000000e+00, ...,
        -4.00769168e-03, -4.00769168e-03, -4.00769168e-03],
       ...,
       [-4.53758402e-03, -4.02180630e-03, -4.00769168e-03, ...,
         1.00000000e+00, -9.08265213e-04, -9.08265213e-04],
       [-4.53758402e-03, -4.02180630e-03, -4.00769168e-03, ...,
        -9.08265213e-04,  1.00000000e+00, -9.08265213e-04],
       [-4.53758402e-03, -4.02180630e-03, -4.00769168e-03, ...,
        -9.08265213e-04, -9.08265213e-04,  1.00000000e+00]])

In [49]:
mu_u = np.array(np.nan_to_num(X_train.sum(1) / (X_train != 0).sum(1))).reshape(-1)

  mu_u = np.array(np.nan_to_num(X_train.sum(1) / (X_train != 0).sum(1))).reshape(-1)


In [63]:
items = sparse.find(X_train[0, :])[1]

In [70]:
items

array([   3,   21,   34,   90,  157,  167,  185,  194,  204,  236,  254,
        304,  314,  330,  378,  494,  524,  566,  608,  612,  794,  830,
        951,  968, 1067], dtype=int32)

In [71]:
users = sparse.find(X_train[:, 3])[0]
rating = sparse.find(X_train[:, 3])[2]

In [72]:
users

array([    0,     6,    32,   214,   299,   301,   366,   419,   521,
         583,   605,   658,   659,   660,   734,   756,   866,   921,
        1016,  1102,  1133,  1363,  1394,  1416,  1457,  1603,  1767,
        1777,  1851,  1985,  2074,  2179,  2305,  2434,  2657,  2706,
        2729,  2740,  2752,  2770,  2888,  2904,  3001,  3045,  3048,
        3081,  3116,  3171,  3176,  3238,  3331,  3354,  3366,  3497,
        3587,  3667,  3691,  3714,  3930,  3938,  4105,  4106,  4116,
        4133,  4142,  4158,  4232,  4240,  4340,  4368,  4419,  4437,
        4444,  4461,  4542,  4553,  4586,  4604,  4683,  4712,  4717,
        4727,  4773,  4787,  4841,  4922,  4923,  4967,  5008,  5057,
        5219,  5253,  5291,  5292,  5308,  5563,  5800,  6014,  6025,
        6095,  6150,  6238,  6285,  6291,  6388,  6395,  6430,  6867,
        7064,  7076,  7229,  7239,  7357,  7436,  7440,  7448,  7697,
        7704,  7853,  7913,  8013,  8100,  8286,  8346,  8404,  8491,
        8527,  8534,

In [74]:
rating

array([4, 5, 4, 5, 5, 3, 4, 3, 3, 3, 4, 3, 5, 5, 4, 5, 5, 4, 4, 5, 4, 5,
       5, 4, 4, 4, 5, 2, 5, 3, 5, 4, 4, 4, 5, 5, 5, 4, 4, 5, 3, 4, 5, 5,
       5, 4, 4, 5, 3, 4, 4, 4, 5, 5, 5, 4, 5, 5, 5, 5, 4, 3, 5, 4, 4, 3,
       3, 4, 4, 5, 5, 5, 3, 4, 4, 4, 2, 3, 5, 4, 5, 3, 4, 4, 4, 4, 4, 4,
       4, 4, 4, 4, 5, 5, 4, 5, 3, 3, 5, 4, 4, 4, 5, 4, 5, 4, 5, 4, 4, 4,
       3, 4, 5, 5, 4, 4, 5, 4, 4, 4, 4, 5, 3, 4, 5, 5, 4, 4, 4, 4, 4, 4,
       5, 5, 5, 5, 4, 5, 2, 5, 4, 3, 5, 5, 4, 4, 5, 4, 4, 5, 2, 5, 3, 5,
       4, 4, 3, 5, 4, 5, 5, 4, 4, 4, 4, 4, 5, 3, 5, 2, 3, 4, 4, 4, 5, 5,
       3, 4, 4, 5, 4, 4, 3, 4, 4, 4, 2, 5, 4, 5, 4, 4, 4, 5, 4, 4, 4, 4,
       5, 5, 2, 5, 4, 2, 4, 5, 5, 4, 5, 4, 5, 4, 5, 4, 4, 3, 4, 4, 4, 5,
       3, 4, 4, 4, 3, 5, 4, 4, 4, 4, 5, 5, 2, 5, 5, 5, 4, 5, 5, 5, 5, 4,
       5], dtype=int32)

In [56]:
np.linalg.norm(sim[0, users], ord = 1)

273.35453756686104

In [79]:
mu_u = np.array(np.nan_to_num(X_train.sum(1) / (X_train != 0).sum(1))).reshape(-1)

pred = np.zeros(X_train.shape)
for i in tqdm(range(X_train.shape[0])):
    items = sparse.find(X_train[i, :])[1]
    for j in items:
        users = sparse.find(X_train[:, j])[0]
        rating = sparse.find(X_train[:, j])[2]
        val = mu_u[i] + np.array(sim[i, users]).reshape(-1).dot((rating - mu_u[users]))/\
                              np.linalg.norm(sim[i, users], ord = 1)
        print(val)
        pred[i,j] = val

  mu_u = np.array(np.nan_to_num(X_train.sum(1) / (X_train != 0).sum(1))).reshape(-1)


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=23503.0), HTML(value='')))

4.182532287197313
4.0706118508898275
4.027610275032192
3.6464846136135938
3.5216433801542006
3.9508751458869944
4.010624294317435
3.864521093211673
3.686605480340293
3.8359357396636864
3.891335664689529
3.2380540554418795
4.03586638317891
3.781365817711667
3.799462401572413
3.5430927514183437
3.4561463367190846
3.1665205615683774
3.7846076680518244
3.8411607787748507
3.7774202943418436
3.175917433439548
3.6577724596519063
3.2956752977907935
3.774229947559386
4.016640989034514
3.5173970103955683
3.98475071614801
3.1440083100012224
3.6795960027001016
3.8760309508415896
3.107086360919912
3.80484465957035
3.489814936002982
3.3995682805937895
3.7188903952306442
3.545841877809717
3.7104374648100142
3.551437890411867
3.726521424981569
3.7360013795252742
3.830842096059552
3.266581063998689
3.779624225010081
1.7568703291232775
3.5306189332095843
3.1943206120746535
3.081867896160562
2.965207219207899
2.862512012042367
2.9084141723917294
3.3284374821377387
2.5092128337766
3.263146135269038
2.7812

3.6301070362852577
3.395912696819738
3.6709549820433414
3.0532421629784983
3.162552249135048
3.1800353997285287
3.504097774776242
2.9590078097082766
2.684893675041404
2.656110004977128
3.2837759417802905
3.0669037684224123
3.328740966135205
3.7352460108037873
3.6708608206945446
3.7039744939704966
3.4136342732648783
3.227139325624887
3.4469822177678897
3.3914115825526534
3.6375371127618736
3.558546633914511
3.061252892321371
3.4076701822395687
3.2696834927349583
2.980074295929772
3.533454096991138
4.144661111794263
4.094107528731259
3.8855345673631874
3.9228561631689174
3.522391691474856
3.8565893471225494
3.7965161933108758
3.8274764328711406
3.749389787585858
3.9323874444714915
3.4484916813981026
3.8591262537647455
4.002116782816743
3.936253655360852
3.665293994093483
3.6290924991964175
3.10151799249517
3.3112379331292248
3.498867829873526
3.388053537998378
3.229986537192622
3.2494060362256825
2.5326286277371826
2.4617490053815367
3.117442482926689
3.038306884061334
1.966309739565398


2.9764040583758744
3.122157080420948
2.8415551765873146
2.736766105115624
2.27253138236494
2.473314138677568
2.522737204889761
2.6329958728820957
2.593502687927543
2.7899799130205696
2.8652607401737633
2.2827470132798253
2.320737882968722
2.9220514208288666
2.2012184198767253
3.363062690861512
2.9051290377877432
3.1195446383819627
3.224556579564213
3.219642442303332
3.5082642670428617
3.364922752022985
3.3892320762317274
3.1711529008886408
3.3385568350564605
3.1609776865675405
3.19194057449426
3.5259171137540943
3.266885939911007
2.7533917902288803
2.7452747144897387
2.420425046384864
2.4821191183691136
2.2071130541623014
2.7299364846711587
2.342229008495984
2.5448040525440434
2.5411733524748508
2.014821232699762
2.4923509222370397
2.9819958323817595
2.9913291923619783
3.03405170227298
3.2766525736079295
3.2403120435196135
3.0797158131731956
2.4395850022596486
3.2007140621336623
2.916548574796123
2.7585920976455958
2.661045483771904
2.917655285905896
2.6091382899786058
2.49428482135953

3.2165376183714605
3.2388700767584995
3.1706335467769864
3.1194577021234826
2.7430439722902036
3.0651411175615326
3.2418354605439332
4.492221265775121
4.2156632596869885
4.12610651743774
4.397683627432819
4.225971629986726
3.875037001324261
4.3269686744563804
4.655519855662004
4.63741805636651
4.090020674987706
4.214009220436124
4.094704996975164
4.103462553064657
2.835578456486824
2.4620371081888477
2.809053994428937
2.442074711158162
2.2047829670101073
2.5137739043533642
1.8087883535731923
2.490636600016481
1.7964354206779267
2.5072734503633742
3.350405220009401
3.3718204622690795
2.90635689765685
2.9657079739657437
2.9105388347201595
2.63433097111321
2.0376211469353023
3.3135165595950715
2.701255473251443
3.1078308153377407
3.0992895126617377
2.7110336121630034
3.141920333947698
3.54959150422785
3.5312186979679527
3.0602489593416635
3.1236191991542683
3.314803927379752
3.2845184115997883
2.1416724989651046
2.545341582868521
3.243679189733417
3.1189042851505806
2.9588539821787987
2.7

4.032006620121387
3.6902584473380666
3.6207694051728687
3.578213342843906
3.49664652369527
3.8375606651464036
4.0944698119621545
3.9613425288800728
4.218419451183674
3.7628544135952655
3.8631124278120863
4.117093321264764
4.194762906263499
4.155101827715944
2.9734147766973864
4.076806143398548
4.252753224501083
4.137693551728296
3.734222041523104
3.968553831992129
3.7089845784850195
3.901801423093765
3.6143301974633557
3.686298608554418
3.482450311076606
3.8002765067387774
3.8122303173915135
3.7841347199383906
3.7255465148372866
3.9256956518558637
3.0646121286233527
3.9051910358436017
3.4443621726222458
3.389212178856117
3.314296337660202
2.812380574975319
2.790719004684146
2.3817290224760765
3.3752921192519123
2.244828870718657
2.4644174985206697
2.8360715515647006
3.362233095981599
3.271354335350634
2.794475164602505
3.7758507080592913
3.921387184847886
3.9054101865022384
3.942186145482772
4.09017986663774
4.0543293944384375
3.7077243267809603
4.200940956958913
4.272119662890131
3.54

3.2534499650012285
3.559074579386492
3.349558604765819
3.130157951996709
3.2135903860665027
3.0331600010911997
2.8360217539406825
3.8032746733159195
3.465772110360561
3.613399324048669
3.392586510309967
3.357804052558611
3.7139613304388366
3.596841642843576
3.6734143172478233
3.1900919767158227
3.067371075087928
3.4242149913702717
3.3531975115006927
3.662585312080104
3.8835633481009446
4.255484037741922
3.493434941671206
3.9163254758399217
3.6826147507573737
4.033808551111951
3.99808960430662
3.983432152424568
3.588108291925878
3.692558136086241
4.035161190094965
3.1836933773407123
3.203418526025509
3.0846702367738486
2.4118105332721083
2.3946529424970464
2.735949433369875
2.754310458971395
2.788129235037988
2.7131506309125304
2.834454511302871
2.1901307067832834
3.806476567909755
3.6302460223958457
3.348334037663018
3.629842976459229
3.9993139241535807
3.864486205188441
4.087617665510892
3.325726335800764
3.589624638353328
3.585967765448551
3.9916712206616944
3.527164154649878
3.18012

2.9580139489825923
2.701075074044799
2.492595237502131
2.7199310683299096
2.1337562228409084
2.0366665432827054
4.3096139737513495
4.078032916183319
3.9907151450898786
3.6850757498649758
3.9698058594010988
4.22588934886261
3.822683692916317
3.783173470575389
3.928292528277398
3.847399336728102
4.297060118102045
3.4255751628794036
2.6662294650538096
2.272732790095776
2.3715934755937025
2.29815607805625
2.2068245220713094
2.3009476712687573
2.3604810364856927
1.94281638990973
1.9885957324265484
2.392039299542199
3.314698064588808
3.458602911630992
3.195833046115946
2.6526483929531457
2.9696280298443285
3.1390679183740957
2.6181704436333524
2.8550917144519676
3.0915024893775787
3.2492068608350784
3.041140179211007
2.970546903701656
3.6812029625455263
3.5003500080956496
3.297831692566045
3.8012673646401645
3.594282054412068
3.5842702347993205
2.763333791433905
3.6605924116694712
3.1824289516441984
3.376962766329324
3.7870568331180694
3.482427471452456
3.9577095409962735
3.864382644546767
3

3.995392497290047
3.478203528674297
3.5705119406831267
3.3952949147621077
3.310177210182247
3.31901114940208
3.1972286408263377
3.0562624734950905
3.1066124837341076
2.948460612238564
3.335070086231894
2.9388828784056225
3.565933748443825
3.322404074517526
3.0412670918052167
3.2176891423835263
3.7405268585109326
3.3715096762208123
2.738423755766069
3.0355365343550025
3.690606637316846
3.5338814432430583
3.5197806788375052
3.36059937261961
2.9657087111423412
2.943051071519885
3.1086307190348337
2.5524095762857417
3.2209474810148024
3.050816555661641
3.4153850560646335
2.7728551688150507
3.885832908549268
3.714507822968132
3.6390421308854837
3.5845534530970253
3.4759144099126718
3.9012059923811493
3.275543304782992



KeyboardInterrupt: 

In [80]:
sparse.find?