In [None]:
# OSM Silver → Gold: Hex-Level Infrastructure Features (counts + area m² + area %)
# Silver: s3://ie-datalake/silver/osm/features_energy/country=XX/snapshot_date=.../feature_type=<TYPE>/
# Gold:   s3://ie-datalake/gold/osm_hex_features/country=XX/h3_resolution=N/

# ─── CONFIG ───
S3_BUCKET = "ie-datalake"
SILVER_PREFIX = "silver/osm/features_energy"
GOLD_PREFIX = "gold/osm_hex_features"
COUNTRIES = ["ES"]
SNAPSHOT_DATES = None  # None = all; or ["2026-02-25"]
H3_RESOLUTIONS = [6, 7, 8, 9]
H3_AREA_M2 = {6: 36_129_062.16, 7: 5_161_293.36, 8: 737_327.60, 9: 105_332.51}
PARQUET_COMPRESSION = "snappy"

# ─── SPARK / GLUE ───
from pyspark.sql import functions as F
try:
    from awsglue.context import GlueContext
    from pyspark.context import SparkContext
    sc = SparkContext.getOrCreate()
    glueContext = GlueContext(sc)
    spark = glueContext.spark_session
except ImportError:
    spark = __import__("pyspark.sql", fromlist=["SparkSession"]).SparkSession.builder.getOrCreate()

# Fix Parquet schema mismatch (BINARY vs INT/dict across files)
spark.conf.set("spark.sql.parquet.enableVectorizedReader", "false")

from pyspark.sql.types import StructType, StructField, StringType, DoubleType
SILVER_SCHEMA = StructType([
    StructField("country", StringType(), True),
    StructField("snapshot_date", StringType(), True),
    StructField("feature_type", StringType(), True),
    StructField("highway", StringType(), True),
    StructField("waterway", StringType(), True),
    StructField("man_made", StringType(), True),
    StructField("power", StringType(), True),
    StructField("plant_source", StringType(), True),
    StructField("generator_source", StringType(), True),
    StructField("length_m", DoubleType(), True),
    StructField("area_m2", DoubleType(), True),
    StructField("h3_6", StringType(), True),
    StructField("h3_7", StringType(), True),
    StructField("h3_8", StringType(), True),
    StructField("h3_9", StringType(), True),
])

def read_silver():
    base = f"s3://{S3_BUCKET}/{SILVER_PREFIX}/"
    df = spark.read.schema(SILVER_SCHEMA).parquet(base)
    if COUNTRIES:
        df = df.filter(F.col("country").isin(COUNTRIES))
    if SNAPSHOT_DATES:
        df = df.filter(F.col("snapshot_date").isin(SNAPSHOT_DATES))
    return df

def ensure_length_area(df):
    """Coalesce length_m/area_m2 to 0 (backward compat with old silver)."""
    if "length_m" in df.columns:
        df = df.withColumn("length_m", F.coalesce(F.col("length_m"), F.lit(0.0)))
    else:
        df = df.withColumn("length_m", F.lit(0.0))
    if "area_m2" in df.columns:
        df = df.withColumn("area_m2", F.coalesce(F.col("area_m2"), F.lit(0.0)))
    else:
        df = df.withColumn("area_m2", F.lit(0.0))
    return df

MAJOR_HIGHWAY = ("motorway", "trunk", "primary", "secondary", "tertiary")

def explode_h3_resolutions(df):
    """One row per (original_row, h3_resolution). Uses stack to explode h3_6..h3_9."""
    return df.selectExpr(
        "country", "feature_type", "highway", "waterway", "man_made", "power",
        "plant_source", "generator_source", "length_m", "area_m2",
        "stack(4, 6, h3_6, 7, h3_7, 8, h3_8, 9, h3_9) as (h3_resolution, h3_index)"
    ).filter(F.col("h3_index").isNotNull() & F.col("country").isNotNull())

def compute_hex_metrics_unified(df_exploded):
    """Single groupBy(country, h3_resolution, h3_index) for all resolutions."""
    highway = F.coalesce(F.col("highway"), F.lit(""))
    waterway = F.coalesce(F.col("waterway"), F.lit(""))
    man_made = F.coalesce(F.col("man_made"), F.lit(""))
    power = F.coalesce(F.col("power"), F.lit(""))
    plant_source = F.coalesce(F.col("plant_source"), F.lit(""))
    gen_source = F.coalesce(F.col("generator_source"), F.lit(""))
    # A) Transport
    road_cnt = F.when(F.col("feature_type") == "ROADS", 1).otherwise(0)
    major_road_cnt = F.when((F.col("feature_type") == "ROADS") & (highway.isin(list(MAJOR_HIGHWAY))), 1).otherwise(0)
    trail_cnt = F.when(F.col("feature_type") == "TRAILS_TRACKS", 1).otherwise(0)
    rail_cnt = F.when(F.col("feature_type") == "RAIL", 1).otherwise(0)
    port_cnt = F.when(F.col("feature_type") == "PORTS_TERMINALS", 1).otherwise(0)
    airport_cnt = F.when(F.col("feature_type") == "AIRPORTS", 1).otherwise(0)

    # B) Energy
    pipeline_cnt = F.when(F.col("feature_type") == "PIPELINES", 1).otherwise(0)
    power_line_cnt = F.when(F.col("feature_type") == "POWER_LINES", 1).otherwise(0)
    power_substation_cnt = F.when(F.col("feature_type") == "POWER_SUBSTATIONS", 1).otherwise(0)
    power_plant_cnt = F.when(F.col("feature_type") == "POWER_PLANTS", 1).otherwise(0)
    solar_plant_cnt = F.when((F.col("feature_type") == "POWER_PLANTS") & (F.lower(F.concat(plant_source, gen_source)).rlike("solar|photovoltaic")), 1).otherwise(0)
    wind_plant_cnt = F.when((F.col("feature_type") == "POWER_PLANTS") & (F.lower(F.concat(plant_source, gen_source)).rlike("wind")), 1).otherwise(0)
    hydro_plant_cnt = F.when((F.col("feature_type") == "POWER_PLANTS") & (F.lower(F.concat(plant_source, gen_source)).rlike("hydro|water")), 1).otherwise(0)
    industrial_cnt = F.when(F.col("feature_type") == "INDUSTRIAL_AREAS", 1).otherwise(0)
    storage_tank_cnt = F.when(F.col("feature_type") == "STORAGE_TANKS", 1).otherwise(0)
    fuel_station_cnt = F.when(F.col("feature_type") == "FUEL_STATIONS", 1).otherwise(0)

    # C) Hydro & wetness
    waterway_cnt = F.when(F.col("feature_type") == "WATERWAYS", 1).otherwise(0)
    waterbody_cnt = F.when(F.col("feature_type") == "WATERBODIES", 1).otherwise(0)
    wetland_cnt = F.when(F.col("feature_type") == "WETLANDS", 1).otherwise(0)
    coastline_cnt = F.when(F.col("feature_type") == "COASTLINE", 1).otherwise(0)
    dam_cnt = F.when(man_made == "dam", 1).otherwise(0)
    weir_cnt = F.when(waterway == "weir", 1).otherwise(0)
    lock_cnt = F.when(waterway == "lock_gate", 1).otherwise(0)
    water_barrier_cnt = F.when(F.col("feature_type") == "WATER_BARRIERS", 1).otherwise(0)
    water_infra_poi_cnt = F.when(F.col("feature_type") == "WATER_INFRA_POI", 1).otherwise(0)

    # D) Built footprint
    building_cnt = F.when(F.col("feature_type") == "BUILDINGS", 1).otherwise(0)
    amenity_cnt = F.when(F.col("feature_type") == "AMENITIES_POI", 1).otherwise(0)
    parks_green_cnt = F.when(F.col("feature_type") == "PARKS_GREEN_URBAN", 1).otherwise(0)
    tree_hedgerow_cnt = F.when(F.col("feature_type") == "TREE_ROWS_HEDGEROWS", 1).otherwise(0)

    # E) Landuse / habitat
    agri_cnt = F.when(F.col("feature_type") == "LANDUSE_AGRICULTURE", 1).otherwise(0)
    managed_forest_cnt = F.when(F.col("feature_type") == "FORESTRY_MANAGED", 1).otherwise(0)
    natural_habitat_cnt = F.when(F.col("feature_type") == "NATURAL_HABITATS", 1).otherwise(0)

    # F) Constraints
    protected_cnt = F.when(F.col("feature_type") == "PROTECTED_AREAS", 1).otherwise(0)
    restricted_cnt = F.when(F.col("feature_type") == "RESTRICTED_AREAS", 1).otherwise(0)
    admin_boundary_cnt = F.when(F.col("feature_type") == "ADMIN_BOUNDARIES", 1).otherwise(0)

    # G) Fragmentation / barriers
    barrier_cnt = F.when(F.col("feature_type") == "BARRIERS", 1).otherwise(0)
    linear_disturbance_cnt = F.when(F.col("feature_type") == "LINEAR_DISTURBANCE", 1).otherwise(0)

    # H) Pollution
    waste_cnt = F.when(F.col("feature_type") == "WASTE_POLLUTION", 1).otherwise(0)

    # Area sums (polygons + estimated road area, waterway riverbank)
    waterbody_area = F.when(F.col("feature_type") == "WATERBODIES", F.col("area_m2")).otherwise(0.0)
    waterway_area = F.when(F.col("feature_type") == "WATERWAYS", F.col("area_m2")).otherwise(0.0)
    wetland_area = F.when(F.col("feature_type") == "WETLANDS", F.col("area_m2")).otherwise(0.0)
    road_area = F.when(F.col("feature_type") == "ROADS", F.col("area_m2")).otherwise(0.0)
    building_area = F.when(F.col("feature_type") == "BUILDINGS", F.col("area_m2")).otherwise(0.0)
    parks_green_area = F.when(F.col("feature_type") == "PARKS_GREEN_URBAN", F.col("area_m2")).otherwise(0.0)
    industrial_area = F.when(F.col("feature_type") == "INDUSTRIAL_AREAS", F.col("area_m2")).otherwise(0.0)
    residential_area = F.when(F.col("feature_type") == "RESIDENTIAL_AREAS", F.col("area_m2")).otherwise(0.0)
    commercial_area = F.when(F.col("feature_type") == "COMMERCIAL_AREAS", F.col("area_m2")).otherwise(0.0)
    parking_area = F.when(F.col("feature_type") == "PARKING_AREAS", F.col("area_m2")).otherwise(0.0)
    cemetery_area = F.when(F.col("feature_type") == "CEMETERIES", F.col("area_m2")).otherwise(0.0)
    construction_area = F.when(F.col("feature_type") == "CONSTRUCTION", F.col("area_m2")).otherwise(0.0)
    retention_basin_area = F.when(F.col("feature_type") == "RETENTION_BASIN", F.col("area_m2")).otherwise(0.0)
    agri_area = F.when(F.col("feature_type") == "LANDUSE_AGRICULTURE", F.col("area_m2")).otherwise(0.0)
    managed_forest_area = F.when(F.col("feature_type") == "FORESTRY_MANAGED", F.col("area_m2")).otherwise(0.0)
    natural_habitat_area = F.when(F.col("feature_type") == "NATURAL_HABITATS", F.col("area_m2")).otherwise(0.0)
    protected_area = F.when(F.col("feature_type") == "PROTECTED_AREAS", F.col("area_m2")).otherwise(0.0)
    restricted_area = F.when(F.col("feature_type") == "RESTRICTED_AREAS", F.col("area_m2")).otherwise(0.0)
    waste_area = F.when(F.col("feature_type") == "WASTE_POLLUTION", F.col("area_m2")).otherwise(0.0)

    agg = df_exploded.groupBy("country", "h3_resolution", "h3_index").agg(
        F.sum(road_cnt).alias("road_count"), F.sum(major_road_cnt).alias("major_road_count"),
        F.sum(trail_cnt).alias("trail_count"), F.sum(rail_cnt).alias("rail_count"),
        F.sum(port_cnt).alias("port_feature_count"), F.sum(airport_cnt).alias("airport_feature_count"),
        F.sum(pipeline_cnt).alias("pipeline_count"), F.sum(power_line_cnt).alias("power_line_count"),
        F.sum(power_substation_cnt).alias("power_substation_count"), F.sum(power_plant_cnt).alias("power_plant_count"),
        F.sum(solar_plant_cnt).alias("solar_plant_count"), F.sum(wind_plant_cnt).alias("wind_plant_count"),
        F.sum(hydro_plant_cnt).alias("hydro_plant_count"), F.sum(industrial_cnt).alias("industrial_area_count"),
        F.sum(storage_tank_cnt).alias("storage_tank_count"), F.sum(fuel_station_cnt).alias("fuel_station_count"),
        F.sum(waterway_cnt).alias("waterway_count"), F.sum(waterbody_cnt).alias("waterbody_count"),
        F.sum(wetland_cnt).alias("wetland_count"), F.sum(coastline_cnt).alias("coastline_count"),
        F.sum(dam_cnt).alias("dam_count"), F.sum(weir_cnt).alias("weir_count"), F.sum(lock_cnt).alias("lock_count"),
        F.sum(water_barrier_cnt).alias("water_barrier_count_total"), F.sum(water_infra_poi_cnt).alias("water_infra_poi_count"),
        F.sum(building_cnt).alias("building_count"), F.sum(amenity_cnt).alias("amenity_count_total"),
        F.sum(parks_green_cnt).alias("parks_green_count"), F.sum(tree_hedgerow_cnt).alias("tree_rows_hedgerow_count"),
        F.sum(agri_cnt).alias("landuse_agriculture_count"), F.sum(managed_forest_cnt).alias("managed_forest_count"),
        F.sum(natural_habitat_cnt).alias("natural_habitat_count"),
        F.sum(protected_cnt).alias("protected_area_count"), F.sum(restricted_cnt).alias("restricted_area_count"),
        F.sum(admin_boundary_cnt).alias("admin_boundary_count"),
        F.sum(barrier_cnt).alias("barrier_count"), F.sum(linear_disturbance_cnt).alias("linear_disturbance_count"),
        F.sum(waste_cnt).alias("waste_site_count"),
        F.sum(waterbody_area).alias("waterbody_area_m2"), F.sum(waterway_area).alias("waterway_area_m2"),
        F.sum(wetland_area).alias("wetland_area_m2"), F.sum(road_area).alias("road_area_m2"),
        F.sum(building_area).alias("building_area_m2"), F.sum(parks_green_area).alias("parks_green_area_m2"),
        F.sum(industrial_area).alias("industrial_area_m2"),
        F.sum(residential_area).alias("residential_area_m2"), F.sum(commercial_area).alias("commercial_area_m2"),
        F.sum(parking_area).alias("parking_area_m2"), F.sum(cemetery_area).alias("cemetery_area_m2"),
        F.sum(construction_area).alias("construction_area_m2"), F.sum(retention_basin_area).alias("retention_basin_area_m2"),
        F.sum(agri_area).alias("agri_area_m2"),
        F.sum(managed_forest_area).alias("managed_forest_area_m2"), F.sum(natural_habitat_area).alias("natural_habitat_area_m2"),
        F.sum(protected_area).alias("protected_area_m2"), F.sum(restricted_area).alias("restricted_area_m2"),
        F.sum(waste_area).alias("waste_site_area_m2"),
    )
    hex_area_expr = F.when(F.col("h3_resolution") == 6, 36_129_062.16).when(F.col("h3_resolution") == 7, 5_161_293.36).when(F.col("h3_resolution") == 8, 737_327.60).otherwise(105_332.51)
    agg = agg.withColumn("hex_area_m2", hex_area_expr).withColumn("hex_area_km2", F.col("hex_area_m2") / 1e6)
    agg = agg.withColumn("human_footprint_area_m2", F.col("building_area_m2") + F.col("industrial_area_m2") + F.col("parks_green_area_m2") + F.col("waste_site_area_m2"))
    agg = agg.withColumn("urban_footprint_area_m2", F.coalesce(F.col("human_footprint_area_m2"), F.lit(0)) + F.coalesce(F.col("residential_area_m2"), F.lit(0)) + F.coalesce(F.col("commercial_area_m2"), F.lit(0)) + F.coalesce(F.col("parking_area_m2"), F.lit(0)) + F.coalesce(F.col("road_area_m2"), F.lit(0)) + F.coalesce(F.col("cemetery_area_m2"), F.lit(0)) + F.coalesce(F.col("construction_area_m2"), F.lit(0)))
    agg = agg.withColumn("water_surface_area_m2", F.coalesce(F.col("waterbody_area_m2"), F.lit(0)) + F.coalesce(F.col("waterway_area_m2"), F.lit(0)) + F.coalesce(F.col("wetland_area_m2"), F.lit(0)))
    agg = agg.withColumn("water_wetland_area_pct", (F.col("waterbody_area_m2") + F.col("wetland_area_m2") + F.col("waterway_area_m2")) / F.col("hex_area_m2") * 100)
    agg = agg.withColumn("waterbody_area_pct", F.col("waterbody_area_m2") / F.col("hex_area_m2") * 100)
    agg = agg.withColumn("waterway_area_pct", F.col("waterway_area_m2") / F.col("hex_area_m2") * 100)
    agg = agg.withColumn("wetland_area_pct", F.col("wetland_area_m2") / F.col("hex_area_m2") * 100)
    agg = agg.withColumn("road_area_pct", F.col("road_area_m2") / F.col("hex_area_m2") * 100)
    agg = agg.withColumn("building_area_pct", F.col("building_area_m2") / F.col("hex_area_m2") * 100)
    agg = agg.withColumn("parks_green_area_pct", F.col("parks_green_area_m2") / F.col("hex_area_m2") * 100)
    agg = agg.withColumn("agri_area_pct", F.col("agri_area_m2") / F.col("hex_area_m2") * 100)
    agg = agg.withColumn("managed_forest_area_pct", F.col("managed_forest_area_m2") / F.col("hex_area_m2") * 100)
    agg = agg.withColumn("natural_habitat_area_pct", F.col("natural_habitat_area_m2") / F.col("hex_area_m2") * 100)
    agg = agg.withColumn("protected_area_pct", F.col("protected_area_m2") / F.col("hex_area_m2") * 100)
    agg = agg.withColumn("restricted_area_pct", F.col("restricted_area_m2") / F.col("hex_area_m2") * 100)
    agg = agg.withColumn("human_footprint_area_pct", F.col("human_footprint_area_m2") / F.col("hex_area_m2") * 100)
    agg = agg.withColumn("industrial_area_pct", F.col("industrial_area_m2") / F.col("hex_area_m2") * 100)
    agg = agg.withColumn("residential_area_pct", F.col("residential_area_m2") / F.col("hex_area_m2") * 100)
    agg = agg.withColumn("commercial_area_pct", F.col("commercial_area_m2") / F.col("hex_area_m2") * 100)
    agg = agg.withColumn("parking_area_pct", F.col("parking_area_m2") / F.col("hex_area_m2") * 100)
    agg = agg.withColumn("cemetery_area_pct", F.col("cemetery_area_m2") / F.col("hex_area_m2") * 100)
    agg = agg.withColumn("construction_area_pct", F.col("construction_area_m2") / F.col("hex_area_m2") * 100)
    agg = agg.withColumn("retention_basin_area_pct", F.col("retention_basin_area_m2") / F.col("hex_area_m2") * 100)
    agg = agg.withColumn("urban_footprint_area_pct", F.col("urban_footprint_area_m2") / F.col("hex_area_m2") * 100)
    agg = agg.withColumn("road_count_per_km2", F.col("road_count") / F.col("hex_area_km2"))
    agg = agg.withColumn("building_count_per_km2", F.col("building_count") / F.col("hex_area_km2"))
    agg = agg.withColumn("power_plant_count_per_km2", F.col("power_plant_count") / F.col("hex_area_km2"))
    agg = agg.withColumn("protected_area_count_per_km2", F.col("protected_area_count") / F.col("hex_area_km2"))
    agg = agg.filter(F.col("h3_resolution").isin(H3_RESOLUTIONS))
    return agg

# ─── RUN ───
print("Reading silver...")
df = read_silver()
df = ensure_length_area(df).cache()
print("Exploding h3 resolutions & computing hex metrics (single pass)...")
df_exploded = explode_h3_resolutions(df)
agg = compute_hex_metrics_unified(df_exploded).cache()
n = agg.count()
gold_base = f"s3://{S3_BUCKET}/{GOLD_PREFIX}"
agg.write.mode("overwrite").partitionBy("country", "h3_resolution").parquet(gold_base, compression=PARQUET_COMPRESSION)
print(f"Done. Written {n} hexes to {gold_base}")