In [10]:
import os
import shutil
import zipfile

In [11]:
base_folder = os.getcwd()
temporary_folder = os.path.join(os.getcwd(), "tmp")

In [12]:
def unzip_files():
# Unzip file on a temporary folder
    if os.path.exists(temporary_folder):
        shutil.rmtree(temporary_folder)
        
    if not os.path.exists(temporary_folder):
        os.makedirs(temporary_folder)
        
    local_file_name = os.path.join(base_folder, "training_dataset", "trainingandtestdata.zip")
    with zipfile.ZipFile(local_file_name, 'r') as zip_ref:
        zip_ref.extractall(temporary_folder)

In [29]:
def cleansing_and_tokenizing(tweet):
# Cleansing and tokenizing tweet
    from nltk.tokenize import word_tokenize
    from nltk.corpus import stopwords 
    from string import punctuation 
    from bs4 import BeautifulSoup
    import re
    
    terms_to_remove = set(stopwords.words("english") + ["USERTAGGING","URL"])
    tweet = BeautifulSoup(tweet, 'html.parser').get_text() # Extracts text from HTML (just in case!)
    tweet = tweet.lower() # Converts text to lower-case
    tweet = re.sub("((www\.[^\s]+)|(https?://[^\s]+))", "URL", tweet) # Replces URLs by URL constan
    tweet = re.sub("@[^\s]+", "USERTAGGING", tweet) # Replaces usernames by USERTAGGING constant 
    tweet = re.sub(r"#([^\s]+)", r"\1", tweet) # Removes the # in #hashtag
    for p in punctuation: 
        tweet = tweet.replace(p, "") # Removes punctiation
    tweet = word_tokenize(tweet) # Creates a list of words
    words = ""
    for each_word in tweet:
        if each_word not in terms_to_remove:
            words = words + " " + each_word
    # return [word for word in tweet if word not in terms_to_remove]
    return words[1:]

In [13]:
unzip_files()

In [18]:
from pyspark.sql import SparkSession, functions


spark = SparkSession.builder.master("local").appName("Training Twitter Sentiment Analysis").getOrCreate()
training_data = spark.read.load(
    "tmp/training.1600000.processed.noemoticon.csv",
    format="csv")
training_data = training_data.withColumnRenamed("_c0", "label") \
    .withColumnRenamed("_c1", "tweet_id") \
    .withColumnRenamed("_c2", "date") \
    .withColumnRenamed("_c3", "query") \
    .withColumnRenamed("_c4", "user") \
    .withColumnRenamed("_c5", "tweet")


In [19]:
sample_size = 10000
training_data = training_data.sample(sample_size / 1600000)

training_data = training_data.select(functions.col("label"), functions.col("tweet"))

In [30]:
udf_cleansing_and_tokenizing = functions.udf(cleansing_and_tokenizing)
training_data = training_data.withColumn("tweet_cleansed", udf_cleansing_and_tokenizing(functions.col("tweet")))

In [32]:
training_data.show(5)

+-----+--------------------+--------------------+
|label|               tweet|      tweet_cleansed|
+-----+--------------------+--------------------+
|    0|Poor Joshy is sic...|poor joshy sick d...|
|    0|@nicolerichie OH ...|         oh yes miss|
|    0|Just heard that t...|heard found sandr...|
|    0|I don't like the ...|dont like previou...|
|    0|worked his heart ...|worked heart toda...|
+-----+--------------------+--------------------+
only showing top 5 rows



In [33]:
training_data.describe()

DataFrame[summary: string, label: string, tweet: string, tweet_cleansed: string]

In [39]:
from pyspark.ml.feature import Tokenizer

tokenizer = Tokenizer(inputCol="tweet_cleansed", outputCol="words")
training_data = tokenizer.transform(training_data)

In [40]:
from pyspark.ml.feature import Word2Vec

word2vec = Word2Vec(inputCol="words", outputCol="vectorized") 
model = word2vec.fit(training_data)
training_data = model.transform(training_data)   


In [41]:
training_data.show(5)

+-----+--------------------+--------------------+--------------------+--------------------+
|label|               tweet|      tweet_cleansed|               words|          vectorized|
+-----+--------------------+--------------------+--------------------+--------------------+
|    0|Poor Joshy is sic...|poor joshy sick d...|[poor, joshy, sic...|[-0.0268824268132...|
|    0|@nicolerichie OH ...|         oh yes miss|     [oh, yes, miss]|[-0.0248483438044...|
|    0|Just heard that t...|heard found sandr...|[heard, found, sa...|[-0.0023246163701...|
|    0|I don't like the ...|dont like previou...|[dont, like, prev...|[-0.0151059520430...|
|    0|worked his heart ...|worked heart toda...|[worked, heart, t...|[-0.0098882078018...|
+-----+--------------------+--------------------+--------------------+--------------------+
only showing top 5 rows



In [47]:
training_data = training_data.withColumn("label", functions.col("label").cast("integer"))
training, test = training_data.select("label", "vectorized").randomSplit([0.5, 0.5])

In [48]:
from pyspark.ml.classification import NaiveBayes
model = NaiveBayes(featuresCol="vectorized", labelCol="label")
model_train = model.fit(training)
predictions = model_train.transform(test)
predictions.show(5)



Py4JJavaError: An error occurred while calling o407.fit.
: org.apache.spark.SparkException: Job aborted due to stage failure: Task 0 in stage 17.0 failed 1 times, most recent failure: Lost task 0.0 in stage 17.0 (TID 21) (192.168.19.2 executor driver): org.apache.spark.SparkException: Failed to execute user defined function(NaiveBayes$$Lambda$3543/0x00000008014d9040: (struct<type:tinyint,size:int,indices:array<int>,values:array<double>>) => struct<type:tinyint,size:int,indices:array<int>,values:array<double>>)
	at org.apache.spark.sql.catalyst.expressions.ScalaUDF.eval(ScalaUDF.scala:1193)
	at org.apache.spark.ml.stat.SummaryBuilderImpl$MetricsAggregate.update(Summarizer.scala:374)
	at org.apache.spark.ml.stat.SummaryBuilderImpl$MetricsAggregate.update(Summarizer.scala:344)
	at org.apache.spark.sql.catalyst.expressions.aggregate.TypedImperativeAggregate.update(interfaces.scala:563)
	at org.apache.spark.sql.execution.aggregate.AggregationIterator$$anonfun$1.$anonfun$applyOrElse$2(AggregationIterator.scala:196)
	at org.apache.spark.sql.execution.aggregate.AggregationIterator$$anonfun$1.$anonfun$applyOrElse$2$adapted(AggregationIterator.scala:196)
	at org.apache.spark.sql.execution.aggregate.AggregationIterator.$anonfun$generateProcessRow$7(AggregationIterator.scala:213)
	at org.apache.spark.sql.execution.aggregate.AggregationIterator.$anonfun$generateProcessRow$7$adapted(AggregationIterator.scala:207)
	at org.apache.spark.sql.execution.aggregate.ObjectAggregationIterator.processInputs(ObjectAggregationIterator.scala:158)
	at org.apache.spark.sql.execution.aggregate.ObjectAggregationIterator.<init>(ObjectAggregationIterator.scala:77)
	at org.apache.spark.sql.execution.aggregate.ObjectHashAggregateExec.$anonfun$doExecute$2(ObjectHashAggregateExec.scala:107)
	at org.apache.spark.sql.execution.aggregate.ObjectHashAggregateExec.$anonfun$doExecute$2$adapted(ObjectHashAggregateExec.scala:85)
	at org.apache.spark.rdd.RDD.$anonfun$mapPartitionsWithIndexInternal$2(RDD.scala:885)
	at org.apache.spark.rdd.RDD.$anonfun$mapPartitionsWithIndexInternal$2$adapted(RDD.scala:885)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:373)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:337)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:373)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:337)
	at org.apache.spark.shuffle.ShuffleWriteProcessor.write(ShuffleWriteProcessor.scala:59)
	at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:99)
	at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:52)
	at org.apache.spark.scheduler.Task.run(Task.scala:131)
	at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$3(Executor.scala:497)
	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1439)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:500)
	at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1128)
	at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:628)
	at java.base/java.lang.Thread.run(Thread.java:834)
Caused by: java.lang.IllegalArgumentException: requirement failed: Naive Bayes requires nonnegative feature values but found [-0.0805140420794487,0.05776373557746411,0.04472080077975989,-0.06039468050003052,-0.03948916606605053,-9.597588330507279E-4,0.005369102954864502,-0.09084606654942036,0.02347710451576859,0.06366781643591822,-0.14358641281723977,0.006951858254615218,0.043731844052672386,0.08499483093619348,-0.049567058729007844,0.030544337257742884,0.028140976279973987,-0.053075131960213184,-0.0013623829931020738,0.027514648064970972,0.14804260134696962,0.03093543676659465,-0.010327377170324326,0.006033461587503553,0.006913885287940503,-0.02880595326423645,-0.038855304848402744,-0.06669797860085964,0.005866802111268044,0.09486495889723301,0.06530047934502363,0.039418494701385504,-0.01340730143710971,-0.015596297010779381,-0.07274153102189303,-0.045411583688110116,0.06434355136007071,-0.07642288971692324,-0.040948635526001455,1.460350351408124E-4,0.05751045816577971,-0.00635526767000556,0.19435634613037112,-0.14347269125282766,-0.07047285307198763,-0.017421470303088427,0.04953364711254835,0.036655936390161514,0.005263985693454743,-0.07514694128185512,-0.006953094527125359,0.04137033857405186,-0.06784584894776345,-0.028330883197486403,-0.07520322129130363,0.07573919892311097,-0.004189199849497527,0.006160224974155426,0.06073436979204416,0.00759834535419941,-0.10546217951923609,0.0017996541224420072,0.040534944087266926,-0.05724871009588242,0.008800731413066388,0.09513340722769499,0.021922691445797685,0.08935730271041394,0.0022789261653088032,0.027210958488285544,-0.033789423666894434,-0.0827432606369257,0.01922410521656275,-0.04829748757183552,0.09769701017066837,-0.08675419241189958,-0.047246372886002065,-0.010551116894930601,0.009164162259548903,-0.048094664572272454,-0.07105480320751667,0.03222439514938742,-0.0960861410945654,0.11337167620658875,-0.11992524415254593,0.10749803557991983,-0.07398652248084546,-0.13408291712403297,0.021345064602792264,-0.014502267865464092,-0.012791129859397189,-0.03857611007988453,-0.06581990905106068,0.009313061146531255,-0.013449309300631285,0.007577323634177447,0.03889341484755278,-0.015883912704885008,-0.006004027277231217,-0.011298755276948215].
	at scala.Predef$.require(Predef.scala:281)
	at org.apache.spark.ml.classification.NaiveBayes$.requireNonnegativeValues(NaiveBayes.scala:358)
	at org.apache.spark.ml.classification.NaiveBayes.$anonfun$trainDiscreteImpl$1(NaiveBayes.scala:177)
	at org.apache.spark.sql.catalyst.expressions.ScalaUDF.$anonfun$f$2(ScalaUDF.scala:205)
	at org.apache.spark.sql.catalyst.expressions.ScalaUDF.eval(ScalaUDF.scala:1190)
	... 29 more

Driver stacktrace:
	at org.apache.spark.scheduler.DAGScheduler.failJobAndIndependentStages(DAGScheduler.scala:2253)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2(DAGScheduler.scala:2202)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2$adapted(DAGScheduler.scala:2201)
	at scala.collection.mutable.ResizableArray.foreach(ResizableArray.scala:62)
	at scala.collection.mutable.ResizableArray.foreach$(ResizableArray.scala:55)
	at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:49)
	at org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:2201)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1(DAGScheduler.scala:1078)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1$adapted(DAGScheduler.scala:1078)
	at scala.Option.foreach(Option.scala:407)
	at org.apache.spark.scheduler.DAGScheduler.handleTaskSetFailed(DAGScheduler.scala:1078)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:2440)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2382)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2371)
	at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:49)
	at org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:868)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2202)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2223)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2242)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2267)
	at org.apache.spark.rdd.RDD.$anonfun$collect$1(RDD.scala:1030)
	at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
	at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:112)
	at org.apache.spark.rdd.RDD.withScope(RDD.scala:414)
	at org.apache.spark.rdd.RDD.collect(RDD.scala:1029)
	at org.apache.spark.sql.execution.SparkPlan.executeCollect(SparkPlan.scala:390)
	at org.apache.spark.sql.Dataset.collectFromPlan(Dataset.scala:3696)
	at org.apache.spark.sql.Dataset.$anonfun$collect$1(Dataset.scala:2965)
	at org.apache.spark.sql.Dataset.$anonfun$withAction$1(Dataset.scala:3687)
	at org.apache.spark.sql.execution.SQLExecution$.$anonfun$withNewExecutionId$5(SQLExecution.scala:103)
	at org.apache.spark.sql.execution.SQLExecution$.withSQLConfPropagated(SQLExecution.scala:163)
	at org.apache.spark.sql.execution.SQLExecution$.$anonfun$withNewExecutionId$1(SQLExecution.scala:90)
	at org.apache.spark.sql.SparkSession.withActive(SparkSession.scala:772)
	at org.apache.spark.sql.execution.SQLExecution$.withNewExecutionId(SQLExecution.scala:64)
	at org.apache.spark.sql.Dataset.withAction(Dataset.scala:3685)
	at org.apache.spark.sql.Dataset.collect(Dataset.scala:2965)
	at org.apache.spark.ml.classification.NaiveBayes.trainDiscreteImpl(NaiveBayes.scala:193)
	at org.apache.spark.ml.classification.NaiveBayes.$anonfun$trainWithLabelCheck$1(NaiveBayes.scala:160)
	at org.apache.spark.ml.util.Instrumentation$.$anonfun$instrumented$1(Instrumentation.scala:191)
	at scala.util.Try$.apply(Try.scala:213)
	at org.apache.spark.ml.util.Instrumentation$.instrumented(Instrumentation.scala:191)
	at org.apache.spark.ml.classification.NaiveBayes.trainWithLabelCheck(NaiveBayes.scala:144)
	at org.apache.spark.ml.classification.NaiveBayes.train(NaiveBayes.scala:133)
	at org.apache.spark.ml.classification.NaiveBayes.train(NaiveBayes.scala:95)
	at org.apache.spark.ml.Predictor.fit(Predictor.scala:151)
	at org.apache.spark.ml.Predictor.fit(Predictor.scala:115)
	at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
	at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
	at java.base/jdk.internal.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
	at java.base/java.lang.reflect.Method.invoke(Method.java:566)
	at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
	at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)
	at py4j.Gateway.invoke(Gateway.java:282)
	at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
	at py4j.commands.CallCommand.execute(CallCommand.java:79)
	at py4j.GatewayConnection.run(GatewayConnection.java:238)
	at java.base/java.lang.Thread.run(Thread.java:834)
Caused by: org.apache.spark.SparkException: Failed to execute user defined function(NaiveBayes$$Lambda$3543/0x00000008014d9040: (struct<type:tinyint,size:int,indices:array<int>,values:array<double>>) => struct<type:tinyint,size:int,indices:array<int>,values:array<double>>)
	at org.apache.spark.sql.catalyst.expressions.ScalaUDF.eval(ScalaUDF.scala:1193)
	at org.apache.spark.ml.stat.SummaryBuilderImpl$MetricsAggregate.update(Summarizer.scala:374)
	at org.apache.spark.ml.stat.SummaryBuilderImpl$MetricsAggregate.update(Summarizer.scala:344)
	at org.apache.spark.sql.catalyst.expressions.aggregate.TypedImperativeAggregate.update(interfaces.scala:563)
	at org.apache.spark.sql.execution.aggregate.AggregationIterator$$anonfun$1.$anonfun$applyOrElse$2(AggregationIterator.scala:196)
	at org.apache.spark.sql.execution.aggregate.AggregationIterator$$anonfun$1.$anonfun$applyOrElse$2$adapted(AggregationIterator.scala:196)
	at org.apache.spark.sql.execution.aggregate.AggregationIterator.$anonfun$generateProcessRow$7(AggregationIterator.scala:213)
	at org.apache.spark.sql.execution.aggregate.AggregationIterator.$anonfun$generateProcessRow$7$adapted(AggregationIterator.scala:207)
	at org.apache.spark.sql.execution.aggregate.ObjectAggregationIterator.processInputs(ObjectAggregationIterator.scala:158)
	at org.apache.spark.sql.execution.aggregate.ObjectAggregationIterator.<init>(ObjectAggregationIterator.scala:77)
	at org.apache.spark.sql.execution.aggregate.ObjectHashAggregateExec.$anonfun$doExecute$2(ObjectHashAggregateExec.scala:107)
	at org.apache.spark.sql.execution.aggregate.ObjectHashAggregateExec.$anonfun$doExecute$2$adapted(ObjectHashAggregateExec.scala:85)
	at org.apache.spark.rdd.RDD.$anonfun$mapPartitionsWithIndexInternal$2(RDD.scala:885)
	at org.apache.spark.rdd.RDD.$anonfun$mapPartitionsWithIndexInternal$2$adapted(RDD.scala:885)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:373)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:337)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:373)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:337)
	at org.apache.spark.shuffle.ShuffleWriteProcessor.write(ShuffleWriteProcessor.scala:59)
	at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:99)
	at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:52)
	at org.apache.spark.scheduler.Task.run(Task.scala:131)
	at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$3(Executor.scala:497)
	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1439)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:500)
	at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1128)
	at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:628)
	... 1 more
Caused by: java.lang.IllegalArgumentException: requirement failed: Naive Bayes requires nonnegative feature values but found [-0.0805140420794487,0.05776373557746411,0.04472080077975989,-0.06039468050003052,-0.03948916606605053,-9.597588330507279E-4,0.005369102954864502,-0.09084606654942036,0.02347710451576859,0.06366781643591822,-0.14358641281723977,0.006951858254615218,0.043731844052672386,0.08499483093619348,-0.049567058729007844,0.030544337257742884,0.028140976279973987,-0.053075131960213184,-0.0013623829931020738,0.027514648064970972,0.14804260134696962,0.03093543676659465,-0.010327377170324326,0.006033461587503553,0.006913885287940503,-0.02880595326423645,-0.038855304848402744,-0.06669797860085964,0.005866802111268044,0.09486495889723301,0.06530047934502363,0.039418494701385504,-0.01340730143710971,-0.015596297010779381,-0.07274153102189303,-0.045411583688110116,0.06434355136007071,-0.07642288971692324,-0.040948635526001455,1.460350351408124E-4,0.05751045816577971,-0.00635526767000556,0.19435634613037112,-0.14347269125282766,-0.07047285307198763,-0.017421470303088427,0.04953364711254835,0.036655936390161514,0.005263985693454743,-0.07514694128185512,-0.006953094527125359,0.04137033857405186,-0.06784584894776345,-0.028330883197486403,-0.07520322129130363,0.07573919892311097,-0.004189199849497527,0.006160224974155426,0.06073436979204416,0.00759834535419941,-0.10546217951923609,0.0017996541224420072,0.040534944087266926,-0.05724871009588242,0.008800731413066388,0.09513340722769499,0.021922691445797685,0.08935730271041394,0.0022789261653088032,0.027210958488285544,-0.033789423666894434,-0.0827432606369257,0.01922410521656275,-0.04829748757183552,0.09769701017066837,-0.08675419241189958,-0.047246372886002065,-0.010551116894930601,0.009164162259548903,-0.048094664572272454,-0.07105480320751667,0.03222439514938742,-0.0960861410945654,0.11337167620658875,-0.11992524415254593,0.10749803557991983,-0.07398652248084546,-0.13408291712403297,0.021345064602792264,-0.014502267865464092,-0.012791129859397189,-0.03857611007988453,-0.06581990905106068,0.009313061146531255,-0.013449309300631285,0.007577323634177447,0.03889341484755278,-0.015883912704885008,-0.006004027277231217,-0.011298755276948215].
	at scala.Predef$.require(Predef.scala:281)
	at org.apache.spark.ml.classification.NaiveBayes$.requireNonnegativeValues(NaiveBayes.scala:358)
	at org.apache.spark.ml.classification.NaiveBayes.$anonfun$trainDiscreteImpl$1(NaiveBayes.scala:177)
	at org.apache.spark.sql.catalyst.expressions.ScalaUDF.$anonfun$f$2(ScalaUDF.scala:205)
	at org.apache.spark.sql.catalyst.expressions.ScalaUDF.eval(ScalaUDF.scala:1190)
	... 29 more


In [None]:
    model = NaiveBayes(featuresCol="features", labelCol="label")
    model_train = model.fit(training)
    predictions = model_train.transform(test)
    predictions.show(5)

    from pyspark.ml.evaluation import MulticlassClassificationEvaluator
    evaluator = MulticlassClassificationEvaluator(labelCol="label", predictionCol="prediction", metricName="accuracy")
    accuracy = evaluator.evaluate(predictions)
    print(accuracy)