In [None]:
import sys
sys.path.append(r"# path to WBGT module directory")

import os
import pandas as pd
import numpy as np
from WBGT import WBGT_Liljegren
from tqdm import tqdm
from joblib import Parallel, delayed

# Set base directory and parameters
weather_base = r"# path to daily weather files"
years_to_process = range(1980, 2020)
n_jobs = 12

required_cols = [
    'tmin', 'tmax', 'tmean',
    'rh_min', 'rh_mean', 'rh_max',
    'era5_wind_speed_min', 'era5_wind_speed_mean', 'era5_wind_speed_max',
    'era5_sp_min', 'era5_sp_mean', 'era5_sp_max',
    'srad', 'czda_min', 'czda_mean', 'czda_max'
]

def process_wbgt_file(fpath):
    try:
        df = pd.read_parquet(fpath)

        output_cols = ['wbgt_liljegren_min', 'wbgt_liljegren_max', 'wbgt_liljegren_mean']
        if all(col in df.columns for col in output_cols):
            return f"Skipped {os.path.basename(fpath)} — already processed"

        if not all(col in df.columns for col in required_cols):
            return f"Skipped {os.path.basename(fpath)} — missing inputs"

        valid = df[required_cols].notnull().all(axis=1)
        if valid.sum() == 0:
            return f"Skipped {os.path.basename(fpath)} — no valid rows"

        valid_idx = df.loc[valid].index
        srad = df.loc[valid, 'srad'].values
        czda_min = np.clip(df.loc[valid, 'czda_min'].values, 0.2, 0.9)
        czda_mean = np.clip(df.loc[valid, 'czda_mean'].values, 0.2, 0.9)
        czda_max = np.clip(df.loc[valid, 'czda_max'].values, 0.2, 0.9)

        mask_min = (df.loc[valid, 'tmin'] >= 18.3)
        mask_mean = (df.loc[valid, 'tmean'] >= 18.3)
        mask_max = (df.loc[valid, 'tmax'] >= 18.3)

        if not (mask_min.any() or mask_mean.any() or mask_max.any()):
            return f"Skipped {os.path.basename(fpath)} — no rows above WBGT threshold"

        if mask_min.any():
            idx_min = valid_idx[mask_min]
            tmin = df.loc[idx_min, 'tmin'].values + 273.15
            rh_max = df.loc[idx_min, 'rh_max'].values
            sp_max = df.loc[idx_min, 'era5_sp_max'].values
            wind_max = df.loc[idx_min, 'era5_wind_speed_max'].clip(lower=0.5).values
            df.loc[idx_min, 'wbgt_liljegren_min'] = WBGT_Liljegren(
                tmin, rh_max, sp_max, wind_max, srad[mask_min],
                czda_min[mask_min], czda_min[mask_min], False
            ) - 273.15

        if mask_mean.any():
            idx_mean = valid_idx[mask_mean]
            tmean = df.loc[idx_mean, 'tmean'].values + 273.15
            rh_mean = df.loc[idx_mean, 'rh_mean'].values
            sp_mean = df.loc[idx_mean, 'era5_sp_mean'].values
            wind_mean = df.loc[idx_mean, 'era5_wind_speed_mean'].clip(lower=0.5).values
            df.loc[idx_mean, 'wbgt_liljegren_mean'] = WBGT_Liljegren(
                tmean, rh_mean, sp_mean, wind_mean, srad[mask_mean],
                czda_mean[mask_mean], czda_mean[mask_mean], False
            ) - 273.15

        if mask_max.any():
            idx_max = valid_idx[mask_max]
            tmax = df.loc[idx_max, 'tmax'].values + 273.15
            rh_min = df.loc[idx_max, 'rh_min'].values
            sp_min = df.loc[idx_max, 'era5_sp_min'].values
            wind_min = df.loc[idx_max, 'era5_wind_speed_min'].clip(lower=0.5).values
            df.loc[idx_max, 'wbgt_liljegren_max'] = WBGT_Liljegren(
                tmax, rh_min, sp_min, wind_min, srad[mask_max],
                czda_max[mask_max], czda_max[mask_max], False
            ) - 273.15

        df.to_parquet(fpath, index=False)
        return f"Processed {os.path.basename(fpath)}"

    except Exception as e:
        return f"Failed {os.path.basename(fpath)}: {e}"

for year in tqdm(years_to_process, desc="Processing years"):
    year_dir = os.path.join(weather_base, f"year={year}")
    all_files = sorted([
        os.path.join(year_dir, f) for f in os.listdir(year_dir) if f.endswith(".parquet")
    ])
    print(f"Year {year}: {len(all_files)} files")

    results = Parallel(n_jobs=n_jobs, backend="threading")(
        delayed(process_wbgt_file)(fpath) for fpath in tqdm(all_files, desc=f"{year}", leave=False)
    )

    for r in results:
        print(r)