In [4]:
# --- put this right after: bundle = safe_load_bundle(PKL_PATH) ---
from sklearn.pipeline import Pipeline
from sklearn.compose import ColumnTransformer
import re

preprocess = bundle["preprocess"]
datetime_cols_cfg = bundle.get("datetime_cols", None)

# 1) recursively patch DateTimeExpand instances inside preprocess
def _walk_and_patch(est):
    # DateTimeExpand from our shim
    if est.__class__.__name__ == "DateTimeExpand":
        # ensure features list exists and uses single-underscore naming internally
        if not hasattr(est, "features") or est.features is None:
            est.features = ["year", "month", "day", "dow", "hour"]
        # if no cols set in the pickled object, use bundle's datetime_cols
        if (not hasattr(est, "cols") or est.cols is None or len(est.cols)==0) and datetime_cols_cfg:
            est.cols = list(datetime_cols_cfg)
        # make sure we don't accidentally keep originals if the trained pipe dropped them
        if not hasattr(est, "drop_original"):
            est.drop_original = False
        return

    # sklearn containers
    if isinstance(est, Pipeline):
        for name, step in est.steps:
            _walk_and_patch(step)
    elif isinstance(est, ColumnTransformer):
        for name, trans, cols in est.transformers_:
            _walk_and_patch(trans)

_walk_and_patch(preprocess)


# 2) A sanitizer to fix any double-underscore datetime columns in the *input* df
_dt_part = r"(year|month|day|dow|hour)"
_double_dt_re = re.compile(rf"^(.*)__{_dt_part}$")

def sanitize_datetime_columns(df: pd.DataFrame) -> pd.DataFrame:
    if not isinstance(df, pd.DataFrame):
        df = pd.DataFrame(df)
    rename_map = {}
    drop_cols  = []
    for c in list(df.columns):
        m = _double_dt_re.match(c)
        if m:
            base, part = m.group(1), m.group(2)
            fixed = f"{base}_{part}"
            # if the fixed name already exists, drop the double-underscore version
            if fixed in df.columns:
                drop_cols.append(c)
            else:
                rename_map[c] = fixed
    if rename_map:
        df = df.rename(columns=rename_map)
    if drop_cols:
        df = df.drop(columns=drop_cols)
    return df
