In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from pyspark.sql import SparkSession
# import pyspark
# from collections import Counters

In [None]:
# data = pd.read_csv('the-reddit-climate-change-dataset-comments.csv')
# data = data.dropna() # drop any rows with missing values

In [None]:
# counter = Counter(data['subreddit.nsfw'])
# print(counter)

# print(data['score'])
#print(data['body'].iloc[20340])

In [None]:
# initialize spark session
spark = SparkSession.builder \
    .appName("Reddit Climate Change Comments") \
    .config("spark.driver.memory", "4g") \
    .config("spark.executor.memory", "8g") \
    .config("spark.sql.shuffle.partitions", "100") \
    .getOrCreate()

In [None]:
# create dataframe
df = spark.read.csv("the-reddit-climate-change-dataset-comments.csv", header=True, inferSchema=True)
df = df.repartition(100)  # increase the number of partitions for large datasets
df.show(5, truncate=False)

df_original = df # save original dataset


In [None]:
# data exploration: initial rows and columns check
row_count1 = df.count() 
print(f"Initial dataset rows: {row_count1}")
print(f"Initial dataset columns: {df.columns}")

In [None]:
# -- clean dataset ---
# clean?? dataset to only data of type 'comment'
df_filtered = df.filter(df["type"] == "comment")


In [None]:
df_filtered.show(5, truncate=False) # intermediary check to see if the filtered work


In [None]:
# count the number of rows with type "comment"
print(f"Number of rows with type 'comment': {df_filtered.count()}") # check to see if any rows were actually removed


In [None]:
# create spark database
# spark.sql("CREATE DATABASE reddit_db") - UNCOMMENT IF NOT CREATED YET

In [None]:
spark.sql("SHOW DATABASES").show() # check that reddit_db is in here


In [None]:
spark.sql("SHOW TABLES").show() # should be empty tables

In [None]:
# Drop the table if it already exists
# spark.sql("USE reddit_db")
# df_filtered.write.mode("overwrite").saveAsTable("reddit_db.comments")
spark.sql("DROP TABLE IF EXISTS reddit_db.comments")


In [None]:
spark.sql("""
CREATE TABLE IF NOT EXISTS reddit_db.comments (
    `type` STRING,
    `id` STRING,
    `subreddit.id` STRING,
    `subreddit.name` STRING,
    `subreddit.nsfw` STRING,
    `created_utc` STRING,
    `permalink` STRING,
    `body` STRING,
    `sentiment` STRING,
    `score` STRING
)
USING PARQUET
""")

In [None]:
spark.sql("SHOW TABLES").show() # should be updated to have one table now

In [None]:
df_filtered.show(5, truncate=False)

In [None]:
#df_filtered.printSchema()


In [None]:
#spark.sql("DESCRIBE reddit_db.comments").show()
# align the columns - spark only accepts '_' but the dataset uses '.'
df_aligned = df_filtered \
    .withColumnRenamed("subreddit.id", "subreddit_id") \
    .withColumnRenamed("subreddit.name", "subreddit_name") \
    .withColumnRenamed("subreddit.nsfw", "subreddit_nsfw")



In [None]:
df_aligned.printSchema() # double check


In [None]:
spark.sql("SHOW TABLES").show()

In [None]:
df_aligned.write.insertInto("reddit_db.comments", overwrite=False) # insert data from csv/df into spark table

In [None]:
spark.sql("SELECT * FROM reddit_db.comments LIMIT 5").show() #validate the table