In [None]:
import geopandas as gpd
import pandas as pd
import numpy as np
import os
from joblib import Parallel, delayed
from tqdm import tqdm

# Environment variable for GDAL (optional)
os.environ['GDAL_DATA'] = r"# path to GDAL share directory, if required"

# Configuration
input_base = r"# path to daily weather input files"
output_base = r"# path to save tract-interpolated outputs"
years_to_process = range(1980, 2020)
n_jobs = 12
decade = "" #Choose right decade vintage 
batch_size = 25

tract_shapefile = r"# path to 1990 Census Tract shapefile"
tract_id_col = "GISJOIN2"

group_cols = [
    'tmin', 'tmax', 'tmean',
    'heat_index_min_c', 'heat_index_max_c', 'heat_index_mean_c',
    'wbgt_liljegren_min', 'wbgt_liljegren_max', 'wbgt_liljegren_mean'
]

# Load a sample file to build static grid
sample_file = None
for year in years_to_process:
    folder = os.path.join(input_base, f"year={year}")
    if os.path.exists(folder):
        for f in os.listdir(folder):
            if f.endswith(".parquet"):
                sample_file = os.path.join(folder, f)
                break
    if sample_file:
        break

if not sample_file:
    raise FileNotFoundError("No .parquet file found to initialize grid.")

df_base = pd.read_parquet(sample_file, columns=['lat', 'lon']).astype({'lat': np.float32, 'lon': np.float32})
df_base["combo_latlon_id"] = df_base["lat"].astype(str) + "_" + df_base["lon"].astype(str)
gdf_grid = gpd.GeoDataFrame(
    df_base,
    geometry=gpd.points_from_xy(df_base["lon"], df_base["lat"]),
    crs="EPSG:4326"
).to_crs("EPSG:5070")
gdf_grid["geometry"] = gdf_grid.geometry.buffer(500, cap_style=3)
gdf_grid = gdf_grid.to_crs("EPSG:4326")
del df_base

# Load tract shapefile
gdf_tracts_proj = gpd.read_file(tract_shapefile).to_crs("EPSG:5070")
gdf_tracts_proj = gdf_tracts_proj[[tract_id_col, 'geometry']]

# Overlay cache
overlay_cache_path = os.path.join(output_base, f"grid_to_tract_overlay_{decade}.parquet")
if not os.path.exists(overlay_cache_path):
    gdf_grid_proj = gdf_grid.to_crs("EPSG:5070")
    overlay = gpd.overlay(gdf_grid_proj, gdf_tracts_proj, how="intersection")
    overlay["area_overlap_m2"] = overlay.geometry.area.astype(np.float32)
    overlay["combo_latlon_id"] = overlay["combo_latlon_id"]
    overlay = overlay[[tract_id_col, "combo_latlon_id", "area_overlap_m2"]].copy()
    overlay.to_parquet(overlay_cache_path, index=False, compression="zstd")
    del gdf_grid_proj
else:
    overlay = pd.read_parquet(overlay_cache_path)

# File processing function
def process_file(file_path):
    try:
        df_weather = pd.read_parquet(file_path, columns=['lat', 'lon'] + group_cols)
        df_weather["combo_latlon_id"] = df_weather["lat"].astype(str) + "_" + df_weather["lon"].astype(str)
        day_val = pd.to_datetime(pd.read_parquet(file_path, columns=['day'])['day'].iloc[0])

        merged = overlay.merge(df_weather, on="combo_latlon_id", how="inner")
        if merged.empty:
            return f"No match on {day_val.strftime('%Y-%m-%d')}"

        def weighted_avg(group, col):
            return np.average(group[col], weights=group['area_overlap_m2']) if not group.empty else np.nan

        agg_dict = {var: lambda g, var=var: weighted_avg(g, var) for var in group_cols}
        agg_dict.update({
            'area_covered_m2': lambda g: g['area_overlap_m2'].sum(),
            'n_cells': lambda g: len(g)
        })

        summary = (
            merged.groupby(tract_id_col, group_keys=False)
            .apply(lambda g: pd.Series({k: f(g) for k, f in agg_dict.items()}))
            .reset_index()
        )

        summary = summary.merge(gdf_tracts_proj, on=tract_id_col, how='left')
        summary = gpd.GeoDataFrame(summary, crs="EPSG:5070").to_crs("EPSG:4326")
        summary["day"] = day_val
        summary["geometry"] = summary["geometry"].buffer(0)

        year_str = day_val.strftime('%Y')
        out_folder = os.path.join(output_base, f"year={year_str}")
        os.makedirs(out_folder, exist_ok=True)
        out_path = os.path.join(out_folder, f"{day_val.strftime('%Y-%m-%d')}_weighted.parquet")
        summary.to_parquet(out_path, index=False, compression="zstd")

        return f"{day_val.strftime('%Y-%m-%d')} processed"

    except Exception as e:
        return f"Failed: {file_path} → {e}"

# Collect and batch files
all_files = []
for year in years_to_process:
    folder = os.path.join(input_base, f"year={year}")
    if not os.path.exists(folder):
        continue
    all_files += [
        os.path.join(folder, f) for f in os.listdir(folder)
        if f.endswith(".parquet") and not f.startswith(".")
    ]

batches = [all_files[i:i + batch_size] for i in range(0, len(all_files), batch_size)]

for i, batch in enumerate(batches):
    print(f"\nBatch {i+1}/{len(batches)} — {len(batch)} files")
    results = Parallel(n_jobs=n_jobs)(
        delayed(process_file)(fp) for fp in tqdm(batch, desc=f"Batch {i+1}")
    )
    completed = sum("processed" in r for r in results)
    print(f"Batch {i+1} done. {completed}/{len(batch)} succeeded.")