In [1]:
!pip install findspark



In [2]:
!pip install pyspark




In [3]:
import pyspark

print(pyspark.__version__)

3.4.0


In [4]:
def warn(*args, **kwargs):
    pass
import warnings
warnings.warn = warn
warnings.filterwarnings('ignore')
import findspark
findspark.init()

In [5]:

from pyspark.sql import SparkSession

# Khởi tạo SparkSession
spark = SparkSession.builder.appName("model").master("spark://spark-master:7077").getOrCreate()
print(spark.version)

3.4.0


In [6]:
df = spark.read.csv("hdfs://namenode:9000/data/train_processed.csv", header=True, inferSchema=True)
df.show(5)

+-----------------+--------------------+----------+------------+------+---+----+-------+--------------+-----+-----------------+----------------+--------------------+-------------------+------------------+------------------+----------------------+-------------------+-----------------+
|      artist_name|          track_name|popularity|danceability|energy|key|mode|valence|time_signature|class|artist_name_index|track_name_index|duration_in_ms_trans|     loudness_trans| speechiness_trans|acousticness_trans|instrumentalness_trans|     liveness_trans|      tempo_trans|
+-----------------+--------------------+----------+------------+------+---+----+-------+--------------+-----+-----------------+----------------+--------------------+-------------------+------------------+------------------+----------------------+-------------------+-----------------+
|   Blonde Redhead|Misery Is a Butte...|      42.0|       0.548|  0.77|9.0| 0.0|   0.41|           4.0|  6.0|           1477.0|          7946.0| 

In [7]:
df.cache()

DataFrame[artist_name: string, track_name: string, popularity: double, danceability: double, energy: double, key: double, mode: double, valence: double, time_signature: double, class: double, artist_name_index: double, track_name_index: double, duration_in_ms_trans: double, loudness_trans: double, speechiness_trans: double, acousticness_trans: double, instrumentalness_trans: double, liveness_trans: double, tempo_trans: double]

In [8]:
df.printSchema()

root
 |-- artist_name: string (nullable = true)
 |-- track_name: string (nullable = true)
 |-- popularity: double (nullable = true)
 |-- danceability: double (nullable = true)
 |-- energy: double (nullable = true)
 |-- key: double (nullable = true)
 |-- mode: double (nullable = true)
 |-- valence: double (nullable = true)
 |-- time_signature: double (nullable = true)
 |-- class: double (nullable = true)
 |-- artist_name_index: double (nullable = true)
 |-- track_name_index: double (nullable = true)
 |-- duration_in_ms_trans: double (nullable = true)
 |-- loudness_trans: double (nullable = true)
 |-- speechiness_trans: double (nullable = true)
 |-- acousticness_trans: double (nullable = true)
 |-- instrumentalness_trans: double (nullable = true)
 |-- liveness_trans: double (nullable = true)
 |-- tempo_trans: double (nullable = true)



In [9]:
from pyspark.ml.feature import StringIndexer
df=df.drop("artist_name","track_name","artist_name_index","track_name_index","energy")

df.groupBy("class").count().orderBy("count").show()


+-----+-----+
|class|count|
+-----+-----+
|  4.0|  387|
|  3.0|  402|
|  7.0|  574|
|  0.0|  625|
|  1.0|  951|
|  2.0| 1220|
|  5.0| 1409|
|  8.0| 1738|
|  6.0| 2152|
|  9.0| 2392|
| 10.0| 4468|
+-----+-----+



In [10]:
from pyspark.ml.feature import VectorAssembler, StandardScaler 
from pyspark.ml.classification import GBTClassifier
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
from pyspark.sql.functions import col
from pyspark.ml.classification import RandomForestClassificationModel
feature_cols = [col for col in df.columns if col not in "class"]
assembler = VectorAssembler(inputCols=feature_cols, outputCol="features_unscaled")
df_vec = assembler.transform(df)
scaler = StandardScaler(inputCol="features_unscaled",outputCol="features",withStd=True,withMean=True)
scaler_model=scaler.fit(df_vec)
df_scaled = scaler_model.transform(df_vec)

In [11]:
(training_data, testing_data) = df_scaled.randomSplit([0.7, 0.3], seed=42)
training_data.cache()

DataFrame[popularity: double, danceability: double, key: double, mode: double, valence: double, time_signature: double, class: double, duration_in_ms_trans: double, loudness_trans: double, speechiness_trans: double, acousticness_trans: double, instrumentalness_trans: double, liveness_trans: double, tempo_trans: double, features_unscaled: vector, features: vector]

In [12]:
from pyspark.ml.classification import RandomForestClassifier

rf = RandomForestClassifier(
    labelCol="class",
    featuresCol="features",
    numTrees=90,
    maxDepth=10,
    seed=42
)
rf_model = rf.fit(training_data)
rf_predictions = rf_model.transform(testing_data)

In [13]:
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
rf_evaluator = MulticlassClassificationEvaluator(
    labelCol="class", 
    predictionCol="prediction", 
    metricName="accuracy"
)
rf_accuracy = rf_evaluator.evaluate(rf_predictions)
print(f"Random Forest Accuracy = {rf_accuracy}")


Random Forest Accuracy = 0.48414023372287146


In [14]:
model_path = "hdfs://namenode:9000/model/random_forest_model"
rf_model.save(model_path)

In [15]:
from pyspark.ml.classification import RandomForestClassifier, LogisticRegression, LinearSVC, NaiveBayes, OneVsRest
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
from pyspark.ml.tuning import ParamGridBuilder, CrossValidator
from pyspark.ml.classification import  MultilayerPerceptronClassifier


In [16]:
training_data.show(5)

+----------+------------+----+----+-------+--------------+-----+--------------------+-------------------+------------------+-------------------+----------------------+-------------------+------------------+--------------------+--------------------+
|popularity|danceability| key|mode|valence|time_signature|class|duration_in_ms_trans|     loudness_trans| speechiness_trans| acousticness_trans|instrumentalness_trans|     liveness_trans|       tempo_trans|   features_unscaled|            features|
+----------+------------+----+----+-------+--------------+-----+--------------------+-------------------+------------------+-------------------+----------------------+-------------------+------------------+--------------------+--------------------+
|       1.0|      0.0921| 7.0| 1.0|  0.039|           3.0|  7.0|   414.8734747124718|-3.1700990790581587|19.492797411356506| 0.9802803584934617|  -0.06187433988886...|-1.6347505922264112| 5.506647609471016|[1.0,0.0921,7.0,1...|[-2.4929400558154...|
|   

In [17]:
num_features = len(training_data.select("features").first()[0])

num_classes = training_data.select("class").distinct().count()


layers = [num_features, 64, 32, num_classes]

mlp = MultilayerPerceptronClassifier(
    labelCol="class",
    featuresCol="features",
    layers=layers,
    maxIter=100,
    blockSize=128,
    seed=42
)
mlp_eva = MulticlassClassificationEvaluator(
    labelCol="class", 
    predictionCol="prediction", 
    metricName="accuracy"
)
mlp_model = mlp.fit(training_data)
mlp_predictions = mlp_model.transform(testing_data)
mlp_accuracy = mlp_eva.evaluate(mlp_predictions)
print(f"MLP Accuracy = {mlp_accuracy}")

MLP Accuracy = 0.46389816360601


In [18]:
model_path = "hdfs://namenode:9000/model/MLP_model"
mlp_model.save(model_path)

In [19]:
training_data.show()

+----------+------------+----+----+-------+--------------+-----+--------------------+-------------------+------------------+--------------------+----------------------+--------------------+------------------+--------------------+--------------------+
|popularity|danceability| key|mode|valence|time_signature|class|duration_in_ms_trans|     loudness_trans| speechiness_trans|  acousticness_trans|instrumentalness_trans|      liveness_trans|       tempo_trans|   features_unscaled|            features|
+----------+------------+----+----+-------+--------------+-----+--------------------+-------------------+------------------+--------------------+----------------------+--------------------+------------------+--------------------+--------------------+
|       1.0|      0.0921| 7.0| 1.0|  0.039|           3.0|  7.0|   414.8734747124718|-3.1700990790581587|19.492797411356506|  0.9802803584934617|  -0.06187433988886...| -1.6347505922264112| 5.506647609471016|[1.0,0.0921,7.0,1...|[-2.4929400558154.