In [None]:
# OSM Bronze → Silver (SageMaker Studio)
# Run this cell first to install dependencies, then restart kernel
%pip install -q osmium geopandas pyarrow s3fs h3 pyproj shapely

In [None]:
from __future__ import annotations
# OSM Bronze → Silver: Energy & Infrastructure (SageMaker Studio)
# Reads pre-split bboxes from S3 (bronze/osm/bboxes/bbox_00.osm.pbf .. bbox_19.osm.pbf), parallel processing

# ─── CONFIG ───
PARALLELISM = 16
DRY_RUN = True  # True = only process 1 bbox
COUNTRY = "ES"
S3_BUCKET = "ie-datalake"
BBOX_S3_PREFIX = "bronze/osm/bboxes"
SILVER_PREFIX = "silver/osm/features_energy"
BATCH_SIZE = 20_000
HEARTBEAT_EVERY = 100_000
PARQUET_COMPRESSION = "snappy"

import json, logging, os, sys, tempfile
from concurrent.futures import ProcessPoolExecutor, as_completed
from datetime import date
from pathlib import Path
from typing import Any
import geopandas as gpd
import h3
import osmium as osm
import pandas as pd
import pyarrow as pa
import pyarrow.parquet as pq
import s3fs
from shapely import wkb
from shapely.ops import transform
from pyproj import Transformer

_GEOM_TRANSFORMER = Transformer.from_crs("EPSG:4326", "EPSG:25830", always_xy=True)

class FlushHandler(logging.StreamHandler):
    def emit(self, record):
        super().emit(record)
        self.flush()

logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s", datefmt="%H:%M:%S", force=True)
log = logging.getLogger("osm_silver")
for h in logging.root.handlers[:]:
    logging.root.removeHandler(h)
logging.root.addHandler(FlushHandler(sys.stdout))
SNAPSHOT_DATE = "2026-02-25"

FEATURE_TYPES = ["PIPELINES", "POWER_LINES", "POWER_SUBSTATIONS", "POWER_PLANTS", "INDUSTRIAL_AREAS", "STORAGE_TANKS", "FUEL_STATIONS", "PORTS_TERMINALS", "AIRPORTS", "ROADS", "RAIL", "ADMIN_BOUNDARIES", "PROTECTED_AREAS", "WATERWAYS", "WATERBODIES", "WETLANDS", "COASTLINE", "WATER_BARRIERS", "WATER_INFRA_POI", "RESTRICTED_AREAS", "RESIDENTIAL_AREAS", "COMMERCIAL_AREAS", "PARKING_AREAS", "CEMETERIES", "CONSTRUCTION", "RETENTION_BASIN", "BUILDINGS", "AMENITIES_POI", "WASTE_POLLUTION", "LANDUSE_AGRICULTURE", "FORESTRY_MANAGED", "NATURAL_HABITATS", "BARRIERS", "LINEAR_DISTURBANCE", "PARKS_GREEN_URBAN", "TREE_ROWS_HEDGEROWS", "TRAILS_TRACKS"]

ROAD_WIDTH_M = {"motorway": 12, "trunk": 10, "primary": 9, "secondary": 8, "tertiary": 7, "residential": 7, "service": 6, "unclassified": 6, "living_street": 6}

PROMOTED_KEYS = ("pipeline", "substance", "location", "operator", "ref", "power", "voltage", "cables", "substation", "plant_source", "generator_source", "capacity", "landuse", "industrial", "man_made", "content", "amenity", "brand", "opening_hours", "harbour", "aeroway", "iata", "icao", "highway", "surface", "bridge", "tunnel", "oneway", "maxspeed", "access", "railway", "electrified", "usage", "admin_level", "boundary", "protect_class", "designation", "waterway", "water", "wetland", "intermittent", "width", "reservoir", "military", "building", "building_levels", "crop", "forest", "natural", "barrier", "sac_scale", "leisure")
TAGS_JSON_KEYS = frozenset(["name", "operator", "ref", "brand", "opening_hours", "pipeline", "substance", "location", "power", "voltage", "cables", "substation", "plant:source", "generator:source", "plant:output:electricity", "capacity", "landuse", "industrial", "man_made", "content", "amenity", "harbour", "aeroway", "iata", "icao", "highway", "surface", "bridge", "tunnel", "oneway", "maxspeed", "access", "railway", "electrified", "usage", "boundary", "admin_level", "protect_class", "designation", "leisure", "waterway", "water", "wetland", "intermittent", "width", "reservoir", "military", "building", "building:levels", "crop", "forest", "natural", "barrier", "sac_scale", "emergency"])

def _tags_dict(obj): return {t.k: t.v for t in obj.tags}
def prune_tags(tags): return {k: v for k, v in tags.items() if k in TAGS_JSON_KEYS and not k.startswith("name:")}

def match_feature_type(tags):
    if tags.get("man_made") == "pipeline" or tags.get("pipeline") or tags.get("substance") or tags.get("location") in ("underground", "overground"): return "PIPELINES", {"pipeline": tags.get("pipeline"), "substance": tags.get("substance"), "location": tags.get("location"), "operator": tags.get("operator"), "ref": tags.get("ref")}
    if tags.get("power") in ("line", "minor_line", "cable"): return "POWER_LINES", {"power": tags.get("power"), "voltage": tags.get("voltage"), "cables": tags.get("cables"), "location": tags.get("location"), "operator": tags.get("operator")}
    if tags.get("power") == "substation": return "POWER_SUBSTATIONS", {"substation": tags.get("substation"), "voltage": tags.get("voltage"), "operator": tags.get("operator")}
    if tags.get("power") in ("plant", "generator"): return "POWER_PLANTS", {"plant_source": tags.get("plant:source") or tags.get("plant_source"), "generator_source": tags.get("generator:source") or tags.get("generator_source"), "operator": tags.get("operator"), "capacity": tags.get("plant:output:electricity") or tags.get("capacity"), "power": tags.get("power")}
    if tags.get("landuse") == "industrial" or tags.get("industrial") or tags.get("man_made") == "works": return "INDUSTRIAL_AREAS", {"landuse": tags.get("landuse"), "industrial": tags.get("industrial"), "man_made": tags.get("man_made"), "operator": tags.get("operator")}
    if tags.get("man_made") == "storage_tank" or tags.get("landuse") == "depot" or tags.get("industrial") in ("oil", "chemical"): return "STORAGE_TANKS", {"man_made": tags.get("man_made"), "content": tags.get("content") or tags.get("substance"), "operator": tags.get("operator")}
    if tags.get("amenity") == "fuel": return "FUEL_STATIONS", {"amenity": tags.get("amenity"), "brand": tags.get("brand"), "operator": tags.get("operator"), "opening_hours": tags.get("opening_hours")}
    if tags.get("harbour") or tags.get("landuse") == "port" or tags.get("man_made") == "pier" or tags.get("amenity") == "ferry_terminal": return "PORTS_TERMINALS", {"harbour": tags.get("harbour"), "landuse": tags.get("landuse"), "man_made": tags.get("man_made"), "operator": tags.get("operator")}
    if tags.get("aeroway") in ("aerodrome", "terminal", "runway", "taxiway"): return "AIRPORTS", {"aeroway": tags.get("aeroway"), "iata": tags.get("iata"), "icao": tags.get("icao"), "operator": tags.get("operator")}
    if tags.get("highway") in ("path", "track", "footway", "bridleway", "cycleway"): return "TRAILS_TRACKS", {"highway": tags.get("highway"), "surface": tags.get("surface"), "access": tags.get("access"), "sac_scale": tags.get("sac_scale"), "name": tags.get("name")}
    if tags.get("highway"): return "ROADS", {"highway": tags.get("highway"), "surface": tags.get("surface"), "bridge": tags.get("bridge"), "tunnel": tags.get("tunnel"), "oneway": tags.get("oneway"), "maxspeed": tags.get("maxspeed"), "access": tags.get("access"), "ref": tags.get("ref")}
    if tags.get("railway") in ("rail", "light_rail", "subway", "tram"): return "RAIL", {"railway": tags.get("railway"), "electrified": tags.get("electrified"), "operator": tags.get("operator"), "usage": tags.get("usage")}
    if tags.get("boundary") == "administrative": return "ADMIN_BOUNDARIES", {"admin_level": tags.get("admin_level"), "boundary": tags.get("boundary"), "name": tags.get("name")}
    if tags.get("boundary") == "protected_area" or tags.get("leisure") == "nature_reserve": return "PROTECTED_AREAS", {"protect_class": tags.get("protect_class"), "designation": tags.get("designation"), "operator": tags.get("operator"), "name": tags.get("name")}
    if tags.get("waterway") in ("dam", "weir", "lock_gate") or tags.get("man_made") == "dam": return "WATER_BARRIERS", {"man_made": tags.get("man_made"), "waterway": tags.get("waterway"), "operator": tags.get("operator"), "name": tags.get("name")}
    if tags.get("waterway"): return "WATERWAYS", {"waterway": tags.get("waterway"), "name": tags.get("name"), "intermittent": tags.get("intermittent"), "tunnel": tags.get("tunnel"), "width": tags.get("width")}
    if tags.get("landuse") == "reservoir" or tags.get("man_made") == "reservoir_covered": return "WATERBODIES", {"water": tags.get("water"), "landuse": tags.get("landuse"), "man_made": tags.get("man_made"), "name": tags.get("name"), "reservoir": tags.get("reservoir")}
    if tags.get("natural") == "water" or tags.get("water"): return "WATERBODIES", {"water": tags.get("water"), "name": tags.get("name"), "reservoir": tags.get("reservoir")}
    if tags.get("natural") == "wetland" or tags.get("wetland"): return "WETLANDS", {"wetland": tags.get("wetland"), "name": tags.get("name")}
    if tags.get("natural") == "coastline": return "COASTLINE", {"name": tags.get("name")}
    if tags.get("man_made") in ("water_tower",) or tags.get("amenity") == "drinking_water" or tags.get("emergency") == "water_tank": return "WATER_INFRA_POI", {"man_made": tags.get("man_made"), "amenity": tags.get("amenity"), "name": tags.get("name"), "operator": tags.get("operator")}
    if tags.get("military") or tags.get("landuse") == "military" or tags.get("boundary") == "military": return "RESTRICTED_AREAS", {"military": tags.get("military"), "landuse": tags.get("landuse"), "name": tags.get("name")}
    if tags.get("landuse") == "residential": return "RESIDENTIAL_AREAS", {"landuse": tags.get("landuse"), "name": tags.get("name")}
    if tags.get("landuse") in ("commercial", "retail"): return "COMMERCIAL_AREAS", {"landuse": tags.get("landuse"), "name": tags.get("name")}
    if tags.get("amenity") == "parking" or tags.get("landuse") == "garages": return "PARKING_AREAS", {"amenity": tags.get("amenity"), "landuse": tags.get("landuse"), "name": tags.get("name")}
    if tags.get("landuse") == "cemetery": return "CEMETERIES", {"landuse": tags.get("landuse"), "name": tags.get("name")}
    if tags.get("landuse") == "construction": return "CONSTRUCTION", {"landuse": tags.get("landuse"), "name": tags.get("name")}
    if tags.get("landuse") == "basin": return "RETENTION_BASIN", {"landuse": tags.get("landuse"), "name": tags.get("name")}
    if tags.get("building"): return "BUILDINGS", {"building": tags.get("building"), "building_levels": tags.get("building:levels"), "name": tags.get("name")}
    if tags.get("amenity") in ("school", "hospital", "university", "marketplace", "prison", "waste_disposal", "recycling"): return "AMENITIES_POI", {"amenity": tags.get("amenity"), "name": tags.get("name"), "operator": tags.get("operator")}
    if tags.get("landuse") == "landfill" or tags.get("man_made") == "wastewater_plant": return "WASTE_POLLUTION", {"landuse": tags.get("landuse"), "man_made": tags.get("man_made"), "name": tags.get("name"), "operator": tags.get("operator")}
    if tags.get("landuse") in ("farmland", "farmyard", "orchard", "vineyard", "meadow"): return "LANDUSE_AGRICULTURE", {"landuse": tags.get("landuse"), "crop": tags.get("crop"), "name": tags.get("name")}
    if tags.get("landuse") == "forest" or tags.get("forest"): return "FORESTRY_MANAGED", {"landuse": tags.get("landuse"), "forest": tags.get("forest"), "name": tags.get("name")}
    if tags.get("natural") in ("wood", "heath", "scrub", "grassland", "bare_rock", "sand", "beach", "cliff"): return "NATURAL_HABITATS", {"natural": tags.get("natural"), "name": tags.get("name")}
    if tags.get("natural") == "tree_row" or tags.get("barrier") == "hedge": return "TREE_ROWS_HEDGEROWS", {"natural": tags.get("natural"), "barrier": tags.get("barrier")}
    if tags.get("barrier") in ("fence", "wall", "hedge", "gate", "bollard"): return "BARRIERS", {"barrier": tags.get("barrier"), "access": tags.get("access")}
    if tags.get("man_made") in ("cutline", "embankment") or tags.get("barrier") == "ditch": return "LINEAR_DISTURBANCE", {"man_made": tags.get("man_made"), "barrier": tags.get("barrier")}
    if tags.get("leisure") in ("park", "garden", "common", "golf_course", "pitch") or tags.get("landuse") == "recreation_ground": return "PARKS_GREEN_URBAN", {"leisure": tags.get("leisure"), "landuse": tags.get("landuse"), "name": tags.get("name")}
    return None, {}

def add_h3_from_point(lat, lon):
    try:
        h9 = h3.latlng_to_cell(lat, lon, 9)
        return {"h3_6": h3.cell_to_parent(h9, 6), "h3_7": h3.cell_to_parent(h9, 7), "h3_8": h3.cell_to_parent(h9, 8), "h3_9": h9}
    except Exception: return {"h3_6": None, "h3_7": None, "h3_8": None, "h3_9": None}

def centroid_from_geom(geom):
    if geom is None or geom.is_empty: return None, None
    try:
        pt = geom.representative_point() if geom.geom_type in ("Polygon", "MultiPolygon") else geom.centroid
        return float(pt.x), float(pt.y)
    except Exception: return None, None

def length_area_from_geom(geom):
    if geom is None or geom.is_empty: return 0.0, 0.0
    try:
        g = transform(lambda x, y: _GEOM_TRANSFORMER.transform(x, y), geom)
        length = g.length if g.geom_type in ("LineString", "MultiLineString") else 0.0
        area = g.area if g.geom_type in ("Polygon", "MultiPolygon") else 0.0
        return float(length), float(area)
    except Exception: return 0.0, 0.0

def _table_no_dictionary(tbl):
    for i in range(tbl.num_columns):
        if pa.types.is_dictionary(tbl.schema.field(i).type):
            tbl = tbl.set_column(i, tbl.schema.field(i).name, tbl.column(i).dictionary_decode())
    return tbl

class EnergyFeaturesHandler(osm.SimpleHandler):
    def __init__(self, wkb_factory, country, snapshot_date, batch_callback, bbox_idx=None):
        super().__init__()
        self.wkb_factory, self.country, self.snapshot_date, self.batch_callback = wkb_factory, country, snapshot_date, batch_callback
        self._batch, self._counts = [], {ft: 0 for ft in FEATURE_TYPES}
        self._n_emitted, self._bbox_idx = 0, bbox_idx
    def _emit(self, osm_id, osm_type, feature_type, geom_wkb, tags, promoted):
        geom = wkb.loads(geom_wkb) if isinstance(geom_wkb, bytes) else (wkb.loads(bytes.fromhex(geom_wkb)) if geom_wkb else None)
        lon, lat = centroid_from_geom(geom)
        if lat is None or lon is None: return
        length_m, area_m2 = length_area_from_geom(geom)
        if feature_type == "ROADS" and area_m2 == 0 and length_m > 0:
            highway = promoted.get("highway") or ""
            area_m2 = length_m * ROAD_WIDTH_M.get(highway, 5)
        h3_cols = add_h3_from_point(lat, lon)
        promoted_full = {k: (str(v) if v is not None else None) for k in PROMOTED_KEYS for v in [promoted.get(k)]}
        row = {"osm_id": osm_id, "osm_type": osm_type, "feature_type": feature_type, "geometry": geom, "length_m": length_m, "area_m2": area_m2, "centroid_lon": lon, "centroid_lat": lat, **h3_cols, "name": tags.get("name"), "tags_json": json.dumps(prune_tags(tags)) or None, "snapshot_date": self.snapshot_date, "country": self.country, **promoted_full}
        self._batch.append(row)
        self._counts[feature_type] = self._counts.get(feature_type, 0) + 1
        self._n_emitted += 1
        if self._n_emitted % HEARTBEAT_EVERY == 0:
            log.info("Bbox %s: heartbeat %d features", self._bbox_idx if self._bbox_idx is not None else "?", self._n_emitted)
        if len(self._batch) >= BATCH_SIZE: self._flush()
    def _flush(self):
        if self._batch: self.batch_callback(self._batch); self._batch = []
    def node(self, n):
        if not n.location.valid(): return
        tags = _tags_dict(n); ft, promoted = match_feature_type(tags)
        if ft is None: return
        try: wkb_bytes = self.wkb_factory.create_point(n)
        except Exception: return
        self._emit(f"n{n.id}", "node", ft, wkb_bytes, tags, promoted)
    def way(self, w):
        if w.is_closed() and len(w.nodes) >= 4: return
        tags = _tags_dict(w); ft, promoted = match_feature_type(tags)
        if ft is None: return
        try: wkb_bytes = self.wkb_factory.create_linestring(w)
        except Exception: return
        self._emit(f"w{w.id}", "way", ft, wkb_bytes, tags, promoted)
    def area(self, a):
        tags = _tags_dict(a); ft, promoted = match_feature_type(tags)
        if ft is None: return
        try: wkb_bytes = self.wkb_factory.create_multipolygon(a)
        except Exception: return
        osm_type = "way" if a.from_way() else "relation"
        osm_id = f"w{a.orig_id()}" if a.from_way() else f"r{a.orig_id()}"
        self._emit(osm_id, osm_type, ft, wkb_bytes, tags, promoted)

def process_bbox(bbox_idx, bbox_s3_path, tmpdir_str):
    """Download bbox from S3, process, write silver (bbox-specific paths). Bboxes are read-only. Uses own s3_fs (ProcessPoolExecutor)."""
    s3_fs = s3fs.S3FileSystem()
    tmpdir = Path(tmpdir_str)
    log.info("Bbox %d: downloading from S3...", bbox_idx)
    local_path = tmpdir / f"bbox_{bbox_idx:02d}.osm.pbf"
    s3_fs.get(bbox_s3_path, str(local_path))
    size_mb = local_path.stat().st_size / (1024 * 1024)
    log.info("Bbox %d: downloaded (%.1f MB), parsing PBF...", bbox_idx, size_mb)
    part_counter = {ft: 0 for ft in FEATURE_TYPES}
    s3_prefix = f"{S3_BUCKET}/{SILVER_PREFIX}/country={COUNTRY}/snapshot_date={SNAPSHOT_DATE}"
    STRING_COLS = list(PROMOTED_KEYS) + ["osm_id", "osm_type", "feature_type", "name", "tags_json", "snapshot_date", "country", "h3_6", "h3_7", "h3_8", "h3_9"]
    def _coerce(df):
        for c in STRING_COLS:
            if c not in df.columns: continue
            s = df[c]; null_mask = s.isna(); df[c] = s.where(~null_mask).astype(str).mask(null_mask, None)
        return df
    from geopandas.io.arrow import _geopandas_to_arrow
    write_count = [0]
    def on_batch(batch):
        gdf = gpd.GeoDataFrame(batch, crs="EPSG:4326", geometry="geometry")
        for ft, sub in gdf.groupby("feature_type"):
            part_counter[ft] += 1
            out_path = f"s3://{s3_prefix}/feature_type={ft}/bbox_{bbox_idx:02d}-part-{part_counter[ft]:05d}.parquet"
            sub = _coerce(sub.copy())
            tbl = _geopandas_to_arrow(sub, index=False, geometry_encoding="WKB")
            tbl = _table_no_dictionary(tbl)
            pq.write_table(tbl, out_path, compression=PARQUET_COMPRESSION, filesystem=s3_fs, use_dictionary=False, row_group_size=len(tbl))
        write_count[0] += 1
        if write_count[0] <= 3 or write_count[0] % 10 == 0:
            log.info("Bbox %d: wrote batch %d (%d rows)", bbox_idx, write_count[0], len(batch))
    handler = EnergyFeaturesHandler(osm.geom.WKBFactory(), COUNTRY, SNAPSHOT_DATE, on_batch, bbox_idx=bbox_idx)
    handler.apply_file(str(local_path), locations=True)
    handler._flush()
    total_feats = sum(handler._counts.values())
    log.info("Bbox %d: done (%d features)", bbox_idx, total_feats)
    local_path.unlink(missing_ok=True)
    return handler._counts

# ─── RUN ───
log.info("Starting OSM Bronze -> Silver pipeline (ProcessPoolExecutor)")
bbox_list = [(i, f"{S3_BUCKET}/{BBOX_S3_PREFIX}/bbox_{i:02d}.osm.pbf") for i in range(20)]
to_process = bbox_list[:1] if DRY_RUN else bbox_list
log.info("Processing %d bbox(es) with PARALLELISM=%d", len(to_process), PARALLELISM)
if DRY_RUN: log.info("DRY_RUN: processing only 1 bbox")
total_counts = {ft: 0 for ft in FEATURE_TYPES}

with tempfile.TemporaryDirectory() as tmp2:
    tmpdir_str = str(tmp2)
    with ProcessPoolExecutor(max_workers=PARALLELISM) as ex:
        futures = {ex.submit(process_bbox, idx, s3_path, tmpdir_str): (idx, s3_path) for idx, s3_path in to_process}
        for fut in as_completed(futures):
            idx, s3_path = futures[fut]
            try:
                counts = fut.result()
                for ft, n in counts.items(): total_counts[ft] += n
                log.info("Bbox %d done", idx)
            except Exception as e:
                log.error("Bbox %d failed: %s", idx, e)
                raise

log.info("Done. Total features: %d", sum(total_counts.values()))
for ft, n in sorted(total_counts.items(), key=lambda x: -x[1]):
    if n > 0: log.info("  %s: %d", ft, n)