In [1]:
raw_user_artist_path = "/content/user_artist_data.txt"

In [3]:
from pyspark.shell import spark

raw_user_artist_data = spark.read.text(raw_user_artist_path)

Collecting pyspark
  Downloading pyspark-3.5.1.tar.gz (317.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m317.0/317.0 MB[0m [31m2.3 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: pyspark
  Building wheel for pyspark (setup.py) ... [?25l[?25hdone
  Created wheel for pyspark: filename=pyspark-3.5.1-py2.py3-none-any.whl size=317488491 sha256=b58b1a04ebe0a99ebac99754a224ad66bcf279e02aa221240ad758e0a3ce3b10
  Stored in directory: /root/.cache/pip/wheels/80/1d/60/2c256ed38dddce2fdd93be545214a63e02fbd8d74fb0b7f3a6
Successfully built pyspark
Installing collected packages: pyspark
Successfully installed pyspark-3.5.1
Welcome to
      ____              __
     / __/__  ___ _____/ /__
    _\ \/ _ \/ _ `/ __/  '_/
   /__ / .__/\_,_/_/ /_/\_\   version 3.5.1
      /_/

Using Python version 3.10.12 (main, Nov 20 2023 15:14:05)
Spark context Web UI available at http://899b2e8b7c2c:4040
Spark c

In [4]:
raw_user_artist_data.show(5)

+-------------------+
|              value|
+-------------------+
|       1000002 1 55|
| 1000002 1000006 33|
|  1000002 1000007 8|
|1000002 1000009 144|
|1000002 1000010 314|
+-------------------+
only showing top 5 rows



In [5]:
raw_artist_data = spark.read.text("/content/artist_data.txt")

In [6]:
raw_artist_data.show(5)

+--------------------+
|               value|
+--------------------+
|1134999\t06Crazy ...|
|6821360\tPang Nak...|
|10113088\tTerfel,...|
|10151459\tThe Fla...|
|6826647\tBodensta...|
+--------------------+
only showing top 5 rows



In [7]:
raw_artist_alias = spark.read.text("/content/artist_alias.txt")

In [8]:
raw_artist_alias.show(5)

+-----------------+
|            value|
+-----------------+
| 1092764\t1000311|
| 1095122\t1000557|
| 6708070\t1007267|
|10088054\t1042317|
| 1195917\t1042317|
+-----------------+
only showing top 5 rows



In [9]:
from pyspark.sql.functions import split, min, max
from pyspark.sql.types import IntegerType, StringType

user_artist_df = raw_user_artist_data.withColumn("user",
                                                 split(raw_user_artist_data["value"], " "). \
                                                 getItem(0). \
                                                 cast(IntegerType()))

user_artist_df = user_artist_df.withColumn("artist",
                                             split(raw_user_artist_data["value"], " "). \
                                             getItem(1). \
                                             cast(IntegerType()))

user_artist_df = user_artist_df.withColumn("count",
                                                split(raw_user_artist_data["value"], " "). \
                                                getItem(2). \
                                                cast(IntegerType())). \
                                                drop("value")

user_artist_df.select([min("user"), max("user"), min("artist"), max("artist")]).show()

+---------+---------+-----------+-----------+
|min(user)|max(user)|min(artist)|max(artist)|
+---------+---------+-----------+-----------+
|       90|  2443548|          1|   10794401|
+---------+---------+-----------+-----------+



In [10]:
from pyspark.sql.functions import col


artist_by_id = raw_artist_data.withColumn("id",
                                            split(col("value"), "\t", 2). \
                                            getItem(0). \
                                            cast(IntegerType()))

artist_by_id = artist_by_id.withColumn("name",
                                        split(col("value"), "\t", 2). \
                                        getItem(1). \
                                        cast(StringType())). \
                                        drop("value")

artist_by_id.show(5)

+--------+--------------------+
|      id|                name|
+--------+--------------------+
| 1134999|        06Crazy Life|
| 6821360|        Pang Nakarin|
|10113088|Terfel, Bartoli- ...|
|10151459| The Flaming Sidebur|
| 6826647|   Bodenstandig 3000|
+--------+--------------------+
only showing top 5 rows



In [11]:
artist_alias = raw_artist_alias.withColumn("artist",
                                            split(col("value"), "\t"). \
                                            getItem(0). \
                                            cast(IntegerType())). \
                                withColumn("alias",
                                            split(col("value"), "\t"). \
                                            getItem(1). \
                                            cast(IntegerType())). \
                                            drop("value")

artist_alias.show(5)

+--------+-------+
|  artist|  alias|
+--------+-------+
| 1092764|1000311|
| 1095122|1000557|
| 6708070|1007267|
|10088054|1042317|
| 1195917|1042317|
+--------+-------+
only showing top 5 rows



In [24]:
artist_by_id.filter(artist_by_id.id.isin(1092764, 1000311)).show()

+-------+--------------+
|     id|          name|
+-------+--------------+
|1000311| Steve Winwood|
|1092764|Winwood, Steve|
+-------+--------------+



In [13]:
from pyspark.sql.functions import broadcast, when


train_data = train_data = user_artist_df.join(broadcast(artist_alias), "artist", how="left")

train_data = train_data.withColumn("artist",
                                    when(col("alias").isNull(), col("artist")). \
                                    otherwise(col("alias")))

train_data = train_data.withColumn("artist",
                                    col("artist").cast(IntegerType())). \
                                    drop("alias")

train_data.cache()

train_data.count()

24296858

In [14]:
from pyspark.ml.recommendation import ALS


model = ALS(rank=10, seed=0, maxIter=5, regParam=0.1, implicitPrefs=True, alpha=1.0, userCol="user", itemCol="artist", ratingCol="count"). \
        fit(train_data)

In [15]:
model.userFactors.show(1, truncate=False)

+---+------------------------------------------------------------------------------------------------------------------------------+
|id |features                                                                                                                      |
+---+------------------------------------------------------------------------------------------------------------------------------+
|90 |[0.16020624, 0.20717518, -0.17194684, 0.06038469, 0.062727705, 0.54658705, -0.40481892, 0.43657345, -0.10396775, -0.042728312]|
+---+------------------------------------------------------------------------------------------------------------------------------+
only showing top 1 row

