# Trajectory prediction

## Setup

In [1]:
from pathlib import Path
from tqdm import tqdm
import seaborn as sns
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import numpy as np

from src.data.download import download_ais_data, download_file
from src.data.cleaning import process_multiple_zip_files
from src.data.preprocessing import (
    load_and_prepare_data,
    create_sequences,
    split_by_vessel,
    normalize_data,
)

from src.models import TrajectoryDataset, EncoderDecoderGRU, EncoderDecoderGRUWithAttention
from src.utils.model_utils import HaversineLoss, train_model, evaluate_model, create_prediction_sequences, predict_trajectories, plot_training_history, load_model_and_config, visualize_predictions
from src.visualization import plot_trajectory_comparison, create_prediction_map, create_trajectory_map
from src.utils import set_seed, haversine_distance
from src.utils import pointwise_haversine, mean_haversine_error, rmse_haversine, ade, fde, dtw_distance_trajectory, dtw_batch_mean


sns.set_style("darkgrid")

# Constants
DATA_DIR = Path("data")
MODEL_PATH = "best_model_encoder_decoder.pt"
MODEL = EncoderDecoderGRUWithAttention
INPUT_HOURS = 2
OUTPUT_HOURS = 1
SAMPLING_RATE = 5
HIDDEN_SIZE = 256
NUM_LAYERS = 3
BATCH_SIZE = 512
EPOCHS = 50
LEARNING_RATE = 0.000001
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
SEED = 42

set_seed(SEED)

print(u"Using device: ", DEVICE)

Using device:  cpu


## Data

### Acquisition

In [2]:
ZIP_NAMES = [
    "aisdk-2024-03-01.zip",
    # "aisdk-2024-03-02.zip",
    # "aisdk-2024-03-03.zip",
    # "aisdk-2024-03-04.zip",
    # "aisdk-2024-03-05.zip",
    # "aisdk-2024-03-06.zip",
    # "aisdk-2024-03-07.zip",
    # "aisdk-2024-03-08.zip",
    # "aisdk-2024-03-09.zip",
    # "aisdk-2024-03-10.zip"
]
# ZIP_NAMES = [] // Uncomment to download all files

if len(ZIP_NAMES) == 0:
    YEAR = "2024"
    MAX_WORKERS = 8
    download_ais_data(YEAR, DATA_DIR, MAX_WORKERS)
else:
    for ZIP_NAME in ZIP_NAMES:
        download_file("http://aisdata.ais.dk/2024/" + ZIP_NAME, DATA_DIR / ZIP_NAME)

Skipping aisdk-2024-03-01.zip (already exists)


### Cleaning

In [2]:
process_multiple_zip_files(DATA_DIR)

Found 1 zip file(s) in data
Requested 4 worker(s)
Skipping aisdk-2024-03-01.zip - aisdk-2024-03-01.parquet already exists

No files to process (all already exist)


{'processed': 0, 'skipped': 1, 'failed': 0}

In [3]:
df = load_and_prepare_data(DATA_DIR)
sequences, targets, mmsi_labels, feature_cols = create_sequences(
    df, INPUT_HOURS, OUTPUT_HOURS, SAMPLING_RATE
)

X_train, X_val, X_test, y_train, y_val, y_test = split_by_vessel(
    sequences, targets, mmsi_labels, train_ratio=0.7, val_ratio=0.15, random_seed=42
)

X_train, X_val, X_test, y_train, y_val, y_test, input_scaler, output_scaler = normalize_data(
    X_train, X_val, X_test, y_train, y_val, y_test
)

train_dataset = TrajectoryDataset(X_train, y_train)
val_dataset = TrajectoryDataset(X_val, y_val)
test_dataset = TrajectoryDataset(X_test, y_test)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=8)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=8)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=8)

input_size = len(feature_cols)
output_timesteps = y_train.shape[1] // 2

Loading data...
Found 1 parquet files


Loading files: 100%|██████████| 1/1 [00:00<00:00,  6.03it/s]

Loaded 3228001 rows from 1 files






Merging continuous segments across file boundaries...
  Original (MMSI, FileIndex, Segment) combinations: 821
  Merged (MMSI, GlobalSegment) combinations: 821
  Merged 0 segment pairs across file boundaries
  Time gap threshold: 15.0 minutes

Creating sequences (2h input -> 1h output)...
  Filtering sequences: min 5 km/h (15 km over 3h)
Resampling and adding features with Polars...
Creating sequences with stride=1
Processing by ['MMSI', 'GlobalSegment'] (merged continuous segments across files)...


Processing segments: 100%|██████████| 821/821 [00:02<00:00, 376.28it/s]


  Skipped 7391 sequences with irregular time spacing
  Skipped 12345 sequences with insufficient distance traveled
Created 65930 sequences from 567 unique vessels
  Input shape: (65930, 24, 12)
  Target shape: (65930, 24)
  Stride: 1 timesteps (5 minutes)
  Sequence overlap: ~95.8%

Vessel-based split:
  Train vessels: 396 (70%)
  Val vessels: 85 (15%)
  Test vessels: 86 (15%)
  Train sequences: 45355
  Val sequences: 10437
  Test sequences: 10138
  ✅ No vessel overlap - proper split confirmed!

Normalizing data...
  Enforcing uniform spatial scaling: 1.9582
  ✅ Data validation passed: No NaNs or Infs detected
  X_train_scaled range: [-5.00, 5.00]
  y_train_scaled range: [-3.70, 2.64]
  Input features: 12
  Features normalized (Lat, Lon, SOG, SOG_diff): [0, 1, 2, 9]
  Features NOT normalized (sin/cos): [3, 4, 5, 6, 7, 8, 10, 11]
  Input scaler - mean: [5.60377200e+01 1.12586038e+01 5.46309012e+00 4.24998081e-03]
  Input scaler - scale: [1.95822011 1.95822011 2.21855121 0.75522489]
  Ou

### Overview

In [None]:
print(f"Total rows: {len(df):,}")
print(f"Unique vessels (MMSI): {df['MMSI'].n_unique()}")
df.head()

In [None]:
MAX_VESSELS = 100
create_trajectory_map(df, MAX_VESSELS)

## Training

In [4]:
model = MODEL(
    input_size=input_size,
    hidden_size=HIDDEN_SIZE,
    num_layers=NUM_LAYERS,
    output_seq_len=output_timesteps,
    dropout=0.3,
).to(DEVICE)
print(f"\nModel architecture:")
print(model)
print(f"\nTotal parameters: {sum(p.numel() for p in model.parameters()):,}")


Model architecture:
EncoderDecoderGRUWithAttention(
  (encoder): GRU(12, 256, num_layers=3, batch_first=True, dropout=0.3)
  (attention): Attention(
    (W_h): Linear(in_features=256, out_features=256, bias=False)
    (W_s): Linear(in_features=256, out_features=256, bias=False)
    (v): Linear(in_features=256, out_features=1, bias=False)
  )
  (decoder): GRU(258, 256, num_layers=3, batch_first=True, dropout=0.3)
  (fc): Linear(in_features=256, out_features=2, bias=True)
)

Total parameters: 2,314,498


In [5]:
criterion = HaversineLoss(output_scaler).to(DEVICE)
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=1e-5)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.5, patience=10)

In [None]:
print(f"\nStarting training for {EPOCHS} epochs...")
train_losses = []
val_losses = []
best_val_loss = float("inf")
patience_counter = 0
early_stop_patience = 20

In [None]:

for epoch in range(EPOCHS):
    teacher_forcing_ratio = max(0.2, 1.0 - (0.8 * (epoch / EPOCHS)))
    train_loss = train_model(model, train_loader, criterion, optimizer, DEVICE, epoch, EPOCHS, teacher_forcing_ratio)
    val_loss = evaluate_model(model, val_loader, criterion, output_scaler, DEVICE)

    train_losses.append(train_loss)
    val_losses.append(val_loss)

    scheduler.step(val_loss)

    print(f"Epoch [{epoch+1}/{EPOCHS}] - " f"Train Loss: {train_loss:.6f}, Val Loss: {val_loss:.6f}")

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        patience_counter = 0
        torch.save(
            {
                "epoch": epoch,
                "model_state_dict": model.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
                "val_loss": val_loss,
                "input_scaler": input_scaler,
                "output_scaler": output_scaler,
                "config": {
                    "input_size": input_size,
                    "hidden_size": HIDDEN_SIZE,
                    "num_layers": NUM_LAYERS,
                    "output_seq_len": output_timesteps,
                    "input_hours": INPUT_HOURS,
                    "output_hours": OUTPUT_HOURS,
                    "sampling_rate": SAMPLING_RATE,
                    "feature_cols": feature_cols,
                },
            },
            MODEL_PATH,
        )
        print(f"  -> Saved best model (val_loss: {val_loss:.6f})")
    else:
        patience_counter += 1
        if patience_counter >= early_stop_patience:
            print(f"\nEarly stopping triggered after {epoch+1} epochs")
            break

print(f"\nTraining complete! Best validation loss: {best_val_loss:.6f}")

plot_training_history(train_losses, val_losses)

## Testing

In [6]:
checkpoint = torch.load(MODEL_PATH, map_location=torch.device(DEVICE), weights_only=False)
model.load_state_dict(checkpoint["model_state_dict"])
test_loss, test_predictions, true_targets = evaluate_model(model, test_loader, criterion, output_scaler ,DEVICE)
print(f"Final Test Loss: {test_loss:.6f}")

Final Test Loss: 3.592035


In [None]:
visualize_predictions(model, test_loader, output_scaler, DEVICE)

## Results

In [None]:
N_VESSELS = 25

model, config, input_scaler, output_scaler = load_model_and_config(
    MODEL_PATH, MODEL
)

sequences, targets, mmsi_list, full_trajectories, timestamps_list = create_prediction_sequences(
    df, config, n_vessels=N_VESSELS
)

predictions, _ = predict_trajectories(model, sequences, input_scaler, output_scaler)

# plot_trajectory_comparison(
#     full_trajectories, 
#     predictions, 
#     mmsi_list, 
#     config["output_hours"], 
# )

create_prediction_map(
    full_trajectories, 
    predictions, 
    mmsi_list, 
    config["output_hours"], 
    "output.html"
)

# Statistics

In [17]:
# Ground truth and predictions, both (N, 2*T)
y_pred = test_predictions  # predicted trajectories
y_true = true_targets  # ground truth

# ---- reshape to (N, T, 2) generically ----
N, D = y_true.shape          # N = number of samples, D = 2 * T
T = D // 2                   # T = number of timesteps

y_true = y_true.reshape(N, T, 2)
y_pred = y_pred.reshape(N, T, 2)

print("y_true shape:", y_true.shape)  
print("y_pred shape:", y_pred.shape)  
print()
# ---- now metrics will work ----
print("Mean Haversine Error (km):", mean_haversine_error(y_true, y_pred))
print("Root Mean Squared Error (km):", rmse_haversine(y_true, y_pred))
print("Average Displacement (km):", ade(y_true, y_pred))
print("Final Displacement (km):", fde(y_true, y_pred))
print("Dynamic Time Warping Distance (km):", dtw_batch_mean(y_true, y_pred))


y_true shape: (10138, 12, 2)
y_pred shape: (10138, 12, 2)

Mean Haversine Error (km): 3.1057066917419434
Root Mean Squared Error (km): 4.803907871246338
Average Displacement (km): 3.1057066917419434
10138
Final Displacement (km): 6.41465425491333
Dynamic Time Warping Distance (km): 1.4476720768843871
