# Trajectory prediction

## Setup

In [None]:
from pathlib import Path
from tqdm import tqdm
import seaborn as sns
import torch
import torch.nn as nn
from torch.utils.data import DataLoader

from src.data.download import download_ais_data, download_file
from src.data.cleaning import process_multiple_zip_files
from src.data.preprocessing import (
    load_and_prepare_data,
    create_sequences_with_features,
    split_by_vessel,
    normalize_data,
)
from src.models import TrajectoryDataset, EncoderDecoderGRU
from src.utils.model_utils import HaversineLoss, train_model, evaluate_model, create_prediction_sequences, predict_trajectories, plot_training_history, load_model_and_config, visualize_predictions
from src.visualization import plot_trajectory_comparison, create_prediction_map, create_trajectory_map
from src.utils import set_seed

sns.set_style("darkgrid")

# Constants
DATA_DIR = Path("data")
MODEL_PATH = "best_model_encoder_decoder.pt"
MODEL = EncoderDecoderGRU
INPUT_HOURS = 2
OUTPUT_HOURS = 1
SAMPLING_RATE = 5
HIDDEN_SIZE = 256
NUM_LAYERS = 3
BATCH_SIZE = 2048
EPOCHS = 50
LEARNING_RATE = 0.0001
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
SEED = 42

set_seed(SEED)

print(u"Using device: ", DEVICE)

## Data

### Acquisition

In [None]:
ZIP_NAMES = [
    "aisdk-2024-03-01.zip",
    "aisdk-2024-03-02.zip",
    "aisdk-2024-03-03.zip",
    # "aisdk-2024-03-04.zip",
    # "aisdk-2024-03-05.zip",
    # "aisdk-2024-03-06.zip",
    # "aisdk-2024-03-07.zip",
    # "aisdk-2024-03-08.zip",
    # "aisdk-2024-03-09.zip",
    # "aisdk-2024-03-10.zip"
]
# ZIP_NAMES = [] // Uncomment to download all files


if len(ZIP_NAMES) == 0:
    YEAR = "2024"
    MAX_WORKERS = 8
    download_ais_data(YEAR, DATA_DIR, MAX_WORKERS)
else:
    for ZIP_NAME in ZIP_NAMES:
        download_file("http://aisdata.ais.dk/2024/" + ZIP_NAME, DATA_DIR / ZIP_NAME)

### Cleaning

In [None]:
process_multiple_zip_files(DATA_DIR)

In [None]:
df = load_and_prepare_data(DATA_DIR)
sequences, targets, mmsi_labels, feature_cols = create_sequences_with_features(
    df, INPUT_HOURS, OUTPUT_HOURS, SAMPLING_RATE
)

X_train, X_val, X_test, y_train, y_val, y_test = split_by_vessel(
    sequences, targets, mmsi_labels, train_ratio=0.7, val_ratio=0.15, random_seed=42
)

X_train, X_val, X_test, y_train, y_val, y_test, input_scaler, output_scaler = normalize_data(
    X_train, X_val, X_test, y_train, y_val, y_test
)

train_dataset = TrajectoryDataset(X_train, y_train)
val_dataset = TrajectoryDataset(X_val, y_val)
test_dataset = TrajectoryDataset(X_test, y_test)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=8)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=8)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=8)

input_size = len(feature_cols)
output_timesteps = y_train.shape[1] // 2

### Overview

In [None]:
print(f"Total rows: {len(df):,}")
print(f"Unique vessels (MMSI): {df['MMSI'].n_unique()}")
df.head()

In [None]:
MAX_VESSELS = 100
create_trajectory_map(df, MAX_VESSELS)

## Training

In [None]:
model = MODEL(
    input_size=input_size,
    hidden_size=HIDDEN_SIZE,
    num_layers=NUM_LAYERS,
    output_seq_len=output_timesteps,
    dropout=0.3,
).to(DEVICE)
print(f"\nModel architecture:")
print(model)
print(f"\nTotal parameters: {sum(p.numel() for p in model.parameters()):,}")

In [None]:
criterion = HaversineLoss(output_scaler).to(DEVICE)
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=1e-5)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.5, patience=10)

In [None]:
print(f"\nStarting training for {EPOCHS} epochs...")
train_losses = []
val_losses = []
best_val_loss = float("inf")
patience_counter = 0
early_stop_patience = 20

for epoch in range(EPOCHS):
    teacher_forcing_ratio = max(0.2, 1.0 - (0.8 * (epoch / EPOCHS)))
    train_loss = train_model(model, train_loader, criterion, optimizer, DEVICE, epoch, EPOCHS, teacher_forcing_ratio)
    val_loss = evaluate_model(model, val_loader, criterion, DEVICE)

    train_losses.append(train_loss)
    val_losses.append(val_loss)

    scheduler.step(val_loss)

    print(f"Epoch [{epoch+1}/{EPOCHS}] - " f"Train Loss: {train_loss:.6f}, Val Loss: {val_loss:.6f}")

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        patience_counter = 0
        torch.save(
            {
                "epoch": epoch,
                "model_state_dict": model.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
                "val_loss": val_loss,
                "input_scaler": input_scaler,
                "output_scaler": output_scaler,
                "config": {
                    "input_size": input_size,
                    "hidden_size": HIDDEN_SIZE,
                    "num_layers": NUM_LAYERS,
                    "output_seq_len": output_timesteps,
                    "input_hours": INPUT_HOURS,
                    "output_hours": OUTPUT_HOURS,
                    "sampling_rate": SAMPLING_RATE,
                    "feature_cols": feature_cols,
                },
            },
            MODEL_PATH,
        )
        print(f"  -> Saved best model (val_loss: {val_loss:.6f})")
    else:
        patience_counter += 1
        if patience_counter >= early_stop_patience:
            print(f"\nEarly stopping triggered after {epoch+1} epochs")
            break

print(f"\nTraining complete! Best validation loss: {best_val_loss:.6f}")

plot_training_history(train_losses, val_losses)

## Testing

In [None]:
checkpoint = torch.load(MODEL_PATH, weights_only=False)
model.load_state_dict(checkpoint["model_state_dict"])
test_loss = evaluate_model(model, test_loader, criterion, DEVICE)
print(f"Final Test Loss: {test_loss:.6f}")

In [None]:
visualize_predictions(model, test_loader, output_scaler, DEVICE)

## Results

In [None]:
N_VESSELS = 200

model, config, input_scaler, output_scaler = load_model_and_config(
    MODEL_PATH, MODEL
)

sequences, targets, mmsi_list, full_trajectories, timestamps_list = create_prediction_sequences(
    df, config, n_vessels=N_VESSELS
)

predictions, _ = predict_trajectories(model, sequences, input_scaler, output_scaler)

# plot_trajectory_comparison(
#     full_trajectories, 
#     predictions, 
#     mmsi_list, 
#     config["output_hours"], 
# )

create_prediction_map(
    full_trajectories, 
    predictions, 
    mmsi_list, 
    config["output_hours"], 
    "output.html"
)