## DSPy train/eval

Note: I use my own fork of DSPy because I had to implement asynchronous batching: https://github.com/rayanehmi/dspy/tree/feat/async_batching

In [None]:
from pathlib import Path
from typing import Literal
import os 

DATA_PATH = Path.cwd().parent / "data"
DATA_TYPE : Literal["train", "rank", "final"] = "train"
OUTPUT_DIR = os.path.join(DATA_PATH)


In [None]:
# Load complete data
import polars as pl

SEGMENTS_PATH = os.path.join(OUTPUT_DIR, "llm_segments_" + DATA_TYPE + ".parquet")
df = pl.read_parquet(SEGMENTS_PATH)
df.head()

In [None]:
import dspy
from dotenv import load_dotenv
load_dotenv()

In [None]:
api_key = os.getenv("OPENAI_API_KEY")
gpt_4_1 = dspy.LM("openai/gpt-4.1", api_key=api_key)
gpt_4_1_nano = dspy.LM("openai/gpt-4.1-nano", api_key=api_key)
groq_api_key = os.getenv("GROQ_API_KEY")
gpt_oss_120b = dspy.LM("groq/openai/gpt-oss-120b", api_key=groq_api_key, cache=False)
gpt_5_1_instant = dspy.LM("openai/gpt-5.1", api_key=api_key, temperature=1.0, max_tokens=32000, reasoning_effort="none", cache=False)

json_adapter = dspy.JSONAdapter()

dspy.configure(lm=gpt_oss_120b, adapter=json_adapter)

In [52]:
from typing import Any

class BurntFuelPrediction(dspy.Signature):
    """Predict the amount of fuel burnt in kgs by the plane over the given segment of flight.
    Segment data is constructed from noisy telemetry: use your common sense if values seem wrong.
    hint: vertical_rate_balance contains positive_frac, negative_frac and near_zero_frac, each corresponding 
    to the share of samples where vertical rate is respectively more than, less than or around 64 ft/min.
    hint 2: estimate the fuel weight penalty (heavy in the beginning, lighter in the end).
    """
    features : dict[str, Any] = dspy.InputField()
    fuel_kg : float = dspy.OutputField() 

# Zero-shot chain of thought
fuel_cot = dspy.ChainOfThought(BurntFuelPrediction)

In [35]:
def df_to_dspy_examples(row, with_fuel: bool = True):
    """Converts a row to a dspy.Example."""
    row_data = row.to_dicts()[0] if hasattr(row, "to_dicts") else row

    def clean(value):
        return "" if value is None else value

    inputs = [
        "aircraft_type",
        "origin_name",
        "origin_destination",
        "track_points_compact",
        "track_points_compact",
        "vertical_rate_balance"
    ]
    
    features = {key: clean(row_data.get(key)) for key in inputs}
    example = dspy.Example(features=features).with_inputs("features")
    if with_fuel:
        example.fuel_kg = clean(row_data.get("fuel_kg"))
    return example


In [36]:
examples = [
    df_to_dspy_examples(row, with_fuel=True)
    for row in df.iter_rows(named=True)
]
examples[0]

Example({'features': {'aircraft_type': 'B789', 'origin_name': 'Kuala Lumpur International Airport', 'origin_destination': '', 'track_points_compact': 'time 2025-04-13T02:31:04.447000->2025-04-13T03:01:04.487000 (30 min) | sources acars:2, adsb:3076 | altitude 3.597e+04 -> 3.597e+04 -> 3.6e+04 -> 3.6e+04 (delta 24.99, range 24.99, mean 3.599e+04) | groundspeed 467 -> 474 -> 471 -> 476 (delta 9, range 15, mean 472) | vertical_rate 0 -> 0 -> 0 -> -64 (delta -64, range 128, mean -3.36) | mach 0.86 -> 0.86 (delta 0, range 0, mean 0.86) | path 45.18/24.35 -> 45.9/22.72 -> 46.62/20.97 -> 47.22/19.52 | delta_lat 2.033 delta_lon -4.833 | phase cruise | vr balance +0.00 / -0.00 / ~0 1.00', 'vertical_rate_balance': {'positive_frac': 0.0, 'negative_frac': 0.0, 'near_zero_frac': 1.0}}, 'fuel_kg': 2500.0}) (input_keys={'features'})

In [37]:
import random
import copy

randomized_examples = copy.deepcopy(examples)
random.Random(42).shuffle(randomized_examples)

## Metrics

In [38]:
import math

def float_metric(gold: dspy.Example, pred: dspy.Prediction, trace=None):
    """Return a scalar score (negative squared error) for the evaluator."""
    true_value = gold.fuel_kg
    pred_value = pred.fuel_kg
    if true_value is None or pred_value is None:
        return float("nan")

    squared_error = (true_value - pred_value) ** 2

    if trace is None: # if we're doing evaluation or optimization
        return -squared_error
    else:  # During bootstrapping / trace collection we simply mark good demos.
        return squared_error < 40000  # Squared error 200.
    
fake_example = dspy.Example(features={"foo": "bar"}, fuel_kg=500.0)
fake_prediction = dspy.Prediction(features={"foo": "bar"}, fuel_kg=600.0)
print(float_metric(fake_example, fake_prediction))  # error -100.0
print(float_metric(fake_example, fake_prediction, trace='foo'))  # True

-10000.0
True


## Evaluation

In [39]:
from dspy.evaluate import Evaluate

In [47]:
evaluator_50 = Evaluate(
    devset=randomized_examples[:50],
    num_threads=50, 
    display_progress=True, 
    display_table=True
)

evaluator_500 = Evaluate(
    devset=randomized_examples[:400],
    num_threads=15, 
    display_progress=True, 
    display_table=True
)

In [48]:
with dspy.context(lm=gpt_oss_120b):
    eval_results = evaluator_500(fuel_cot, metric=float_metric)

Average Metric: -26070998.89 / 231 (-11286146.7%):  57%|█████▊    | 230/400 [00:20<00:48,  3.48it/s]



Average Metric: -26091407.26 / 232 (-11246296.2%):  58%|█████▊    | 232/400 [00:20<00:32,  5.22it/s]



Average Metric: -26186239.29 / 235 (-11143080.5%):  59%|█████▉    | 235/400 [00:21<00:36,  4.51it/s]



Average Metric: -26186239.29 / 236 (-11095864.1%):  59%|█████▉    | 236/400 [00:21<00:32,  5.05it/s]



Average Metric: -32147972.99 / 247 (-13015373.7%):  62%|██████▏   | 247/400 [00:23<00:35,  4.35it/s]

2025/11/29 22:00:36 ERROR dspy.utils.parallelizer: Error for Example({'features': {'aircraft_type': 'A320', 'origin_name': 'Toronto Pearson International Airport', 'origin_destination': '', 'track_points_compact': 'time 2025-06-29T08:15:15.211000->2025-06-29T08:20:15.498000 (5.005 min) | sources acars:2, adsb:467 | altitude 3.4e+04 -> 3.4e+04 -> 3.4e+04 -> 3.4e+04 (delta 0, range 49.99, mean 3.401e+04) | groundspeed 446 -> 446 -> 448 -> 449 (delta 3, range 4, mean 447.1) | vertical_rate 0 -> 0 -> 0 -> 0 (delta 0, range 128, mean -4.861) | mach 0.774 -> 0.774 (delta 0, range 0, mean 0.774) | path 31.84/-81.95 -> 31.62/-82.01 -> 31.41/-82.05 -> 31.23/-82.09 | delta_lat -0.61 delta_lon -0.1383 | phase cruise | vr balance +0.00 / -0.00 / ~0 1.00', 'vertical_rate_balance': {'positive_frac': 0.0, 'negative_frac': 0.0, 'near_zero_frac': 1.0}}, 'fuel_kg': 200.0}) (input_keys={'features'}): litellm.RateLimitError: RateLimitError: GroqException - {"error":{"message":"Rate limit reached for model

Average Metric: -32467250.09 / 254 (-12782381.9%):  64%|██████▎   | 254/400 [00:24<00:24,  5.96it/s]



Average Metric: -32618923.44 / 259 (-12594178.9%):  65%|██████▍   | 259/400 [00:25<00:27,  5.08it/s]



Average Metric: -32652517.79 / 264 (-12368378.0%):  66%|██████▋   | 265/400 [00:27<00:28,  4.77it/s]



Average Metric: -32673621.53 / 267 (-12237311.4%):  67%|██████▋   | 268/400 [00:27<00:29,  4.55it/s]

2025/11/29 22:00:40 ERROR dspy.utils.parallelizer: Error for Example({'features': {'aircraft_type': 'A320', 'origin_name': 'Uruapan International Airport', 'origin_destination': '', 'track_points_compact': 'time 2025-07-12T04:34:52.818000->2025-07-12T04:39:53.385000 (5.009 min) | sources acars:2, adsb:608 | altitude 2.63e+04 -> 2.253e+04 -> 1.912e+04 -> 1.581e+04 (delta -1.049e+04, range 1.049e+04, mean 2.087e+04) | groundspeed 429 -> 400 -> 375 -> 354 (delta -75, range 75, mean 389.6) | vertical_rate -2496 -> -2176 -> -2112 -> -1920 (delta 576, range 832, mean -2096) | mach 0.71 -> 0.561 (delta -0.149, range 0.149, mean 0.6355) | path 33.1/-116.6 -> 33.15/-116.8 -> 33.18/-117 -> 33.2/-117.2 | delta_lat 0.1017 delta_lon -0.6328 | phase descent | vr balance +0.00 / -1.00 / ~0 0.00', 'vertical_rate_balance': {'positive_frac': 0.0, 'negative_frac': 1.0, 'near_zero_frac': 0.0}}, 'fuel_kg': 100.0}) (input_keys={'features'}): litellm.RateLimitError: RateLimitError: GroqException - {"error":{

Average Metric: -33977180.48 / 272 (-12491610.5%):  68%|██████▊   | 274/400 [00:28<00:19,  6.59it/s]



Average Metric: -34170305.50 / 275 (-12425565.6%):  69%|██████▉   | 277/400 [00:29<00:19,  6.46it/s]



Average Metric: -34194450.74 / 278 (-12300162.1%):  70%|███████   | 280/400 [00:30<00:23,  5.19it/s]



Average Metric: -34258273.65 / 282 (-12148324.0%):  71%|███████   | 283/400 [00:30<00:21,  5.37it/s]



Average Metric: -35563740.84 / 292 (-12179363.3%):  73%|███████▎  | 293/400 [00:32<00:12,  8.37it/s]

2025/11/29 22:00:44 ERROR dspy.utils.parallelizer: Error for Example({'features': {'aircraft_type': 'A20N', 'origin_name': 'Quito Mariscal Sucre International Airport', 'origin_destination': '', 'track_points_compact': 'time 2025-08-10T03:43:18.399000->2025-08-10T03:48:18.755000 (5.006 min) | sources acars:2, adsb:589 | altitude 3.7e+04 -> 3.7e+04 -> 3.7e+04 -> 3.7e+04 (delta 0, range 49.99, mean 3.7e+04) | groundspeed 474 -> 474 -> 475 -> 475 (delta 1, range 1, mean 474.3) | vertical_rate 64 -> 0 -> 64 -> 0 (delta -64, range 128, mean -1.099) | mach 0.783 -> 0.779 (delta -0.004, range 0.004, mean 0.781) | path 35.31/-76.97 -> 35.52/-76.87 -> 35.72/-76.76 -> 35.91/-76.65 | delta_lat 0.6017 delta_lon 0.315 | phase cruise | vr balance +0.00 / -0.00 / ~0 1.00', 'vertical_rate_balance': {'positive_frac': 0.0, 'negative_frac': 0.0, 'near_zero_frac': 1.0}}, 'fuel_kg': 74.389088}) (input_keys={'features'}): litellm.RateLimitError: RateLimitError: GroqException - {"error":{"message":"Rate limi

Average Metric: -35568650.38 / 293 (-12139471.1%):  74%|███████▍  | 295/400 [00:32<00:13,  7.98it/s]



Average Metric: -35570350.38 / 295 (-12057745.9%):  74%|███████▍  | 297/400 [00:33<00:13,  7.75it/s]



Average Metric: -78469331.44 / 307 (-25560042.8%):  78%|███████▊  | 310/400 [00:35<00:11,  8.07it/s]



Average Metric: -78559331.44 / 308 (-25506276.4%):  78%|███████▊  | 310/400 [00:35<00:11,  8.07it/s]



Average Metric: -78561760.11 / 309 (-25424517.8%):  78%|███████▊  | 312/400 [00:36<00:22,  3.83it/s]

2025/11/29 22:00:48 ERROR dspy.utils.parallelizer: Error for Example({'features': {'aircraft_type': 'A20N', 'origin_name': 'Quito Mariscal Sucre International Airport', 'origin_destination': '', 'track_points_compact': 'time 2025-08-13T04:04:48.579000->2025-08-13T04:06:44.636000 (1.934 min) | sources acars:2, adsb:231 | altitude 2.25e+04 -> 2.088e+04 -> 1.93e+04 -> 1.78e+04 (delta -4700, range 4700, mean 2.009e+04) | groundspeed 426 -> 417 -> 409 -> 402 (delta -24, range 24, mean 413.6) | vertical_rate -2496 -> -2432 -> -2240 -> -1536 (delta 960, range 1216, mean -2418) | mach 0.685 -> 0.626 (delta -0.059, range 0.059, mean 0.6555) | path 39.42/-74.45 -> 39.48/-74.38 -> 39.54/-74.32 -> 39.59/-74.27 | delta_lat 0.1683 delta_lon 0.1833 | phase descent | vr balance +0.00 / -1.00 / ~0 0.00', 'vertical_rate_balance': {'positive_frac': 0.0, 'negative_frac': 1.0, 'near_zero_frac': 0.0}}, 'fuel_kg': 8.164656}) (input_keys={'features'}): litellm.RateLimitError: RateLimitError: GroqException - {

Average Metric: -78561760.11 / 309 (-25424517.8%):  78%|███████▊  | 313/400 [00:36<00:20,  4.22it/s]



Average Metric: -78742924.51 / 316 (-24918647.0%):  80%|████████  | 320/400 [00:37<00:12,  6.19it/s]



Average Metric: -78779246.89 / 321 (-24541821.5%):  81%|████████  | 324/400 [00:38<00:16,  4.63it/s]

2025/11/29 22:00:50 ERROR dspy.utils.parallelizer: Error for Example({'features': {'aircraft_type': 'A20N', 'origin_name': 'Portland International Airport', 'origin_destination': '', 'track_points_compact': 'time 2025-05-13T07:13:16.220000->2025-05-13T07:18:21.096000 (5.081 min) | sources acars:2, adsb:594 | altitude 3.494e+04 -> 3.502e+04 -> 3.5e+04 -> 3.496e+04 (delta 20, range 113.9, mean 3.501e+04) | groundspeed 448 -> 454 -> 454 -> 448 (delta 0, range 7, mean 451.7) | vertical_rate 64 -> 64 -> -128 -> 0 (delta -64, range 256, mean 5.444) | path 41.04/-120.9 -> 40.83/-120.9 -> 40.61/-120.9 -> 40.4/-120.9 | delta_lat -0.6343 delta_lon 0.0547 | phase level | vr balance +0.00 / -0.04 / ~0 0.96', 'vertical_rate_balance': {'positive_frac': 0.0, 'negative_frac': 0.0414, 'near_zero_frac': 0.9586}}, 'fuel_kg': 90.7184}) (input_keys={'features'}): litellm.RateLimitError: RateLimitError: GroqException - {"error":{"message":"Rate limit reached for model `openai/gpt-oss-120b` in organization `

Average Metric: -78779247.89 / 322 (-24465604.9%):  82%|████████▏ | 326/400 [00:38<00:11,  6.26it/s]



Average Metric: -78814323.07 / 328 (-24028757.0%):  83%|████████▎ | 333/400 [00:39<00:12,  5.42it/s]



Average Metric: -78818529.32 / 330 (-23884402.8%):  84%|████████▍ | 335/400 [00:40<00:18,  3.47it/s]



Average Metric: -78824198.95 / 332 (-23742228.6%):  84%|████████▍ | 336/400 [00:41<00:16,  3.91it/s]



Average Metric: -79561141.42 / 335 (-23749594.5%):  85%|████████▌ | 340/400 [00:41<00:12,  4.92it/s]



Average Metric: -79561291.12 / 336 (-23678955.7%):  85%|████████▌ | 341/400 [00:41<00:13,  4.47it/s]



Average Metric: -79803604.25 / 345 (-23131479.5%):  87%|████████▋ | 349/400 [00:43<00:09,  5.14it/s]

2025/11/29 22:00:55 ERROR dspy.utils.parallelizer: Error for Example({'features': {'aircraft_type': 'A359', 'origin_name': 'San Francisco International Airport', 'origin_destination': '', 'track_points_compact': 'time 2025-07-18T17:17:46.771000->2025-07-18T17:32:47.415000 (15.01 min) | sources acars:2, adsb:810 | altitude 4e+04 -> 4e+04 -> 4e+04 -> 4e+04 (delta 0, range 49.99, mean 4e+04) | groundspeed 497 -> 492 -> 487 -> 478 (delta -19, range 21, mean 488.6) | vertical_rate 0 -> 64 -> 0 -> -64 (delta -64, range 256, mean -4.115) | mach 0.849 -> 0.85 (delta 0.001, range 0.001, mean 0.8495) | path 27.43/128.7 -> 26.95/128.3 -> 26.46/127.9 -> 25.95/127.2 | delta_lat -1.485 delta_lon -1.499 | phase level | vr balance +0.01 / -0.02 / ~0 0.98', 'vertical_rate_balance': {'positive_frac': 0.0051, 'negative_frac': 0.0169, 'near_zero_frac': 0.978}}, 'fuel_kg': 1270.0576}) (input_keys={'features'}): litellm.RateLimitError: RateLimitError: GroqException - {"error":{"message":"Rate limit reached 

Average Metric: -79868647.32 / 350 (-22819613.5%):  89%|████████▉ | 355/400 [00:44<00:06,  6.43it/s]



Average Metric: -80128590.16 / 355 (-22571433.8%):  90%|█████████ | 361/400 [00:45<00:07,  5.13it/s]



Average Metric: -85545965.56 / 362 (-23631482.2%):  92%|█████████▏| 367/400 [00:46<00:05,  5.69it/s]



Average Metric: -85546190.56 / 363 (-23566443.7%):  92%|█████████▏| 369/400 [00:46<00:05,  5.33it/s]



Average Metric: -85613268.07 / 368 (-23264475.0%):  94%|█████████▎| 374/400 [00:47<00:04,  6.14it/s]



Average Metric: -86055077.87 / 377 (-22826280.6%):  96%|█████████▌| 383/400 [00:49<00:02,  7.59it/s]

2025/11/29 22:01:02 ERROR dspy.utils.parallelizer: Error for Example({'features': {'aircraft_type': 'A20N', 'origin_name': 'Guadalajara International Airport', 'origin_destination': '', 'track_points_compact': 'time 2025-07-10T02:13:55.309000->2025-07-10T02:18:54.684000 (4.99 min) | sources acars:2, adsb:585 | altitude 3.602e+04 -> 3.6e+04 -> 3.6e+04 -> 3.595e+04 (delta -74, range 99.97, mean 3.6e+04) | groundspeed 398 -> 414 -> 440 -> 416 (delta 18, range 42, mean 420) | vertical_rate 64 -> 0 -> 0 -> 0 (delta -64, range 256, mean -14.72) | mach 0.754 -> 0.78 (delta 0.026, range 0.026, mean 0.767) | path 44.18/-121.6 -> 44.36/-121.7 -> 44.54/-121.8 -> 44.7/-121.9 | delta_lat 0.5181 delta_lon -0.3483 | phase level | vr balance +0.02 / -0.09 / ~0 0.89', 'vertical_rate_balance': {'positive_frac': 0.0175, 'negative_frac': 0.0877, 'near_zero_frac': 0.8947}}, 'fuel_kg': 45.3592}) (input_keys={'features'}): litellm.RateLimitError: RateLimitError: GroqException - {"error":{"message":"Rate limi

Average Metric: -86060068.12 / 379 (-22707142.0%):  96%|█████████▋| 385/400 [00:49<00:02,  6.28it/s]



Average Metric: -86106547.58 / 382 (-22540981.0%):  97%|█████████▋| 389/400 [00:50<00:01,  5.50it/s]



Average Metric: -86597409.44 / 393 (-22034964.2%): 100%|██████████| 400/400 [00:54<00:00,  7.37it/s]

2025/11/29 22:01:06 INFO dspy.evaluate.evaluate: Average Metric: -86597409.4413901 / 400 (-21649352.4%)





Unnamed: 0,features,example_fuel_kg,pred_fuel_kg,reasoning,float_metric,fuel_kg
0,"{'aircraft_type': 'A320', 'origin_name': 'Guatemala City La Aurora...",54.000000,78.2,The segment lasts about 3.13 minutes (from 05:28:03 to 05:31:11). ...,✔️ [-585.640],
1,"{'aircraft_type': 'A21N', 'origin_name': 'Guadalajara Internationa...",300.000000,235.0,The segment lasts roughly 10 minutes (9.99 min) and is a descent f...,✔️ [-4225.000],
2,"{'aircraft_type': 'A20N', 'origin_name': 'San Jose Juan Santamaria...",90.718400,200.0,"The segment is a short 5.1‑minute cruise at 38,000 ft on an A320ne...",✔️ [-11942.468],
3,"{'aircraft_type': 'A20N', 'origin_name': 'Montreal Pierre Elliott ...",74.389088,190.0,The segment lasts about 5.03 minutes (≈0.0838 h). The aircraft is ...,✔️ [-13365.883],
4,"{'aircraft_type': 'B738', 'origin_name': 'Istanbul Sabiha Gokcen I...",526.166720,400.0,The segment is a cruise phase lasting 9.423 minutes (0.157 hours) ...,✔️ [-15918.041],
...,...,...,...,...,...,...
395,"{'aircraft_type': 'B789', 'origin_name': 'Shanghai Pudong Internat...",400.068144,350.0,The segment is a short cruise leg lasting 3.948 minutes (~0.0658 h...,✔️ [-2506.819],
396,"{'aircraft_type': 'A20N', 'origin_name': 'Guatemala City La Aurora...",90.718400,223.4,The segment lasts 5.362 minutes (about 0.0894 hours). For an Airbu...,✔️ [-17604.407],
397,"{'aircraft_type': 'A359', 'origin_name': 'Taipei Taoyuan Internati...",1769.008800,1240.0,The segment is a short cruise leg of about 15 minutes (0.25 h) at ...,✔️ [-279850.310],
398,"{'aircraft_type': 'A359', 'origin_name': 'Rome Leonardo da Vinci F...",1700.000000,1800.0,"The segment shows a 20‑minute level flight at cruise altitude (38,...",✔️ [-10000.000],


In [49]:
def compute_rmse(eval_results):
    rmse = math.sqrt(abs(eval_results.score)/len(eval_results.results))
    return rmse

In [50]:
compute_rmse(eval_results)

232.64432273322294

In [None]:
# dump in a csv file
import pandas as pd
import csv

results_list = []
for result in eval_results.results:
    true_value = result[0].fuel_burnt
    reasoning = result[1].reasoning
    pred_value = result[1].fuel_burnt
    metric = result[2]
    results_list.append({
        "true_value": true_value,
        "reasoning": reasoning,
        "pred_value": pred_value,
        "metric": metric
    })
results_df = pd.DataFrame(results_list)
results_df.to_csv("eval_results.csv", index=False, quoting=csv.QUOTE_ALL)

## Batch API

In [None]:
# 0. prepare batches of 100 examples each


In [None]:


# 1. create a jsonl file
artifacts = fuel_cot.create_batch_file(
    examples,
    input_file_path="batches/math_input.jsonl",
)
artifacts

## Usage

In [51]:
for lm in [gpt_oss_120b]:
  cost = sum([x['cost'] for x in lm.history if x['cost'] is not None])
  print(cost)

0.3628254
