Importing Kaggle Dataset

In [2]:
from kagglehub import dataset_download
#remove warnings
import warnings
warnings.filterwarnings("ignore")

from pyspark.sql import SparkSession
import pyspark.sql.functions as sf
from pyspark.sql.types import StructType, StructField, StringType, FloatType, IntegerType, TimestampType, LongType
import re

path: str = dataset_download("jinquan/cc-sample-data")

print(path)

spark = SparkSession.builder.appName("payNet").config('spark.driver.memory', '8g').getOrCreate()


  from .autonotebook import tqdm as notebook_tqdm


/home/jeevin/.cache/kagglehub/datasets/jinquan/cc-sample-data/versions/1


Using Spark's default log4j profile: org/apache/spark/log4j2-defaults.properties
25/07/22 00:57:28 WARN Utils: Your hostname, jeevin, resolves to a loopback address: 127.0.1.1; using 10.255.255.254 instead (on interface lo)
25/07/22 00:57:28 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
Using Spark's default log4j profile: org/apache/spark/log4j2-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
25/07/22 00:57:28 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


Load JSON and clean up Data

In [3]:
# Read the JSON data from the file
df = spark.read.json(path)

def clean_json_string(json_str):
    """
    Clean JSON string by:
    1. Removing all backslashes
    2. Removing quotes around JSON objects (e.g., "{ }" becomes { })
    """
    if json_str is None:
        return None
    
    # Remove all backslashes
    cleaned = json_str.replace("\\", "")
    
    # Remove quotes around JSON objects - pattern: "{ ... }"
    # This regex finds quoted JSON objects and removes the outer quotes
    cleaned = re.sub(r'"\s*\{\s*(.*?)\s*\}\s*"', r'{\1}', cleaned)
    
    return cleaned

# Register UDF
clean_json_udf = sf.udf(clean_json_string, StringType())

# Apply cleaning to the personal_detail column
df = df.withColumn("personal_detail", clean_json_udf(sf.col("personal_detail")))

# Define schema for address (nested within personal_detail) - all strings initially
address_schema = StructType([
    StructField("street", StringType(), True),
    StructField("city", StringType(), True),
    StructField("state", StringType(), True),
    StructField("zip", StringType(), True)
])

# Define schema for personal_detail - all strings initially
personal_schema = StructType([
    StructField("person_name", StringType(), True),
    StructField("gender", StringType(), True),
    StructField("address", address_schema, True),
    StructField("lat", StringType(), True),
    StructField("long", StringType(), True),
    StructField("city_pop", StringType(), True),
    StructField("job", StringType(), True),
    StructField("dob", StringType(), True)
])

# Parse the cleaned JSON string into proper columns (overwrite the original column)
df_with_parsed_personal = df.withColumn("personal_detail", sf.from_json(sf.col("personal_detail"), personal_schema))

# for debugging
# df_with_parsed_personal.select(col("personal_detail.person_name")).show()
# df_with_parsed_personal.select(col("personal_detail.address.street")).show()

# Process name splitting for 'first' and 'last' names with improved robustness.
df_with_names = df_with_parsed_personal.withColumn(
    "cleaned_person_name",
    sf.when(sf.col("personal_detail.person_name").isNotNull(),
        sf.trim(
            sf.regexp_replace( # Normalize multiple spaces to single space
                sf.regexp_replace( # Replace all non-alphanumeric characters (except spaces) with a single space
                    sf.regexp_replace( # Remove specific trailing strings like 'eeeee' and 'N' followed by 4 or more '0' or 'O' (case-insensitive)
                        sf.regexp_replace(sf.col("personal_detail.person_name"), r"(?i),?eeeee$", ""),
                        r"(?i),?\s*N[0O]{4,}$", "" # Updated regex to handle N0000, NOOOO etc.
                    ),
                    r"[^a-zA-Z0-9\s]", " " # Replace any character that is NOT a letter, number, or whitespace with a space. This will catch /, !, @, |, and also the comma if it's not part of a "Lastname, Firstname" pattern.
                ),
                r"\s+", " " # Normalize multiple spaces to single space
            )
        )
    ).otherwise(None)
)

df_with_names = df_with_names \
    .withColumn("name_parts", sf.split(sf.col("cleaned_person_name"), " ")) \
    .withColumn("first", 
        sf.when(sf.size(sf.col("name_parts")) >= 1, sf.trim(sf.element_at(sf.col("name_parts"), 1)))
        .otherwise(None)
    ) \
    .withColumn("last", 
        sf.when(sf.size(sf.col("name_parts")) > 1, 
                sf.trim(sf.concat_ws(" ", sf.slice(sf.col("name_parts"), 2, sf.size(sf.col("name_parts"))))))
        .otherwise(None)
    ) \
    .drop("cleaned_person_name", "name_parts") # Drop intermediate columns


# Flatten the personal_detail structure and address structure
df_flattened = df_with_names.select(
    # Original columns in desired order
    sf.col("Unnamed: 0"),
    sf.col("trans_date_trans_time"),
    sf.col("cc_num"),
    sf.col("merchant"),
    sf.col("category"),
    sf.col("amt"),
    
    sf.col("first"),
    sf.col("last"),

    # Personal details
    sf.col("personal_detail.gender").alias("gender"),
    
    # Flattened address details
    sf.col("personal_detail.address.street").alias("street"),
    sf.col("personal_detail.address.city").alias("city"),
    sf.col("personal_detail.address.state").alias("state"),
    sf.col("personal_detail.address.zip").alias("zip"),
    
    # Location and demographic info
    sf.col("personal_detail.lat").alias("lat"),
    sf.col("personal_detail.long").alias("long"),
    sf.col("personal_detail.city_pop").alias("city_pop"),
    sf.col("personal_detail.job").alias("job"),
    sf.col("personal_detail.dob").alias("dob"),
    
    # Transaction details
    sf.col("trans_num"),
    sf.col("merch_lat"),
    sf.col("merch_long"),
    sf.col("is_fraud"),
    sf.col("merch_zipcode"),
    sf.col("merch_last_update_time"),
    sf.col("merch_eff_time"),
    sf.col("cc_bic")
)


# Type conversions and rounding in one operation, including date format and timezone handling. Assuming initial timzeone is UTC and converting to UTC
df_cleaned = df_flattened.withColumns({
    'Unnamed: 0': sf.col("Unnamed: 0").cast(IntegerType()),


    # Convert trans_date_trans_time to TimestampType, then to UTC+8, then format
    'trans_date_trans_time': sf.date_format(sf.from_utc_timestamp(sf.col("trans_date_trans_time").cast("timestamp"), "UTC+8"), 'yyyy-MM-dd HH:mm:ss Z'),


    'amt': sf.round(sf.col("amt").cast(FloatType()), 6),
    'merch_lat': sf.round(sf.col("merch_lat").cast(FloatType()), 6),
    'merch_long': sf.round(sf.col("merch_long").cast(FloatType()), 6),
    'is_fraud': sf.col("is_fraud").cast(IntegerType()),


    # Convert merch_eff_time (microseconds) to TimestampType, then to UTC+8, then format
 
    'merch_eff_time': sf.date_format(
        sf.from_utc_timestamp(
            (
                sf.rpad(
                    sf.col("merch_eff_time").cast(StringType()),
                    16,
                    '0'
                ).cast(LongType()) / 1000000
            ).cast("timestamp"),
            "UTC+8"
        ),
        'yyyy-MM-dd HH:mm:ss.SSSSSS Z'
    ),


    # Convert merch_last_update_time (microseconds) to TimestampType, then to UTC+8, then format
    'merch_last_update_time': sf.date_format(
        sf.from_utc_timestamp(
            (
                sf.rpad(
                    sf.col("merch_last_update_time").cast(StringType()),
                    16,
                    '0'
                ).cast(LongType()) / 1000000
            ).cast("timestamp"),
            "UTC+8"
            ),
            'yyyy-MM-dd HH:mm:ss.SSSSSS Z'
        ), 
    
    'lat': sf.round(sf.col("lat").cast(FloatType()), 6),
    'long': sf.round(sf.col("long").cast(FloatType()), 6),
    'city_pop': sf.col("city_pop").cast(IntegerType())
})


# Handle null values and "NA" strings for all string columns automatically
string_columns = [field.name for field in df_cleaned.schema.fields if field.dataType.typeName() == 'string']

# Create dictionary for null value handling across all string columns
null_handling_dict = {}
for col_name in string_columns:
    null_handling_dict[col_name] = sf.when(
        (sf.lower(sf.col(col_name)) == "na") | 
        (sf.lower(sf.col(col_name)) == "null") | 
        (sf.col(col_name) == ""), 
        None
    ).otherwise(sf.col(col_name))

df_cleaned = df_cleaned.withColumns(null_handling_dict)

#all merchant names start with the word 'fraud_'; safe to remove
df_cleaned = df_cleaned.withColumn(
    "merchant",
    sf.regexp_replace(sf.col("merchant"), "fraud_", "")
)

## Display cleaned data 

# Show final result
df_cleaned.show(40,truncate=False)

#convert subsection to markdown for viewing
# print(df_cleaned.filter(sf.col("Unnamed: 0") <= 25).toPandas().to_markdown())

# Show schema to verify structure
df_cleaned.printSchema()

25/07/22 00:57:36 WARN SparkStringUtils: Truncated the string representation of a plan since it was too large. This behavior can be adjusted by setting 'spark.sql.debug.maxToStringFields'.
[Stage 1:>                                                          (0 + 1) / 1]

+----------+-------------------------+-------------------+----------------------------------+-------------+------+-----------+----------+------+------------------------------+------------------------+-----+-----+-------+---------+--------+---------------------------------------------+----------+--------------------------------+---------+-----------+--------+-------------+--------------------------------+--------------------------------+-----------+
|Unnamed: 0|trans_date_trans_time    |cc_num             |merchant                          |category     |amt   |first      |last      |gender|street                        |city                    |state|zip  |lat    |long     |city_pop|job                                          |dob       |trans_num                       |merch_lat|merch_long |is_fraud|merch_zipcode|merch_last_update_time          |merch_eff_time                  |cc_bic     |
+----------+-------------------------+-------------------+----------------------------------+-

                                                                                

Handling PII data

In [4]:
# Direct Identifiers
# cc_num, first, last, trans_num, dob

# Indirect Identifiers
# city, state, zip, street, job, lat, long, cc_bic

# Indentifiers to delete : 
identifier_delete = [
    # not needed for analysis
    "cc_num",
    "first",
    "last",
    "trans_num",
    "cc_bic",
    # specific location not required
    "street",
]

# reducing specificity of location by rounding latitude and longitude
df_cleaned = df_cleaned.withColumn("lat", sf.round(sf.col("lat"), 2))
df_cleaned = df_cleaned.withColumn("long", sf.round(sf.col("long"), 2))

# remove columns 
df_cleaned = df_cleaned.drop(*identifier_delete)

# keep only the year (first 4 digits) from dob
df_cleaned = df_cleaned.withColumn("dob", sf.substring(sf.col("dob"), 1, 4))

df_cleaned.show(40,truncate=False)
print(df_cleaned.tail(40))

df_cleaned.printSchema()

+----------+-------------------------+----------------------------------+-------------+------+------+------------------------+-----+-----+-----+-------+--------+---------------------------------------------+----+---------+-----------+--------+-------------+--------------------------------+--------------------------------+
|Unnamed: 0|trans_date_trans_time    |merchant                          |category     |amt   |gender|city                    |state|zip  |lat  |long   |city_pop|job                                          |dob |merch_lat|merch_long |is_fraud|merch_zipcode|merch_last_update_time          |merch_eff_time                  |
+----------+-------------------------+----------------------------------+-------------+------+------+------------------------+-----+-----+-----+-------+--------+---------------------------------------------+----+---------+-----------+--------+-------------+--------------------------------+--------------------------------+
|0         |2019-01-01 08:00

[Stage 3:>                                                          (0 + 1) / 1]

[Row(Unnamed: 0=1296635, trans_date_trans_time='2020-06-21 19:55:55 +0800', merchant='Berge LLC', category='gas_transport', amt=80.44999694824219, gender='M', city='Sontag', state='MS', zip='39665', lat=31.649999618530273, long=-90.18000030517578, city_pop=1196, job='Librarian, academic', dob='1958', merch_lat=30.870750427246094, merch_long=-89.87936401367188, is_fraud=0, merch_zipcode='70427', merch_last_update_time='2013-06-22 03:55:55.907000 +0800', merch_eff_time='2013-06-22 03:55:55.149813 +0800'), Row(Unnamed: 0=1296636, trans_date_trans_time='2020-06-21 19:55:56 +0800', merchant='Lockman, West and Runte', category='grocery_pos', amt=98.9800033569336, gender='M', city='Bridger', state='MT', zip='59014', lat=45.290000915527344, long=-108.91000366210938, city_pop=1446, job='Chartered loss adjuster', dob='1978', merch_lat=45.25511932373047, merch_long=-108.9658203125, is_fraud=0, merch_zipcode=None, merch_last_update_time='2013-06-22 03:55:56.628000 +0800', merch_eff_time='2013-06-2

                                                                                

Plots - Pyspark now has native plotting without needing to convert to pandas. working on a plotly backend. 

In [5]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import math

KDE plot

In [None]:
# Calculate min and max for the legend
min_amt = df_cleaned.agg(sf.min("amt")).collect()[0][0]
max_amt = df_cleaned.agg(sf.max("amt")).collect()[0][0]

print(f'Min: ${min_amt:.2f} \nMax: ${max_amt:.2f}')

kde_plot = df_cleaned.plot.kde(column = "amt", bw_method=.8, template="plotly_white")
kde_plot.update_layout(title="Distribution of Transaction Amounts", xaxis_title="Transaction Amount (USD)", yaxis_title="Density")
kde_plot.update_xaxes(minor=dict(ticks='inside', ticklen=5,showgrid=True), range=[0, math.log(max_amt,10)], type="log")

kde_plot.show()


Percentage transaction by category

In [None]:
# Calculate total count for percentage calculation
total_count = df_cleaned.count()

# Aggregate data to get counts per category and convert to percentage
category_counts_df = df_cleaned.groupBy("category").count() \
    .withColumn("percentage", (sf.col("count") / total_count) * 100) \
    .orderBy(sf.desc("percentage")) \
    .drop("count") # Drop the raw count column as we now have percentage

# Generate the bar plot using PySpark's native plotting
# Map 'category' to x-axis and 'percentage' to y-axis.
bar_plot = category_counts_df.plot.bar(
    x="category",
    y="percentage",
    title="Percentage of Transactions by Category",
    template="plotly_white",    
)

bar_plot.update_yaxes(
    title_text="Percentage",
)

# Display the plot
bar_plot.show()


Transaction amount by fraud status

In [None]:
fraud_count = df_cleaned.filter(sf.col("is_fraud") == 1).count()
print(fraud_count)

percentage_fraud = (fraud_count / total_count) * 100
print(percentage_fraud)

fraud_data = df_cleaned.filter(sf.col("is_fraud") == 1)

max_amt_fraud = fraud_data.agg(sf.max("amt")).collect()[0][0]

#kde plot of fraud data
kde_plot = fraud_data.plot.kde(column = "amt", bw_method=.9, template="plotly_white")
kde_plot.update_layout(title="Distribution of Fraud Transaction Amounts", xaxis_title="Transaction Amount (USD)", yaxis_title="Density")
kde_plot.update_xaxes(range=[0, max_amt_fraud])

kde_plot.show()

Transaction amount vs city population

In [None]:
max_amt_city_pop = df_cleaned.agg(sf.max("city_pop")).collect()[0][0]

# df_cleaned_sample = df_cleaned.sample(fraction=0.01)

city_pop_scatter = df_cleaned.plot.scatter(
    x = 'city_pop',
    y = "amt",
)

city_pop_scatter.update_layout(title="Distribution of Fraud Transaction Amounts", xaxis_title="City Population", yaxis_title="Amount Spent")
city_pop_scatter.update_xaxes(range=[0, max_amt_city_pop])

#uncomment to add scatter into the kde plot

city_pop_scatter.add_histogram(
    x=df_cleaned.filter(sf.col("city_pop")),
    y=df_cleaned.filter(sf.col("amt")),
    mode="markers",
    name="transaction amount scatter",
)

city_pop_scatter.show()

In [6]:
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import pyspark.sql.functions as sf
from pyspark.sql.types import *

# US State Population Data (2020 Census)
us_state_population = {
    'state': ['AL', 'AK', 'AZ', 'AR', 'CA', 'CO', 'CT', 'DE', 'FL', 'GA', 'HI', 'ID', 'IL', 'IN', 'IA', 'KS', 'KY', 'LA', 'ME', 'MD', 'MA', 'MI', 'MN', 'MS', 'MO', 'MT', 'NE', 'NV', 'NH', 'NJ', 'NM', 'NY', 'NC', 'ND', 'OH', 'OK', 'OR', 'PA', 'RI', 'SC', 'SD', 'TN', 'TX', 'UT', 'VT', 'VA', 'WA', 'WV', 'WI', 'WY'],
    'state_name': ['Alabama', 'Alaska', 'Arizona', 'Arkansas', 'California', 'Colorado', 'Connecticut', 'Delaware', 'Florida', 'Georgia', 'Hawaii', 'Idaho', 'Illinois', 'Indiana', 'Iowa', 'Kansas', 'Kentucky', 'Louisiana', 'Maine', 'Maryland', 'Massachusetts', 'Michigan', 'Minnesota', 'Mississippi', 'Missouri', 'Montana', 'Nebraska', 'Nevada', 'New Hampshire', 'New Jersey', 'New Mexico', 'New York', 'North Carolina', 'North Dakota', 'Ohio', 'Oklahoma', 'Oregon', 'Pennsylvania', 'Rhode Island', 'South Carolina', 'South Dakota', 'Tennessee', 'Texas', 'Utah', 'Vermont', 'Virginia', 'Washington', 'West Virginia', 'Wisconsin', 'Wyoming'],
    'population_2020': [5024279, 733391, 7151502, 3011524, 39538223, 5773714, 3605944, 989948, 21538187, 10711908, 1455271, 1839106, 12812508, 6785528, 3190369, 2937880, 4505836, 4657757, 1395722, 6177224, 7001399, 10037261, 5737915, 2961279, 6196010, 1084225, 1961504, 3104614, 1377529, 9288994, 2117522, 20201249, 10439388, 779094, 11799448, 3959353, 4237256, 13002700, 1097379, 5118425, 886667, 6910840, 29145505, 3271616, 643077, 8631393, 7705281, 1793716, 5893718, 576851]
}

# Create population DataFrame
pop_df = pd.DataFrame(us_state_population)
pop_spark_df = spark.createDataFrame(pop_df)

print("US State Population Data (2020 Census):")
pop_spark_df.show(10)

# Calculate customers per state from your transaction data
customers_by_state = df_cleaned.groupBy("state").agg(
    sf.countDistinct("Unnamed: 0").alias("customer_count"),  # Using transaction ID as proxy for customers
    sf.count("*").alias("total_transactions"),
    sf.avg("amt").alias("avg_transaction_amount"),
    sf.sum("amt").alias("total_spending")
).orderBy(sf.desc("customer_count"))

print("\nTop 10 States by Customer Count:")
customers_by_state.show(10)

# Join with population data
customer_population_joined = customers_by_state.join(
    pop_spark_df, 
    customers_by_state.state == pop_spark_df.state, 
    "inner"
).select(
    customers_by_state.state,
    pop_spark_df.state_name,
    pop_spark_df.population_2020,
    customers_by_state.customer_count,
    customers_by_state.total_transactions,
    customers_by_state.avg_transaction_amount,
    customers_by_state.total_spending
)

# Calculate customers per capita (per 100,000 people)
customer_per_capita = customer_population_joined.withColumn(
    "customers_per_100k",
    (sf.col("customer_count") / sf.col("population_2020")) * 100000
).withColumn(
    "transactions_per_100k", 
    (sf.col("total_transactions") / sf.col("population_2020")) * 100000
).withColumn(
    "spending_per_capita",
    sf.col("total_spending") / sf.col("population_2020")
)

# Convert to Pandas for visualization
analysis_df = customer_per_capita.toPandas()

print(f"\nCustomer Per Capita Analysis:")
print(f"Total states with data: {len(analysis_df)}")
print(f"Highest customers per 100k: {analysis_df['state_name'].iloc[analysis_df['customers_per_100k'].idxmax()]} ({analysis_df['customers_per_100k'].max():.2f})")
print(f"Lowest customers per 100k: {analysis_df['state_name'].iloc[analysis_df['customers_per_100k'].idxmin()]} ({analysis_df['customers_per_100k'].min():.2f})")

# Sort by customers per capita for better visualization
analysis_df = analysis_df.sort_values('customers_per_100k', ascending=False)

# Create subplot figure with multiple visualizations
fig = make_subplots(
    rows=2, cols=2,
    subplot_titles=('Customers Per 100k Population by State', 'Customer Count vs State Population', 
                   'Spending Per Capita by State', 'Geographic Distribution'),
    specs=[[{"type": "bar"}, {"type": "scatter"}],
           [{"type": "bar"}, {"type": "geo"}]],
    vertical_spacing=0.12,
    horizontal_spacing=0.1
)

# 1. Bar chart of customers per 100k
fig.add_trace(
    go.Bar(
        x=analysis_df['state'],
        y=analysis_df['customers_per_100k'],
        name='Customers per 100k',
        marker_color='lightblue',
        text=[f"{val:.1f}" for val in analysis_df['customers_per_100k']],
        textposition='outside'
    ),
    row=1, col=1
)

# 2. Scatter plot: Customer count vs Population
fig.add_trace(
    go.Scatter(
        x=analysis_df['population_2020'],
        y=analysis_df['customer_count'],
        mode='markers+text',
        text=analysis_df['state'],
        textposition='top center',
        name='Customer Count vs Population',
        marker=dict(
            size=analysis_df['customers_per_100k']/2,  # Size based on per capita
            color=analysis_df['customers_per_100k'],
            colorscale='Viridis',
            showscale=True,
            colorbar=dict(title="Customers per 100k")
        )
    ),
    row=1, col=2
)

# 3. Spending per capita bar chart
fig.add_trace(
    go.Bar(
        x=analysis_df['state'],
        y=analysis_df['spending_per_capita'],
        name='Spending per Capita',
        marker_color='lightcoral',
        text=[f"${val:.0f}" for val in analysis_df['spending_per_capita']],
        textposition='outside'
    ),
    row=2, col=1
)

# 4. Geographic choropleth map
fig.add_trace(
    go.Choropleth(
        locations=analysis_df['state'],
        z=analysis_df['customers_per_100k'],
        locationmode='USA-states',
        colorscale='Blues',
        text=analysis_df['state_name'],
        colorbar=dict(title="Customers per 100k", x=1.02),
        name='Geographic Distribution'
    ),
    row=2, col=2
)

# Update layout
fig.update_layout(
    title_text="Customer Distribution Analysis by US State (Per Capita Metrics)",
    title_x=0.5,
    height=1000,
    width=1400,
    showlegend=False,
    font=dict(size=10)
)

# Update geo layout for the map
fig.update_geos(
    scope="usa",
    projection_type="albers usa",
    row=2, col=2
)

# Update axis labels
fig.update_xaxes(title_text="State", row=1, col=1, tickangle=45)
fig.update_yaxes(title_text="Customers per 100k", row=1, col=1)

fig.update_xaxes(title_text="State Population", row=1, col=2, type="log")
fig.update_yaxes(title_text="Customer Count", row=1, col=2, type="log")

fig.update_xaxes(title_text="State", row=2, col=1, tickangle=45)
fig.update_yaxes(title_text="Spending per Capita ($)", row=2, col=1)

# Show the comprehensive analysis
fig.show()

# Print detailed analysis
print(f"\n=== DETAILED ANALYSIS ===")
print(f"\nTop 5 States by Customers per 100k Population:")
for i, row in analysis_df.head().iterrows():
    print(f"{row['state_name']}: {row['customers_per_100k']:.2f} customers per 100k people")

print(f"\nTop 5 States by Spending per Capita:")
spending_sorted = analysis_df.sort_values('spending_per_capita', ascending=False)
for i, row in spending_sorted.head().iterrows():
    print(f"{row['state_name']}: ${row['spending_per_capita']:.2f} per person")

print(f"\nStates with Highest Market Penetration (customers/population ratio):")
for i, row in analysis_df.head(3).iterrows():
    penetration = (row['customer_count'] / row['population_2020']) * 100
    print(f"{row['state_name']}: {penetration:.4f}% of population are customers")

print(f"\nPotentially Underserved Large Markets (High population, Low customers per capita):")
underserved = analysis_df[(analysis_df['population_2020'] > 5000000) & 
                         (analysis_df['customers_per_100k'] < analysis_df['customers_per_100k'].median())]
for i, row in underserved.iterrows():
    print(f"{row['state_name']}: {row['population_2020']:,} people, {row['customers_per_100k']:.2f} customers per 100k")

# Calculate market opportunity
print(f"\n=== MARKET OPPORTUNITY ANALYSIS ===")
avg_customers_per_100k = analysis_df['customers_per_100k'].mean()
print(f"Average customers per 100k across all states: {avg_customers_per_100k:.2f}")

analysis_df['market_gap'] = (avg_customers_per_100k - analysis_df['customers_per_100k']) * analysis_df['population_2020'] / 100000
analysis_df['market_gap'] = analysis_df['market_gap'].clip(lower=0)  # Only positive gaps

potential_markets = analysis_df[analysis_df['market_gap'] > 0].sort_values('market_gap', ascending=False)
print(f"\nTop Expansion Opportunities (if all states reached average penetration):")
for i, row in potential_markets.head(5).iterrows():
    print(f"{row['state_name']}: {row['market_gap']:.0f} additional customers potential")

US State Population Data (2020 Census):


                                                                                

+-----+-----------+---------------+
|state| state_name|population_2020|
+-----+-----------+---------------+
|   AL|    Alabama|        5024279|
|   AK|     Alaska|         733391|
|   AZ|    Arizona|        7151502|
|   AR|   Arkansas|        3011524|
|   CA| California|       39538223|
|   CO|   Colorado|        5773714|
|   CT|Connecticut|        3605944|
|   DE|   Delaware|         989948|
|   FL|    Florida|       21538187|
|   GA|    Georgia|       10711908|
+-----+-----------+---------------+
only showing top 10 rows

Top 10 States by Customer Count:


                                                                                

+-----+--------------+------------------+----------------------+------------------+
|state|customer_count|total_transactions|avg_transaction_amount|    total_spending|
+-----+--------------+------------------+----------------------+------------------+
|   TX|         94876|             94876|     71.68216968714283| 6800917.531237364|
|   NY|         83501|             83501|     71.93325864714355| 6006499.030295134|
|   PA|         79847|             79847|     72.27584669822063| 5771009.531312823|
|   CA|         56360|             56360|      73.4222588943369| 4138078.511284828|
|   OH|         46480|             46480|      73.0668386385922|3396146.6599217653|
|   MI|         46154|             46154|      71.1254703674236| 3282724.959338069|
|   IL|         43252|             43252|     69.63588501567989| 3011891.298698187|
|   FL|         42671|             42671|     73.94229410416057|3155191.6317186356|
|   AL|         40989|             40989|     65.44968406481794| 2682717.100

                                                                                


Customer Per Capita Analysis:
Total states with data: 50
Highest customers per 100k: Wyoming (3349.57)
Lowest customers per 100k: Delaware (0.91)



=== DETAILED ANALYSIS ===

Top 5 States by Customers per 100k Population:
Wyoming: 3349.57 customers per 100k people
North Dakota: 1897.85 customers per 100k people
Vermont: 1829.95 customers per 100k people
West Virginia: 1432.28 customers per 100k people
South Dakota: 1389.92 customers per 100k people

Top 5 States by Spending per Capita:
Wyoming: $2.54 per person
Vermont: $1.50 per person
North Dakota: $1.24 per person
West Virginia: $1.01 per person
South Dakota: $0.98 per person

States with Highest Market Penetration (customers/population ratio):
Wyoming: 3.3496% of population are customers
North Dakota: 1.8978% of population are customers
Vermont: 1.8300% of population are customers

Potentially Underserved Large Markets (High population, Low customers per capita):
Maryland: 6,177,224 people, 424.03 customers per 100k
New York: 20,201,249 people, 413.35 customers per 100k
Indiana: 6,785,528 people, 406.45 customers per 100k
Ohio: 11,799,448 people, 393.92 customers per 100k
Vir

In [None]:
import pyspark.sql.functions as sf
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import pandas as pd

def analyze_underserved_cities(df_cleaned, sample_fraction=0.1):
    """
    Analyze if larger cities are being underserved by plotting:
    1. Transaction count vs city population
    2. Average transaction amount vs city population
    3. Transaction density (transactions per capita) vs city population
    4. Combined analysis showing potential underserved markets
    
    Args:
        df_cleaned: PySpark DataFrame with transaction data
        sample_fraction: Fraction of data to sample for performance (default 0.1)
    
    Returns:
        plotly figure object
    """
    
    # Sample data for better performance if dataset is large
    if sample_fraction < 1.0:
        df_sample = df_cleaned.sample(fraction=sample_fraction, seed=42)
        print(f"Using {sample_fraction*100}% sample of the data for analysis")
    else:
        df_sample = df_cleaned
    
    # Aggregate data by city population
    city_analysis = df_sample.groupBy("city_pop").agg(
        sf.count("*").alias("transaction_count"),
        sf.avg("amt").alias("avg_transaction_amount"),
        sf.sum("amt").alias("total_transaction_amount"),
        sf.countDistinct("city").alias("num_cities")  # In case multiple cities have same population
    ).withColumn(
        "transactions_per_capita", 
        sf.col("transaction_count") / sf.col("city_pop")
    ).filter(
        sf.col("city_pop").isNotNull() & (sf.col("city_pop") > 0)
    ).orderBy("city_pop")
    
    # Convert to Pandas for plotting
    city_pandas = city_analysis.toPandas()
    
    # Create population bins for better visualization
    city_pandas['pop_bin'] = pd.cut(city_pandas['city_pop'], 
                                   bins=10, 
                                   labels=[f'Bin {i+1}' for i in range(10)])
    
    # Create subplots
    fig = make_subplots(
        rows=2, cols=2,
        subplot_titles=('Transaction Count vs City Population',
                       'Avg Transaction Amount vs City Population', 
                       'Transactions Per Capita vs City Population',
                       'Market Opportunity Analysis'),
        specs=[[{"secondary_y": False}, {"secondary_y": False}],
               [{"secondary_y": False}, {"secondary_y": True}]]
    )
    
    # Plot 1: Transaction count vs city population
    fig.add_trace(
        go.Scatter(
            x=city_pandas['city_pop'],
            y=city_pandas['transaction_count'],
            mode='markers',
            name='Transaction Count',
            marker=dict(
                size=8,
                color=city_pandas['transaction_count'],
                colorscale='Blues',
                showscale=False
            ),
            hovertemplate='<b>City Pop:</b> %{x}<br><b>Transactions:</b> %{y}<extra></extra>'
        ),
        row=1, col=1
    )
    
    # Plot 2: Average transaction amount vs city population
    fig.add_trace(
        go.Scatter(
            x=city_pandas['city_pop'],
            y=city_pandas['avg_transaction_amount'],
            mode='markers',
            name='Avg Transaction Amount',
            marker=dict(
                size=8,
                color=city_pandas['avg_transaction_amount'],
                colorscale='Greens',
                showscale=False
            ),
            hovertemplate='<b>City Pop:</b> %{x}<br><b>Avg Amount:</b> $%{y:.2f}<extra></extra>'
        ),
        row=1, col=2
    )
    
    # Plot 3: Transactions per capita vs city population
    fig.add_trace(
        go.Scatter(
            x=city_pandas['city_pop'],
            y=city_pandas['transactions_per_capita'],
            mode='markers',
            name='Transactions Per Capita',
            marker=dict(
                size=8,
                color=city_pandas['transactions_per_capita'],
                colorscale='Reds',
                showscale=False
            ),
            hovertemplate='<b>City Pop:</b> %{x}<br><b>Trans/Capita:</b> %{y:.6f}<extra></extra>'
        ),
        row=2, col=1
    )
    
    # Plot 4: Market opportunity analysis - bubble chart
    # Bubble size represents total transaction amount, color represents transactions per capita
    fig.add_trace(
        go.Scatter(
            x=city_pandas['city_pop'],
            y=city_pandas['transaction_count'],
            mode='markers',
            name='Market Opportunity',
            marker=dict(
                size=city_pandas['total_transaction_amount']/1000,  # Scale down for reasonable bubble size
                sizemode='diameter',
                sizemin=4,
                color=city_pandas['transactions_per_capita'],
                colorscale='RdYlBu_r',
                showscale=True,
                colorbar=dict(
                    title="Trans/Capita",
                    x=1.02,
                    len=0.5,
                    y=0.25
                ),
                line=dict(width=1, color='black')
            ),
            hovertemplate='<b>City Pop:</b> %{x}<br>' +
                         '<b>Transaction Count:</b> %{y}<br>' +
                         '<b>Trans/Capita:</b> %{marker.color:.6f}<br>' +
                         '<b>Total Amount:</b> $%{text}<extra></extra>',
            text=[f'{amt:,.0f}' for amt in city_pandas['total_transaction_amount']]
        ),
        row=2, col=2
    )
    
    # Update layout
    fig.update_layout(
        height=800,
        title_text="City Population vs Transaction Activity Analysis - Identifying Underserved Markets",
        title_x=0.5,
        showlegend=False,
        template="plotly_white"
    )
    
    # Update x-axes
    fig.update_xaxes(title_text="City Population", row=1, col=1, type="log")
    fig.update_xaxes(title_text="City Population", row=1, col=2, type="log")
    fig.update_xaxes(title_text="City Population", row=2, col=1, type="log")
    fig.update_xaxes(title_text="City Population", row=2, col=2, type="log")
    
    # Update y-axes
    fig.update_yaxes(title_text="Transaction Count", row=1, col=1)
    fig.update_yaxes(title_text="Avg Transaction Amount ($)", row=1, col=2)
    fig.update_yaxes(title_text="Transactions Per Capita", row=2, col=1)
    fig.update_yaxes(title_text="Transaction Count", row=2, col=2)
    
    # Print insights
    print("=== UNDERSERVED CITIES ANALYSIS ===")
    print(f"Total cities analyzed: {len(city_pandas)}")
    
    # Identify potentially underserved large cities
    large_cities = city_pandas[city_pandas['city_pop'] > city_pandas['city_pop'].quantile(0.8)]
    low_activity_large_cities = large_cities[
        large_cities['transactions_per_capita'] < large_cities['transactions_per_capita'].median()
    ]
    
    print(f"\nLarge cities (top 20% by population): {len(large_cities)}")
    print(f"Large cities with below-median transaction activity: {len(low_activity_large_cities)}")
    
    if len(low_activity_large_cities) > 0:
        print("\nPotentially underserved large cities:")
        for _, city in low_activity_large_cities.head(5).iterrows():
            print(f"  Population: {city['city_pop']:,} | "
                  f"Trans/Capita: {city['transactions_per_capita']:.6f} | "
                  f"Total Transactions: {city['transaction_count']:,}")
    
    # Calculate correlation
    correlation = city_pandas[['city_pop', 'transaction_count', 'transactions_per_capita']].corr()
    print(f"\nCorrelation between city population and transaction count: {correlation.loc['city_pop', 'transaction_count']:.3f}")
    print(f"Correlation between city population and transactions per capita: {correlation.loc['city_pop', 'transactions_per_capita']:.3f}")
    
    return fig

# Usage example:
# fig = analyze_underserved_cities(df_cleaned, sample_fraction=0.1)
# fig.show()

def create_city_market_heatmap(df_cleaned, sample_fraction=0.1, use_percentile_range=True, percentile_cap=95):
    """
    Create a geographic heatmap showing market penetration by city population.
    
    Args:
        df_cleaned: PySpark DataFrame with transaction data
        sample_fraction: Fraction of data to sample for performance
        use_percentile_range: If True, cap the color scale at the specified percentile
        percentile_cap: Percentile to cap the color scale (default 95th percentile)
    
    Returns:
        plotly figure object
    """
    
    # Sample and aggregate data
    if sample_fraction < 1.0:
        df_sample = df_cleaned.sample(fraction=sample_fraction, seed=42)
    else:
        df_sample = df_cleaned
    
    # Aggregate by location
    location_analysis = df_sample.groupBy("lat", "long", "city_pop").agg(
        sf.count("*").alias("transaction_count"),
        sf.avg("amt").alias("avg_transaction_amount")
    ).withColumn(
        "transactions_per_capita",
        sf.col("transaction_count") / sf.col("city_pop")
    ).filter(
        sf.col("city_pop").isNotNull() & 
        sf.col("lat").isNotNull() & 
        sf.col("long").isNotNull() &
        (sf.col("city_pop") > 0)
    )
    
    # Convert to Pandas
    location_pandas = location_analysis.toPandas()
    
    # Handle outliers in color scale
    if use_percentile_range:
        color_max = location_pandas['transactions_per_capita'].quantile(percentile_cap/100)
        color_min = location_pandas['transactions_per_capita'].quantile(0.05)  # 5th percentile as minimum
        
        # Cap the values for color mapping while preserving original values for hover
        location_pandas['transactions_per_capita_capped'] = location_pandas['transactions_per_capita'].clip(
            lower=color_min, upper=color_max
        )
        color_column = 'transactions_per_capita_capped'
        
        print(f"Color scale range: {color_min:.6f} to {color_max:.6f} (5th to {percentile_cap}th percentile)")
        print(f"Original range: {location_pandas['transactions_per_capita'].min():.6f} to {location_pandas['transactions_per_capita'].max():.6f}")
    else:
        color_column = 'transactions_per_capita'
        color_min = None
        color_max = None
    
    # Create the heatmap
    fig = px.scatter_mapbox(
        location_pandas,
        lat="lat",
        lon="long",
        size="city_pop",
        color=color_column,
        hover_data={
            "city_pop": ":,",
            "transaction_count": ":,",
            "avg_transaction_amount": ":.2f",
            "transactions_per_capita": ":.6f"  # Always show original values in hover
        },
        color_continuous_scale="RdYlBu_r",
        size_max=50,
        zoom=3,
        title=f"Geographic Distribution: City Population vs Transaction Activity<br><sub>Color scale capped at {percentile_cap}th percentile for better sensitivity</sub>",
        range_color=[color_min, color_max] if use_percentile_range else None
    )
    
    fig.update_layout(
        mapbox_style="open-street-map",
        height=600,
        coloraxis_colorbar=dict(
            title="Transactions Per Capita<br>(Capped Scale)" if use_percentile_range else "Transactions Per Capita"
        )
    )
    
    return fig

# Usage example with improved color sensitivity:
# fig_map = create_city_market_heatmap(df_cleaned, sample_fraction=0.1, use_percentile_range=True, percentile_cap=95)
# fig_map.show()

# Alternative: Use 90th percentile for even more sensitivity
fig_map = create_city_market_heatmap(df_cleaned, sample_fraction=0.1, use_percentile_range=True, percentile_cap=90)
fig_map.show()

fig_underserved = analyze_underserved_cities(df_cleaned, sample_fraction=0.1)
fig_underserved.show()

# Or disable capping to see the full range
# fig_map = create_city_market_heatmap(df_cleaned, sample_fraction=0.1, use_percentile_range=False)
# fig_map.show()


Interactive Map overlaying the heatmap and the population distribution to find out underserved markets


In [None]:
import os

pandas_df = df_cleaned.select(["lat", "long", "amt"]).toPandas()

print(f'Pandas dataframe shape: {pandas_df.shape}')

heat_data = pandas_df.to_numpy().tolist()

if not pandas_df.empty:
    map_center = [pandas_df.lat.mean(), pandas_df.long.mean()]
else:
    map_center = [0, 0]


In [None]:
import dash
from dash import html
from dash import dcc
import plotly.express as px
import folium
from folium.plugins import HeatMap


# Set the desired initial zoom level for the Folium map
desired_zoom_level = 3
m = folium.Map(location=map_center, zoom_start=desired_zoom_level, tiles="OpenStreetMap")

HeatMap(heat_data).add_to(m)

assets_folder = 'assets'
if not os.path.exists(assets_folder):
    os.makedirs(assets_folder)
folium_map_path = os.path.join(assets_folder, 'folium_heatmap.html')
m.save(folium_map_path)
print(f"\nFolium map saved to {folium_map_path}")


app = dash.Dash(__name__) 

app.layout = html.Div(
    style={
        'fontFamily': 'Inter, sans-serif',
        'padding': '20px',
        'backgroundColor': '#f8f9fa',
        'minHeight': '100vh',
        'display': 'flex',
        'flexDirection': 'column',
        'alignItems': 'center'
    },
    children=[
        html.H1(
            "Geospatial Heatmap Dashboard",
            style={
                'textAlign': 'center',
                'color': '#343a40',
                'marginBottom': '30px',
                'fontSize': '2.5em',
                'fontWeight': '600'
            }
        ),

        html.Div(
            style={
                'display': 'flex',
                'flexDirection': 'row',
                'gap': '30px',
                'flexWrap': 'wrap',
                'justifyContent': 'center',
                'width': '100%',
                'maxWidth': '1200px'
            },
            children=[
                # Left panel for the Folium Map and its Legend (THIS PART REMAINS UNCHANGED)
                html.Div(
                    style={
                        'display': 'flex',
                        'flexDirection': 'column',
                        'flex': '1',
                        'minWidth': '550px',
                        'backgroundColor': 'white',
                        'padding': '25px',
                        'borderRadius': '12px',
                        'boxShadow': '0 4px 12px rgba(0,0,0,0.1)',
                        'border': '1px solid #e0e0e0'
                    },
                    children=[
                        html.H2(
                            "Expenditure Heatmap",
                            style={
                                'textAlign': 'center',
                                'color': '#495057',
                                'marginBottom': '25px',
                                'fontSize': '1.8em',
                                'fontWeight': '500'
                            }
                        ),
                        html.Div(
                            style={
                                'display': 'flex',
                                'flexDirection': 'row',
                                'gap': '25px',
                                'justifyContent': 'center',
                                'alignItems': 'center',
                                'width': '100%'
                            },
                            children=[
                                html.Div(
                                    style={
                                        'display': 'flex',
                                        'flexDirection': 'column',
                                        'alignItems': 'center',
                                        'padding': '15px 10px',
                                        'backgroundColor': '#f0f2f5',
                                        'borderRadius': '8px',
                                        'border': '1px solid #d0d0d0',
                                        'height': '400px',
                                        'justifyContent': 'space-between',
                                        'minWidth': '90px',
                                        'maxWidth': '120px',
                                        'boxShadow': '0 2px 8px rgba(0,0,0,0.08)'
                                    },
                                    children=[
                                        html.P(
                                            "Intensity",
                                            style={
                                                'fontWeight': '600',
                                                'marginBottom': '10px',
                                                'textAlign': 'center',
                                                'color': '#343a40',
                                                'fontSize': '1.1em'
                                            }
                                        ),
                                        html.Span(
                                            f"Max: ${max_amt:.2f}",
                                            style={
                                                'fontSize': '0.95em',
                                                'whiteSpace': 'nowrap',
                                                'fontWeight': 'bold',
                                                'color': '#dc3545'
                                            }
                                        ),
                                        html.Div(
                                            style={
                                                'width': '25px',
                                                'height': '200px',
                                                'margin': '15px 0',
                                                'background': 'linear-gradient(to bottom, #dc3545, #fd7e14, #ffc107, #28a745, #0011ff)',
                                                'borderRadius': '5px',
                                                'border': '1px solid #ccc'
                                            }
                                        ),
                                        html.Span(
                                            f"Min: ${min_amt:.2f}",
                                            style={
                                                'fontSize': '0.95em',
                                                'whiteSpace': 'nowrap',
                                                'fontWeight': 'bold',
                                                'color': "#0011ff"
                                            }
                                        )
                                    ]
                                ),
                                html.Iframe(
                                    id='folium-map-iframe',
                                    srcDoc=open(folium_map_path, 'r').read(),
                                    style={
                                        'flexGrow': '1',
                                        'minWidth': '400px',
                                        'height': '400px',
                                        'border': 'none',
                                        'borderRadius': '8px',
                                        'boxShadow': '0 2px 8px rgba(0,0,0,0.08)'
                                    }
                                )
                            ]
                        )
                    ]
                ),

                # # Right panel for the Scatter Map - MODIFIED SECTION
                # html.Div(
                #     style={
                #         'display': 'flex',
                #         'flexDirection': 'column',
                #         'flex': '1',
                #         'minWidth': '550px',
                #         'backgroundColor': 'white',
                #         'padding': '25px',
                #         'borderRadius': '12px',
                #         'boxShadow': '0 4px 12px rgba(0,0,0,0.1)',
                #         'border': '1px solid #e0e0e0'
                #     },
                #     children=[
                #         html.H2(
                #             "Scatter Map",
                #             style={
                #                 'textAlign': 'center',
                #                 'color': '#495057',
                #                 'marginBottom': '25px',
                #                 'fontSize': '1.8em',
                #                 'fontWeight': '500'
                #             }
                #         ),
                #         html.Div([ # This div just wraps the dcc.Graph
                #             dcc.Graph(
                #                 id='scatter-map',
                #                 figure=scatter_map,
                #                 style={'height': '450px', 'width': '100%'} # Explicitly set height and width
                #             )
                #         ],
                #             style = { # Remove redundant flex properties here
                #                 'width': '100%',
                #                 'display': 'flex',
                #                 'justifyContent': 'center', # Center the graph within its container
                #                 'alignItems': 'center'
                #             }
                #         )
                #     ]
                # )
            ]
        )
    ]
)

print("\\nRunning Dash application inline in Jupyter Notebook")
app.run(mode='inline', port=8050)