In [80]:
import plotly.express as px

import os
import sys
cur_dir = os.path.dirname(os.path.abspath("__file__"))  # Gets the current notebook directory
src_dir = os.path.join(cur_dir, '../')  # Constructs the path to the 'src' directory
if src_dir not in sys.path:
    sys.path.append(src_dir)
    
from src.constant import sidewalks, stations
import pandas as pd


### Load Model

In [94]:
pred_all = pd.read_csv('../data/pred_tra_all.csv', index_col=0)
truth_all = pd.read_csv('../data/truth_tra_all.csv', index_col=0)
input_all = pd.read_csv('../data/input_tra_all.csv', index_col=0)

In [95]:
trajectory_id = 0

pred_tra = pred_all[pred_all['trajectory_id'] == trajectory_id]
truth_tra = truth_all[truth_all['trajectory_id'] == trajectory_id]
input_tra = input_all[input_all['trajectory_id'] == trajectory_id]

pred_tra['type'] = 'pred'
truth_tra['type'] = 'truth'
input_tra['type'] = 'input'

df = pd.concat([pred_tra, truth_tra, input_tra])

In [96]:
plt = px.line(df, x="X", y="Y", animation_frame="Group_ID", animation_group="type",
                 color="type", hover_name="type",
                 range_x=[0, 15000], range_y=[5000,10000])

plt.update_layout({
    'autosize': True,
    'plot_bgcolor': 'rgba(255, 255, 255, 100)',  # Makes plot background transparent
    'paper_bgcolor': 'rgba(255, 255, 255, 100)', # Makes the entire figure background transparent
    'xaxis': {'showgrid': False},        # Hides the x-axis grid lines
    'yaxis': {'showgrid': False}         # Hides the y-axis grid lines
})
plt.update_xaxes(title_text='', showticklabels=False, visible=False)  # Hides the entire x-axis
plt.update_yaxes(title_text='', showticklabels=False, visible=False)  # Hides the entire y-axis


# Function to add a line to the Plotly figure
def add_sidewalk(fig, x0, y0, x1, y1, showlegend):
    fig.add_shape(type='line',
                  x0=x0, y0=y0, x1=x1, y1=y1,
                  line=dict(color='black', width=2, dash='dash'),
                  name='sidewalks',
                  legendgroup='sidewalks',  # this groups legend entries together
                  showlegend=showlegend)
    return fig


# Adding lines to the figure
for i, (key, v) in enumerate(sidewalks.items()):
    showlegend = True if i == 0 else False
    plt = add_sidewalk(plt, *v, showlegend=showlegend)


def draw_rectangle(fig, center, lx, ly, label):
    cx, cy = center
    x0 = cx - lx / 2
    y0 = cy - ly / 2
    x1 = x0 + lx
    y1 = y0 + ly
    
    # Add rectangle shape
    fig.add_shape(type="rect",
                  x0=x0, y0=y0, x1=x1, y1=y1,
                  line=dict(color="black", width=2),
                  fillcolor="rgba(0,0,0,0)",
                  name=str(label))
    
    # Determine text offset based on the y-coordinate
    dy = -200 if cy <= 8000 else 150
    
    # Add text annotation
    fig.add_annotation(x=cx, y=cy + dy, text=str(label),
                       showarrow=False,
                       bgcolor='yellow',
                       bordercolor='black',
                       borderpad=4,
                       font=dict(color='black'))

    return fig

for k, v in stations.items():
    plt = draw_rectangle(plt, v, 500, 100, k)

plt.show()