In [5]:
import plotly.graph_objs as go
import numpy as np
from scipy.spatial.transform import Rotation as R
import imageio
import os

def plot_3d_poses(poses_list, poses_names, name = "default", fps = 2):
    """
    poses_list is a list of T x 4 x 4 poses
    pose_names is a list of names, same len as poses_list
    """

    assert len(poses_list) == len(poses_names)

    num_paths = len(poses_list)
    T = poses_list[0].shape[0]

    def get_walk(poses_list):
        walks = np.zeros((num_paths,T,3))
        for i in range(num_paths):
            for j in range(T):
                walks[i,j,:] = poses_list[i][j,:3,3]
        return walks
    
    walks = get_walk(poses_list)

    # Generate unique colors for each walk
    colors = ['rgba(255,0,0,0.8)']
    if num_paths > 1:
     colors = colors + [f'rgba({np.random.randint(0, 255)}, {np.random.randint(0, 255)}, {np.random.randint(0, 255)}, 0.8)' for _ in range(num_paths - 1)]

    # Find the axis ranges based on the walks
    all_walks = walks.reshape(-1, 3)  # Reshape for simplicity
    x_range = [all_walks[:, 0].min(), all_walks[:, 0].max()]
    y_range = [all_walks[:, 1].min(), all_walks[:, 1].max()]
    z_range = [all_walks[:, 2].min(), all_walks[:, 2].max()]

    # Function to generate a random rotation matrix
    def random_rotation_matrix():
        random_rotation = R.random()
        return random_rotation.as_matrix()

    # Define a function to create an arrow at a given point and direction
    def create_arrow(point, direction, color='red', length_scale=1, showlegend=False):
        length = length_scale * np.linalg.norm([x_range[1] - x_range[0], y_range[1] - y_range[0], z_range[1] - z_range[0]]) * 0.03
        # Normalize the direction
        direction = direction / np.linalg.norm(direction)
        
        # Create the arrow components (shaft and head)
        shaft = go.Scatter3d(
            x=[point[0], point[0] + direction[0] * length],
            y=[point[1], point[1] + direction[1] * length],
            z=[point[2], point[2] + direction[2] * length],
            mode='lines',
            line=dict(color=color, width=4),
            showlegend=showlegend
        )
        
        head = go.Cone(
            x=[point[0] + direction[0] * length],
            y=[point[1] + direction[1] * length],
            z=[point[2] + direction[2] * length],
            u=[direction[0]],
            v=[direction[1]],
            w=[direction[2]],
            sizemode='absolute',
            sizeref=0.1,
            anchor='tip',
            colorscale=[[0, color], [1, color]],
            showscale=False,
            showlegend=showlegend
        )
        
        return shaft, head

    # Create a directory for frames
    if not os.path.exists("frames"):
        os.mkdir("frames")

    # Generate each frame and save as an image file
    image_files = []

    # Generating frames and saving as image files, the loop for this
    for i in range(T):
        frame_data = []
        for w in range(num_paths):
            rotation_matrix = random_rotation_matrix()
            arrow_direction = rotation_matrix @ np.array([1, 0, 0])
            
            # Pass showlegend as False to the create_arrow function
            shaft, head = create_arrow(walks[w, i], arrow_direction, color=colors[w], showlegend=False)
            
            # Add name and legendgroup to the Scatter3d trace
            trace = go.Scatter3d(
                x=walks[w, :i+1, 0],
                y=walks[w, :i+1, 1],
                z=walks[w, :i+1, 2],
                mode='markers+lines',
                marker=dict(size=5, color=colors[w]),
                line=dict(color=colors[w], width=2),
                name=poses_names[w],  # Name for legend
                legendgroup=poses_names[w],  # Same legendgroup for walk dots and arrows
            )
            # Add only for the first frame to avoid duplicate legend entries
            if i == 0:
                trace.legendgrouptitle = dict(text=poses_names[w])

            frame_data.extend([trace, shaft, head])
        
        # Define the figure for the current frame
        fig = go.Figure(
            data=frame_data,
            layout=go.Layout(
                scene=dict(
                    xaxis=dict(range=x_range, autorange=False),
                    yaxis=dict(range=y_range, autorange=False),
                    zaxis=dict(range=z_range, autorange=False),
                    aspectratio=dict(x=1, y=1, z=1),
                    camera=dict(
                    eye=dict(x=1.25, y=1.25, z=1.25),
                    up=dict(x=0, y=0, z=1),
                    center=dict(x=0, y=0, z=0)
                )
                ),
                margin=dict(l=0, r=0, t=0, b=0)  # Reduce white space around the plot
            )
        )
        
        # Save the figure as an image file
        img_file = f'frames/frame_{i:03d}.png'
        fig.write_image(img_file)
        image_files.append(img_file)

    # # Create a GIF using the saved image files
    with imageio.get_writer(f'{name}.gif', mode='I', fps=fps) as writer:
        for filename in image_files:
            image = imageio.imread(filename)
            writer.append_data(image)
            # Optionally, remove the image file after adding it to the GIF
            os.remove(filename)  

    # Clean up the frames directory if desired
    os.rmdir("frames")

    print(f"GIF saved as '{name}.gif'")


GIF saved as 'random_walks.gif'






In [None]:
import plotly.graph_objs as go
import numpy as np
from scipy.spatial.transform import Rotation as R
import imageio
import os

# Function to generate random walks
def random_walks_3D(steps, num_walks):
    walks = np.zeros((num_walks, steps, 3))
    for walk in range(num_walks):
        for i in range(1, steps):
            step = (np.random.rand(3) - 0.5) * 2
            walks[walk, i] = walks[walk, i-1] + step
    return walks

# Generate the random walk data
num_steps = 100
num_walks = 2  # Number of random walks
walks = random_walks_3D(num_steps, num_walks)

# Generate unique colors for each walk
colors = [f'rgba({np.random.randint(0, 255)}, {np.random.randint(0, 255)}, {np.random.randint(0, 255)}, 0.8)' for _ in range(num_walks)]

# Find the axis ranges based on the walks
all_walks = walks.reshape(-1, 3)  # Reshape for simplicity
x_range = [all_walks[:, 0].min(), all_walks[:, 0].max()]
y_range = [all_walks[:, 1].min(), all_walks[:, 1].max()]
z_range = [all_walks[:, 2].min(), all_walks[:, 2].max()]

# Function to generate a random rotation matrix
def random_rotation_matrix():
    random_rotation = R.random()
    return random_rotation.as_matrix()

# Define a function to create an arrow at a given point and direction
def create_arrow(point, direction, color='red', length_scale=1):
    length = length_scale * np.linalg.norm([x_range[1] - x_range[0], y_range[1] - y_range[0], z_range[1] - z_range[0]]) * 0.03
    # Normalize the direction
    direction = direction / np.linalg.norm(direction)
    
    # Create the arrow components (shaft and head)
    shaft = go.Scatter3d(
        x=[point[0], point[0] + direction[0] * length],
        y=[point[1], point[1] + direction[1] * length],
        z=[point[2], point[2] + direction[2] * length],
        mode='lines',
        line=dict(color=color, width=4)
    )
    
    head = go.Cone(
        x=[point[0] + direction[0] * length],
        y=[point[1] + direction[1] * length],
        z=[point[2] + direction[2] * length],
        u=[direction[0]],
        v=[direction[1]],
        w=[direction[2]],
        sizemode='absolute',
        sizeref=0.5,
        anchor='tip',
        colorscale=[[0, color], [1, color]],
        showscale=False
    )
    
    return shaft, head

# Create a directory for frames
if not os.path.exists("frames"):
    os.mkdir("frames")

# Generate each frame and save as an image file
image_files = []
for i in range(num_steps):
    frame_data = []
    for w in range(num_walks):
        rotation_matrix = random_rotation_matrix()
        arrow_direction = rotation_matrix @ np.array([1, 0, 0])
        shaft, head = create_arrow(walks[w, i], arrow_direction, color=colors[w])
        
        frame_data.extend([
            go.Scatter3d(
                x=walks[w, :i+1, 0],
                y=walks[w, :i+1, 1],
                z=walks[w, :i+1, 2],
                mode='markers+lines',
                marker=dict(size=5, color=colors[w]),
                line=dict(color=colors[w], width=2)
            ),
            shaft,
            head
        ])
    
    # Define the figure for the current frame
    fig = go.Figure(
        data=frame_data,
        layout=go.Layout(
            scene=dict(
                xaxis=dict(range=x_range, autorange=False),
                yaxis=dict(range=y_range, autorange=False),
                zaxis=dict(range=z_range, autorange=False),
                aspectratio=dict(x=1, y=1, z=1)
            ),
            margin=dict(l=0, r=0, t=0, b=0)  # Reduce white space around the plot
        )
    )
    
    # Save the figure as an image file
    img_file = f'frames/frame_{i:03d}.png'
    fig.write_image(img_file)
    image_files.append(img_file)

# Create a GIF using the saved image files
with imageio.get_writer('random_walks.gif', mode='I', duration=0.1) as writer:
    for filename in image_files:
        image = imageio.imread(filename)
        writer.append_data(image)
        # Optionally, remove the image file after adding it to the GIF
        os.remove(filename)  

# Clean up the frames directory if desired
os.rmdir("frames")

print("GIF saved as 'random_walks.gif'")
