In [1]:
from pyspark.sql import SparkSession
from pyspark.ml.feature import VectorAssembler, StringIndexer
from pyspark.ml.classification import LogisticRegression
from pyspark.ml import Pipeline
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

Welcome to the Glue Interactive Sessions Kernel
For more information on available magic commands, please type %help in any new cell.

Please view our Getting Started page to access the most up-to-date information on the Interactive Sessions kernel: https://docs.aws.amazon.com/glue/latest/dg/interactive-sessions.html
Installed kernel version: 1.0.7 
Trying to create a Glue session for the kernel.
Session Type: glueetl
Session ID: e4b383f1-c9da-41a2-a80e-e393e592dbab
Applying the following default arguments:
--glue_kernel_version 1.0.7
--enable-glue-datacatalog true
Waiting for session e4b383f1-c9da-41a2-a80e-e393e592dbab to get into ready status...
Session e4b383f1-c9da-41a2-a80e-e393e592dbab has been created.



In [2]:
spark = SparkSession.builder.getOrCreate()




In [3]:
train_df = spark.read.parquet("s3://music-preference-bucket/train/")
val_df = spark.read.parquet("s3://music-preference-bucket/validation/")




In [4]:
feature_cols = [c for c in train_df.columns if c not in ['country', 'user_id']]
print(feature_cols)

vec_assembler = VectorAssembler(inputCols=feature_cols, outputCol='features', handleInvalid='skip')
label_indexer = StringIndexer(inputCol='country', outputCol='label')
lg_estimator = LogisticRegression(family='multinomial', maxIter=10, regParam=0.01)

['dim_0', 'dim_1', 'dim_2', 'dim_3', 'dim_4', 'dim_5', 'dim_6', 'dim_7', 'dim_8', 'dim_9', 'dim_10', 'dim_11', 'dim_12', 'dim_13', 'dim_14', 'dim_15', 'dim_16', 'dim_17', 'dim_18', 'dim_19', 'dim_20', 'dim_21', 'dim_22', 'dim_23', 'dim_24', 'dim_25', 'dim_26', 'dim_27', 'dim_28', 'dim_29', 'dim_30', 'dim_31', 'dim_32', 'dim_33', 'dim_34', 'dim_35', 'dim_36', 'dim_37', 'dim_38', 'dim_39', 'dim_40', 'dim_41', 'dim_42', 'dim_43', 'dim_44', 'dim_45', 'dim_46', 'dim_47', 'dim_48', 'dim_49', 'dim_50', 'dim_51', 'dim_52', 'dim_53', 'dim_54', 'dim_55', 'dim_56', 'dim_57', 'dim_58', 'dim_59', 'dim_60', 'dim_61', 'dim_62', 'dim_63', 'dim_64', 'dim_65', 'dim_66', 'dim_67', 'dim_68', 'dim_69', 'dim_70', 'dim_71', 'dim_72', 'dim_73', 'dim_74', 'dim_75', 'dim_76', 'dim_77', 'dim_78', 'dim_79', 'dim_80', 'dim_81', 'dim_82', 'dim_83', 'dim_84', 'dim_85', 'dim_86', 'dim_87', 'dim_88', 'dim_89', 'dim_90', 'dim_91', 'dim_92', 'dim_93', 'dim_94', 'dim_95', 'dim_96', 'dim_97', 'dim_98', 'dim_99', 'dim_100'

In [5]:
lg_pipeline = Pipeline(stages=[vec_assembler, label_indexer, lg_estimator])




In [6]:
model = lg_pipeline.fit(train_df)




In [7]:
preds = model.transform(val_df)
evaluator = MulticlassClassificationEvaluator(labelCol="label", predictionCol="prediction", metricName="accuracy")
accuracy = evaluator.evaluate(preds)
print(f"Validation Accuracy: {accuracy:.4f}")

Validation Accuracy: 0.4708


In [13]:
from pyspark.sql import functions as F
preds.select(F.avg("probability")).show()

AnalysisException: cannot resolve 'avg(probability)' due to data type mismatch: function average requires numeric or interval types, not struct<type:tinyint,size:int,indices:array<int>,values:array<double>>;
'Aggregate [unresolvedalias(avg(probability#2201), Some(org.apache.spark.sql.Column$$Lambda$3347/1297554302@38625107))]
+- Project [country#262, user_id#263, dim_0#264, dim_1#265, dim_2#266, dim_3#267, dim_4#268, dim_5#269, dim_6#270, dim_7#271, dim_8#272, dim_9#273, dim_10#274, dim_11#275, dim_12#276, dim_13#277, dim_14#278, dim_15#279, dim_16#280, dim_17#281, dim_18#282, dim_19#283, dim_20#284, dim_21#285, ... 112 more fields]
   +- Project [country#262, user_id#263, dim_0#264, dim_1#265, dim_2#266, dim_3#267, dim_4#268, dim_5#269, dim_6#270, dim_7#271, dim_8#272, dim_9#273, dim_10#274, dim_11#275, dim_12#276, dim_13#277, dim_14#278, dim_15#279, dim_16#280, dim_17#281, dim_18#282, dim_19#283, dim_20#284, dim_21#285, ... 111 more fields]
      +- Project [country#262, user_id#263,

In [15]:
preds.first()

Row(country='Indonesia', user_id=13515, dim_0=0.05113790375256234, dim_1=-0.0512292432442696, dim_2=-0.03897780428924097, dim_3=-0.020539602275872586, dim_4=0.0709531680358448, dim_5=0.06936525441605007, dim_6=0.027662616471201024, dim_7=-0.03341394467721696, dim_8=-0.033422253561740146, dim_9=-0.02153239314204834, dim_10=-0.026961709000312008, dim_11=0.03505597480202398, dim_12=0.009397621920913889, dim_13=-0.015037941546517304, dim_14=-0.09486028581979575, dim_15=0.07976579658645289, dim_16=0.015420502525363757, dim_17=-0.04555491886994407, dim_18=0.04162913822091982, dim_19=0.06185324945104693, dim_20=-0.1398545950875567, dim_21=-0.017595827466461923, dim_22=0.026781574646240935, dim_23=0.0530154015624052, dim_24=0.010293297970581459, dim_25=0.02994589795891948, dim_26=0.09182263309140147, dim_27=0.0288264381658984, dim_28=0.008506462679865292, dim_29=0.10323017573261595, dim_30=-0.15486386555108747, dim_31=0.05442101885618125, dim_32=0.013379282075711119, dim_33=0.01618851840997194