In [2]:
! pip install gensim
! pip install nltk


Collecting gensim
  Downloading gensim-4.0.1-cp38-cp38-manylinux1_x86_64.whl (23.9 MB)
[K     |████████████████████████████████| 23.9 MB 11.1 MB/s eta 0:00:01     |████████████████████████████▉   | 21.5 MB 11.1 MB/s eta 0:00:01
Collecting smart-open>=1.8.1
  Downloading smart_open-5.1.0-py3-none-any.whl (57 kB)
[K     |████████████████████████████████| 57 kB 4.2 MB/s eta 0:00:01
Installing collected packages: smart-open, gensim
Successfully installed gensim-4.0.1 smart-open-5.1.0
Collecting nltk
  Downloading nltk-3.6.2-py3-none-any.whl (1.5 MB)
[K     |████████████████████████████████| 1.5 MB 3.7 MB/s eta 0:00:01
Collecting regex
  Downloading regex-2021.4.4-cp38-cp38-manylinux2014_x86_64.whl (733 kB)
[K     |████████████████████████████████| 733 kB 10.0 MB/s eta 0:00:01
Installing collected packages: regex, nltk
Successfully installed nltk-3.6.2 regex-2021.4.4


In [3]:
from pyspark.sql import SparkSession
import pandas as pd
import numpy as np
import re
import sys
import scipy
import gensim
from gensim.utils import simple_preprocess
import gensim.corpora as corpora




In [4]:
import string
import pyspark
import pyspark.sql.functions as F
from pyspark import broadcast, SparkContext
from pyspark.sql import SQLContext
from pyspark.mllib.util import MLUtils
from pyspark.sql.types import *
from pyspark.ml.feature import CountVectorizer, CountVectorizerModel, Tokenizer, RegexTokenizer, StopWordsRemover, OneHotEncoder, StringIndexer, VectorAssembler, VectorIndexer, Bucketizer
from pyspark.ml.linalg import Vectors, SparseVector
from pyspark.ml.clustering import LDA
from pyspark.ml.functions import vector_to_array


In [88]:
from pyspark.sql.functions import month, year, mean, count, dayofweek, hour, col, min, max, avg, sum, when, lit, desc, unix_timestamp, from_unixtime, udf, regexp_replace

In [9]:
# Build the spark session
spark = SparkSession.builder \
        .appName('reddit_app') \
        .getOrCreate()

In [10]:
spark.version

'3.1.1'

In [11]:
sc = spark.sparkContext

## Load data from csv, strip to required size then save to parquet

In [101]:
# Load all submission csv files
data = spark.read.format('csv').options(header ='true').load(f"./submissions/*.csv")

In [102]:
print(data.printSchema())
data.count()

root
 |-- _c0: string (nullable = true)
 |-- id: string (nullable = true)
 |-- author_fullname: string (nullable = true)
 |-- title: string (nullable = true)
 |-- score: string (nullable = true)
 |-- author_premium: string (nullable = true)
 |-- domain: string (nullable = true)
 |-- over_18: string (nullable = true)
 |-- subreddit_id: string (nullable = true)
 |-- permalink: string (nullable = true)
 |-- parent_whitelist_status: string (nullable = true)
 |-- url: string (nullable = true)
 |-- created_utc: string (nullable = true)

None


80648

In [76]:
data.head(5)

[Row(_c0='0', id='mhj2t3', author_fullname='t2_ape487w4', title='British Government’s ‘Gaslighting’ Report on Racism Says Slavery Had Some Upsides', score='1', author_premium='False', domain='thedailybeast.com', over_18='False', subreddit_id='t5_2qh13', permalink='/r/worldnews/comments/mhj2t3/british_governments_gaslighting_report_on_racism/', parent_whitelist_status='all_ads', url='https://www.thedailybeast.com/british-governments-gaslighting-report-on-racism-says-slavery-had-some-upsides', created_utc='1617235221'),
 Row(_c0='1', id='mhj8o3', author_fullname='t2_72m2t7ri', title='Blessed April Greetings from SAM Climate Smart Trust', score='1', author_premium='False', domain='youtube.com', over_18='False', subreddit_id='t5_2qh13', permalink='/r/worldnews/comments/mhj8o3/blessed_april_greetings_from_sam_climate_smart/', parent_whitelist_status='all_ads', url='https://youtube.com/watch?v=JO74vX9teF4&amp;feature=share', created_utc='1617235743'),
 Row(_c0='2', id='mhj983', author_fullna

In [103]:
# Only keep the needed columns and remove duplicate rows
data = data.select('id','title','domain','subreddit_id','created_utc')
data = data.distinct()
data.count()

1700

In [106]:
# Also load all comment files, so can calcluate numebr of comments for each submission
comms =spark.read.format('csv').options(header ='true').load(f"./comments/*.csv")

In [107]:
print(comms.printSchema())
comms.count()

root
 |-- _c0: string (nullable = true)
 |-- id: string (nullable = true)
 |-- link_id: string (nullable = true)

None


71524

In [108]:
# Only keep the needed columns and remove duplicate rows
comms = comms.drop('_c0')
comms = comms.distinct()
comms.count()

1700

In [190]:
# Group comments by link_id and count, then join to submissions data
total_comms = comms.groupBy('link_id').count()
total_comms = total_comms.withColumnRenamed('count','num_comments')
# Remove the prefix in link_id so it will match to submissions format i.e. after the underscore
total_comms = total_comms.withColumn('link_id', regexp_replace(col('link_id'), 't3_', ''))
print(total_comms.count())
total_comms.show(5)

547
+-------+------------+
|link_id|num_comments|
+-------+------------+
| fbc8v0|           1|
| l92q0n|           1|
| hi55dq|           1|
| l9pizh|           2|
| hia8fx|           1|
+-------+------------+
only showing top 5 rows



In [191]:
df = data.join(total_comms, data.id == total_comms.link_id, how= 'left')
print(df.printSchema())
df.count()
df.show(5)

root
 |-- id: string (nullable = true)
 |-- title: string (nullable = true)
 |-- domain: string (nullable = true)
 |-- subreddit_id: string (nullable = true)
 |-- created_utc: string (nullable = true)
 |-- link_id: string (nullable = true)
 |-- num_comments: long (nullable = true)

None
+------+--------------------+---------------+------------+-----------+-------+------------+
|    id|               title|         domain|subreddit_id|created_utc|link_id|num_comments|
+------+--------------------+---------------+------------+-----------+-------+------------+
|fbmdfg|"Greta Thunberg B...|              1|     bbc.com|    all_ads|   null|        null|
|k4axmk|Mark Angel is Nig...|   ghlatest.net|    t5_2qh13| 1606787707|   null|        null|
|k4b1z0|Australia demands...|edition.cnn.com|    t5_2qh13| 1606788095|   null|        null|
|i1jwlh|googul TEA provid...|   foodtour.xyz|    t5_2qh13| 1596245778|   null|        null|
|ikahmv|Hundreds protest ...|  worldnewj.com|    t5_2qh13| 159892147

In [192]:
df.filter(df.link_id.isNotNull()).show(truncate=30)

+------+------------------------------+-----------------+------------+-----------+-------+------------+
|    id|                         title|           domain|subreddit_id|created_utc|link_id|num_comments|
+------+------------------------------+-----------------+------------+-----------+-------+------------+
|l9ptoo|Potential military coup und...|        bbc.co.uk|    t5_2qh13| 1612138100| l9ptoo|           1|
|mhj2t3|British Government’s ‘Gasli...|thedailybeast.com|    t5_2qh13| 1617235221| mhj2t3|           1|
|gb7i55|Michigan launches GI Bill f...|       nypost.com|    t5_2qh13| 1588291304| gb7i55|           1|
|ewxhk0|Brexit is finally official ...|24hrnewsworld.com|    t5_2qh13| 1580515275| ewxhk0|           1|
|k48use|Tony Evers on Monday formal...|  edition.cnn.com|    t5_2qh13| 1606780985| k48use|           1|
|fsohtz|28 Texas spring-breakers te...|      reuters.com|    t5_2qh13| 1585699310| fsohtz|           1|
|gb7htr|Logging returns to NSW nati...|       smh.com.au|    t5_

In [193]:
df = df.na.fill(value=0,subset=["num_comments"])

In [194]:
df.show(5)

+------+--------------------+---------------+------------+-----------+-------+------------+
|    id|               title|         domain|subreddit_id|created_utc|link_id|num_comments|
+------+--------------------+---------------+------------+-----------+-------+------------+
|fbmdfg|"Greta Thunberg B...|              1|     bbc.com|    all_ads|   null|           0|
|k4axmk|Mark Angel is Nig...|   ghlatest.net|    t5_2qh13| 1606787707|   null|           0|
|k4b1z0|Australia demands...|edition.cnn.com|    t5_2qh13| 1606788095|   null|           0|
|i1jwlh|googul TEA provid...|   foodtour.xyz|    t5_2qh13| 1596245778|   null|           0|
|ikahmv|Hundreds protest ...|  worldnewj.com|    t5_2qh13| 1598921474|   null|           0|
+------+--------------------+---------------+------------+-----------+-------+------------+
only showing top 5 rows



In [195]:
#Remove punctuation and numbers from titles
df = df.withColumn("title", regexp_replace(col("title"), '[^\sa-zA-Z]', ''))
df.show(5)

+------+--------------------+---------------+------------+-----------+-------+------------+
|    id|               title|         domain|subreddit_id|created_utc|link_id|num_comments|
+------+--------------------+---------------+------------+-----------+-------+------------+
|fbmdfg|Greta Thunberg Br...|              1|     bbc.com|    all_ads|   null|           0|
|k4axmk|Mark Angel is Nig...|   ghlatest.net|    t5_2qh13| 1606787707|   null|           0|
|k4b1z0|Australia demands...|edition.cnn.com|    t5_2qh13| 1606788095|   null|           0|
|i1jwlh|googul TEA provid...|   foodtour.xyz|    t5_2qh13| 1596245778|   null|           0|
|ikahmv|Hundreds protest ...|  worldnewj.com|    t5_2qh13| 1598921474|   null|           0|
+------+--------------------+---------------+------------+-----------+-------+------------+
only showing top 5 rows



In [196]:
df.write.parquet('reddit_data', mode='overwrite')

In [197]:
df= spark.read.parquet('reddit_data')

In [198]:
df.printSchema()

root
 |-- id: string (nullable = true)
 |-- title: string (nullable = true)
 |-- domain: string (nullable = true)
 |-- subreddit_id: string (nullable = true)
 |-- created_utc: string (nullable = true)
 |-- link_id: string (nullable = true)
 |-- num_comments: long (nullable = true)



## Prepare titles and do topics modelling

In [199]:
tokenizer = Tokenizer(inputCol="title", outputCol="words")
wordsDataFrame = tokenizer.transform(df)

In [200]:
wordsDataFrame.show()

+------+--------------------+--------------------+------------+-----------+-------+------------+--------------------+
|    id|               title|              domain|subreddit_id|created_utc|link_id|num_comments|               words|
+------+--------------------+--------------------+------------+-----------+-------+------------+--------------------+
|mhlgtj|Cn h bin s hu lu ...|      solifeland.com|    t5_2qh13| 1617243528|   null|           0|[cn, h, bin, s, h...|
|mhlu2w|Reports that mult...|         twitter.com|    t5_2qh13| 1617244881|   null|           0|[reports, that, m...|
|fsq0qp|Russia sent Italy...| businessinsider.com|    t5_2qh13| 1585704991|   null|           0|[russia, sent, it...|
|fblzl5|Ganbare Super Str...|   24hrnewsworld.com|    t5_2qh13| 1583022931|   null|           0|[ganbare, super, ...|
|lw5pgd|We attack Indones...|   news.mongabay.com|    t5_2qh13| 1614701722|   null|           0|[we, attack, indo...|
|i1irvy|Wigan residents s...|    thecanadian.news|    t5

In [201]:
stop_words =StopWordsRemover.loadDefaultStopWords("english")
stop_words = stop_words + ['a','i']

In [202]:
remover = StopWordsRemover(inputCol="words", outputCol="filtered", stopWords = stop_words)
wordsDataFrame = remover.transform(wordsDataFrame)

In [203]:
wordsDataFrame.show()

+------+--------------------+--------------------+------------+-----------+-------+------------+--------------------+--------------------+
|    id|               title|              domain|subreddit_id|created_utc|link_id|num_comments|               words|            filtered|
+------+--------------------+--------------------+------------+-----------+-------+------------+--------------------+--------------------+
|mhlgtj|Cn h bin s hu lu ...|      solifeland.com|    t5_2qh13| 1617243528|   null|           0|[cn, h, bin, s, h...|[cn, h, bin, hu, ...|
|mhlu2w|Reports that mult...|         twitter.com|    t5_2qh13| 1617244881|   null|           0|[reports, that, m...|[reports, multipl...|
|fsq0qp|Russia sent Italy...| businessinsider.com|    t5_2qh13| 1585704991|   null|           0|[russia, sent, it...|[russia, sent, it...|
|fblzl5|Ganbare Super Str...|   24hrnewsworld.com|    t5_2qh13| 1583022931|   null|           0|[ganbare, super, ...|[ganbare, super, ...|
|lw5pgd|We attack Indones..

In [204]:
cv = CountVectorizer(inputCol="filtered", outputCol="vectors")
cvmodel = cv.fit(wordsDataFrame)
df_vect = cvmodel.transform(wordsDataFrame)
basics = df_vect.select('vectors', 'id')
basics.show()

+--------------------+------+
|             vectors|    id|
+--------------------+------+
|(6178,[0,29,39,41...|mhlgtj|
|(6178,[28,30,68,1...|mhlu2w|
|(6178,[1,18,65,72...|fsq0qp|
|(6178,[494,575,45...|fblzl5|
|(6178,[76,102,104...|lw5pgd|
|(6178,[77,127,441...|i1irvy|
|(6178,[1,32,68,12...|i1j4bz|
|(6178,[2027,6137]...|j2zl8e|
|(6178,[0,102,139,...|ika9mn|
|(6178,[0,524,763,...|ikahqt|
|(6178,[0,37,59,13...|guairn|
|(6178,[10,967,105...|gub860|
|(6178,[2,5,26,160...|ko1gqr|
|(6178,[0,249,727,...|ko2er6|
|(6178,[207,1306,1...|hj0m6r|
|(6178,[19,94,294,...|hj0ust|
| (6178,[3297],[1.0])|gb836t|
|(6178,[7,30,65,75...|gb99ca|
|(6178,[3,4,7,90,1...|fspfao|
|(6178,[0,16,31,89...|fbmary|
+--------------------+------+
only showing top 20 rows



In [218]:
# Create the LDA model and fit it
num_topics = 20
lda = LDA(featuresCol='vectors',k=num_topics, seed=42)
#Train the LDA model
model = lda.fit(basics)

In [206]:
# See the results of modelling
ldatopics = model.describeTopics(10)
ldatopics.show(truncate = 100)

+-----+---------------------------------------------------------+----------------------------------------------------------------------------------------------------+
|topic|                                              termIndices|                                                                                         termWeights|
+-----+---------------------------------------------------------+----------------------------------------------------------------------------------------------------+
|    0|    [204, 89, 122, 762, 838, 1425, 1848, 322, 1189, 2235]|[0.0012678998507832564, 0.0011068580560100229, 9.262934147928398E-4, 8.117138231307551E-4, 8.0389...|
|    1|                        [0, 1, 5, 10, 4, 2, 3, 25, 16, 8]|[0.0317462620231647, 0.008963583225651387, 0.004341257794022516, 0.004334524169802166, 0.00416171...|
|    2|          [99, 146, 291, 203, 312, 213, 118, 7, 429, 486]|[0.0018872771501108096, 0.001129977119372255, 8.873721093781615E-4, 8.50515073242995E-4, 7.536073...

In [207]:
# For mapping words to the model term indices, first collect the words used in the input vectors
vocab = cvmodel.vocabulary
vocab_broadcast = sc.broadcast(vocab) # saves the words to a broadcast variable for use in mapping
print("Number of words in word vectors: " + str(len(vocab)))
print("Example last words: ")
vocab[-6:-1]

Number of words in word vectors: 6178
Example last words: 


['juan', 'diego', 'khe', 'wounds', 'skinned']

In [208]:
# Now match input words to terms in model output topics
def map_termID_to_Word(termIndices):
    words = []
    for termID in termIndices:
        words.append(vocab_broadcast.value[termID])

    return words

udf_map_termID_to_Word = udf(map_termID_to_Word , ArrayType(StringType()))

ldatopics_mapped = ldatopics.withColumn("topic_desc", udf_map_termID_to_Word(ldatopics.termIndices))

In [209]:
ldatopics_mapped.show(truncate=False)

+-----+---------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------+
|topic|termIndices                                              |termWeights                                                                                                                                                                                                                      |topic_desc                                                                                  |
+-----+---------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------

In [210]:
# Use the model to allocate topic weightings for each title
indiv = model.transform(basics)

In [211]:
indiv.show(truncate = 100)

+----------------------------------------------------------------------------------------------------+------+----------------------------------------------------------------------------------------------------+
|                                                                                             vectors|    id|                                                                                   topicDistribution|
+----------------------------------------------------------------------------------------------------+------+----------------------------------------------------------------------------------------------------+
|(6178,[0,29,39,41,43,44,45,46,47,50,51,53,54,293,434,572,655,668,695,1034,1185,1296,1417,1688,171...|mhlgtj|[0.0016162722546917453,0.002009110673516834,0.0016065295844630877,0.0016167244349579635,0.0016738...|
|                    (6178,[28,30,68,103,196,252,270,768,1374],[1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0])|mhlu2w|[0.004700873924317509,0.005843432928403815,0.00

## Build features and target dataset

In [None]:
# convert epoch time into datetime, but keep epoch time as id

In [212]:
df= df.withColumn('timestamp', from_unixtime(df.created_utc,'yyyy-MM-dd HH:mm:ss')) 

In [213]:
df = df.withColumn("hour", hour(col("timestamp"))).withColumn("day", dayofweek(col("timestamp")))
df = df.withColumn("hour", F.col("hour").astype(StringType())).withColumn("day", F.col("day").astype(StringType()))

In [214]:
df.show()

+------+--------------------+--------------------+------------+-----------+-------+------------+-------------------+----+---+
|    id|               title|              domain|subreddit_id|created_utc|link_id|num_comments|          timestamp|hour|day|
+------+--------------------+--------------------+------------+-----------+-------+------------+-------------------+----+---+
|mhlgtj|Cn h bin s hu lu ...|      solifeland.com|    t5_2qh13| 1617243528|   null|           0|2021-04-01 02:18:48|   2|  5|
|mhlu2w|Reports that mult...|         twitter.com|    t5_2qh13| 1617244881|   null|           0|2021-04-01 02:41:21|   2|  5|
|fsq0qp|Russia sent Italy...| businessinsider.com|    t5_2qh13| 1585704991|   null|           0|2020-04-01 01:36:31|   1|  4|
|fblzl5|Ganbare Super Str...|   24hrnewsworld.com|    t5_2qh13| 1583022931|   null|           0|2020-03-01 00:35:31|   0|  1|
|lw5pgd|We attack Indones...|   news.mongabay.com|    t5_2qh13| 1614701722|   null|           0|2021-03-02 16:15:22|  

In [215]:
df.printSchema()

root
 |-- id: string (nullable = true)
 |-- title: string (nullable = true)
 |-- domain: string (nullable = true)
 |-- subreddit_id: string (nullable = true)
 |-- created_utc: string (nullable = true)
 |-- link_id: string (nullable = true)
 |-- num_comments: long (nullable = true)
 |-- timestamp: string (nullable = true)
 |-- hour: string (nullable = true)
 |-- day: string (nullable = true)



In [221]:
# Create features df by combining topic distributions for each title, with the other features of domain, hour and day
temp = indiv.select('topicDistribution')
# And a list of sequential column titles
X_titles = [f'T_{i}' for i in range(1, num_topics +1)]
X_titles = ['id'] + X_titles # adding the id at start

In [222]:
temp = indiv.withColumn("T_", vector_to_array("topicDistribution")).select(["id"] + [col("T_")[i] for i in range(20)]).drop('vectors', 'topicDistribution')
temp = temp.toDF(*X_titles)
temp.show(truncate =100)

+------+---------------------+---------------------+---------------------+---------------------+---------------------+---------------------+---------------------+---------------------+---------------------+---------------------+---------------------+---------------------+---------------------+---------------------+---------------------+---------------------+---------------------+---------------------+---------------------+---------------------+
|    id|                  T_1|                  T_2|                  T_3|                  T_4|                  T_5|                  T_6|                  T_7|                  T_8|                  T_9|                 T_10|                 T_11|                 T_12|                 T_13|                 T_14|                 T_15|                 T_16|                 T_17|                 T_18|                 T_19|                 T_20|
+------+---------------------+---------------------+---------------------+---------------------+------

In [223]:
rf_full = temp.join(df, on = 'id', how= 'left').drop('id','subreddit_id','title','timestamp','created_utc','link_id')

In [224]:
rf_full.printSchema()

root
 |-- T_1: double (nullable = true)
 |-- T_2: double (nullable = true)
 |-- T_3: double (nullable = true)
 |-- T_4: double (nullable = true)
 |-- T_5: double (nullable = true)
 |-- T_6: double (nullable = true)
 |-- T_7: double (nullable = true)
 |-- T_8: double (nullable = true)
 |-- T_9: double (nullable = true)
 |-- T_10: double (nullable = true)
 |-- T_11: double (nullable = true)
 |-- T_12: double (nullable = true)
 |-- T_13: double (nullable = true)
 |-- T_14: double (nullable = true)
 |-- T_15: double (nullable = true)
 |-- T_16: double (nullable = true)
 |-- T_17: double (nullable = true)
 |-- T_18: double (nullable = true)
 |-- T_19: double (nullable = true)
 |-- T_20: double (nullable = true)
 |-- domain: string (nullable = true)
 |-- num_comments: long (nullable = true)
 |-- hour: string (nullable = true)
 |-- day: string (nullable = true)



In [225]:
rf_full.count()

1700

In [226]:
# Convert the number of comments into numerical and create a grouped categorical
rf_full = rf_full.withColumn("num_comments", F.col("num_comments").astype(IntegerType()))
bucketizer = Bucketizer(splitsArray=[[0,1,2,float("inf")]],inputCols=["num_comments"], outputCols=["group_comments"])
rf_full = bucketizer.setHandleInvalid("keep").transform(rf_full)
rf_full = rf_full.drop('num_comments')
rf_full.show(8, truncate=100)

+---------------------+--------------------+---------------------+---------------------+---------------------+---------------------+---------------------+---------------------+---------------------+---------------------+---------------------+---------------------+---------------------+---------------------+---------------------+---------------------+---------------------+---------------------+---------------------+---------------------+-------------------------+----+---+--------------+
|                  T_1|                 T_2|                  T_3|                  T_4|                  T_5|                  T_6|                  T_7|                  T_8|                  T_9|                 T_10|                 T_11|                 T_12|                 T_13|                 T_14|                 T_15|                 T_16|                 T_17|                 T_18|                 T_19|                 T_20|                   domain|hour|day|group_comments|
+-----------------

## Now build the ML model

In [71]:
from pyspark.ml import Pipeline, PipelineModel
from pyspark.ml.classification import GBTClassifier, RandomForestClassifier
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
from pyspark.mllib.evaluation import MulticlassMetrics

In [54]:
# Need to one-hot-encode the domain, hour and day columns as they are categorical
#Create variable for all categorical columns
cat_cols = ['domain','hour','day']
# Look at indexing and one-hot-encoding results for categoricals
for cat_col in cat_cols:
    col_indexer = StringIndexer(inputCol=cat_col, outputCol=f"{cat_col}_ind")
    col_encoder = OneHotEncoder(inputCols=[f"{cat_col}_ind"], outputCols=[f"{cat_col}_ohe"])

In [55]:
#Create and empty list and put the string indexed and one-hot-encoded transformations into it
stages = []
for cat_col in cat_cols:
    col_indexer = StringIndexer(inputCol=cat_col, outputCol=f"{cat_col}_ind",handleInvalid='skip')
    col_encoder = OneHotEncoder(inputCols=[f"{cat_col}_ind"], outputCols=[f"{cat_col}_ohe"])
    stages += [col_indexer, col_encoder]

In [56]:
# Split data into train and test sets
train, test = rf_full.randomSplit([0.8, 0.2], seed=42)
# Seperate out the features from the data
X = rf_full.drop('group_comments')
y= rf_full.select('group_comments')
# make list of numeric columns (here that is all columns in features not categorical)
features_cols = X.columns
num_cols = [x for x in features_cols if x not in cat_cols]

In [57]:
# Change title of one not encoded categoricals
cat_cols_ohe = [f"{cat_col}_ohe" for cat_col in cat_cols]

In [58]:
# Build all features into one vetor
assembler = VectorAssembler(inputCols=cat_cols_ohe + num_cols, outputCol="features")
stages += [assembler]

In [59]:
#specify the model core and target variable
rf = RandomForestClassifier(labelCol="group_comments",featuresCol="features")
stages += [rf]

In [60]:
# Populate the pipeline
pipeline = Pipeline(stages=stages)

In [61]:
# FIT MODEL TO TRAINING SET VIA PIPELINE
from datetime import datetime 
start_time = datetime.now() 

pipeline_model = pipeline.fit(train)

print('Time elapsed (hh:mm:ss.ms) {}'.format(datetime.now() - start_time))

Time elapsed (hh:mm:ss.ms) 0:00:08.346667


In [62]:
#Save the model for later use
pipeline_model.save('rfC_model')

In [63]:
test_preds = pipeline_model.transform(test)

In [65]:
evaluator = MulticlassClassificationEvaluator(labelCol=rf.getLabelCol(), predictionCol=rf.getPredictionCol(), metricName="accuracy")
accuracy = evaluator.evaluate(test_preds)
print("Test Error = %g" % (1.0 - accuracy))

Test Error = 0.672913


In [67]:
test_preds.show(5)

+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+---------+----+---+--------------+----------+-------------+--------+---------------+-------+-------------+--------------------+--------------------+--------------------+----------+
|                 T_1|                 T_2|                 T_3|                 T_4|                 T_5|                 T_6|                 T_7|                 T_8|                 T_9|                T_10|                T_11|                T_12|                T_13|                T_14|                T_15|                T_16|                T_17|                T_18|                T_1

In [69]:
preds_and_labels = test_preds.select(['prediction','group_comments']).withColumn('label', F.col('group_comments').cast(FloatType())).orderBy('prediction')

In [72]:
#select only prediction and label columns
preds_and_labels = preds_and_labels.select(['prediction','label'])

metrics = MulticlassMetrics(preds_and_labels.rdd.map(tuple))

print(metrics.confusionMatrix().toArray())

[[ 30.  39. 112.]
 [ 39.  42. 126.]
 [ 34.  45. 120.]]


In [8]:
spark.stop()