In [1]:
import pyspark.pandas as pspd
from pyspark import SparkConf, SparkContext
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, avg
from pyspark.sql.types import StructField, StructType, StringType, IntegerType, FloatType



In [2]:
conf = SparkConf().set("spark.executor.memory", "6g")

In [3]:
sc = SparkContext().getOrCreate(conf=conf)

In [4]:
spark = SparkSession.builder.appName('Collaborative Filtering').getOrCreate()

In [5]:
schema = StructType([
    StructField('userId', IntegerType(), nullable=True),
    StructField('movieId', IntegerType(), nullable=True),
    StructField('rating', FloatType(), nullable=True),
    StructField('timestamp', IntegerType(), nullable=True),
])

In [6]:
# rdd = sc.textFile('ratings.csv')
# rdd.collect()

In [7]:
# headers = rdd.first()
# rdd = rdd.filter(lambda x: x != headers)
# rdd.collect()

In [8]:
%%time
rdf = spark.read.options(header= True,).schema(schema=schema).csv('ratings.csv')
rdf

CPU times: total: 0 ns
Wall time: 1.42 s


DataFrame[userId: int, movieId: int, rating: float, timestamp: int]

In [9]:
rdf.show()

+------+-------+------+---------+
|userId|movieId|rating|timestamp|
+------+-------+------+---------+
|     1|      1|   4.0|964982703|
|     1|      3|   4.0|964981247|
|     1|      6|   4.0|964982224|
|     1|     47|   5.0|964983815|
|     1|     50|   5.0|964982931|
|     1|     70|   3.0|964982400|
|     1|    101|   5.0|964980868|
|     1|    110|   4.0|964982176|
|     1|    151|   5.0|964984041|
|     1|    157|   5.0|964984100|
|     1|    163|   5.0|964983650|
|     1|    216|   5.0|964981208|
|     1|    223|   3.0|964980985|
|     1|    231|   5.0|964981179|
|     1|    235|   4.0|964980908|
|     1|    260|   5.0|964981680|
|     1|    296|   3.0|964982967|
|     1|    316|   3.0|964982310|
|     1|    333|   5.0|964981179|
|     1|    349|   4.0|964982563|
+------+-------+------+---------+
only showing top 20 rows



In [10]:
rdf.printSchema()

root
 |-- userId: integer (nullable = true)
 |-- movieId: integer (nullable = true)
 |-- rating: float (nullable = true)
 |-- timestamp: integer (nullable = true)



In [11]:
rdf.rdd.getNumPartitions()

1

In [12]:
rdf.count()

100836

In [13]:
rdf.select('userId', 'movieId', 'rating').show()

+------+-------+------+
|userId|movieId|rating|
+------+-------+------+
|     1|      1|   4.0|
|     1|      3|   4.0|
|     1|      6|   4.0|
|     1|     47|   5.0|
|     1|     50|   5.0|
|     1|     70|   3.0|
|     1|    101|   5.0|
|     1|    110|   4.0|
|     1|    151|   5.0|
|     1|    157|   5.0|
|     1|    163|   5.0|
|     1|    216|   5.0|
|     1|    223|   3.0|
|     1|    231|   5.0|
|     1|    235|   4.0|
|     1|    260|   5.0|
|     1|    296|   3.0|
|     1|    316|   3.0|
|     1|    333|   5.0|
|     1|    349|   4.0|
+------+-------+------+
only showing top 20 rows



In [14]:
rdf.select(rdf.userId, rdf.movieId, rdf.rating).show()

+------+-------+------+
|userId|movieId|rating|
+------+-------+------+
|     1|      1|   4.0|
|     1|      3|   4.0|
|     1|      6|   4.0|
|     1|     47|   5.0|
|     1|     50|   5.0|
|     1|     70|   3.0|
|     1|    101|   5.0|
|     1|    110|   4.0|
|     1|    151|   5.0|
|     1|    157|   5.0|
|     1|    163|   5.0|
|     1|    216|   5.0|
|     1|    223|   3.0|
|     1|    231|   5.0|
|     1|    235|   4.0|
|     1|    260|   5.0|
|     1|    296|   3.0|
|     1|    316|   3.0|
|     1|    333|   5.0|
|     1|    349|   4.0|
+------+-------+------+
only showing top 20 rows



In [15]:
def normalize(rdf, column):
    mean = rdf.agg(avg(col(column)))
    mean = mean.collect()[0][0]
    rdf = rdf.withColumn(column+'_new', col(column)-mean)
    
    return rdf

In [16]:
df = rdf.pandas_api()
df.head()

Unnamed: 0,userId,movieId,rating,timestamp
0,1,1,4.0,964982703
1,1,3,4.0,964981247
2,1,6,4.0,964982224
3,1,47,5.0,964983815
4,1,50,5.0,964982931


In [17]:
rdf = normalize(rdf, 'rating')
rdf.show()

+------+-------+------+---------+------------------+
|userId|movieId|rating|timestamp|        rating_new|
+------+-------+------+---------+------------------+
|     1|      1|   4.0|964982703| 0.498443016383038|
|     1|      3|   4.0|964981247| 0.498443016383038|
|     1|      6|   4.0|964982224| 0.498443016383038|
|     1|     47|   5.0|964983815| 1.498443016383038|
|     1|     50|   5.0|964982931| 1.498443016383038|
|     1|     70|   3.0|964982400|-0.501556983616962|
|     1|    101|   5.0|964980868| 1.498443016383038|
|     1|    110|   4.0|964982176| 0.498443016383038|
|     1|    151|   5.0|964984041| 1.498443016383038|
|     1|    157|   5.0|964984100| 1.498443016383038|
|     1|    163|   5.0|964983650| 1.498443016383038|
|     1|    216|   5.0|964981208| 1.498443016383038|
|     1|    223|   3.0|964980985|-0.501556983616962|
|     1|    231|   5.0|964981179| 1.498443016383038|
|     1|    235|   4.0|964980908| 0.498443016383038|
|     1|    260|   5.0|964981680| 1.4984430163

In [18]:
df['rating'] = df['rating']-df['rating'].mean()
df.head()

Unnamed: 0,userId,movieId,rating,timestamp
0,1,1,0.498443,964982703
1,1,3,0.498443,964981247
2,1,6,0.498443,964982224
3,1,47,1.498443,964983815
4,1,50,1.498443,964982931


In [19]:
sdf = df.to_spark()
sdf.show()

+------+-------+------------------+---------+
|userId|movieId|            rating|timestamp|
+------+-------+------------------+---------+
|     1|      1| 0.498443016383038|964982703|
|     1|      3| 0.498443016383038|964981247|
|     1|      6| 0.498443016383038|964982224|
|     1|     47| 1.498443016383038|964983815|
|     1|     50| 1.498443016383038|964982931|
|     1|     70|-0.501556983616962|964982400|
|     1|    101| 1.498443016383038|964980868|
|     1|    110| 0.498443016383038|964982176|
|     1|    151| 1.498443016383038|964984041|
|     1|    157| 1.498443016383038|964984100|
|     1|    163| 1.498443016383038|964983650|
|     1|    216| 1.498443016383038|964981208|
|     1|    223|-0.501556983616962|964980985|
|     1|    231| 1.498443016383038|964981179|
|     1|    235| 0.498443016383038|964980908|
|     1|    260| 1.498443016383038|964981680|
|     1|    296|-0.501556983616962|964982967|
|     1|    316|-0.501556983616962|964982310|
|     1|    333| 1.498443016383038

