In [1]:
import pyspark.pandas as pspd
from pyspark import SparkConf, SparkContext
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, avg, collect_list, udf
from pyspark.sql.types import StructField, StructType, StringType, IntegerType, FloatType



In [2]:
spark = SparkSession.builder.appName('Collaborative Filtering').config("spark.executor.memory", "6g").getOrCreate()

In [3]:
sc = spark.sparkContext

In [4]:
schema = StructType([
    StructField('userId', IntegerType(), nullable=True),
    StructField('movieId', IntegerType(), nullable=True),
    StructField('rating', FloatType(), nullable=True),
    StructField('timestamp', IntegerType(), nullable=True),
])

In [5]:
import pandas as pd
df = pd.DataFrame([{'a': 10, 'b': 1}, {'a': 10, 'b': 0}])
df

Unnamed: 0,a,b
0,10,1
1,10,0


In [6]:
df = df.groupby(by='a', as_index=False)[['b']].agg(lambda x: list(x))
df

Unnamed: 0,a,b
0,10,"[1, 0]"


In [7]:
# rdd = sc.textFile('ratings.csv')
# rdd.collect()

In [8]:
# headers = rdd.first()
# rdd = rdd.filter(lambda x: x != headers)
# rdd.collect()

In [9]:
%%time
rdf = spark.read.options(header= True,).schema(schema=schema).csv('ratings.csv')
rdf

CPU times: total: 0 ns
Wall time: 1.45 s


DataFrame[userId: int, movieId: int, rating: float, timestamp: int]

In [10]:
rdf.show()

+------+-------+------+---------+
|userId|movieId|rating|timestamp|
+------+-------+------+---------+
|     1|      1|   4.0|964982703|
|     1|      3|   4.0|964981247|
|     1|      6|   4.0|964982224|
|     1|     47|   5.0|964983815|
|     1|     50|   5.0|964982931|
|     1|     70|   3.0|964982400|
|     1|    101|   5.0|964980868|
|     1|    110|   4.0|964982176|
|     1|    151|   5.0|964984041|
|     1|    157|   5.0|964984100|
|     1|    163|   5.0|964983650|
|     1|    216|   5.0|964981208|
|     1|    223|   3.0|964980985|
|     1|    231|   5.0|964981179|
|     1|    235|   4.0|964980908|
|     1|    260|   5.0|964981680|
|     1|    296|   3.0|964982967|
|     1|    316|   3.0|964982310|
|     1|    333|   5.0|964981179|
|     1|    349|   4.0|964982563|
+------+-------+------+---------+
only showing top 20 rows



In [11]:
rdf.printSchema()

root
 |-- userId: integer (nullable = true)
 |-- movieId: integer (nullable = true)
 |-- rating: float (nullable = true)
 |-- timestamp: integer (nullable = true)



In [12]:
rdf.rdd.getNumPartitions()

1

In [13]:
rdf.count()

100836

In [14]:
rdf.select('userId', 'movieId', 'rating').show()

+------+-------+------+
|userId|movieId|rating|
+------+-------+------+
|     1|      1|   4.0|
|     1|      3|   4.0|
|     1|      6|   4.0|
|     1|     47|   5.0|
|     1|     50|   5.0|
|     1|     70|   3.0|
|     1|    101|   5.0|
|     1|    110|   4.0|
|     1|    151|   5.0|
|     1|    157|   5.0|
|     1|    163|   5.0|
|     1|    216|   5.0|
|     1|    223|   3.0|
|     1|    231|   5.0|
|     1|    235|   4.0|
|     1|    260|   5.0|
|     1|    296|   3.0|
|     1|    316|   3.0|
|     1|    333|   5.0|
|     1|    349|   4.0|
+------+-------+------+
only showing top 20 rows



In [15]:
rdf.select(rdf.userId, rdf.movieId, rdf.rating).show()

+------+-------+------+
|userId|movieId|rating|
+------+-------+------+
|     1|      1|   4.0|
|     1|      3|   4.0|
|     1|      6|   4.0|
|     1|     47|   5.0|
|     1|     50|   5.0|
|     1|     70|   3.0|
|     1|    101|   5.0|
|     1|    110|   4.0|
|     1|    151|   5.0|
|     1|    157|   5.0|
|     1|    163|   5.0|
|     1|    216|   5.0|
|     1|    223|   3.0|
|     1|    231|   5.0|
|     1|    235|   4.0|
|     1|    260|   5.0|
|     1|    296|   3.0|
|     1|    316|   3.0|
|     1|    333|   5.0|
|     1|    349|   4.0|
+------+-------+------+
only showing top 20 rows



In [16]:
def normalize(rdf, column):
    mean = rdf.agg(avg(col(column)))
    mean = mean.collect()[0][0]
    rdf = rdf.withColumn(column+'_new', col(column)-mean)
    
    return rdf

In [17]:
df = rdf.pandas_api()
df.head()

Unnamed: 0,userId,movieId,rating,timestamp
0,1,1,4.0,964982703
1,1,3,4.0,964981247
2,1,6,4.0,964982224
3,1,47,5.0,964983815
4,1,50,5.0,964982931


In [18]:
rdf = normalize(rdf, 'rating')
rdf.show()

+------+-------+------+---------+------------------+
|userId|movieId|rating|timestamp|        rating_new|
+------+-------+------+---------+------------------+
|     1|      1|   4.0|964982703| 0.498443016383038|
|     1|      3|   4.0|964981247| 0.498443016383038|
|     1|      6|   4.0|964982224| 0.498443016383038|
|     1|     47|   5.0|964983815| 1.498443016383038|
|     1|     50|   5.0|964982931| 1.498443016383038|
|     1|     70|   3.0|964982400|-0.501556983616962|
|     1|    101|   5.0|964980868| 1.498443016383038|
|     1|    110|   4.0|964982176| 0.498443016383038|
|     1|    151|   5.0|964984041| 1.498443016383038|
|     1|    157|   5.0|964984100| 1.498443016383038|
|     1|    163|   5.0|964983650| 1.498443016383038|
|     1|    216|   5.0|964981208| 1.498443016383038|
|     1|    223|   3.0|964980985|-0.501556983616962|
|     1|    231|   5.0|964981179| 1.498443016383038|
|     1|    235|   4.0|964980908| 0.498443016383038|
|     1|    260|   5.0|964981680| 1.4984430163

In [19]:
df['rating_new'] = df['rating']-df['rating'].mean()
df.head()

Unnamed: 0,userId,movieId,rating,timestamp,rating_new
0,1,1,4.0,964982703,0.498443
1,1,3,4.0,964981247,0.498443
2,1,6,4.0,964982224,0.498443
3,1,47,5.0,964983815,1.498443
4,1,50,5.0,964982931,1.498443


In [20]:
sdf = df.to_spark()
sdf.show()

+------+-------+------+---------+------------------+
|userId|movieId|rating|timestamp|        rating_new|
+------+-------+------+---------+------------------+
|     1|      1|   4.0|964982703| 0.498443016383038|
|     1|      3|   4.0|964981247| 0.498443016383038|
|     1|      6|   4.0|964982224| 0.498443016383038|
|     1|     47|   5.0|964983815| 1.498443016383038|
|     1|     50|   5.0|964982931| 1.498443016383038|
|     1|     70|   3.0|964982400|-0.501556983616962|
|     1|    101|   5.0|964980868| 1.498443016383038|
|     1|    110|   4.0|964982176| 0.498443016383038|
|     1|    151|   5.0|964984041| 1.498443016383038|
|     1|    157|   5.0|964984100| 1.498443016383038|
|     1|    163|   5.0|964983650| 1.498443016383038|
|     1|    216|   5.0|964981208| 1.498443016383038|
|     1|    223|   3.0|964980985|-0.501556983616962|
|     1|    231|   5.0|964981179| 1.498443016383038|
|     1|    235|   4.0|964980908| 0.498443016383038|
|     1|    260|   5.0|964981680| 1.4984430163



In [21]:
rdf.filter((rdf.rating >= 4) & (rdf.rating_new > 0)).show()

+------+-------+------+---------+-----------------+
|userId|movieId|rating|timestamp|       rating_new|
+------+-------+------+---------+-----------------+
|     1|      1|   4.0|964982703|0.498443016383038|
|     1|      3|   4.0|964981247|0.498443016383038|
|     1|      6|   4.0|964982224|0.498443016383038|
|     1|     47|   5.0|964983815|1.498443016383038|
|     1|     50|   5.0|964982931|1.498443016383038|
|     1|    101|   5.0|964980868|1.498443016383038|
|     1|    110|   4.0|964982176|0.498443016383038|
|     1|    151|   5.0|964984041|1.498443016383038|
|     1|    157|   5.0|964984100|1.498443016383038|
|     1|    163|   5.0|964983650|1.498443016383038|
|     1|    216|   5.0|964981208|1.498443016383038|
|     1|    231|   5.0|964981179|1.498443016383038|
|     1|    235|   4.0|964980908|0.498443016383038|
|     1|    260|   5.0|964981680|1.498443016383038|
|     1|    333|   5.0|964981179|1.498443016383038|
|     1|    349|   4.0|964982563|0.498443016383038|
|     1|    

In [22]:
df[df['rating'] >= 4].head()

Unnamed: 0,userId,movieId,rating,timestamp,rating_new
0,1,1,4.0,964982703,0.498443
1,1,3,4.0,964981247,0.498443
2,1,6,4.0,964982224,0.498443
3,1,47,5.0,964983815,1.498443
4,1,50,5.0,964982931,1.498443


In [23]:
rdf.select('rating').distinct().show()

+------+
|rating|
+------+
|   5.0|
|   2.5|
|   2.0|
|   3.0|
|   1.5|
|   0.5|
|   3.5|
|   1.0|
|   4.5|
|   4.0|
+------+



In [24]:
rdf.select('userId').distinct().count()

610

In [25]:
rdf.dropDuplicates(['userId', 'movieId']).count()

100836

In [26]:
rdf.count()

100836

In [27]:
group = rdf.groupBy('userId').agg(collect_list('movieId'), collect_list('rating'))
group = group.withColumnRenamed('collect_list(movieId)', 'movieId')
group = group.withColumnRenamed('collect_list(rating)', 'rating')

def get_avg(x):
    if len(x) < 1:
        return 0
    return sum(x)/len(x)


_udf = udf(lambda x: get_avg(x), FloatType())
group = group.withColumn('avg_rating', _udf(group.rating))
group.show()

+------+--------------------+--------------------+----------+
|userId|             movieId|              rating|avg_rating|
+------+--------------------+--------------------+----------+
|     1|[1, 3, 6, 47, 50,...|[4.0, 4.0, 4.0, 5...| 4.3663793|
|     2|[318, 333, 1704, ...|[3.0, 4.0, 4.5, 4...| 3.9482758|
|     3|[31, 527, 647, 68...|[0.5, 0.5, 0.5, 0...| 2.4358974|
|     4|[21, 32, 45, 47, ...|[3.0, 2.0, 3.0, 2...| 3.5555556|
|     5|[1, 21, 34, 36, 3...|[4.0, 4.0, 4.0, 4...| 3.6363637|
|     6|[2, 3, 4, 5, 6, 7...|[4.0, 5.0, 3.0, 5...| 3.4936306|
|     7|[1, 50, 58, 150, ...|[4.5, 4.5, 3.0, 4...| 3.2302632|
|     8|[2, 10, 11, 21, 3...|[4.0, 2.0, 4.0, 4...| 3.5744681|
|     9|[41, 187, 223, 37...|[3.0, 3.0, 4.0, 3...| 3.2608695|
|    10|[296, 356, 588, 5...|[1.0, 3.5, 4.0, 3...| 3.2785714|
|    11|[6, 10, 36, 44, 9...|[5.0, 3.0, 4.0, 2...|   3.78125|
|    12|[39, 168, 222, 25...|[4.0, 5.0, 5.0, 5...|  4.390625|
|    13|[47, 305, 597, 11...|[5.0, 1.0, 3.0, 3...| 3.6451614|
|    14|

In [28]:
group = group.pandas_api()
group.head()

Unnamed: 0,userId,movieId,rating,avg_rating
0,1,"[1, 3, 6, 47, 50, 70, 101, 110, 151, 157, 163,...","[4.0, 4.0, 4.0, 5.0, 5.0, 3.0, 5.0, 4.0, 5.0, ...",4.366379
1,2,"[318, 333, 1704, 3578, 6874, 8798, 46970, 4851...","[3.0, 4.0, 4.5, 4.0, 4.0, 3.5, 4.0, 4.0, 4.5, ...",3.948276
2,3,"[31, 527, 647, 688, 720, 849, 914, 1093, 1124,...","[0.5, 0.5, 0.5, 0.5, 0.5, 5.0, 0.5, 0.5, 0.5, ...",2.435897
3,4,"[21, 32, 45, 47, 52, 58, 106, 125, 126, 162, 1...","[3.0, 2.0, 3.0, 2.0, 3.0, 3.0, 4.0, 5.0, 1.0, ...",3.555556
4,5,"[1, 21, 34, 36, 39, 50, 58, 110, 150, 153, 232...","[4.0, 4.0, 4.0, 4.0, 3.0, 4.0, 5.0, 4.0, 3.0, ...",3.636364


In [29]:
rdd_new = group.to_spark()
rdd_new.show()



+------+--------------------+--------------------+----------+
|userId|             movieId|              rating|avg_rating|
+------+--------------------+--------------------+----------+
|     1|[1, 3, 6, 47, 50,...|[4.0, 4.0, 4.0, 5...| 4.3663793|
|     2|[318, 333, 1704, ...|[3.0, 4.0, 4.5, 4...| 3.9482758|
|     3|[31, 527, 647, 68...|[0.5, 0.5, 0.5, 0...| 2.4358974|
|     4|[21, 32, 45, 47, ...|[3.0, 2.0, 3.0, 2...| 3.5555556|
|     5|[1, 21, 34, 36, 3...|[4.0, 4.0, 4.0, 4...| 3.6363637|
|     6|[2, 3, 4, 5, 6, 7...|[4.0, 5.0, 3.0, 5...| 3.4936306|
|     7|[1, 50, 58, 150, ...|[4.5, 4.5, 3.0, 4...| 3.2302632|
|     8|[2, 10, 11, 21, 3...|[4.0, 2.0, 4.0, 4...| 3.5744681|
|     9|[41, 187, 223, 37...|[3.0, 3.0, 4.0, 3...| 3.2608695|
|    10|[296, 356, 588, 5...|[1.0, 3.5, 4.0, 3...| 3.2785714|
|    11|[6, 10, 36, 44, 9...|[5.0, 3.0, 4.0, 2...|   3.78125|
|    12|[39, 168, 222, 25...|[4.0, 5.0, 5.0, 5...|  4.390625|
|    13|[47, 305, 597, 11...|[5.0, 1.0, 3.0, 3...| 3.6451614|
|    14|

In [30]:
rdf_cached = rdf.alias('rdf_cached')
rdf_cached.cache()

DataFrame[userId: int, movieId: int, rating: float, timestamp: int, rating_new: double]

In [31]:
rdf_cached.show()

+------+-------+------+---------+------------------+
|userId|movieId|rating|timestamp|        rating_new|
+------+-------+------+---------+------------------+
|     1|      1|   4.0|964982703| 0.498443016383038|
|     1|      3|   4.0|964981247| 0.498443016383038|
|     1|      6|   4.0|964982224| 0.498443016383038|
|     1|     47|   5.0|964983815| 1.498443016383038|
|     1|     50|   5.0|964982931| 1.498443016383038|
|     1|     70|   3.0|964982400|-0.501556983616962|
|     1|    101|   5.0|964980868| 1.498443016383038|
|     1|    110|   4.0|964982176| 0.498443016383038|
|     1|    151|   5.0|964984041| 1.498443016383038|
|     1|    157|   5.0|964984100| 1.498443016383038|
|     1|    163|   5.0|964983650| 1.498443016383038|
|     1|    216|   5.0|964981208| 1.498443016383038|
|     1|    223|   3.0|964980985|-0.501556983616962|
|     1|    231|   5.0|964981179| 1.498443016383038|
|     1|    235|   4.0|964980908| 0.498443016383038|
|     1|    260|   5.0|964981680| 1.4984430163