In [11]:
# %% Imports
import numpy as np
import pandas as pd
import torch
from pathlib import Path
import pickle
from typing import Dict, List, Tuple, Optional, Union
import folium
import warnings
warnings.filterwarnings('ignore')

# Add project root to path
import sys
sys.path.append('../..')

from ml_mobility_ns3.models.vae import ConditionalTrajectoryVAE
from ml_mobility_ns3.utils.model_utils import load_model_from_checkpoint

# %% Configuration
# Update these paths according to your setup
MODEL_PATH = Path("../results/large_lstm_gpu/lstm_vae/best_model.pt")
DATA_PATH = Path("../preprocessing/vae_dataset.npz")
PREPROCESSING_DIR = Path("../preprocessing/")

# --- Device setup ---
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")

# %% Load Model and Data
def load_vae_model_and_data(model_path, data_path, preprocessing_dir):
    """Load model, data, and all necessary components."""
    print("Loading model...")
    model, config = load_model_from_checkpoint(model_path, device)
    
    print("Loading data...")
    data = np.load(data_path)
    
    print("Loading metadata and scalers...")
    with open(preprocessing_dir / "metadata.pkl", 'rb') as f:
        metadata = pickle.load(f)
    
    with open(preprocessing_dir / "scalers.pkl", 'rb') as f:
        scalers = pickle.load(f)
    
    return model, data, metadata, scalers

model, data, metadata, scalers = load_vae_model_and_data(MODEL_PATH, DATA_PATH, PREPROCESSING_DIR)

# --- Create transport mode mappings ---
mode_to_idx = {name: i for i, name in enumerate(metadata['transport_modes'])}
idx_to_mode = {i: name for i, name in enumerate(metadata['transport_modes'])}

print(f"\nModel architecture: {model.architecture}")
print(f"Transport modes: {metadata['transport_modes']}")
print(f"Sequence length: {metadata['sequence_length']}")

# %% Trajectory Generation Function
def generate_trajectories(
    model: ConditionalTrajectoryVAE,
    transport_mode: Union[str, int],
    trip_length: int,
    n_samples: int = 1
) -> List[Dict[str, Union[np.ndarray, str, int]]]:
    """Generate trajectories for a given transport mode."""
    if isinstance(transport_mode, str):
        mode_idx = mode_to_idx[transport_mode]
    else:
        mode_idx = transport_mode
    
    mode_tensor = torch.tensor([mode_idx] * n_samples, dtype=torch.long, device=device)
    length_tensor = torch.tensor([trip_length] * n_samples, dtype=torch.long, device=device)
    
    with torch.no_grad():
        generated = model.generate(mode_tensor, length_tensor, n_samples=n_samples, device=device)
    
    generated_np = generated.cpu().numpy()
    
    # Unscale trajectories
    trajectory_scaler = scalers['trajectory']
    unscaled_trajectories = []
    for i in range(n_samples):
        valid_traj = generated_np[i, :trip_length, :]
        unscaled = trajectory_scaler.inverse_transform(valid_traj)
        unscaled_trajectories.append({
            'trajectory': unscaled,
            'mode': idx_to_mode[mode_idx],
            'length': trip_length,
            'type': 'Generated'
        })
        
    return unscaled_trajectories

# %% Real Data Sampling Function
def sample_real_trajectories(
    data: Dict[str, np.ndarray],
    transport_mode: str,
    n_samples: int = 1
) -> List[Dict[str, Union[np.ndarray, str, int]]]:
    """Sample real trajectories from the dataset."""
    mode_idx = mode_to_idx[transport_mode]
    valid_indices = np.where(data['transport_modes'] == mode_idx)[0]
    
    if len(valid_indices) < n_samples:
        print(f"Warning: Only {len(valid_indices)} trajectories for {transport_mode}. Sampling with replacement.")
        sample_indices = np.random.choice(valid_indices, n_samples, replace=True)
    else:
        sample_indices = np.random.choice(valid_indices, n_samples, replace=False)
        
    samples = []
    trajectory_scaler = scalers['trajectory']
    for idx in sample_indices:
        length = int(data['trip_lengths'][idx])
        scaled_traj = data['trajectories'][idx, :length, :]
        unscaled_traj = trajectory_scaler.inverse_transform(scaled_traj)
        samples.append({
            'trajectory': unscaled_traj,
            'mode': transport_mode,
            'length': length,
            'type': 'Real'
        })
        
    return samples

# %% Visualization Function
def create_interactive_map(
    trajectories: List[Dict],
    center: Optional[Tuple[float, float]] = None,
    zoom_start: int = 12
) -> folium.Map:
    """Create an interactive Folium map with trajectories."""
    if center is None:
        all_lats = [p[0] for traj in trajectories for p in traj['trajectory'][:, :2]]
        all_lons = [p[1] for traj in trajectories for p in traj['trajectory'][:, :2]]
        center = (np.mean(all_lats), np.mean(all_lons))
        
    m = folium.Map(location=center, zoom_start=zoom_start)
    
    # Define colors for modes and types
    mode_colors = {
        'PRIV_CAR_PASSENGER': 'blue',
        'WALKING': 'green',
        'BIKE': 'red',
        'default': 'gray'
    }
    
    for traj in trajectories:
        points = [(lat, lon) for lat, lon in traj['trajectory'][:, :2]]
        color = mode_colors.get(traj['mode'], mode_colors['default'])
        
        # Style differently for real vs generated
        style = {'weight': 4, 'opacity': 0.8}
        if traj['type'] == 'Generated':
            style['dashArray'] = '5, 5'

        folium.PolyLine(
            points,
            color=color,
            popup=f"{traj['type']} {traj['mode']} ({traj['length']} points)",
            **style
        ).add_to(m)
        
        # Add markers
        folium.CircleMarker(points[0], radius=5, color='cyan', fill=True, popup="Start").add_to(m)
        folium.CircleMarker(points[-1], radius=5, color='magenta', fill=True, popup="End").add_to(m)
        
    return m



Using device: cpu
Loading model...
Loading data...
Loading metadata and scalers...

Model architecture: lstm
Transport modes: ['BIKE', 'BUS', 'ELECT_BIKE', 'ELECT_SCOOTER', 'LIGHT_COMM_VEHICLE', 'ON_DEMAND', 'OTHER', 'PLANE', 'PRIV_CAR_DRIVER', 'PRIV_CAR_PASSENGER', 'SUBWAY', 'TAXI', 'TRAIN', 'TRAIN_EXPRESS', 'TRAMWAY', 'TWO_WHEELER', 'WALKING', 'mixed']
Sequence length: 2070


In [12]:
# 1. Generate new trajectories
print("\n1. Generating new trajectories...")
generated_car = generate_trajectories(model, 'PRIV_CAR_PASSENGER', trip_length=1400, n_samples=2)
generated_walk = generate_trajectories(model, 'WALKING', trip_length=250, n_samples=2)
generated_bike = generate_trajectories(model, 'BIKE', trip_length=350, n_samples=2)
all_generated = generated_car + generated_walk + generated_bike
print(f"Generated {len(all_generated)} total trajectories.")
interactive_map = create_interactive_map(all_generated)

# Display the map
interactive_map



1. Generating new trajectories...
Generated 6 total trajectories.


In [13]:
# 2. Sample real trajectories for comparison
print("\n2. Sampling real trajectories...")
real_car = sample_real_trajectories(data, 'PRIV_CAR_PASSENGER', n_samples=2)
real_walk = sample_real_trajectories(data, 'WALKING', n_samples=2)
real_bike = sample_real_trajectories(data, 'BIKE', n_samples=2)
all_real = real_car + real_walk + real_bike
print(f"Sampled {len(all_real)} real trajectories.")

# 3. Display on map
print("\n3. Creating interactive map...")
# Solid lines for REAL, Dashed for GENERATED
# Blue for CAR, Green for WALK, Red for BIKE
interactive_map = create_interactive_map(all_real)

# Display the map
interactive_map


2. Sampling real trajectories...
Sampled 6 real trajectories.

3. Creating interactive map...


In [14]:
generated_car

[{'trajectory': array([[48.763355  ,  2.2417483 ,  3.2850466 ],
         [48.74356   ,  2.2469306 ,  1.138957  ],
         [48.726654  ,  2.2580402 ,  0.72156435],
         ...,
         [48.87992   ,  2.1093218 , 14.76383   ],
         [48.879963  ,  2.1092927 , 14.744748  ],
         [48.88      ,  2.1092641 , 14.725921  ]], dtype=float32),
  'mode': 'PRIV_CAR_PASSENGER',
  'length': 1400,
  'type': 'Generated'},
 {'trajectory': array([[48.766903 ,  2.6003618,  8.512765 ],
         [48.78536  ,  2.5915048,  8.407952 ],
         [48.785725 ,  2.6269233,  9.337168 ],
         ...,
         [48.733322 ,  2.5165026,  5.1790247],
         [48.73332  ,  2.5165   ,  5.1792536],
         [48.73332  ,  2.5164974,  5.179495 ]], dtype=float32),
  'mode': 'PRIV_CAR_PASSENGER',
  'length': 1400,
  'type': 'Generated'}]