In [11]:
import pandas as pd
import numpy as np
from pyspark.sql import SparkSession
from pyspark import SparkContext, SparkConf
import pyspark.ml as M
import pyspark.sql.functions as F
import pyspark.sql.types as T
import pyspark.sql.window as W

In [2]:
import psutil
NUM_WORKER = psutil.cpu_count(logical = False)

In [3]:
conf_spark = SparkConf().set("spark.driver.host", "127.0.0.1")
sc = SparkContext(conf = conf_spark)
sc.setLogLevel("ERROR")
spark = SparkSession(sc)
spark.conf.set("spark.sql.shuffle.partitions", NUM_WORKER)
spark

In [6]:
ratings = spark.read.csv("./ml-latest/ratings.csv",header=True)
ratings.head(3)

[Row(userId='1', movieId='307', rating='3.5', timestamp='1256677221'),
 Row(userId='1', movieId='481', rating='3.5', timestamp='1256677456'),
 Row(userId='1', movieId='1091', rating='1.5', timestamp='1256677471')]

# function

In [1]:
# return the ratings df by spliting the latest n movies each user in the df watched
import pyspark.sql.functions as F
import pyspark.sql.window as W
def train_test_split_by_time(ratings,n=2):
    window = W.Window.partitionBy(ratings['userId']).orderBy(ratings['timestamp'].desc())
    ranked = ratings.select('*', F.rank().over(window).alias('rank'))
    train = ranked.select(["userId","movieId","rating","timestamp"]).where(ranked.rank > n)
    test = ranked.select(["userId","movieId","rating","timestamp"]).where(ranked.rank <= n)
    # test may contain more than n records for each user because the timestamp is the same
    return train,test

In [17]:

window = W.Window.partitionBy(ratings['userId']).orderBy(ratings['timestamp'].desc())

ranked = ratings.select('*', F.rank().over(window).alias('rank'))

train = ranked.select(["userId","movieId","rating","timestamp"]).where(ranked.rank > 2)#n=2
test = ranked.select(["userId","movieId","rating","timestamp"]).where(ranked.rank <= 2)#n=2
# test may contain more than n records for each user because the timestamp is the same

In [21]:
ranked.show()

+------+-------+------+----------+----+
|userId|movieId|rating| timestamp|rank|
+------+-------+------+----------+----+
|     1|   2840|   3.0|1256677500|   1|
|     1|   2986|   2.5|1256677496|   2|
|     1|   3893|   3.5|1256677486|   3|
|     1|   1591|   1.5|1256677475|   4|
|     1|   1091|   1.5|1256677471|   5|
|     1|   2134|   4.5|1256677464|   6|
|     1|   1257|   4.5|1256677460|   7|
|     1|    481|   3.5|1256677456|   8|
|     1|   3424|   4.5|1256677444|   9|
|     1|   1449|   4.5|1256677264|  10|
|     1|   3020|   4.0|1256677260|  11|
|     1|   3698|   3.5|1256677243|  12|
|     1|   2478|   4.0|1256677239|  13|
|     1|   1590|   2.5|1256677236|  14|
|     1|    307|   3.5|1256677221|  15|
|     1|   3826|   2.0|1256677210|  16|
|100002|   1639|   5.0| 895853114|   1|
|100002|    329|   5.0| 869227428|   2|
|100002|    151|   5.0| 869227384|   3|
|100002|    338|   3.0| 869227384|   3|
+------+-------+------+----------+----+
only showing top 20 rows



In [18]:
train.show()

+------+-------+------+----------+
|userId|movieId|rating| timestamp|
+------+-------+------+----------+
|     1|   3893|   3.5|1256677486|
|     1|   1591|   1.5|1256677475|
|     1|   1091|   1.5|1256677471|
|     1|   2134|   4.5|1256677464|
|     1|   1257|   4.5|1256677460|
+------+-------+------+----------+
only showing top 5 rows



In [19]:
test.show(5)

+------+-------+------+----------+
|userId|movieId|rating| timestamp|
+------+-------+------+----------+
|     1|   2840|   3.0|1256677500|
|     1|   2986|   2.5|1256677496|
|100002|   1639|   5.0| 895853114|
|100002|    329|   5.0| 869227428|
| 10001|  46578|   4.5|1300426052|
+------+-------+------+----------+
only showing top 5 rows



In [22]:
test.groupBy("userId").agg(F.count("userId")).show(20)

+------+-------------+
|userId|count(userId)|
+------+-------------+
|     1|            2|
|100002|            2|
| 10001|            2|
|100010|            3|
|100015|            2|
|100018|            2|
|100023|            2|
|100028|            7|
|100039|            5|
|100042|            2|
|100056|            3|
|100063|            4|
|100065|            2|
|100066|            2|
|100073|            2|
|100077|            2|
| 10008|            2|
|100085|            2|
|100088|            2|
|100096|            2|
+------+-------------+
only showing top 20 rows



In [23]:
train.groupBy("userId").agg(F.count("userId")).show(20)

+------+-------------+
|userId|count(userId)|
+------+-------------+
|     1|           14|
|100002|          110|
| 10001|           16|
|100010|            3|
|100015|           11|
|100018|           14|
|100023|           70|
|100028|           26|
|100042|            3|
|100056|           60|
|100063|           17|
|100065|          254|
|100066|           73|
|100073|           13|
| 10008|           19|
|100085|            3|
|100088|           15|
|100096|            8|
|100100|           19|
|100103|           24|
+------+-------------+
only showing top 20 rows

