In [1]:
import pyspark
from pyspark.conf import SparkConf
from pyspark.sql import SparkSession
from pyspark.sql.dataframe import DataFrame
from pyspark.sql.functions import count, desc, expr, lag, round, unix_timestamp, col, round, sum, min, max, row_number, collect_list
from pyspark.sql.types import Row
from pyspark.sql.window import Window
from pyspark.sql.types import IntegerType, StringType, StructField, StructType, TimestampType

In [9]:
SESSION_PATH = "../lastfm-dataset-1K/user-session-track.tsv"

SESSION_SCHEMA = StructType(
    [
        StructField("userid", StringType(), False),
        StructField("timestamp", TimestampType(), True),
        StructField("artistid", StringType(), True),
        StructField("artistname", StringType(), True),
        StructField("trackid", StringType(), True),
        StructField("trackname", StringType(), True)
    ]
)

conf = SparkConf()

# this needs to be configured if running locally and depending on machine specs
# to avoid warnings similar to
# https://stackoverflow.com/questions/46907447/meaning-of-apache-spark-warning-calling-spill-on-rowbasedkeyvaluebatch
conf.set("spark.driver.memory", "4g")
conf.set("spark.executor.memory", "4g")

spark = SparkSession.builder.appName("lastfm").config(conf=conf).getOrCreate()
spark.sparkContext.setLogLevel('ERROR')
data = (spark.read.format("csv")
        .option("header", "false")
        .option("delimiter", "\t")
        .schema(SESSION_SCHEMA)
        .load(SESSION_PATH)
    )

data.printSchema()

root
 |-- userid: string (nullable = true)
 |-- timestamp: timestamp (nullable = true)
 |-- artistid: string (nullable = true)
 |-- artistname: string (nullable = true)
 |-- trackid: string (nullable = true)
 |-- trackname: string (nullable = true)



In [10]:
cols_to_drop = ("artistid", "trackid")
data = data.repartition(4)
df = data.drop(*cols_to_drop)
df.show()



+-----------+-------------------+--------------------+--------------------+
|     userid|          timestamp|          artistname|           trackname|
+-----------+-------------------+--------------------+--------------------+
|user_000025|2006-04-25 18:01:40|                Moby|           Down Slow|
|user_000024|2007-05-28 16:57:32|           Destroyer|       The Crossover|
|user_000001|2009-01-05 15:16:41|               Plaid|             Seizure|
|user_000041|2008-07-22 01:31:39|     Younger Brother|  Magic Monkey Juice|
|user_000019|2006-01-08 23:12:58|Einstürzende Neub...|Three Thoughts (D...|
|user_000035|2009-04-04 01:42:31|            Legowelt|Are You Really So...|
|user_000010|2009-04-14 12:43:00|         The Prodigy|                Omen|
|user_000034|2008-12-23 17:35:25|      Black Mountain|         Stormy High|
|user_000026|2006-08-16 16:16:23|           Riverside|   I Turned You Down|
|user_000025|2006-11-02 21:33:20|              Nebula|    Out Of Your Head|
|user_000008

                                                                                

In [11]:
limit,session_cutoff=10,20
w1 = Window.partitionBy("userid").orderBy("timestamp")

df1 = df.withColumn("pretimestamp", lag("timestamp").over(w1)) \
.withColumn("delta_mins", round((col("timestamp").cast("long") - col("pretimestamp").cast("long"))/60)) \
.withColumn("sessionflag", expr(f"CASE WHEN delta_mins > {session_cutoff} OR delta_mins IS NULL THEN 1 ELSE 0 END")) \
.withColumn('sessionID', sum("sessionflag").over(w1)).cache()
                  

In [16]:
df1.select('userid', 'timestamp', 'pretimestamp', 'delta_mins', 'sessionflag', 'sessionID').show()

+-----------+-------------------+-------------------+----------+-----------+---------+
|     userid|          timestamp|       pretimestamp|delta_mins|sessionflag|sessionID|
+-----------+-------------------+-------------------+----------+-----------+---------+
|user_000066|2006-05-09 23:12:52|               null|      null|          1|        1|
|user_000066|2006-05-09 23:31:12|2006-05-09 23:12:52|      18.0|          0|        1|
|user_000066|2006-05-10 01:35:09|2006-05-09 23:31:12|     124.0|          1|        2|
|user_000066|2006-05-10 01:39:47|2006-05-10 01:35:09|       5.0|          0|        2|
|user_000066|2006-05-10 01:47:24|2006-05-10 01:39:47|       8.0|          0|        2|
|user_000066|2006-05-10 02:03:59|2006-05-10 01:47:24|      17.0|          0|        2|
|user_000066|2006-05-10 02:25:20|2006-05-10 02:03:59|      21.0|          1|        3|
|user_000066|2006-05-10 02:34:15|2006-05-10 02:25:20|       9.0|          0|        3|
|user_000066|2006-05-10 02:48:34|2006-05-10

In [None]:
df2 = df1.groupBy("userid", "sessionID") \
.agg(min("timestamp").alias("session_start_ts"),\
max("timestamp").alias("session_end_ts")) \
.withColumn("session_length(hrs)", round((col("session_end_ts").cast("long") - col("session_start_ts").cast("long"))/3600)) \
.orderBy(desc('session_length(hrs)')).limit(10).cache()


In [12]:
df2.show()

+-----------+---------+-------------------+-------------------+-------------------+
|     userid|sessionID|   session_start_ts|     session_end_ts|session_length(hrs)|
+-----------+---------+-------------------+-------------------+-------------------+
|user_000949|      149|2006-02-12 17:49:31|2006-02-27 11:29:37|              354.0|
|user_000997|       18|2007-04-26 01:36:02|2007-05-10 18:55:03|              353.0|
|user_000949|      553|2007-05-01 03:41:15|2007-05-14 01:05:52|              309.0|
|user_000544|       75|2007-02-12 13:03:52|2007-02-23 00:51:08|              252.0|
|user_000949|      137|2005-12-09 08:26:38|2005-12-18 04:40:04|              212.0|
|user_000949|      187|2006-03-18 23:04:14|2006-03-26 19:13:45|              187.0|
|user_000949|      123|2005-11-11 03:30:37|2005-11-18 22:50:07|              187.0|
|user_000544|       55|2007-01-06 01:07:04|2007-01-13 13:57:45|              181.0|
|user_000250|     1258|2008-02-21 15:31:45|2008-02-28 21:18:03|             

In [13]:
df3 = df2.join(df1,["userid","sessionID"]) \
.select("userid","sessionID", "trackname", "session_length(hrs)") \
.groupBy("userid", "sessionID", "session_length(hrs)") \
.agg(collect_list("trackname").alias('tracklist')) \
.orderBy(desc('session_length(hrs)')) \
.cache()


In [14]:
df3.show()

+-----------+---------+-------------------+--------------------+
|     userid|sessionID|session_length(hrs)|           tracklist|
+-----------+---------+-------------------+--------------------+
|user_000949|      149|              354.0|[Chained To You, ...|
|user_000997|       18|              353.0|[Unentitled State...|
|user_000949|      553|              309.0|[White Daisy Pass...|
|user_000544|       75|              252.0|[Finally Woken, O...|
|user_000949|      137|              212.0|[Neighborhood #2 ...|
|user_000949|      123|              187.0|[Excuse Me Miss A...|
|user_000949|      187|              187.0|[Disco Science, H...|
|user_000544|       55|              181.0|[La Murga, Breath...|
|user_000250|     1258|              174.0|[Lazarus Heart, S...|
|user_000949|      150|              170.0|[Y-Control, Banqu...|
+-----------+---------+-------------------+--------------------+

