In [3]:
%matplotlib inline

import numpy as np
import scipy
import scipy.io
import scipy.sparse as sp
import matplotlib.pyplot as plt
from helpers import *
import csv

%load_ext autoreload
%autoreload 2

In [4]:
from helpers import load_data, preprocess_data

path_dataset = "47b05e70-6076-44e8-96da-2530dc2187de_data_train.csv"
path_submission = "9b4d32bb-f99a-466f-95a1-0ab80048971c_sample_submission (2).csv"
ratings = load_data(path_dataset)
ratings_=ratings.toarray()
submission = load_submission(path_submission)
submission_row_col = submission[0]
submission_pos = submission[1]
num_item, num_user = ratings.get_shape()

number of items: 10000, number of users: 1000


In [8]:
nb_nzi = len(ratings.nonzero()[0])
nb_nzu = len(ratings.nonzero()[1])

users = ratings.nonzero()[1]
items = ratings.nonzero()[0]

stars = []
for i in range(nb_nzi):
    stars.append(ratings[items[i],users[i]])

In [9]:
with open('surprise', 'w') as csvfile:
    fieldnames = ['item', 'user','ratings']
    writer = csv.DictWriter(csvfile, delimiter=";", fieldnames=fieldnames)
    for r1, r2,r3 in zip(items, users, stars):
        writer.writerow({'item':r1,'user':r2,'ratings':r3})

### mean calculation

In [11]:
def extract_user_mean(train):
    #calculate user mean
    user_nnz = train.getnnz(axis=0)
    user_sum = train.sum(axis=0)
    user_mean = np.empty((1, num_user))
    for ind in range(num_user):
        user_mean[0,ind] = user_sum[0,ind] / user_nnz[ind]
    return user_mean
def extract_global_mean(train):
    # calculate the global mean
    nonzero_train = train[train.nonzero()]
    global_mean = nonzero_train.mean()
    return global_mean
user_mean=extract_user_mean(ratings)
global_mean=extract_global_mean(ratings)

### baseline estimate

In [19]:
def baseline_estimate(train,lamda_i,lamda_u,epochs):
    # set the user and item baselines
    bu = np.zeros(num_user)
    bi = np.zeros(num_item)    
    
    # group the indices by row or column index
    nz_train, nz_item_userindices, nz_user_itemindices = build_index_groups(train)
    
    #using Alternating Least Squares (ALS)
    for iter_ in range(epochs):
        for i,i_users in nz_item_userindices:
            dev_i = 0
            for u in i_users:
                dev_i += train[i,u] - global_mean - bu[u]

            bi[i] = dev_i / (lamda_i + len(i_users))

        for u,u_items in nz_user_itemindices:
            dev_u = 0    
            for i in u_items:
                dev_u += train[i,u] - global_mean - bi[i]

            bu[u] = dev_u / (lamda_u + len(u_items))
   
    return bu,bi

In [20]:
# set the parameters
lamda_i = 10
lamda_u = 15
epochs = 10
#baseline_estimate
bu,bi = baseline_estimate(ratings,lamda_i,lamda_u,epochs)

### Similarity with pearson baseline

In [138]:
def user_based_similarity_by_pearson_baseline(train,min_support,global_mean, user_biases, item_biases, shrinkage=100):
    train=train.toarray()
    # set some matrixs
    freq = np.zeros((num_user, num_user))# matrix of number of common items
    prods = np.zeros((num_user, num_user))# matrix of sum (r_ui - b_ui) * (r_vi - b_vi) for common items
    sq_diff_u = np.zeros((num_user,num_user))# matrix of sum (r_ui - b_ui)**2 for common items
    sq_diff_v = np.zeros((num_user,num_user))# matrix of sum (r_vi - b_vi)**2 for common items
    sim = np.zeros((num_user, num_user))#matrix of similatiries

    # Need this because of shrinkage. When pearson coeff is zero when support is 1, so that's OK.
    min_support = max(2, min_support)

    # group the indices by row or column index
    nz_train, nz_item_userindices, nz_user_itemindices = build_index_groups(train)
    
    for u,items_u in nz_user_itemindices:
        sim[u, u] = 1
        for v,items_v in nz_user_itemindices[(u+1):]:  
            com_items = np.intersect1d(items_u,items_v)
            freq[u, v] = len(com_items)
            diff_u = (train[com_items,u] - (global_mean + item_biases[com_items] + user_biases[u]))
            diff_v = (train[com_items,v] - (global_mean + item_biases[com_items] + user_biases[v]))
            prods[u, v]= diff_u.T @ diff_v
            sq_diff_u[u, v] = diff_u.T @ diff_u
            sq_diff_v[u, v] = diff_v.T @ diff_v
            if freq[u, v] < min_support:
                sim[u, v] = 0
            else:
                # calculate the similarity
                sim[u, v] = prods[u, v] / (np.sqrt(sq_diff_u[u, v] *
                                                       sq_diff_v[u, v]))
                # shrunk similarity
                sim[u, v] *= (freq[u, v] - 1) / (freq[u, v] - 1 +
                                                     shrinkage)

            sim[v, u] = sim[u, v]

    return sim

In [139]:
#set the parameters
min_support = 1
shrinkage = 1000
sim = user_based_similarity_by_pearson_baseline(ratings, min_support, global_mean, bu, bi, shrinkage)

### KNN with means 

In [323]:
def KNN_with_user_means(train,sim_matrix,k,min_k,user_mean):
    
    pred=[]
    
    for row,col in submission_row_col:
        i = row-1
        u = col-1
        #x, y = self.switch(u, i)neighbors=[]
        neighbors=[]
        for v in range(num_user):
            if train[i,v]>0:
                new_neighbors=(v,sim_matrix[u, v],train[i,v])
                neighbors.append(new_neighbors)
        # Extract the top-K most-similar ratings
        k_neighbors = heapq.nlargest(k, neighbors, key=lambda t: t[1])

        #initial setting
        est = user_mean[u]
        sum_sim = 0
        sum_ratings = 0
        actual_k = 0
        
        # compute weighted average
        for (nb,sim, r) in k_neighbors:
            if sim > 0:
                sum_sim += sim
                sum_ratings += (sim * (r - user_mean[nb]) )
                print(sum_ratings)
                actual_k += 1

        if actual_k < min_k:
            sum_ratings = 0
        if sum_sim>0:
            est += sum_ratings / sum_sim
        
        # round ratings
        if est < 1:
            est = 1
        elif est > 5:
            est = 5
        else:
            est = np.round(est)
        
        pred.append(est)         
    return pred


In [324]:
import heapq
import math
#initial setting    
k = 200
min_k = 1
ratings_ = ratings.toarray()
pred = KNN_with_user_means(ratings_,sim,k,min_k,user_mean[0].T)

-0.0065502421894853766
-0.023340649916030996
-257.5315550891973
-257.5161663553159
-257.5116641769296
-257.5319152568012
-257.55187431832144
-257.54604646473246
-257.5325788640442
-257.5791640372158
-257.557061506829
-257.5812886427144
-257.58191247016134
-257.59380283528066
-257.61188890120184
-257.6039817824696
-257.5914523647635
-257.6158370838669
-257.60083324186166
-257.5956037891966
-257.6020825118384
-257.61141132332966
-257.6033059886896
-257.601101604286
-257.5990501594532
-257.58503561621825
-257.57095684059135
-257.56749994729904
-257.5681528962354
-257.56718231574195
-257.5611218115769
-257.5590778660517
-257.5575683115635
-257.55962044199725
-257.55861900463026
-257.5535682451109
-257.5602067415307
-257.5653599478826
-257.55472275012977
-257.5593839825808
-257.5561975326309
-257.546219177643
-257.5490095885169
-257.5503130472405
-257.5508781992573
989.299454601118
989.3030117768085
989.3049024190755
964.1145035472076
964.1082965687423
964.116053372256
964.1131426837159
964

-605.5441257811267
-605.5441148038215
-605.5432697486505
-605.5406432322359
-605.541900274917
-605.5398844197349
-605.5370296616586
-605.5381643446963
3597.4303804253414
3597.432116077012
3597.4321492884083
3597.43029841923
3597.432155700983
3597.4337133518084
5303.62130224507
5303.622190238978
5303.623313486713
5303.6227820944005
5303.623973854471
5303.624173546857
5303.6236345411535
5303.623481283747
5303.624177916001
5303.624736741386
5303.6246420272855
5303.624819903978
-257.4873336792149
-257.5013936442676
-257.5191336500775
-257.4976011873521
-257.52031167903453
-257.5265876060414
-257.53726231984183
-257.54549084731383
-257.5531673260811
-257.5753943088571
-257.6162396666132
-257.6151384066431
-257.62165363554504
-257.63483902011916
-257.6484541030673
-257.6782147494399
-257.6889862317857
-257.69777389514337
-257.6799752473092
-257.6814387105081
-257.6874884604358
-257.6862062043751
-257.698530595782
-257.7007185839196
-257.70329239508885
-257.68337036393996
-257.6922064243985
-

0.02197376604222865
0.008498668258404134
0.010946190422296193
0.012248018091092093
0.021309644327098868
0.01998819311225512
0.025209948321759122
0.028396398271651233
0.029658884014891296
0.02214120861454447
0.02222302755307976
0.03039083970399232
1246.8807236400794
1246.8737436056767
1246.864218756587
1246.862229233312
1246.861134224586
1246.8600378352276
1246.8663516990434
1246.8630104982085
1246.8618159816745
1246.8568838500425
1246.854179757578
1246.8511123224926
1246.8516814538586
1246.848593351492
1246.851897660582
1246.8497011693369
1246.8507999895921
1246.8495163793484
1246.8483097307173
1246.8503255858993
1246.8527445401626
1246.854343566653
1246.8543945569093
1246.8538254483406
1246.8516082166454
1246.8492672858736
1246.847643820151
1246.8464195167319
1246.8456575711723
1246.845126178859
1246.8431098058534
1246.8429831661479
1246.8427806081042
1246.8434772403575
1246.843278207244
1246.8441200235563
1246.844226745987
1246.84395294647
1246.8441585760133
1246.8441645507874
-0.028

964.9104359586529
964.9110279438486
964.9115115232534
964.9119429443285
964.912148573872
964.912450897434
964.9126517585794
964.9128471244651
-0.00978195382470584
-257.4971156330396
-257.50469556201506
-257.50515192437837
-257.5035625535497
-257.4820300908243
-257.4583254876322
-257.4448578869439
-257.46221431334425
-257.4611486249191
-257.4674280324978
-257.4730943890629
-257.46647719734887
-257.4814745940557
-257.4689451763496
-257.4585730270079
-257.45361426435386
-257.4590599426815
-257.4457109956399
-257.44866396272994
-257.44661251789717
-257.44918113754534
-257.4494231020396
5654.550692112904
5654.561806037484
5654.5483309397
5654.54312187366
5654.5509446011465
5654.541502280437
5654.543298226835
5654.554171867009
5654.551491005056
5654.554556301817
5654.563617928053
5654.562954118662
5654.56244099645
5654.55847145171
5654.555681040836
5654.55764921715
5654.555677507953
5654.561006448712
5654.553470226471
5654.554987828411
5654.551427284792
5654.548516596252
5654.552066868558
56

0.032687858047161
0.03614475133947304
5912.036259966283
5912.045556629987
5912.039559290538
5912.038255831814
5912.037629754705
5912.040343911622
5912.038983112985
5912.042369334021
5912.045110829299
5912.0502325221105
5912.050801653477
5912.048932823739
5912.0489195427535
5912.0513028108735
5912.052401631128
5912.045119747945
5912.043033468395
5912.044418359972
5912.043841246268
5912.041980460324
5912.040417875738
5912.042218740651
5912.042271282268
-0.013727817923819271
-0.04079115381788029
-0.042642295010986896
-0.05316917147905782
-0.055138092654564944
-0.0626875824500505
-0.05478046371778919
-0.06129569261975247
-0.0660111585232014
-0.07688535960353911
-0.07165590693848438
-0.0861521996985202
-0.11022635972704789
1819.5154589886781
1819.5175104335108
1819.5140525673846
1819.519923698009
1819.506915453445
1819.5136816956972
1819.5089971843465
1819.5046104635705
1819.494584006398
1819.4990533323105
1819.4928966169064
1819.487687550867
1819.4894834972642
1819.4897883593649
1819.47874

-257.45654451351135
-257.45763952223746
-257.46895708980963
-257.46748915907506
-257.46611249897643
-257.46033813528
-257.4569519142445
-257.45340324202135
-257.44981038286016
-257.44645207368325
-257.44516325976474
-257.4404160713659
-257.4385765949748
-257.4374096800338
-257.43547353992506
-257.43770418462583
-257.4342934819976
-257.43364536376515
-257.4328003085942
-257.43291366223156
-257.43499994178177
-257.4334009152914
-257.4313503785545
-257.4301636751907
-257.43039981553886
-257.4288800414889
-257.4305035072114
-257.43196008580156
-257.430442848724
-257.4298170684095
-257.42961737602366
-257.42884716539794
-257.4284443855543
-257.42860222611284
-257.42849550368214
-257.4280640826069
-257.42820633668316
-257.4280007071397
-257.4280580118722
-257.42814446354276
0.01267423189078043
-0.0013857331619175836
0.020146729563494416
-0.0025637621189539238
-0.03856544145735345
-0.013592686227826934
-0.0025165256191927495
-0.009031754521156033
-0.016570766493737288
-0.004041348787641371
-0

-257.681973039755
-257.6863873692243
-257.6816780411847
-257.6691486234786
-257.67386408938205
-257.6588602473768
-257.6692368499747
-257.67468252830236
-257.6733342406031
-257.6855958612128
-257.68866289350126
-257.67531394645965
-257.68464275795094
-257.6715782673094
-257.6702960112487
-257.6577997346188
-257.6719663088586
-257.68537171032284
-257.6921921966487
-257.7032457007346
-257.7066224237903
-257.69909226120456
-257.70751262560447
-257.69837552017776
-257.68919969579724
-257.6889829214211
-257.69103505185484
-257.7046546000462
-257.6994328448367
-257.6935470312305
-257.6806772084042
-257.681303285513
-257.6934709305152
-257.6915902766825
-257.6974757304296
989.1586936646939
989.1657740264625
989.1638023172651
989.168060754395
989.1733896951534
989.176946870844
989.1697613471584
989.161488662399
989.1681847699621
989.1660611131275
989.1681925376304
989.1586911173088
989.1568695385974
989.1566041946402
989.1601528668633
989.1483428203046
989.1474703114752
989.152217499874
989.15

5912.01911114495
5912.01373333648
5912.022858125335
5912.029379892349
5912.031272601861
5912.032436672193
5912.034519508878
5912.030962027114
5912.0301377966525
5912.029386316001
5912.030838690753
5912.034431549914
5912.037789859091
5912.042911551903
5912.050955478108
5912.04939539914
5912.047164754439
5912.050941713268
5912.050373769592
5912.043883670416
5912.044735709254
5764.780411909859
5764.781594516392
5764.785756557265
5764.789384625564
5764.7919533149925
5764.792798370164
5764.7949004120555
5764.797012174698
5764.79669084075
5764.800166194044
5764.802585148308
5764.799340170951
5764.8010383166275
5764.799581738037
5764.799166367182
5764.796490115674
5764.796931831071
5764.797087827967
5764.798898829006
5764.798696270963
5764.800067665221
5764.800018265032
5764.8008582494385
5764.800877903149
5764.801976622924
5764.801598082589
5764.802081661994
5764.802252723153
5764.802258697927
5764.802329199529
-0.011678642248150227
0.015857626423384577
-0.03223062428454737
-0.04413198922251

1815.1588123662623
1815.1555673889056
1815.1530112473545
1815.155750410959
1815.158963305189
1815.1587938360183
3521.347906275574
3521.3458018890597
3521.3454869658312
3521.3445853183935
3521.343868208902
3521.3457729883917
3521.3472902254694
3521.3461148987562
3521.344620693072
3521.345317325325
3521.3467376280287
3521.3456896778066
3521.346626008883
3521.347063967987
3521.347903952394
3521.3470119095214
3521.3462234695153
3521.346815454711
3521.3464260408696
3521.3468574619446
3521.347159785507
3521.3473606466523
3521.347307715807
3521.3470890999547
3521.3470347372736
-0.011678642248150227
-0.016367913087549278
0.005164549637862722
-0.05305312173440439
-0.05932904874128876
-0.08455593973793286
-0.10237787261571095
-0.13423174580733477
-0.1231290639962623
-0.1483951141669863
-0.18163460779695625
-0.16340571738581403
-0.1633505436538084
-0.18937594503406222
-0.1841464923690075
-0.199772181884318
-0.20805667396892974
-0.19443236608969144
-0.17327849091220976
-0.16258131363855882
-0.1773

0.052948964843241714
0.05642435115147834
0.046397893979015906
0.050867219891426015
0.060043044271942674
0.058365785274561245
0.06340692151230662
0.06440835887931876
0.07294462665502624
0.08045386155461856
0.08108681864000138
0.08313268771030509
0.0883544429198091
0.0965938068548145
0.0999989168268095
0.09416023139484292
0.07883878857333035
0.09170861139964104
0.09489506134953316
0.10353231390761328
0.11103368550822415
0.11300186182247693
0.10695894879065622
0.11301096600581287
0.12091684254425368
0.12908465469516625
0.1362101222851494
0.14558034776986106
0.13348793659077607
0.12286306053770564
0.11952185970283155
0.12690156720295548
0.12806563753499575
0.1346521517872492
0.1399286319039441
0.13810705319259967
0.1416557254157533
0.13316571679059833
0.13995620353873478
0.14331451271565546
0.14756491507610012
0.15421601227377976
0.15896320067259692
0.16700712687791347
0.17255490939797172
0.17002396550564966
0.17196010561437328
0.17480627789973155
0.1802931435385829
0.1840090347902675
0.18

KeyboardInterrupt: 

### submission

In [306]:
create_csv_submission(submission_pos, pred, "pred_fuxian")