In [17]:
import pyspark  

import os
import sys
from pyspark.sql import SparkSession
from pyspark.sql import functions as fn

In [18]:
# Initialize Spark with proper configuration for Windows
spark = SparkSession.builder \
.appName("SparkAppName") \
.master("spark://spark-master:7077") \
.getOrCreate()

sc = spark.sparkContext
sc.setLogLevel("ERROR")

print("Spark Session initialized successfully.")

Spark Session initialized successfully.


In [19]:
csv_path = "/app/data/FULL_STOCKS.csv"
df = spark.read.csv(csv_path, header=True, inferSchema=True)

print(f"\nLoaded CSV file: {csv_path}")


Loaded CSV file: /app/data/FULL_STOCKS.csv


In [20]:
print("\nFirst 10 records:")
df.show(10)


First 10 records:
+--------------+----------+-----------+------------+----------------+--------+------------------+-----------------+------------------+---------------------+----------+----------+--------------------+------------+--------------+----------+----------+------------+-----------+-------------+
|transaction_id| timestamp|customer_id|stock_ticker|transaction_type|quantity|average_trade_size|      stock_price|total_trade_amount|customer_account_type|is_weekend|is_holiday|stock_liquidity_tier|stock_sector|stock_industry|day_Friday|day_Monday|day_Thursday|day_Tuesday|day_Wednesday|
+--------------+----------+-----------+------------+----------------+--------+------------------+-----------------+------------------+---------------------+----------+----------+--------------------+------------+--------------+----------+----------+------------+-----------+-------------+
|             1|2023-01-02|       4747|           1|               0|      15|             139.4|4.967956350196401

In [21]:
print("\n1. Total trading volume for each stock ticker:")
q1_result = df.groupBy("stock_ticker") \
    .agg(fn.sum("quantity").alias("total_volume")) 
q1_result.show()


1. Total trading volume for each stock ticker:
+------------+------------+
|stock_ticker|total_volume|
+------------+------------+
|          12|      152604|
|           1|        4405|
|          13|        1512|
|           6|       28486|
|          16|       35435|
|           3|        2857|
|           5|      611667|
|          19|       49724|
|          15|        3210|
|           9|      125862|
|          17|      153811|
|           4|        6997|
|           8|      314915|
|           7|      104680|
|          10|       32119|
|          11|      428057|
|          14|        3302|
|           2|        8570|
|           0|        3557|
|          18|       34229|
+------------+------------+



In [22]:
# Question 2: What is the average stock price by sector?
print("\n2. Average stock price by sector:")
q2_result = df.groupBy("stock_sector") \
    .agg(fn.avg(fn.exp(fn.col("stock_price"))).alias("avg_stock_price")) 
q2_result.show()


2. Average stock price by sector:
+------------+------------------+
|stock_sector|   avg_stock_price|
+------------+------------------+
|           1| 101.5326450467328|
|           3|152.00790316478066|
|           4|153.67922533016065|
|           2| 79.92351314619141|
|           0|213.62484690770114|
+------------+------------------+



In [23]:
# Question 3: How many buy vs sell transactions occurred on weekends?
print("\n3. Buy vs Sell transactions on weekends:")
q3_result = df.filter(fn.col("is_weekend") == 1) \
    .groupBy("transaction_type") \
    .agg(fn.count("transaction_id").alias("transaction_count")) 
q3_result.show()


3. Buy vs Sell transactions on weekends:
+----------------+-----------------+
|transaction_type|transaction_count|
+----------------+-----------------+
+----------------+-----------------+



In [24]:
# Question 4: Which customers have made more than 10 transactions?
print("\n4. Customers with more than 10 transactions:")
q4_result = df.groupBy("customer_id") \
    .agg(fn.count("transaction_id").alias("transaction_count")) \
    .filter(fn.col("transaction_count") > 10) 
print(f"Total customers with >10 transactions: {q4_result.count()}")
q4_result.show()


4. Customers with more than 10 transactions:
Total customers with >10 transactions: 74
+-----------+-----------------+
|customer_id|transaction_count|
+-----------+-----------------+
|       4519|              284|
|       1903|               11|
|       1157|              227|
|       4697|               17|
|       3087|               12|
|        193|              419|
|       1243|               13|
|       1816|               86|
|       4700|               63|
|       2871|               29|
|       2750|               13|
|        192|               41|
|       4354|               15|
|        336|               20|
|       2920|               19|
|        319|               17|
|       3498|               39|
|       4987|               14|
|        363|               15|
|        182|              162|
+-----------+-----------------+
only showing top 20 rows



In [25]:
# Question 5: What is the total trade amount per day of the week, ordered from highest to lowest?
print("\n5. Total trade amount per day of the week (highest to lowest):")

# one-hot encoded day columns
day_cols = ["day_Monday", "day_Tuesday", "day_Wednesday", "day_Thursday", "day_Friday"]

# compute total trade amount for each day by filtering on the one-hot flag
sums = []
for col in day_cols:
    total = df.filter(fn.col(col) == 1) \
              .agg(fn.sum("total_trade_amount").alias("total")) \
              .collect()[0]["total"]
    total = float(total) if total is not None else 0.0
    sums.append((col.replace("day_", ""), total))

# create a Spark DataFrame and order by total_trade_amount desc
q5_result = spark.createDataFrame(sums, ["day", "total_trade_amount"]) \
                 .orderBy(fn.desc("total_trade_amount"))

q5_result.show()


5. Total trade amount per day of the week (highest to lowest):
+---------+--------------------+
|      day|  total_trade_amount|
+---------+--------------------+
| Thursday|6.0742740599982046E7|
|Wednesday| 5.930424152755955E7|
|   Monday|5.8715772269690834E7|
|   Friday| 5.809313014764415E7|
|  Tuesday|5.2834481894753695E7|
+---------+--------------------+



## Spark SQL Analysis Questions

In [26]:
# Register DataFrame as temporary SQL table
df.createOrReplaceTempView("trades")

In [27]:
# SQL Question 1: What are the top 5 most traded stock tickers by total quantity?
print("SQL 1. Top 5 most traded stock tickers by total quantity:")
sql1_result = spark.sql("""
    SELECT stock_ticker, 
           SUM(quantity) as total_quantity
    FROM trades
    GROUP BY stock_ticker
    ORDER BY total_quantity DESC
    LIMIT 5
""")
sql1_result.show()

SQL 1. Top 5 most traded stock tickers by total quantity:
+------------+--------------+
|stock_ticker|total_quantity|
+------------+--------------+
|           5|        611667|
|          11|        428057|
|           8|        314915|
|          17|        153811|
|          12|        152604|
+------------+--------------+



In [28]:
# SQL Question 2: What is the average trade amount by customer account type?
print("\nSQL 2. Average trade amount by customer account type:")
sql2_result = spark.sql("""
    SELECT 
        CASE customer_account_type
            WHEN 0 THEN 'Institutional'
            ELSE 'Retail'
        END as account_type,
        AVG(total_trade_amount) as avg_trade_amount
    FROM trades
    GROUP BY customer_account_type
""")
sql2_result.show()


SQL 2. Average trade amount by customer account type:
+-------------+------------------+
| account_type|  avg_trade_amount|
+-------------+------------------+
|       Retail|29253.766554710597|
|Institutional|26043.733739066804|
+-------------+------------------+



In [29]:
# SQL Question 3: How many transactions occurred during holidays vs non-holidays?
print("SQL 3. Transactions during holidays vs non-holidays:")
sql3_result = spark.sql("""
    SELECT 
        CASE is_holiday
            WHEN 1 THEN 'Holiday'
            ELSE 'Non-Holiday'
        END as period_type,
        COUNT(transaction_id) as transaction_count
    FROM trades
    GROUP BY is_holiday
""")
sql3_result.show()

SQL 3. Transactions during holidays vs non-holidays:
+-----------+-----------------+
|period_type|transaction_count|
+-----------+-----------------+
|    Holiday|              180|
|Non-Holiday|             9820|
+-----------+-----------------+



In [30]:
# SQL Question 4: Which stock sectors had the highest total trading volume on weekends?
print("\nSQL 4. Stock sectors with highest total trading volume on weekends:")
sql4_result = spark.sql("""
    SELECT stock_sector,
           SUM(quantity) as total_volume
    FROM trades
    WHERE is_weekend = 1
    GROUP BY stock_sector
    ORDER BY total_volume DESC
""")
sql4_result.show()


SQL 4. Stock sectors with highest total trading volume on weekends:
+------------+------------+
|stock_sector|total_volume|
+------------+------------+
+------------+------------+



In [31]:
# SQL Question 5: What is the total buy vs sell amount for each stock liquidity tier?
print("\nSQL 5. Total buy vs sell amount for each stock liquidity tier:")
sql5_result = spark.sql("""
    SELECT stock_liquidity_tier,
           CASE transaction_type
               WHEN 0 THEN 'BUY'
                ELSE  'SELL'
           END as transaction_type,
           SUM(total_trade_amount) as total_amount
    FROM trades
    GROUP BY stock_liquidity_tier, transaction_type
""")
sql5_result.show()


SQL 5. Total buy vs sell amount for each stock liquidity tier:
+--------------------+----------------+--------------------+
|stock_liquidity_tier|transaction_type|        total_amount|
+--------------------+----------------+--------------------+
|                High|            SELL| 6.118042233749355E7|
|                High|             BUY|1.6443413419306594E8|
|                 Low|             BUY|   2599141.018825055|
|                 Mid|            SELL|1.5581928747374471E7|
|                 Mid|             BUY| 4.549252643164398E7|
|                 Low|            SELL|  402213.71122708963|
+--------------------+----------------+--------------------+

