In [1]:
from datetime import datetime, timedelta
import logging
import os
import pandas as pd
import psycopg2
import configparser
import datetime

from pyspark.sql.functions import expr, udf, col, lit, year, month, upper, to_date, monotonically_increasing_id, when, substring
from pyspark.sql.functions import sum as _sum
from pyspark.sql.functions import date_add as d_add 
from pyspark.sql.types import DoubleType

from pyspark.sql import SparkSession

In [2]:
def create_spark_session():
    spark = SparkSession.builder.\
    config("spark.jars.packages","saurfang:spark-sas7bdat:3.0.0-s_2.12")\
    .enableHiveSupport().getOrCreate()
    return spark

def rename_columns(table, new_columns):
    for original, new in zip(table.columns, new_columns):
        table = table.withColumnRenamed(original, new)
    return table

# DB Properties
db_properties={}
config = configparser.ConfigParser()
config.read("parameters.cfg")
db_prop = config['postgresql']
db_url = db_prop['url']
db_user = db_prop['username']
db_password = db_prop['password']
db_properties['username']=db_prop['username']
db_properties['password']=db_prop['password']
db_properties['url']=db_prop['url']
db_properties['driver']=db_prop['driver']

In [3]:
def read_immigration_data(input_path):
    # Read in the data here
    sas_files = os.listdir(f"{input_path}sas/")
    spark = create_spark_session()

    def customUnion(df1, df2):
        cols1 = df1.columns
        cols2 = df2.columns
        total_cols = sorted(cols1 + list(set(cols2) - set(cols1)))
        def expr(mycols, allcols):
            def processCols(colname):
                if colname in mycols:
                    return colname
                else:
                    return lit(None).alias(colname)
            cols = map(processCols, allcols)
            return list(cols)
        appended = df1.select(expr(cols1, total_cols)).union(df2.select(expr(cols2, total_cols)))
        return appended

    i = 0
    for file in sas_files:
        if i == 0:
            df_spark = spark\
                    .read\
                    .format('com.github.saurfang.sas.spark')\
                    .load(f"{input_path}sas/{file}")
            df_final = df_spark
        else:
            df_spark = spark\
                    .read\
                    .format('com.github.saurfang.sas.spark')\
                    .load(f"{input_path}sas/{file}")
            df_final = customUnion(df_final,df_spark)
        i += 1
    
    return df_final

In [4]:
def create_immigration_table(df, table):
    """ 
        Description: create the fact immigration table
        Arguments:
            df: dataframe generated by Spark
            table: table name
            output_path: Target directory
        Return:
            df: dataframe generated by Spark
    """

    logging.info("Start creating states table")

    sas_epoch = datetime.datetime(1960, 1, 1)

    df_immigration = df \
            .withColumn("immigration_id", monotonically_increasing_id()) \
            .withColumn("cic_id", col("cicid").cast("integer")) \
            .drop("cicid") \
            .withColumnRenamed("i94addr", "cod_state") \
            .withColumnRenamed("i94port", "cod_port") \
            .withColumn("cod_visa", col("i94visa").cast("integer")) \
            .drop("i94visa") \
            .withColumn("cod_mode", col("i94mode").cast("integer")) \
            .drop("i94mode") \
            .withColumn("cod_country_origin", col("i94res").cast("integer")) \
            .drop("i94res") \
            .withColumn("cod_country_cit", col("i94cit").cast("integer")) \
            .drop("i94cit") \
            .withColumn("year", col("i94yr").cast("integer")) \
            .drop("i94yr") \
            .withColumn("month", col("i94mon").cast("integer")) \
            .drop("i94mon") \
            .withColumn("bird_year", col("biryear").cast("integer")) \
            .drop("biryear") \
            .withColumn("age", col("i94bir").cast("integer")) \
            .drop("i94bir") \
            .withColumn("counter", col("count").cast("integer")) \
            .drop("count") \
            .withColumn("arrival_date", to_date(lit("arrdate"))) \
            .withColumn("departure_date", to_date(lit("depdate"))) \
            .drop("arrdate", "depdate")

    df_immigration.printSchema()

    df_immigration.write \
        .mode("overwrite") \
        .format("jdbc") \
        .option("url", db_url) \
        .option("dbtable", table) \
        .option("user", db_user) \
        .option("password", db_password) \
        .save()

In [5]:
def create_airlines_table(input_path, table):

    spark = create_spark_session()

    df_raw = spark.\
        read\
        .format("csv")\
        .option("header", "false")\
        .option("delimiter", ",")\
        .load(input_path + "airlines.csv")

    df_airlines = df_raw.\
            select("*")\
            .withColumn("airline_id", df_raw[0]) \
            .withColumn("airline_name", df_raw[1]) \
            .withColumn("alias", df_raw[2]) \
            .withColumn("iata", df_raw[3]) \
            .withColumn("ica0", df_raw[4]) \
            .withColumn("callsign", df_raw[5]) \
            .withColumn("country", df_raw[6]) \
            .withColumn("active", df_raw[7]) \
            .drop("_c0", "_c1", "_c2", "_c3", "_c4", "_c5", "_c6", "_c7")

    df_airlines.printSchema()

    df_airlines = df_airlines\
            .select("airline_id", "airline_name", "alias", "iata", "ica0", "callsign", "country", "active")\
            .withColumn("alias", when(df_airlines["alias"] == "\\N", None).otherwise(df_airlines["alias"]))\
            .write \
            .mode("overwrite") \
            .format("jdbc") \
            .option("url", db_url) \
            .option("dbtable", table) \
            .option("user", db_user) \
            .option("password", db_password) \
            .save()

In [6]:
def create_airports_table(input_path, table):
    spark = create_spark_session()

    df_raw = spark.\
        read\
        .format("csv")\
        .option("header", "true")\
        .option("delimiter", ";")\
        .load(input_path + "airport-codes_csv.csv")

    df_airports = df_raw.\
            select("*")\
            .where(
                (col("iso_country") == "US")
                & (col("type").isin("large_airport", "medium_airport", "small_airport"))
                ) \
            .withColumn("iso_region", substring(col("iso_region"), 4, 2)) \
            .withColumn("elevation_ft", col("elevation_ft").cast("float"))
    
    df_airports.printSchema()

    df_airports \
            .write \
            .mode("overwrite") \
            .format("jdbc") \
            .option("url", db_url) \
            .option("dbtable", table) \
            .option("user", db_user) \
            .option("password", db_password) \
            .save()

In [7]:
def create_demographics_table(input_path, table):

    spark = create_spark_session()

    df_raw = spark.\
        read\
        .format("csv")\
        .option("header", "true")\
        .option("delimiter", ";")\
        .load(input_path + "us-cities-demographics.csv")

    df_demographics = df_raw.\
            select("*")\
            .groupBy(col("City"), col("State"), col("Median Age"), col("Male Population"),
                                        col("Female Population") \
                                        , col("Total Population"), col("Number of Veterans"), col("Foreign-born"),
                                        col("Average Household Size") \
                                        , col("State Code")).pivot("Race").agg(_sum("count").cast("integer")) \
            .fillna({"American Indian and Alaska Native": 0,
                        "Asian": 0,
                        "Black or African-American": 0,
                        "Hispanic or Latino": 0,
                        "White": 0})
    
    df_demographics.printSchema()

    df_demographics \
            .write \
            .mode("overwrite") \
            .format("jdbc") \
            .option("url", db_url) \
            .option("dbtable", table) \
            .option("user", db_user) \
            .option("password", db_password) \
            .save()

In [8]:
def process_label_descriptions(input_path):
    """ Parsing label desctiption file to get codes of country, city, state
        Arguments:
            input_path {object}: Source S3 endpoint
        Returns:
            None
    """

    spark = create_spark_session()

    logging.info("Start processing label descriptions")
    label_file = os.path.join(input_path + "I94_SAS_Labels_Descriptions.SAS")
    with open(label_file) as f:
        contents = f.readlines()

    print("    Creating country table")
    country_code = {}
    for countries in contents[10:298]:
        pair = countries.split('=')
        code, country = pair[0].strip(), pair[1].strip().strip("'")
        country_code[code] = country
    df_countries = spark.createDataFrame(country_code.items(), ['code', 'country'])
    df_countries \
        .write \
        .mode("overwrite") \
        .format("jdbc") \
        .option("url", db_url) \
        .option("dbtable", "countries") \
        .option("user", db_user) \
        .option("password", db_password) \
        .save()

    df_countries.printSchema()

    print("    Creating city table")
    city_code = {}
    for cities in contents[303:962]:
        pair = cities.split('=')
        code, city = pair[0].strip("\t").strip().strip("'"),\
                     pair[1].strip('\t').strip().strip("''")
        city_code[code] = city
    df_cities = spark.createDataFrame(city_code.items(), ['code', 'city'])
    df_cities \
        .write \
        .mode("overwrite") \
        .format("jdbc") \
        .option("url", db_url) \
        .option("dbtable", "cities") \
        .option("user", db_user) \
        .option("password", db_password) \
        .save()

    df_cities.printSchema()

    print("    Creating state table")
    state_code = {}
    for states in contents[982:1036]:
        pair = states.split('=')
        code, state = pair[0].strip('\t').strip("'"), pair[1].strip().strip("'")
        state_code[code] = state
    df_states = spark.createDataFrame(state_code.items(), ['code', 'state'])
    df_states \
        .write \
        .mode("overwrite") \
        .format("jdbc") \
        .option("url", db_url) \
        .option("dbtable", "states") \
        .option("user", db_user) \
        .option("password", db_password) \
        .save()

    df_states.printSchema()

In [9]:
input_path = "D:/Temp/Data/"
output_path = "D:/Temp/Results/"

print("Reading immigration data...")
cachedDF = read_immigration_data(input_path)
print("Creating immigration table...")
create_immigration_table(cachedDF, "immigration")
cachedDF.unpersist()
print("Creating airlines table...")
create_airlines_table(input_path, "airlines")
print("Creating airports table...")
create_airports_table(input_path, "airports")
print("Creating tables from SAS desciption file...")
process_label_descriptions(input_path)
print("Process finished")

Reading immigration data...
Creating immigration table...
root
 |-- admnum: double (nullable = true)
 |-- airline: string (nullable = true)
 |-- delete_days: double (nullable = true)
 |-- delete_dup: double (nullable = true)
 |-- delete_mexl: double (nullable = true)
 |-- delete_recdup: double (nullable = true)
 |-- delete_visa: double (nullable = true)
 |-- dtaddto: string (nullable = true)
 |-- dtadfile: string (nullable = true)
 |-- entdepa: string (nullable = true)
 |-- entdepd: string (nullable = true)
 |-- entdepu: string (nullable = true)
 |-- fltno: string (nullable = true)
 |-- gender: string (nullable = true)
 |-- cod_state: string (nullable = true)
 |-- cod_port: string (nullable = true)
 |-- insnum: string (nullable = true)
 |-- matflag: string (nullable = true)
 |-- occup: string (nullable = true)
 |-- validres: double (nullable = true)
 |-- visapost: string (nullable = true)
 |-- visatype: string (nullable = true)
 |-- immigration_id: long (nullable = false)
 |-- cic_id: 