In [None]:
import pandas as pd
import plotly.graph_objects as go
import plotly.io as py
import numpy as np
import colorsys

In [None]:
def plot_trajectory(observation_number):
    # Load the data
    observation_file = f'./data/Ob{observation_number}.txt'
    df = pd.read_csv(observation_file, delimiter=',', header=None, names=['identity', 'posx', 'posy', 'posz', 'time', 'vx', 'vy', 'vz', 'ax', 'ay', 'az'])

    # Create a figure for the trajectories
    fig_traj = go.Figure()

    # Iterate over each identity and plot its trajectory
    for identity_to_plot in df['identity'].unique():
        # Get the data for this identity
        group = df[df['identity'] == identity_to_plot]

        # Plot the trajectory for this identity
        fig_traj.add_trace(go.Scatter3d(
            x=group['posx'], 
            y=group['posy'], 
            z=group['posz'], 
            mode='lines'
        ))

    # Set labels and title, and disable the legend
    fig_traj.update_layout(
        scene=dict(
            xaxis_title='X',
            yaxis_title='Y',
            zaxis_title='Z',
        ),
        title=f'Trajectories of All Identities for Observation {observation_number}',
        showlegend=False
    )

    # Save the plot
    fig_traj.write_image(f'./fig/traj_Ob{observation_number}.pdf')

def plot_animation(observation_number):
    # Load the data
    observation_file = f'./data/Ob{observation_number}.txt'
    df = pd.read_csv(observation_file, delimiter=',', header=None, names=['identity', 'posx', 'posy', 'posz', 'time', 'vx', 'vy', 'vz', 'ax', 'ay', 'az'])

    # Get the unique timepoints and downsample
    timepoints = df['time'].unique()
    timepoints_downsampled = timepoints[::10]  # Take every 10th timepoint

    # Get the unique identities and random select 100
    identities = df['identity'].unique()
    selected_identities = np.random.choice(identities, size=100, replace=False)

    # Generate a distinct color for each selected identity
    num_selected_identities = len(selected_identities)
    colors = ['rgb'+str(tuple(int(255 * x) for x in colorsys.hsv_to_rgb(h, 1, 1))) for h in np.linspace(0, 1, num_selected_identities)]

    # Create a figure for the animation
    fig_anim = go.Figure(
        data=[
            go.Scatter3d(
                x=df[df['time'] == timepoints_downsampled[0]]['posx'],
                y=df[df['time'] == timepoints_downsampled[0]]['posy'],
                z=df[df['time'] == timepoints_downsampled[0]]['posz'],
                mode='markers',
                marker=dict(color=colors[i % num_selected_identities], size=4),
                name=f'Identity {selected_identities[i % num_selected_identities]}'
            ) for i in range(num_selected_identities)
        ],
        layout=go.Layout(
            updatemenus=[dict(type="buttons",
                              buttons=[dict(label="Play",
                                            method="animate",
                                            args=[None])])],
            scene=dict(
                xaxis_title='X',
                yaxis_title='Y',
                zaxis_title='Z',
                xaxis=dict(range=[df['posx'].min(), df['posx'].max()]),  # Fix x axis
                yaxis=dict(range=[df['posy'].min(), df['posy'].max()]),  # Fix y axis
                zaxis=dict(range=[df['posz'].min(), df['posz'].max()]),  # Fix z axis
                aspectmode='cube'  # Keep 3D box fixed
            ),
            title=f'Animated Trajectories of Different Identities for Observation {observation_number}',
            showlegend=False
        ),
        frames=[
            go.Frame(
                data=[
                    go.Scatter3d(
                        x=df[(df['time'] <= time) & (df['identity'] == identity)]['posx'],
                        y=df[(df['time'] <= time) & (df['identity'] == identity)]['posy'],
                        z=df[(df['time'] <= time) & (df['identity'] == identity)]['posz'],
                        mode='lines',
                        line=dict(color=colors[i % num_selected_identities], width=10),  # Thicker line for trace
                        marker=dict(color=colors[i % num_selected_identities], size=10),
                        opacity=max(0.1, i / len(timepoints_downsampled)),  # Faded transparency for past traces
                        name=f'Identity {identity}'
                    ) for i, (identity, color) in enumerate(zip(selected_identities, colors))
                ] + [
                    go.Scatter3d(
                        x=df[(df['time'] == time) & (df['identity'] == identity)]['posx'],
                        y=df[(df['time'] == time) & (df['identity'] == identity)]['posy'],
                        z=df[(df['time'] == time) & (df['identity'] == identity)]['posz'],
                        mode='markers',
                        marker=dict(color=colors[i % num_selected_identities], size=100),  # Larger marker size for current point
                        name=f'Identity {identity}'
                    ) for i, (identity, color) in enumerate(zip(selected_identities, colors))
                ]
            ) for time in timepoints_downsampled
        ]
    )

    # Save the plot as an HTML file
    fig_anim.write_html(f'./fig/anim_Ob{observation_number}.html')

In [None]:
# # Generate the plot for observations 1 to 19
# for observation_number in range(1, 20):
#     plot_trajectory(observation_number)

In [None]:
# # Generate the plot for observations 1 to 19
# for observation_number in range(1, 20):
#     plot_animation(observation_number)