# 08 – MEG→VLCC Destination Ranker
# ---------------------------------
# Focused notebook to isolate Middle East Gulf (MEG) VLCC voyages and
# train a LightGBM ranker with the same memory safeguards used in 04.

## Prerequisites
# - Run notebooks `03_candidates_logistic_ranker` & `04_candidates_gbdt_ranker`
#   so cached candidate tables / enriched features already exist.
# - Ensure raw CSVs are available locally (no iCloud placeholders).
# - This notebook concentrates on crude-oil voyages departing MEG export ports
#   on VLCC/Suezmax tankers; the smaller cohort both reduces memory pressure
#   and gives a clearer signal to validate modelling ideas.

In [1]:
#Adjustable: Add parent directory (which contains utils/) to Python search path
import sys, os
sys.path.append(os.path.abspath(".."))  #  notebooks  sys.path

In [None]:
from __future__ import annotations

import json
from pathlib import Path

import numpy as np
import pandas as pd
import polars as pl
import lightgbm as lgb

from utils.config import PROCESSED_DIR, INTERIM_DIR, DATA_DIR
from utils.splits import temporal_split, add_crisis_flag
from utils.candidates import build_origin_next_transitions, build_pc_coords
from utils.features import (
    build_ports_attr,
    compute_port_degree,
    attach_port_side,
    build_sample_side,
    merge_all_features,
)
from utils.metrics import eval_topk_mrr

# ------------------------------------------------------------------
# Helper utilities (same as 04)
# ------------------------------------------------------------------
def downsample_groupwise(
    df: pl.DataFrame,
    neg_per_pos: int = 8,
    seed: int = 42,
) -> pl.DataFrame:
    """Keep all positives; random-sample negatives per sample."""
    out = []
    rng = np.random.default_rng(seed)
    for _, sub in df.group_by("sample_port_call_id", maintain_order=True):
        pos = sub.filter(pl.col("y") == 1)
        neg = sub.filter(pl.col("y") == 0)

        if pos.height == 0:
            k = min(10, neg.height)
            if k > 0:
                keep_idx = rng.choice(neg.height, size=k, replace=False)
                out.append(neg[keep_idx])
            continue

        k = min(neg.height, neg_per_pos * pos.height)
        if k > 0:
            keep_idx = rng.choice(neg.height, size=k, replace=False)
            out.append(pl.concat([pos, neg[keep_idx]]))
        else:
            out.append(pos)
    return pl.concat(out) if out else df.head(0)


def ensure_columns(df: pl.DataFrame, num_cols: list[str], cat_cols: list[str]) -> pl.DataFrame:
    out = df
    for c in num_cols:
        if c not in out.columns:
            out = out.with_columns(pl.lit(0.0).alias(c))
    for c in cat_cols:
        if c not in out.columns:
            out = out.with_columns(pl.lit("unk").alias(c))
    return out


def prep_for_lgb(
    df_pl: pl.DataFrame,
    keep_cols: list[str],
    num_cols: list[str],
    cat_cols: list[str],
):
    df_pl = ensure_columns(df_pl, num_cols, cat_cols)
    df_pd = df_pl.select(keep_cols).to_pandas()

    for c in cat_cols:
        if c in df_pd.columns:
            df_pd[c] = df_pd[c].astype("category")
    for c in num_cols:
        if c in df_pd.columns:
            df_pd[c] = (
                pd.to_numeric(df_pd[c], errors="coerce")
                  .replace([np.inf, -np.inf], np.nan)
                  .fillna(0.0)
            )

    X      = df_pd[num_cols + cat_cols]
    y      = df_pd["y"].astype(int).values
    groups = df_pd.groupby("sample_port_call_id")["candidate"].size().tolist()
    meta   = df_pd[["sample_port_call_id", "origin", "candidate", "label"]]
    return X, y, groups, meta


def subsample(df: pl.DataFrame, n: int = 75_000, seed: int = 42) -> pl.DataFrame:
    return df if df.height <= n else df.sample(n=n, seed=seed)


# ------------------------------------------------------------------
# Load base data and identify MEG VLCC cohort
# ------------------------------------------------------------------
samples  = pl.read_parquet(PROCESSED_DIR / "samples_taskA.parquet")
pc_clean = pl.read_parquet(INTERIM_DIR / "port_calls.cleaned.parquet")
vessels  = pl.read_csv(DATA_DIR / "vessels.csv")

train_samples, val_samples, test_samples = temporal_split(samples)

sample_side_all = build_sample_side(samples, pc_clean, vessels).with_columns(
    pl.col("vessel_type").fill_null("").str.to_uppercase(),
    pl.col("product_family_dom").fill_null(""),
)

# MEG export ports (UN/port names can扩展)
MEG_PORTS = [
    "RAS TANURA",
    "JUAYMAH TERMINAL",
    "YABLU ISLAND",
    "YANBU SOUTH TERMINAL",
    "JUBAIL COMMERCIAL PORT",
    "RAS AL KHAIR",
    "AL JUBAIL",
    "RAS LAFFAN",
    "ABU DHABI",
    "DAS ISLAND",
    "FUJAIRAH",
    "MINA AL AHMADI",
    "KHALIFA BIN SALMAN",
    "KHARG ISLAND",
]

# Identify sample IDs where:
#  - destination (current call) 属于 MEG_PORTS (作为 origin)
#  - vessel type contains VLCC/SUEZ/CRUDE
#  - product family crude oil/condensate (可选)
meg_vlcc_ids = (
    sample_side_all
    .join(samples.select("sample_port_call_id", "destination"), on="sample_port_call_id", how="left")
    .filter(
        pl.col("destination").str.contains("|".join(MEG_PORTS), literal=False)
        & pl.col("vessel_type").str.contains("VLCC|SUEZ|CRUDE")
    )
    .select("sample_port_call_id")
    .unique()
)

meg_id_list = meg_vlcc_ids.get_column("sample_port_call_id").to_list()

def split_card(name: str, df: pl.DataFrame):
    subset = df.filter(pl.col("sample_port_call_id").is_in(meg_id_list))
    share  = subset.height / df.height if df.height else 0.0
    print(f"{name:5s}: total={df.height:,}  MEG-VLCC={subset.height:,} ({share:.2%})")

print("MEG VLCC sample availability:")
split_card("train", train_samples)
split_card("val",   val_samples)
split_card("test",  test_samples)

meg_train_samples = train_samples.filter(pl.col("sample_port_call_id").is_in(meg_id_list))
meg_val_samples   = val_samples.filter(pl.col("sample_port_call_id").is_in(meg_id_list))
meg_test_samples  = test_samples.filter(pl.col("sample_port_call_id").is_in(meg_id_list))

# ------------------------------------------------------------------
# Load cached candidates, filter, enrich, apply memory guards
# ------------------------------------------------------------------
cand_train = pl.read_parquet(PROCESSED_DIR / "cand_train_cached.parquet")
cand_val   = pl.read_parquet(PROCESSED_DIR / "cand_val_cached.parquet")
cand_test  = pl.read_parquet(PROCESSED_DIR / "cand_test_cached.parquet")

cand_train = cand_train.filter(pl.col("sample_port_call_id").is_in(meg_id_list))
cand_val   = cand_val.filter(pl.col("sample_port_call_id").is_in(meg_id_list))
cand_test  = cand_test.filter(pl.col("sample_port_call_id").is_in(meg_id_list))

print("Raw MEG VLCC candidate rows:", cand_train.height, cand_val.height, cand_test.height)

pc_coords   = build_pc_coords(pc_clean)
ports_attr  = build_ports_attr(pc_coords)
transitions = build_origin_next_transitions(train_samples)  # 全量训练统计
port_degree = compute_port_degree(transitions)

sample_side_meg = sample_side_all.filter(pl.col("sample_port_call_id").is_in(meg_id_list))

lagged_splits = {
    "train": add_crisis_flag(meg_train_samples),
    "val":   add_crisis_flag(meg_val_samples),
    "test":  add_crisis_flag(meg_test_samples),
}

def enrich_split(cand_df: pl.DataFrame, split: str) -> pl.DataFrame:
    if cand_df.is_empty():
        return pl.DataFrame()

    split_df = lagged_splits[split]
    if split_df.is_empty():
        return pl.DataFrame()

    ids = split_df.select("sample_port_call_id").unique()
    filtered = cand_df.join(ids, on="sample_port_call_id", how="inner")
    if filtered.is_empty():
        return pl.DataFrame()

    enriched = attach_port_side(filtered, ports_attr, port_degree)
    s_side_split = sample_side_meg.join(ids, on="sample_port_call_id", how="inner")
    enriched = merge_all_features(enriched, s_side_split, split_df)

    return (
        enriched
        .select([c for c in enriched.columns if not c.endswith("_port")])
        .unique(subset=["sample_port_call_id", "candidate"])
    )

cand_train_meg = enrich_split(cand_train, "train")
cand_val_meg   = enrich_split(cand_val,   "val")
cand_test_meg  = enrich_split(cand_test,  "test")

# Memory guards: tighter caps (cohort更小)
cand_val_meg  = subsample(cand_val_meg,  n=50_000)
cand_test_meg = subsample(cand_test_meg, n=50_000)

if cand_train_meg.height > 400_000:
    cand_train_meg = cand_train_meg.sample(n=400_000, seed=42)

cand_train_ds = downsample_groupwise(cand_train_meg, neg_per_pos=6, seed=42)

print("Enriched MEG VLCC rows:")
print("  train:", cand_train_meg.height, "-> downsampled:", cand_train_ds.height)
print("  val  :", cand_val_meg.height)
print("  test :", cand_test_meg.height)

# ------------------------------------------------------------------
# LightGBM ranker (lighter config)
# ------------------------------------------------------------------
num_cols = [
    "dist_km", "is_same_region", "in_cnt", "out_cnt", "age",
    "prev_dist_km", "last_leg_knots_est",
    "month_sin", "month_cos", "dow_sin", "dow_cos",
    "is_crisis_time", "dist_x_crisis",
]
extra_num = [c for c in [
    "prior_prob_oc", "hist_cnt_oc", "prior_prob_oc_vtype",
    "prior_prob_oc_dwt", "prior_prob_oc_laden", "prior_prob_oc_pf",
    "geo_rank_in_sample", "log_dist_km", "cand_is_waypoint",
    "in_cnt_cand", "out_cnt_cand",
] if c in cand_train_meg.columns]
num_cols = list(dict.fromkeys(num_cols + extra_num))

cat_cols = [
    c for c in ["origin", "candidate", "vessel_type", "dwt_bucket", "product_family_dom"]
    if c in cand_train_meg.columns
]

keep_cols = list(dict.fromkeys(
    ["sample_port_call_id", "origin", "candidate", "label", "y", *num_cols, *cat_cols]
))

if cand_train_ds.is_empty() or cand_val_meg.is_empty():
    print("⚠️  MEG cohort too small; skipping LightGBM.")
else:
    Xtr, ytr, gtr, mtr = prep_for_lgb(cand_train_ds, keep_cols, num_cols, cat_cols)
    Xva, yva, gva, mva = prep_for_lgb(cand_val_meg, keep_cols, num_cols, cat_cols)
    Xte, yte, gte, mte = (
        prep_for_lgb(cand_test_meg, keep_cols, num_cols, cat_cols)
        if not cand_test_meg.is_empty() else (None, None, None, None)
    )

    ranker = lgb.LGBMRanker(
        objective="lambdarank",
        metric="map",
        learning_rate=0.05,
        n_estimators=250,
        num_leaves=31,
        subsample=0.8,
        colsample_bytree=0.8,
        random_state=42,
        n_jobs=4,
    )
    ranker.fit(
        Xtr, ytr, group=gtr,
        eval_set=[(Xva, yva)],
        eval_group=[gva],
        eval_at=[1, 3, 5],
        categorical_feature=cat_cols,
        callbacks=[
            lgb.early_stopping(stopping_rounds=60, verbose=False),
            lgb.log_evaluation(25),
        ],
    )

    def rank_predict(model, X, meta):
        scores = model.predict(X, num_iteration=model.best_iteration_)
        meta2 = meta.copy()
        meta2["score"] = scores
        preds, truth = [], []
        for sid, g in meta2.groupby("sample_port_call_id"):
            g2 = g.sort_values("score", ascending=False)
            preds.append(g2["candidate"].tolist())
            truth.append(g2["label"].iloc[0])
        return preds, truth

    preds_val, truth_val = rank_predict(ranker, Xva, mva)
    val_metrics = eval_topk_mrr([p[:5] for p in preds_val], truth_val, ks=(1, 3, 5))
    print("Validation metrics:", val_metrics)

    if Xte is not None:
        preds_te, truth_te = rank_predict(ranker, Xte, mte)
        test_metrics = eval_topk_mrr([p[:5] for p in preds_te], truth_te, ks=(1, 3, 5))
        print("Test metrics:", test_metrics)
    else:
        test_metrics = None
        print("Test split empty for MEG cohort.")

    model_path   = PROCESSED_DIR / "model_taskA_lgbm_ranker_meg_vlcc.txt"
    metrics_path = PROCESSED_DIR / "metrics_gbdt_meg_vlcc.json"
    ranker.booster_.save_model(model_path)
    with open(metrics_path, "w") as fh:
        json.dump({"val": val_metrics, "test": test_metrics}, fh, indent=2)
    print("Saved model:", model_path)
    print("Saved metrics:", metrics_path)
