In [1]:
import pandas as pd
import numpy as np
import torch
import warnings
import sys, os

sys.path.append(os.path.abspath('..'))

ais_tracks_path = '../../data/tracks_ais.csv'
cleaned_detections_path = '../../data/cleaned_data/processed_radar_detections.csv'
ais_tracks = pd.read_csv(ais_tracks_path)
radar_detections = pd.read_csv(cleaned_detections_path)

device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print("Using device:", device)



Using device: mps


In [2]:
BATCH_SIZE = 32
MAX_LENGTH = 256
HIDDEN_DIM = 64
NUM_LAYERS_GRU = 1

## Preprocess

In [3]:
radar_detections.dtypes

id_detect       int64
id_track        int64
id_site         int64
id_m2          object
source         object
speed         float64
course        float64
assoc_str       int64
assoc_id        int64
confidence    float64
cdate          object
ctime          object
longitude     float64
latitude      float64
datetime       object
dtype: object

In [8]:
from core.sum_stats import SumStats
from core.vessel_agg import VesselTypeAggregator

sum_stats = SumStats(radar_detections)
summary_df = sum_stats()

In [18]:
summary_cols = [col for col in summary_df.columns if col not in ['id_track', 'duration', 'detections']]
print(f'Summary columns:{summary_cols}')
#dict, key: id_track, value: summary_cols dict
summary_lookup = summary_df.set_index('id_track')[summary_cols].to_dict(orient='index')

Summary columns:['p95_speed', 'p5_speed', 'med_speed', 'curviness', 'heading_mean', 'heading_std', 'turning_mean', 'turning_std', 'distance_total', 'distance_o']


In [52]:
merged = radar_detections.merge(ais_tracks[['id_track', 'type_m2']], left_on = 'assoc_id', right_on = 'id_track', how = 'inner')
merged.rename(columns = {'id_track_x': 'id_track'}, inplace = True)
merged.drop(columns = ['id_track_y'], inplace = True)

vessel_type_aggregator = VesselTypeAggregator()
vessel_type_aggregator.aggregate_vessel_type(merged)

feature_cols = ['speed', 'course', 'longitude', 'latitude']
label_col = 'type_m2_agg'

label_dict = {label: i for i, label in enumerate(merged[label_col].unique())}
label_list = np.array(list(label_dict.keys()))
type_dict = merged.drop_duplicates('id_track').set_index('id_track')[label_col].map(label_dict).to_dict()


In [45]:
merged.columns

Index(['id_detect', 'id_track', 'id_site', 'id_m2', 'source', 'speed',
       'course', 'assoc_str', 'assoc_id', 'confidence', 'cdate', 'ctime',
       'longitude', 'latitude', 'datetime', 'type_m2', 'type_m2_agg'],
      dtype='object')

In [46]:
# Group by track
grouped = merged.groupby('id_track')

# Initialize Dataset
track_data = []
for id_track, group in grouped:
    if id_track not in type_dict:
        continue

    features = torch.tensor(group[feature_cols].values, dtype=torch.float32)
    length = features.size(0)
    summary_vector = torch.tensor(list(summary_lookup[id_track].values()), dtype=torch.float32)
    track_data.append({
        'features': features.to(device),  # T x M
        'summary': summary_vector.to(device),
        'length': length,
        'label': torch.tensor(type_dict[id_track], dtype=torch.long).to(device)
    })

print(f"Prepared {len(track_data)} track tensors (raw features) on {device}")

Prepared 14204 track tensors (raw features) on mps


In [47]:
from torch.nn.utils.rnn import pad_sequence

def collate_fn(batch):
    # Sort by sequence length (optional but helpful for some RNNs)
    batch.sort(key=lambda x: x['length'], reverse=True)
    features = [item['features'] for item in batch]
    summaries = torch.stack([item['summary'] for item in batch])
    labels = torch.stack([item['label'] for item in batch])
    lengths = torch.tensor([seq.size(0) for seq in features])
    padded_features = pad_sequence(features, batch_first=True)  # B, T_max, M

    return padded_features, lengths, summaries, labels

In [58]:
from torch.utils.data import DataLoader, Dataset, Subset
from sklearn.model_selection import StratifiedShuffleSplit
import numpy as np
import torch

class VesselDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

# Set random seed for reproducibility
seed = 42
torch.manual_seed(seed)
np.random.seed(seed)

# Initialize full dataset
full_dataset = VesselDataset(track_data)

# Get all labels
labels = torch.stack([item['label'] for item in track_data]).cpu().numpy()

# Step 1: Train vs temp (val+test)
splitter = StratifiedShuffleSplit(n_splits=1, test_size=0.30, random_state=seed)
train_idx, temp_idx = next(splitter.split(np.zeros(len(labels)), labels))

# Step 2: Temp → val and test
temp_labels = labels[temp_idx]
splitter2 = StratifiedShuffleSplit(n_splits=1, test_size=0.5, random_state=seed)  # 50% to test
val_idx, test_idx = next(splitter2.split(np.zeros(len(temp_labels)), temp_labels))
val_idx = temp_idx[val_idx]
test_idx = temp_idx[test_idx]

# Create datasets
train_set = Subset(full_dataset, train_idx)
val_set = Subset(full_dataset, val_idx)
test_set = Subset(full_dataset, test_idx)

# Sample 40% of train for small training
small_train_size = int(0.4 * len(train_set))
small_labels = labels[train_idx]
splitter_small = StratifiedShuffleSplit(n_splits=1, train_size=small_train_size, random_state=seed)
small_indices, _ = next(splitter_small.split(np.zeros(len(small_labels)), small_labels))
small_train_subset = Subset(train_set, small_indices)

# DataLoaders
train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_set, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn)
test_loader = DataLoader(test_set, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn)

print(f"Data split: {len(train_set)} train, {len(val_set)} val, {len(test_set)} test")


Data split: 9942 train, 2131 val, 2131 test


In [None]:
from collections import Counter
import torch

def check_loader_label_distribution(loader, name=""):
    all_labels = []
    for batch in loader:
        # batch = (padded_features, lengths, summaries, labels)
        labels = batch[-1]  # labels is the last item
        all_labels.extend(labels.cpu().numpy())

    counter = Counter(all_labels)
    total = sum(counter.values())
    print(f"\n{name} Label Distribution:")
    for label, count in sorted(counter.items()):
        print(f"  Label {label}: {count} ({count/total:.2%})")

# Run this:
check_loader_label_distribution(train_loader, name="Train")
check_loader_label_distribution(val_loader, name="Validation")
check_loader_label_distribution(test_loader, name="Test")

## Model Building

In [14]:
import torch.nn as nn

class VesselRNNClassifier(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_classes, num_layers, summary_dim):
        super().__init__()
        self.gru = nn.GRU(input_dim, hidden_dim, num_layers, batch_first=True)
        self.fc1 = nn.Linear(hidden_dim + summary_dim, num_classes)

    def forward(self, x, lengths, summaries):
        #lengths: length B, indicating real length of each sequence
        x_padded = nn.utils.rnn.pack_padded_sequence(x, lengths.cpu(), batch_first=True, enforce_sorted=False) # N * T_max * M
        _, hidden = self.gru(x_padded) 

        concat_layer = torch.cat((hidden[-1], summaries), dim=1)

        #concatenate hidden state with summary vector
        logits = self.fc1(concat_layer)  # B * num_classes
        return logits


In [16]:
EPOCH = 5

from tqdm import tqdm
from torch.optim.lr_scheduler import StepLR

input_dim = len(feature_cols)
num_classes = len(label_dict)

model = VesselRNNClassifier(input_dim, HIDDEN_DIM, num_classes, num_layers=NUM_LAYERS_GRU, summary_dim=len(summary_cols)).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.05)
scheduler = StepLR(optimizer, step_size=2, gamma=0.5) 

train_losses = []
val_losses = []

for epoch in range(EPOCH):
    model.train()
    total_loss = 0
    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}", leave=False)

    for x_batch, lengths, summaries, y_batch in progress_bar:
        x_batch = x_batch.to(device)
        y_batch = y_batch.to(device)
        summaries = summaries.to(device)
        
        optimizer.zero_grad()
        logits = model(x_batch, lengths, summaries) #B * num_classes
        loss = criterion(logits, y_batch)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        progress_bar.set_postfix(loss=loss.item())

    scheduler.step()

    avg_train_loss = total_loss / len(train_loader)
    train_losses.append(avg_train_loss)

    # Evaluate on validation set
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for x_val, lengths_val, summaries_val, y_val in val_loader:
            output_val = model(x_val, lengths_val, summaries_val)
            loss_val = criterion(output_val, y_val)
            val_loss += loss_val.item()

    avg_val_loss = val_loss / len(val_loader)
    val_losses.append(avg_val_loss)

    print(f"✅ Epoch {epoch+1} | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f}")

save_path = 'models/gru_5_epochs.pth'

torch.save({
    'epoch': epoch + 1,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'scheduler_state_dict': scheduler.state_dict(),
    'train_losses': train_losses,
    'val_losses': val_losses,
    'input_dim': input_dim,
    'hidden_dim': HIDDEN_DIM,
    'num_classes': num_classes,
}, save_path)


                                                                      

✅ Epoch 1 | Train Loss: 1.8236 | Val Loss: 1.8552


                                                                      

✅ Epoch 2 | Train Loss: 1.7629 | Val Loss: 2.0339


                                                                      

✅ Epoch 3 | Train Loss: 1.4894 | Val Loss: 1.3750


                                                                      

✅ Epoch 4 | Train Loss: 1.4962 | Val Loss: 1.5347


                                                                      

✅ Epoch 5 | Train Loss: 1.3493 | Val Loss: 1.4799


### Resume Training

In [None]:
##TODO: UPDATE MODEL 

checkpoint = torch.load('models/gru_5_epochs.pth', map_location=device)

# Recreate the model architecture
model = VesselRNNClassifier(
    input_dim=checkpoint['input_dim'],
    hidden_dim=checkpoint['hidden_dim'],
    num_classes=checkpoint['num_classes'],
    num_layers=NUM_LAYERS_GRU
).to(device)

# Load saved state
model.load_state_dict(checkpoint['model_state_dict'])

# Restore optimizer and scheduler
optimizer = torch.optim.Adam(model.parameters(), lr=0.005)  # same LR as before
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

# Continue tracking losses
TRAIN_LOSSES = checkpoint['train_losses']
VAL_LOSSES = checkpoint['val_losses']

# Resume from this epoch
start_epoch = checkpoint['epoch']

for epoch in range(start_epoch, start_epoch + 5):  # continue 5 more epochs
    model.train()
    total_loss = 0
    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}", leave=False)

    for x_batch, lengths, y_batch in progress_bar:
        x_batch = x_batch.to(device)
        y_batch = y_batch.to(device)
        optimizer.zero_grad()
        logits = model(x_batch, lengths)
        loss = criterion(logits, y_batch)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        progress_bar.set_postfix(loss=loss.item())

    scheduler.step()

    avg_train_loss = total_loss / len(train_loader)
    train_losses.append(avg_train_loss)

    # Evaluate on validation set
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for x_val, lengths_val, y_val in val_loader:
            output_val = model(x_val, lengths_val)
            loss_val = criterion(output_val, y_val)
            val_loss += loss_val.item()

    avg_val_loss = val_loss / len(val_loader)
    val_losses.append(avg_val_loss)

    print(f"✅ Epoch {epoch+1} | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f}")


                                                                         

✅ Epoch 6 | Train Loss: 1.6352 | Val Loss: 1.6514


                                                                     

✅ Epoch 7 | Train Loss: 1.6380 | Val Loss: 1.6324


                                                                     

✅ Epoch 8 | Train Loss: 1.6363 | Val Loss: 1.6632


                                                                     

✅ Epoch 9 | Train Loss: 1.6342 | Val Loss: 1.6847


                                                                      

✅ Epoch 10 | Train Loss: 1.6393 | Val Loss: 1.6743


                                                                      

✅ Epoch 11 | Train Loss: 1.6332 | Val Loss: 1.6981


                                                                      

✅ Epoch 12 | Train Loss: 1.6360 | Val Loss: 1.6496


                                                                      

✅ Epoch 13 | Train Loss: 1.6356 | Val Loss: 1.6512


                                                                      

✅ Epoch 14 | Train Loss: 1.6309 | Val Loss: 1.6425


                                                                      

✅ Epoch 15 | Train Loss: 1.6319 | Val Loss: 1.6442


                                                                      

✅ Epoch 16 | Train Loss: 1.5850 | Val Loss: 1.6864


                                                                      

✅ Epoch 17 | Train Loss: 1.5874 | Val Loss: 1.5905


                                                                      

✅ Epoch 18 | Train Loss: 1.5679 | Val Loss: 1.5778


                                                                      

✅ Epoch 19 | Train Loss: 1.6000 | Val Loss: 1.6583


                                                                      

✅ Epoch 20 | Train Loss: 1.6454 | Val Loss: 1.6877


                                                                      

✅ Epoch 21 | Train Loss: 1.6511 | Val Loss: 1.7079


                                                                      

✅ Epoch 22 | Train Loss: 1.6306 | Val Loss: 1.6997


                                                                      

✅ Epoch 23 | Train Loss: 1.6303 | Val Loss: 1.7322


                                                                      

✅ Epoch 24 | Train Loss: 1.6105 | Val Loss: 1.6678


                                                                      

✅ Epoch 25 | Train Loss: 1.6100 | Val Loss: 1.6301


                                                                      

✅ Epoch 26 | Train Loss: 1.6542 | Val Loss: 1.6602


                                                                      

✅ Epoch 27 | Train Loss: 1.6330 | Val Loss: 1.6863


                                                                      

✅ Epoch 28 | Train Loss: 1.6328 | Val Loss: 1.6667


                                                                      

✅ Epoch 29 | Train Loss: 1.6208 | Val Loss: 1.6030


                                                                      

✅ Epoch 30 | Train Loss: 1.6203 | Val Loss: 1.6973


                                                                      

✅ Epoch 31 | Train Loss: 1.6349 | Val Loss: 1.7357


                                                                      

✅ Epoch 32 | Train Loss: 1.6457 | Val Loss: 1.6544


                                                                      

✅ Epoch 33 | Train Loss: 1.6090 | Val Loss: 1.7722


                                                                      

✅ Epoch 34 | Train Loss: 1.6372 | Val Loss: 1.6811


                                                                      

✅ Epoch 35 | Train Loss: 1.6321 | Val Loss: 1.6689


                                                                      

✅ Epoch 36 | Train Loss: 1.6207 | Val Loss: 1.6785


                                                                      

✅ Epoch 37 | Train Loss: 1.6098 | Val Loss: 1.6317


                                                                      

✅ Epoch 38 | Train Loss: 1.6064 | Val Loss: 1.5775


                                                                      

✅ Epoch 39 | Train Loss: 1.5869 | Val Loss: 1.6311


                                                                      

✅ Epoch 40 | Train Loss: 1.5973 | Val Loss: 1.6602


In [55]:
##TODO: UPDATE MODEL PATH

model_path = 'models/gru_40_epochs.pth'

torch.save({
    'epoch': epoch + 1,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'scheduler_state_dict': scheduler.state_dict(),
    'train_losses': train_losses,
    'val_losses': val_losses,
    'input_dim': input_dim,
    'hidden_dim': HIDDEN_DIM,
    'num_classes': num_classes,
}, model_path)


### Model Eval

In [18]:
# Ensure model is in eval mode
model.eval()

all_preds = []
all_labels = []

with torch.no_grad():

    progress_bar = tqdm(test_loader, desc=f"Running Test", leave=False)
    for x_test, lengths_test, summaries_test, y_test in progress_bar:
        x_test = x_test.to(device)
        lengths_test = lengths_test  # assuming these are CPU-side already
        logits = model(x_test, lengths_test, summaries_test)

        preds = torch.argmax(logits, dim=1)  # get class with highest logit
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(y_test.cpu().numpy())  # true labels

# Convert to numpy arrays if needed
import numpy as np
all_preds = np.array(all_preds)
all_labels = np.array(all_labels)

# (Optional) Evaluate accuracy or confusion matrix
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix

print("🔍 Test Accuracy:", accuracy_score(all_labels, all_preds))
print("\n📊 Classification Report:\n", classification_report(all_labels, all_preds))


                                                             

🔍 Test Accuracy: 0.5371402042711235

📊 Classification Report:
               precision    recall  f1-score   support

           0       0.49      0.23      0.32       158
           1       0.00      0.00      0.00        25
           2       0.67      0.92      0.78       869
           3       0.74      0.05      0.10       387
           4       0.25      0.05      0.09        39
           5       0.35      0.78      0.48       376
           6       0.18      0.01      0.02       224
           7       0.00      0.00      0.00        76

    accuracy                           0.54      2154
   macro avg       0.34      0.26      0.22      2154
weighted avg       0.52      0.54      0.44      2154



  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
