In [1]:
from manim import *
from utils import download_glm_hmm, save_data_path, val_eid, all_eid
from one.api import ONE
one = ONE(password='international')

In [2]:
from pathlib import Path
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import cv2

```Original # Frame rate: 60.0, Dimensions: 1280x1024```

In [3]:
def get_frames(video_path, start_frame, end_frame):
    cap = cv2.VideoCapture(str(video_path))
    total_frames = end_frame - start_frame
    frames = np.empty((total_frames, 1024, 1280, 3), dtype=np.uint8)  # Pre-allocate with fixed size

    cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame)

    for i in range(total_frames):
        ret, frame = cap.read()
        if not ret:
            break
        frames[i] = frame
    cap.release()

    return frames

In [4]:
def load_trial_data(eid, trial_idx):
    tr = one.load_object(eid, 'trials')

    tr_interval = tr['intervals'][trial_idx]
    stim_on = tr['stimOn_times'][trial_idx]
    stim_off = tr['stimOff_times'][trial_idx]
    go_cue = tr['goCue_times'][trial_idx]
    feeback_time = tr['feedback_times'][trial_idx]

    feedback_type = tr["feedbackType"][trial_idx]

    timeline_data = (tr_interval, stim_on, stim_off, go_cue, feeback_time)

    dets = one.get_details(eid)

    lab = dets['lab']
    subject = dets['subject']

    return timeline_data, feedback_type, lab, subject

def load_frame_data(eid, start, end):
    camera_times = one.load_dataset(eid, f'*leftCamera.times.npy', collection='alf')

    # closest camera times to trial start and end
    camera_start_idx = np.searchsorted(camera_times, start, side='left')
    camera_end_idx = np.searchsorted(camera_times, end, side='right')


    video_path = one.eid2path(eid).joinpath("raw_video_data/_iblrig_leftCamera.raw.mp4")
    frames = get_frames(video_path, camera_start_idx, camera_end_idx)

    frame_pred_path = Path("data") / "frame_probs" / f"{eid}_probs.csv"
    probabilities = pd.read_csv(frame_pred_path).to_numpy()[camera_start_idx:camera_end_idx]

    return camera_start_idx, camera_end_idx, camera_times[camera_start_idx:camera_end_idx], frames, probabilities

def calc_argmaxcumsum(probabilities):
    # Create an array to mark the max class as 1, others as 0
    max_class_array = np.zeros_like(probabilities)
    max_indices = np.argmax(probabilities, axis=1)
    max_class_array[np.arange(probabilities.shape[0]), max_indices] = 1

    # Calculate the cumulative sum of these max class indicators
    cumulative_counts = np.cumsum(max_class_array, axis=0)
    return cumulative_counts, max(cumulative_counts.flatten())

In [5]:
def make_legend(position = UP * 2.5 + RIGHT * 1):
    
    # Create the legend
    labels_legend = ["still", "move", "wheel_turn", "groom"]
    colors_legend = [RED, BLUE, GREEN, YELLOW]
    legend_items = VGroup()

    for label_text, color in zip(labels_legend, colors_legend):
        legend_marker = Square(
            side_length=0.1,
            fill_color=color,
            fill_opacity=1,
            stroke_color=color
        )
        legend_label = Text(label_text, font_size=12)
        legend_entry = VGroup(legend_marker, legend_label)
        legend_entry.arrange(RIGHT, buff=0.2)
        legend_items.add(legend_entry)

    legend_items.arrange(DOWN, aligned_edge=LEFT, buff=0.15)
    legend_items.height = 1

    legend_items.move_to(position)

    return legend_items


In [6]:
def make_image(idx, 
               frames, 
               position = UP * 2 + RIGHT * 4):
    
    image = ImageMobject(frames[idx], image_mode="BGR")
    image.height = 3
    image.move_to(position)
    return image

In [7]:
def make_graph(idx, 
               x, 
               y, 
               colors = None,
               position = UP * 2.5 + LEFT * 3.5):
    """
    Create a graph with the given n lines using x and y values
    
    Parameters:
    idx : int
        Index of the frame
    x : (lines, x_values)
        list of lists of x values
    y : (lines, y_values)
        list of lists of y values
    static_axis : bool
        If True, the x-axis will be static, otherwise it will be dynamic

    Returns:
    graph_group : VGroup
        Group of graph, axes and labels
    """
    
    x_start = x[0][0]
    x_end = x[0][idx]
    ax = Axes(x_range=[x_start, x_end, max((x_end - x_start)/10,1)],
              y_range=[0, 1, 0.2],
              x_length=12,
              y_length=3,
              axis_config={"include_numbers": True},
              tips=False,
              )

    labels = ax.get_axis_labels(x_label=MathTex('t (s)'), y_label=MathTex('p(frame)'))

    lines_arr = []
    for l in range(len(x)):
        line = ax.plot_line_graph(x_values=np.array(x[l]), 
                                y_values=np.array(y[l]), 
                                add_vertex_dots=False,
                                stroke_width=2,
                                line_color= colors[l] if colors else BLUE,
                                )
        lines_arr.append(line)

    # Group axes and labels
    graph_group = VGroup(ax, labels, *lines_arr)
    graph_group.height = 2
    graph_group.move_to(position)
    
    return graph_group

In [8]:
def make_cum_bar(idx,
                 cumulative_counts,
                 c_max = None,
                 position = DOWN * 1.5 + LEFT * 3.7):
    
    y_range = [0, c_max, c_max / 10]
    
    # Create the bar graph
    bar_graph = BarChart(
        values=cumulative_counts[idx],
        bar_names=["Still", "Move", "WheelT", "Groom"],
        bar_colors=[RED, BLUE, GREEN, YELLOW],
        y_range=y_range,
        y_axis_config={"font_size": 16, },
        y_length = 4,
        x_length=6,
    )
    bar_graph.move_to(position)
    title = Text("Cumulative Frame Counts", font_size=18)
    title.next_to(bar_graph, UP, buff=0.1)

    c_bar_lbls = bar_graph.get_bar_labels(font_size=18)

    bar_group = VGroup(bar_graph, title, c_bar_lbls)
    bar_group.height = 4.5

    return bar_group

In [9]:
def make_timeline(timeline_data,
                  play_head_position = None,
                  marker_spacing=1, 
                  scale = 0.7,
                  position = DOWN * 2.5 + RIGHT * 4):
    
    tr_interval, stim_on, stim_off, go_cue, feedback_time = timeline_data
    
    # Create the number line
    number_line = NumberLine(x_range=[tr_interval[0], tr_interval[1], 1],
                             length=5,
                             color=WHITE,
                             include_numbers=True,
                             label_direction=UP)

    # Add vertical lines for events and labels
    stim_on_line = Arrow(
        start=number_line.n2p(stim_on) + UP * marker_spacing,
        end=number_line.n2p(stim_on),
        max_stroke_width_to_length_ratio = 4,
        max_tip_length_to_length_ratio = 0.3,
        tip_shape = ArrowSquareTip
    )
    stim_off_line = Arrow(
        start=number_line.n2p(stim_off) + UP * marker_spacing,
        end=number_line.n2p(stim_off),
        max_stroke_width_to_length_ratio = 4,
        max_tip_length_to_length_ratio = 0.3,
        tip_shape = ArrowSquareTip
    )
    go_cue_line = Arrow(
        start=number_line.n2p(go_cue) + (DOWN) * marker_spacing ,
        end=number_line.n2p(go_cue),
        max_stroke_width_to_length_ratio = 4,
        max_tip_length_to_length_ratio = 0.3,
        tip_shape = ArrowSquareTip
    )
    feedback_line = Arrow(
        start=number_line.n2p(feedback_time) + (DOWN + DOWN*0.4) * marker_spacing ,
        end=number_line.n2p(feedback_time),
        max_stroke_width_to_length_ratio = 4,
        max_tip_length_to_length_ratio = 0.3,
        tip_shape = ArrowSquareTip
    )
    
    # Create labels
    stim_on_label = Text("stimOn", font_size=18).next_to(stim_on_line, UP)
    stim_off_label = Text("stimOff", font_size=18).next_to(stim_off_line, UP)
    go_cue_label = Text("goCue", font_size=18).next_to(go_cue_line, DOWN)
    feedback_label = Text("feedback", font_size=18).next_to(feedback_line, DOWN)
    
    # Playhead indicator
    playhead = Dot(number_line.n2p(tr_interval[0]) if play_head_position is None else number_line.n2p(play_head_position),
                   radius=.1,
                   color=RED)

    number_group = VGroup(number_line, stim_on_line, stim_off_line, go_cue_line, feedback_line)
    label_group = VGroup(stim_on_label, stim_off_label, go_cue_label, feedback_label)

    timeline_group = VGroup(number_group, label_group, playhead)

    timeline_group.move_to(position).scale(scale)

    return timeline_group

In [10]:
def make_text_display(idx,
                      camera_times,
                      feedback_type,
                      lab_name,
                      subj_name,
                      eid,
                      trial,
                      font_size=18,
                      position = DOWN * 0.5 + RIGHT * 4):
    # Displays current frame index, camera time and feedback type
    response = " correct" if feedback_type == 1 else " incorrect"
    resp_color = GREEN if feedback_type == 1 else RED

    vid_eid = Text(eid, font_size=16, slant=NORMAL, color=BLUE)
    trial_num = Text(f"Trial #{trial}", font_size=font_size, slant=NORMAL, color=BLUE)

    time = Text(f"{camera_times[idx]:.2f} (seconds)", font_size=font_size, slant=NORMAL, color=WHITE)
    frame = Text(f"#frame: {idx} ", font_size=font_size, slant=NORMAL, color=WHITE)

    lab_name = Text(f"Lab: {lab_name}", font_size=font_size, slant=NORMAL, color=WHITE)
    subj_name = Text(f"Subject: {subj_name}", font_size=font_size, slant=NORMAL, color=WHITE)

    response = Text(f"Response: {response}", font_size=font_size, slant=NORMAL, color=resp_color)

    idx_grp = VGroup(vid_eid, trial_num).arrange(DOWN, buff=0.1)
    
    local_dat = VGroup(time, frame).arrange(DOWN, aligned_edge=LEFT, buff=0.1)
    global_dat = VGroup(lab_name, subj_name).arrange(DOWN, aligned_edge=RIGHT, buff=0.1)

    text_group = VGroup(local_dat, global_dat).arrange(buff=0.1)

    meta_grp = VGroup(idx_grp, text_group, response).arrange(DOWN, buff=0.2)
    
    meta_grp.move_to(position)
    return meta_grp

In [11]:
eid = val_eid[17]
trial_idx=6

In [12]:
%%manim -qm -v WARNING MouseInspection

class MouseInspection(Scene):

   def construct(self):

      # Load data
      timeline_data, feedback_type, lab, sub = load_trial_data(eid, trial_idx=trial_idx)
      camera_start_idx, camera_end_idx, camera_times, frames, probabilities = load_frame_data(eid, *timeline_data[0])
      cum_counts, c_max = calc_argmaxcumsum(probabilities)
      
      # Initiate empty elements
      image = ImageMobject(np.zeros((1024, 1280, 3)))
      legend = make_legend()
      graph = VGroup()
      cum_graph = VGroup()
      timeline = VGroup()
      meta = VGroup()
      
      self.add(image, graph, legend, timeline, cum_graph, meta)

      # ValueTracker for timestep
      timestep = ValueTracker(0)

      # Updater for image
      def update_image(mob):
         idx = int(timestep.get_value())
         mob.become(make_image(idx, frames))
      
      # Updater for graph
      def update_prob_graph(mob):
         idx = int(timestep.get_value())
         x = np.array([camera_times[:idx+1]]*4)
         y = probabilities[:idx+1].T
         mob.become(make_graph(idx, x, y, colors=[RED, BLUE, GREEN, YELLOW]))

      def update_timeline(mob):
         idx = int(timestep.get_value())
         mob.become(make_timeline(timeline_data, play_head_position=camera_times[idx]))

      def update_cum_graph(mob):
         idx = int(timestep.get_value())
         mob.become(make_cum_bar(idx, cum_counts, c_max))

      def update_meta(mob):
         idx = int(timestep.get_value())
         mob.become(make_text_display(idx, camera_times, feedback_type, lab, sub, eid, trial_idx))
      
      image.add_updater(update_image)
      graph.add_updater(update_prob_graph)
      cum_graph.add_updater(update_cum_graph)
      timeline.add_updater(update_timeline)
      meta.add_updater(update_meta)

      # Animation duration based on number of frames
      # self.play(timestep.animate.set_value(len(camera_times) - 1), run_time=1, rate_func=linear)
      actual_vid_duration = camera_times[-1] - camera_times[0]
      self.play(timestep.animate.set_value(len(camera_times) - 1), run_time=actual_vid_duration, rate_func=linear)
      
      # Remove updaters after animation
      image.remove_updater(update_image)
      graph.remove_updater(update_prob_graph)
      cum_graph.remove_updater(update_cum_graph)
      timeline.remove_updater(update_timeline)
      meta.remove_updater(update_meta)

  return self.scale(length / self.get_length())
  lambda points: scale_factor * points, **kwargs
  alphas = (number - self.x_range[0]) / (self.x_range[1] - self.x_range[0])
                                                                                            