In [1]:
! pip install gensim

Collecting gensim
  Downloading gensim-4.0.1-cp38-cp38-manylinux1_x86_64.whl (23.9 MB)
[K     |████████████████████████████████| 23.9 MB 3.0 MB/s eta 0:00:01
Collecting smart-open>=1.8.1
  Downloading smart_open-5.1.0-py3-none-any.whl (57 kB)
[K     |████████████████████████████████| 57 kB 3.4 MB/s eta 0:00:01
[?25hInstalling collected packages: smart-open, gensim
Successfully installed gensim-4.0.1 smart-open-5.1.0


In [2]:
from pyspark.sql import SparkSession
import pandas as pd
import numpy as np
import re
import sys
import scipy
import gensim
from gensim.utils import simple_preprocess
import gensim.corpora as corpora




In [3]:
import string
import pyspark
import pyspark.sql.functions as F
from pyspark import broadcast, SparkContext
from pyspark.sql import SQLContext
from pyspark.mllib.util import MLUtils
from pyspark.sql.types import *
from pyspark.ml.feature import CountVectorizer, CountVectorizerModel, Tokenizer, RegexTokenizer, StopWordsRemover, OneHotEncoder, StringIndexer, VectorAssembler, VectorIndexer, Bucketizer
from pyspark.ml.linalg import Vectors, SparseVector
from pyspark.ml.clustering import LDA
from pyspark.ml.functions import vector_to_array


In [43]:
from pyspark.sql.functions import month, year, mean, count, dayofweek, hour, col, min, max, avg, sum, when, lit, desc, unix_timestamp, from_unixtime, udf, regexp_replace, isnan

In [5]:
# Build the spark session
spark = SparkSession.builder \
        .appName('kafka') \
        .getOrCreate()

In [6]:
spark.version

'3.1.1'

In [7]:
sc = spark.sparkContext

## Load data from csv, strip to required size then save to parquet

In [6]:
# Load all submission csv files
data = spark.read.format('csv').options(header ='true').load(f"./data/submissions/*.csv")

In [7]:
print(data.printSchema())
data.count()

root
 |-- _c0: string (nullable = true)
 |-- id: string (nullable = true)
 |-- author_fullname: string (nullable = true)
 |-- title: string (nullable = true)
 |-- score: string (nullable = true)
 |-- author_premium: string (nullable = true)
 |-- domain: string (nullable = true)
 |-- over_18: string (nullable = true)
 |-- subreddit_id: string (nullable = true)
 |-- permalink: string (nullable = true)
 |-- parent_whitelist_status: string (nullable = true)
 |-- url: string (nullable = true)
 |-- created_utc: string (nullable = true)
 |-- num_comments: string (nullable = true)
 |-- upvote_ratio: string (nullable = true)

None


127600

In [8]:
data.head(5)

[Row(_c0='0', id='l8zd9o', author_fullname='t2_a30zb78q', title='موقع إخباري متنوع.. تعرف علي آخر الأخبار', score='1', author_premium='False', domain='nabakham.com', over_18='False', subreddit_id='t5_2qh13', permalink='/r/worldnews/comments/l8zd9o/موقع_إخباري_متنوع_تعرف_علي_آخر_الأخبار/', parent_whitelist_status='all_ads', url='https://nabakham.com/', created_utc='1612051145', num_comments='1', upvote_ratio='1.0'),
 Row(_c0='1', id='l8zd2a', author_fullname='t2_9xscy03j', title='New clinical trials raise fears the coronavirus is learning how to resist vaccines', score='1', author_premium='True', domain='google.co.uk', over_18='False', subreddit_id='t5_2qh13', permalink='/r/worldnews/comments/l8zd2a/new_clinical_trials_raise_fears_the_coronavirus/', parent_whitelist_status='all_ads', url='https://www.google.co.uk/amp/s/news.yahoo.com/amphtml/clinical-trials-raise-fears-coronavirus-040855671.html', created_utc='1612051128', num_comments='0', upvote_ratio='1.0'),
 Row(_c0='2', id='l8zaej'

In [8]:
# Only keep the needed columns and remove duplicate rows
df = data.select('id','title','domain','subreddit_id','created_utc','num_comments')
df = df.distinct()
df.count()

127600

In [9]:
df = df.na.fill(value=0,subset=["num_comments"])

In [11]:
df.show(5)

+------+--------------------+--------------------+------------+-----------+------------+
|    id|               title|              domain|subreddit_id|created_utc|num_comments|
+------+--------------------+--------------------+------------+-----------+------------+
|l8w5qa|Jessica Simpson T...|           bluzz.org|    t5_2qh13| 1612042018|           0|
|l8ni30|Pornhub Now Accep...|  cryptobriefing.com|    t5_2qh13| 1612019085|           0|
|l8n6ta|India Mulls Law T...|            ndtv.com|    t5_2qh13| 1612018166|           2|
|l8j19t|tìm hiểu drama là...|indexlink93447958...|    t5_2qh13| 1612003085|           0|
|l8co25|House Democrats p...|           bluzz.org|    t5_2qh13| 1611977317|           0|
+------+--------------------+--------------------+------------+-----------+------------+
only showing top 5 rows



In [10]:
#Remove punctuation and numbers from titles
df = df.withColumn("title", regexp_replace(col("title"), '[^\sa-zA-Z]', ''))
df.show(5)

+------+--------------------+--------------------+------------+-----------+------------+
|    id|               title|              domain|subreddit_id|created_utc|num_comments|
+------+--------------------+--------------------+------------+-----------+------------+
|l8w5qa|Jessica Simpson T...|           bluzz.org|    t5_2qh13| 1612042018|           0|
|l8ni30|Pornhub Now Accep...|  cryptobriefing.com|    t5_2qh13| 1612019085|           0|
|l8n6ta|India Mulls Law T...|            ndtv.com|    t5_2qh13| 1612018166|           2|
|l8j19t|tm hiu drama l g ...|indexlink93447958...|    t5_2qh13| 1612003085|           0|
|l8co25|House Democrats p...|           bluzz.org|    t5_2qh13| 1611977317|           0|
+------+--------------------+--------------------+------------+-----------+------------+
only showing top 5 rows



## Prepare titles and do topics modelling

In [11]:
tokenizer = Tokenizer(inputCol="title", outputCol="words")
wordsDataFrame = tokenizer.transform(df)

In [14]:
wordsDataFrame.show()

+------+--------------------+--------------------+------------+-----------+------------+--------------------+
|    id|               title|              domain|subreddit_id|created_utc|num_comments|               words|
+------+--------------------+--------------------+------------+-----------+------------+--------------------+
|l8w5qa|Jessica Simpson T...|           bluzz.org|    t5_2qh13| 1612042018|           0|[jessica, simpson...|
|l8ni30|Pornhub Now Accep...|  cryptobriefing.com|    t5_2qh13| 1612019085|           0|[pornhub, now, ac...|
|l8n6ta|India Mulls Law T...|            ndtv.com|    t5_2qh13| 1612018166|           2|[india, mulls, la...|
|l8j19t|tm hiu drama l g ...|indexlink93447958...|    t5_2qh13| 1612003085|           0|[tm, hiu, drama, ...|
|l8co25|House Democrats p...|           bluzz.org|    t5_2qh13| 1611977317|           0|[house, democrats...|
|l7y8yl|Congos prime mini...|      abcnews.go.com|    t5_2qh13| 1611940865|           1|[congos, prime, m...|
|l7y6fc|  

In [12]:
stop_words =StopWordsRemover.loadDefaultStopWords("english")
stop_words = stop_words + ['a','i']

In [13]:
remover = StopWordsRemover(inputCol="words", outputCol="filtered", stopWords = stop_words)
wordsDataFrame = remover.transform(wordsDataFrame)

In [17]:
wordsDataFrame.show(truncate = 50)

+------+--------------------------------------------------+-------------------------------+------------+-----------+------------+--------------------------------------------------+--------------------------------------------------+
|    id|                                             title|                         domain|subreddit_id|created_utc|num_comments|                                             words|                                          filtered|
+------+--------------------------------------------------+-------------------------------+------------+-----------+------------+--------------------------------------------------+--------------------------------------------------+
|l8w5qa|Jessica Simpson Twins With Daughter Birdie in A...|                      bluzz.org|    t5_2qh13| 1612042018|           0|[jessica, simpson, twins, with, daughter, birdi...|[jessica, simpson, twins, daughter, birdie, ado...|
|l8ni30|             Pornhub Now Accepts Dogecoin Payments|             

In [14]:
cv = CountVectorizer(inputCol="filtered", outputCol="vectors")
cvmodel = cv.fit(wordsDataFrame)

#save th countvectorizer to apply to new texts
cvmodel.write().overwrite().save('count_vectorizer_model')
# When need to relaod this model elsewhere use : loadedModel = CountVectorizerModel.load('count_vectorizer_model')

df_vect = cvmodel.transform(wordsDataFrame)
basics = df_vect.select('vectors', 'id')
basics.show()

+--------------------+------+
|             vectors|    id|
+--------------------+------+
|(68283,[464,1645,...|l8w5qa|
|(68283,[2220,2382...|l8ni30|
|(68283,[8,129,224...|l8n6ta|
|(68283,[0,47,65,9...|l8j19t|
|(68283,[88,712,84...|l8co25|
|(68283,[99,477,51...|l7y8yl|
|(68283,[1498,9570...|l7y6fc|
|(68283,[6,225,378...|l7qrh2|
|(68283,[768,895,1...|l7g5r4|
|(68283,[4,21,33,1...|l77aqb|
|(68283,[6,97,157,...|l6wpsu|
|(68283,[863,1397,...|l6vfrx|
|(68283,[7,27,247,...|l686x9|
|(68283,[1,191,456...|l5g6ya|
|(68283,[240,323,3...|l5cr2f|
|(68283,[2,5,183,2...|l51rbl|
|(68283,[117,208,2...|l51nju|
|(68283,[18,26,35,...|l4znqi|
|(68283,[38,43,190...|l4vc0j|
|(68283,[6255],[1.0])|l4sjox|
+--------------------+------+
only showing top 20 rows



In [15]:
basics.write.mode("overwrite").parquet('lda_basics')

In [16]:
basics= spark.read.parquet('lda_basics')

In [17]:
# Create the LDA model and fit it
num_topics = 20
lda = LDA(featuresCol='vectors',k=num_topics, seed=42)
#Train the LDA model
lda_model = lda.fit(basics)

# Save teh model to be applied later in stream pipeline
lda_model.write().overwrite().save('lda_distributed_model')
# When need to relaod this model elsewhere use : sameModel = DistributedLDAModel.load('lda_distributed_model')

In [18]:
# See the results of modelling
ldatopics = lda_model.describeTopics(10)
ldatopics.show(truncate = 50)

+-----+--------------------------------------------------+--------------------------------------------------+
|topic|                                       termIndices|                                       termWeights|
+-----+--------------------------------------------------+--------------------------------------------------+
|    0|[0, 273, 7552, 873, 79, 25006, 1528, 616, 1, 5557]|[5.538735792542494E-4, 1.990048559790569E-4, 1....|
|    1|[298, 4858, 0, 446, 2595, 8525, 10599, 9988, 74...|[7.485265172648988E-4, 3.348183251550614E-4, 2....|
|    2|        [48, 12, 2, 81, 85, 918, 0, 94, 160, 5048]|[0.012289764480241687, 0.011485272140385177, 0....|
|    3|        [0, 6026, 11, 6, 433, 3, 593, 678, 331, 1]|[3.3400557656743644E-4, 3.170723394473146E-4, 3...|
|    4|   [161, 139, 149, 133, 120, 19, 0, 231, 236, 233]|[0.006821106519123967, 0.006627767440201427, 0....|
|    5|[1355, 52, 1162, 0, 2799, 7028, 23, 194, 11856,...|[0.0010139612943358192, 9.102326738652258E-4, 8...|
|    6|   

In [27]:
# For mapping words to the model term indices, first collect the words used in the input vectors
vocab = cvmodel.vocabulary
vocab_broadcast = sc.broadcast(vocab) # saves the words to a broadcast variable for use in mapping
print("Number of words in word vectors: " + str(len(vocab)))
print("Example last words: ")
vocab[-6:-1]

Number of words in word vectors: 68283
Example last words: 


['rambut', 'flav', 'pil', 'alirezas', 'lifelessonswecanlearnfrom']

In [28]:
# Now match input words to terms in model output topics
def map_termID_to_Word(termIndices):
    words = []
    for termID in termIndices:
        words.append(vocab_broadcast.value[termID])

    return words

udf_map_termID_to_Word = udf(map_termID_to_Word , ArrayType(StringType()))

ldatopics_mapped = ldatopics.withColumn("topic_desc", udf_map_termID_to_Word(ldatopics.termIndices))

In [29]:
ldatopics_mapped.show(truncate=False)

+-----+----------------------------------------------------------------+-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------------------------------------------------+
|topic|termIndices                                                     |termWeights                                                                                                                                                                                                                          |topic_desc                                                                                                               |
+-----+----------------------------------------------------------------+------------------------------------------------------------------------------

In [19]:
# Use the model to allocate topic weightings for each title
indiv = lda_model.transform(basics)

In [21]:
indiv.show()

+--------------------+------+--------------------+
|             vectors|    id|   topicDistribution|
+--------------------+------+--------------------+
|(68283,[18,26,83,...|l8xslp|[0.00580986353959...|
|(68283,[214,267,2...|l8vqvn|[0.00421993242201...|
|       (68283,[],[])|l8usg7|[0.0,0.0,0.0,0.0,...|
|(68283,[0,66,121,...|l8t0zg|[0.00386716845783...|
|(68283,[0,10,20,2...|l8o2wz|[0.00464351487824...|
|(68283,[12,71,372...|l8k2ic|[0.00580986354293...|
|(68283,[1,815,237...|l8ip36|[0.00464351488199...|
|(68283,[5,11,479,...|l8bylm|[0.00464351491842...|
|(68283,[1,4,43,18...|l89f20|[0.00464351487738...|
|(68283,[1,5,63,21...|l88ld9|[0.00580986353995...|
|(68283,[3,4,126,1...|l7sxn2|[0.00516162081526...|
|(68283,[145,418,4...|l7si03|[0.00664431653027...|
|(68283,[11,218,34...|l7rn9f|[0.00421993242862...|
|(68283,[0,2,45,48...|l7pbej|[0.00464351487861...|
|   (68283,[0],[1.0])|l7ddrt|[0.02357267036537...|
|(68283,[678,895,1...|l763bt|[0.00464351487832...|
|(68283,[1126,2624...|l6vj09|[0

## Build features and target dataset

In [None]:
# convert epoch time into datetime, but keep epoch time as id

In [22]:
df= df.withColumn('timestamp', from_unixtime(df.created_utc,'yyyy-MM-dd HH:mm:ss')) 

In [23]:
df = df.withColumn("hour", hour(col("timestamp"))).withColumn("day", dayofweek(col("timestamp")))
df = df.withColumn("hour", F.col("hour").astype(StringType())).withColumn("day", F.col("day").astype(StringType()))

In [24]:
df.show(truncate =60)

+------+------------------------------------------------------------+-------------------------------+------------+-----------+------------+-------------------+----+---+
|    id|                                                       title|                         domain|subreddit_id|created_utc|num_comments|          timestamp|hour|day|
+------+------------------------------------------------------------+-------------------------------+------------+-----------+------------+-------------------+----+---+
|l8w5qa|  Jessica Simpson Twins With Daughter Birdie in Adorable Pic|                      bluzz.org|    t5_2qh13| 1612042018|           0|2021-01-30 21:26:58|  21|  7|
|l8ni30|                       Pornhub Now Accepts Dogecoin Payments|             cryptobriefing.com|    t5_2qh13| 1612019085|           0|2021-01-30 15:04:45|  15|  7|
|l8n6ta|India Mulls Law To Ban Cryptocurrencies Create Official D...|                       ndtv.com|    t5_2qh13| 1612018166|           2|2021-01-30 14:49

In [25]:
df.printSchema()

root
 |-- id: string (nullable = true)
 |-- title: string (nullable = true)
 |-- domain: string (nullable = true)
 |-- subreddit_id: string (nullable = true)
 |-- created_utc: string (nullable = true)
 |-- num_comments: string (nullable = true)
 |-- timestamp: string (nullable = true)
 |-- hour: string (nullable = true)
 |-- day: string (nullable = true)



In [26]:
# Create features df by combining topic distributions for each title, with the other features of domain, hour and day
temp = indiv.select('topicDistribution')
# And a list of sequential column titles
X_titles = [f'T_{i}' for i in range(1, num_topics +1)]
X_titles = ['id'] + X_titles # adding the id at start

In [27]:
temp = indiv.withColumn("T_", vector_to_array("topicDistribution")).select(["id"] + [col("T_")[i] for i in range(num_topics)]).drop('vectors', 'topicDistribution')
temp = temp.toDF(*X_titles)
temp.show()

+------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+
|    id|                 T_1|                 T_2|                 T_3|                 T_4|                 T_5|                 T_6|                 T_7|                 T_8|                 T_9|                T_10|                T_11|                T_12|                T_13|                T_14|                T_15|                T_16|                T_17|                T_18|                T_19|                T_20|
+------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------

In [46]:
rf_full = temp.join(df, on = 'id', how= 'left').drop('id','subreddit_id','title','timestamp','created_utc')

In [29]:
rf_full.printSchema()

root
 |-- T_1: double (nullable = true)
 |-- T_2: double (nullable = true)
 |-- T_3: double (nullable = true)
 |-- T_4: double (nullable = true)
 |-- T_5: double (nullable = true)
 |-- T_6: double (nullable = true)
 |-- T_7: double (nullable = true)
 |-- T_8: double (nullable = true)
 |-- T_9: double (nullable = true)
 |-- T_10: double (nullable = true)
 |-- T_11: double (nullable = true)
 |-- T_12: double (nullable = true)
 |-- T_13: double (nullable = true)
 |-- T_14: double (nullable = true)
 |-- T_15: double (nullable = true)
 |-- T_16: double (nullable = true)
 |-- T_17: double (nullable = true)
 |-- T_18: double (nullable = true)
 |-- T_19: double (nullable = true)
 |-- T_20: double (nullable = true)
 |-- domain: string (nullable = true)
 |-- num_comments: string (nullable = true)
 |-- hour: string (nullable = true)
 |-- day: string (nullable = true)



In [30]:
rf_full.count()

127600

In [47]:
# Convert the number of comments into numerical and create a grouped categorical
rf_full = rf_full.withColumn("num_comments", F.col("num_comments").astype(IntegerType()))
bucketizer = Bucketizer(splitsArray=[[0,5,25,float("inf")]],inputCols=["num_comments"], outputCols=["group_comments"])
rf_full = bucketizer.setHandleInvalid("keep").transform(rf_full)
rf_full = rf_full.drop('num_comments')
rf_full.show(20)

+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+----+---+--------------+
|                 T_1|                 T_2|                 T_3|                 T_4|                 T_5|                 T_6|                 T_7|                 T_8|                 T_9|                T_10|                T_11|                T_12|                T_13|                T_14|                T_15|                T_16|                T_17|                T_18|                T_19|                T_20|              domain|hour|day|group_comments|
+--------------------+--------------------+--------------------+--

In [32]:
rf_full.printSchema()

root
 |-- T_1: double (nullable = true)
 |-- T_2: double (nullable = true)
 |-- T_3: double (nullable = true)
 |-- T_4: double (nullable = true)
 |-- T_5: double (nullable = true)
 |-- T_6: double (nullable = true)
 |-- T_7: double (nullable = true)
 |-- T_8: double (nullable = true)
 |-- T_9: double (nullable = true)
 |-- T_10: double (nullable = true)
 |-- T_11: double (nullable = true)
 |-- T_12: double (nullable = true)
 |-- T_13: double (nullable = true)
 |-- T_14: double (nullable = true)
 |-- T_15: double (nullable = true)
 |-- T_16: double (nullable = true)
 |-- T_17: double (nullable = true)
 |-- T_18: double (nullable = true)
 |-- T_19: double (nullable = true)
 |-- T_20: double (nullable = true)
 |-- domain: string (nullable = true)
 |-- hour: string (nullable = true)
 |-- day: string (nullable = true)
 |-- group_comments: double (nullable = true)



In [48]:
# Drop rows where there are null values, to avaoid probelms with modelling
rf_full = rf_full.na.drop("any")

In [49]:
# Check all gone
print(rf_full.select([count(when(isnan(c) | col(c).isNull(), c)).alias(c) for c in rf_full.columns]).show()) 

+---+---+---+---+---+---+---+---+---+----+----+----+----+----+----+----+----+----+----+----+------+----+---+--------------+
|T_1|T_2|T_3|T_4|T_5|T_6|T_7|T_8|T_9|T_10|T_11|T_12|T_13|T_14|T_15|T_16|T_17|T_18|T_19|T_20|domain|hour|day|group_comments|
+---+---+---+---+---+---+---+---+---+----+----+----+----+----+----+----+----+----+----+----+------+----+---+--------------+
|  0|  0|  0|  0|  0|  0|  0|  0|  0|   0|   0|   0|   0|   0|   0|   0|   0|   0|   0|   0|     0|   0|  0|             0|
+---+---+---+---+---+---+---+---+---+----+----+----+----+----+----+----+----+----+----+----+------+----+---+--------------+

None


In [50]:
rf_full.write.mode("overwrite").parquet('rf_full')

In [51]:
rf_full= spark.read.parquet('rf_full')

## Now build the ML model

In [52]:
from pyspark.ml import Pipeline, PipelineModel
from pyspark.ml.classification import GBTClassifier, RandomForestClassifier
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
from pyspark.mllib.evaluation import MulticlassMetrics

In [53]:
# Need to one-hot-encode the domain, hour and day columns as they are categorical
#Create variable for all categorical columns
cat_cols = ['domain','hour','day']
#Create and empty list and put the string indexed and one-hot-encoded transformations into it
stages = []
for cat_col in cat_cols:
    col_indexer = StringIndexer(inputCol=cat_col, outputCol=f"{cat_col}_ind",handleInvalid='skip')
    col_encoder = OneHotEncoder(inputCols=[f"{cat_col}_ind"], outputCols=[f"{cat_col}_ohe"])
    stages += [col_indexer, col_encoder]

In [54]:
# Split data into train and test sets
train, test = rf_full.randomSplit([0.8, 0.2], seed=42)
# Seperate out the features from the data
X = rf_full.drop('group_comments')
y= rf_full.select('group_comments')
# make list of numeric columns (here that is all columns in features not categorical)
features_cols = X.columns
num_cols = [x for x in features_cols if x not in cat_cols]

In [55]:
# Change title of one not encoded categoricals
cat_cols_ohe = [f"{cat_col}_ohe" for cat_col in cat_cols]

In [56]:
# Build all features into one vetor
assembler = VectorAssembler(inputCols=cat_cols_ohe + num_cols, outputCol="features")
stages += [assembler]

In [57]:
#specify the model core and target variable
rf = RandomForestClassifier(labelCol="group_comments",featuresCol="features")
stages += [rf]

In [58]:
# Populate the pipeline
pipeline = Pipeline(stages=stages)

In [59]:
# FIT MODEL TO TRAINING SET VIA PIPELINE
from datetime import datetime 
start_time = datetime.now() 

pipeline_model = pipeline.fit(train)

print('Time elapsed (hh:mm:ss.ms) {}'.format(datetime.now() - start_time))

Time elapsed (hh:mm:ss.ms) 0:01:15.580071


In [60]:
#Save the model for later use
pipeline_model.write().overwrite().save('pipeline_model')

In [61]:
test_preds = pipeline_model.transform(test)

In [62]:
evaluator = MulticlassClassificationEvaluator(labelCol=rf.getLabelCol(), predictionCol=rf.getPredictionCol(), metricName="accuracy")
accuracy = evaluator.evaluate(test_preds)
print("Test Error = %g" % (1.0 - accuracy))

Test Error = 0.172747


In [64]:
test_preds.select('prediction').show(5)

+----------+
|prediction|
+----------+
|       0.0|
|       0.0|
|       0.0|
|       0.0|
|       0.0|
+----------+
only showing top 5 rows



In [65]:
preds_and_labels = test_preds.select(['prediction','group_comments']).withColumn('label', F.col('group_comments').cast(FloatType())).orderBy('prediction')

In [66]:
#select only prediction and label columns
preds_and_labels = preds_and_labels.select(['prediction','label'])

metrics = MulticlassMetrics(preds_and_labels.rdd.map(tuple))

print(metrics.confusionMatrix().toArray())

[[20113.     0.     0.]
 [ 2556.     0.     0.]
 [ 1644.     0.     0.]]


In [67]:
spark.stop()