In [0]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, lit, when
from pyspark.sql.types import StructType,ArrayType,StructField,StringType
from snowflake.snowpark import Session

In [0]:
region_a_path = '/FileStore/tables/order_region_a.xlsx'
region_b_path = '/FileStore/tables/order_region_b.xlsx'


In [0]:
def flatten_struct(schema):
    def wrapper(schema,prefix = ""):
        flatten_cols = []        
        for field in schema.fields:
            path = f'{prefix}.{field.name}' if prefix  else f'{field.name}'
            if isinstance(field.dataType,StructType):
                  flatten_cols+=wrapper(field.dataType,path)
            else:
                  flatten_cols.append(path)
        return flatten_cols
    
    return [col(cols).alias(cols.replace('.','_')) for cols in wrapper(schema)]  


In [0]:
from pyspark.sql.functions import from_json, col

def flatten_drop(df):
  schema = StructType([
      StructField("CurrencyCode", StringType(), True),
      StructField("Amount", StringType(), True)
  ])
  parsed_df = df.withColumn("PromotionDiscount", from_json(col("PromotionDiscount"), schema))
  return parsed_df.select(flatten_struct(parsed_df.schema)).drop('PromotionDiscount')

In [0]:
def create_spark_session():
    # Create a SparkSession
    return SparkSession.builder.appName("sales_etl") \
                .config("spark.jars.packages", "com.crealytics:spark-excel-2.12.17-3.2.2_2.12:3.2.2_0.18.1") \
                .getOrCreate()

def read_files(spark):
    # Read Region A data
    df_region_a = spark.read \
        .format("com.crealytics.spark.excel") \
        .option("header", "true") \
        .option("inferSchema", "true") \
        .load(region_a_path)

    df_region_a  = flatten_drop(df_region_a.withColumn('region',lit('A')))

    # Read Region A data
    df_region_b = spark.read \
        .format("com.crealytics.spark.excel") \
        .option("header", "true") \
        .option("inferSchema", "true") \
        .load(region_b_path)

    df_region_b  = flatten_drop(df_region_b.withColumn('region',lit('B')))

    return (df_region_a,df_region_b)

def transform_data(df_region_a, df_region_b):
    # Combine datasets
    combined_df = df_region_a.union(df_region_b)
    
    # Add required columns and apply transformations
    transformed_df = combined_df \
        .withColumn("total_sales", col("QuantityOrdered") * col("ItemPrice")) \
        .withColumn("net_sale", col("total_sales") - col("PromotionDiscount_Amount")) \
        .dropDuplicates(["OrderId","region"]).where(col("net_sale") > 0)
        
    return transformed_df

def create_snowflake_connection(snowflake_configs):
    # Establish connection to Snowflake
    return Session.builder.configs(snowflake_configs).create()

def upload_to_snowflake(df,session, table_name):
    # Convert DataFrame to Pandas DataFrame
    pandas_df = df.toPandas()
    
    # Create table and upload data
    pandas_df.to_sql(
        name=table_name,
        con=session.connection,
        if_exists='replace',
        index=False
    )

def load_to_sf(df):
    # save as temp table
    df.createOrReplaceTempView("sales_data")  




In [0]:
# # create spark session
spark = create_spark_session()

# # Extract
df_region_a, df_region_b = read_files(spark)

# # Transform
transformed_df = transform_data(df_region_a, df_region_b)

# # Load
load_to_sf(transformed_df)

In [0]:
display(df_region_a)

OrderId,OrderItemId,QuantityOrdered,ItemPrice,PromotionDiscount_CurrencyCode,PromotionDiscount_Amount,batch_id,region
171-0001135-1657958,11168926687715.0,1.0,949.0,INR,10.0,359.0,A
171-0001497-9165123,19760298917699.0,1.0,699.0,INR,10.1,1135.0,A
171-0002127-1363507,5949764099083.0,1.0,399.0,INR,10.0,297.0,A
171-0002370-0601169,57571868836379.0,1.0,499.0,INR,10.1,114.0,A
171-0004526-2028348,33851287891403.0,1.0,1699.0,INR,10.0,764.0,A
171-0004781-3853173,43686103544491.0,1.0,399.0,INR,10.1,809.0,A
171-0004947-4305927,15941372058555.0,1.0,1399.0,INR,10.0,15.0,A
171-0004947-4305927,15941372058555.0,1.0,1399.0,INR,10.1,330.0,A
171-0005467-8036365,33952397753619.0,1.0,349.0,INR,10.0,868.0,A
171-0006030-2254725,31456208605443.0,1.0,499.0,INR,10.1,1494.0,A


In [0]:
display(df_region_b)

OrderId,OrderItemId,QuantityOrdered,ItemPrice,PromotionDiscount_CurrencyCode,PromotionDiscount_Amount,batch_id,region
171-0001135-1657958,11168926687715.0,1.0,949.0,INR,10.0,359.0,B
171-0001497-9165123,19760298917699.0,1.0,699.0,INR,10.1,1135.0,B
171-0002127-1363507,5949764099083.0,1.0,399.0,INR,10.0,297.0,B
171-0002370-0601169,57571868836379.0,1.0,499.0,INR,10.1,114.0,B
171-0004526-2028348,33851287891403.0,1.0,1699.0,INR,10.0,764.0,B
171-0004781-3853173,43686103544491.0,1.0,399.0,INR,10.1,809.0,B
171-0004947-4305927,15941372058555.0,1.0,1399.0,INR,10.0,15.0,B
171-0004947-4305927,15941372058555.0,1.0,1399.0,INR,10.1,330.0,B
171-0005467-8036365,33952397753619.0,1.0,349.0,INR,10.0,868.0,B
171-0006030-2254725,31456208605443.0,1.0,499.0,INR,10.1,1494.0,B


In [0]:
display(transformed_df)

OrderId,OrderItemId,QuantityOrdered,ItemPrice,PromotionDiscount_CurrencyCode,PromotionDiscount_Amount,batch_id,region,total_sales,net_sale
171-0001135-1657958,11168926687715.0,1.0,949.0,INR,10.0,359.0,A,949.0,939.0
171-0004781-3853173,43686103544491.0,1.0,399.0,INR,10.1,809.0,B,399.0,388.9
171-0008662-0057942,66326653881907.0,1.0,599.0,INR,10.1,122.0,A,599.0,588.9
171-0010322-1769977,47459215833307.0,1.0,399.0,INR,10.0,676.0,B,399.0,389.0
171-0010803-5365973,21135179938611.0,1.0,299.0,INR,10.1,669.0,B,299.0,288.9
171-0015668-5065935,8918074736283.0,1.0,699.0,INR,10.0,709.0,A,699.0,689.0
171-0016656-1255573,21577356095867.0,1.0,499.0,INR,10.1,695.0,A,499.0,488.9
171-0018345-8874731,59365948515931.0,1.0,599.0,INR,10.0,150.0,A,599.0,589.0
171-0022691-4700335,60239001823435.0,1.0,499.0,INR,10.0,64.0,A,499.0,489.0
171-0023912-2838777,12927344177755.0,1.0,499.0,INR,10.1,910.0,A,499.0,488.9


Validation

In [0]:
def validate_data(df):
    # Total records
    print(f"Total Records: {df.count()}")
    
    # Sales by region
    df.groupBy("region") \
      .sum("total_sales") \
      .show()
    
    # Average sales
    df.agg({"total_sales": "avg"}) \
      .show()
    
    # Duplicate check
    duplicates = df.groupBy(["OrderId","region"]) \
                   .count() \
                   .filter(col("count") > 1)
    
    if duplicates.count() > 0:
        print("Warning: Duplicate OrderIds found!")
    else:
        print("No duplicates found.")

In [0]:
validate_data(transformed_df)

Total Records: 82104
+------+-------------------+
|region|   sum(total_sales)|
+------+-------------------+
|     B|3.457098452000002E7|
|     A|      3.457098452E7|
+------+-------------------+

+-----------------+
| avg(total_sales)|
+-----------------+
|842.1266812822765|
+-----------------+

No duplicates found.


In [0]:
%sql
-- Count total records
SELECT COUNT(*) as total_records 
FROM sales_data;


total_records
82104


In [0]:
%sql
-- Total sales by region
SELECT region, 
       SUM(total_sales) as total_sales_amount
FROM sales_data
GROUP BY region;


region,total_sales_amount
B,34570984.52000002
A,34570984.52


In [0]:
%sql
-- Average sales per transaction
SELECT AVG(total_sales) as avg_sales_per_transaction
FROM sales_data;


avg_sales_per_transaction
842.1266812822765


In [0]:
%sql
-- Check for duplicate OrderIds
SELECT OrderId,region, COUNT(*) as count
FROM sales_data
GROUP BY OrderId,region
HAVING count > 1;

OrderId,region,count
