In [33]:
import pandas as pd 
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np

import plotly.express as px
import plotly.graph_objects as go
from tqdm import tqdm

import os
import sys
cur_dir = os.path.dirname(os.path.abspath("__file__"))  # Gets the current notebook directory
src_dir = os.path.join(cur_dir, '../')  # Constructs the path to the 'src' directory
if src_dir not in sys.path:
    sys.path.append(src_dir)
    
from src.constant import sidewalks, stations

In [21]:
base_name = "PID001_Control_1_6_14_9"
input = pd.read_csv(f"../data/PredictionModelOutput/{base_name}_ModelInputs.csv")
output = pd.read_csv(f"../data/PredictionModelOutput/{base_name}_ModelOutputs.csv")
features = pd.read_csv(f"../data/FeatureGeneratorOutput/{base_name}.csv")

input.columns = [c.strip().replace(' ', '_') for c in input.columns]
output.columns = [c.strip().replace(' ', '_') for c in output.columns]
features.columns = [c.strip().replace(' ', '_') for c in features.columns]

input = input[['Timestamp', 'User_X', 'User_Y', 'AGV_X', 'AGV_Y']]

output = output.groupby('ModelRunTimestamp').apply(lambda x : x[['X', 'Y']].to_numpy()).reset_index()
output.columns = ['Timestamp', 'pred_traj']

features = features[['Timestamp', 'Phase1_scenario_num', 'Phase2_scenario_num', 'GazeDirection_X', 'GazeDirection_Y']]
features[['Phase1_scenario_num', 'Phase2_scenario_num']] = features[['Phase1_scenario_num', 'Phase2_scenario_num']].astype(int)

In [22]:
# Parameters
PAST_WINDOW = 30  # how many previous rows (including current) to use
FUTURE_WINDOW = 40  # how many future rows (including current) to use

def get_input_traj(dframe, idx, window=30):
    """
    Returns the past `window` rows of [User_X, User_Y], up to and including row `idx`.
    If there aren't enough past rows, you'll get as many as possible from the start.
    """
    start_idx = max(0, idx - (window - 1))
    return dframe.loc[start_idx : idx, ['User_X', 'User_Y']].to_numpy()

def get_gt_traj(dframe, idx, window=40):
    """
    Returns the future `window` rows of [User_X, User_Y], starting at row `idx`.
    If there aren't enough future rows, you'll get as many as possible until the end.
    """
    end_idx = min(len(dframe) - 1, idx + (window - 1))
    return dframe.loc[idx : end_idx, ['User_X', 'User_Y']].to_numpy()

input['input_traj'] = input.apply(lambda row: get_input_traj(input, row.name, PAST_WINDOW), axis=1)
input['gt_traj'] = input.apply(lambda row: get_gt_traj(input, row.name, FUTURE_WINDOW), axis=1)

# Merge input and output
data = pd.merge(input, output, on='Timestamp', how='inner')
data = pd.merge(data, features, on='Timestamp', how='inner')

data.sort_values(by='Timestamp', inplace=True)


In [4]:
data['Phase1_scenario_num'].unique()

array([ 2,  4,  6,  7, 10, 11, 13, 15, 16,  0,  1])

## Configs

- ``PHASE1_SCENARIO``
- ``SAVE_DIR``

In [35]:
PHASE1_SCENARIO = 2
SAVE = False
SAVE_PATH = f"../data/animation/{base_name}_phase1_scenario_{PHASE1_SCENARIO}.gif"

df = data[data['Phase1_scenario_num'] == PHASE1_SCENARIO][['Timestamp', 'input_traj', 'gt_traj', 'pred_traj', 'User_X', 'User_Y', 'AGV_X', 'AGV_Y', 'GazeDirection_X', 'GazeDirection_Y']]

In [30]:
def flatten_trajectory(df, col_name, label):
    """
    - 'df': original DataFrame
    - 'col_name': e.g. 'input_traj'
    - 'label': e.g. 'input_traj' (for the 'type' column)
    
    Returns a DataFrame with:
        Timestamp, frame, X, Y, type
    """
    # 1) We keep Timestamp & the array column, then explode so each point is a row
    out = df[['Timestamp', col_name]].explode(col_name, ignore_index=False)
    
    # 2) Within each original row, keep track of sub-point index via cumcount()
    out['frame'] = out.groupby(level=0).cumcount()  # e.g. 0,1,2 for the first array
    
    # 3) The exploded column is still an array-like of shape [X, Y].
    #    We'll convert it to two columns:
    out[['X', 'Y']] = pd.DataFrame(out[col_name].tolist(), index=out.index)
    
    # 4) Add a label indicating which type of trajectory
    out['type'] = label
    
    # 5) Drop the original array column
    out.drop(columns=[col_name], inplace=True)
    
    # Reset index so it’s clean for the final result
    return out.reset_index(drop=True)


In [38]:
def build_plot(df, downsample=1):
    df = df.iloc[::downsample]
    # Flatten each trajectory type
    df_input_flat = flatten_trajectory(df, 'input_traj', 'input_traj')
    df_gt_flat = flatten_trajectory(df, 'gt_traj', 'gt_traj')
    df_pred_flat = flatten_trajectory(df, 'pred_traj', 'pred_traj')

    # Concatenate them all
    df_flat = pd.concat([df_input_flat, df_gt_flat, df_pred_flat], ignore_index=True)
    plt = px.line(df_flat, x="X", y="Y", animation_frame="Timestamp", animation_group="type",
                    color="type", hover_name="type",
                    range_x=[0, 15000], range_y=[5000,10000], 
                    width=15000 / 12, height=5000 / 8)

    plt.update_layout({
        'autosize':  False, # True,
        'plot_bgcolor': 'rgba(255, 255, 255, 100)',  # Makes plot background transparent
        'paper_bgcolor': 'rgba(255, 255, 255, 100)', # Makes the entire figure background transparent
        'xaxis': {'showgrid': False},        # Hides the x-axis grid lines
        'yaxis': {'showgrid': False}         # Hides the y-axis grid lines
    })
    plt.update_xaxes(title_text='', showticklabels=False, visible=False)  # Hides the entire x-axis
    plt.update_yaxes(title_text='', showticklabels=False, visible=False)  # Hides the entire y-axis

    # Adjust animation duration and frame duration
    plt.layout.updatemenus[0].buttons[0].args[1]['frame']['duration'] = 500
    plt.layout.updatemenus[0].buttons[0].args[1]['transition']['duration'] = 100


    # Function to add a line to the Plotly figure
    def add_sidewalk(fig, x0, y0, x1, y1, showlegend):
        fig.add_shape(type='line',
                    x0=x0, y0=y0, x1=x1, y1=y1,
                    line=dict(color='black', width=2, dash='dash'),
                    name='sidewalks',
                    legendgroup='sidewalks',  # this groups legend entries together
                    showlegend=showlegend)
        return fig


    # Adding lines to the figure
    for i, (key, v) in enumerate(sidewalks.items()):
        showlegend = True if i == 0 else False
        plt = add_sidewalk(plt, *v, showlegend=showlegend)


    def draw_rectangle(fig, center, lx, ly, label):
        cx, cy = center
        x0 = cx - lx / 2
        y0 = cy - ly / 2
        x1 = x0 + lx
        y1 = y0 + ly
        
        # Add rectangle shape
        fig.add_shape(type="rect",
                    x0=x0, y0=y0, x1=x1, y1=y1,
                    line=dict(color="black", width=2),
                    fillcolor="rgba(0,0,0,0)",
                    name=str(label))
        
        # Determine text offset based on the y-coordinate
        dy = -200 if cy <= 8000 else 150
        
        # Add text annotation
        fig.add_annotation(x=cx, y=cy + dy, text=str(label),
                        showarrow=False,
                        bgcolor='yellow',
                        bordercolor='black',
                        borderpad=4,
                        font=dict(color='black'))

        return fig

    for k, v in stations.items():
        plt = draw_rectangle(plt, v, 500, 100, k)
        
    # Add User and AGV locations to each frame
    for frame in plt.frames:
        timestamp = frame.name
        # Filter data for the current frame
        frame_data = df[df['Timestamp'] == timestamp]
        
        # Extract user and AGV positions
        user_x, user_y = frame_data.iloc[0]['User_X'], frame_data.iloc[0]['User_Y']
        agv_x, agv_y = frame_data.iloc[0]['AGV_X'], frame_data.iloc[0]['AGV_Y']
        
        # Add scatter points for User and AGV
        frame.data += (
            go.Scatter(x=[user_x], y=[user_y], mode='markers', marker=dict(size=10, color='blue'), name='User'),
            go.Scatter(x=[agv_x], y=[agv_y], mode='markers', marker=dict(size=10, color='red'), name='AGV')
        )
        
        
    return plt


plt = build_plot(df)
plt.show()

if SAVE:
    import imageio
    import plotly.io as pio
    import shutil
    plt = build_plot(df, downsample=5)
    os.makedirs(os.path.dirname(SAVE_PATH), exist_ok=True)
    temp_path = os.path.join(os.path.dirname(SAVE_PATH), 'temp')
    shutil.rmtree(temp_path, ignore_errors=True)
    os.makedirs(temp_path, exist_ok=True)
    for frame_number, frame_data in enumerate(plt.frames):
        plt.update(data = frame_data.data, layout = frame_data.layout)
        pio.write_image(plt, os.path.join(temp_path, f"frame_{frame_number:03d}.png"))
    
    frame_files = [os.path.join(temp_path, f) for f in sorted(os.listdir(temp_path))]
    with imageio.get_writer(SAVE_PATH, mode="I") as writer:  # Adjust FPS as needed
        for frame_file in tqdm(frame_files, desc="Saving animation", maxinterval=len(frame_files)):
            writer.append_data(imageio.imread(frame_file))

    print(f"Animation saved as {SAVE_PATH}")