# CV Split Search (v1)


In [1]:
# -------------------------
# 0) Config
# -------------------------
import os

CSIRO_CODE_DIR = "/notebooks/CSIRO"
DATA_ROOT = "/notebooks/kaggle/csiro"
TRAIN_CSV = f"{DATA_ROOT}/train.csv"

N_TRIALS = 100_000
TOP_K = 10
SEED_START = 0
N_SPLITS = 5

# Constraints / scoring
MIN_FOLD_N = None          # None = auto
MIN_TARGET_VAR = 1e-3
MIN_STATES_PER_FOLD = None
MIN_SEASONS_PER_FOLD = 2
N_BINS = 4
MIN_BIN_N = 5

GROUP_MODE = "state_quarter"  # "state_quarter" or "date"
DATE_COL = "Sampling_Date"
STATE_COL = "State"

OUT_PATH = "/notebooks/cv/cv_split_search_v1.csv"
SAVE_OUT = False

# Guard rails
for name, val in {
    "CSIRO_CODE_DIR": CSIRO_CODE_DIR,
    "TRAIN_CSV": TRAIN_CSV,
}.items():
    if val is None:
        raise ValueError(f"{name} is None; set it before running.")


## Search configs (preset sweeps)


In [2]:
# -------------------------
# 0.1) Search configs (presets)
# -------------------------
SEARCH_CONFIGS = [
    dict(name="strict", min_target_var=1e-3, min_seasons_per_fold=2, n_bins=4, min_bin_n=5),
    dict(name="relaxed_bins", min_target_var=1e-3, min_seasons_per_fold=2, n_bins=3, min_bin_n=3),
    dict(name="relaxed_var", min_target_var=1e-5, min_seasons_per_fold=2, n_bins=3, min_bin_n=3),
    dict(name="no_seasons", min_target_var=1e-5, min_seasons_per_fold=None, n_bins=3, min_bin_n=2),
    dict(name="loose", min_target_var=1e-6, min_seasons_per_fold=None, n_bins=3, min_bin_n=2),
    dict(name="date_group", min_target_var=1e-6, min_seasons_per_fold=None, n_bins=3, min_bin_n=2, group_mode="date"),
    # Less strict
    dict(name="very_loose", min_target_var=1e-8, min_seasons_per_fold=None, n_bins=2, min_bin_n=1),
    dict(name="very_loose_date", min_target_var=1e-8, min_seasons_per_fold=None, n_bins=2, min_bin_n=1, group_mode="date"),
]


In [3]:
# -------------------------
# 1) Imports
# -------------------------
import sys
import pandas as pd

sys.path.insert(0, CSIRO_CODE_DIR)

from csiro.data import load_train_wide
from csiro.utils_v2 import search_cv_splits


In [4]:
# -------------------------
# 2) Load data
# -------------------------
wide_df = load_train_wide(TRAIN_CSV, root=DATA_ROOT)
print("rows", len(wide_df))


rows 357


In [5]:
# -------------------------
# 3) Run search (multi-config)
# -------------------------
all_results = []
for cfg in SEARCH_CONFIGS:
    cfg_name = cfg.get("name", "cfg")
    res = search_cv_splits(
        wide_df,
        n_splits=N_SPLITS,
        n_trials=N_TRIALS,
        seed_start=SEED_START,
        top_k=TOP_K,
        group_mode=cfg.get("group_mode", GROUP_MODE),
        date_col=DATE_COL,
        state_col=STATE_COL,
        min_fold_n=MIN_FOLD_N,
        min_target_var=cfg["min_target_var"],
        min_states_per_fold=MIN_STATES_PER_FOLD,
        min_seasons_per_fold=cfg["min_seasons_per_fold"],
        n_bins=cfg["n_bins"],
        min_bin_n=cfg["min_bin_n"],
    )
    for r in res:
        r["config_name"] = cfg_name
    all_results.extend(res)
    print(cfg_name, "found", len(res))

all_results = sorted(all_results, key=lambda d: float(d.get("score", -1)), reverse=True)
results = all_results[:TOP_K]
print("total candidates", len(all_results), "top", len(results))
results[:3]


cv_search:   0%|          | 0/100000 [00:00<?, ?it/s]

strict found 0


cv_search:   0%|          | 0/100000 [00:00<?, ?it/s]

relaxed_bins found 0


cv_search:   0%|          | 0/100000 [00:00<?, ?it/s]

relaxed_var found 0


cv_search:   0%|          | 0/100000 [00:00<?, ?it/s]

no_seasons found 0


cv_search:   0%|          | 0/100000 [00:00<?, ?it/s]

loose found 0


cv_search:   0%|          | 0/100000 [00:00<?, ?it/s]

date_group found 0


cv_search:   0%|          | 0/100000 [00:00<?, ?it/s]

very_loose found 10


cv_search:   0%|          | 0/100000 [00:00<?, ?it/s]

very_loose_date found 10
total candidates 20 top 10


[{'score': 0.8990320702034426,
  'fold_sizes': [72, 70, 71, 70, 74],
  'min_target_var': 68.4383773803711,
  'size_balance': 0.9790383342612338,
  'mean_balance': 0.8800633996725082,
  'var_balance': 0.7304593324661255,
  'state_balance': 0.9669112230148316,
  'season_balance': 0.9292803411317411,
  'cv_params': {'mode': 'gkf', 'cv_seed': 34096, 'n_splits': 5},
  'group_mode': 'date',
  'config_name': 'very_loose_date'},
 {'score': 0.8833571760396148,
  'fold_sizes': [66, 69, 76, 75, 71],
  'min_target_var': 60.396522521972656,
  'size_balance': 0.947896148523196,
  'mean_balance': 0.9247022122144699,
  'var_balance': 0.6843403875827789,
  'state_balance': 0.9647576091615239,
  'season_balance': 0.8718668272342984,
  'cv_params': {'mode': 'gkf', 'cv_seed': 38568, 'n_splits': 5},
  'group_mode': 'date',
  'config_name': 'very_loose_date'},
 {'score': 0.8832377420028095,
  'fold_sizes': [76, 71, 68, 72, 70],
  'min_target_var': 43.90837860107422,
  'size_balance': 0.9628389384436519,
  '

## all results

In [17]:
results

[{'score': 0.8990320702034426,
  'fold_sizes': [72, 70, 71, 70, 74],
  'min_target_var': 68.4383773803711,
  'size_balance': 0.9790383342612338,
  'mean_balance': 0.8800633996725082,
  'var_balance': 0.7304593324661255,
  'state_balance': 0.9669112230148316,
  'season_balance': 0.9292803411317411,
  'cv_params': {'mode': 'gkf', 'cv_seed': 34096, 'n_splits': 5},
  'group_mode': 'date',
  'config_name': 'very_loose_date'},
 {'score': 0.8833571760396148,
  'fold_sizes': [66, 69, 76, 75, 71],
  'min_target_var': 60.396522521972656,
  'size_balance': 0.947896148523196,
  'mean_balance': 0.9247022122144699,
  'var_balance': 0.6843403875827789,
  'state_balance': 0.9647576091615239,
  'season_balance': 0.8718668272342984,
  'cv_params': {'mode': 'gkf', 'cv_seed': 38568, 'n_splits': 5},
  'group_mode': 'date',
  'config_name': 'very_loose_date'},
 {'score': 0.8832377420028095,
  'fold_sizes': [76, 71, 68, 72, 70],
  'min_target_var': 43.90837860107422,
  'size_balance': 0.9628389384436519,
  '

## Split inspection (single cv_params)


In [18]:
CV_PARAMS = {'mode': 'gkf', 'cv_seed': 34024, 'n_splits': 5}
from csiro.config import TARGETS
import numpy as np

In [19]:
# -------------------------
# 3.1) Inspect a split
# -------------------------
from csiro.utils_v2 import build_cv_splits

CV_TO_INSPECT = CV_PARAMS  # set to a candidate (e.g., from results)
splits = build_cv_splits(wide_df, cv_params=CV_TO_INSPECT)

fold_sizes = [len(va) for _, va in splits]
print("fold_sizes", fold_sizes, "mean", sum(fold_sizes)/len(fold_sizes))

# Per-fold target variance
y = wide_df[TARGETS].to_numpy(dtype=float)
rows = []
for f, (_, va_idx) in enumerate(splits):
    y_f = y[va_idx]
    rows.append({"fold": f, **{t: float(y_f[:, i].var()) for i, t in enumerate(TARGETS)}})
display(pd.DataFrame(rows))

# Extra diagnostics (mirrors search score components)
val_sizes = []
n_dates = []
clover_rates = []
q99s = []
max_totals = []
state_top_fracs = []
n_states = []
wa_present = []

df_local = wide_df.copy()
for fold_idx, (_, va_idx) in enumerate(splits):
    va = df_local.iloc[va_idx]
    val_sizes.append(len(va_idx))
    n_dates.append(int(va[DATE_COL].nunique()) if DATE_COL in va.columns else 0)

    clover_pos_rate = float((va["Dry_Clover_g"] > 0).mean())
    clover_rates.append(clover_pos_rate)

    ltot = np.log1p(va["Dry_Total_g"].astype(float).values)
    q99 = float(np.quantile(ltot, 0.99))
    q99s.append(q99)

    max_total = float(va["Dry_Total_g"].max())
    max_totals.append(max_total)

    if STATE_COL in va.columns:
        st = va[STATE_COL].value_counts(normalize=True)
        state_top = float(st.iloc[0]) if len(st) else 1.0
        state_top_fracs.append(state_top)
        n_states.append(int(va[STATE_COL].nunique()))
        wa_present.append(int("WA" in set(va[STATE_COL].unique())))

size_range = max(val_sizes) - min(val_sizes)
clover_range = max(clover_rates) - min(clover_rates)
q99_range = max(q99s) - min(q99s)
min_fold_max = min(max_totals)
max_fold_max = max(max_totals)
max_state_top = max(state_top_fracs) if state_top_fracs else None
min_states = min(n_states) if n_states else None
wa_folds = sum(wa_present) if wa_present else None
n_dates_range = max(n_dates) - min(n_dates) if n_dates else None

global_max_total = float(df_local["Dry_Total_g"].max())
score = (
    1.0 * size_range +
    2.0 * clover_range +
    2.0 * q99_range +
    (0.5 * max(0.0, max_state_top - 0.55) if max_state_top is not None else 0.0)
)

print("size_range", size_range)
print("clover_range", clover_range)
print("q99_range", q99_range)
print("min_fold_max_total", min_fold_max)
print("max_fold_max_total", max_fold_max)
print("global_max_total", global_max_total)
print("max_state_top_frac", max_state_top)
print("min_n_states", min_states)
print("wa_folds", wa_folds)
print("n_dates_range", n_dates_range)
print("score", score)

# State / season coverage (if available)
if STATE_COL in wide_df.columns:
    state_counts = []
    for f, (_, va_idx) in enumerate(splits):
        vals = wide_df.iloc[va_idx][STATE_COL].astype(str)
        state_counts.append({"fold": f, "n_states": vals.nunique()})
    display(pd.DataFrame(state_counts))

if DATE_COL in wide_df.columns:
    d = pd.to_datetime(wide_df[DATE_COL], errors="coerce")
    seasons = d.dt.month.apply(lambda m: ("summer" if m in (12,1,2) else "autumn" if m in (3,4,5) else "winter" if m in (6,7,8) else "spring") if pd.notna(m) else None)
    season_counts = []
    for f, (_, va_idx) in enumerate(splits):
        vals = seasons.iloc[va_idx]
        season_counts.append({"fold": f, "n_seasons": vals.nunique()})
    display(pd.DataFrame(season_counts))


fold_sizes [76, 71, 68, 72, 70] mean 71.4


Unnamed: 0,fold,Dry_Green_g,Dry_Clover_g,Dry_Dead_g,GDM_g,Dry_Total_g
0,0,676.923679,197.857983,142.898591,674.781055,1056.514256
1,1,528.172652,197.906881,87.882596,503.802423,509.772906
2,2,953.211439,43.908378,146.158662,826.779692,840.492631
3,3,394.27339,92.584008,186.910451,366.297936,578.183704
4,4,542.118397,164.881203,165.373105,589.094782,816.031568


size_range 8
clover_range 0.12693498452012386
q99_range 0.49761117973102476
min_fold_max_total 109.6
max_fold_max_total 185.7
global_max_total 185.7
max_state_top_frac 0.5571428571428572
min_n_states 3
wa_folds 3
n_dates_range 1
score 9.252663757073725


Unnamed: 0,fold,n_states
0,0,4
1,1,3
2,2,3
3,3,4
4,4,4


Unnamed: 0,fold,n_seasons
0,0,2
1,1,3
2,2,4
3,3,3
4,4,4


In [6]:
# -------------------------
# 4) Save top-K
# -------------------------
if SAVE_OUT:
    import os
    import pandas as pd

    os.makedirs(os.path.dirname(OUT_PATH), exist_ok=True)

    if results:
        df_out = pd.DataFrame(results)
        df_out.to_csv(OUT_PATH, index=False)
        print("Wrote", OUT_PATH)
    else:
        print("No candidates found; relax constraints.")
