In [0]:
from pyspark.ml.feature import VectorAssembler, StandardScaler, PCA
from pyspark.ml.classification import LogisticRegression, DecisionTreeClassifier, RandomForestClassifier, NaiveBayes, MultilayerPerceptronClassifier
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
from pyspark.ml.tuning import CrossValidator, ParamGridBuilder
import mlflow
import mlflow.spark

In [0]:
movie = spark.sql('select * from movie1')
# movies = spark.sql('select * from movie_final')

In [0]:
display(movie)
# display(movies)

title,day,month,year,director_number,actor_number,budget,popularity,revenue,runtime,company,country,language,award,vote_average,genre
Ariel,21,10,1988,1,4,0,0.823904,0,69,2,1,2,88,7.1,8
Shadows in Paradise,16,10,1986,1,7,0,0.47445,0,76,1,1,3,88,7.0,8
Four Rooms,25,12,1995,4,24,4000000,1.698,4300000,98,2,1,1,15,6.5,1
Judgment Night,15,10,1993,1,15,0,1.32287,12136938,110,3,2,1,14,6.5,0
Life in Loops (A Megacities RMX),1,1,2006,1,0,42000,0.054716,0,80,1,1,5,5,6.4,10
Sunday in August,2,9,2004,2,2,0,0.001647,0,15,0,1,1,12,5.3,8
Star Wars,25,5,1977,1,106,11000000,10.492614,775398007,121,2,1,1,84,8.0,6
Finding Nemo,30,5,2003,1,24,94000000,9.915573,940335536,100,1,1,1,94,7.6,15
American Beauty,15,9,1999,1,41,15000000,8.191009,356296601,122,2,1,1,63,7.9,8
Citizen Kane,30,4,1941,1,151,839727,3.82689,23217674,119,2,1,1,37,7.9,14


In [0]:
# import scipy.sparse
# from pyspark.ml.linalg import Vectors, _convert_to_vector, VectorUDT
# from pyspark.sql.functions import udf, col

In [0]:
# def dense_to_sparse(vector):
#     return _convert_to_vector(scipy.sparse.csc_matrix(vector.toArray()).T)

# to_sparse = udf(dense_to_sparse, VectorUDT())
# DF.withColumn("sparse", to_sparse(col("densevector"))

In [0]:
movies = movie.withColumnRenamed('genre', 'genres')
display(movies)

title,day,month,year,director_number,actor_number,budget,popularity,revenue,runtime,company,country,language,award,vote_average,genres
Ariel,21,10,1988,1,4,0,0.823904,0,69,2,1,2,88,7.1,8
Shadows in Paradise,16,10,1986,1,7,0,0.47445,0,76,1,1,3,88,7.0,8
Four Rooms,25,12,1995,4,24,4000000,1.698,4300000,98,2,1,1,15,6.5,1
Judgment Night,15,10,1993,1,15,0,1.32287,12136938,110,3,2,1,14,6.5,0
Life in Loops (A Megacities RMX),1,1,2006,1,0,42000,0.054716,0,80,1,1,5,5,6.4,10
Sunday in August,2,9,2004,2,2,0,0.001647,0,15,0,1,1,12,5.3,8
Star Wars,25,5,1977,1,106,11000000,10.492614,775398007,121,2,1,1,84,8.0,6
Finding Nemo,30,5,2003,1,24,94000000,9.915573,940335536,100,1,1,1,94,7.6,15
American Beauty,15,9,1999,1,41,15000000,8.191009,356296601,122,2,1,1,63,7.9,8
Citizen Kane,30,4,1941,1,151,839727,3.82689,23217674,119,2,1,1,37,7.9,14


In [0]:
# vector assembler
assembler = VectorAssembler(inputCols=['day', 'month', 'year', 'vote_average', 'director_number', 'actor_number', 'popularity', 'runtime', 'company', 'country', 'language', 'award'], outputCol='features')
#assembler = VectorAssembler(inputCols=['day', 'month', 'year', 'genre', 'director_number', 'actor_number', 'budget', 'popularity', 'revenue', 'runtime', 'company', 'country', 'language', 'award'], outputCol='features1')

movie1 = assembler.transform(movie)

In [0]:
(train, test) = movie1.randomSplit([0.8, 0.2], seed = 2)
train = train.select(train['features'], train['genre'].alias('label'))
test = test.select(test['features'], test['genre'].alias('label'))
print("Train Count: " + str(train.count()))
print("Test Count: " + str(test.count()))

In [0]:
# nomalization
regular = StandardScaler(inputCol='features', outputCol='features1', withStd=True, withMean=True).fit(train)
reg_train = regular.transform(train)
reg_test = regular.transform(test)
train = reg_train.select(reg_train['features1'].alias('features'), 'label')
test = reg_test.select(reg_test['features1'].alias('features'), 'label')

In [0]:
# PCA (option)
#pca = PCA(k=3, inputCol="features", outputCol="features1").fit(train)
#print(pca.explainedVariance)
#pca_train = pca.transform(train)
#pca_test = pca.transform(test)
#train = pca_train.select(pca_train['features1'].alias('features'), 'label')
#test = pca_test.select(pca_test['features1'].alias('features'), 'label')

In [0]:
evaluator = MulticlassClassificationEvaluator()

In [0]:
with mlflow.start_run():
  lr = LogisticRegression()
  paramGrid = ParamGridBuilder().build()
  crossval = CrossValidator(estimator=lr,
                            evaluator=evaluator,
                            numFolds=5, 
                            estimatorParamMaps=paramGrid)
  cvModel = crossval.fit(train)
  
  train_pred = cvModel.transform(train)
  train_metric = evaluator.evaluate(train_pred)
  mlflow.log_metric('train_' + evaluator.getMetricName(), train_metric)
  test_pred = cvModel.transform(test)
  test_metric = evaluator.evaluate(test_pred)
  mlflow.log_metric('test_' + evaluator.getMetricName(), test_metric)
  
  #mlflow.spark.log_model(spark_model=cvModel.bestModel, artifact_path='best-model') 

In [0]:
display(test_pred)

features,label,rawPrediction,probability,prediction
"Map(vectorType -> dense, length -> 12, values -> List(0.3844254702883426, 0.2579957218484962, 0.35568780181844367, 1.1801621800971176, -0.15661325072631158, 1.546495373470508, 4.4224293316065655, 0.39541571577255413, 1.5100274181792657, 0.2350648666164713, 2.921608831058877, -0.049427546321821024))",19,"Map(vectorType -> dense, length -> 21, values -> List(6.142791002864293, 3.1473197212243784, 1.3503969375223952, 3.9754539593679077, 3.632200086295211, -0.3930206635671392, 4.8869142346189065, -15.51271569676868, 3.803522841553278, -1.3111768734156075, -7.830079751052991, 3.9618569740293967, 3.6622323805757113, -0.08972986559789065, 1.9785178725649912, 1.3694316503124293, -6.465623580485674, 4.541176748376149, 0.31142779156773615, 4.5450178932076515, -15.705913663192456))","Map(vectorType -> dense, length -> 21, values -> List(0.4411685321134077, 0.02206418442799157, 0.003658425618864009, 0.050505887716635306, 0.035831809945997176, 6.399375375295761E-4, 0.1256562737599762, 1.7367506289893297E-10, 0.04252783488911014, 2.554979443800798E-4, 3.769327409173593E-7, 0.04982380752744642, 0.03692424344123414, 8.666726784148911E-4, 0.006856203282284018, 0.00372872968594217, 1.4751617680940201E-6, 0.0889267241930868, 0.0012944213355268728, 0.08926896149082533, 1.4316361432337722E-10))",0.0
"Map(vectorType -> dense, length -> 12, values -> List(-1.271770900598584, -0.2738653806602984, -0.24059886637983785, -1.1406789069018983, 1.2772404593442372, 0.17801116562340052, -0.20161451296512697, 0.39541571577255413, -0.6919233318719775, 0.2350648666164713, 0.1499707968926724, -0.2963566777462492))",19,"Map(vectorType -> dense, length -> 21, values -> List(0.6628928184187695, -0.13953914669066408, 0.497172568108881, -0.14237948320620425, -0.7922834072874305, -1.1399106200438471, 0.07145345731292883, -2.1904659893210896, 2.4330630941215623, -0.6201809113615644, 0.6759345925907599, -0.3901688794735506, -0.6239051708592962, -0.6050000506202666, -0.6945376031056921, -1.846684112569672, 0.20563414776969702, 0.4863526469280256, -0.6778886943398726, 1.93550812263542, 2.8949326209931066))","Map(vectorType -> dense, length -> 21, values -> List(0.03733997218449342, 0.01673717725492668, 0.031637532175574634, 0.016689705489012444, 0.008713627395467327, 0.006154976580817819, 0.020668794114069203, 0.002152662795702489, 0.2192548318055058, 0.01035004288734657, 0.03783014105755756, 0.013026720815888048, 0.01031156833090807, 0.010508364132948285, 0.009608364024096183, 0.003035837890022997, 0.023636821342948727, 0.03129706182384842, 0.009769671873951752, 0.13331032718035826, 0.34796579884455525))",20.0
"Map(vectorType -> dense, length -> 12, values -> List(0.7984745630100742, 1.055787375611688, 0.35568780181844367, 0.5443153235601492, -0.15661325072631158, -0.6641329622825117, -0.19915397927167725, 0.5279211453053034, 0.04206025147843692, 0.2350648666164713, 0.1499707968926724, -0.17289211203403512))",1,"Map(vectorType -> dense, length -> 21, values -> List(0.720192880252497, -0.09973913883407776, 0.1584794170807644, 0.4130243714968689, -1.0479637188905684, -0.3386497961614271, -0.34946608100584303, -0.9968255086989457, 2.5027145660185837, -1.2667383289618215, 2.42309880810125, -0.4164966120857331, -1.058111441731584, -1.4773005310802254, -0.9104540917432553, -0.13966953832121576, -0.10185380019818979, 0.527419476372323, -1.3996678006610463, 1.6172997015324302, 1.2407071675192158))","Map(vectorType -> dense, length -> 21, values -> List(0.04525333595207232, 0.019932356599038654, 0.025804862081774876, 0.033285034493862053, 0.007722358256146377, 0.015696436525789295, 0.015527574274679217, 0.008127537613985672, 0.2690233223131167, 0.0062049364751198864, 0.24843526733361823, 0.014520869605482522, 0.0076443901732028245, 0.00502679688355943, 0.008860732340394797, 0.019152130669503855, 0.0198902509499283, 0.037319015148158156, 0.005432588420438824, 0.11098368531949679, 0.07615651857063131))",8.0
"Map(vectorType -> dense, length -> 12, values -> List(-0.7542095346964194, 0.2579957218484962, -0.02527312508601396, 1.1801621800971176, -0.15661325072631158, -0.032524866353077546, -0.12425857237021509, -0.49426359680447707, 0.7760438348288513, 0.2350648666164713, 1.5357898139757746, -0.3375115329836539))",16,"Map(vectorType -> dense, length -> 21, values -> List(1.1637624379518159, 0.44468421101626177, 0.2923487444916744, 0.2675438161197847, -0.6288237041872676, -1.180949303715861, 0.16763114843722715, -1.8693912211198098, 2.5483019244383307, -0.6942393261354539, 0.7981333573010962, -0.15284749247472207, -0.45989395566314467, -1.1808145212851593, -0.7898295995800916, 0.7149752480753085, -0.4621642107011526, 0.710871637991795, -1.232813172211625, 2.2997571035927673, -0.756243122341774))","Map(vectorType -> dense, length -> 21, values -> List(0.07483927205117796, 0.036461778505220074, 0.03130973530091428, 0.030542652605854644, 0.012462905002896839, 0.007175202849520673, 0.027638548556951383, 0.0036045218097773687, 0.29883221472257926, 0.011673729910600936, 0.051920499554700225, 0.0200601015372287, 0.01475654009829109, 0.0071761700059778935, 0.010609509664516216, 0.04777753666286302, 0.014723076988059242, 0.04758187800962699, 0.006812554530677748, 0.2330696743162353, 0.010971897316330092))",8.0
"Map(vectorType -> dense, length -> 12, values -> List(0.1774009239274768, -0.805726483169093, 0.3391242832573803, 1.4345009833509863, 1.2772404593442372, -0.1377928823413166, -0.2000808470226026, 0.8118613514469092, 0.04206025147843692, 0.2350648666164713, 0.1499707968926724, -0.5021309539332727))",8,"Map(vectorType -> dense, length -> 21, values -> List(1.2417999291377593, -0.10561425545417752, 0.08132084120013106, 0.3142514352596514, -0.7188661279930574, -0.7246174119463472, -0.11673590297345311, -1.0810919737584335, 2.475690469421636, -0.9936365204224245, 1.595353030960322, -0.2774866643580889, -0.6151388255698272, -1.172542024794411, -0.9159631321361071, -0.9322151420147162, 0.09080755983606253, 0.9803219042757507, -2.020957111991241, 1.9309779013229318, 0.9643420219980394))","Map(vectorType -> dense, length -> 21, values -> List(0.0823845463841798, 0.021412689013326713, 0.02581404664839083, 0.032584912730419455, 0.011596830387463387, 0.011530325151512632, 0.02117586401683719, 0.008072845097863999, 0.2829551531721938, 0.008810651840220715, 0.11732536553688948, 0.018031338682257217, 0.01286433972087444, 0.007367334129115217, 0.009522284744426164, 0.009368779243889682, 0.026060102531894808, 0.06342891630375876, 0.0031539033180584637, 0.16411638612670215, 0.06242338521972526))",8.0
"Map(vectorType -> dense, length -> 12, values -> List(-1.168258627418151, 1.055787375611688, 0.3391242832573803, -1.1406789069018983, -0.15661325072631158, 0.17801116562340052, -0.19798122492948467, 0.3007689803920189, 0.04206025147843692, 0.2350648666164713, 0.1499707968926724, -0.3375115329836539))",8,"Map(vectorType -> dense, length -> 21, values -> List(0.7195455879815909, -0.04017621779687852, 0.46248245587907, 0.7224381129336668, -0.7454534038481448, -0.04211090069735812, -0.14730368497017832, -1.1740726083984898, 2.531115795191284, -0.9713209021507518, 0.17252190376076026, -0.14368552746871863, -0.8156427266029684, -1.0045820661377172, -0.7678090034073805, -1.9955547296051894, -0.07496156506956608, 0.7456791448897822, -0.3442439286745643, 1.6510751125394523, 1.2620591516522996))","Map(vectorType -> dense, length -> 21, values -> List(0.05380477157827855, 0.025169686367526846, 0.041608273186862735, 0.05396062852399433, 0.01243311108289172, 0.025121038080379996, 0.022612726818256608, 0.008099027217820961, 0.32928787279859556, 0.00991943202535097, 0.031135183128380952, 0.022694691416418427, 0.01159036140901065, 0.009594926797165115, 0.012158245272683211, 0.003561785046292943, 0.02430920295870919, 0.05522941607154506, 0.01857046912609786, 0.13657742621593214, 0.09256172487780623))",8.0
"Map(vectorType -> dense, length -> 12, values -> List(0.4879377434687755, -0.805726483169093, 0.30599724613525353, 0.41714592193321476, -0.15661325072631158, -0.1377928823413166, 0.17350016639456328, 1.7961873994044757, 0.04206025147843692, 0.2350648666164713, 0.1499707968926724, -0.460976098695868))",8,"Map(vectorType -> dense, length -> 21, values -> List(1.8019229207961827, 0.5734315388311007, 0.4575578235936717, 0.9583061036845061, -0.8645449281604036, 0.03878642508443797, 0.30046699191788734, -2.2034913616987253, 2.8990361437276446, -0.7371101161433202, 1.067999572603445, -0.47248230545240233, -0.7640918596621153, -0.8208627234598501, -0.4569738625097574, -2.6123732726747924, -0.5243897167913996, 0.8174688502460487, -1.143266086177368, 1.7229596544954129, -0.038349792250203185))","Map(vectorType -> dense, length -> 21, values -> List(0.1250673140145438, 0.03661143863197399, 0.032605696333684424, 0.05379794596562837, 0.008691836532155906, 0.021449806165532645, 0.027865689794721118, 0.0022783188442509635, 0.37463991103395444, 0.009873151145968307, 0.060035061883966845, 0.012864182229985493, 0.009610318106187472, 0.00907992979011396, 0.013065240688188401, 0.001513698374613546, 0.012213470315765947, 0.046730545679235805, 0.006577555011593138, 0.11557143668097059, 0.019857452776965184))",8.0
"Map(vectorType -> dense, length -> 12, values -> List(-1.271770900598584, -1.0716570344234904, 0.05754446771930292, 0.6078999485747645, -0.15661325072631158, 0.3885471975998786, -0.08085248836826865, -0.05888861405401499, -0.6919233318719775, 0.2350648666164713, 0.1499707968926724, -0.4198212434584633))",8,"Map(vectorType -> dense, length -> 21, values -> List(1.2235920000486276, 0.3998732853289166, 0.23576901596042954, 0.4088383554046663, -0.40173473764256307, -0.694379012238006, -0.028588324799160222, -2.0288892724233545, 2.541365562035927, -0.9170464478654319, -0.25819469385989524, -0.15455971278243205, -0.45787041194891775, -1.1548886638010718, -0.7562981045534793, -1.3478287150155204, -0.4744716008851255, 0.8126246901790364, -0.9644961992243878, 2.1930282950519633, 1.8241546930297794))","Map(vectorType -> dense, length -> 21, values -> List(0.07595738083188965, 0.03332985992844016, 0.028285503239321524, 0.03363000787168548, 0.014952008920650312, 0.01115849835813257, 0.021714757089310293, 0.0029378885174717174, 0.28370812092545356, 0.008931049887344857, 0.017259896870076445, 0.019144600002010636, 0.014135791600828677, 0.007040588360508971, 0.010488530217460932, 0.005805186194004197, 0.013903057827456503, 0.05036039711120304, 0.008517170674141502, 0.20025843340962737, 0.1384812721629813))",8.0
"Map(vectorType -> dense, length -> 12, values -> List(1.00549910937094, -0.007934829405901128, 0.32256076469631695, 0.22639189529166512, -0.15661325072631158, 0.17801116562340052, -0.017453743994327402, 0.3007689803920189, 0.04206025147843692, 0.2350648666164713, 0.1499707968926724, -0.2963566777462492))",17,"Map(vectorType -> dense, length -> 21, values -> List(1.1788394922990806, 0.30738734611269536, 0.43726676471197456, 0.7220010423293883, -0.5968982379542388, -0.17422739250928065, -0.0823301861726032, -1.6786606471961962, 2.648198464483109, -1.0129999595732893, 0.13929790786598917, -0.1800079569164721, -0.6862336901838211, -1.1863096101000106, -0.6217376373777203, -1.3434306671992848, -0.4578209366835729, 0.8817505550897817, -0.7735583582654886, 2.012668621429498, 0.46680508581046176))","Map(vectorType -> dense, length -> 21, values -> List(0.07854308559300734, 0.032857998114151694, 0.03741510903335106, 0.049739972490389965, 0.01330193621920578, 0.02029917917433235, 0.022253018377001856, 0.004509323741329548, 0.34138344151162364, 0.008774132318710761, 0.027774150523444046, 0.020182176956930253, 0.012165135963590973, 0.007377967785531563, 0.012975593964303582, 0.0063052156747972965, 0.015286757126278696, 0.05835577923136524, 0.011147881554159004, 0.1808153716273877, 0.038536773019107726))",8.0
"Map(vectorType -> dense, length -> 12, values -> List(0.3844254702883426, 1.3217179268660852, 0.1403620605246198, 0.639692261082072, -0.15661325072631158, -0.34832891431779467, 1.5818405088941414, 0.5468504923814105, 0.7760438348288513, 1.5971841878435262, 4.3074278481419785, -0.17289211203403512))",19,"Map(vectorType -> dense, length -> 21, values -> List(3.1815323024432987, 1.2987754414726635, 0.8406636368529226, 1.5509552857097697, 0.3666294138697721, -1.7801570018422734, 2.3993487111675798, -6.693428052818396, 3.2498287609325764, -0.283597672899915, 0.24265593503829808, 1.0679410556068472, 0.8012923942291914, -0.3384219527406611, 0.05213715105985328, 1.2102688150291532, -3.209131740387125, 1.8286554846608492, -1.0485044334423175, 2.6165971784666295, -7.354040712408716))","Map(vectorType -> dense, length -> 21, values -> List(0.22765652407361614, 0.034642496769927505, 0.021910574060989004, 0.04457891560602392, 0.013639024484794765, 0.0015938440109308115, 0.10413149010093105, 1.1712222368751033E-5, 0.24374789604013158, 0.0071185784128899076, 0.012048758822099848, 0.027501671176612737, 0.02106468857894527, 0.006738812783565088, 0.009958663425231145, 0.03170817494405016, 3.8181234936084243E-4, 0.0588481973480897, 0.0033128249474204266, 0.12939929007184536, 6.04977017580717E-6))",8.0


In [0]:
with mlflow.start_run():
  dt = DecisionTreeClassifier()
  paramGrid = ParamGridBuilder().build()
  crossval = CrossValidator(estimator=dt,
                            evaluator=evaluator,
                            numFolds=5, 
                            estimatorParamMaps=paramGrid)
  cvModel = crossval.fit(train)
  
  train_pred = cvModel.transform(train)
  train_metric = evaluator.evaluate(train_pred)
  mlflow.log_metric('train_' + evaluator.getMetricName(), train_metric)
  test_pred = cvModel.transform(test)
  test_metric = evaluator.evaluate(test_pred)
  mlflow.log_metric('test_' + evaluator.getMetricName(), test_metric)
  
  #mlflow.spark.log_model(spark_model=cvModel.bestModel, artifact_path='best-model')

In [0]:
display(test_pred)

features,label,rawPrediction,probability,prediction
"Map(vectorType -> dense, length -> 12, values -> List(0.3844254702883426, 0.2579957218484962, 0.35568780181844367, 1.1801621800971176, -0.15661325072631158, 1.546495373470508, 4.4224293316065655, 0.39541571577255413, 1.5100274181792657, 0.2350648666164713, 2.921608831058877, -0.049427546321821024))",19,"Map(vectorType -> dense, length -> 21, values -> List(696.0, 241.0, 142.0, 289.0, 109.0, 94.0, 272.0, 7.0, 1180.0, 18.0, 122.0, 137.0, 139.0, 18.0, 64.0, 280.0, 26.0, 589.0, 52.0, 1375.0, 35.0))","Map(vectorType -> dense, length -> 21, values -> List(0.11826677994902295, 0.040951571792693285, 0.024129141886151232, 0.049107901444350045, 0.018521665250637212, 0.01597281223449448, 0.04621920135938828, 0.0011894647408666101, 0.20050977060322855, 0.003058623619371283, 0.020730671197960918, 0.023279524214103654, 0.023619371282922685, 0.003058623619371283, 0.010875106202209005, 0.0475785896346644, 0.004418011894647409, 0.10008496176720476, 0.008836023789294817, 0.2336448598130841, 0.00594732370433305))",19.0
"Map(vectorType -> dense, length -> 12, values -> List(-1.271770900598584, -0.2738653806602984, -0.24059886637983785, -1.1406789069018983, 1.2772404593442372, 0.17801116562340052, -0.20161451296512697, 0.39541571577255413, -0.6919233318719775, 0.2350648666164713, 0.1499707968926724, -0.2963566777462492))",19,"Map(vectorType -> dense, length -> 21, values -> List(83.0, 82.0, 83.0, 95.0, 33.0, 20.0, 45.0, 5.0, 815.0, 19.0, 170.0, 41.0, 42.0, 21.0, 28.0, 24.0, 29.0, 112.0, 22.0, 492.0, 938.0))","Map(vectorType -> dense, length -> 21, values -> List(0.025945608002500783, 0.025633010315723664, 0.025945608002500783, 0.029696780243826194, 0.01031572366364489, 0.006251953735542357, 0.014066895904970303, 0.0015629884338855893, 0.25476711472335106, 0.005939356048765239, 0.053141606752110035, 0.012816505157861832, 0.01312910284463895, 0.006564551422319475, 0.0087527352297593, 0.007502344482650828, 0.009065332916536417, 0.0350109409190372, 0.006877149109096593, 0.15379806189434198, 0.29321663019693656))",20.0
"Map(vectorType -> dense, length -> 12, values -> List(0.7984745630100742, 1.055787375611688, 0.35568780181844367, 0.5443153235601492, -0.15661325072631158, -0.6641329622825117, -0.19915397927167725, 0.5279211453053034, 0.04206025147843692, 0.2350648666164713, 0.1499707968926724, -0.17289211203403512))",1,"Map(vectorType -> dense, length -> 21, values -> List(29.0, 18.0, 26.0, 13.0, 8.0, 16.0, 13.0, 6.0, 342.0, 12.0, 263.0, 8.0, 7.0, 6.0, 5.0, 27.0, 28.0, 30.0, 4.0, 130.0, 204.0))","Map(vectorType -> dense, length -> 21, values -> List(0.024267782426778243, 0.01506276150627615, 0.021757322175732216, 0.010878661087866108, 0.0066945606694560665, 0.013389121338912133, 0.010878661087866108, 0.00502092050209205, 0.28619246861924685, 0.0100418410041841, 0.2200836820083682, 0.0066945606694560665, 0.005857740585774059, 0.00502092050209205, 0.0041841004184100415, 0.022594142259414227, 0.023430962343096235, 0.02510460251046025, 0.0033472803347280333, 0.1087866108786611, 0.1707112970711297))",8.0
"Map(vectorType -> dense, length -> 12, values -> List(-0.7542095346964194, 0.2579957218484962, -0.02527312508601396, 1.1801621800971176, -0.15661325072631158, -0.032524866353077546, -0.12425857237021509, -0.49426359680447707, 0.7760438348288513, 0.2350648666164713, 1.5357898139757746, -0.3375115329836539))",16,"Map(vectorType -> dense, length -> 21, values -> List(1023.0, 834.0, 828.0, 725.0, 261.0, 367.0, 553.0, 78.0, 5935.0, 250.0, 863.0, 439.0, 207.0, 124.0, 303.0, 396.0, 342.0, 970.0, 578.0, 3950.0, 1180.0))","Map(vectorType -> dense, length -> 21, values -> List(0.05062852618034247, 0.04127486885083639, 0.04097792734831238, 0.03588043155498367, 0.01291695535979412, 0.018162921904384836, 0.02736810848262892, 0.003860239532812036, 0.2937246362466594, 0.012372562605166781, 0.042710086113035735, 0.02172621993467287, 0.010244481837078096, 0.006136791052162724, 0.01499554587746214, 0.019598139166584184, 0.01692566564386816, 0.048005542908047115, 0.028605364743145602, 0.19548648916163516, 0.05839849549638721))",8.0
"Map(vectorType -> dense, length -> 12, values -> List(0.1774009239274768, -0.805726483169093, 0.3391242832573803, 1.4345009833509863, 1.2772404593442372, -0.1377928823413166, -0.2000808470226026, 0.8118613514469092, 0.04206025147843692, 0.2350648666164713, 0.1499707968926724, -0.5021309539332727))",8,"Map(vectorType -> dense, length -> 21, values -> List(2067.0, 777.0, 813.0, 738.0, 172.0, 180.0, 657.0, 97.0, 7810.0, 262.0, 407.0, 187.0, 245.0, 245.0, 250.0, 81.0, 337.0, 450.0, 131.0, 3236.0, 1018.0))","Map(vectorType -> dense, length -> 21, values -> List(0.1025297619047619, 0.03854166666666667, 0.040327380952380955, 0.03660714285714286, 0.008531746031746031, 0.008928571428571428, 0.032589285714285716, 0.004811507936507937, 0.38740079365079366, 0.012996031746031747, 0.020188492063492065, 0.00927579365079365, 0.012152777777777778, 0.012152777777777778, 0.01240079365079365, 0.0040178571428571425, 0.01671626984126984, 0.022321428571428572, 0.006498015873015873, 0.16051587301587303, 0.05049603174603175))",8.0
"Map(vectorType -> dense, length -> 12, values -> List(-1.168258627418151, 1.055787375611688, 0.3391242832573803, -1.1406789069018983, -0.15661325072631158, 0.17801116562340052, -0.19798122492948467, 0.3007689803920189, 0.04206025147843692, 0.2350648666164713, 0.1499707968926724, -0.3375115329836539))",8,"Map(vectorType -> dense, length -> 21, values -> List(1023.0, 834.0, 828.0, 725.0, 261.0, 367.0, 553.0, 78.0, 5935.0, 250.0, 863.0, 439.0, 207.0, 124.0, 303.0, 396.0, 342.0, 970.0, 578.0, 3950.0, 1180.0))","Map(vectorType -> dense, length -> 21, values -> List(0.05062852618034247, 0.04127486885083639, 0.04097792734831238, 0.03588043155498367, 0.01291695535979412, 0.018162921904384836, 0.02736810848262892, 0.003860239532812036, 0.2937246362466594, 0.012372562605166781, 0.042710086113035735, 0.02172621993467287, 0.010244481837078096, 0.006136791052162724, 0.01499554587746214, 0.019598139166584184, 0.01692566564386816, 0.048005542908047115, 0.028605364743145602, 0.19548648916163516, 0.05839849549638721))",8.0
"Map(vectorType -> dense, length -> 12, values -> List(0.4879377434687755, -0.805726483169093, 0.30599724613525353, 0.41714592193321476, -0.15661325072631158, -0.1377928823413166, 0.17350016639456328, 1.7961873994044757, 0.04206025147843692, 0.2350648666164713, 0.1499707968926724, -0.460976098695868))",8,"Map(vectorType -> dense, length -> 21, values -> List(2067.0, 777.0, 813.0, 738.0, 172.0, 180.0, 657.0, 97.0, 7810.0, 262.0, 407.0, 187.0, 245.0, 245.0, 250.0, 81.0, 337.0, 450.0, 131.0, 3236.0, 1018.0))","Map(vectorType -> dense, length -> 21, values -> List(0.1025297619047619, 0.03854166666666667, 0.040327380952380955, 0.03660714285714286, 0.008531746031746031, 0.008928571428571428, 0.032589285714285716, 0.004811507936507937, 0.38740079365079366, 0.012996031746031747, 0.020188492063492065, 0.00927579365079365, 0.012152777777777778, 0.012152777777777778, 0.01240079365079365, 0.0040178571428571425, 0.01671626984126984, 0.022321428571428572, 0.006498015873015873, 0.16051587301587303, 0.05049603174603175))",8.0
"Map(vectorType -> dense, length -> 12, values -> List(-1.271770900598584, -1.0716570344234904, 0.05754446771930292, 0.6078999485747645, -0.15661325072631158, 0.3885471975998786, -0.08085248836826865, -0.05888861405401499, -0.6919233318719775, 0.2350648666164713, 0.1499707968926724, -0.4198212434584633))",8,"Map(vectorType -> dense, length -> 21, values -> List(286.0, 103.0, 105.0, 169.0, 54.0, 54.0, 72.0, 32.0, 1227.0, 21.0, 281.0, 64.0, 56.0, 22.0, 48.0, 87.0, 48.0, 290.0, 17.0, 998.0, 700.0))","Map(vectorType -> dense, length -> 21, values -> List(0.060414026193493876, 0.02175749894381073, 0.022179974651457542, 0.03569919729615547, 0.011406844106463879, 0.011406844106463879, 0.015209125475285171, 0.006759611322348965, 0.25918884664131814, 0.004435994930291508, 0.05935783692437685, 0.01351922264469793, 0.011829319814110688, 0.004647232784114913, 0.010139416983523447, 0.018377693282636248, 0.010139416983523447, 0.06125897760878749, 0.0035910435149978876, 0.21081537811575835, 0.1478664976763836))",8.0
"Map(vectorType -> dense, length -> 12, values -> List(1.00549910937094, -0.007934829405901128, 0.32256076469631695, 0.22639189529166512, -0.15661325072631158, 0.17801116562340052, -0.017453743994327402, 0.3007689803920189, 0.04206025147843692, 0.2350648666164713, 0.1499707968926724, -0.2963566777462492))",17,"Map(vectorType -> dense, length -> 21, values -> List(1023.0, 834.0, 828.0, 725.0, 261.0, 367.0, 553.0, 78.0, 5935.0, 250.0, 863.0, 439.0, 207.0, 124.0, 303.0, 396.0, 342.0, 970.0, 578.0, 3950.0, 1180.0))","Map(vectorType -> dense, length -> 21, values -> List(0.05062852618034247, 0.04127486885083639, 0.04097792734831238, 0.03588043155498367, 0.01291695535979412, 0.018162921904384836, 0.02736810848262892, 0.003860239532812036, 0.2937246362466594, 0.012372562605166781, 0.042710086113035735, 0.02172621993467287, 0.010244481837078096, 0.006136791052162724, 0.01499554587746214, 0.019598139166584184, 0.01692566564386816, 0.048005542908047115, 0.028605364743145602, 0.19548648916163516, 0.05839849549638721))",8.0
"Map(vectorType -> dense, length -> 12, values -> List(0.3844254702883426, 1.3217179268660852, 0.1403620605246198, 0.639692261082072, -0.15661325072631158, -0.34832891431779467, 1.5818405088941414, 0.5468504923814105, 0.7760438348288513, 1.5971841878435262, 4.3074278481419785, -0.17289211203403512))",19,"Map(vectorType -> dense, length -> 21, values -> List(2067.0, 777.0, 813.0, 738.0, 172.0, 180.0, 657.0, 97.0, 7810.0, 262.0, 407.0, 187.0, 245.0, 245.0, 250.0, 81.0, 337.0, 450.0, 131.0, 3236.0, 1018.0))","Map(vectorType -> dense, length -> 21, values -> List(0.1025297619047619, 0.03854166666666667, 0.040327380952380955, 0.03660714285714286, 0.008531746031746031, 0.008928571428571428, 0.032589285714285716, 0.004811507936507937, 0.38740079365079366, 0.012996031746031747, 0.020188492063492065, 0.00927579365079365, 0.012152777777777778, 0.012152777777777778, 0.01240079365079365, 0.0040178571428571425, 0.01671626984126984, 0.022321428571428572, 0.006498015873015873, 0.16051587301587303, 0.05049603174603175))",8.0


In [0]:
with mlflow.start_run():
  rf = RandomForestClassifier()
  paramGrid = ParamGridBuilder().addGrid(rf.maxDepth, [10]) \
                                .addGrid(rf.minInfoGain, [0.001]) \
                                .addGrid(rf.numTrees, [30]) \
                                .build()
  crossval = CrossValidator(estimator=rf,
                            evaluator=evaluator,
                            numFolds=5, 
                            estimatorParamMaps=paramGrid)
  cvModel = crossval.fit(train)
  
  train_pred = cvModel.transform(train)
  train_metric = evaluator.evaluate(train_pred)
  mlflow.log_metric('train_' + evaluator.getMetricName(), train_metric)
  test_pred = cvModel.transform(test)
  test_metric = evaluator.evaluate(test_pred)
  mlflow.log_metric('test_' + evaluator.getMetricName(), test_metric)
  
  #mlflow.spark.log_model(spark_model=cvModel.bestModel, artifact_path='best-model')

In [0]:
display(test_pred)

features,label,rawPrediction,probability,prediction
"Map(vectorType -> dense, length -> 12, values -> List(0.3844254702883426, 0.2579957218484962, 0.35568780181844367, 1.1801621800971176, -0.15661325072631158, 1.546495373470508, 4.4224293316065655, 0.39541571577255413, 1.5100274181792657, 0.2350648666164713, 2.921608831058877, -0.049427546321821024))",19,"Map(vectorType -> dense, length -> 21, values -> List(3.963686312163781, 1.1208110164295122, 0.43264703702811463, 1.24499721408606, 0.7904986289162985, 0.13439392156114402, 3.6262470455694453, 0.01092670285760479, 5.672721968020863, 0.18467909673233254, 0.451447536936876, 0.28354397379990387, 1.1203426093237714, 0.046803051663225036, 0.4345344269039295, 3.468630837215247, 0.07504904612048718, 1.32835447823654, 0.12533524126739723, 5.434742238558925, 0.049607616608542154))","Map(vectorType -> dense, length -> 21, values -> List(0.13212287707212603, 0.03736036721431707, 0.014421567900937153, 0.04149990713620199, 0.026349954297209946, 0.004479797385371467, 0.1208749015189815, 3.6422342858682624E-4, 0.18909073226736206, 0.0061559698910777505, 0.015048251231229198, 0.009451465793330128, 0.037344753644125706, 0.001560101722107501, 0.01448448089679765, 0.11562102790717489, 0.002501634870682906, 0.04427848260788466, 0.004177841375579907, 0.1811580746186308, 0.0016535872202847383))",8.0
"Map(vectorType -> dense, length -> 12, values -> List(-1.271770900598584, -0.2738653806602984, -0.24059886637983785, -1.1406789069018983, 1.2772404593442372, 0.17801116562340052, -0.20161451296512697, 0.39541571577255413, -0.6919233318719775, 0.2350648666164713, 0.1499707968926724, -0.2963566777462492))",19,"Map(vectorType -> dense, length -> 21, values -> List(1.0878056193555112, 0.917789496946721, 0.8015300580595099, 0.7546719097618191, 0.5166101768646097, 0.21890590201118623, 0.3493591418384885, 0.12130607948654924, 8.50880502552492, 0.38032936793875843, 1.5975758281042434, 0.31906339727826794, 0.28203492733523705, 0.27624748283142975, 0.2549370168647153, 0.2633475042520066, 0.46065143731844643, 0.9924190713508001, 0.18516056813419754, 4.883631570463335, 6.827818418279245))","Map(vectorType -> dense, length -> 21, values -> List(0.03626018731185037, 0.030592983231557364, 0.026717668601983664, 0.02515573032539397, 0.017220339228820325, 0.007296863400372874, 0.011645304727949616, 0.004043535982884975, 0.283626834184164, 0.012677645597958614, 0.05325252760347478, 0.010635446575942264, 0.009401164244507902, 0.009208249427714326, 0.008497900562157178, 0.008778250141733553, 0.01535504791061488, 0.03308063571169333, 0.006172018937806585, 0.16278771901544448, 0.22759394727597482))",8.0
"Map(vectorType -> dense, length -> 12, values -> List(0.7984745630100742, 1.055787375611688, 0.35568780181844367, 0.5443153235601492, -0.15661325072631158, -0.6641329622825117, -0.19915397927167725, 0.5279211453053034, 0.04206025147843692, 0.2350648666164713, 0.1499707968926724, -0.17289211203403512))",1,"Map(vectorType -> dense, length -> 21, values -> List(1.078975557584326, 0.6203042711057606, 0.8766278143004868, 0.6173414224940774, 0.31162729059997, 0.5148171016977008, 0.38358758213029215, 0.25709739147996596, 11.11729860800616, 0.2673493162870491, 4.439882167429507, 0.4631077609748697, 0.14345211506354222, 0.21079374355374153, 0.2707281150859598, 0.43613132794142845, 0.7872957939181136, 0.6915379460396875, 0.0799180676535639, 4.208616821929604, 2.2235097847241962))","Map(vectorType -> dense, length -> 21, values -> List(0.035965851919477536, 0.020676809036858687, 0.02922092714334956, 0.020578047416469247, 0.010387576353332335, 0.017160570056590028, 0.012786252737676406, 0.008569913049332198, 0.37057662026687205, 0.00891164387623497, 0.1479960722476502, 0.015436925365828991, 0.004781737168784741, 0.007026458118458051, 0.009024270502865327, 0.014537710931380949, 0.026243193130603786, 0.023051264867989585, 0.0026639355884521304, 0.14028722739765345, 0.07411699282413987))",8.0
"Map(vectorType -> dense, length -> 12, values -> List(-0.7542095346964194, 0.2579957218484962, -0.02527312508601396, 1.1801621800971176, -0.15661325072631158, -0.032524866353077546, -0.12425857237021509, -0.49426359680447707, 0.7760438348288513, 0.2350648666164713, 1.5357898139757746, -0.3375115329836539))",16,"Map(vectorType -> dense, length -> 21, values -> List(1.6994257689042946, 0.8656604324501856, 0.7873718843301387, 0.9969814302913341, 0.5966344990345942, 0.9987964293787732, 0.7479678885093738, 0.0905316157078944, 6.373921299713372, 0.23223471474805069, 3.1717713273697843, 0.8711808698051433, 0.43746581045649097, 0.12057297787575336, 0.37976429907340337, 2.4332431113428754, 0.6420578694291529, 1.7196666425937686, 0.3080670468156693, 4.7123294423961255, 1.8143546397738215))","Map(vectorType -> dense, length -> 21, values -> List(0.05664752563014316, 0.028855347748339522, 0.02624572947767129, 0.03323271434304447, 0.019887816634486473, 0.033293214312625774, 0.02493226295031246, 0.00301772052359648, 0.21246404332377905, 0.007741157158268356, 0.10572571091232615, 0.02903936232683811, 0.014582193681883033, 0.004019099262525112, 0.012658809969113446, 0.08110810371142918, 0.021401928980971763, 0.057322221419792284, 0.010268901560522311, 0.15707764807987085, 0.060478487992460715))",8.0
"Map(vectorType -> dense, length -> 12, values -> List(0.1774009239274768, -0.805726483169093, 0.3391242832573803, 1.4345009833509863, 1.2772404593442372, -0.1377928823413166, -0.2000808470226026, 0.8118613514469092, 0.04206025147843692, 0.2350648666164713, 0.1499707968926724, -0.5021309539332727))",8,"Map(vectorType -> dense, length -> 21, values -> List(2.485860142400155, 0.8905954553286674, 1.1482698307586539, 1.2464533625642122, 0.3209077311939832, 0.26516175806119036, 0.8352991679596945, 0.22558475042485504, 9.167718196050233, 0.3057187625244567, 1.997859359515262, 0.3671513896244195, 0.3354530792479859, 0.20108393740098712, 0.3652176790962527, 0.44788676022077645, 1.0714961946540305, 1.384391919577171, 0.20777036375915794, 5.208809897569694, 1.521310262068162))","Map(vectorType -> dense, length -> 21, values -> List(0.08286200474667181, 0.029686515177622236, 0.038275661025288454, 0.04154844541880706, 0.010696924373132771, 0.008838725268706342, 0.027843305598656477, 0.007519491680828499, 0.3055906065350077, 0.010190625417481888, 0.06659531198384205, 0.012238379654147313, 0.011181769308266193, 0.006702797913366236, 0.012173922636541754, 0.014929558674025878, 0.03571653982180101, 0.04614639731923902, 0.00692567879197193, 0.17362699658565642, 0.05071034206893872))",8.0
"Map(vectorType -> dense, length -> 12, values -> List(-1.168258627418151, 1.055787375611688, 0.3391242832573803, -1.1406789069018983, -0.15661325072631158, 0.17801116562340052, -0.19798122492948467, 0.3007689803920189, 0.04206025147843692, 0.2350648666164713, 0.1499707968926724, -0.3375115329836539))",8,"Map(vectorType -> dense, length -> 21, values -> List(1.5605309338905862, 1.1603163827977991, 1.3475993755520475, 1.4945867895248826, 0.3962306897705596, 0.6560016912244745, 0.7798797512264882, 0.12662079551376812, 8.744368208766456, 0.3412204875188709, 1.32596223594583, 0.6504249659103237, 0.2586663494367529, 0.18410184288916195, 0.4283971922963751, 0.25379281944237164, 0.5432610477434776, 1.632394130570946, 0.39449108539970856, 5.489484819942644, 2.231668404636477))","Map(vectorType -> dense, length -> 21, values -> List(0.05201769779635287, 0.03867721275992663, 0.044919979185068244, 0.04981955965082941, 0.013207689659018652, 0.021866723040815814, 0.025995991707549605, 0.00422069318379227, 0.29147894029221516, 0.011374016250629029, 0.04419874119819433, 0.021680832197010787, 0.008622211647891762, 0.006136728096305398, 0.014279906409879167, 0.008459760648079054, 0.018108701591449252, 0.05441313768569819, 0.01314970284665695, 0.18298282733142143, 0.0743889468212159))",8.0
"Map(vectorType -> dense, length -> 12, values -> List(0.4879377434687755, -0.805726483169093, 0.30599724613525353, 0.41714592193321476, -0.15661325072631158, -0.1377928823413166, 0.17350016639456328, 1.7961873994044757, 0.04206025147843692, 0.2350648666164713, 0.1499707968926724, -0.460976098695868))",8,"Map(vectorType -> dense, length -> 21, values -> List(5.508454731458979, 1.2444992855528858, 1.1294310471676037, 1.5455384709954136, 0.32605504911519234, 0.4582676257091013, 1.3253164545855403, 0.095688933145474, 8.689866459556322, 0.37427149949522226, 0.5322808606276366, 0.37332858988191475, 0.40129976594741124, 0.4260510462954058, 0.37090163138930227, 0.27697402404098953, 0.39220279464992275, 1.2744815174975523, 0.20992970475461784, 4.3685788092136635, 0.6765816989198485))","Map(vectorType -> dense, length -> 21, values -> List(0.18361515771529932, 0.04148330951842953, 0.03764770157225346, 0.05151794903318046, 0.010868501637173079, 0.015275587523636712, 0.04417721515285135, 0.0031896311048491338, 0.2896622153185441, 0.012475716649840743, 0.017742695354254556, 0.01244428632939716, 0.01337665886491371, 0.014201701543180196, 0.012363387712976744, 0.009232467468032985, 0.01307342648833076, 0.042482717249918416, 0.006997656825153928, 0.14561929364045548, 0.022552723297328287))",8.0
"Map(vectorType -> dense, length -> 12, values -> List(-1.271770900598584, -1.0716570344234904, 0.05754446771930292, 0.6078999485747645, -0.15661325072631158, 0.3885471975998786, -0.08085248836826865, -0.05888861405401499, -0.6919233318719775, 0.2350648666164713, 0.1499707968926724, -0.4198212434584633))",8,"Map(vectorType -> dense, length -> 21, values -> List(2.1916129443699983, 0.8024843859058357, 0.5893558833625192, 1.132215555592053, 0.5099536646095233, 0.2603556566352401, 0.5899273213808199, 0.11112171050757323, 5.821987373783151, 0.1851980126867724, 2.729205090145368, 0.49732165465038164, 0.7121263504278954, 0.07346321793607821, 0.3375173192714426, 0.8438194942992455, 0.3228893075894859, 2.6007101850224577, 0.2937619797380872, 6.298969763960764, 3.096003128125308))","Map(vectorType -> dense, length -> 21, values -> List(0.07305376481233329, 0.02674947953019453, 0.019645196112083978, 0.03774051851973511, 0.016998455486984115, 0.008678521887841338, 0.019664244046027336, 0.0037040570169191087, 0.19406624579277176, 0.006173267089559082, 0.09097350300484562, 0.01657738848834606, 0.023737545014263187, 0.0024487739312026077, 0.011250577309048089, 0.028127316476641524, 0.010762976919649532, 0.08669033950074861, 0.009792065991269576, 0.20996565879869217, 0.10320010427084363))",19.0
"Map(vectorType -> dense, length -> 12, values -> List(1.00549910937094, -0.007934829405901128, 0.32256076469631695, 0.22639189529166512, -0.15661325072631158, 0.17801116562340052, -0.017453743994327402, 0.3007689803920189, 0.04206025147843692, 0.2350648666164713, 0.1499707968926724, -0.2963566777462492))",17,"Map(vectorType -> dense, length -> 21, values -> List(2.906975038829597, 1.062034139640062, 0.9562412474319057, 1.7340945908970067, 0.575111189257107, 0.7324292951222175, 0.9033868916127891, 0.10721754486804302, 7.27656499638, 0.2080238816201696, 0.8389640102752817, 0.710578751104055, 0.49367452815625196, 0.1350532949910107, 0.40701441147703304, 0.4533518269332813, 0.22869415601978865, 2.8809535016106733, 0.4133645541836712, 6.205258786788003, 0.771013362802054))","Map(vectorType -> dense, length -> 21, values -> List(0.09689916796098655, 0.03540113798800207, 0.031874708247730185, 0.05780315302990022, 0.0191703729752369, 0.02441430983740725, 0.030112896387092965, 0.0035739181622681006, 0.24255216654599998, 0.006934129387338986, 0.027965467009176053, 0.023685958370135162, 0.0164558176052084, 0.0045017764997003565, 0.013567147049234432, 0.015111727564442708, 0.007623138533992954, 0.09603178338702244, 0.013778818472789037, 0.20684195955960008, 0.025700445426735132))",8.0
"Map(vectorType -> dense, length -> 12, values -> List(0.3844254702883426, 1.3217179268660852, 0.1403620605246198, 0.639692261082072, -0.15661325072631158, -0.34832891431779467, 1.5818405088941414, 0.5468504923814105, 0.7760438348288513, 1.5971841878435262, 4.3074278481419785, -0.17289211203403512))",19,"Map(vectorType -> dense, length -> 21, values -> List(3.329570435157281, 1.263569552531437, 0.8750592524624783, 1.090860158913073, 0.22056766500569158, 0.31038383917977885, 1.256420769992432, 0.04456736610752648, 12.646526897188187, 0.27765392936057626, 0.4220946484932457, 0.21165335393328638, 0.5102828661570035, 0.24594163801480778, 0.3893109225831688, 0.11248197124704581, 0.22662826882344506, 1.2014931024766713, 0.16270363723020154, 4.8217610248022815, 0.380468700340381))","Map(vectorType -> dense, length -> 21, values -> List(0.11098568117190939, 0.04211898508438124, 0.029168641748749283, 0.03636200529710244, 0.0073522555001897215, 0.010346127972659297, 0.041880692333081074, 0.0014855788702508828, 0.42155089657293965, 0.009255130978685877, 0.014069821616441526, 0.007055111797776214, 0.017009428871900118, 0.008198054600493596, 0.012977030752772296, 0.0037493990415681947, 0.007554275627448171, 0.04004977008255572, 0.005423454574340052, 0.16072536749340943, 0.012682290011346036))",8.0


In [0]:
with mlflow.start_run():
  nn = MultilayerPerceptronClassifier(layers=[12, 24, 36, 21], seed=1)
  paramGrid = ParamGridBuilder().build()
  crossval = CrossValidator(estimator=nn,
                            evaluator=evaluator,
                            numFolds=5, 
                            estimatorParamMaps=paramGrid)
  cvModel = crossval.fit(train)
  
  train_pred = cvModel.transform(train)
  train_metric = evaluator.evaluate(train_pred)
  mlflow.log_metric('train_' + evaluator.getMetricName(), train_metric)
  test_pred = cvModel.transform(test)
  test_metric = evaluator.evaluate(test_pred)
  mlflow.log_metric('test_' + evaluator.getMetricName(), test_metric)
  
  #mlflow.spark.log_model(spark_model=cvModel.bestModel, artifact_path='best-model')

In [0]:
display(test_pred)

features,label,rawPrediction,probability,prediction
"Map(vectorType -> dense, length -> 12, values -> List(0.3844254702883428, 0.25799572184849595, 0.3556878018184399, 1.1801621800971174, -0.15661325072631127, 1.5464953734705078, 4.4224293316065655, 0.39541571577255435, 1.5100274181792657, 0.2350648666164713, 2.921608831058877, -0.0494275463218211))",19,"Map(vectorType -> dense, length -> 21, values -> List(2.9424825598276905, 0.8171773138519575, -0.3029759992453993, 0.9423145743542523, 1.001029262974704, -0.6127768565738613, 2.148948057019174, -3.061058608992228, 1.7878171514469283, -1.509231775039811, -1.6734844532370696, -0.06501991430605603, 0.8494465739218496, -1.6258581529393579, -0.5793026572135084, 2.0439844148266837, -1.8883884364597765, 2.2472577483848624, -0.5296105025100267, 2.157885000982854, -3.045803989047968))","Map(vectorType -> dense, length -> 21, values -> List(0.2581450339469401, 0.030821618036675005, 0.010054929532389652, 0.0349302630156859, 0.037042587979878125, 0.0073762262357265984, 0.11674441652971557, 6.376156575719127E-4, 0.08135775562880875, 0.0030096001794479143, 0.002553728788239259, 0.012756221328415468, 0.03183243021324126, 0.0026782962407328505, 0.007627318615686767, 0.10511168263758044, 0.0020598858687965378, 0.12880463030724215, 0.008015911562982468, 0.11779243088579142, 6.474168084518115E-4))",0.0
"Map(vectorType -> dense, length -> 12, values -> List(-1.2717709005985838, -0.27386538066029864, -0.2405988663798416, -1.1406789069018983, 1.2772404593442375, 0.17801116562340036, -0.20161451296512697, 0.39541571577255435, -0.6919233318719775, 0.2350648666164713, 0.14997079689267223, -0.2963566777462493))",19,"Map(vectorType -> dense, length -> 21, values -> List(0.9237342779799518, 0.5472760835123789, 0.8011405301976666, 0.13149576646694056, -0.6768187932289335, -0.3694554881481985, 0.2997149334512167, -1.2265566189819765, 2.716853648809647, -0.563957175661464, 0.4405358271545526, -0.22392172544785932, -0.8416220839395787, -0.4108444683251751, -0.31592448949913854, -2.5325888574126547, 0.39435151083583386, 0.3221027898004013, -0.7166081963951377, 2.2881401897197176, 2.89099994099285))","Map(vectorType -> dense, length -> 21, values -> List(0.04086449919252676, 0.028044807564180485, 0.03614967462750166, 0.018504670953919004, 0.008245838299426976, 0.011212978808765288, 0.021894650798161903, 0.004758682283696262, 0.2455206521615018, 0.009231026373694804, 0.025205525715347128, 0.012969568027959007, 0.006992969825850044, 0.010758358092100314, 0.011829577088073901, 0.001289092600412775, 0.02406789822198555, 0.022390353248338348, 0.00792418298718422, 0.15991902725306686, 0.29222596587630684))",20.0
"Map(vectorType -> dense, length -> 12, values -> List(0.7984745630100745, 1.0557873756116878, 0.3556878018184399, 0.5443153235601491, -0.15661325072631127, -0.664132962282512, -0.19915397927167725, 0.5279211453053037, 0.04206025147843692, 0.2350648666164713, 0.14997079689267223, -0.17289211203403518))",1,"Map(vectorType -> dense, length -> 21, values -> List(1.1550158430063093, 0.15309412418625812, 0.6744852004324718, 0.8940931713976363, -0.5851643491796813, -0.09826279405849506, -0.42130324706322486, -0.7235974651580912, 3.002876825991987, -1.127281917142253, 2.1806728213713997, -0.1499008974279083, -0.702447796124314, -0.5732863804044818, -0.7456963136635639, -0.9236717843021796, 0.36433455922713764, 0.35561469018480185, -3.683525858423257, 2.1518394793582893, 1.454440510981894))","Map(vectorType -> dense, length -> 21, values -> List(0.05358204202734361, 0.01967388764654289, 0.03313808855861169, 0.04127641632314336, 0.00940304697426286, 0.015301262450076862, 0.011077265834926188, 0.008187435040717237, 0.34004397861797997, 0.005468018023245746, 0.14943641034796026, 0.01453118797951129, 0.008362440714693945, 0.009515402026540937, 0.008008486712628243, 0.006702806691295808, 0.02430136912512184, 0.02409038558010505, 4.2429439948079087E-4, 0.14518918432377945, 0.07228659060203198))",8.0
"Map(vectorType -> dense, length -> 12, values -> List(-0.7542095346964192, 0.25799572184849595, -0.025273125086017728, 1.1801621800971174, -0.15661325072631127, -0.03252486635307773, -0.12425857237021509, -0.49426359680447673, 0.7760438348288513, 0.2350648666164713, 1.5357898139757746, -0.337511532983654))",16,"Map(vectorType -> dense, length -> 21, values -> List(0.561407041308881, 0.27016419578426654, 0.28385843795186105, 0.3493018788559263, -0.08593794386973666, -0.6041163201914108, 0.13173102098850487, -1.6708799463874373, 2.372555051472777, -1.3910236696159013, 1.3638003798303246, 0.16603698758459778, -0.18456233508675, -1.2638883952881288, -0.6077094707225242, 0.7913651766898474, -0.27825964576639284, 1.0852303767848257, -1.3759910719404038, 2.3377187403274666, 0.22137173982547048))","Map(vectorType -> dense, length -> 21, values -> List(0.039746330473051726, 0.029703790863673217, 0.03011335973612278, 0.0321499970601752, 0.020804567137382424, 0.012391294383435687, 0.025863727419503425, 0.004264097421300036, 0.24314683785719085, 0.005641143599197403, 0.08866904652111793, 0.026766402633185426, 0.018850664123702645, 0.00640591693977688, 0.012346850492053543, 0.05002263770553601, 0.01716462937235645, 0.06711052299036714, 0.005726585237243371, 0.23482233823465692, 0.02828925979897072))",8.0
"Map(vectorType -> dense, length -> 12, values -> List(0.17740092392747697, -0.8057264831690932, 0.33912428325737654, 1.4345009833509863, 1.2772404593442375, -0.13779288234131676, -0.2000808470226026, 0.8118613514469093, 0.04206025147843692, 0.2350648666164713, 0.14997079689267223, -0.5021309539332728))",8,"Map(vectorType -> dense, length -> 21, values -> List(1.176450781077448, 0.080276652400119, 0.4102873183155068, 0.318509188677044, -0.5323187498452755, -0.4410836672287144, -0.3431674992667373, -0.5240705410582556, 2.6208078103047936, -1.2271937935476849, 2.625017182914333, 0.21557947245672499, -0.473503168158497, -0.608322476733948, -0.8075365901908789, -0.2437692477269749, 0.7853832161447302, 0.6658459428160679, -3.6342386579185386, 2.7153594309098525, 0.9583117221233393))","Map(vectorType -> dense, length -> 21, values -> List(0.0513871623707788, 0.017170868452334478, 0.023884385502717757, 0.021789905008570833, 0.009305623092275365, 0.010194556829007599, 0.011243274206969455, 0.009382695231026566, 0.21783663718729662, 0.004644779066069828, 0.21875552537581242, 0.019658642692989763, 0.009869354318430341, 0.008624570151181342, 0.00706675229868138, 0.012418264757010935, 0.03475490861037156, 0.030839105498346585, 4.1840682381577614E-4, 0.23943860364522007, 0.04131597888109236))",19.0
"Map(vectorType -> dense, length -> 12, values -> List(-1.1682586274181508, 1.0557873756116878, 0.33912428325737654, -1.1406789069018983, -0.15661325072631127, 0.17801116562340036, -0.19798122492948467, 0.3007689803920191, 0.04206025147843692, 0.2350648666164713, 0.14997079689267223, -0.337511532983654))",8,"Map(vectorType -> dense, length -> 21, values -> List(0.8272125559727164, 0.31976080669028495, 0.8640696313991058, 1.46523790620263, -0.28691945661323154, 0.47921568566016515, -0.44684802594057094, -1.290132100212377, 2.692134939300956, -1.6302504934780404, 0.8218208573323929, 0.25806893763375016, -0.8000159124018686, -1.1743136839893389, -0.5356490082835355, -1.6271449109132932, -0.5412294403127188, 1.1963509761354003, -1.7209053182380485, 2.0345640802684697, 1.6926758819742243))","Map(vectorType -> dense, length -> 21, values -> List(0.044960457707952685, 0.027067442898064186, 0.046648485485856844, 0.08509844272860513, 0.014756023514446449, 0.03174662801853752, 0.012575152007160991, 0.00541102603251863, 0.2902393040842915, 0.0038509517918937195, 0.044718696805897684, 0.02544806649786972, 0.00883353184858321, 0.006075456594000826, 0.011506611941864222, 0.0038629298304030937, 0.011442578908204593, 0.06503475413679329, 0.003517201031022198, 0.1503754114072475, 0.106830846728786))",8.0
"Map(vectorType -> dense, length -> 12, values -> List(0.4879377434687757, -0.8057264831690932, 0.3059972461352498, 0.41714592193321465, -0.15661325072631127, -0.13779288234131676, 0.17350016639456328, 1.7961873994044755, 0.04206025147843692, 0.2350648666164713, 0.14997079689267223, -0.4609760986958681))",8,"Map(vectorType -> dense, length -> 21, values -> List(2.5603594518274897, 0.8360573920628589, 0.7740204443857327, 1.075863818344729, -0.44656323270898113, -0.19253692896291072, 0.6887453265652176, -1.0023066479331049, 3.1134121317650756, -0.47311773377126476, -0.14861116796996066, -0.5898547663409723, -0.41737010414166, -0.13541892701624478, -0.4240494657538783, -2.1684903935694746, -0.14080001578021537, 0.5951668693186162, -2.6600336318275537, 2.378168463296075, 0.26268682164571217))","Map(vectorType -> dense, length -> 21, values -> List(0.19652310549242274, 0.035039568441855566, 0.03293187396032047, 0.04453539989195008, 0.009716809511801964, 0.01252696635186467, 0.030240010070716465, 0.005574005466867813, 0.34166567126775027, 0.009462180222463895, 0.013089487011389534, 0.008419629218116025, 0.010004654689337514, 0.013263310719175967, 0.009938052659995441, 0.0017366017115217614, 0.013192131349900499, 0.027538566133918807, 0.0010622474781493234, 0.16379063904008223, 0.019749089310398822))",8.0
"Map(vectorType -> dense, length -> 12, values -> List(-1.2717709005985838, -1.0716570344234906, 0.05754446771929915, 0.6078999485747643, -0.15661325072631127, 0.38854719759987844, -0.08085248836826865, -0.058888614054014715, -0.6919233318719775, 0.2350648666164713, 0.14997079689267223, -0.4198212434584634))",8,"Map(vectorType -> dense, length -> 21, values -> List(1.1208177262330299, 0.5544790768035045, 0.4830787948631806, 1.066340861643561, -0.0252513730218731, 0.16912877244062932, -0.37343582039907797, -1.599648037530184, 2.37825428561433, -2.0163977683939773, 0.6904258618612863, 0.3317041892271651, -0.556814542065248, -1.6620866213329932, -0.5106286348307969, -0.688289359031154, -0.6579626328995062, 1.6899387857291028, -1.450327735357964, 2.5947229023382423, 1.7723117851945929))","Map(vectorType -> dense, length -> 21, values -> List(0.056797580395140136, 0.03223829645357699, 0.030016727054825963, 0.05378619638905436, 0.01805503560735361, 0.02192888620184063, 0.012746289115933519, 0.003739784398778852, 0.19972277610563163, 0.0024652130444176735, 0.036932867094523474, 0.02580014698841632, 0.010610684422442207, 0.003513418090471068, 0.011112241795603422, 0.00930346248308766, 0.009589927865391873, 0.10034505391017232, 0.004342057284798765, 0.24799259881899396, 0.1089607564795456))",19.0
"Map(vectorType -> dense, length -> 12, values -> List(1.0054991093709402, -0.007934829405901365, 0.3225607646963132, 0.22639189529166495, -0.15661325072631127, 0.17801116562340036, -0.017453743994327402, 0.3007689803920191, 0.04206025147843692, 0.2350648666164713, 0.14997079689267223, -0.2963566777462493))",17,"Map(vectorType -> dense, length -> 21, values -> List(1.761605427825576, 0.6002269301476957, 0.6312199464216074, 1.6249203432754742, 0.11856788057593856, 0.398538353381482, 0.042449321520787175, -1.6334129917064384, 2.472902198518429, -1.9638003155624915, 0.16354647353643625, 0.37946207654743147, -0.33092631063420763, -1.5011661151120876, -0.5399573472625817, -0.6217405105147754, -0.9571637560876599, 1.8313083922639026, -1.9223831125564348, 2.534457966673699, 0.09229867665116538))","Map(vectorType -> dense, length -> 21, values -> List(0.10464765160566285, 0.032760401362108886, 0.033791643114115984, 0.0912783768802626, 0.020237988578103425, 0.02677669537706906, 0.01875467213432662, 0.0035098755768069004, 0.21312860841441422, 0.0025223556957990776, 0.021169046709740923, 0.026270736967972864, 0.01291084001466989, 0.004006137119702773, 0.010475468947775266, 0.009652848661177191, 0.0069021283755590504, 0.11220212952325638, 0.0026290181981286718, 0.226660101933931, 0.019713274809416377))",19.0
"Map(vectorType -> dense, length -> 12, values -> List(0.3844254702883428, 1.321717926866085, 0.14036206052461603, 0.6396922610820719, -0.15661325072631127, -0.34832891431779484, 1.5818405088941414, 0.5468504923814107, 0.7760438348288513, 1.5971841878435262, 4.3074278481419785, -0.17289211203403518))",19,"Map(vectorType -> dense, length -> 21, values -> List(2.9087233322678894, 0.7853454614202937, -0.08677432266619733, 0.4579622986266086, 0.5849357397396496, -1.1041103155529888, 2.1317691038464486, -2.4329111582728262, 2.319099385728609, -0.7246788307185436, -1.1209414068401848, -0.3811445612916889, 0.8288060458404064, -0.7888146027416972, -0.6385902113172728, 1.1440085753318077, -1.3295531916987915, 1.3951122524380648, -1.116565566917009, 1.9481345790063045, -2.4995401373667225))","Map(vectorType -> dense, length -> 21, values -> List(0.28887057470067623, 0.03455668013970518, 0.014446917853371759, 0.024908714340433888, 0.02828102703061196, 0.005223382087507556, 0.1328239438646012, 0.0013831233776409786, 0.1601890696422854, 0.007633730006130015, 0.005136202587871783, 0.010762962647666997, 0.036091647340913795, 0.007159504784376053, 0.008320024552178343, 0.04946493290995526, 0.00416910850602208, 0.0635843689915912, 0.005158727033953067, 0.11054138725588086, 0.0012939703466263248))",0.0


In [0]:
#from pyspark.mllib.classification import SVMModel
#from pyspark.mllib.classification import SVMWithSGD
#from pyspark.mllib.regression import LabeledPoint
#train1 = train.rdd
#test1 = test.rdd
#trainrdd = train1.map(lambda row:LabeledPoint(row[1], row[0][:]))
#testrdd = test1.map(lambda row:LabeledPoint(row[1], row[0][:]))

#svm = SVMWithSGD.train(sc.parallelize(trainrdd.collect()))

#svmAccuracy = float(svmTotalCorrect)/testrdd.count()
#print("总体预测准确率为{}".format(svmAccuracy))
 
#AUC计算
#scoreAndLabels = testrdd.map(lambda x:(float(svm.predict(x.features)),x.label))
#metrics = evaluator(scoreAndLabels)
#print('PR值:{:.4f},AUC值:{:.4f}'.format(metrics.areaUnderPR, metrics.areaUnderROC))