In [7]:
from pyspark.sql import SparkSession

import sys
import pandas as pd
from functools import reduce
from pyspark.sql.functions import udf, broadcast

spark = SparkSession.builder.appName("QueryExecutor").getOrCreate()
sc = spark.sparkContext

In [2]:
queryPath = "./query.sql"
outputFile = "./result.csv"

In [26]:
df = spark.read.csv("./database.csv", header=True)
columnsRenamed = [ (c, c.replace(" ", "")) for c in df.columns]
df = reduce(lambda df, c: df.withColumnRenamed(c[0], c[1]), columnsRenamed, df)

In [27]:
states = spark.read.csv("./state_codes.csv", header=True)
df = df.join(broadcast(states), df.State == states.State, "left_outer").drop("State").withColumnRenamed("Abbreviation", "State")

In [28]:
df.select("State").dropDuplicates().collect()

[Row(State='AZ'),
 Row(State='SC'),
 Row(State='LA'),
 Row(State='MN'),
 Row(State='NJ'),
 Row(State='DC'),
 Row(State='OR'),
 Row(State=None),
 Row(State='VA'),
 Row(State='KY'),
 Row(State='WY'),
 Row(State='NH'),
 Row(State='MI'),
 Row(State='NV'),
 Row(State='WI'),
 Row(State='ID'),
 Row(State='CA'),
 Row(State='CT'),
 Row(State='NE'),
 Row(State='MT'),
 Row(State='NC'),
 Row(State='VT'),
 Row(State='MD'),
 Row(State='DE'),
 Row(State='MO'),
 Row(State='IL'),
 Row(State='ME'),
 Row(State='ND'),
 Row(State='WA'),
 Row(State='MS'),
 Row(State='AL'),
 Row(State='IN'),
 Row(State='OH'),
 Row(State='TN'),
 Row(State='IA'),
 Row(State='NM'),
 Row(State='PA'),
 Row(State='SD'),
 Row(State='NY'),
 Row(State='TX'),
 Row(State='WV'),
 Row(State='GA'),
 Row(State='MA'),
 Row(State='KS'),
 Row(State='CO'),
 Row(State='FL'),
 Row(State='AK'),
 Row(State='AR'),
 Row(State='OK'),
 Row(State='UT'),
 Row(State='HI')]

In [29]:
df.createOrReplaceTempView("homicides")

In [30]:
file = open(queryPath, "r")
query = file.read()
file.close()
query = query.replace("\n", " ")
query

'select state, count(1) as count from homicides group by state'

In [31]:
spark.sql(query).toPandas().to_csv(path_or_buf=outputFile)

[Row(state='AZ', count=12871),
 Row(state='SC', count=11698),
 Row(state='LA', count=19629),
 Row(state='MN', count=3975),
 Row(state='NJ', count=14132),
 Row(state='DC', count=7115),
 Row(state='OR', count=4217),
 Row(state='VA', count=15520),
 Row(state=None, count=1211),
 Row(state='KY', count=6554),
 Row(state='WY', count=630),
 Row(state='NH', count=655),
 Row(state='MI', count=28448),
 Row(state='NV', count=5553),
 Row(state='WI', count=6191),
 Row(state='ID', count=1150),
 Row(state='CA', count=99783),
 Row(state='CT', count=4896),
 Row(state='NE', count=1331),
 Row(state='MT', count=601),
 Row(state='NC', count=20390),
 Row(state='VT', count=412),
 Row(state='MD', count=17312),
 Row(state='DE', count=1179),
 Row(state='MO', count=14832),
 Row(state='IL', count=25871),
 Row(state='ME', count=869),
 Row(state='ND', count=308),
 Row(state='WA', count=7815),
 Row(state='MS', count=6546),
 Row(state='AL', count=11376),
 Row(state='IN', count=11463),
 Row(state='OH', count=19158),
 R

In [41]:
spark.stop()