In [1]:
import matplotlib.pyplot as plt

import numpy as np
import os
from ctd.comparison.analysis.tt.tt import Analysis_TT 
from ctd.comparison.analysis.tt.tasks.tt_MultiTask import Analysis_TT_MultiTask
from ctd.comparison.analysis.tt.tasks.tt_RandomTarget import Analysis_TT_RandomTarget
from ctd.comparison.analysis.dd.dd import Analysis_DD
# Import pca
import dotenv
from ctd.comparison.comparison import Comparison
import torch
from sklearn.decomposition import PCA
from sklearn.linear_model import LinearRegression
from matplotlib.animation import FuncAnimation, FFMpegWriter
from mpl_toolkits.mplot3d import Axes3D  # required for 3D plotting
from matplotlib import gridspec
from sklearn.decomposition import PCA
dotenv.load_dotenv(dotenv.find_dotenv())



True

In [2]:

HOME_DIR = os.environ['HOME_DIR']
print(HOME_DIR)
pathTT_3BFF = HOME_DIR + 'content/trained_models/task-trained/tt_3bff/'
pathTT_MT = HOME_DIR + 'content/trained_models/task-trained/tt_MultiTask/'
pathTT_RT = HOME_DIR + 'content/trained_models/task-trained/tt_RandomTarget/'

an_TT_3BFF = Analysis_TT(run_name = "TT_3BFF", filepath = pathTT_3BFF)
an_TT_MT = Analysis_TT_MultiTask(run_name = "TT_MT", filepath = pathTT_MT)
an_TT_RT = Analysis_TT_RandomTarget(run_name = "TT_RT", filepath = pathTT_RT)


path_LDS_Sweep_3BFF = pathTT_3BFF + "20250130_NBFF_LDS_Viz/"
subfolders_LDS_3BFF = [f.path for f in os.scandir(path_LDS_Sweep_3BFF) if f.is_dir()]

path_LDS_Sweep_MT = pathTT_MT + "20250131_MultiTask_LDS_Viz/"
subfolders_LDS_MT = [f.path for f in os.scandir(path_LDS_Sweep_MT) if f.is_dir()]

path_LDS_Sweep_RT = pathTT_RT + "20250130_RandomTarget_LDS_Viz/"
subfolders_LDS_RT = [f.path for f in os.scandir(path_LDS_Sweep_RT) if f.is_dir()]

/home/csverst/Github/CtDBenchmark/


In [3]:
comparison_3BFF = Comparison(comparison_tag="3BFF")
comparison_3BFF.load_analysis(an_TT_3BFF, reference_analysis=True, group = "TT")

for subfolder in subfolders_LDS_3BFF:
    subfolder = subfolder + "/"
    analysis_LDS = Analysis_DD.create(run_name = "LDS", filepath = subfolder, model_type = "SAE")
    comparison_3BFF.load_analysis(analysis_LDS, group = "LDS")

comparison_3BFF.regroup()
# comparison_3BFF.plot_trials(num_trials=2)

In [4]:
comparison_NBFF_single = Comparison(comparison_tag="NBFF_single")
comparison_NBFF_single.load_analysis(an_TT_3BFF, reference_analysis=True, group = "TT")
comparison_NBFF_single.load_analysis(analysis_LDS, group = "LDS")

In [11]:

nbff_spikes = analysis_LDS.get_spiking(phase="val")
latents = an_TT_3BFF.get_latents(phase="val")

nbff_spikes = nbff_spikes.detach().cpu().numpy()
latents= latents.detach().cpu().numpy()

In [19]:


# -------------------------------------------------------------------
# Data Setup
# -------------------------------------------------------------------
# Replace these with your actual data:
# nbff_spikes = analysis_LDS.get_spiking(phase="val")   # shape: (B, T, N)
# latents     = an_TT_3BFF.get_latents(phase="val")      # shape: (B, T, D)
#
# For demonstration purposes, we assume nbff_spikes and latents are already loaded.
# If needed, you can simulate data as follows:
#
# B, T, N, D = 10, 100, 50, 10
# nbff_spikes = np.random.poisson(0.5, size=(B, T, N))
# latents     = np.random.randn(B, T, D)

B, T, N = nbff_spikes.shape
_, _, D = latents.shape

# Set the number of trials to include in the combined video (must be ≤ B).
Y = 6  # Change as needed
frame_skip = 2  # Adjust this value as desired (maximum allowed is 4)

# -------------------------------------------------------------------
# Compute the top-3 Principal Components (PCs)
# -------------------------------------------------------------------
latents_all = latents.reshape(-1, D)
pca = PCA(n_components=3)
pca.fit(latents_all)
all_pc = pca.transform(latents_all)  # shape: (B*T, 3)

# Compute the min and max for each PC across all latent data.
x_min, x_max = np.min(all_pc[:, 0]), np.max(all_pc[:, 0])
y_min, y_max = np.min(all_pc[:, 1]), np.max(all_pc[:, 1])
z_min, z_max = np.min(all_pc[:, 2]), np.max(all_pc[:, 2])

# Precompute the PC coordinates for each trial (for Y trials).
pc_coords_all = np.empty((Y, T, 3))
for trial in range(Y):
    trial_latents = latents[trial, :, :]
    pc_coords_all[trial] = pca.transform(trial_latents)

# -------------------------------------------------------------------
# Setup Colors for Each Trial
# -------------------------------------------------------------------
# Use a colormap (here, tab10) to assign distinct colors to each trial.
colors = plt.cm.tab10(np.linspace(0, 1, Y))

# -------------------------------------------------------------------
# Create the Figure and Subplots with GridSpec
# -------------------------------------------------------------------
fig = plt.figure(figsize=(12, 6))
# Use GridSpec to set the imshow panel to 75% of the width of the 3D panel.
gs = gridspec.GridSpec(1, 2, width_ratios=[1, 0.75])

# Left subplot: 3D latent trajectory.
ax_pc = fig.add_subplot(gs[0], projection='3d')
ax_pc.set_title("Trial 1: Latent Trajectory (Top 3 PCs)")
ax_pc.set_xlabel("PC1")
ax_pc.set_ylabel("PC2")
ax_pc.set_zlabel("PC3")
ax_pc.set_xlim(x_min, x_max)
ax_pc.set_ylim(y_min, y_max)
ax_pc.set_zlim(z_min, z_max)
ax_pc.set_xticklabels([])
ax_pc.set_yticklabels([])
ax_pc.set_zticklabels([])
# Initialize the active trail line and current point using the first trial’s color.
trail_line, = ax_pc.plot([], [], [], color=colors[0], lw=2)
current_point = ax_pc.scatter([], [], [], color=colors[0], s=50)

# Right subplot: Spiking activity.
ax_spike = fig.add_subplot(gs[1])
ax_spike.set_title("Trial 1: Spiking Activity")
# Transpose so that time is on the x-axis and neurons on the y-axis.
im = ax_spike.imshow(nbff_spikes[0].T, aspect='auto',
                     interpolation='nearest', cmap='viridis')
ax_spike.set_xlabel("Time")
ax_spike.set_ylabel("Neuron")
plt.colorbar(im, ax=ax_spike, orientation='vertical')
# Add a vertical red line to indicate the current time point.
vline = ax_spike.axvline(x=0, color='red', lw=2)

# -------------------------------------------------------------------
# Frame Skipping Setup (to speed up the animation)
# -------------------------------------------------------------------
# You are allowed to skip frames if needed, but not more than 4 frames.
# Set frame_skip to 1 (no skip), 2, 3, or 4.

# The total number of "effective" frames is Y*T.
total_effective_frames = Y * T
# The number of frames in the animation will be reduced by the frame_skip factor.
n_frames = int(np.ceil(total_effective_frames / frame_skip))

# -------------------------------------------------------------------
# Define the Update Function for the Animation
# -------------------------------------------------------------------
def update(frame):
    global trail_line, current_point
    # Compute the effective frame number by skipping frames.
    eff_frame = min(frame * frame_skip, total_effective_frames - 1)
    trial_idx = eff_frame // T      # Current trial index (0-indexed)
    time_idx  = eff_frame % T       # Current time point within the trial

    # When starting a new trial:
    if time_idx == 0:
        print(f"Trial: {trial_idx}")
        # For trials after the first, permanently plot the previous trial's trajectory
        # with its designated color in transparent mode.
        if trial_idx > 0:
            prev_pc = pc_coords_all[trial_idx - 1]
            ax_pc.plot(prev_pc[:, 0], prev_pc[:, 1], prev_pc[:, 2],
                       color=colors[trial_idx - 1], alpha=0.3, lw=2)
        # Update the active trail line's color to that of the current trial.
        trail_line.set_color(colors[trial_idx])
        
        # Update spiking activity and subplot titles for the new trial.
        im.set_data(nbff_spikes[trial_idx].T)
        ax_spike.set_title(f"Trial {trial_idx+1}: Spiking Activity")
        ax_pc.set_title(f"Trial {trial_idx+1}: Latent Trajectory (Top 3 PCs)")
        # Clear the active trail data.
        trail_line.set_data([], [])
        trail_line.set_3d_properties([])
        
        # Remove the old current point and create a new one with the current trial's color.
        current_point.remove()
        current_point = ax_pc.scatter([], [], [], color=colors[trial_idx], s=50)

    # Get the precomputed PC coordinates for the current trial.
    current_pc = pc_coords_all[trial_idx]
    # Update the active trail line with all points up to the current time.
    xs = current_pc[:time_idx+1, 0]
    ys = current_pc[:time_idx+1, 1]
    zs = current_pc[:time_idx+1, 2]
    trail_line.set_data(xs, ys)
    trail_line.set_3d_properties(zs)
    
    # Update the current point (marker) to the current coordinate.
    cp = current_pc[time_idx]
    current_point._offsets3d = ([cp[0]], [cp[1]], [cp[2]])
    
    # Update the vertical line in the spiking activity plot.
    vline.set_xdata(time_idx)
    
    # Slowly rotate the 3D axes.
    # Here, we make a full 360° rotation over the entire animation.
    current_azim = 360 * frame / n_frames
    # Optionally, you can also adjust the elevation if desired (here fixed at 30°).
    ax_pc.view_init(elev=30, azim=current_azim)
    
    return trail_line, current_point, vline, im

# -------------------------------------------------------------------
# Create and Save the Animation
# -------------------------------------------------------------------
# Set the interval to ~16.67 ms for 60 frames per second.
interval_ms = 1000 / 60

ani = FuncAnimation(fig, update, frames=n_frames, interval=interval_ms, blit=False)

# Save the combined animation as an MP4 video with 60 fps.
writer = FFMpegWriter(fps=60)
ani.save("combined_trials_animation.mp4", writer=writer)

plt.close(fig)
print("Combined video saved as 'combined_trials_animation.mp4'.")


  vline.set_xdata(time_idx)


Combined video saved as 'combined_trials_animation.mp4'.


In [46]:
comparison_MT = Comparison(comparison_tag="MultiTask")
comparison_MT.load_analysis(an_TT_MT, reference_analysis=True, group = "TT")

for subfolder in subfolders_LDS_MT:
    subfolder = subfolder + "/"
    analysis_LDS = Analysis_DD.create(run_name = "LDS", filepath = subfolder, model_type = "SAE")
    comparison_MT.load_analysis(analysis_LDS, group = "LDS")
    
comparison_MT.regroup()

In [47]:
mt_spikes = analysis_LDS.get_spiking(phase="all")
mt_latents = an_TT_MT.get_latents(phase = "all")
mt_task_flag, phase_dict = an_TT_MT.get_task_flag('MemoryPro', phase = "all")


In [48]:
mt_memPro_spikes = mt_spikes[mt_task_flag, :, :].detach().cpu().numpy()
mt_memPro_latents = mt_latents[mt_task_flag, :, :].detach().cpu().numpy()
readout = an_TT_MT.model.readout
output1 = readout(mt_latents[mt_task_flag, :,:]).detach().cpu().numpy()
print(output1.shape)

(500, 320, 3)


In [49]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from matplotlib.animation import FuncAnimation, FFMpegWriter
from mpl_toolkits.mplot3d import Axes3D  # required for 3D plotting
from matplotlib import gridspec
from sklearn.decomposition import PCA

# =============================================================================
# Preliminaries: Extract Data and Readout
# =============================================================================
trial_count = 30  # how many trials to include in the video
frame_skip = 2    # allowed to skip frames (choose 1,2,3, or 4)

# Convert tensors to numpy arrays (assumes you are using PyTorch)
mt_memPro_spikes = mt_spikes[mt_task_flag, :, :].detach().cpu().numpy()
mt_memPro_latents = mt_latents[mt_task_flag, :, :].detach().cpu().numpy()
readout = an_TT_MT.model.readout  # the readout module
output1 = readout(mt_latents[mt_task_flag, :, :]).detach().cpu().numpy()
print("Readout shape:", output1.shape)  # e.g., (B, T, R) where R>=1

# =============================================================================
# 1. Data Setup – Determine Trial Lengths
# =============================================================================
# Use the "response" phase to determine effective trial length.
trial_lengths = []
for d in phase_dict:
    # If response is inclusive, add one:
    trial_lengths.append(d['response'][1] + 1)
    
# Compute cumulative lengths to map global frame indices to (trial, time)
cumulative_lengths = np.cumsum(trial_lengths)
total_effective_frames_all = cumulative_lengths[-1]

# =============================================================================
# 1a. Reorder Trials by Final Output Angle to Span 0 to 2π
# =============================================================================
num_total_trials = len(phase_dict)
all_angles = []
for i in range(num_total_trials):
    # Use the final output from indices 1 and 2.
    final_vector = output1[i, trial_lengths[i]-1, 1:3]  # note: trial_lengths[i]-1 is the last frame index for trial i
    angle = np.arctan2(final_vector[1], final_vector[0])
    if angle < 0:
        angle += 2 * np.pi
    all_angles.append(angle)
all_angles = np.array(all_angles)

# If there are more available trials than desired, choose a subset that spans the full range.
if num_total_trials >= trial_count:
    sorted_indices = np.argsort(all_angles)
    # Choose equally spaced indices along the sorted order.
    chosen_indices = sorted_indices[np.linspace(0, num_total_trials - 1, trial_count, dtype=int)]
else:
    chosen_indices = np.arange(num_total_trials)

# Reorder the data arrays and phase_dict according to the chosen indices.
mt_memPro_spikes = mt_memPro_spikes[chosen_indices]
mt_memPro_latents = mt_memPro_latents[chosen_indices]
output1 = output1[chosen_indices]
phase_dict = [phase_dict[i] for i in chosen_indices]
trial_lengths = [trial_lengths[i] for i in chosen_indices]
cumulative_lengths = np.cumsum(trial_lengths)
print("Chosen trial indices (ordered by final angle):", chosen_indices)
print("Final angles (radians):", all_angles[chosen_indices])

# =============================================================================
# 2. Compute PCA – Fit on "mem1" Data Only
# =============================================================================
# For each trial, extract only the "mem1" period.
all_mem1_latents_list = []
for i, d in enumerate(phase_dict):
    # Assume that d['mem1'] is a list/tuple like [start, end]
    mem1_range = d['mem1']
    # If boundaries are inclusive, add one to the end index:
    all_mem1_latents_list.append(mt_memPro_latents[i, mem1_range[0]:mem1_range[1]+1, :])
all_mem1_latents_concat = np.concatenate(all_mem1_latents_list, axis=0)

# Fit PCA using only the mem1 period data
pca = PCA(n_components=3)
pca.fit(all_mem1_latents_concat)

# For each trial, transform the full trial latent data (up to its effective length)
trial_pc_list = []
for i, length in enumerate(trial_lengths):
    trial_latents = mt_memPro_latents[i, :length, :]
    trial_pc = pca.transform(trial_latents)  # shape: (length, 3)
    trial_pc_list.append(trial_pc)

# =============================================================================
# 3. Helper Function for Current Phase (for text display)
# =============================================================================
def get_current_phase(trial_idx, time_idx):
    trial_phase_info = phase_dict[trial_idx]
    for phase, (p_start, p_end) in trial_phase_info.items():
        if p_start <= time_idx <= p_end:
            return phase
    return "unknown"

# =============================================================================
# 4. Animation Setup
# =============================================================================
# Use only the first "trial_count" trials.
total_effective_frames = cumulative_lengths[trial_count - 1]
n_frames = int(np.ceil(total_effective_frames / frame_skip))

# Create a figure with two subplots:
# Left: 3D latent trajectory; Right: Spiking activity.
fig = plt.figure(figsize=(12, 6))
gs = gridspec.GridSpec(1, 2, width_ratios=[1, 0.6])

# --- 3D Latent Trajectory Subplot ---
ax_pc = fig.add_subplot(gs[0], projection='3d')
ax_pc.set_title("Trial 1: Latent Trajectory")
ax_pc.set_xlabel("PC1")
ax_pc.set_ylabel("PC2")
ax_pc.set_zlabel("Readout (Z)")
# For axis limits, use PC1/PC2 from full trials and the readout's index 1 for z.
all_pc_full = []
all_readout = []
for i in range(trial_count):
    trial_pc = trial_pc_list[i]  # shape: (trial_lengths[i], 3)
    all_pc_full.append(trial_pc[:, :2])
    # Use output1 index 1 for z.
    all_readout.append(output1[i, :trial_lengths[i], 1:2])
all_pc_full_concat = np.concatenate(all_pc_full, axis=0)
all_readout_concat = np.concatenate(all_readout, axis=0)
ax_pc.set_xlim(np.min(all_pc_full_concat[:, 0]), np.max(all_pc_full_concat[:, 0]))
ax_pc.set_ylim(np.min(all_pc_full_concat[:, 1]), np.max(all_pc_full_concat[:, 1]))
ax_pc.set_zlim(np.min(all_readout_concat), np.max(all_readout_concat))

# Initialize an active trail line and a scatter for the current point.
trail_line, = ax_pc.plot([], [], [], lw=2)
current_point = ax_pc.scatter([], [], [], s=50)

# --- Spiking Activity Subplot ---
ax_spike = fig.add_subplot(gs[1])
ax_spike.set_title("Trial 1: Spiking Activity")
# Show spiking data for trial 0 (truncated to its effective length)
im = ax_spike.imshow(mt_memPro_spikes[0, :trial_lengths[0], :].T, aspect='auto',
                     interpolation='nearest', cmap='viridis')
ax_spike.set_xlabel("Time")
ax_spike.set_ylabel("Neuron")
num_neurons = mt_memPro_spikes.shape[2]
# Set extent so that the horizontal axis spans 0 to trial_lengths[0]
im.set_extent((0, trial_lengths[0], 0, num_neurons))
plt.colorbar(im, ax=ax_spike, orientation='vertical')
vline = ax_spike.axvline(x=0, color='red', lw=2)

# --- Add phase text in the top right of the 3D plot ---
phase_text = ax_pc.text2D(0.95, 0.95, "", transform=ax_pc.transAxes,
                            ha="right", va="top", fontsize=12, color='black')

# --- Global variable for current trial and phase segmentation ---
global_current_trial = -1  # to detect when a new trial starts
last_phase = None          # last phase seen in the current trial
phase_segment_start = 0    # time index (within the trial) where the current phase segment began

# --- Function to compute trial color based on final output angle ---
def get_trial_color(trial_idx):
    # Use the final output1 vector for this trial (at trial_lengths[trial_idx]-1)
    final_vector = output1[trial_idx, trial_lengths[trial_idx]-1, 1:3]  # take indices 1 and 2
    angle = np.arctan2(final_vector[1], final_vector[0])
    if angle < 0:
        angle += 2 * np.pi
    normalized = angle / (2 * np.pi)
    return plt.cm.hsv(normalized)

# =============================================================================
# 5. Update Function for the Animation
# =============================================================================
def update(frame):
    global global_current_trial, last_phase, phase_segment_start, trail_line, current_point

    # Print progress every 10 frames.
    if frame % 10 == 0:
        remaining = n_frames - frame
        print(f"Processing frame {frame}/{n_frames} (remaining: {remaining})")

    # Compute effective global frame index (with frame skipping)
    eff_frame = min(frame * frame_skip, total_effective_frames - 1)
    trial_idx = np.searchsorted(cumulative_lengths, eff_frame, side='right')
    if trial_idx > 0:
        time_idx = eff_frame - cumulative_lengths[trial_idx - 1]
    else:
        time_idx = eff_frame

    # If a new trial has started:
    if trial_idx != global_current_trial:
        # Before switching, freeze the active trace of the previous trial.
        if global_current_trial != -1:
            prev_trial = global_current_trial
            prev_trial_pc = trial_pc_list[prev_trial]
            if phase_segment_start < trial_lengths[prev_trial]:
                segment = prev_trial_pc[phase_segment_start:trial_lengths[prev_trial], :]
                if segment.shape[0] > 0:
                    xs = segment[:, 0]
                    ys = segment[:, 1]
                    zs = output1[prev_trial, phase_segment_start:trial_lengths[prev_trial], 1]
                    # Color the frozen trace using the trial's final output angle.
                    ax_pc.plot(xs, ys, zs,
                               color=get_trial_color(prev_trial),
                               alpha=0.3, lw=2)
        print(f"Switching to trial {trial_idx+1}. Processed {eff_frame} frames so far.")
        global_current_trial = trial_idx
        last_phase = None
        phase_segment_start = 0

        # Update spiking plot for the new trial.
        im.set_data(mt_memPro_spikes[trial_idx, :trial_lengths[trial_idx], :].T)
        im.set_extent((0, trial_lengths[trial_idx], 0, num_neurons))
        ax_spike.set_title(f"Trial {trial_idx+1}: Spiking Activity")
        ax_pc.set_title(f"Trial {trial_idx+1}: Latent Trajectory")
        
        # Clear the active trace for the new trial.
        trail_line.set_data([], [])
        trail_line.set_3d_properties([])
        current_point.remove()
        current_point = ax_pc.scatter([], [], [], s=50)

    # Update the active trace for the current trial.
    trial_pc = trial_pc_list[trial_idx]  # shape: (trial_lengths[trial_idx], 3)
    active_segment = trial_pc[:time_idx+1, :]
    if active_segment.shape[0] > 0:
        xs = active_segment[:, 0]
        ys = active_segment[:, 1]
        # For z, use the corresponding readout output (index 1)
        zs = output1[trial_idx, :time_idx+1, 1]
        trail_line.set_data(xs, ys)
        trail_line.set_3d_properties(zs)
        cp_x, cp_y, cp_z = xs[-1], ys[-1], zs[-1]
        current_point._offsets3d = ([cp_x], [cp_y], [cp_z])

    # Set trace color by trial (based on final output angle)
    trace_color = get_trial_color(trial_idx)
    trail_line.set_color(trace_color)
    current_point.set_color(trace_color)
    
    # Determine current phase (for text display only)
    current_phase = get_current_phase(trial_idx, time_idx)
    phase_text.set_text(f"Phase: {current_phase}")
    
    # Update the vertical red line on the spiking plot.
    vline.set_xdata(time_idx)
    
    # Slowly rotate the 3D view.
    current_azim = 180 * frame / n_frames
    ax_pc.view_init(elev=30, azim=current_azim)
    
    return trail_line, current_point, vline, im, phase_text

# =============================================================================
# 6. Create and Save the Animation
# =============================================================================
interval_ms = 1000 / 30  # about 16.67 ms per frame (60 fps)
ani = FuncAnimation(fig, update, frames=n_frames, interval=interval_ms, blit=False)
writer = FFMpegWriter(fps=30)
ani.save("MultiTask_animation.mp4", writer=writer)

plt.close(fig)
print("Combined video saved as 'MultiTask_animation.mp4'.")


Readout shape: (500, 320, 3)
Chosen trial indices (ordered by final angle): [394 330 300 391 441  78  11 341 208 179 267  96 244 382  28 289 198 145
 183 133 144  46 328 385 235  33  15 187 265  89]
Final angles (radians): [5.16433129e-03 2.39183351e-01 3.82965297e-01 6.60612464e-01
 8.38039041e-01 1.06450009e+00 1.29632914e+00 1.43564343e+00
 1.73454678e+00 1.88094890e+00 2.10781908e+00 2.28656697e+00
 2.46091890e+00 2.64802480e+00 2.93870020e+00 3.19498855e+00
 3.40189392e+00 3.59356362e+00 3.74535257e+00 3.93201191e+00
 4.25495941e+00 4.43342561e+00 4.62641234e+00 4.86270196e+00
 5.06245202e+00 5.38599617e+00 5.65947247e+00 5.89929554e+00
 6.10455934e+00 6.27054740e+00]
Processing frame 0/2085 (remaining: 2085)
Switching to trial 1. Processed 0 frames so far.
Processing frame 0/2085 (remaining: 2085)


  vline.set_xdata(time_idx)


Processing frame 10/2085 (remaining: 2075)
Processing frame 20/2085 (remaining: 2065)
Processing frame 30/2085 (remaining: 2055)
Processing frame 40/2085 (remaining: 2045)
Processing frame 50/2085 (remaining: 2035)
Processing frame 60/2085 (remaining: 2025)
Processing frame 70/2085 (remaining: 2015)
Processing frame 80/2085 (remaining: 2005)
Switching to trial 2. Processed 162 frames so far.
Processing frame 90/2085 (remaining: 1995)
Processing frame 100/2085 (remaining: 1985)
Processing frame 110/2085 (remaining: 1975)
Processing frame 120/2085 (remaining: 1965)
Processing frame 130/2085 (remaining: 1955)
Processing frame 140/2085 (remaining: 1945)
Switching to trial 3. Processed 296 frames so far.
Processing frame 150/2085 (remaining: 1935)
Processing frame 160/2085 (remaining: 1925)
Processing frame 170/2085 (remaining: 1915)
Processing frame 180/2085 (remaining: 1905)
Processing frame 190/2085 (remaining: 1895)
Processing frame 200/2085 (remaining: 1885)
Processing frame 210/2085 (

In [3]:
an_TT_RT = Analysis_TT_RandomTarget(run_name = "TT_RT", filepath = pathTT_RT)

for subfolder in subfolders_LDS_RT:
    subfolder = subfolder + "/"
    analysis_LDS = Analysis_DD.create(run_name = "LDS", filepath = subfolder, model_type = "SAE")


In [8]:
rt_spikes = analysis_LDS.get_spiking(phase="val").detach().cpu().numpy()
rt_latents = an_TT_RT.get_latents(phase = "val").detach().cpu().numpy()
outputs_rt = an_TT_RT.get_model_outputs(phase = 'val')
inputs_rt = an_TT_RT.get_model_inputs(phase= 'val')
ext_inputs = an_TT_RT.get_inputs_to_env(phase= 'val').detach().cpu().numpy()
true_inputs_rt = an_TT_RT.get_true_inputs(phase='val').detach().cpu().numpy()


In [14]:

shoulder_ang = outputs_rt['joints'][:,:,0].detach().cpu().numpy()
elbow_ang = outputs_rt['joints'][:,:,1].detach().cpu().numpy()
target_pos = inputs_rt[2].detach().cpu().numpy()
target_pos_input = true_inputs_rt[:,:,:2]

(200, 155, 2)


In [18]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from matplotlib.animation import FuncAnimation, FFMpegWriter
from mpl_toolkits.mplot3d import Axes3D  # needed for 3D plotting
from matplotlib import gridspec
from sklearn.decomposition import PCA
import torch  # if your data are torch tensors

# Assume these are already defined (for RandomTarget task):
#   rt_spikes: shape (B, T, N)
#   rt_latents: shape (B, T, D)
#   target_pos: shape (B, T, 2)
#   shoulder_ang: shape (B, T)
#   elbow_ang: shape (B, T)
# Also assume that the kinematics parameters are defined:
upper_arm_length = an_TT_RT.env.effector.skeleton.l1
forearm_length   = an_TT_RT.env.effector.skeleton.l2

# =============================================================================
# 1. Data Setup
# =============================================================================
B, T, N = rt_spikes.shape
_, _, D = rt_latents.shape
num_trials = 15  # For video demonstration; adjust as desired.
frame_skip = 2  # choose 1 (no skip), 2, 3, or 4

# =============================================================================
# 2. PCA on Latent Data
# =============================================================================
# Reshape all latent data to (B*T, D) and fit PCA.
all_latents = rt_latents.reshape(-1, D)
pca = PCA(n_components=3)
pca.fit(all_latents)
all_pc = pca.transform(all_latents)

# Precompute the PCA projections for each trial (each trial has T frames)
trial_pc_list = []
for i in range(B):
    trial_latents = rt_latents[i, :, :]
    trial_pc = pca.transform(trial_latents)  # shape (T, 3)
    trial_pc_list.append(trial_pc)

# =============================================================================
# 3. Set Up Figure with Three Subplots
# =============================================================================
# We use GridSpec to create three subplots in one row.
# The left subplot will show 2D Arm Kinematics, the middle will be the 3D latent trajectory,
# and the right will show the Spiking activity (with width 75% of the 3D plot).
fig = plt.figure(figsize=(18, 6))
gs = gridspec.GridSpec(1, 3, width_ratios=[1.4, 1, 0.75])

# --- Left: 2D Arm Kinematics ---
ax_arm = fig.add_subplot(gs[0])
ax_arm.set_title("Trial 1: Arm Kinematics")
ax_arm.set_xlim(-1.2, 1.2)
ax_arm.set_ylim(-0.2, 1.2)
ax_arm.axis("off")  # Hide axes for better visualization

# --- Middle: 3D Latent Trajectory ---
ax_pc = fig.add_subplot(gs[1], projection='3d')
ax_pc.set_title("Trial 1: Latent Trajectory")
ax_pc.set_xlabel("PC1")
ax_pc.set_ylabel("PC2")
ax_pc.set_zlabel("PC3")
ax_pc.set_xlim(np.min(all_pc[:, 0]), np.max(all_pc[:, 0]))
ax_pc.set_ylim(np.min(all_pc[:, 1]), np.max(all_pc[:, 1]))
ax_pc.set_zlim(np.min(all_pc[:, 2]), np.max(all_pc[:, 2]))
# Initialize an empty line and a current-point marker.
trail_line, = ax_pc.plot([], [], [], lw=2)
current_point = ax_pc.scatter([], [], [], s=50)

# --- Right: Spiking Activity ---
ax_spike = fig.add_subplot(gs[2])
ax_spike.set_title("Trial 1: Spiking Activity")
# Display spiking data for trial 0 (transposed so time is x-axis, neurons on y-axis)
im = ax_spike.imshow(rt_spikes[0].T, aspect='auto', interpolation='nearest', cmap='viridis')
ax_spike.set_xlabel("Time")
ax_spike.set_ylabel("Neuron")
plt.colorbar(im, ax=ax_spike, orientation='vertical')
vline = ax_spike.axvline(x=0, color='red', lw=2)

# =============================================================================
# 4. Helper Functions
# =============================================================================
def calculate_arm_positions(shoulder_angle, elbow_angle, upper_arm_length=1.0, forearm_length=1.0):
    """
    Calculate the positions of shoulder, elbow, and hand (2D) given joint angles.
    """
    shoulder_pos = np.array([0, 0])
    elbow_x = np.cos(shoulder_angle) * upper_arm_length
    elbow_y = np.sin(shoulder_angle) * upper_arm_length
    elbow_pos = np.array([elbow_x, elbow_y])
    hand_x = elbow_x + np.cos(shoulder_angle + elbow_angle) * forearm_length
    hand_y = elbow_y + np.sin(shoulder_angle + elbow_angle) * forearm_length
    hand_pos = np.array([hand_x, hand_y])
    return shoulder_pos, elbow_pos, hand_pos

def get_reach_angle(trial_idx):
    """
    Compute the reach angle for a given trial as the angle (in radians)
    between the target position at the first time point and at the last time point.
    """
    start = target_pos[trial_idx, 0, :]  # starting target position
    end = target_pos[trial_idx, -1, :]     # final target position
    vec = end - start
    angle = np.arctan2(vec[1], vec[0])
    if angle < 0:
        angle += 2 * np.pi
    return angle

def get_trial_color_by_reach(trial_idx):
    """
    Map the reach angle to a color using the HSV colormap.
    """
    angle = get_reach_angle(trial_idx)
    normalized = angle / (2 * np.pi)
    return plt.cm.hsv(normalized)

# For displaying progress in the arm plot, you might also want a helper to plot the arm.
# (The arm plot update is done in the update function below.)

def get_current_phase(trial_idx, time_idx):
    """
    (Optional) For text display only: return the current phase.
    In this RandomTarget example, you may not have phases;
    you can return a dummy value or remove this function if not needed.
    """
    return "Reach"  # or use actual phase info if available

# =============================================================================
# 5. Frame Skipping & Global Animation Parameters
# =============================================================================
total_effective_frames = num_trials * T  # all trials are same length in RandomTarget
n_frames = int(np.ceil(total_effective_frames / frame_skip))

# =============================================================================
# 6. Update Function for Animation
# =============================================================================
def update(frame):
    global trail_line, current_point

    # Compute effective frame index (with frame skipping)
    eff_frame = min(frame * frame_skip, total_effective_frames - 1)
    # Determine trial index and time index (all trials have same length T)
    trial_idx = eff_frame // T
    time_idx = eff_frame % T

    # --- Update 3D Latent Trajectory (Middle subplot) ---
    trial_pc = trial_pc_list[trial_idx]  # shape: (T, 3)
    # When starting a new trial, clear the active trace and update titles.
    if time_idx == 0:
        print(f"Trial {trial_idx} started at frame {eff_frame}")
        # Freeze previous trial's trace (if any) by leaving it on the plot with transparency.
        if trial_idx > 0:
            prev_trial_pc = trial_pc_list[trial_idx - 1]
            xs = prev_trial_pc[:, 0]
            ys = prev_trial_pc[:, 1]
            zs = prev_trial_pc[:, 2]
            ax_pc.plot(xs, ys, zs,
                       color=get_trial_color_by_reach(trial_idx - 1),
                       alpha=0.3, lw=2)
        # Reset active trace.
        trail_line.set_data([], [])
        trail_line.set_3d_properties([])
        current_point.remove()
        current_point = ax_pc.scatter([], [], [], s=50)
        # Update titles.
        ax_pc.set_title(f"Trial {trial_idx+1}: Latent Trajectory")
        ax_spike.set_title(f"Trial {trial_idx+1}: Spiking Activity")
        ax_arm.set_title(f"Trial {trial_idx+1}: Arm Kinematics")
        # Update spiking data.
        im.set_data(rt_spikes[trial_idx].T)
        # For RandomTarget, the trial length is fixed so the imshow extent remains the same.
        im.set_extent((0, T, 0, N))

    # Update the active segment of the 3D latent trajectory.
    current_segment = trial_pc[:time_idx+1, :]
    if current_segment.shape[0] > 0:
        xs = current_segment[:, 0]
        ys = current_segment[:, 1]
        zs = current_segment[:, 2]
        trail_line.set_data(xs, ys)
        trail_line.set_3d_properties(zs)
        cp = current_segment[-1, :]
        current_point._offsets3d = ([cp[0]], [cp[1]], [cp[2]])

    # Color the trace by the reach angle for this trial.
    trace_color = get_trial_color_by_reach(trial_idx)
    trail_line.set_color(trace_color)
    current_point.set_color(trace_color)

    # --- Update Spiking Activity Plot (Right subplot) ---
    vline.set_xdata(time_idx)

    # --- Update 2D Arm Kinematics Plot (Left subplot) ---
    ax_arm.cla()
    ax_arm.set_xlim(-1.2, 1.2)
    ax_arm.set_ylim(-0.2, 1.2)
    ax_arm.axis("off")
    ax_arm.set_title(f"Trial {trial_idx+1}: Arm Kinematics")
    # Get the current shoulder and elbow angles (assumed available)
    s_angle = shoulder_ang[trial_idx, time_idx]
    e_angle = elbow_ang[trial_idx, time_idx]
    bump = ext_inputs[trial_idx, time_idx,:]

    shoulder_pos, elbow_pos, hand_pos = calculate_arm_positions(s_angle, e_angle,
                                                                upper_arm_length,
                                                                forearm_length)
    ax_arm.plot([shoulder_pos[0], elbow_pos[0]],
                [shoulder_pos[1], elbow_pos[1]],
                'k-', lw=3)
    ax_arm.plot([elbow_pos[0], hand_pos[0]],
                [elbow_pos[1], hand_pos[1]],
                'k-', lw=3)
    inputs_mag = np.linalg.norm(bump)

    if inputs_mag > 0.001:
        ax_arm.arrow(
            hand_pos[0],
            hand_pos[1],
            bump[0] / (50 * inputs_mag),
            bump[1] / (50 * inputs_mag),
            head_width=0.05,
            head_length=0.1,
            fc="b",
            ec="b",
        )    
        ax_arm.scatter(shoulder_pos[0], shoulder_pos[1], color='black', s=50)
    # Plot the hand trajectory so far.
    hand_traj = []
    for t in range(time_idx+1):
        s_ang = shoulder_ang[trial_idx, t]
        e_ang = elbow_ang[trial_idx, t]
        _, _, h_pos = calculate_arm_positions(s_ang, e_ang, upper_arm_length, forearm_length)
        hand_traj.append(h_pos)
    hand_traj = np.array(hand_traj)
    if hand_traj.shape[0] > 1:
        ax_arm.plot(hand_traj[:, 0], hand_traj[:, 1], 'k--')
    # Plot target position as a red square.
    targ = target_pos[trial_idx, time_idx, :]
    targ_input = target_pos_input[trial_idx, time_idx, :]
    target_rect = patches.Rectangle((targ[0]-0.05, targ[1]-0.05), 0.1, 0.1, color='red')
    if np.sum(targ_input) > 0.001:
        target_rect_input = patches.Rectangle((targ_input[0]-0.05, targ_input[1]-0.05), 0.1, 0.1, color='green')
        ax_arm.add_patch(target_rect_input)
    ax_arm.add_patch(target_rect)
    # Draw a ground line.
    ax_arm.plot([-0.5, 0.5], [0, 0], 'k-', lw=3)
    # Draw a progress bar.
    bar_start = -0.5
    bar_width = 1.0 * (time_idx / T)
    ax_arm.plot([bar_start, bar_start + bar_width], [-0.1, -0.1], 'b-', lw=3)
    ax_arm.set_aspect("equal")

    ax_arm.scatter(hand_pos[0], hand_pos[1], color='black', s=50)
    # --- Slowly Rotate the 3D View ---
    current_azim = 120 * frame / n_frames
    ax_pc.view_init(elev=30, azim=current_azim)
    # print(current_azim)

    return trail_line, current_point, vline, im

# =============================================================================
# 7. Create and Save the Animation
# =============================================================================
interval_ms = 1000 / 60  # about 16.67 ms per frame (60 fps)
n_total_frames = n_frames
ani = FuncAnimation(fig, update, frames=n_total_frames, interval=interval_ms, blit=False)
writer = FFMpegWriter(fps=60)
ani.save("random_target_video.mp4", writer=writer)

plt.close(fig)
print("Combined video saved as 'random_target_video.mp4'.")


Trial 0 started at frame 0
Trial 0 started at frame 0


  vline.set_xdata(time_idx)


Trial 1 started at frame 150
Trial 2 started at frame 300
Trial 3 started at frame 450
Trial 4 started at frame 600
Trial 5 started at frame 750
Trial 6 started at frame 900
Trial 7 started at frame 1050
Trial 8 started at frame 1200
Trial 9 started at frame 1350
Trial 10 started at frame 1500
Trial 11 started at frame 1650
Trial 12 started at frame 1800
Trial 13 started at frame 1950
Trial 14 started at frame 2100
Combined video saved as 'random_target_video.mp4'.
