# Preprocessing

Goal: Link the flight data to fuel segments and chunk them for batching.

In [19]:
import os
import pandas as pd
import polars as pl
import pyarrow.dataset as ds
from typing import Literal

In [20]:
DATA_PATH = os.path.join(os.getcwd(), "data")
DATA_TYPE : Literal["train", "rank", "final"] = "train"

## 1. Augment segments with select features

Refer to DATA.md for explanations on the data.

Let's link the features to each fuel segment.

Since we are working with LLMs, we can save tokens by:
- excluding features that are not used in their reasoning
- summarizing features which have many data points, such as track_points and timestamps
- round (and convert) numbers, especially those with high decimal precision
- possibly exclude missing values ("NaN" --> "")

In [21]:
fuel_file_name = "fuel_" + (DATA_TYPE if DATA_TYPE == "train" else DATA_TYPE + "_submission")
fuel_file_path = os.path.join(DATA_PATH, fuel_file_name + ".parquet")
flightlist_path = os.path.join(DATA_PATH, "flightlist_" + DATA_TYPE + ".parquet")
flights_folder_path = os.path.join(DATA_PATH, "flights_" + DATA_TYPE)
airports_file_path = os.path.join(DATA_PATH, "apt.parquet")

In [None]:
fuel_df_test = pd.read_parquet(fuel_file_path)

In [None]:
# Print number of unique flights
print(fuel_df_test["flight_id"].nunique())
print(len(fuel_df_test))
print(len(fuel_df_test) / fuel_df_test["flight_id"].nunique())

Method 1

In [13]:
fuel_df = pl.read_parquet(fuel_file_path).lazy()
airports_df = pl.read_parquet(airports_file_path).lazy()
flightlist_df = pl.read_parquet(flightlist_path).lazy()

flights_folder = pl.scan_parquet(os.path.join(flights_folder_path,"*.parquet"))

skeleton_df = (
    fuel_df
    .join(flightlist_df, on="flight_id", how="left") 
    .with_columns([
        # 1. Duration (in minutes)
        ((pl.col("end") - pl.col("start")).dt.total_minutes()).alias("duration_min"),
        
        # 2. Route (Origin-Dest)
        (pl.col("origin_icao") + "-" + pl.col("destination_icao")).alias("route_icao"),
        
        # 3. Progress Pct: (Segment Start - Takeoff) / (Landed - Takeoff)
        (
            (pl.col("start") - pl.col("takeoff")) / 
            (pl.col("landed") - pl.col("takeoff"))
        ).alias("progress_pct")
    ])
)

joined_df = (
    skeleton_df
    .join(flights_folder, on="flight_id", how="left")
    # filter trajectory points to be within the fuel segment window
    .filter(
        (pl.col("timestamp") >= pl.col("start")) & 
        (pl.col("timestamp") <= pl.col("end"))
    )
)

final_lazy = (
    joined_df
    .sort("timestamp")
    .group_by("idx")
    .agg([
        # Preserve the metadata columns 
        pl.col("aircraft_type").first().alias("aircraft"),
        pl.col("duration_min").first(),
        pl.col("route_icao").first(),
        
        # Create the 'status' struct
        pl.struct([
            pl.col("progress_pct").first().alias("progress_pct"),
            pl.col("altitude").first().alias("alt_start"),
            pl.col("altitude").last().alias("alt_end"),
            pl.col("mach").mean().alias("mach_avg")
        ]).alias("status")
    ])
)


In [14]:
result = final_lazy.collect(streaming=True)

In [17]:
len(result)

Method 2

In [29]:
import polars as pl
import os
import gc
from tqdm import tqdm

# --- Configuration ---
BATCH_SIZE = 100
OUTPUT_DIR = os.path.join(DATA_PATH, "output_batches", DATA_TYPE)
TRAJECTORY_FOLDER = os.path.join(DATA_PATH, "flights_train") 
fuel_file_path = os.path.join(DATA_PATH, "fuel_train.parquet")
flightlist_path = os.path.join(DATA_PATH, "flightlist_train.parquet")

os.makedirs(OUTPUT_DIR, exist_ok=True)

# --- 1. Load Metadata & Sanitize ---
print("Loading metadata...")
fuel_df = pl.read_parquet(fuel_file_path)
flightlist_df = pl.read_parquet(flightlist_path)

skeleton_df = (
    fuel_df
    .join(flightlist_df, on="flight_id", how="left")
    # FIX 1: Sanitize IDs and pre-calculate types
    .with_columns([
        pl.col("flight_id").cast(pl.String).str.strip_chars(),
        ((pl.col("end") - pl.col("start")).dt.total_minutes()).alias("duration_min"),
        (pl.col("origin_icao") + "-" + pl.col("destination_icao")).alias("route_icao"),
        ((pl.col("start") - pl.col("takeoff")) / (pl.col("landed") - pl.col("takeoff"))).alias("progress_pct")
    ])
)

unique_flight_ids = skeleton_df["flight_id"].unique().to_list()
total_flights = len(unique_flight_ids)
print(f"Found {total_flights} flights. Processing...")

# --- 2. Batch Processing ---
for i in tqdm(range(0, total_flights, BATCH_SIZE), desc="Processing Batches"):
    
    batch_ids = unique_flight_ids[i : i + BATCH_SIZE]
    batch_filename = os.path.join(OUTPUT_DIR, f"batch_{i // BATCH_SIZE}.parquet")
    
    if os.path.exists(batch_filename):
        continue

    valid_files = []
    valid_ids_in_batch = []
    
    for fid in batch_ids:
        path = os.path.join(TRAJECTORY_FOLDER, f"{fid}.parquet")
        if os.path.exists(path):
            valid_files.append(path)
            valid_ids_in_batch.append(fid)

    # Prepare the skeleton for this batch
    current_skeleton = skeleton_df.filter(pl.col("flight_id").is_in(batch_ids))

    if valid_files:
        try:
            # Load Trajectories with Normalized Types
            traj_lazy = (
                pl.scan_parquet(valid_files)
                .with_columns([
                    pl.col("flight_id").cast(pl.String).str.strip_chars(),
                    # Force timestamps to Naive Microseconds to match fuel_df
                    pl.col("timestamp").dt.cast_time_unit("us").dt.replace_time_zone(None)
                ])
            )
            
            # Prepare Skeleton for Join with Normalized Types
            skeleton_lazy = (
                current_skeleton.lazy()
                .with_columns([
                    pl.col("start").dt.cast_time_unit("us").dt.replace_time_zone(None),
                    pl.col("end").dt.cast_time_unit("us").dt.replace_time_zone(None)
                ])
            )

            # Calculate Stats
            stats_df = (
                skeleton_lazy
                .join(traj_lazy, on="flight_id", how="inner")
                .filter(
                    # Add 30s buffer to catch points for short/zero-duration segments
                    (pl.col("timestamp") >= pl.col("start").dt.offset_by("-30s")) & 
                    (pl.col("timestamp") <= pl.col("end").dt.offset_by("30s"))
                )
                .sort("timestamp")
                .group_by("idx")
                .agg([
                    pl.col("altitude").first().alias("alt_start"),
                    pl.col("altitude").last().alias("alt_end"),
                    pl.col("mach").mean().alias("mach_avg")
                ])
                .collect() 
            )

            # Join back to ensure all rows are kept
            final_batch = (
                current_skeleton
                .join(stats_df, on="idx", how="left")
                .select([
                    pl.col("idx"),
                    pl.col("flight_id"),
                    pl.col("aircraft_type").alias("aircraft"),
                    pl.col("duration_min"),
                    pl.col("route_icao"),
                    pl.struct([
                        pl.col("progress_pct"),
                        pl.col("alt_start"),
                        pl.col("alt_end"),
                        pl.col("mach_avg")
                    ]).alias("status")
                ])
            )
            final_batch.write_parquet(batch_filename)

        except Exception as e:
            print(f"Error in batch {i}: {e}")

    else:
        # Fallback for batches with no files
        final_batch = (
            current_skeleton
            .select([
                pl.col("idx"),
                pl.col("flight_id"),
                pl.col("aircraft_type").alias("aircraft"),
                pl.col("duration_min"),
                pl.col("route_icao"),
                pl.struct([
                    pl.col("progress_pct"),
                    pl.lit(None).alias("alt_start"),
                    pl.lit(None).alias("alt_end"),
                    pl.lit(None).alias("mach_avg")
                ]).alias("status")
            ])
        )
        final_batch.write_parquet(batch_filename)

    gc.collect()

print("Processing complete.")

In [30]:
# Check one batch
df = pl.read_parquet(os.path.join(OUTPUT_DIR, "batch_0.parquet"))
df.head()