# 03  Candidate Recall + Logistic Ranker

In [1]:
#Adjustable: Add parent directory (which contains utils/) to Python search path
import sys, os
sys.path.append(os.path.abspath(".."))  #  notebooks  sys.path

In [2]:
# === Unified imports ===
import polars as pl

# Display and string cache setup
pl.Config.set_tbl_rows(5)
pl.Config.set_tbl_cols(10)
pl.Config.set_tbl_formatting("ASCII_FULL")

# Compatibility: enable global string cache for joins
if hasattr(pl, "enable_string_cache"):
    pl.enable_string_cache()   # Polars ≥ 1.0
elif hasattr(pl, "toggle_string_cache"):
    pl.toggle_string_cache(True)  # legacy

# --- Standard imports ---
import numpy as np, pandas as pd, json, joblib
from sklearn.preprocessing import OneHotEncoder
from sklearn.compose import ColumnTransformer
from sklearn.pipeline import Pipeline
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import log_loss

from utils.config import DATA_DIR, INTERIM_DIR, PROCESSED_DIR
from utils.etl_clean import ensure_interim
from utils.splits import temporal_split, add_crisis_flag
from utils.candidates import build_origin_next_transitions, global_mf_next, build_pc_coords, build_candidates_for_split
from utils.features import build_ports_attr, compute_port_degree, attach_port_side, build_sample_side, merge_all_features
from utils.metrics import eval_topk_mrr

In [None]:
%%time
#Load samples and cleaned table
samples = pl.read_parquet(PROCESSED_DIR / "samples_taskA.parquet")
pc = pl.read_parquet(INTERIM_DIR / "port_calls.cleaned.parquet")
tr = pl.read_csv(DATA_DIR / "trades.csv",  try_parse_dates=True)
vs = pl.read_csv(DATA_DIR / "vessels.csv", try_parse_dates=True)

train, val, test = temporal_split(samples)
train = add_crisis_flag(train); val = add_crisis_flag(val); test = add_crisis_flag(test)

trans = build_origin_next_transitions(train)
g_top = global_mf_next(trans)
pc_coords = build_pc_coords(pc)

cand_train = build_candidates_for_split(train, trans, pc_coords, add_true_label=True,  N=10, M=10, global_top1=g_top)
cand_val   = build_candidates_for_split(val,   trans, pc_coords, add_true_label=True,  N=10, M=10, global_top1=g_top)
cand_test  = build_candidates_for_split(test,  trans, pc_coords, add_true_label=False, N=10, M=10, global_top1=g_top)

ports_attr  = build_ports_attr(pc_coords)
port_degree = compute_port_degree(trans)
cand_train  = attach_port_side(cand_train, ports_attr, port_degree)
cand_val    = attach_port_side(cand_val,   ports_attr, port_degree)
cand_test   = attach_port_side(cand_test,  ports_attr, port_degree)

s_side   = build_sample_side(samples, pc, vs)
cand_train = merge_all_features(cand_train, s_side, train)
cand_val   = merge_all_features(cand_val,   s_side, val)
cand_test  = merge_all_features(cand_test,  s_side, test)

print("cand_train:", cand_train.shape)

In [None]:
#  Logistic Ranker，One-Hot + LR

num_cols = ["dist_km","is_same_region","in_cnt","out_cnt","age",
            "prev_dist_km","last_leg_knots_est","month_sin","month_cos","dow_sin","dow_cos",
            "is_crisis_time","dist_x_crisis"]
cat_cols = ["origin","candidate","vessel_type","dwt_bucket","product_family_dom"]

def to_xy(df: pl.DataFrame):
    keep = ["sample_port_call_id","origin","candidate","label","y"] + num_cols + cat_cols
    missing = [c for c in keep if c not in df.columns]
    for c in missing:
        if c in num_cols:
            df = df.with_columns(pl.lit(0.0).alias(c))
        else:
            df = df.with_columns(pl.lit("unk").alias(c))
    pdf = df.select(keep).to_pandas()
    X = pdf[num_cols + cat_cols]
    y = pdf["y"].values
    meta = pdf[["sample_port_call_id","origin","candidate","label"]]
    return X, y, meta

Xtr, ytr, mtr = to_xy(cand_train)
Xva, yva, mva = to_xy(cand_val)
Xte, yte, mte = to_xy(cand_test)

preproc = ColumnTransformer(
    transformers=[("cat", OneHotEncoder(handle_unknown="ignore", sparse=True), cat_cols)],
    remainder="passthrough",
    sparse_threshold=1.0
)
clf = LogisticRegression(max_iter=300, class_weight="balanced", n_jobs=None)
pipe = Pipeline([("prep", preproc), ("clf", clf)])
pipe.fit(Xtr, ytr)

#  Top-K
def rank_predict(pipe, X, meta, ks=(1,3,5)):
    proba = pipe.predict_proba(X)[:,1]
    meta2 = meta.copy()
    meta2["score"] = proba
    topk = {}
    for sid, g in meta2.groupby("sample_port_call_id"):
        g2 = g.sort_values("score", ascending=False)
        topk[sid] = g2["candidate"].tolist()
    truth = []
    preds = []
    for sid, g in meta2.groupby("sample_port_call_id"):
        lab = g["label"].iloc[0]
        truth.append(lab)
        preds.append(topk[sid])
    return preds, truth

preds_val, truth_val = rank_predict(pipe, Xva, mva)
preds_te,  truth_te  = rank_predict(pipe, Xte, mte)

from utils.metrics import eval_topk_mrr
print("VAL:", eval_topk_mrr([p[:5] for p in preds_val], truth_val, ks=(1,3,5)))
print("TEST:", eval_topk_mrr([p[:5] for p in preds_te],  truth_te,  ks=(1,3,5)))

import joblib, json
outm = PROCESSED_DIR / "model_taskA_logreg.joblib"
joblib.dump(pipe, outm)
print("saved:", outm)