# Fuel-Segment Preprocessing

This notebook keeps only the fields DSPy needs for each fuel segment:
`idx`, `flight_id`, `fuel_kg`, `flight_date`, `aircraft_type`, `origin_name`,
`destination_name`, plus a compact `track_summary`, the readable
`track_points_compact` string, and the `vertical_rate_balance` structure.

In [16]:

from __future__ import annotations

from pathlib import Path
from typing import Any, Dict, List, Literal, Optional
from collections import defaultdict, Counter
from datetime import datetime
import math

import polars as pl
from tqdm.auto import tqdm

pl.Config.set_tbl_rows(20)


def _locate_data_root(start: Path) -> Path:
    for candidate in (start, *start.parents):
        data_dir = candidate / "data"
        if data_dir.exists():
            return candidate
    raise FileNotFoundError("Unable to locate the `data/` directory relative to this notebook.")


PROJECT_ROOT = _locate_data_root(Path.cwd())
DATA_DIR = PROJECT_ROOT / "data"

DATA_TYPE: Literal["train", "rank", "final"] = "final"

fuel_filename = "fuel_train.parquet" if DATA_TYPE == "train" else f"fuel_{DATA_TYPE}_submission.parquet"
FUEL_FILE = DATA_DIR / fuel_filename
FLIGHTLIST_FILE = DATA_DIR / f"flightlist_{DATA_TYPE}.parquet"
FLIGHTS_DIR = DATA_DIR / f"flights_{DATA_TYPE}"
OUTPUT_FILE = DATA_DIR / DATA_TYPE / f"llm_segments.parquet"

for required in (FUEL_FILE, FLIGHTLIST_FILE):
    if not required.exists():
        raise FileNotFoundError(required)
FLIGHTS_DIR.mkdir(parents=True, exist_ok=True)
OUTPUT_FILE.parent.mkdir(parents=True, exist_ok=True)

print(f"Project root: {PROJECT_ROOT}")
print(f"Fuel file: {FUEL_FILE.name}")
print(f"Flight list file: {FLIGHTLIST_FILE.name}")
print(f"Flight tracks folder: {FLIGHTS_DIR}")
print(f"Output parquet: {OUTPUT_FILE}")


Project root: c:\Users\rayte\Work\prc2025dspy
Fuel file: fuel_final_submission.parquet
Flight list file: flightlist_final.parquet
Flight tracks folder: c:\Users\rayte\Work\prc2025dspy\data\flights_final
Output parquet: c:\Users\rayte\Work\prc2025dspy\data\final\llm_segments.parquet


## Helper functions

Everything required to summarise track points lives in this notebook so the
preprocessing does not depend on `prc_challenge` modules.

In [17]:

_NUMERIC_TRACK_COLUMNS = ("altitude", "groundspeed", "vertical_rate", "mach", "TAS", "CAS")
_PHASE_TO_CODE = {"climb": 1, "descent": -1, "cruise": 0, "level": 0, "mixed": 2, "unknown": 99}
TRACK_POINT_COLUMNS = [
    "timestamp",
    "source",
    "latitude",
    "longitude",
    "altitude",
    "groundspeed",
    "vertical_rate",
    "mach",
    "TAS",
    "CAS",
]


def _format_number(value: Any) -> str:
    if value is None:
        return "NA"
    try:
        numeric = float(value)
    except (TypeError, ValueError):
        return str(value)
    if math.isnan(numeric):
        return "NA"
    return format(numeric, ".4g")


def _parse_timestamp(value: Any) -> Optional[datetime]:
    if isinstance(value, datetime):
        return value
    if isinstance(value, str):
        try:
            return datetime.fromisoformat(value)
        except ValueError:
            return None
    return None


def _sample_indices(length: int, sample_size: int) -> List[int]:
    if length <= 0:
        return []
    if sample_size <= 1 or length == 1:
        return [0]
    step = (length - 1) / max(sample_size - 1, 1)
    indices: List[int] = []
    for i in range(sample_size):
        idx = int(round(i * step))
        if idx >= length:
            idx = length - 1
        if not indices or idx != indices[-1]:
            indices.append(idx)
    if indices[-1] != length - 1:
        indices.append(length - 1)
    return indices


def _clean_numeric_series(track_points: List[Dict[str, Any]], key: str) -> List[float]:
    values: List[float] = []
    for point in track_points:
        if not isinstance(point, dict):
            continue
        value = point.get(key)
        if value is None:
            continue
        try:
            numeric = float(value)
        except (TypeError, ValueError):
            continue
        if math.isnan(numeric):
            continue
        values.append(numeric)
    return values


def _safe_mean(values: List[float]) -> Optional[float]:
    if not values:
        return None
    return sum(values) / len(values)


def _safe_std(values: List[float], mean: Optional[float]) -> Optional[float]:
    if mean is None or len(values) < 2:
        return None
    variance = sum((value - mean) ** 2 for value in values) / (len(values) - 1)
    return math.sqrt(variance)


def _summarise_track_points(track_points: List[Dict[str, Any]], sample_count: int = 4) -> Dict[str, Any]:
    summary: Dict[str, Any] = {
        "num_points": len(track_points),
        "time_window": None,
        "source_counts": {},
        "numeric_profiles": {},
        "path_profile": None,
        "aggregate_features": {},
        "vertical_rate_balance": None,
        "phase_hint": "unknown",
    }
    if not track_points:
        return summary

    start_ts = track_points[0].get("timestamp")
    end_ts = track_points[-1].get("timestamp")
    start_dt = _parse_timestamp(start_ts)
    end_dt = _parse_timestamp(end_ts)
    time_window: Dict[str, Any] = {"start": start_ts, "end": end_ts, "minutes": None}
    if start_dt and end_dt:
        duration_minutes = (end_dt - start_dt).total_seconds() / 60.0
        time_window["minutes"] = round(duration_minutes, 3)
    summary["time_window"] = time_window

    source_counts = Counter((point.get("source") or "unknown") for point in track_points)
    summary["source_counts"] = dict(source_counts)

    numeric_profiles: Dict[str, Dict[str, Any]] = {}
    aggregates: Dict[str, float] = {}
    vertical_balance: Optional[Dict[str, float]] = None
    for column in _NUMERIC_TRACK_COLUMNS:
        series = _clean_numeric_series(track_points, column)
        if not series:
            continue
        indices = _sample_indices(len(series), sample_count)
        samples = [round(series[i], 4) for i in indices]
        delta = series[-1] - series[0]
        col_min = min(series)
        col_max = max(series)
        mean_val = _safe_mean(series)
        std_val = _safe_std(series, mean_val)
        profile: Dict[str, Any] = {
            "samples": samples,
            "delta": round(delta, 4),
            "range": round(col_max - col_min, 4),
            "min": round(col_min, 4),
            "max": round(col_max, 4),
        }
        if mean_val is not None:
            profile["mean"] = round(mean_val, 4)
        if std_val is not None:
            profile["std"] = round(std_val, 4)
        numeric_profiles[column] = profile
        aggregates[f"{column}_delta"] = round(delta, 4)
        aggregates[f"{column}_range"] = round(col_max - col_min, 4)
        if mean_val is not None:
            aggregates[f"{column}_mean"] = round(mean_val, 4)
        if std_val is not None:
            aggregates[f"{column}_std"] = round(std_val, 4)
        if column == "vertical_rate":
            threshold = 64.0
            positives = sum(1 for value in series if value > threshold)
            negatives = sum(1 for value in series if value < -threshold)
            total = len(series)
            zeros = total - positives - negatives
            vertical_balance = {
                "positive_frac": round(positives / total, 4),
                "negative_frac": round(negatives / total, 4),
                "near_zero_frac": round(max(zeros, 0) / total, 4),
            }
            aggregates["vertical_rate_positive_frac"] = vertical_balance["positive_frac"]
            aggregates["vertical_rate_negative_frac"] = vertical_balance["negative_frac"]
            aggregates["vertical_rate_near_zero_frac"] = vertical_balance["near_zero_frac"]
    summary["numeric_profiles"] = numeric_profiles
    summary["aggregate_features"] = aggregates
    summary["vertical_rate_balance"] = vertical_balance
    summary["phase_hint"] = _infer_phase(summary)
    aggregates["phase_hint_code"] = _PHASE_TO_CODE.get(summary["phase_hint"], 99)

    lat_series = _clean_numeric_series(track_points, "latitude")
    lon_series = _clean_numeric_series(track_points, "longitude")
    if lat_series and lon_series:
        idxs = _sample_indices(len(lat_series), sample_count)
        path_profile = {
            "lat_samples": [round(lat_series[i], 4) for i in idxs],
            "lon_samples": [round(lon_series[i], 4) for i in idxs],
            "delta_lat": round(lat_series[-1] - lat_series[0], 4),
            "delta_lon": round(lon_series[-1] - lon_series[0], 4),
        }
        summary["path_profile"] = path_profile
        aggregates["delta_lat"] = path_profile["delta_lat"]
        aggregates["delta_lon"] = path_profile["delta_lon"]

    return summary


def _infer_phase(summary: Dict[str, Any]) -> str:
    aggregates = summary.get("aggregate_features") or {}
    numeric = summary.get("numeric_profiles") or {}
    altitude_profile = numeric.get("altitude") or {}
    vertical_profile = numeric.get("vertical_rate") or {}
    delta_alt = altitude_profile.get("delta")
    if delta_alt is None:
        delta_alt = aggregates.get("altitude_delta")
    mean_vr = vertical_profile.get("mean")
    vrange = vertical_profile.get("range")
    balance = summary.get("vertical_rate_balance") or {}
    pos_frac = balance.get("positive_frac") if balance else None
    neg_frac = balance.get("negative_frac") if balance else None
    mean_vr = mean_vr if mean_vr is not None else aggregates.get("vertical_rate_mean")
    delta_alt = delta_alt if delta_alt is not None else 0.0
    mean_vr = mean_vr if mean_vr is not None else 0.0
    if delta_alt > 800 or mean_vr > 150:
        return "climb"
    if delta_alt < -800 or mean_vr < -150:
        return "descent"
    if vrange is not None and vrange < 200 and abs(mean_vr) < 80:
        return "cruise"
    if pos_frac is not None and neg_frac is not None and pos_frac > 0.2 and neg_frac > 0.2:
        return "mixed"
    return "level"


def _format_track_summary(summary: Dict[str, Any]) -> str:
    if not summary.get("num_points"):
        return "no track points"
    parts: List[str] = []
    time_window = summary.get("time_window") or {}
    start = time_window.get("start")
    end = time_window.get("end")
    minutes = time_window.get("minutes")
    if start and end:
        segment = f"time {start}->{end}"
        if minutes is not None:
            segment += f" ({_format_number(minutes)} min)"
        parts.append(segment)
    sources = summary.get("source_counts") or {}
    if sources:
        source_text = ", ".join(f"{src}:{count}" for src, count in sorted(sources.items()))
        parts.append(f"sources {source_text}")
    numeric_profiles = summary.get("numeric_profiles") or {}
    for column in ("altitude", "groundspeed", "vertical_rate", "mach"):
        stats = numeric_profiles.get(column)
        if not stats:
            continue
        samples = stats.get("samples") or []
        sample_text = " -> ".join(_format_number(val) for val in samples) if samples else "n/a"
        extras = []
        delta = stats.get("delta")
        if delta is not None:
            extras.append(f"delta {_format_number(delta)}")
        value_range = stats.get("range")
        if value_range is not None:
            extras.append(f"range {_format_number(value_range)}")
        mean_val = stats.get("mean")
        if mean_val is not None:
            extras.append(f"mean {_format_number(mean_val)}")
        if extras:
            parts.append(f"{column} {sample_text} ({', '.join(extras)})")
        else:
            parts.append(f"{column} {sample_text}")
    path_profile = summary.get("path_profile") or {}
    lat_samples = path_profile.get("lat_samples")
    lon_samples = path_profile.get("lon_samples")
    if lat_samples and lon_samples:
        pairs = " -> ".join(
            f"{_format_number(lat)}/{_format_number(lon)}"
            for lat, lon in zip(lat_samples, lon_samples)
        )
        parts.append(f"path {pairs}")
    delta_lat = path_profile.get("delta_lat")
    delta_lon = path_profile.get("delta_lon")
    if delta_lat is not None or delta_lon is not None:
        parts.append(f"delta_lat {_format_number(delta_lat)} delta_lon {_format_number(delta_lon)}")
    phase = summary.get("phase_hint")
    if phase:
        parts.append(f"phase {phase}")
    balance = summary.get("vertical_rate_balance") or {}
    if balance:
        parts.append(
            "vr balance +{pos:.2f} / -{neg:.2f} / ~0 {zero:.2f}".format(
                pos=balance.get("positive_frac", 0.0),
                neg=balance.get("negative_frac", 0.0),
                zero=balance.get("near_zero_frac", 0.0),
            )
        )
    compact_text = " | ".join(parts) if parts else "track summary unavailable"
    return compact_text[:600] + ("..." if len(compact_text) > 600 else "")


def _slice_track_window(track_df: Optional[pl.DataFrame], start: datetime, end: datetime) -> pl.DataFrame:
    if track_df is None or track_df.is_empty():
        return pl.DataFrame()
    available_cols = [col for col in TRACK_POINT_COLUMNS if col in track_df.columns]
    if not available_cols:
        return pl.DataFrame()
    return (
        track_df
        .filter((pl.col("timestamp") >= start) & (pl.col("timestamp") <= end))
        .select(available_cols)
        .sort("timestamp")
    )


def _prepare_track_points(df: pl.DataFrame) -> List[Dict[str, Any]]:
    if df.is_empty():
        return []
    records = df.to_dicts()
    for record in records:
        ts = record.get("timestamp")
        if hasattr(ts, "isoformat"):
            record["timestamp"] = ts.isoformat()
        record.setdefault("source", "unknown")
    return records


def summarise_segment_track(track_df: Optional[pl.DataFrame], start: datetime, end: datetime) -> Dict[str, Any]:
    window = _slice_track_window(track_df, start, end)
    track_points = _prepare_track_points(window)
    return _summarise_track_points(track_points)


## Load fuel segments and flight metadata

In [18]:

fuel_df = pl.read_parquet(FUEL_FILE).select(["idx", "flight_id", "start", "end", "fuel_kg"])
flightlist_df = pl.read_parquet(FLIGHTLIST_FILE).select([
    "flight_id",
    "flight_date",
    "aircraft_type",
    "origin_name",
    "destination_name",
])

segments_df = (
    fuel_df
    .join(flightlist_df, on="flight_id", how="left")
    .select([
        "idx",
        "flight_id",
        "fuel_kg",
        "flight_date",
        "aircraft_type",
        "origin_name",
        "destination_name",
        "start",
        "end",
    ])
    .sort(["flight_id", "start"])
)

print(f"Segments: {segments_df.height:,} across {segments_df['flight_id'].n_unique():,} flights")
segments_df.head()


Segments: 61,745 across 4,724 flights


idx,flight_id,fuel_kg,flight_date,aircraft_type,origin_name,destination_name,start,end
i64,str,null,date,str,str,str,datetime[ns],datetime[ns]
0,"""prc806615763""",,,,,,2025-09-01 03:03:10.925,2025-09-01 03:07:51.584
1,"""prc806615763""",,,,,,2025-09-01 03:07:51.584,2025-09-01 03:12:50.921
2,"""prc806615763""",,,,,,2025-09-01 03:12:50.921,2025-09-01 03:17:51.404
3,"""prc806615763""",,,,,,2025-09-01 03:17:51.404,2025-09-01 03:22:50.539
4,"""prc806615763""",,,,,,2025-09-01 03:22:50.539,2025-09-01 03:27:50.727


In [23]:
fuel_df

idx,flight_id,start,end,fuel_kg
i64,str,datetime[ns],datetime[ns],null
0,"""prc806615763""",2025-09-01 03:03:10.925,2025-09-01 03:07:51.584,
1,"""prc806615763""",2025-09-01 03:07:51.584,2025-09-01 03:12:50.921,
2,"""prc806615763""",2025-09-01 03:12:50.921,2025-09-01 03:17:51.404,
3,"""prc806615763""",2025-09-01 03:17:51.404,2025-09-01 03:22:50.539,
4,"""prc806615763""",2025-09-01 03:22:50.539,2025-09-01 03:27:50.727,
5,"""prc806615763""",2025-09-01 03:27:50.727,2025-09-01 03:32:50.802,
6,"""prc806615763""",2025-09-01 03:32:50.802,2025-09-01 03:37:51.778,
7,"""prc806615763""",2025-09-01 03:37:51.778,2025-09-01 03:42:50.706,
8,"""prc806615763""",2025-09-01 03:42:50.706,2025-09-01 03:47:50.452,
9,"""prc806615763""",2025-09-01 03:47:50.452,2025-09-01 03:52:51.183,


## Build the compact dataset

In [19]:
SCHEMA = {
    "idx": pl.Int64,
    "flight_id": pl.Utf8,
    "fuel_kg": pl.Float64,
    "flight_date": pl.Date,
    "aircraft_type": pl.Utf8,
    "origin_name": pl.Utf8,
    "destination_name": pl.Utf8,
    "start": pl.Datetime("ms"),
    "end": pl.Datetime("ms"),
    "track_points_compact": pl.Utf8,
    "vertical_rate_balance": pl.Struct([
        pl.Field("positive_frac", pl.Float64),
        pl.Field("negative_frac", pl.Float64),
        pl.Field("near_zero_frac", pl.Float64),
    ]),
}

def build_segment_dataset(segments: pl.DataFrame) -> pl.DataFrame:
    grouped = defaultdict(list)
    for row in segments.iter_rows(named=True):
        grouped[row["flight_id"]].append(row)

    processed: List[Dict[str, Any]] = []
    for flight_id, rows in tqdm(grouped.items(), desc="Summarising segments", total=len(grouped)):
        flight_path = FLIGHTS_DIR / f"{flight_id}.parquet"
        track_df = None
        if flight_path.exists():
            track_df = pl.read_parquet(flight_path).sort("timestamp")
        for row in rows:
            summary = summarise_segment_track(track_df, row["start"], row["end"])
            processed.append(
                {
                    "idx": row["idx"],
                    "flight_id": row["flight_id"],
                    "fuel_kg": row["fuel_kg"],
                    "flight_date": row["flight_date"],
                    "aircraft_type": row["aircraft_type"],
                    "origin_name": row["origin_name"],
                    "destination_name": row["destination_name"],
                    # "track_summary": summary,
                    "track_points_compact": _format_track_summary(summary),
                    "vertical_rate_balance": summary.get("vertical_rate_balance"),
                }
            )
        del track_df
    return pl.from_dicts(processed, schema=SCHEMA)

segment_records = build_segment_dataset(segments_df)
segment_records.write_parquet(OUTPUT_FILE)
print(f"Wrote {segment_records.height:,} rows to {OUTPUT_FILE}")
segment_records.head()


Summarising segments:   0%|          | 0/4724 [00:00<?, ?it/s]

Wrote 61,745 rows to c:\Users\rayte\Work\prc2025dspy\data\final\llm_segments.parquet


idx,flight_id,fuel_kg,flight_date,aircraft_type,origin_name,destination_name,start,end,track_points_compact,vertical_rate_balance
i64,str,f64,date,str,str,str,datetime[ms],datetime[ms],str,struct[3]
0,"""prc806615763""",,,,,,,,"""no track points""",
1,"""prc806615763""",,,,,,,,"""no track points""",
2,"""prc806615763""",,,,,,,,"""no track points""",
3,"""prc806615763""",,,,,,,,"""no track points""",
4,"""prc806615763""",,,,,,,,"""no track points""",


In [20]:
segment_records[0]["track_points_compact"]

track_points_compact
str
"""no track points"""
