In [None]:
# Databricks notebook source
# MAGIC %md
# MAGIC # Trade Area Feature Aggregation
# MAGIC
# MAGIC Aggregates H3 features to trade area level:
# MAGIC 1. H3 index trade areas → get h3_cell_ids
# MAGIC 2. Join to h3_features_gold on h3_cell_id
# MAGIC 3. Aggregate by store_number

In [None]:
from pyspark.sql import functions as F
import yaml

dbutils.widgets.text("catalog", "geo_site_selection")
dbutils.widgets.text("silver_schema", "silver")
dbutils.widgets.text("gold_schema", "gold")
dbutils.widgets.text("config_path", "/Workspace/resources/configs/h3_features_config.yml")
dbutils.widgets.text("trade_area_table", "", "Trade Area Table (optional)")
dbutils.widgets.text("output_table_override", "", "Output Table (optional)")

catalog = dbutils.widgets.get("catalog")
silver_schema = dbutils.widgets.get("silver_schema")
gold_schema = dbutils.widgets.get("gold_schema")
config_path = dbutils.widgets.get("config_path")
trade_area_table_override = dbutils.widgets.get("trade_area_table")
output_table_override = dbutils.widgets.get("output_table_override")

with open(config_path, 'r') as f:
    config = yaml.safe_load(f)

H3_RESOLUTION = config['h3_grid']['resolution']

if trade_area_table_override and trade_area_table_override.strip():
    trade_area_table = trade_area_table_override.strip()
    default_output = trade_area_table.split('.')[-1] + "_features"
else:
    trade_area_table = f"{catalog}.{silver_schema}.silver_rmc_urbanicity_based_isochrones"
    default_output = "rmc_trade_area_features"

if output_table_override and output_table_override.strip():
    output_table_name = output_table_override.strip()
else:
    output_table_name = default_output

print(f"Input: {trade_area_table}")
print(f"Output: {catalog}.{gold_schema}.gold_{output_table_name}")

%md
## H3 Index Trade Areas

In [None]:
trade_areas = spark.table(trade_area_table)

ta_h3 = trade_areas.select(
    F.col("store_number"),
    F.col("latitude"),
    F.col("longitude"),
    F.col("store_type"),
    F.col("city"),
    F.col("state"),
    F.col("drive_time_minutes"),
    F.col("area_sqkm"),
    F.col("geometry"),  # Keep geometry for output
    F.explode(F.expr(f"h3_polyfillash3string(ST_AsText(geometry), {H3_RESOLUTION})")).alias("h3_cell_id")
)

print(f"Trade areas indexed with H3")
display(ta_h3)

In [None]:
display(trade_areas)

%md
## Join to H3 Features

In [None]:
h3_features = spark.table(f"{catalog}.{gold_schema}.silver_h3_features").drop("h3_geometry", "h3_resolution", "processing_timestamp")

ta_with_features = ta_h3.join(h3_features, "h3_cell_id", "inner")

print(f"Joined trade areas with H3 features")
display(ta_with_features.limit(5))

%md
## Aggregate by Store

In [None]:
demo_vars = config['demographic_variables']
count_vars = (
    demo_vars['population'] + 
    demo_vars['income'] + 
    demo_vars['households'] + 
    demo_vars['education'] + 
    demo_vars['employment'] + 
    demo_vars['housing'] + 
    demo_vars['commute']
)
median_vars = demo_vars['median']

existing_count_vars = [v for v in count_vars if v in ta_with_features.columns]
existing_median_vars = [v for v in median_vars if v in ta_with_features.columns]
poi_cols = [c for c in ta_with_features.columns if c.startswith('poi_count_')]
competitor_cols = [c for c in ta_with_features.columns if c.startswith('competitor_count_')]
distance_cols = [c for c in ta_with_features.columns if c.startswith('distance_to_')]

In [None]:
agg_exprs = []

# Count variables: sum (convert negatives to positive for demo purposes)
for var in existing_count_vars:
    agg_exprs.append(F.abs(F.sum(var)).cast("long").alias(var))

# POI counts: sum (convert negatives to positive)
for col in poi_cols:
    agg_exprs.append(F.abs(F.sum(col)).cast("long").alias(col))
agg_exprs.append(F.abs(F.sum("total_poi_count")).cast("long").alias("total_poi_count"))

# Competitor counts: sum (convert negatives to positive)
for col in competitor_cols:
    agg_exprs.append(F.abs(F.sum(col)).cast("long").alias(col))
agg_exprs.append(F.abs(F.sum("total_competitor_count")).cast("long").alias("total_competitor_count"))

# Median/rate variables: avg (convert negatives to positive)
for var in existing_median_vars:
    agg_exprs.append(F.abs(F.avg(var)).alias(var))
if 'per_capita_income' in ta_with_features.columns:
    agg_exprs.append(F.abs(F.avg("per_capita_income")).alias("per_capita_income"))

# Distance features: min (keep as-is, distances should be positive)
for col in distance_cols:
    agg_exprs.append(F.min(col).alias(col))

# Population density and urbanicity: avg (convert negatives to positive)
agg_exprs.extend([
    F.abs(F.avg("urbanicity_score")).alias("urbanicity_score"),
    F.count("h3_cell_id").alias("h3_cell_count"),
    F.first("geometry").alias("geometry")  # Keep the isochrone geometry
])

ta_features_agg = ta_with_features.groupBy(
    "store_number",
    "latitude",
    "longitude",
    "store_type",
    "city",
    "state",
    # "urbanicity_category",
    "drive_time_minutes",
    "area_sqkm"
).agg(*agg_exprs)

display(ta_features_agg.limit(5))

In [None]:
display(ta_features_agg)

%md
## Write to Gold

In [None]:
ta_features_final = ta_features_agg.withColumn("processing_timestamp", F.current_timestamp())

numeric_cols = [
    field.name for field in ta_features_final.schema.fields 
    if field.dataType.typeName() in ['long', 'double', 'integer', 'float']
    and field.name not in ['latitude', 'longitude', 'drive_time_minutes', 'area_sqkm']
]
ta_features_final = ta_features_final.fillna(0, subset=numeric_cols)

# Ensure one record per store (deduplication)
# Keep first occurrence if there are any duplicates
from pyspark.sql.window import Window
window_spec = Window.partitionBy("store_number").orderBy(F.desc("processing_timestamp"))
ta_features_final = ta_features_final.withColumn("row_num", F.row_number().over(window_spec)).filter(F.col("row_num") == 1).drop("row_num")

output_table = f"{catalog}.{gold_schema}.gold_{output_table_name}"

print(f"Records to write: {ta_features_final.count()}")

(
    ta_features_final
    .write
    .format("delta")
    .mode("overwrite")
    .option("overwriteSchema", "true")
    .saveAsTable(output_table)
)

print(f"Written to {output_table}")

In [None]:
display(spark.sql(f"""
  SELECT
    COUNT(*) as total_trade_areas,
    ROUND(AVG(area_sqkm), 2) as avg_area_sqkm,
    ROUND(AVG(total_population), 0) as avg_population,
    ROUND(AVG(total_poi_count), 0) as avg_poi_count,
    ROUND(AVG(total_competitor_count), 0) as avg_competitor_count,
    ROUND(AVG(h3_cell_count), 0) as avg_h3_cells
  FROM {output_table}
"""))