<a href="https://colab.research.google.com/github/rossrco/experiments/blob/recommenders/recommenders/collaborative_filtering.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install pyspark

Collecting pyspark
[?25l  Downloading https://files.pythonhosted.org/packages/45/b0/9d6860891ab14a39d4bddf80ba26ce51c2f9dc4805e5c6978ac0472c120a/pyspark-3.1.1.tar.gz (212.3MB)
[K     |████████████████████████████████| 212.3MB 70kB/s 
[?25hCollecting py4j==0.10.9
[?25l  Downloading https://files.pythonhosted.org/packages/9e/b6/6a4fb90cd235dc8e265a6a2067f2a2c99f0d91787f06aca4bcf7c23f3f80/py4j-0.10.9-py2.py3-none-any.whl (198kB)
[K     |████████████████████████████████| 204kB 38.1MB/s 
[?25hBuilding wheels for collected packages: pyspark
  Building wheel for pyspark (setup.py) ... [?25l[?25hdone
  Created wheel for pyspark: filename=pyspark-3.1.1-py2.py3-none-any.whl size=212767604 sha256=6ae31e98754cd081333641caeeeb2e7708ceab09d40826c007fbfa96af9a89f1
  Stored in directory: /root/.cache/pip/wheels/0b/90/c0/01de724414ef122bd05f056541fb6a0ecf47c7ca655f8b3c0f
Successfully built pyspark
Installing collected packages: py4j, pyspark
Successfully installed py4j-0.10.9 pyspark-3.1.1


In [2]:
from pyspark.sql import SparkSession
from pyspark.sql import types as T
from pyspark.sql import functions as F
from pyspark.ml.recommendation import ALS
from pyspark.ml.evaluation import RankingEvaluator

spark = SparkSession.builder.master('local[*]').getOrCreate()

In [3]:
def download_dataset():
  print('Downloading movielens data...')
  from urllib.request import urlretrieve
  import zipfile

  url = 'http://files.grouplens.org/datasets/movielens/ml-100k.zip'
  dest_file = 'movielens.zip'

  urlretrieve(url, dest_file)
  zip_ref = zipfile.ZipFile(dest_file, 'r')
  zip_ref.extractall()
  print('Done. Dataset contains:')
  print(zip_ref.read('ml-100k/u.info'))


def read_ratings():
  print('Reading the ratings file...')
  ratings_schema = T.StructType(
      [T.StructField('user_id', T.DoubleType(), False),
       T.StructField('movie_id', T.DoubleType(), True),
       T.StructField('rating', T.DoubleType(), True),
       T.StructField('unix_timestamp', T.LongType(), True)]
       )

  ratings = (spark.read
             .load('ml-100k/u.data', format='csv', sep='\t',
                   header='false', schema=ratings_schema)
             .withColumn('unix_timestamp',
                         F.to_timestamp(F.col('unix_timestamp'))))

  print(f'Ingested {ratings.count()} ratings.')
  return ratings

In [4]:
download_dataset()

ratings = read_ratings()

Downloading movielens data...
Done. Dataset contains:
b'943 users\n1682 items\n100000 ratings\n'
Reading the ratings file...
Ingested 100000 ratings.


In [5]:
ratings.show()

+-------+--------+------+-------------------+
|user_id|movie_id|rating|     unix_timestamp|
+-------+--------+------+-------------------+
|  196.0|   242.0|   3.0|1997-12-04 15:55:49|
|  186.0|   302.0|   3.0|1998-04-04 19:22:22|
|   22.0|   377.0|   1.0|1997-11-07 07:18:36|
|  244.0|    51.0|   2.0|1997-11-27 05:02:03|
|  166.0|   346.0|   1.0|1998-02-02 05:33:16|
|  298.0|   474.0|   4.0|1998-01-07 14:20:06|
|  115.0|   265.0|   2.0|1997-12-03 17:51:28|
|  253.0|   465.0|   5.0|1998-04-03 18:34:27|
|  305.0|   451.0|   3.0|1998-02-01 09:20:17|
|    6.0|    86.0|   3.0|1997-12-31 21:16:53|
|   62.0|   257.0|   2.0|1997-11-12 22:07:14|
|  286.0|  1014.0|   5.0|1997-11-17 15:38:45|
|  200.0|   222.0|   5.0|1997-10-05 09:05:40|
|  210.0|    40.0|   3.0|1998-03-27 21:59:54|
|  224.0|    29.0|   3.0|1998-02-21 23:40:57|
|  303.0|   785.0|   3.0|1997-11-14 05:28:38|
|  122.0|   387.0|   5.0|1997-11-11 17:47:39|
|  194.0|   274.0|   2.0|1997-11-14 20:36:34|
|  291.0|  1042.0|   4.0|1997-09-2

In [6]:
train, test = ratings.randomSplit(weights=[0.8, 0.2], seed=42)

In [8]:
model = ALS(userCol='user_id', itemCol='movie_id', ratingCol='rating').fit(train)

In [9]:
k = 3
test_recomm = model.recommendForUserSubset(dataset=test, numItems=k)
test_recomm.show(truncate=False)

+-------+---------------------------------------------------------+
|user_id|recommendations                                          |
+-------+---------------------------------------------------------+
|471    |[{266, 5.074799}, {680, 5.0551834}, {309, 4.8210654}]    |
|463    |[{1591, 4.727555}, {1062, 4.349219}, {1449, 4.2445273}]  |
|833    |[{1368, 5.0188975}, {1643, 4.883041}, {1597, 4.7060757}] |
|496    |[{1591, 5.084086}, {253, 4.5795393}, {475, 4.501776}]    |
|148    |[{1129, 5.4802837}, {408, 5.0231385}, {114, 4.9768257}]  |
|540    |[{1449, 4.997586}, {1643, 4.744245}, {1398, 4.713765}]   |
|392    |[{1643, 5.7083597}, {119, 5.063157}, {1449, 5.056046}]   |
|243    |[{1449, 4.605709}, {1398, 4.428873}, {1643, 4.3661985}]  |
|623    |[{1643, 4.760572}, {694, 4.6491685}, {496, 4.5237613}]   |
|737    |[{1591, 5.588254}, {1512, 5.0285597}, {1449, 5.002753}]  |
|897    |[{1643, 5.496185}, {1169, 4.922133}, {313, 4.858496}]    |
|858    |[{853, 4.5270243}, {1473, 4.4608526}, {