<a href="https://colab.research.google.com/github/seijimorimoto/CE888-Data-Science/blob/master/Lab5/CE888_Lab5.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#**Lab 5. Recommender Systems**

**1. Loading the *jester-data-1* dataset.**

In [0]:
import pandas as pd
import numpy as np

In [0]:
col_names = ['JokeCount'] + [f'Joke{num + 1}' for num in range(100)]
df = pd.read_csv('https://raw.githubusercontent.com/albanda/CE888/master/lab5-recommender/jester-data-1.csv',header=None, names=col_names)

In [0]:
df = df.drop(columns=['JokeCount'])

In [0]:
arr = df.values

**2. Labeling 10% of the dataset cells as NaNs (represented as 99 in our case) to create a validation set.**

In [0]:
def replace(orig, percentage=0.1):
  """
  Replaces 'percentage'% of the original values in 'orig' with 99's
  :param orig: original data array
  :param percentage: percentage of values to replace (0<percentage<1)
  """
  new_data = orig.copy()
  rated = np.where(arr!=99)
  n_rated = len(rated[0])
  idx = np.random.choice(n_rated, size=int(percentage*n_rated), replace=False)
  new_data[rated[0][idx], rated[1][idx]] = 99
  return new_data, (rated[0][idx], rated[1][idx])

In [0]:
new_arr, validation_indices = replace(arr, 0.1)

**3. Infer hidden ratings of the users and hidden characteristics of the jokes by utilizing latent factors modeling.**

In [0]:
n_latent_factors = 2

user_ratings = new_arr
# Initialise as random values
latent_user_preferences = np.random.random((user_ratings.shape[0], n_latent_factors))
latent_item_features = np.random.random((user_ratings.shape[1], n_latent_factors))

In [0]:
def predict_rating(user_id, item_id):
    """ Predict a rating given a user_id and an item_id.
    """
    user_preference = latent_user_preferences[user_id]
    item_preference = latent_item_features[item_id]
    return user_preference.dot(item_preference)


def train(user_id, item_id, rating, alpha=0.0001):
    
    #print item_id
    prediction_rating = predict_rating(user_id, item_id)
    err =  prediction_rating - rating
    #print err
    user_pref_values = latent_user_preferences[user_id][:]
    latent_user_preferences[user_id] -= alpha * err * latent_item_features[item_id]
    latent_item_features[item_id] -= alpha * err * user_pref_values
    return err
    

def sgd(iterations=50):
    """ Iterate over all users and all items and train for 
        a certain number of iterations
    """
    for iteration in range(iterations):
        error = []
        for user_id in range(latent_user_preferences.shape[0]):
            for item_id in range(latent_item_features.shape[0]):
                rating = user_ratings[user_id][item_id]
                if rating != 99:
                    err = train(user_id, item_id, rating)
                    error.append(err)
        mse = (np.array(error) ** 2).mean()   
        if (iteration % 10) == 0:
            print(mse)

In [0]:
sgd()

25.053632117133635
22.901327893010883
21.289883849050668
20.29711679089692
18.495648917778848


**4. Comparing predictions (obtained through latent factors modeling) with true values (original dataset).**

In [0]:
predictions = latent_user_preferences.dot(latent_item_features.T)
values = [zip(user_ratings[i], predictions[i]) for i in range(predictions.shape[0])]
comparison_data = pd.DataFrame(values)
comparison_data.columns = [f'Joke{num + 1}' for num in range(100)]

In [0]:
comparison_data

Unnamed: 0,Joke1,Joke2,Joke3,Joke4,Joke5,Joke6,Joke7,Joke8,Joke9,Joke10,Joke11,Joke12,Joke13,Joke14,Joke15,Joke16,Joke17,Joke18,Joke19,Joke20,Joke21,Joke22,Joke23,Joke24,Joke25,Joke26,Joke27,Joke28,Joke29,Joke30,Joke31,Joke32,Joke33,Joke34,Joke35,Joke36,Joke37,Joke38,Joke39,Joke40,...,Joke61,Joke62,Joke63,Joke64,Joke65,Joke66,Joke67,Joke68,Joke69,Joke70,Joke71,Joke72,Joke73,Joke74,Joke75,Joke76,Joke77,Joke78,Joke79,Joke80,Joke81,Joke82,Joke83,Joke84,Joke85,Joke86,Joke87,Joke88,Joke89,Joke90,Joke91,Joke92,Joke93,Joke94,Joke95,Joke96,Joke97,Joke98,Joke99,Joke100
0,"(99.0, -3.1883958892690005)","(99.0, -3.764506638261345)","(99.0, -4.820591478704082)","(-8.16, -8.385132141409427)","(-7.52, -2.5084620862662903)","(-8.5, -1.5751803876805703)","(-9.85, -4.568982547222585)","(4.17, -3.783547958829055)","(99.0, -7.3431645777557)","(-4.76, -2.9910108061484566)","(-8.5, -1.3716799884801405)","(-6.75, -1.545340877585691)","(-7.18, -5.967084055011345)","(8.45, -1.472993808324943)","(-7.18, -5.8833170875497185)","(-7.52, -9.17504308137995)","(-7.43, -4.092418817164939)","(-9.81, -4.565710249305481)","(-9.85, -3.460714761540243)","(-9.85, -5.645610141865154)","(-9.37, 0.3160672048576742)","(1.5, -3.7131964580413115)","(-4.37, -5.688849826009972)","(-9.81, -9.2544556636687)","(99.0, -4.3158568961549095)","(1.12, -1.5047874963854457)","(7.82, 3.2387251906967425)","(2.86, -1.6846678816614893)","(9.13, 3.3946874769912125)","(-7.43, -7.138840346995856)","(2.14, 1.108237986878809)","(-4.08, 2.9787481023333444)","(-9.08, -8.476918411268972)","(7.82, -2.6192872308375583)","(5.05, 2.7354227025380435)","(4.95, 3.6267242510120807)","(-9.17, -8.074543560285829)","(-8.4, -2.9520411387216168)","(99.0, -2.4483385092662595)","(-8.4, -3.299079838153281)",...,"(8.59, 1.1210292683534746)","(3.59, 1.9395308014378039)","(-6.84, -4.947252640653697)","(-9.03, -7.325446524038105)","(2.82, 1.3677528281596953)","(-1.36, 1.434114377969336)","(-9.08, -7.799092121451946)","(8.3, 1.4075562857074677)","(99.0, 1.687991448650621)","(-4.81, -4.360541048634479)","(99.0, -5.33875730106742)","(99.0, 0.9154413039119257)","(99.0, -2.4650081662539605)","(99.0, -7.098591216439074)","(99.0, -4.797789930963332)","(99.0, 0.5866458451590482)","(99.0, -3.225099349443256)","(-9.42, -1.6276915205570497)","(99.0, -3.973846732792764)","(99.0, -1.9116052314776721)","(99.0, -0.6398598294165625)","(-7.72, -3.5196210960443373)","(99.0, -0.015589205059639974)","(99.0, -3.6070709640361587)","(99.0, -2.8627608965341813)","(99.0, -4.561016094071904)","(99.0, -0.36356520737630565)","(99.0, -0.6973358379246835)","(99.0, 2.759521471523628)","(99.0, -4.104165597603101)","(2.82, -0.8045321262930739)","(99.0, -2.2227276754881338)","(99.0, 0.3066215383184227)","(99.0, -3.309791737789375)","(99.0, -3.0332324189597943)","(99.0, -2.0725992906259134)","(-5.63, -1.5576108984202672)","(99.0, -1.9348708738236442)","(99.0, -4.944718352261926)","(99.0, -1.7075386589627728)"
1,"(4.08, 2.8831219765214375)","(-0.29, 3.2038812798831757)","(6.36, 3.615952078382612)","(4.37, 4.772747477519686)","(-2.38, 2.3318571392772713)","(-9.66, 2.820866687300003)","(-0.73, 2.8759381610871504)","(-5.34, 2.079253312202786)","(8.88, 4.4702611875595375)","(9.22, 2.9289636396293814)","(6.75, 2.6867121424280076)","(8.64, 2.88007372881708)","(4.42, 2.7943032434583523)","(7.43, 2.8480192430193245)","(99.0, 2.5610390479450715)","(-0.97, 3.6562957949216135)","(4.66, 2.2751889286264926)","(99.0, 2.6373827609834866)","(3.3, 2.719425794484822)","(-1.21, 3.4253741771369084)","(0.87, 2.1420799855130377)","(8.64, 3.5143175522782375)","(8.35, 4.047643942420594)","(9.17, 4.891088734680343)","(0.05, 3.660047271563889)","(7.57, 2.799991660995579)","(99.0, 0.8758525120098716)","(0.87, 2.7925808580413767)","(-0.39, 0.6789812116090436)","(6.99, 4.6709445477910485)","(99.0, 1.7387573577790794)","(-0.92, 0.9859334252894942)","(7.14, 4.816389265785961)","(9.03, 3.082197876504474)","(99.0, 1.1049999219886675)","(0.73, 0.6025807183177946)","(7.09, 4.360475864818317)","(3.4, 3.1511879630680077)","(-0.87, 3.137937202998842)","(7.91, 3.381004905004493)",...,"(-6.7, 1.8569897004996787)","(-3.35, 1.4921167413168905)","(-9.03, 3.818729594694821)","(4.47, 4.334371356031835)","(4.08, 1.633089265384066)","(-3.83, 1.6054523865335941)","(8.74, 4.572694125088371)","(1.12, 1.5896747069629744)","(0.78, 1.5571052333737152)","(7.52, 3.6498217583941366)","(-5.0, 3.379388239576269)","(99.0, 1.8945665449296416)","(8.3, 2.9958069892957586)","(7.77, 3.991595025855203)","(7.33, 3.343066435580344)","(6.21, 1.8414073384984577)","(7.72, 3.111059162723777)","(8.98, 2.8047115493007033)","(8.64, 3.3136953887651988)","(8.2, 2.852684754680316)","(3.93, 2.2544223616493104)","(4.85, 3.2328057593670336)","(4.85, 2.0932837077497646)","(6.07, 3.24498722659662)","(8.98, 3.079965374528792)","(4.51, 3.2394562436014693)","(99.0, 2.040156590753143)","(3.69, 2.2527356221931663)","(4.56, 1.1209807786003188)","(0.58, 3.4671718547907213)","(2.82, 2.235380453990118)","(-4.95, 2.7490183860485478)","(-0.29, 1.9919472746348827)","(7.86, 3.056901393385677)","(-0.19, 2.873600944406585)","(-2.14, 2.7009799976578717)","(3.06, 2.4314101336513456)","(0.34, 2.5424143339634635)","(-4.32, 3.5361355932519296)","(1.07, 2.7021533177873764)"
2,"(99.0, 5.828243737414878)","(99.0, 6.162509779178894)","(99.0, 6.143569425340183)","(99.0, 5.239251369042193)","(9.03, 4.8136142262260275)","(9.27, 7.893853772680035)","(9.03, 3.8433954492410876)","(9.27, 2.097228965336443)","(99.0, 5.631650197247966)","(99.0, 6.272939330112478)","(7.33, 7.700905089538598)","(7.57, 8.148793906567624)","(9.37, 1.5663580690985996)","(6.17, 8.136354685917855)","(99.0, 0.8476293330505662)","(99.0, 0.10946529779191988)","(99.0, 2.362497491563625)","(9.03, 2.9914453387299447)","(99.0, 4.854030610557535)","(99.0, 4.288549477277253)","(99.0, 8.14017947093838)","(99.0, 7.350018899463413)","(8.25, 6.461609333839165)","(99.0, 4.430615183355501)","(99.0, 7.018122930189001)","(7.48, 7.9187842497281125)","(7.28, 7.740722227759061)","(7.28, 7.636923553585167)","(8.93, 7.255116570869475)","(99.0, 6.6421907210423825)","(6.17, 7.816038197411518)","(7.28, 7.767088004708391)","(99.0, 5.265713497120155)","(99.0, 7.3506415073385885)","(8.98, 7.849347543951324)","(7.33, 7.310041544301137)","(99.0, 4.199611179061788)","(6.17, 7.126189943754669)","(9.08, 7.7933639569807776)","(7.33, 7.458957758191632)",...,"(6.46, 8.258732161922753)","(7.28, 8.110017494063365)","(99.0, 6.691961150691832)","(99.0, 5.1688453110239525)","(7.04, 7.804861841211732)","(7.28, 7.7997917013319045)","(99.0, 5.352500213743591)","(7.28, 7.705452009335251)","(8.25, 7.986440677200904)","(99.0, 6.91799880977773)","(99.0, 4.558849632509487)","(99.0, 8.101931928137912)","(99.0, 7.259355225862347)","(99.0, 4.259927964242874)","(99.0, 5.196059907960713)","(8.93, 7.444489473327852)","(99.0, 6.5946276335280825)","(99.0, 7.761331346425354)","(99.0, 6.2597714405625275)","(9.08, 7.530717612780987)","(99.0, 7.18711236981716)","(99.0, 6.613863114318964)","(99.0, 7.4943435242687935)","(99.0, 6.533512358243979)","(99.0, 6.9971359220172005)","(99.0, 5.160003470016319)","(99.0, 6.8097990962291615)","(99.0, 7.099497227688713)","(99.0, 7.940927013858845)","(9.03, 6.625944993302656)","(99.0, 6.88506753345033)","(99.0, 6.716994431691211)","(99.0, 7.587686594758561)","(99.0, 6.2799819338553835)","(99.0, 6.014233085588056)","(99.0, 6.757532681742115)","(99.0, 6.520342622407953)","(99.0, 6.383599234584857)","(99.0, 5.680832067489491)","(99.0, 7.279766956735011)"
3,"(99.0, 1.5100219988792305)","(8.35, 1.413957137361862)","(99.0, 0.9136708504568634)","(99.0, -1.2062069697439812)","(1.8, 1.3051490824684127)","(8.16, 3.3194615650322445)","(-2.82, -0.14991102475351917)","(6.21, -0.6812048189092291)","(99.0, -0.5208180968183912)","(1.84, 1.8299295089160361)","(7.33, 3.3149683468566953)","(6.6, 3.4641586407562492)","(6.31, -1.9679324504867526)","(8.11, 3.4913810023561815)","(-7.23, -2.297875753567893)","(-6.65, -4.205756531458401)","(1.17, -0.6885428293836034)","(-6.6, -0.5856149133242999)","(99.0, 0.8835563690242527)","(-2.09, -0.421564769234883)","(99.0, 4.3243875029573635)","(99.0, 2.047225455135935)","(99.0, 0.6735720742640325)","(99.0, -2.0250148314708767)","(99.0, 1.5969511121764923)","(2.91, 3.3649544496081623)","(3.93, 5.476999676401282)","(6.75, 3.1367455791961407)","(6.6, 5.300231716168891)","(99.0, 0.09270643391832613)","(6.65, 4.526011308907946)","(-6.12, 5.3697677332183655)","(99.0, -1.2352625016834375)","(7.57, 2.555680958999811)","(6.21, 5.298955613359849)","(6.65, 5.436203707310107)","(99.0, -1.5954814265674429)","(-8.3, 2.2859227981211996)","(7.18, 2.862295788892597)","(2.82, 2.295496178961019)",...,"(99.0, 4.759145199799835)","(-3.69, 5.063029580499014)","(99.0, 1.1362716288032462)","(99.0, -0.7501009778291762)","(7.82, 4.64082379826898)","(0.24, 4.669047642459311)","(99.0, -0.8758635460363522)","(7.28, 4.608295565584076)","(-2.33, 4.882765994115769)","(99.0, 1.5248106395648542)","(99.0, -0.1403083474559663)","(99.0, 4.58317629019976)","(99.0, 2.580497266300861)","(99.0, -1.111183257549485)","(99.0, 0.43799713126680945)","(99.0, 4.093044372904384)","(99.0, 1.8862837552458083)","(99.0, 3.227058422486015)","(99.0, 1.3666308062782817)","(99.0, 2.9768247609620455)","(99.0, 3.3912286363756)","(99.0, 1.7593460460063155)","(99.0, 3.838883136718025)","(99.0, 1.677488069505981)","(0.63, 2.2611637100345114)","(99.0, 0.5294775602277048)","(99.0, 3.325932759518287)","(-2.33, 3.3195658230168563)","(99.0, 5.3571487444452925)","(99.0, 1.4940174976093017)","(99.0, 3.15972555888349)","(99.0, 2.414698362694183)","(99.0, 4.036458493004097)","(0.53, 1.6854656432541877)","(99.0, 1.6775479933035342)","(99.0, 2.5052394540802627)","(99.0, 2.622731884382881)","(99.0, 2.377312195992443)","(99.0, 0.6185336929496997)","(99.0, 2.942827713251525)"
4,"(99.0, 2.762972964504446)","(4.61, 2.765983056472014)","(-4.17, 2.3354285064015645)","(-5.39, 0.30206236455729807)","(1.36, 2.3313341261168055)","(1.6, 4.826630367761077)","(7.04, 0.8470275125517118)","(4.61, -0.04790312351204617)","(-0.44, 0.9848434018520402)","(5.73, 3.1479848438622127)","(99.0, 4.773878841222765)","(99.0, 5.014417358952476)","(-3.93, -1.2775467165751857)","(7.23, 5.034429701833354)","(-2.33, -1.7405882555726933)","(-9.66, -3.5514083422525706)","(2.72, 0.013118627497757574)","(-1.36, 0.26019987759504337)","(2.57, 1.9828006921717278)","(99.0, 0.7287275013754697)","(8.2, 5.7443015138472004)","(6.12, 3.6060319009299966)","(8.3, 2.211748814936965)","(99.0, -0.5998094121551499)","(7.77, 3.1386793611514685)","(1.89, 4.871667388415891)","(-1.17, 6.623899057156192)","(5.68, 4.60598402897902)","(8.45, 6.350327305046856)","(4.61, 1.7632144541249766)","(8.06, 5.8336912091723585)","(-9.47, 6.53932875537472)","(7.28, 0.2840458784642259)","(5.68, 4.038893591099691)","(99.0, 6.4999257907699395)","(3.2, 6.4799694733981275)","(-1.26, -0.2928473317720528)","(6.8, 3.7524088652449223)","(99.0, 4.412092838348477)","(99.0, 3.844938881233736)",...,"(7.38, 6.14434977250469)","(6.17, 6.365249356830563)","(4.71, 2.66392594941962)","(-2.28, 0.6723624033233344)","(7.38, 5.928564395103817)","(4.56, 5.951297659581069)","(7.14, 0.6119074539594017)","(4.22, 5.875674104648472)","(3.01, 6.180505665925692)","(3.83, 3.051897289000144)","(99.0, 1.0366236701530125)","(99.0, 5.954836161226556)","(99.0, 4.036864387719793)","(99.0, 0.13459326514228703)","(99.0, 1.6903540482375794)","(99.0, 5.371012446857737)","(4.13, 3.277516597797227)","(99.0, 4.714388909095642)","(99.0, 2.750371098242806)","(99.0, 4.4429572495016725)","(5.24, 4.708490552463146)","(5.92, 3.1743682791712966)","(0.87, 5.1673590860557885)","(7.28, 3.0843305640336953)","(3.93, 3.698613099058724)","(99.0, 1.7590622104293685)","(99.0, 4.557244005305126)","(4.71, 4.625286989669622)","(2.82, 6.57267170745127)","(2.96, 2.951633128176117)","(5.19, 4.434885496624959)","(5.58, 3.758235570729503)","(4.27, 5.359168729687789)","(5.19, 3.026829627263101)","(5.73, 2.952703264354266)","(1.55, 3.8455671425516957)","(3.11, 3.8854086824526846)","(6.55, 3.641877335691461)","(99.0, 1.9669215833910823)","(1.6, 4.350389409867262)"
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
24978,"(0.44, 2.6179297992815287)","(7.43, 2.6566099284236193)","(9.08, 2.345819472327753)","(2.33, 0.7890270408269865)","(99.0, 2.1975732001847286)","(99.0, 4.323324394897927)","(-8.79, 1.027273281423171)","(-0.53, 0.1947945232834644)","(-8.74, 1.3214757057159627)","(7.23, 2.942582619575165)","(-0.53, 4.264418394768435)","(5.63, 4.485828200050452)","(-7.14, -0.744901061166586)","(-4.08, 4.498818568763261)","(-3.5, -1.1554443531476053)","(-8.2, -2.534510388570712)","(99.0, 0.26753217486937697)","(-9.22, 0.5134157742145881)","(-0.15, 1.952079740131616)","(-6.46, 0.9910858689973746)","(5.63, 5.008236689990278)","(-0.92, 3.3886980081061897)","(99.0, 2.29188621388884)","(-4.17, 0.054005524758068885)","(2.82, 3.017328816788344)","(3.4, 4.358341239018081)","(99.0, 5.595290989795884)","(6.84, 4.137041976744828)","(6.8, 5.346074395634516)","(-0.87, 1.9900033535218569)","(7.38, 5.036916318704536)","(-3.5, 5.537532162346832)","(8.88, 0.7789999019899732)","(7.43, 3.6991413646168616)","(5.39, 5.518266658533612)","(2.23, 5.445032958000078)","(-0.68, 0.2488674863550243)","(3.4, 3.469199377142182)","(-0.58, 4.015108574342854)","(4.42, 3.5719043876404757)",...,"(8.59, 5.308036795464763)","(3.45, 5.4501801697583385)","(0.87, 2.6412794852005868)","(9.27, 1.0468511733025232)","(-4.66, 5.103722226568653)","(5.73, 5.119468724292961)","(-0.49, 1.023569164448541)","(8.35, 5.054936752941076)","(99.0, 5.304211219237635)","(5.0, 2.9441638358411204)","(99.0, 1.2413899310165646)","(8.98, 5.155017683527483)","(8.98, 3.687712453000007)","(-9.81, 0.5619459164013213)","(9.13, 1.7797563866536197)","(9.08, 4.664565924446906)","(9.08, 3.0706084862928376)","(3.98, 4.22836441148229)","(0.73, 2.6560424451009346)","(9.03, 4.008542665241508)","(8.98, 4.161396126652815)","(9.22, 2.9987495875675836)","(8.93, 4.5239871512933885)","(9.13, 2.9254106404480367)","(9.27, 3.4165257346068483)","(-1.99, 1.8250827668805436)","(-9.95, 4.01172250981804)","(-9.9, 4.09216377542134)","(9.13, 5.5804336925502245)","(8.83, 2.8403616005230345)","(8.83, 3.9322114484071133)","(-1.21, 3.428668737719582)","(99.0, 4.671719280945169)","(-6.7, 2.8564800236683)","(8.45, 2.774293476855145)","(9.03, 3.4957173941929653)","(6.55, 3.498369711783968)","(99.0, 3.308809603456669)","(8.79, 2.031029983029315)","(7.43, 3.914749850205091)"
24979,"(99.0, 2.4763550764557976)","(-8.16, 2.8333976197269855)","(8.59, 3.4084531584621036)","(9.08, 5.243706672882376)","(0.87, 1.9769743867415275)","(99.0, 1.8540895624430365)","(-3.5, 2.981583191786484)","(5.78, 2.3325108307279856)","(99.0, 4.723340645078757)","(4.9, 2.424360629553192)","(8.88, 1.718553705727227)","(-8.69, 1.8698384539803958)","(-7.48, 3.4596387636189125)","(-8.83, 1.828716592286393)","(99.0, 3.323441453201903)","(6.6, 5.030425029451683)","(99.0, 2.534756115704045)","(1.5, 2.8726426150533357)","(7.67, 2.5027228059010547)","(99.0, 3.6262409624805056)","(9.22, 0.8509890347420065)","(8.74, 2.9546979123345887)","(9.03, 3.9231930597201012)","(9.08, 5.617324800258138)","(8.93, 3.242472629111723)","(3.74, 1.8187362482051297)","(3.2, -0.7972792381415443)","(-9.17, 1.8816396677919134)","(-8.98, -0.9436306027914387)","(8.79, 4.738719452443071)","(-7.67, 0.3770843094593606)","(-3.06, -0.6518149516858079)","(9.13, 5.297221034254756)","(8.4, 2.3566571975174195)","(-0.63, -0.508425479468008)","(-7.18, -1.0635938141954864)","(0.58, 4.943127997320678)","(99.0, 2.510368045849922)","(9.27, 2.3188692490245466)","(8.5, 2.741971953670738)",...,"(99.0, 0.42576885052278735)","(8.11, -0.04046905848159843)","(-7.96, 3.546680208961293)","(8.93, 4.655445062777653)","(-0.87, 0.23378316895747933)","(-5.87, 0.19686084538827034)","(8.88, 4.937519874574201)","(-1.12, 0.19951678731783226)","(-8.74, 0.08152314634079463)","(8.74, 3.2543118761594894)","(99.0, 3.492458558449539)","(99.0, 0.5184575473096722)","(99.0, 2.2608204646171117)","(99.0, 4.417090662243059)","(99.0, 3.276815846008656)","(4.9, 0.61554594343421)","(99.0, 2.5928130547422317)","(99.0, 1.8661335876681526)","(99.0, 2.9600917212174864)","(99.0, 1.99236495223812)","(-0.29, 1.2537941951396194)","(0.92, 2.75626860494482)","(-0.78, 0.9511014915453503)","(0.15, 2.793977884538303)","(-0.1, 2.445320734069352)","(0.0, 3.1428196368295205)","(-0.19, 1.0552686069365877)","(-0.87, 1.2742009497314235)","(-1.36, -0.5100838904481545)","(-0.58, 3.0774007775456544)","(-1.17, 1.3058436591642373)","(-5.73, 2.0601345150483024)","(-1.46, 0.7866655882053036)","(0.24, 2.5995469035813854)","(9.22, 2.4149082551962646)","(-8.2, 1.9831470784057315)","(-7.23, 1.671733962957461)","(-8.59, 1.8608107093501207)","(9.13, 3.4181229149832886)","(8.45, 1.8492245100055684)"
24980,"(99.0, -0.2301208684066623)","(99.0, -0.2649206410573123)","(99.0, -0.3227543890190289)","(99.0, -0.5100299030117041)","(-7.77, -0.183200130686016)","(99.0, -0.16098908245154225)","(6.7, -0.2872358015699891)","(-6.75, -0.22761922380578523)","(99.0, -0.4564944358062858)","(99.0, -0.22347296266257125)","(99.0, -0.1479904180794232)","(99.0, -0.16175477574355357)","(-6.46, -0.34255666266841983)","(-1.65, -0.15766329889265984)","(-6.8, -0.3311755144292824)","(-6.41, -0.5050322014730217)","(-6.99, -0.24708823997840995)","(7.23, -0.2790196889329038)","(6.75, -0.23589011978161312)","(-6.99, -0.3505757513692687)","(6.55, -0.05942342695383664)","(99.0, -0.2733037464109149)","(99.0, -0.3734492317686507)","(99.0, -0.5501554346563962)","(99.0, -0.30327809588474086)","(0.49, -0.15735745630550912)","(99.0, 0.1048908384281135)","(-6.94, -0.16457980659857202)","(-0.49, 0.11803876302655845)","(99.0, -0.45480291031925907)","(99.0, -0.012850640927755539)","(-0.53, 0.09036118008968307)","(99.0, -0.5153214825759824)","(99.0, -0.21321945560267047)","(-7.86, 0.07621898507159447)","(-0.34, 0.13026680819856926)","(99.0, -0.48316195919729293)","(-6.94, -0.22938092485592237)","(99.0, -0.2080051285451051)","(99.0, -0.25158327493123406)",...,"(0.49, -0.01632388347437669)","(-0.24, 0.030040510731935546)","(99.0, -0.33488501121991765)","(99.0, -0.45115556291675085)","(-3.11, 0.0015103451614943723)","(-6.65, 0.00520351616170079)","(99.0, -0.47890606294543453)","(-0.58, 0.004634535027634575)","(6.31, 0.01738874403513491)","(99.0, -0.3047882062375214)","(99.0, -0.336269553869081)","(-7.86, -0.026138074572672212)","(99.0, -0.2038835682433379)","(99.0, -0.4301203157349164)","(99.0, -0.3125641014823939)","(99.0, -0.03799774020767868)","(99.0, -0.23936627011825135)","(99.0, -0.1626235312385404)","(99.0, -0.27733748302829936)","(99.0, -0.17604400410104487)","(99.0, -0.10294382410003047)","(99.0, -0.25572626179397395)","(99.0, -0.07154973219056074)","(99.0, -0.25977208069846636)","(99.0, -0.2232592662883265)","(99.0, -0.2992176234065328)","(99.0, -0.08420736521510176)","(99.0, -0.10527460727442255)","(99.0, 0.07667890858756392)","(99.0, -0.2879501828441618)","(99.0, -0.10914037507167448)","(99.0, -0.18545868263543808)","(99.0, -0.05473068234175276)","(99.0, -0.24105053554801595)","(99.0, -0.22335191250961045)","(99.0, -0.17759428629144106)","(99.0, -0.1470677693884808)","(99.0, -0.16650136975937155)","(99.0, -0.32520791929287146)","(99.0, -0.16246711212725745)"
24981,"(99.0, 0.27090079079403134)","(99.0, 0.2324242233690662)","(99.0, 0.0850648530216961)","(99.0, -0.5145172573431175)","(-9.71, 0.24089146530629604)","(99.0, 0.7437007019416086)","(4.56, -0.16012468483174472)","(-8.3, -0.2646132171298936)","(99.0, -0.3236766694690673)","(99.0, 0.35209627822229866)","(99.0, 0.7481854358149187)","(99.0, 0.7788052146750487)","(-9.47, -0.6290920932530077)","(99.0, 0.7872293953258188)","(3.45, -0.7049987470552082)","(-0.92, -1.2469013373395084)","(-4.51, -0.2747752550788164)","(-4.13, -0.2632899473952678)","(99.0, 0.11501270114789693)","(-9.51, -0.25386219527324455)","(99.0, 1.0334264680889476)","(99.0, 0.3838974253253367)","(99.0, 0.004487666627656575)","(99.0, -0.7322678197159387)","(99.0, 0.2607551407903897)","(-0.49, 0.7564013981573365)","(2.91, 1.3862784340766243)","(2.62, 0.6974143086374353)","(8.3, 1.3486405625566613)","(99.0, -0.17270964851927426)","(99.0, 1.1028107896060606)","(5.44, 1.3537765687541758)","(99.0, -0.5239059846512035)","(99.0, 0.5342242515288561)","(-0.68, 1.3303596888664535)","(2.04, 1.3871913908322488)","(99.0, -0.5982988660319307)","(99.0, 0.4612217459296218)","(1.55, 0.6115486651993125)","(99.0, 0.45402656864936797)",...,"(-8.83, 1.1584084281992533)","(-0.78, 1.2527448393571181)","(99.0, 0.13436338758805416)","(99.0, -0.37752967959465683)","(4.51, 1.1370965112887228)","(-2.48, 1.1455947985685184)","(99.0, -0.4202498318285555)","(1.26, 1.1304733256793753)","(5.78, 1.2031658028769687)","(99.0, 0.2424405033512896)","(99.0, -0.17884112928165924)","(99.0, 1.1111002737992293)","(99.0, 0.5443125615801291)","(99.0, -0.4569136616100632)","(99.0, -0.027040015934636005)","(99.0, 0.985980840033699)","(99.0, 0.3590675835084022)","(99.0, 0.7203707068224026)","(99.0, 0.21549987591725245)","(3.16, 0.6533270744802847)","(99.0, 0.7862149209342763)","(99.0, 0.320953733794398)","(99.0, 0.9093256666819973)","(99.0, 0.2991699701691483)","(99.0, 0.4577889720906352)","(99.0, 0.0010962416979097538)","(99.0, 0.7782755789849919)","(99.0, 0.7676646408138614)","(99.0, 1.3448076677099956)","(99.0, 0.2421345246503115)","(99.0, 0.7268618913792172)","(99.0, 0.5116281099583561)","(99.0, 0.9649345435854608)","(99.0, 0.3091674646371793)","(99.0, 0.314833016638264)","(99.0, 0.537178898366006)","(99.0, 0.5790666042468311)","(99.0, 0.5106182063443454)","(99.0, 0.0117373085223718)","(99.0, 0.6508353363719726)"


**5. Calculating the performance (MSE) of the trained latent factors model on the validation set.**

In [28]:
original_data = arr
assert len(validation_indices[0]) == len(validation_indices[1])
validation_size = len(validation_indices[0])
errors = []
for i in range(validation_size):
  user_id = validation_indices[0][i]
  joke_id = validation_indices[1][i]
  prediction_rating = predict_rating(user_id, joke_id)
  real_rating = original_data[user_id][joke_id]
  error = prediction_rating - real_rating
  errors.append(error)
mse = (np.array(errors) ** 2).mean()
print('Validation MSE:', mse)

Validation MSE: 18.304230220838267


**6. Making predictions on the test set.**

In [0]:
filled_data = np.zeros(original_data.shape)
rows, columns = original_data.shape
for user_id in range(rows):
  for joke_id in range(columns):
    if original_data[user_id][joke_id] == 99:
      filled_data[user_id][joke_id] = predict_rating(user_id, joke_id)
    else:
      filled_data[user_id][joke_id] = original_data[user_id][joke_id]

In [0]:
filled_data

array([[-7.82      ,  8.79      , -9.66      , ..., -1.93487087,
        -4.94471835, -1.70753866],
       [ 4.08      , -0.29      ,  6.36      , ...,  0.34      ,
        -4.32      ,  1.07      ],
       [ 5.82824374,  6.16250978,  6.14356943, ...,  6.38359923,
         5.68083207,  7.27976696],
       ...,
       [-0.23012087, -0.26492064, -0.32275439, ..., -0.16650137,
        -0.32520792, -0.16246711],
       [ 0.27090079,  0.23242422,  0.08506485, ...,  0.51061821,
         0.01173731,  0.65083534],
       [ 2.43      ,  2.67      , -3.98      , ...,  3.4894801 ,
         2.32802986,  4.09969455]])

**7. Finding the best and worst rated jokes.**

In [0]:
means = np.mean(filled_data, axis=0)

In [0]:
max_joke = -9999999999
max_joke_index = -1
min_joke = 99999999999
min_joke_index = -1

for i in range(len(means)):
  if means[i] > max_joke:
    max_joke = means[i]
    max_joke_index = i
  elif means[i] < min_joke:
    min_joke = means[i]
    min_joke_index = i

print(f'The best-rated joke was joke {max_joke_index + 1}, with a mean rating of: {max_joke}')
print(f'The worst-rated joke was joke {min_joke_index + 1}, with a mean rating of: {min_joke}')

The best-rated joke was joke 50, with a mean rating of: 3.6655005509779905
The worst-rated joke was joke 58, with a mean rating of: -3.343396876647032
