# Project Overview
Dimensional data mart that represents a simple business involving customers, the products, and the vendors.

###  Import Required Libraries

In [None]:
import findspark
findspark.init()
print(findspark.find())

import os
import sys
import json
import time
import pymongo
import certifi
import shutil
import pandas as pd

from pyspark import SparkConf
from pyspark.sql import SparkSession
from pyspark.sql.functions import *
from pyspark.sql.types import *
from pyspark.sql.window import Window as W
from sqlalchemy import create_engine, text

### Instantiate Global Variables and Paths for Data Sources

In [None]:
# --------------------------------------------------------------------------------
# Specify MySQL Server Connection Information
# --------------------------------------------------------------------------------
mysql_args = {
    "host_name" : "localhost",
    "port" : "3306",
    "db_name" : "northwind_dw",
    "conn_props" : {
        "user" : "root",
        "password" : "Ashwaniis#1!",
        "driver" : "com.mysql.cj.jdbc.Driver"
    }
}

# --------------------------------------------------------------------------------
# Specify MongoDB Cluster Connection Information
# --------------------------------------------------------------------------------
mongodb_args = {
    "cluster_location" : "local", # "atlas"
    "user_name" : "vaneeshagupta10",
    "password" : "Fdztq26kWFlyBXiE",
    "cluster_name" : "cluster0",
    "cluster_subnet" : "koqso",
    "db_name" : "northwind",
    "collection" : "",
    "null_column_threshold" : 0.5
}


# --------------------------------------------------------------------------------
# Specify Directory Structure for Source Data
# --------------------------------------------------------------------------------
base_dir = "dbfs:/mnt/data"  # DBFS path for your data
data_dir = os.path.join(base_dir, 'source_data')
batch_dir = os.path.join(data_dir, 'batch')
stream_dir = os.path.join(data_dir, 'streaming')

# --------------------------------------------------------------------------------
# Streaming Data Source Directories
# --------------------------------------------------------------------------------
orders_stream_dir = os.path.join(stream_dir, 'orders')
purchase_orders_stream_dir = os.path.join(stream_dir, 'purchase_orders')
inventory_trans_stream_dir = os.path.join(stream_dir, 'inventory_transactions')

# --------------------------------------------------------------------------------
# Databricks/SQL Warehouse Paths (Bronze, Silver, Gold Layers)
# --------------------------------------------------------------------------------
dest_database = "data_mart_dlh"
sql_warehouse_dir = os.path.abspath('spark-warehouse')
dest_database_dir = f"{dest_database}.db"
database_dir = os.path.join(sql_warehouse_dir, dest_database_dir)

orders_output_bronze = os.path.join(database_dir, 'fact_orders', 'bronze')
orders_output_silver = os.path.join(database_dir, 'fact_orders', 'silver')
orders_output_gold = os.path.join(database_dir, 'fact_orders', 'gold')

purchase_orders_output_bronze = os.path.join(database_dir, 'fact_purchase_orders', 'bronze')
purchase_orders_output_silver = os.path.join(database_dir, 'fact_purchase_orders', 'silver')
purchase_orders_output_gold = os.path.join(database_dir, 'fact_purchase_orders', 'gold')

inventory_trans_output_bronze = os.path.join(database_dir, 'fact_inventory_transactions', 'bronze')
inventory_trans_output_silver = os.path.join(database_dir, 'fact_inventory_transactions', 'silver')
inventory_trans_output_gold = os.path.join(database_dir, 'fact_inventory_transactions', 'gold')




### Define Global Variables for MySQL, MongoDB, and File Handling

In [None]:

def get_file_info(path: str):
    file_sizes = []
    modification_times = []
    items = os.listdir(path)
    files = sorted([item for item in items if os.path.isfile(os.path.join(path, item))])
    for file in files:
        file_sizes.append(os.path.getsize(os.path.join(path, file)))
        modification_times.append(pd.to_datetime(os.path.getmtime(os.path.join(path, file)), unit='s'))
    data = list(zip(files, file_sizes, modification_times))
    column_names = ['name','size','modification_time']
    return pd.DataFrame(data=data, columns=column_names)

def wait_until_stream_is_ready(query, min_batches=1):
    while len(query.recentProgress) < min_batches:
        time.sleep(5)
    print(f"The stream has processed {len(query.recentProgress)} batches")

def remove_directory_tree(path: str):
    try:
        if os.path.exists(path):
            shutil.rmtree(path)
            return f"Directory '{path}' has been removed successfully."
        else:
            return f"Directory '{path}' does not exist."
    except Exception as e:
        return f"An error occurred: {e}"

def drop_null_columns(df, threshold):
    columns_with_nulls = [col for col in df.columns if df.filter(df[col].isNull()).count() / df.count() > threshold] 
    df_dropped = df.drop(*columns_with_nulls) 
    return df_dropped

def get_mysql_dataframe(spark_session, sql_query : str, **args):
    jdbc_url = f"jdbc:mysql://{args['host_name']}:{args['port']}/{args['db_name']}"
    dframe = spark_session.read.format("jdbc") \
        .option("url", jdbc_url) \
        .option("driver", args['conn_props']['driver']) \
        .option("user", args['conn_props']['user']) \
        .option("password", args['conn_props']['password']) \
        .option("query", sql_query) \
        .load()
    return dframe

def get_mongo_uri(**args):
    if args["cluster_location"] not in ['atlas', 'local']:
        raise Exception("You must specify either 'atlas' or 'local' for the 'cluster_location' parameter.")
        
    if args['cluster_location'] == "atlas":
        uri = f"mongodb+srv://{args['user_name']}:{args['password']}@"
        uri += f"{args['cluster_name']}.{args['cluster_subnet']}.mongodb.net/"
    else:
        uri = "mongodb://localhost:27017/"
    return uri

def get_spark_conf_args(spark_jars : list, **args):
    jars = ""
    for jar in spark_jars:
        jars += f"{jar}, "
    
    sparkConf_args = {
        "app_name" : "PySpark Northwind Data Lakehouse (Medallion Architecture)",
        "worker_threads" : f"local[{int(os.cpu_count()/2)}]",
        "shuffle_partitions" : int(os.cpu_count()),
        "mongo_uri" : get_mongo_uri(**args),
        "spark_jars" : jars[0:-2],
        "database_dir" : sql_warehouse_dir
    }
    
    return sparkConf_args

def get_spark_conf(**args):
    sparkConf = SparkConf().setAppName(args['app_name'])\
    .setMaster(args['worker_threads']) \
    .set('spark.driver.memory', '4g') \
    .set('spark.executor.memory', '2g') \
    .set('spark.jars', args['spark_jars']) \
    .set('spark.jars.packages', 'org.mongodb.spark:mongo-spark-connector_2.12:3.0.1') \
    .set('spark.mongodb.input.uri', args['mongo_uri']) \
    .set('spark.mongodb.output.uri', args['mongo_uri']) \
    .set('spark.sql.adaptive.enabled', 'false') \
    .set('spark.sql.debug.maxToStringFields', 35) \
    .set('spark.sql.shuffle.partitions', args['shuffle_partitions']) \
    .set('spark.sql.streaming.forceDeleteTempCheckpointLocation', 'true') \
    .set('spark.sql.streaming.schemaInference', 'true') \
    .set('spark.sql.warehouse.dir', args['database_dir']) \
    .set('spark.streaming.stopGracefullyOnShutdown', 'true')
    
    return sparkConf

def get_mongo_client(**args):
    mongo_uri = get_mongo_uri(**args)
    if args['cluster_location'] == "atlas":
        client = pymongo.MongoClient(mongo_uri, tlsCAFile=certifi.where())
    elif args['cluster_location'] == "local":
        client = pymongo.MongoClient(mongo_uri)
    else:
        raise Exception("A MongoDB Client could not be created.")
    return client
    
def set_mongo_collections_with_pyspark(spark_session, data_directory: str, json_files: dict, **mongo_args):
    db_name = mongo_args["db_name"]
    mongo_uri = get_mongo_uri(**mongo_args)

    for collection_name, filename in json_files.items():
        json_file_path = os.path.join(data_directory, filename)

        df = spark_session.read \
            .option("multiline", "true") \
            .json(json_file_path)

        df.write \
            .format("com.mongodb.spark.sql.DefaultSource") \
            .mode("overwrite") \
            .option("uri", mongo_uri) \
            .option("database", db_name) \
            .option("collection", collection_name) \
            .save()

        print(f"✔ Loaded {filename} into MongoDB collection '{collection_name}'")

def get_mongodb_dataframe(spark_session, **args):
    dframe = spark_session.read.format("com.mongodb.spark.sql.DefaultSource") \
        .option("database", args['db_name']) \
        .option("collection", args['collection']).load()
    dframe = dframe.drop('_id')
    dframe = drop_null_columns(dframe, args['null_column_threshold'])
    return dframe


### Initialize Data Lakehouse Directory Structure
Remove the Data Lakehouse Database Directory Structure to Ensure Idempotency

In [None]:
remove_directory_tree(database_dir)

### Create a New Spark Session

In [None]:
worker_threads = f"local[{int(os.cpu_count()/2)}]"

jars = []
mysql_spark_jar = os.path.join(os.getcwd(), "mysql-connector-j-9.1.0", "mysql-connector-j-9.1.0.jar")
mssql_spark_jar = os.path.join(os.getcwd(), "sqljdbc_12.8", "enu", "jars", "mssql-jdbc-12.8.1.jre11.jar")

jars.append(mysql_spark_jar)
#jars.append(mssql_spark_jar)

sparkConf_args = get_spark_conf_args(jars, **mongodb_args)

sparkConf = get_spark_conf(**sparkConf_args)
spark = SparkSession.builder.config(conf=sparkConf).getOrCreate()
spark.sparkContext.setLogLevel("OFF")
spark

### Create a New Metadata Database

In [None]:
spark.sql(f"DROP DATABASE IF EXISTS {dest_database} CASCADE;")

sql_create_db = f"""
    CREATE DATABASE IF NOT EXISTS {dest_database}
    COMMENT 'DS-2002 Lab 06 Database'
    WITH DBPROPERTIES (contains_pii = true, purpose = 'DS-2002 Lab 6.0');
"""
spark.sql(sql_create_db)

### Fetch Reference Data from MongoDB, MySQL, and CSV Files

#### MongoDB (Note: Customer Data)

In [None]:
# Fetch Data from MongoDB
# Get MongoDB client and fetch customer data
client = get_mongo_client(**mongodb_args)
query = {}  # Select all elements (columns) and all documents (rows)
collection_name = "customers"

# Fetch MongoDB data into a DataFrame
df_mongo_customers = get_mongo_dataframe(client, mongodb_args["db_name"], collection_name, query)

# Make Necessary Transformations to the DataFrame (Standardizing Column Names, etc.)
df_mongo_customers.rename(columns={"CustomerID": "customer_id"}, inplace=True)

# Save as the dim_customer table in the Data Warehouse (MySQL)
# Standardize the column names for the data mart
df_mongo_customers = df_mongo_customers[['customer_id', 'TerritoryID', 'AccountNumber', 'CustomerType']]

# Insert the transformed customer data into the data warehouse's dim_customer table
set_dataframe(df_mongo_customers, table_name="dim_customer", pk_column="customer_id", db_operation="insert", **mysql_args)

# Unit Test: Describe and Preview the Table
# Check the 'dim_customer' table description and preview
spark.sql(f"DESCRIBE EXTENDED data_mart.dim_customer").show()
spark.sql(f"SELECT * FROM data_mart.dim_customer LIMIT 2").toPandas()


#### MySQL (Note: Product Data and Date Data)

In [None]:
# Fetch Data from MySQL
# Fetch product data from MySQL
sql_product = "SELECT ProductID, Name, ProductNumber, ListPrice FROM product;"
df_product_mysql = get_sql_dataframe(sql_product, **mysql_args)

# Make Necessary Transformations to the DataFrame (Standardizing Column Names, etc.)
df_product_mysql.rename(columns={"ProductID": "product_id", "Name": "product_name", "ProductNumber": "product_code", "ListPrice": "list_price"}, inplace=True)

# Save as the dim_product table in the Data Warehouse (MySQL)
# Standardize the column names for the data mart
df_product_mysql = df_product_mysql[['product_id', 'product_name', 'product_code', 'list_price']]

# Insert the transformed product data into the data warehouse's dim_product table
set_dataframe(df_product_mysql, table_name="dim_product", pk_column="product_id", db_operation="insert", **mysql_args)

# Unit Test: Describe and Preview the Table
# Check the 'dim_product' table description and preview
spark.sql(f"DESCRIBE EXTENDED data_mart.dim_product").show()
spark.sql(f"SELECT * FROM data_mart.dim_product LIMIT 2").toPandas()


In [None]:
# Fetching the Date Dimension Data from MySQL
sql_dim_date = "SELECT date_key, full_date FROM data_mart.dim_date;"
df_dim_date = get_sql_dataframe(sql_dim_date, **mysql_args)

# Ensure the 'full_date' column is in datetime format for merging
df_dim_date['full_date'] = pd.to_datetime(df_dim_date['full_date']).dt.date

# Show the first few rows of the data to confirm the structure
df_dim_date.head()

# Convert to a PySpark DataFrame
df_dim_date_spark = spark.createDataFrame(df_dim_date)

# Check the schema of the date dimension to ensure correctness
df_dim_date_spark.printSchema()

# Save the date dimension as a table in your Data Warehouse (Data Lakehouse)
df_dim_date_spark.write.saveAsTable("data_mart.dim_date", mode="overwrite")

# Unit Test: Describe and Preview the Table
# Check the structure of the dim_date table
spark.sql(f"DESCRIBE EXTENDED data_mart.dim_date").show()

# Preview the first few rows from the dim_date table to confirm data insertion
spark.sql(f"SELECT * FROM data_mart.dim_date LIMIT 2").toPandas()

#### CSV File Using PySpark (Note: Vendor Data)

In [None]:
# Read the CSV file into a DataFrame
vendor_csv_path = os.path.join(data_dir, 'vendor.csv')

df_vendor = spark.read.format('csv') \
    .option('header', 'true') \
    .option('inferSchema', 'true') \
    .load(vendor_csv_path)

# Show the first few rows to check the data
df_vendor.show(5)

# Rename columns to standardize them according to the data mart schema
df_vendor = df_vendor.withColumnRenamed("VendorID", "vendor_id") \
    .withColumnRenamed("Name", "vendor_name") \
    .withColumnRenamed("AccountNumber", "vendor_account_number") \
    .withColumnRenamed("CreditRating", "credit_rating") \
    .withColumnRenamed("PreferredVendorStatus", "preferred_vendor_status") \
    .withColumnRenamed("ActiveFlag", "active_flag") \
    .withColumnRenamed("PurchasingWebServiceURL", "purchasing_web_service_url") \
    .withColumnRenamed("ModifiedDate", "modified_date")

# Drop any columns that are not required for the dimension table
df_vendor = df_vendor.select("vendor_id", "vendor_name", "vendor_account_number", "credit_rating", 
                             "preferred_vendor_status", "active_flag", "purchasing_web_service_url", "modified_date")

# Show the transformed DataFrame to check
df_vendor.show(5)

# Save as the dim_vendor table in the Data Warehouse
set_dataframe(df_vendor, table_name="dim_vendor", pk_column="vendor_id", db_operation="insert", **mysql_args)

# Unit Test: Describe and Preview the Table
# Check the structure of the dim_vendor table
spark.sql(f"DESCRIBE EXTENDED data_mart.dim_vendor").show()

# Preview the first few rows from the dim_vendor table to confirm data insertion
spark.sql(f"SELECT * FROM data_mart.dim_vendor LIMIT 2").toPandas()


### Fact Table

In [None]:
# Extracting Fact Table Data (Sales Orders)
sql_sales_order_header = """
    SELECT SalesOrderID, CustomerID, OrderDate, TotalDue 
    FROM adventureworks.SalesOrderHeader;
"""
df_sales_order_header = get_sql_dataframe(sql_sales_order_header, **mysql_args)

# Extracting Order Details to get the ProductID
sql_sales_order_detail = """
    SELECT SalesOrderID, ProductID, OrderQty, LineTotal 
    FROM adventureworks.SalesOrderDetail;
"""
df_sales_order_detail = get_sql_dataframe(sql_sales_order_detail, **mysql_args)

# Merging the sales order header and order details
df_fact_orders = pd.merge(df_sales_order_header, df_sales_order_detail, on='SalesOrderID', how='left')

# Renaming columns for consistency
df_fact_orders.rename(columns={
    "SalesOrderID": "sales_order_id",
    "CustomerID": "customer_id",
    "ProductID": "product_id",
    "OrderDate": "order_date",
    "OrderQty": "order_qty",
    "LineTotal": "line_total",
    "TotalDue": "total_due"
}, inplace=True)

# Merge with the customer dimension to get more information
df_fact_orders = pd.merge(df_fact_orders, df_dim_customer, on='customer_id', how='left')

# Merge with the vendor dimension to get vendor details
df_fact_orders = pd.merge(df_fact_orders, df_dim_vendor, on='vendor_id', how='left')

# Merge with the product dimension to get product details
df_fact_orders = pd.merge(df_fact_orders, df_dim_product, on='product_id', how='left')

# Merge with the date dimension to get the date details
df_fact_orders = pd.merge(df_fact_orders, df_dim_date, left_on='order_date', right_on='full_date', how='left')

# Drop the 'full_date' column from the date dimension (since it's now included as 'order_date')
df_fact_orders.drop(['full_date'], axis=1, inplace=True)

# Handle missing data
df_fact_orders['vendor_id'].fillna(0, inplace=True)  # Replace NaN with 0
df_fact_orders['order_qty'].fillna(0, inplace=True)  # Replace NaN with 0

# Remove duplicates based on the primary key column
df_fact_orders = df_fact_orders.drop_duplicates(subset='sales_order_id', keep='first')

# Ensure the column types are correct
df_fact_orders['order_qty'] = df_fact_orders['order_qty'].astype(int)

# Save the fact table to MySQL (or another DB) using the set_dataframe function
table_name = "fact_orders"
primary_key = "sales_order_id"
db_operation = "insert"

# Insert data into MySQL database
set_dataframe(df_fact_orders, table_name, primary_key, db_operation, **mysql_args)

# Unit Test: Preview the fact table in the database
spark.sql(f"SELECT * FROM data_mart.fact_orders LIMIT 2").toPandas()


#### Use PySpark Structured Streaming to Process (Hot Path) Orders Fact Data

In [None]:
# Define the path to the orders streaming data (streaming data for new sales orders)
orders_stream_dir = os.path.join(stream_dir, 'orders')

# Create a streaming DataFrame to read JSON data (representing the incoming orders)
df_orders_bronze = (
    spark.readStream
    .option("schemaLocation", orders_output_bronze)  # Location for schema inference
    .option("maxFilesPerTrigger", 1)  # Max files per trigger (mini-batch)
    .option("multiLine", "true")  # Handle multi-line JSON
    .json(orders_stream_dir)
)

# Check if the DataFrame is streaming
df_orders_bronze.isStreaming


#### Write the Streaming Data to a Parquet File (Bronze Layer)

In [None]:
orders_checkpoint_bronze = os.path.join(orders_output_bronze, '_checkpoint')

# Write the incoming orders data to the bronze layer (Parquet format)
orders_bronze_query = (
    df_orders_bronze
    .withColumn("receipt_time", current_timestamp())  # Add timestamp
    .withColumn("source_file", input_file_name())  # Add source file for traceability
    .writeStream
    .format("parquet")
    .outputMode("append")
    .queryName("orders_bronze")
    .trigger(availableNow=True)  # Process available data immediately
    .option("checkpointLocation", orders_checkpoint_bronze)
    .option("compression", "snappy")
    .start(orders_output_bronze)
)


In [None]:
# Wait until the stream processes at least one batch
wait_until_stream_is_ready(orders_bronze_query, 1)


#### Create the Silver Layer: Integrate "Cold-Path" Data & Make Transformations

In [None]:
# Join the streaming orders (fact table) with the reference data (dimensions)
df_dim_customer = spark.table(f"{dest_database}.dim_customer")
df_dim_product = spark.table(f"{dest_database}.dim_product")
df_dim_vendor = spark.table(f"{dest_database}.dim_vendor")
df_dim_date = spark.table(f"{dest_database}.dim_date")

df_orders_silver = (
    spark.readStream.format("parquet").load(orders_output_bronze)
    .join(df_dim_customer, "customer_id")  # Join with customer dimension
    .join(df_dim_product, "product_id")  # Join with product dimension
    .join(df_dim_vendor, "vendor_id")  # Join with vendor dimension
    .join(df_dim_date, df_dim_date.date_key == col("order_date_key"), "left_outer")  # Join with date dimension
    .select(
        col("sales_order_id"),
        col("customer_id"),
        col("product_id"),
        col("vendor_id"),
        col("order_date"),
        col("order_qty"),
        col("line_total"),
        col("total_due"),
        col("date_key").alias("order_date_key")
    )
)

# Check if the DataFrame is streaming
df_orders_silver.isStreaming


#### Write the Transformed Streaming Data to the Data Lakehouse (Silver Layer)

In [None]:
orders_checkpoint_silver = os.path.join(orders_output_silver, '_checkpoint')

# Write the enriched orders data (silver layer) to the Data Lakehouse
orders_silver_query = (
    df_orders_silver.writeStream
    .format("parquet")
    .outputMode("append")
    .queryName("orders_silver")
    .trigger(availableNow=True)
    .option("checkpointLocation", orders_checkpoint_silver)
    .option("compression", "snappy")
    .start(orders_output_silver)
)


####  Define a Query to Create a Business Report (Gold Layer)

In [None]:
df_orders_by_product_category_gold = (
    spark.readStream.format("parquet").load(orders_output_silver)
    .join(df_dim_product, "product_id")  # Join with product dimension
    .join(df_dim_date, df_dim_date.date_key == col("order_date_key"))
    .groupBy("month_of_year", "category", "month_name")
    .agg(count("product_id").alias("product_count"))
    .orderBy(asc("month_of_year"), desc("product_count"))
)


#### Write the Aggregated Data to Memory in "Complete" Mode

In [None]:
orders_gold_query = (
    df_orders_by_product_category_gold.writeStream
    .format("memory")
    .outputMode("complete")
    .queryName("fact_orders_by_product_category")
    .start()
)


#### Query the Gold Data from Memory

In [None]:
# Wait until the stream has processed the batches
wait_until_stream_is_ready(orders_gold_query, 1)

# Query the gold data from memory
df_fact_orders_by_product_category = spark.sql("SELECT * FROM fact_orders_by_product_category")
df_fact_orders_by_product_category.show()


#### Create the Final Selection

In [None]:
df_fact_orders_by_product_category_gold_final = df_fact_orders_by_product_category \
    .select(col("month_name").alias("Month"), col("category").alias("Product Category"), col("product_count").alias("Product Count")) \
    .orderBy(asc("month_of_year"), desc("Product Count"))


#### Load the Final Results into a New Table and Display the Results


In [None]:
df_fact_orders_by_product_category_gold_final.write.saveAsTable(f"{dest_database}.fact_orders_by_product_category", mode="overwrite")

# Display the results of the gold layer
spark.sql(f"SELECT * FROM {dest_database}.fact_orders_by_product_category").show()
