# 05  Sequence Forecast (Greedy / Beam) for Task‑A

In [1]:
#Adjustable: Add parent directory (which contains utils/) to Python search path
import sys, os
sys.path.append(os.path.abspath(".."))  #  notebooks  sys.path

In [6]:
import numpy as np, pandas as pd, polars as pl, joblib
from utils.config import DATA_DIR, INTERIM_DIR, PROCESSED_DIR
from utils.splits import temporal_split, add_crisis_flag
from utils.candidates import build_origin_next_transitions, global_mf_next, build_pc_coords, build_candidates_for_split
from utils.features import build_ports_attr, compute_port_degree, attach_port_side, build_sample_side, merge_all_features
from sklearn.preprocessing import OneHotEncoder

samples = pl.read_parquet(PROCESSED_DIR / "samples_taskA.parquet")
pc = pl.read_parquet(INTERIM_DIR / "port_calls.cleaned.parquet")
tr = pl.read_csv(DATA_DIR / "trades.csv",  try_parse_dates=True)
vs = pl.read_csv(DATA_DIR / "vessels.csv", try_parse_dates=True)

train, val, test = temporal_split(samples)
train = add_crisis_flag(train); val = add_crisis_flag(val); test = add_crisis_flag(test)

trans = build_origin_next_transitions(train)
g_top = global_mf_next(trans)
pc_coords = build_pc_coords(pc)

shape: (1, 16)
┌───────────┬───────────┬───────────┬───────────┬───┬───────────┬───────────┬───────────┬──────────┐
│ sample_po ┆ vessel_id ┆ destinati ┆ destinati ┆ … ┆ prev_dist ┆ last_leg_ ┆ product_f ┆ is_crisi │
│ rt_call_i ┆ ---       ┆ on        ┆ on_latitu ┆   ┆ _km       ┆ knots_est ┆ amily_dom ┆ s_time   │
│ d         ┆ i64       ┆ ---       ┆ de        ┆   ┆ ---       ┆ ---       ┆ ---       ┆ ---      │
│ ---       ┆           ┆ str       ┆ ---       ┆   ┆ f64       ┆ f64       ┆ str       ┆ bool     │
│ i64       ┆           ┆           ┆ f64       ┆   ┆           ┆           ┆           ┆          │
╞═══════════╪═══════════╪═══════════╪═══════════╪═══╪═══════════╪═══════════╪═══════════╪══════════╡
│ 371013981 ┆ 56709     ┆ Texas     ┆ 29.370731 ┆ … ┆ 10.109735 ┆ 0.012947  ┆ chem/bio  ┆ true     │
│           ┆           ┆ City      ┆           ┆   ┆           ┆           ┆           ┆          │
└───────────┴───────────┴───────────┴───────────┴───┴───────────┴───────────

In [4]:
#Load a trained ranker (LR or GBDT)
lr_path  = PROCESSED_DIR / "model_taskA_logreg.joblib"
gbdt_path= PROCESSED_DIR / "model_taskA_gbdt.joblib"
use_gbdt = gbdt_path.exists()

if use_gbdt:
    pack = joblib.load(gbdt_path)
    clf = pack["clf"]; enc = pack["enc"]
    num_cols = pack["num_cols"]; cat_cols = pack["cat_cols"]
else:
    clf = joblib.load(lr_path)  # pipeline

In [17]:
seed = test.head(1)
print(seed)

shape: (1, 16)
┌───────────┬───────────┬───────────┬───────────┬───┬───────────┬───────────┬───────────┬──────────┐
│ sample_po ┆ vessel_id ┆ destinati ┆ destinati ┆ … ┆ prev_dist ┆ last_leg_ ┆ product_f ┆ is_crisi │
│ rt_call_i ┆ ---       ┆ on        ┆ on_latitu ┆   ┆ _km       ┆ knots_est ┆ amily_dom ┆ s_time   │
│ d         ┆ i64       ┆ ---       ┆ de        ┆   ┆ ---       ┆ ---       ┆ ---       ┆ ---      │
│ ---       ┆           ┆ str       ┆ ---       ┆   ┆ f64       ┆ f64       ┆ str       ┆ bool     │
│ i64       ┆           ┆           ┆ f64       ┆   ┆           ┆           ┆           ┆          │
╞═══════════╪═══════════╪═══════════╪═══════════╪═══╪═══════════╪═══════════╪═══════════╪══════════╡
│ 371013981 ┆ 56709     ┆ Texas     ┆ 29.370731 ┆ … ┆ 10.109735 ┆ 0.012947  ┆ chem/bio  ┆ true     │
│           ┆           ┆ City      ┆           ┆   ┆           ┆           ┆           ┆          │
└───────────┴───────────┴───────────┴───────────┴───┴───────────┴───────────

In [11]:
%%time
# Greedy one-step predictor
ports_attr  = build_ports_attr(pc_coords)
port_degree = compute_port_degree(trans)
s_side      = build_sample_side(samples, pc, vs)

def rank_topk_for_sample(sample_row: pl.DataFrame, k: int = 5) -> pd.DataFrame:
    if sample_row.height == 0:
        raise ValueError("sample_row is empty; supply at least one sample_port_call_id.")
    sample_id = sample_row.select("sample_port_call_id").item()
    truth = sample_row.select("next_call_name").item() if "next_call_name" in sample_row.columns else None

    cands = build_candidates_for_split(
        sample_row,
        trans,
        pc_coords,
        add_true_label=False,
        N=10,
        M=10,
        global_top1=g_top,
    )
    base_cols = ["sample_port_call_id", "origin", "candidate", "label", "y"]
    cands = cands.select([c for c in base_cols if c in cands.columns])

    cands = attach_port_side(cands, ports_attr, port_degree)

    sample_row = add_crisis_flag(sample_row)
    sample_row = sample_row.drop([c for c in ("origin", "candidate") if c in sample_row.columns])

    cands = merge_all_features(cands, s_side, sample_row)
    cands = cands.unique(subset=["sample_port_call_id", "candidate"])

    num_cols = [
        "dist_km", "is_same_region", "in_cnt", "out_cnt", "age",
        "prev_dist_km", "last_leg_knots_est",
        "month_sin", "month_cos", "dow_sin", "dow_cos",
        "is_crisis_time", "dist_x_crisis",
    ]
    cat_cols = ["origin", "candidate", "vessel_type", "dwt_bucket", "product_family_dom"]
    cols = list(dict.fromkeys(["sample_port_call_id", "origin", "candidate"] + num_cols + cat_cols))

    for c in num_cols:
        if c not in cands.columns:
            cands = cands.with_columns(pl.lit(0.0).alias(c))
    for c in cat_cols:
        if c not in cands.columns:
            cands = cands.with_columns(pl.lit("unk").alias(c))

    pdf = cands.select(cols).to_pandas()

    if hasattr(clf, "predict_proba") and hasattr(clf, "steps"):
        X = pdf[num_cols + cat_cols]
        proba = clf.predict_proba(X)[:, 1]
    else:
        X_num = pdf[num_cols].values
        X_cat = enc.transform(pdf[cat_cols])
        X = np.hstack([X_num, X_cat])
        proba = clf.predict_proba(X)[:, 1]

    pdf["score"] = proba
    pdf = pdf.sort_values("score", ascending=False)
    topk = pdf.head(k)[["candidate", "score"]].copy()
    topk.insert(0, "sample_port_call_id", sample_id)
    topk["truth_next_port"] = truth
    return topk

def preview_samples(split: str = "test", sample_ids: list[int] | None = None, n: int = 3, k: int = 5) -> pd.DataFrame:
    split_map = {"train": train, "val": val, "test": test}
    base = split_map[split]
    if sample_ids is None:
        n = min(n, base.height)
        picks = base.sample(n=n, seed=42)
    else:
        picks = base.filter(pl.col("sample_port_call_id").is_in(sample_ids))
    frames = []
    for sid in picks["sample_port_call_id"]:
        row = base.filter(pl.col("sample_port_call_id") == sid)
        frames.append(rank_topk_for_sample(row, k=k))
    return pd.concat(frames, ignore_index=True)

    sample_port_call_id         candidate     score truth_next_port
3             371013981          Quintero  0.900633      Texas City
2             371013981           Houston  0.723351      Texas City
11            371013981             Tampa  0.667947      Texas City
4             371013981  Galveston Light.  0.624828      Texas City
12            371013981     Puerto Cortes  0.574844      Texas City
   sample_port_call_id      candidate     score truth_next_port
0            371140677          Chiba  0.744085           Chiba
1            371140677       Kawasaki  0.695198           Chiba
2            371140677          Ulsan  0.648185           Chiba
3            370979733  San Francisco  0.706300   San Francisco
4            370979733     Esmeraldas  0.658761   San Francisco
5            370979733  Panama Light.  0.614956   San Francisco
CPU times: user 3min 20s, sys: 995 ms, total: 3min 21s
Wall time: 3min 21s


In [18]:
# Example: Single sample (seed defined above)
print(rank_topk_for_sample(seed, k=5))

    sample_port_call_id         candidate     score truth_next_port
1             371013981          Quintero  0.900633      Texas City
7             371013981           Houston  0.723351      Texas City
0             371013981             Tampa  0.667947      Texas City
12            371013981  Galveston Light.  0.624828      Texas City
9             371013981     Puerto Cortes  0.574844      Texas City


In [20]:
# Example: Randomly select 2 test samples to view top-3
print(preview_samples(split="test", n=5, k=3))

    sample_port_call_id              candidate     score  \
0             371070683                  Chiba  0.725499   
1             371070683               Kawasaki  0.674615   
2             371070683                  Ulsan  0.626134   
3             371284739  Khor Al Zubair Light.  0.792790   
4             371284739                  Sikka  0.589102   
5             371284739    Khor Al Zubair Port  0.583356   
6             371220225                  Chiba  0.768799   
7             371220225                  Ulsan  0.678155   
8             371220225                 Nagoya  0.668338   
9             371174635              Rotterdam  0.753377   
10            371174635                Antwerp  0.698497   
11            371174635               Goteborg  0.610989   
12            371108205                Malacca  0.674182   
13            371108205                  Dumai  0.668929   
14            371108205               Surabaya  0.641927   

          truth_next_port  
0          