In [None]:
import pandas as pd
import pyarrow.parquet as pq
import plotly.graph_objects as go
from paths import PROCESSED_DATA_DIR
from paths import SPLITS_DIR
import matplotlib.pyplot as plt

In [None]:
df = pq.read_table(PROCESSED_DATA_DIR).to_pandas()

In [None]:
#Load the training, validation and test splits
train_df = pq.read_table(SPLITS_DIR / 'train.parquet').to_pandas()
val_df = pq.read_table(SPLITS_DIR / 'val.parquet').to_pandas()
test_df = pq.read_table(SPLITS_DIR / 'test.parquet').to_pandas()

In [None]:
df = val_df.sort_values(['MMSI', 'segment_id', 'Timestamp']) # Ensure data is sorted properly

In [None]:
print(df.head())
print(df.describe())

In [None]:
unique_mmsi = df['MMSI'].unique()
print(f"Total unique MMSIs in dataset: {len(unique_mmsi)}")
n_segments = df.groupby(['MMSI', 'segment_id']).ngroups
print("Number of segments:", n_segments)

In [None]:
#Check for missing datapoints
print(df.isna().sum())

In [None]:
# Check coordinate bounds
print("Lat range:", (df['Latitude'].min(), df['Latitude'].max()))
print("Lon range:", (df['Longitude'].min(), df['Longitude'].max()))

In [None]:
#Compute time differences between consecutive timestamps within each segment 
df['Δt'] = df.groupby(['MMSI', 'segment_id'])['Timestamp'].diff().dt.total_seconds() # Compute time difference *within each segment*
segment_means = df.groupby(['MMSI', 'segment_id'])['Δt'].mean() # Compute mean Δt per segment 
print(segment_means.describe()) # Basic statistics across segments
overall_mean = segment_means.mean() # Mean of all per-segment means
print(f"\nOverall mean Δt across segments: {overall_mean:.2f} seconds ({overall_mean/60:.2f} minutes)")

In [None]:
# Simple scatter plot of positions
plt.figure(figsize=(8,6))
plt.scatter(df['Longitude'], df['Latitude'], s=1, alpha=0.5)
plt.title('AIS Vessel Positions')
plt.xlabel('Longitude')
plt.ylabel('Latitude')
plt.show()

In [None]:
random_mmsi = pd.Series(unique_mmsi).sample(n=10, random_state=42) # Randomly select 50 MMSIs
df_sampled = df[df['MMSI'].isin(random_mmsi)] # Filter the DataFrame for those MMSIs

def plot_vessel_tracks_with_filter(df: pd.DataFrame):
    # Ensure data is sorted
    df = df.sort_values(['MMSI', 'segment_id', 'Timestamp'])
    
    # Prepare traces: one trace per vessel segment
    traces = []
    vessel_list = df['MMSI'].unique()
    for vessel in vessel_list:
        vessel_df = df[df['MMSI'] == vessel]
        for segment in vessel_df['segment_id'].unique():
            segment_df = vessel_df[vessel_df['segment_id'] == segment]
            traces.append(
                go.Scattermap(
                    lat=segment_df['Latitude'],
                    lon=segment_df['Longitude'],
                    mode='markers',
                    line=dict(width=2),
                    name=str(vessel),
                    visible=True,  
                    hoverinfo='text',
                    text=segment_df.apply(lambda row: f"MMSI: {row.MMSI}<br>segment_id: {row.segment_id}<br>SOG: {row.SOG}<br>COG: {row.COG}<br>Time: {row.Timestamp}", axis=1)
                )
            )
    
    # Create figure
    fig = go.Figure(data=traces)
    
    # Create buttons for filtering
    buttons = []
    # Button for "All vessels"
    buttons.append(dict(
        label="All vessels",
        method="update",
        args=[{"visible": [True]*len(traces)},
              {"title": "All Vessels"}]
    ))
    
    # One option per vessel
    for i, vessel in enumerate(vessel_list):
        visibility = [False]*len(traces)
        # Set visible=True for all segments of this vessel
        for j, trace in enumerate(traces):
            if trace.name == str(vessel):
                visibility[j] = True
        buttons.append(dict(
            label=str(vessel),
            method="update",
            args=[{"visible": visibility},
                  {"title": f"Vessel {vessel}"}]
        ))
    
    # Add dropdown menu
    fig.update_layout(
        mapbox_style="open-street-map",
        mapbox_center={"lat":56, "lon": 8},       
        height=800,
        margin={"r":0,"t":50,"l":0,"b":0},
        updatemenus=[dict(
            active=0,
            buttons=buttons,
            x=0,
            y=1.05,
            xanchor='left',
            yanchor='top'
        )],
        title=dict(
        text="Vessel Track using AIS Data",
        x=0.5,         
        xanchor='center', 
        yanchor='top'
    )
    )
    
    fig.show()
plot_vessel_tracks_with_filter(df_sampled)