In [1]:
import findspark
findspark.init()

In [2]:
from pyspark.sql import SparkSession
spark = SparkSession.builder \
     .appName("Test SparkSession") \
     .getOrCreate()

In [3]:
spark.version

'2.4.4'

In [4]:
df_in=spark.read.load("s3://lusun-bucket1/clean_big_subset.csv",format="csv",header="true")

In [22]:
df_in.take(1)

[Row(_c0='0', analysis_sample_rate='22050', artist_familiarity='0.6479336219623286', artist_hotttnesss='0.4820228266956789', artist_id='AR4PQ891187FB5CA9F', artist_latitude='40.76099', artist_location='East Orange, NJ', artist_longitude='-74.20991', artist_name='Dionne Warwick', artist_terms="['brill building pop', 'quiet storm', 'ballad', 'easy listening', 'motown', 'disco', 'soul jazz', 'smooth jazz', 'soul', 'jazz', 'soft rock', 'uk garage', 'chill-out', 'german pop', 'salsa', 'r&b', 'chanson', 'rock', 'pop', 'blues-rock', 'vocal jazz', 'funk', 'oldies', 'pop rock', 'downtempo', 'hip hop', 'classic rock', 'united states', 'germany', 'adult contemporary', 'folk rock', 'vocal', 'soundtrack', 'blues', 'female vocalist', 'electronic', 'new wave', 'urban', 'reggae', 'singer-songwriter', 'swing', '60s', 'female', 'american', '80s', '90s', 'ambient']", artist_terms_freq='0.7901756112133602', bars_confidence='0.09891509433962263', bars_start='109.345429245283', beats_confidence='0.555841121

In [23]:
from pyspark.sql.types import FloatType,IntegerType
from pyspark.sql.functions import when

#convert field type
float_list=['duration','tempo','loudness','energy','danceability','song_hotttnesss','artist_hotttnesss']

for item in float_list:
    df_in = df_in.withColumn(item, df_in[item].cast(FloatType()))
    
df_in = df_in.withColumn("year", df_in["year"].cast(IntegerType()))

df_in = df_in.withColumn("period", \
              when(df_in["year"] <1950, 0).when(df_in["year"] <1960, 1).when(df_in["year"] <1970, 2).when(df_in["year"] <1980, 3)\
                         .when(df_in["year"] <1990, 4).when(df_in["year"] <2000, 5).when(df_in["year"] <2010, 6).when(df_in["year"] <2020, 7).otherwise(8))

In [24]:
from pyspark.ml.feature import OneHotEncoder, StringIndexer, IndexToString, VectorAssembler,StandardScaler
from pyspark.ml import Pipeline

df_in.select('period').show()

assembler = VectorAssembler(
    inputCols=float_list,
    outputCol="features",handleInvalid='skip')

+------+
|period|
+------+
|     4|
|     6|
|     5|
|     5|
|     0|
|     6|
|     0|
|     0|
|     5|
|     2|
|     0|
|     6|
|     5|
|     6|
|     5|
|     6|
|     0|
|     5|
|     8|
|     5|
+------+
only showing top 20 rows



In [40]:
train, test = df_in.randomSplit([0.95, 0.05], seed=10)
p1=Pipeline(stages=[assembler])
training_data=p1.fit(train).transform(train)

In [45]:
from pyspark.ml.classification import RandomForestClassifier,GBTClassifier
rf = RandomForestClassifier(labelCol="period", featuresCol="features", numTrees=20)
pipeline=Pipeline(stages=[rf])
model = pipeline.fit(training_data)

In [46]:
test_data=p1.fit(test).transform(test)
predictions = model.transform(test_data)

In [47]:
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
evaluator = MulticlassClassificationEvaluator(
    labelCol="period", predictionCol="prediction", metricName="accuracy")
accuracy = evaluator.evaluate(predictions)
print("Accuracy: ",accuracy)

Accuracy:  0.5461830451286377
