In [44]:
import sys
from pyspark.sql import SparkSession
from pyspark.sql.functions import udf, lit, col
from pyspark.sql.types import IntegerType, FloatType
import pycountry
import datetime

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [45]:
# Define date
year = "2020"
month = "05"
day = "29"

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [46]:
# Define input and output parameters
my_bucket = "<your-S3-bucket"
input_path = "s3://{}/covid-ingest-data/csse_covid_19_data/csse_covid_19_daily_reports/{}-{}-{}.csv"\
.format(my_bucket, month, day, year)
processed_output_path = "s3://{}/notebook/etl1_processed/".format(my_bucket)
aggregated_output_path = "s3://{}/notebook/etl1_aggregated/".format(my_bucket)
glue_db = "covid_project"
glue_processed_table = "{}.etl1_ntbk_processed".format(glue_db)
glue_aggregated_table = "{}.etl1_ntbk_aggregated".format(glue_db)

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [47]:
# To handle mapping of country name to country code 
input_countries = {}
unknown_countries = []
# make manual list of country codes from prvious unknown_countries
input_countries['Mainland China'] = "CN"
input_countries['Macau'] = 'MO'
input_countries['South Korea'] = 'KR'
input_countries['Ivory Coast'] = 'CI'
input_countries['North Ireland'] = 'GB'
input_countries['Republic of Ireland'] = 'IE'
input_countries['St. Martin'] = 'MF'
input_countries['Iran (Islamic Republic of)'] = 'IR'
input_countries['Taipei and environs'] = 'TW'
input_countries['occupied Palestinian territory'] = 'PS'
input_countries['Channel Islands'] = 'GBR'
input_countries['Korea, South'] = 'KR'
input_countries['Cruise Ship'] = 'XZ' # international waters
input_countries['Taiwan*'] = 'TW'
input_countries['Congo (Kinshasa)'] = 'CD'
input_countries['Congo (Brazzaville)'] = 'CG'
input_countries['Gambia, The'] = 'GM'
input_countries['Bahamas, The'] = 'BS'
input_countries['Cape Verde'] = 'CV'
input_countries['East Timor'] = 'TL'
input_countries['Laos'] = 'LA'
input_countries['Diamond Princess'] = 'XZ' # Cruise ship
input_countries['West Bank and Gaza'] = 'PS'
input_countries['Burma'] = 'MM'
input_countries['MS Zaandam'] = 'XZ' # Cruise ship
input_countries['Others'] = 'XZ'

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [48]:
# Types of schema discovered so far
#01-22-2020.csv
known_schema = [['Province/State', 'Country/Region', 'Last Update', 'Confirmed', 'Deaths', 'Recovered']]
#03-01-2020.csv
known_schema.append(['Province/State', 'Country/Region', 'Last Update', 'Confirmed', 'Deaths', 'Recovered', 'Latitude', 'Longitude'])
#03-22-2020.csv
known_schema.append(['FIPS', 'Admin2', 'Province_State', 'Country_Region', 'Last_Update', 'Lat', 'Long_', 'Confirmed', 'Deaths', 'Recovered', 'Active', 'Combined_Key'])
#05-29-2020.csv
known_schema.append(['FIPS', 'Admin2', 'Province_State', 'Country_Region', 'Last_Update', 'Lat', 'Long_', 'Confirmed', 'Deaths', 'Recovered', 'Active', 'Combined_Key', 'Incidence_Rate', 'Case-Fatality_Ratio'])


VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [49]:
# Function to get Country Code from name
def to_country_code(country):
    existing = input_countries.get(country)
    if existing: 
        return existing
    elif country in unknown_countries:
        return country
    else:
        try:
            cc = pycountry.countries.search_fuzzy(country)[0]
            input_countries[country] = cc.alpha_2
            return cc.alpha_2
        except Exception as e:
            unknown_countries.append(country)
            return country
        
# Register custom function as UDF
myudf = udf(to_country_code)

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [50]:
# Create Spark Session
spark = SparkSession \
  .builder \
  .appName("ETL1-Aggreate-stats-by-state") \
  .enableHiveSupport() \
  .getOrCreate()

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [51]:
raw = spark.read.csv(input_path, header=True)
raw.printSchema()

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

root
 |-- FIPS: string (nullable = true)
 |-- Admin2: string (nullable = true)
 |-- Province_State: string (nullable = true)
 |-- Country_Region: string (nullable = true)
 |-- Last_Update: string (nullable = true)
 |-- Lat: string (nullable = true)
 |-- Long_: string (nullable = true)
 |-- Confirmed: string (nullable = true)
 |-- Deaths: string (nullable = true)
 |-- Recovered: string (nullable = true)
 |-- Active: string (nullable = true)
 |-- Combined_Key: string (nullable = true)
 |-- Incidence_Rate: string (nullable = true)
 |-- Case-Fatality_Ratio: string (nullable = true)

In [52]:
# Conditional processing of input file based on match from known schema
if known_schema.index(list(raw.columns)) == 0:
    raw = raw.withColumnRenamed("Province/State", "state").withColumnRenamed("Country/Region", "country")\
        .withColumnRenamed("Last Update", "last_update").withColumnRenamed("Confirmed", "confirmed")\
        .withColumnRenamed("Deaths", "deaths").withColumnRenamed("Recovered", "recovered")\
        .withColumn('latitude', lit("null")).withColumn('longitude', lit("null"))\
        .withColumn("active", lit("null")).withColumn("incidence_rate", lit("null"))\
        .withColumn("case_fatality_ratio", lit("null"))
elif known_schema.index(list(raw.columns)) == 1:
    raw = raw.withColumnRenamed("Province/State", "state").withColumnRenamed("Country/Region", "country")\
        .withColumnRenamed("Last Update", "last_update").withColumnRenamed("Confirmed", "confirmed")\
        .withColumnRenamed("Deaths", "deaths").withColumnRenamed("Recovered", "recovered")\
        .withColumnRenamed("Latitude", "latitude").withColumnRenamed("Longitude", "longitude")\
        .withColumn("active", lit("null"))\
        .withColumn("incidence_rate", lit("null")).withColumn("case_fatality_ratio", lit("null"))
elif known_schema.index(list(raw.columns)) == 2:
    raw = raw.withColumnRenamed("Province_State", "state").withColumnRenamed("Country_Region", "country")\
        .withColumnRenamed("Last_Update", "last_update").withColumnRenamed("Lat", "latitude").withColumnRenamed("Long_", "longitude")\
        .withColumnRenamed("Confirmed", "confirmed").withColumnRenamed("Deaths", "deaths")\
        .withColumnRenamed("Recovered", "recovered").withColumnRenamed("Active", "active")\
        .withColumn('incidence_rate', lit("null")).withColumn("case_fatality_ratio")
elif known_schema.index(list(raw.columns)) == 3:
    raw = raw.withColumnRenamed("Province_State", "state").withColumnRenamed("Country_Region", "country")\
        .withColumnRenamed("Last_Update", "last_update").withColumnRenamed("Lat", "latitude")\
        .withColumnRenamed("Long_", "longitude").withColumnRenamed("Confirmed", "confirmed").withColumnRenamed("Deaths", "deaths")\
        .withColumnRenamed("Recovered", "recovered").withColumnRenamed("Active", "active")\
        .withColumnRenamed("Incidence_Rate", "incidence_rate").withColumnRenamed("Case-Fatality_Ratio", "case_fatality_ratio")
else:
    print("New schema found!")
    # Afterwards, we may want to send a notification to SNS whenever a new schema is detected.
    exit()

raw.printSchema()

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

root
 |-- FIPS: string (nullable = true)
 |-- Admin2: string (nullable = true)
 |-- state: string (nullable = true)
 |-- country: string (nullable = true)
 |-- last_update: string (nullable = true)
 |-- latitude: string (nullable = true)
 |-- longitude: string (nullable = true)
 |-- confirmed: string (nullable = true)
 |-- deaths: string (nullable = true)
 |-- recovered: string (nullable = true)
 |-- active: string (nullable = true)
 |-- Combined_Key: string (nullable = true)
 |-- incidence_rate: string (nullable = true)
 |-- case_fatality_ratio: string (nullable = true)

In [53]:
mapped = raw.withColumn("country", myudf(raw.country))
mapped = mapped.fillna('null')
mapped = mapped.dropDuplicates()

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [54]:
# Cast columns to proper data type
mapped = mapped.withColumn("confirmed", col("confirmed").cast(IntegerType()))\
    .withColumn("deaths", col("deaths").cast(IntegerType()))\
    .withColumn("recovered", col("recovered").cast(IntegerType()))\
    .withColumn("latitude", col("latitude").cast(FloatType()))\
    .withColumn("longitude", col("longitude").cast(FloatType()))\
    .withColumn("active", col("active").cast(IntegerType()))\
    .withColumn("incidence_rate", col("incidence_rate").cast(FloatType()))\
    .withColumn("case_fatality_ratio", col("case_fatality_ratio").cast(FloatType()))\
    .select("state", "country", "last_update", "latitude", "longitude", "confirmed", "deaths", "recovered", "active", "incidence_rate", "case_fatality_ratio")

mapped.printSchema()

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

root
 |-- state: string (nullable = false)
 |-- country: string (nullable = false)
 |-- last_update: string (nullable = false)
 |-- latitude: float (nullable = true)
 |-- longitude: float (nullable = true)
 |-- confirmed: integer (nullable = true)
 |-- deaths: integer (nullable = true)
 |-- recovered: integer (nullable = true)
 |-- active: integer (nullable = true)
 |-- incidence_rate: float (nullable = true)
 |-- case_fatality_ratio: float (nullable = true)

In [55]:
# Add new columns to put month (YYYY-mm) and day
# Write this dataframe separately to S3 and Hive metastore to preserve original content.
processed = mapped.withColumn('month', lit("{}-{}".format(year, month))).withColumn('day', lit("{}".format(day)))
processed.show(4)

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+-------------+-------+-------------------+---------+----------+---------+------+---------+------+--------------+-------------------+-------+---+
|        state|country|        last_update| latitude| longitude|confirmed|deaths|recovered|active|incidence_rate|case_fatality_ratio|  month|day|
+-------------+-------+-------------------+---------+----------+---------+------+---------+------+--------------+-------------------+-------+---+
|      Georgia|     US|2020-05-30 02:32:48|33.942432|-84.576126|     2987|   170|        0|  2817|      392.9534|           5.691329|2020-05| 29|
|New Hampshire|     US|2020-05-30 02:32:48| 44.69063|-71.306335|        5|     0|        0|     5|     15.841333|                0.0|2020-05| 29|
|        Idaho|     US|2020-05-30 02:32:48| 43.35071|-115.47016|       31|     2|        0|    29|      112.6822|           6.451613|2020-05| 29|
|     Illinois|     US|2020-05-30 02:32:48| 39.00072| -89.02453|       20|     3|        0|    17|      93.73828|           

In [56]:
# Create another dataframe by grouping the data by country and getting sum of stats.
summed = mapped.groupBy(['state','country']).sum('confirmed','deaths','recovered')\
.withColumnRenamed("sum(confirmed)", "confirmed").withColumnRenamed("sum(deaths)", "deaths")\
.withColumnRenamed("sum(recovered)", "recovered")

aggregated_df = summed.withColumn('month', lit("{}-{}".format(year, month))).withColumn('day', lit("{}".format(day)))
aggregated_df.show(4)

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+-------------+-------+---------+------+---------+-------+---+
|        state|country|confirmed|deaths|recovered|  month|day|
+-------------+-------+---------+------+---------+-------+---+
|         null|     LT|     1662|    68|     1216|2020-05| 29|
|    Andalusia|     ES|    12655|  1404|    10671|2020-05| 29|
|Niedersachsen|     DE|    11893|   596|    10600|2020-05| 29|
|Metropolitana|     CL|    72910|   685|        0|2020-05| 29|
+-------------+-------+---------+------+---------+-------+---+
only showing top 4 rows

In [57]:
# Write un-aggregated and processed data to Hive and S3. This retains original stats by co-ordinates for a day.
processed.coalesce(1).write.mode("append").option("path",processed_output_path)\
.partitionBy("country", "month").format("Parquet").saveAsTable(glue_processed_table)  

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [None]:
# Write aggregated data to Hive and S3. This summarizes stats by state and country for a day.
aggregated_df.repartition(1).write.mode("append").option("path", aggregated_output_path)\
.partitionBy("country", "month").format("Parquet").saveAsTable(glue_aggregated_table)