# Reddit Mental Health Data Using Spark ML

In [1]:
from pyspark.sql import SparkSession

spark = SparkSession.builder.master('local[4]').appName('spark_ml').getOrCreate()

### Load the data

First, we load the data.

In [2]:
reddit = spark.read.options(inferSchema = True).csv('mental_disorders_reddit.csv', header=True)
reddit.printSchema()

root
 |-- title: string (nullable = true)
 |-- selftext: string (nullable = true)
 |-- created_utc: string (nullable = true)
 |-- over_18: string (nullable = true)
 |-- subreddit: string (nullable = true)



In [14]:
reddit = reddit.repartition(30)

In [15]:
import time
start_time = time.time()
print(reddit.count())
print("--- %s seconds ---" % (time.time() - start_time))

2018951
--- 2.069136381149292 seconds ---


### Word Counts of the Free Text Column

In [18]:
import pyspark.sql.functions as fn

start_time = time.time()

#Split the lines into words
words = reddit.select(fn.explode(fn.split(fn.concat_ws(" ", reddit.selftext), ' ')).alias('word'))

#Generate word count
word_counts = words.groupBy('word').count()

word_counts.show(10)

print("--- %s seconds ---" % (time.time() - start_time))

+-----------+------+
|       word| count|
+-----------+------+
|      still| 75215|
|       some|101747|
|   anymore,|  4839|
|     online|  8884|
|interaction|  1518|
|      those| 24251|
|  involving|   475|
|     fight,|   291|
|        few| 55002|
|       hope| 17592|
+-----------+------+
only showing top 10 rows

--- 11.045085668563843 seconds ---
