In [1]:
from pyspark.sql import SparkSession
import pyspark.sql.functions as F
from pyspark.sql.types import NullType, StringType
from pyspark.sql.functions import col, lit
import psycopg2

In [2]:
# Notebook configuration parameters
CREATE_TABLES = True # Set to False to skip table creation
FIX_NULLS = True # Set to False to not fix nulls (code will not work)

# Postgres configuration parameters
POSTGRES_HOST = "localhost"
POSTGRES_PORT = "5432"
POSTGRES_DATABASE = "postgres"
POSTGRES_USER = "postgres"
POSTGRES_PASSWORD = "vanguard" # Would store securely if running in production

POSTGRES_URL = "jdbc:postgresql://{}:{}/{}".format(POSTGRES_HOST, POSTGRES_PORT, POSTGRES_DATABASE)

In [3]:
# Generates a Postgres column schema to be used in CREATE TABLE statements
def build_column_schema(column_names, column_optionals, column_types):
    temp = []
    for n in range(0, len(column_names)):
        col_definition = column_names[n] + " " + column_types[n]
        if not column_optionals[n]:
            col_definition += " NOT NULL"
        temp.append(col_definition)
    temp.append("yyyy integer")
    temp.append("yyyymmdd integer")
    temp.append("timestamp timestamp")
    temp.append("kind varchar")
    column_schema = ", ".join(temp)
    return column_schema

In [4]:
# Function to create Postgres tables
def create_table(input_df, table_name):

    # After inspecting the data, this will result in only 1 row, so can use distinct and collect
    collected = input_df.select("columnnames", "columnoptionals", "columntypes", "schema").distinct().collect()[0]
    column_names = collected["columnnames"]
    column_optionals = collected["columnoptionals"]
    column_types = collected["columntypes"]
    schema = collected["schema"]
    
    column_schema = build_column_schema(column_names, column_optionals, column_types)
    
    create_statement = "CREATE TABLE IF NOT EXISTS {}.{} ({})".format(schema, table_name, column_schema)
    
    # Connect to Postgres and create table
    conn = psycopg2.connect(
        host=POSTGRES_HOST,
        database=POSTGRES_DATABASE,
        user=POSTGRES_USER,
        password=POSTGRES_PASSWORD
    )
    cur = conn.cursor()
    cur.execute(create_statement)
    cur.close()
    conn.commit()

In [5]:
# Function to process incoming data and insert into Postgres
def process_data(input_df, table_name, fix_nulls=True):
    
    collected = input_df.select("columnnames", "columnoptionals", "columntypes", "schema").distinct().collect()[0]
    column_names = collected["columnnames"]
    schema = collected["schema"]
    
    select_cols = list(map(
        lambda f: col("mapped").getItem(f).alias(str(f)),
        column_names
    ))
    select_cols = select_cols + ["yyyy", "yyyymmdd", "timestamp", "kind"]
    
    # Zip the two arrays and then create a map
    # From map, we can select all the fields
    output_df = (
        input_df
        .withColumn("mapped", F.map_from_entries(F.arrays_zip(col("columnnames"), col("columnvalues"))))
        .select(select_cols)
    )
    
    # Bug fix - Writes with nulls won't work, so cleaning them to allow code to run
    # https://stackoverflow.com/questions/64671739/pyspark-nullable-uuid-type-uuid-but-expression-is-of-type-character-varying
    if (fix_nulls):
        if (table_name == "forms"):
            output_df = output_df.fillna({
                "deleted_by": "00000000-0000-0000-0000-000000000000",
                "deleted_at": "2999-12-31 23:59:59.000"
            })
        elif (table_name == "questions"):
            output_df = output_df.fillna({
                "deleted_by": "00000000-0000-0000-0000-000000000000",
                "deleted_at": "2999-12-31 23:59:59.000",
                "min_selections": 0,
                "max_selections": 0,
                "duplicated_from_question_id": "00000000-0000-0000-0000-000000000000",
            })
        elif (table_name == "submissions"):
            output_df = output_df.fillna({
                "updated_by": "00000000-0000-0000-0000-000000000000",
                "context_id": "00000000-0000-0000-0000-000000000000",
                "created_by": "00000000-0000-0000-0000-000000000000",
            })
    
    (
        output_df
        .write
        .mode("append")
        .option("driver", "org.postgresql.Driver")
        #.option("createTableColumnTypes", column_schema) # SEE APPROACH #1 IN NOTES
        .jdbc(POSTGRES_URL, schema + "." + table_name,
             properties = {
                 "user": POSTGRES_USER,
                 "password": POSTGRES_PASSWORD,
                 "stringtype": "unspecified"
             }
             )
    )

In [6]:
# MAIN WORKFLOW

In [7]:
# Create Spark Session
# Added postgresql JDBC driver to the class path
spark = (
    SparkSession.builder.master("local").appName("vanguard")
    .config("spark.driver.extraClassPath", "postgresql-42.2.18.jar")
    .config("spark.executor.extraClassPath", "postgresql-42.2.18.jar")
    .getOrCreate()
)

In [8]:
# Read in the write ahead logs
data = spark.read.format("json").option("basePath", "data").load("data/*/*")

In [9]:
# Explode the change object so there's only one database transaction per row
exploded = (
    data
    .withColumn("change", F.explode("change"))
    .select(
        "change.columnnames",
        "change.columnoptionals",
        "change.columntypes",
        "change.columnvalues",
        "change.kind",
        "change.table",
        "change.schema",
        "yyyy",
        "yyyymmdd",
        "timestamp"
    )
)

In [10]:
# Process each table and write to Postgres
for table_name in ["forms", "questions", "responses", "submissions"]:
    input_df = exploded.filter(col("table") == table_name)
    if CREATE_TABLES:
        create_table(input_df, table_name)
    process_data(input_df, table_name, FIX_NULLS)