<a href="https://colab.research.google.com/github/yjs1210/movie-recommendations/blob/master/recommendationSystemBaseline.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [0]:
!apt-get install openjdk-8-jdk-headless -qq > /dev/null
!wget -q https://www-us.apache.org/dist/spark/spark-2.4.4/spark-2.4.4-bin-hadoop2.7.tgz
!tar xf spark-2.4.4-bin-hadoop2.7.tgz
!pip install -q findspark

In [0]:
import os
os.environ["JAVA_HOME"] = "/usr/lib/jvm/java-8-openjdk-amd64"
os.environ["SPARK_HOME"] = "/content/spark-2.4.4-bin-hadoop2.7"

In [0]:
import findspark
findspark.init()
from pyspark.sql import SparkSession
spark = SparkSession.builder.master("local[*]").getOrCreate()
import pyspark.sql.functions as F
import pyspark.sql.types as T

In [0]:
ratings = spark.read.csv('/content/drive/My Drive/RecommendationSystems/ml-20m/ratings.csv', header = True, inferSchema=True).limit(1000)
ratings.printSchema()

root
 |-- userId: integer (nullable = true)
 |-- movieId: integer (nullable = true)
 |-- rating: double (nullable = true)
 |-- timestamp: integer (nullable = true)



In [0]:
class BaselineModel:

  def __init__(self, user_column, item_column,ratings_column):
    self.user_col = user_column
    self.item_col = item_column
    self.ratings_col = ratings_column

  def __find_avg_of_col(self, data, column):
    return data.select(F.mean(F.col(column))).collect()[0][0]

  def __subtract_from_col(self, data, column, value):
    return data.withColumn('normalized_' + column, F.col(column)-value)

  def train(self, training_data):
    user_col = self.user_col
    item_col = self.item_col
    ratings_col = self.ratings_col
    user_bias = {}
    item_bias = {}

    # find average and calculate bias for user and item
    avg_rating = self.__find_avg_of_col(training_data, ratings_col)
    norm_training_data = self.__subtract_from_col(training_data, ratings_col, avg_rating)
    user_bias = norm_training_data.groupby(user_col).agg(F.avg('normalized_' + ratings_col)).rdd.map(lambda x : (x[0],x[1])).collectAsMap()
    item_bias = norm_training_data.groupby(item_col).agg(F.avg('normalized_' + ratings_col)).rdd.map(lambda x : (x[0],x[1])).collectAsMap()

    self.training_data = norm_training_data
    self.avg_rating = avg_rating
    self.user_bias = user_bias
    self.item_bias = item_bias
    
    return avg_rating, user_bias, item_bias

In [0]:
baselineModel = BaselineModel(user_column='userId', item_column='movieId', ratings_column='rating')

In [0]:
avg_rating, user_bias, item_bias = baselineModel.train(ratings)

In [0]:
def udf_predict(user_id, item_id):
  return avg_rating + user_bias[user_id] + item_bias[item_id]

In [0]:
def predict(test_data, user_column, item_column, ratings_column):
  predict_udf = F.udf(udf_predict, T.FloatType())
  return test_data.withColumn(ratings_column, predict_udf(F.col(user_column), F.col(item_column)))

In [0]:
predict(ratings, user_column='userId', item_column='movieId', ratings_column='rating').show()

+------+-------+---------+----------+
|userId|movieId|   rating| timestamp|
+------+-------+---------+----------+
|     1|      2| 3.264857|1112486027|
|     1|     29| 3.514857|1112484676|
|     1|     32| 3.389857|1112484819|
|     1|     47|4.2648573|1112484727|
|     1|     50|4.2648573|1112484580|
|     1|    112| 3.514857|1094785740|
|     1|    151| 3.514857|1094785734|
|     1|    223|4.5148573|1112485573|
|     1|    253| 4.181524|1112484940|
|     1|    260|4.6398573|1112484826|
|     1|    293|4.0148573|1112484703|
|     1|    296| 4.181524|1112484767|
|     1|    318|4.7648573|1112484798|
|     1|    337| 3.264857|1094785709|
|     1|    367|3.3148572|1112485980|
|     1|    541| 4.681524|1112484603|
|     1|    589|4.5148573|1112485557|
|     1|    593|4.1398573|1112484661|
|     1|    653| 3.514857|1094785691|
|     1|    919| 3.764857|1094785621|
+------+-------+---------+----------+
only showing top 20 rows



In [0]:
ratings.show()

+------+-------+------+----------+
|userId|movieId|rating| timestamp|
+------+-------+------+----------+
|     1|      2|   3.5|1112486027|
|     1|     29|   3.5|1112484676|
|     1|     32|   3.5|1112484819|
|     1|     47|   3.5|1112484727|
|     1|     50|   3.5|1112484580|
|     1|    112|   3.5|1094785740|
|     1|    151|   4.0|1094785734|
|     1|    223|   4.0|1112485573|
|     1|    253|   4.0|1112484940|
|     1|    260|   4.0|1112484826|
|     1|    293|   4.0|1112484703|
|     1|    296|   4.0|1112484767|
|     1|    318|   4.0|1112484798|
|     1|    337|   3.5|1094785709|
|     1|    367|   3.5|1112485980|
|     1|    541|   4.0|1112484603|
|     1|    589|   3.5|1112485557|
|     1|    593|   3.5|1112484661|
|     1|    653|   3.0|1094785691|
|     1|    919|   3.5|1094785621|
+------+-------+------+----------+
only showing top 20 rows

