<a href="https://colab.research.google.com/github/yjs1210/movie-recommendations/blob/master/recommendationSystemsTrainTestSplit.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3Aietf%3Awg%3Aoauth%3A2.0%3Aoob&scope=email%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdocs.test%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdrive%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdrive.photos.readonly%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fpeopleapi.readonly&response_type=code

Enter your authorization code:
··········
Mounted at /content/drive


In [0]:
!apt-get install openjdk-8-jdk-headless -qq > /dev/null
!wget -q https://www-us.apache.org/dist/spark/spark-2.4.4/spark-2.4.4-bin-hadoop2.7.tgz
!tar xf spark-2.4.4-bin-hadoop2.7.tgz
!pip install -q findspark

In [0]:
# setup libraries and env
import os
os.environ["JAVA_HOME"] = "/usr/lib/jvm/java-8-openjdk-amd64"
os.environ["SPARK_HOME"] = "/content/spark-2.4.4-bin-hadoop2.7"
import findspark
findspark.init()
from pyspark.sql import SparkSession
spark = SparkSession.builder.master("local[*]").getOrCreate()
import pyspark.sql.functions as F
import pyspark.sql.types as T
from pyspark.sql.window import *
import numpy as np

In [5]:
# read ratings dataset and process timestamp column as timestamp
ratings = spark.read.csv('/content/drive/My Drive/RecommendationSystems/ml-20m/ratings.csv', header = True, inferSchema=True)
ratings = ratings.withColumn("timestamp", ratings["timestamp"].cast(T.TimestampType()))
ratings.printSchema()

root
 |-- userId: integer (nullable = true)
 |-- movieId: integer (nullable = true)
 |-- rating: double (nullable = true)
 |-- timestamp: timestamp (nullable = true)



In [0]:
# limit below which column entities would be removed
limit = 10
# test size
test_size = 0.05
def udf_user_limit(user_counts):
  if user_counts < limit:
    return -1
  else:
    if (user_counts * (1-test_size)) <= limit:
      return limit
    else:
      return int(np.around(user_counts * (1-test_size)))
user_limit = F.udf(udf_user_limit, T.IntegerType())

In [0]:
# train-test function
def train_test_split(data, col_to_split_on= 'userId', timestamp_col = 'timestamp'):
  # original columns
  orig_cols = data.columns
  # define count col
  count_col = 'count(' + col_to_split_on + ')'
  # do user ratings count
  counts_for_col = data.groupby(col_to_split_on).agg(F.count(col_to_split_on))
  # generate limits on each user based on rules
  limits = counts_for_col.withColumn('train_limit', user_limit(F.col(count_col)))
  # remove users having less than limit no of ratings
  limits_filtered = limits.filter(limits.train_limit > 0)
  # generate row numbers based on temporality
  data_row_num = data.withColumn("row_num", F.row_number().over(Window.partitionBy(col_to_split_on).orderBy(timestamp_col)))
  # join ratings and user counts dfs together
  data_row_num = data_row_num.alias('a')
  limits_filtered = limits_filtered.alias('b')
  merged_data = data_row_num.join(limits_filtered,F.col('b.' + col_to_split_on) == F.col('a.' + col_to_split_on)).select([F.col('a.'+xx) for xx in data_row_num.columns] + [F.col('b.' + count_col),F.col('b.train_limit')])
  # generate selection column based on number limit
  final_train_test = merged_data.withColumn('selection', F.col('row_num') <= F.col('train_limit'))
  # find train and test 
  train = final_train_test.filter(final_train_test.selection == True).select(orig_cols)
  test = final_train_test.filter(final_train_test.selection == False).select(orig_cols)
  return train, test

In [0]:
# to check if test has more entities in any column
def compatibility_test(train, test, cols_to_test = ['userId', 'movieId']):
  cols_greater = []
  for i in cols_to_test:
    train_unique = train.select([i]).distinct().rdd.map(lambda x: x[0]).collect()
    test_unique = test.select([i]).distinct().rdd.map(lambda x: x[0]).collect()
    size = len(list(set(test_unique) - set(train_unique)))
    print("Test has %d more %s" %(size, i))
    cols_greater.append(i)
  return cols_greater

In [0]:
## change col_to_split_on to 'movieId', that would be better and make more sense
train, test = train_test_split(ratings, col_to_split_on= 'movieId', timestamp_col='timestamp')

In [10]:
train.limit(100).show()

+------+-------+------+-------------------+
|userId|movieId|rating|          timestamp|
+------+-------+------+-------------------+
|   148|     86|   2.0|2002-04-16 14:06:04|
|   148|    908|   4.0|2002-04-16 14:06:04|
|   148|   2103|   2.0|2002-04-16 14:06:04|
|   148|   1210|   5.0|2002-04-16 14:06:26|
|   148|   1097|   5.0|2002-04-16 14:06:56|
|   148|   1453|   3.0|2002-04-16 14:06:56|
|   148|   1968|   3.0|2002-04-16 14:06:56|
|   148|     18|   1.0|2002-04-16 14:07:27|
|   148|    368|   3.0|2002-04-16 14:07:27|
|   148|   1270|   5.0|2002-04-16 14:07:27|
|   148|   5218|   4.0|2002-04-16 14:08:44|
|   148|   5009|   2.0|2002-04-16 14:09:29|
|   148|   5067|   3.0|2002-04-16 14:10:05|
|   148|   5093|   1.0|2002-04-16 14:10:05|
|   148|   5128|   1.0|2002-04-16 14:10:21|
|   148|   5171|   3.0|2002-04-16 14:10:46|
|   148|   4995|   5.0|2002-04-16 14:12:38|
|   148|   4975|   1.0|2002-04-16 14:12:55|
|   148|   3512|   2.0|2002-04-16 14:13:21|
|   148|   3978|   4.0|2002-04-1

In [11]:
test.limit(100).show()

+------+-------+------+-------------------+
|userId|movieId|rating|          timestamp|
+------+-------+------+-------------------+
|   148|   2168|   4.0|2002-04-16 15:12:25|
|   148|   5303|   3.0|2002-04-16 15:12:25|
|   148|   2629|   4.0|2002-04-16 15:12:33|
|   148|   3259|   3.0|2002-04-16 15:12:47|
|   148|   1286|   3.0|2002-04-16 15:13:10|
|   148|   1409|   3.0|2002-04-16 15:13:10|
|   463|    477|   4.0|1996-05-30 14:35:01|
|   463|     45|   4.0|1996-05-30 14:35:40|
|   463|    281|   5.0|1996-05-30 14:35:40|
|   463|    648|   3.0|1996-05-31 07:14:40|
|   471|    588|   4.5|2009-08-08 03:38:16|
|   471|   1265|   3.5|2009-08-08 03:40:19|
|   471|    292|   3.0|2009-08-08 03:41:43|
|   471|   6874|   4.5|2009-08-08 03:48:25|
|   471|   4262|   4.0|2009-08-08 03:48:35|
|   471|   3702|   3.0|2009-08-08 03:48:44|
|   471|   3703|   3.0|2009-08-08 03:48:54|
|   471|   3421|   3.0|2009-08-08 03:49:16|
|   471|   3107|   2.0|2009-08-08 03:49:45|
|   471|   2881|   4.0|2009-08-0

In [0]:
print((train.count(), len(train.columns)))

In [0]:
print((test.count(), len(test.columns)))

In [16]:
cols_greater = compatibility_test(train, test, cols_to_test=['userId', 'movieId'])

Test has 0 more userId
Test has 1594 more movieId


In [0]:
for i in cols_greater:
  train_unique = train.select([i]).distinct().rdd.map(lambda x: x[0]).collect()
  test_unique = test.select([i]).distinct().rdd.map(lambda x: x[0]).collect()
  unique_to_test = list(set(test_unique) - set(train_unique))
  test = test[~test[i].isin(unique_to_test)]

In [18]:
print((train.count(), len(train.columns)))

(19002066, 4)


In [19]:
print((test.count(), len(test.columns)))

(996207, 4)


In [20]:
# should not have any entities extra in test now --
cols_greater = compatibility_test(train, test, cols_to_test=['userId', 'movieId'])

Test has 0 more userId
Test has 0 more movieId


In [0]:
# save dataframes
train.write.csv('train.csv')
test.write.csv('test.csv')