
# 03 · Destination Reranker (Plus)
**Goal:** stronger Task‑A baseline by adding (1) larger candidate recall, (2) historical priors, and (3) structured geo/channel features.  
**Pipeline:** build candidates → enrich features → pointwise LR reranker → Top‑K metrics.


In [None]:

# === Path bootstrap (so `utils/` is importable if running inside notebooks/) ===
import sys, os
from pathlib import Path
proj_root = Path.cwd()
if (proj_root.name.lower() == "notebooks" or not (proj_root/"utils").exists()) and (proj_root.parent/"utils").exists():
    proj_root = proj_root.parent
if str(proj_root) not in sys.path:
    sys.path.append(str(proj_root))
print("Project root:", proj_root)

# === Unified imports ===
import polars as pl
import numpy as np, pandas as pd, json, joblib, time

from sklearn.preprocessing import OneHotEncoder, StandardScaler
from sklearn.compose import ColumnTransformer
from sklearn.pipeline import Pipeline
from sklearn.linear_model import LogisticRegression
from sklearn.impute import SimpleImputer

from utils.config import DATA_DIR, INTERIM_DIR, PROCESSED_DIR
from utils.splits import temporal_split, add_crisis_flag
from utils.candidates import build_origin_next_transitions, global_mf_next, build_pc_coords, build_candidates_for_split
from utils.features import build_ports_attr, compute_port_degree, attach_port_side, build_sample_side, merge_all_features
from utils.metrics import eval_topk_mrr
from utils.etl_clean import ensure_interim

# Use all CPU cores for Polars
os.environ["POLARS_MAX_THREADS"] = str(os.cpu_count())

# Display and string cache
pl.Config.set_tbl_rows(5)
pl.Config.set_tbl_cols(10)
pl.Config.set_tbl_formatting("ASCII_FULL")
if hasattr(pl, "enable_string_cache"):
    pl.enable_string_cache()
elif hasattr(pl, "toggle_string_cache"):
    pl.toggle_string_cache(True)


### 1) Load samples & cleaned tables; temporal split

In [None]:

# Load / rebuild samples_taskA if missing
samples_path = PROCESSED_DIR / "samples_taskA.parquet"
if not samples_path.exists():
    print("samples_taskA.parquet not found -> rebuilding from cleaned port_calls...")
    pc_clean = ensure_interim()  # returns cleaned port_calls as Polars DataFrame
    pc_clean = pc_clean.sort(["vessel_id","start_utc"])
    pc_clean = pc_clean.with_columns([
        pl.col("id").shift(-1).over("vessel_id").alias("next_call_id"),
        pl.col("destination").shift(-1).over("vessel_id").alias("next_call_name"),
        pl.col("start_utc").alias("call_ts")
    ])
    keep_cols = [
        "id","vessel_id","destination","destination_latitude","destination_longitude",
        "call_ts","next_call_id","next_call_name","is_load","is_discharge",
        "prev_dist_km","last_leg_knots_est","product_family_dom"
    ]
    keep_cols = [c for c in keep_cols if c in pc_clean.columns]
    samples = pc_clean.select(keep_cols).rename({"id":"sample_port_call_id"})
    samples = samples.filter(pl.col("next_call_name").is_not_null())
    samples.write_parquet(samples_path)
else:
    samples = pl.read_parquet(samples_path)

# Ensure call_ts is datetime
if samples.schema.get("call_ts") == pl.Utf8:
    samples = samples.with_columns(pl.col("call_ts").str.strptime(pl.Datetime, strict=False))

# Load cleaned port_calls & static CSVs
pc = pl.read_parquet(INTERIM_DIR / "port_calls.cleaned.parquet")
tr = pl.read_csv(DATA_DIR / "trades.csv",  try_parse_dates=True)
vs = pl.read_csv(DATA_DIR / "vessels.csv", try_parse_dates=True)

# Temporal split (Jan-Sep Train, Oct-Nov Val, Dec Test inside utils.splits)
train, val, test = temporal_split(samples)
train = add_crisis_flag(train); val = add_crisis_flag(val); test = add_crisis_flag(test)

print("train/val/test rows:", train.height, val.height, test.height)


### 2) Build transitions, coordinates, and candidate sets (larger N/M for better recall)

In [None]:

# Transitions (from training only) & global-most-frequent next
trans = build_origin_next_transitions(train)
g_top = global_mf_next(trans)

# Coordinates
pc_coords = build_pc_coords(pc)

# Candidate sets (N/M enlarged to 30/30); keep cached files if present
def maybe_build_or_load(path, builder):
    if path.exists():
        return pl.read_parquet(path)
    df = builder()
    df.write_parquet(path)
    return df

cand_train_path = PROCESSED_DIR / "cand_train_N30_M30.parquet"
cand_val_path   = PROCESSED_DIR / "cand_val_N30_M30.parquet"
cand_test_path  = PROCESSED_DIR / "cand_test_N30_M30.parquet"

cand_train = maybe_build_or_load(
    cand_train_path,
    lambda: build_candidates_for_split(train, trans, pc_coords, add_true_label=True,  N=30, M=30, global_top1=g_top)
)
cand_val = maybe_build_or_load(
    cand_val_path,
    lambda: build_candidates_for_split(val,   trans, pc_coords, add_true_label=True,  N=30, M=30, global_top1=g_top)
)
cand_test = maybe_build_or_load(
    cand_test_path,
    lambda: build_candidates_for_split(test,  trans, pc_coords, add_true_label=False, N=30, M=30, global_top1=g_top)
)

print("cand_train / val / test shapes:", cand_train.shape, cand_val.shape, cand_test.shape)

# Optional: downsample training candidates for speed (keeps all positives)
sampled_path = PROCESSED_DIR / "cand_train_N30_M30_sampled500k.parquet"
if sampled_path.exists():
    cand_train = pl.read_parquet(sampled_path)
else:
    pos = cand_train.filter(pl.col("y")==1)
    neg = cand_train.filter(pl.col("y")==0)
    target_n = 500_000 - pos.height
    target_n = max(50_000, target_n)  # keep a floor
    frac = min(1.0, target_n / max(1, neg.height))
    neg = neg.sample(n=int(target_n), seed=42) if frac >= 1.0 else neg.sample(frac=frac, seed=42)
    cand_train = pl.concat([pos, neg])
    cand_train.write_parquet(sampled_path)
print("cand_train after sampling:", cand_train.shape)


### 3) Attach port-side features and sample-side features

In [None]:

# Port attributes and network degrees
ports_attr  = build_ports_attr(pc_coords)
port_degree = compute_port_degree(trans)

cand_train  = attach_port_side(cand_train, ports_attr, port_degree)
cand_val    = attach_port_side(cand_val,   ports_attr, port_degree)
cand_test   = attach_port_side(cand_test,  ports_attr, port_degree)

# Sample-side features (vessel static, speed/seasonal, laden/product flags)
s_side = build_sample_side(samples, pc, vs)

cand_train = merge_all_features(cand_train, s_side, train)
cand_val   = merge_all_features(cand_val,   s_side, val)
cand_test  = merge_all_features(cand_test,  s_side, test)

print("cols:", len(cand_train.columns))


### 4) Add historical priors & structured geo/channel features

In [None]:

# ---- (A) Base prior: P(candidate | origin) and hist counts from training transitions ----
prior_base = (
    trans
    .with_columns(pl.col("cnt").sum().over("destination").alias("tot"))
    .with_columns((pl.col("cnt")/pl.col("tot")).alias("prior_prob_oc"))
    .select(["destination","next_call_name","prior_prob_oc","cnt"])
    .rename({"destination":"origin","next_call_name":"candidate","cnt":"hist_cnt_oc"})
)

def join_priors(cdf: pl.DataFrame) -> pl.DataFrame:
    return cdf.join(prior_base, on=["origin","candidate"], how="left")

cand_train = join_priors(cand_train)
cand_val   = join_priors(cand_val)
cand_test  = join_priors(cand_test)

# ---- (B) Conditional priors: add vessel_type / dwt_bucket / laden / product_family_dom ----
def add_conditional_prior(cdf: pl.DataFrame, keys: list[str], col_suffix: str) -> pl.DataFrame:
    if not keys:
        return cdf
    # Use positive instances in training as observed transitions
    src = cand_train.select(["origin","candidate"] + [k for k in keys if k in cand_train.columns] + ["y"])
    keys_present = [k for k in keys if k in src.columns]
    if len(keys_present) != len(keys):
        return cdf  # skip if some keys missing
    src = (
        src.filter(pl.col("y")==1)
           .group_by(["origin","candidate"] + keys_present).len().rename({"len":"cnt"})
           .with_columns(pl.col("cnt").sum().over(["origin"] + keys_present).alias("tot"))
           .with_columns((pl.col("cnt")/pl.col("tot")).alias(f"prior_prob_{col_suffix}"))
           .select(["origin","candidate"] + keys_present + [f"prior_prob_{col_suffix}"])
    )
    return cdf.join(src, on=["origin","candidate"] + keys_present, how="left")

cand_train = add_conditional_prior(cand_train, ["vessel_type"],         "oc_vtype")
cand_val   = add_conditional_prior(cand_val,   ["vessel_type"],         "oc_vtype")
cand_test  = add_conditional_prior(cand_test,  ["vessel_type"],         "oc_vtype")

cand_train = add_conditional_prior(cand_train, ["dwt_bucket"],          "oc_dwt")
cand_val   = add_conditional_prior(cand_val,   ["dwt_bucket"],          "oc_dwt")
cand_test  = add_conditional_prior(cand_test,  ["dwt_bucket"],          "oc_dwt")

cand_train = add_conditional_prior(cand_train, ["is_laden_after_call"], "oc_laden")
cand_val   = add_conditional_prior(cand_val,   ["is_laden_after_call"], "oc_laden")
cand_test  = add_conditional_prior(cand_test,  ["is_laden_after_call"], "oc_laden")

cand_train = add_conditional_prior(cand_train, ["product_family_dom"],  "oc_pf")
cand_val   = add_conditional_prior(cand_val,   ["product_family_dom"],  "oc_pf")
cand_test  = add_conditional_prior(cand_test,  ["product_family_dom"],  "oc_pf")

# Fill NAs for prior columns
def fill_prior_nas(df: pl.DataFrame) -> pl.DataFrame:
    for c in df.columns:
        if c.startswith("prior_prob_") or c in ("hist_cnt_oc",):
            df = df.with_columns(pl.col(c).fill_null(0.0))
    return df

cand_train = fill_prior_nas(cand_train)
cand_val   = fill_prior_nas(cand_val)
cand_test  = fill_prior_nas(cand_test)

# ---- (C) Structured geo/channel features ----
WAYPOINT_RX = "(?i)light|anchorage|canal|suez|panama|offshore|STS"
def add_geo_feats(df: pl.DataFrame) -> pl.DataFrame:
    df = df.with_columns([
        pl.col("candidate").cast(pl.Utf8).str.contains(WAYPOINT_RX).fill_null(False).cast(pl.Int8).alias("cand_is_waypoint"),
        pl.when(pl.col("dist_km").is_not_null()).then(pl.col("dist_km").log1p()).otherwise(pl.lit(0.0)).alias("log_dist_km"),
    ])
    # rank by distance within sample
    df = df.with_columns(
        pl.when(pl.col("dist_km").is_not_null())
          .then(pl.col("dist_km").rank("min").over("sample_port_call_id"))
          .otherwise(pl.lit(None))
          .alias("geo_rank_in_sample")
    )
    # fill missing ranks by large number (worse)
    df = df.with_columns(pl.col("geo_rank_in_sample").fill_null(pl.lit(9999)))
    return df

cand_train = add_geo_feats(cand_train)
cand_val   = add_geo_feats(cand_val)
cand_test  = add_geo_feats(cand_test)

print("Feature enrich done.")


### 5) Candidate recall sanity check

In [None]:

def candidate_recall(cands: pl.DataFrame) -> float:
    return (
        cands.group_by("sample_port_call_id")
             .agg(pl.col("y").max().alias("has_truth"))
             .select(pl.col("has_truth").mean())
             .item()
    )

print({
    "recall_train": float(candidate_recall(cand_train)),
    "recall_val":   float(candidate_recall(cand_val)),
    "recall_test":  float(candidate_recall(cand_test)),
})


### 6) Train reranker (LR + priors) and evaluate Top‑K

In [None]:

num_cols = [
    "dist_km","log_dist_km","geo_rank_in_sample",
    "is_same_region","in_cnt","out_cnt","age",
    "prev_dist_km","last_leg_knots_est",
    "month_sin","month_cos","dow_sin","dow_cos",
    "is_crisis_time","dist_x_crisis",
    "cand_is_waypoint",
    "prior_prob_oc","hist_cnt_oc",
    "prior_prob_oc_vtype","prior_prob_oc_dwt","prior_prob_oc_laden","prior_prob_oc_pf",
]
cat_cols = ["origin","candidate","vessel_type","dwt_bucket","product_family_dom"]

def to_xy(df: pl.DataFrame):
    base = ["sample_port_call_id","origin","candidate","label","y"]
    keep = list(dict.fromkeys(base + num_cols + cat_cols))
    # add missing columns
    for c in keep:
        if c not in df.columns:
            df = df.with_columns((pl.lit(0.0) if c in num_cols else pl.lit("unk")).alias(c))
    # select and unique per (sample, candidate)
    df = df.select(keep).unique(subset=["sample_port_call_id","candidate"], keep="first")
    pdf = df.to_pandas()
    X = pdf[num_cols + cat_cols]
    y = pdf["y"].astype(int).values
    meta = pdf[["sample_port_call_id","origin","candidate","label"]]
    return X, y, meta

Xtr, ytr, mtr = to_xy(cand_train)
Xva, yva, mva = to_xy(cand_val)
Xte, yte, mte = to_xy(cand_test)

# Build pipeline
numeric_transformer = Pipeline(steps=[
    ('imputer', SimpleImputer(strategy='median')),
    ('scaler', StandardScaler(with_mean=False))  # with_mean=False works nicely with sparse later
])

try:
    ohe = OneHotEncoder(handle_unknown="ignore", sparse_output=True)
except TypeError:
    ohe = OneHotEncoder(handle_unknown="ignore", sparse=True)

preproc = ColumnTransformer(
    transformers=[
        ("num", numeric_transformer, num_cols),
        ("cat", ohe, cat_cols)
    ],
    remainder="drop",
    sparse_threshold=1.0
)

try:
    clf = LogisticRegression(max_iter=2000, class_weight="balanced", n_jobs=-1, solver="saga", C=0.5)
except TypeError:
    clf = LogisticRegression(max_iter=2000, class_weight="balanced", n_jobs=-1)

pipe = Pipeline([("prep", preproc), ("clf", clf)])

t0 = time.time()
pipe.fit(Xtr, ytr)
print(f"Fit time: {time.time()-t0:.1f}s")

# Ranking & metrics
def rank_predict(pipe, X, meta, k=5):
    proba = pipe.predict_proba(X)[:, 1]
    meta2 = meta.copy()
    meta2["score"] = proba
    topk = {}
    truth = {}
    for sid, g in meta2.groupby("sample_port_call_id"):
        g2 = g.sort_values("score", ascending=False)
        topk[sid] = g2["candidate"].tolist()
        truth[sid] = g["label"].iloc[0]
    sids = list(topk.keys())
    preds = [topk[sid] for sid in sids]
    truths = [truth[sid] for sid in sids]
    return preds, truths

preds_val, truth_val = rank_predict(pipe, Xva, mva, k=5)
preds_te,  truth_te  = rank_predict(pipe, Xte, mte,  k=5)

print("VAL:",  eval_topk_mrr([p[:5] for p in preds_val], truth_val, ks=(1,3,5)))
print("TEST:", eval_topk_mrr([p[:5] for p in preds_te],  truth_te,  ks=(1,3,5)))


### 7) Persist model and optional Top‑5 tables

In [None]:

outm = PROCESSED_DIR / "model_taskA_logreg_plus.joblib"
joblib.dump(pipe, outm)
print("Model saved to:", outm)

# Optional: persist top-5 candidate tables for inspection
def dump_topk(meta: pd.DataFrame, scores: np.ndarray, out_path: Path, k=5):
    meta2 = meta.copy()
    meta2["score"] = scores
    rows = []
    for sid, g in meta2.groupby("sample_port_call_id"):
        g2 = g.sort_values("score", ascending=False).head(k).reset_index(drop=True)
        for i, row in g2.iterrows():
            rows.append({
                "sample_port_call_id": sid,
                "rank": i+1,
                "candidate": row["candidate"],
                "score": row["score"],
                "label": row["label"],
            })
    out_df = pl.from_pandas(pd.DataFrame(rows))
    out_df.write_parquet(out_path)
    print("Saved:", out_path)

# Save val/test top5
from pathlib import Path as _Path
val_scores = pipe.predict_proba(Xva)[:,1]
test_scores= pipe.predict_proba(Xte)[:,1]
dump_topk(mva, val_scores,  PROCESSED_DIR / "val_top5_taskA_logreg_plus.parquet",  k=5)
dump_topk(mte, test_scores, PROCESSED_DIR / "test_top5_taskA_logreg_plus.parquet", k=5)
