# Try to visualize how the warping of the input image occurs via the motion field output of the model

## First, let's just do some inference.

In [1]:
# model_names = ["Original-Pretrained-R2plus1DMotionSegNet_model.pth", "dropout-0_10-R2plus1DMotionSegNet_model.pth", "dropout-0_25-R2plus1DMotionSegNet_model.pth", "dropout-0_50-R2plus1DMotionSegNet_model.pth", "dropout-0_75-R2plus1DMotionSegNet_model.pth"]
model_names = ["Original_Pretrained_R2plus1DMotionSegNet_model.pth", "dropout_v3_0_10_R2plus1D_18_MotionNet.pth"]

In [2]:
%config Completer.use_jedi = False

import echonet
from echonet.datasets import Echo

import torch.nn.functional as F
from torchvision.models.video import r2plus1d_18
from torch.utils.data import Dataset, DataLoader, Subset
from multiprocessing import cpu_count

from src.utils.torch_utils import TransformDataset, torch_collate
from src.transform_utils import generate_2dmotion_field
from src.loss_functions import huber_loss, convert_to_1hot, convert_to_1hot_tensor

from src.model.R2plus1D_18_MotionNet import R2plus1D_18_MotionNet # original model
# new models (small alterations)
from src.model.dropout_0_10_R2plus1D_18_MotionNet import dropout_0_10_R2plus1D_18_MotionNet 
from src.model.dropout_0_25_R2plus1D_18_MotionNet import dropout_0_25_R2plus1D_18_MotionNet 
from src.model.dropout_0_50_R2plus1D_18_MotionNet import dropout_0_50_R2plus1D_18_MotionNet 
from src.model.dropout_0_75_R2plus1D_18_MotionNet import dropout_0_75_R2plus1D_18_MotionNet 


from src.echonet_dataset import EchoNetDynamicDataset
from src.clasfv_losses import deformation_motion_loss, motion_seg_loss, DiceLoss, categorical_dice
from src.train_test import train, test



######
# for slider visualizations
%matplotlib widget
import matplotlib.pyplot as plt
import scipy.ndimage as ndimage
import numpy as np

from scipy.ndimage import correlate
from skimage.filters import *

from ipywidgets import VBox, IntSlider, AppLayout
# initialize to use dark background ...
plt.style.use('dark_background')
######
# for creating gif animation and viewing them
from matplotlib import animation
from IPython.display import Image

######

import torch
import torch.nn as nn
import torch.optim as optim

import random
import pickle
import time

tic, toc = (time.time, time.time)

## Load in models

We have the luxury of trying inference on multiple models, since I'm trying to create slightly different models.

Original pre-trained model by Yida will be `model_name_1`.
The one that I will add dropout and k-fold cross validation to will be `model_name_2`.

In [3]:
# hold tuples of (name, model object)
loaded_in_models = []

for model_name in model_names:
    model_save_path = f"save_models/{model_name}"
    
    # original model
    if model_name == "Original-Pretrained-R2plus1DMotionSegNet_model.pth":
        # model = DDP(R2plus1D_18_MotionNet())
         model = torch.nn.DataParallel(R2plus1D_18_MotionNet())
        
    # altered models
    if model_name == "dropout-0_75-R2plus1DMotionSegNet_model.pth":
        # model = DDP(dropout_0_75_R2plus1D_18_MotionNet())
        model = torch.nn.DataParallel(dropout_0_75_R2plus1D_18_MotionNet())
    if model_name == "dropout-0_50-R2plus1DMotionSegNet_model.pth":
        # model = DDP(dropout_0_50_R2plus1D_18_MotionNet())
        model = torch.nn.DataParallel(dropout_0_50_R2plus1D_18_MotionNet())
    if model_name == "dropout-0_25-R2plus1DMotionSegNet_model.pth":
        # model = DDP(dropout_0_25_R2plus1D_18_MotionNet())
        model = torch.nn.DataParallel(dropout_0_25_R2plus1D_18_MotionNet())
    if model_name == "dropout-0_10-R2plus1DMotionSegNet_model.pth":
        # model = DDP(dropout_0_10_R2plus1D_18_MotionNet())
        model = torch.nn.DataParallel(dropout_0_10_R2plus1D_18_MotionNet())
    
    
    model.to("cuda")
    torch.cuda.empty_cache()
    model.load_state_dict(torch.load(model_save_path)["model"])
    print(f'{model_name} has {sum(p.numel() for p in model.parameters() if p.requires_grad)} parameters.')
    model.eval();
    
    loaded_in_models.append((model_name, model))

print(len(loaded_in_models))

NameError: name 'model' is not defined

## Load in Testing Dataset to do inference on pretrained model

In [None]:
with open("fold_indexes/stanford_valid_sampled_indices", "rb") as infile:
    valid_mask = pickle.load(infile)
infile.close()

# batch_size = 4
# num_workers = max(4, cpu_count()//2)

# def worker_init_fn_valid(worker_id):                                                          
#     np.random.seed(np.random.get_state()[1][0] + worker_id)

test_dataset = EchoNetDynamicDataset(split='test', clip_length="full", raise_for_es_ed=False, period=1)
# test_dataloader = DataLoader(valid_dataset, 
#                               batch_size=batch_size, 
#                               num_workers=num_workers,
#                               shuffle=False, 
#                               pin_memory=("cuda"),
#                               worker_init_fn=worker_init_fn_valid
#                              )

In [None]:
print(len(test_dataset))

### Grab a video to look at, can be random or manually choose one of the 1276 from test dataset
### For sake of comparison, let's look at first sample from the test dataset

In [None]:
# test_pat_index = np.random.randint(len(test_dataset))
test_pat_index = 0 

video, (filename, EF, es_clip_index, ed_clip_index, es_index, ed_index, es_frame, ed_frame, es_label, ed_label) = test_dataset[test_pat_index]

In [None]:
print(type(video))

### Get all possible 32-Frame Clips that covers ED-ES

In [None]:
def divide_to_consecutive_clips(video, clip_length=32, interpolate_last=False):
    source_video = video.copy()
    video_length = video.shape[1]
    left = video_length % clip_length
    if left != 0 and interpolate_last:
        source_video = torch.Tensor(source_video).unsqueeze(0)
        source_video = F.interpolate(source_video, size=(int(np.round(video_length / clip_length) * clip_length), 112, 112),
                                     mode="trilinear", align_corners=False)
        source_video = source_video.squeeze(0).squeeze(0)
        source_video = source_video.numpy()
    
    videos = np.empty(shape=(1, 3, clip_length, 112, 112))

    for start in range(0, int(clip_length * np.round(video_length / clip_length)), clip_length):
        one_clip = source_video[:, start: start + clip_length]
        one_clip = np.expand_dims(one_clip, 0)
        videos = np.concatenate([videos, one_clip])
    return videos[1:]


# goes thru a video and annotates where we can start clips given video length, cli length, etc.
def get_all_possible_start_points(ed_index, es_index, video_length, clip_length):
    assert es_index - ed_index > 0, "not a ED to ES clip pair"
    possible_shift = clip_length - (es_index - ed_index)
    allowed_right = video_length - es_index
    if allowed_right < possible_shift:
        return np.arange(ed_index - possible_shift + 1, video_length - clip_length + 1)
    if possible_shift < 0:
        return np.array([ed_index])
    elif ed_index < possible_shift:
        return np.arange(ed_index + 1)
    else:
        return np.arange(ed_index - possible_shift + 1, ed_index + 1)

In [None]:
possible_starts = get_all_possible_start_points(ed_index, es_index, video.shape[1], clip_length=32)
print(len(possible_starts))
print(possible_starts)

In [None]:
ed_index

In [None]:
es_index

### Segment All 32-Frame Clips

In [None]:
# segment using all models
all_segmentation_outputs = []
all_motion_outputs = []

# for each model, segment the clips
for name, model in loaded_in_models:
    
    segmentation_outputs = np.empty(shape=(1, 2, 32, 112, 112))
    motion_outputs = np.empty(shape=(1, 4, 32, 112, 112))
    for start in possible_starts:
        one_clip = np.expand_dims(video[:, start: start + 32], 0)
        segmentation_output, motion_output = model(torch.Tensor(one_clip))
        segmentation_outputs = np.concatenate([segmentation_outputs, segmentation_output.cpu().detach().numpy()])
        motion_outputs = np.concatenate([motion_outputs, motion_output.cpu().detach().numpy()])
    segmentation_outputs = segmentation_outputs[1:]
    motion_outputs = motion_outputs[1:]
    
    # save 
    all_segmentation_outputs.append(segmentation_outputs)
    all_motion_outputs.append(motion_outputs)


In [None]:
print(len(all_segmentation_outputs), len(all_motion_outputs))
print(len(all_segmentation_outputs[0]))

## Know that our shapes are: [Forward x, y, backward x, y]


In [None]:
segmentation_outputs_1.shape

In [None]:
motion_outputs_1.shape

In [None]:
# last 32 frame segmented clip
(motion_outputs_1[-1].shape, motion_outputs_1[-1].min(), motion_outputs_1[-1].max(), motion_outputs_1[-1].dtype)

In [None]:
motion_outputs_1[-1][0].shape

In [None]:
motion_outputs_1[-1][0][0].shape

In [None]:
tmp = motion_outputs_1[-1][0][0]
tmp = (tmp - tmp.min()) / (tmp.max() - tmp.min())

In [None]:
tmp.min(), tmp.max()

In [None]:
motion_outputs_1[-1][0][0].min(), motion_outputs_1[-1][0][0].max()

In [None]:
mid = (motion_outputs_1[-1][0][0].max() + motion_outputs_1[-1][0][0].min()) / 2
print(mid)

In [None]:
# def find_mid(x):
#     '''Expects a single image, calculates the mid point value of this single image
#     '''
#     return (x.min() + x.max()) / 2

In [None]:
# def normalize(x):
#     ''' normalizes the thing passed in, assumes input is a ndarray numpy object'''
#     return (x - x.min()) / (x.max() - x.min())

In [None]:
clip_index = -1   # last 32 frame clip
frame_index = 0   # first frame
which_direction = [0,1,2,3] # in order of: forward x,y, backward x,y 

## Plot frame 0 of all 4

In [None]:
# plot all 4, forward x,y and backward x, y
# make sure to normalize.
fig, (ax1, ax2, ax3, ax4) = plt.subplots(1, 4, figsize=(10,5));
plt.suptitle(model_name_1)
ax1.set_title("forward x");
fx = motion_outputs_1[clip_index][which_direction[0]][frame_index]
ax1_img = ax1.imshow(fx, cmap="viridis");
fig.colorbar(ax1_img, ax=ax1)

ax2.set_title("forward y")
fy = motion_outputs_1[clip_index][which_direction[1]][frame_index]
ax2_img = ax2.imshow(fy, cmap="viridis");
fig.colorbar(ax2_img, ax=ax2)

ax3.set_title("backward x")
bx = motion_outputs_1[clip_index][which_direction[2]][frame_index]
ax3_img = ax3.imshow(bx, cmap="viridis");
fig.colorbar(ax3_img, ax=ax3)

ax4.set_title("backward y")
by = motion_outputs_1[clip_index][which_direction[3]][frame_index]
ax4_img = ax4.imshow(by, cmap="viridis");
fig.colorbar(ax4_img, ax=ax4)

fig.show()

## That's not good, as seen by the colorbars that the same colors across images do not mean the same number. Let's fix that using `vmin` and `vmax`

In [None]:
clip_index = -1   # last 32 frame clip
frame_index = 0   # first frame
which_direction = [0,1,2,3] # in order of: forward x,y, backward x,y 

# find the absolute max and min values of the 4 motion fields for this singular frame and use that for min and max
all_mins = [all_motion_outputs[0][clip_index][which_direction[i]][frame_index].min() for i in range(4)]
all_maxes = [all_motion_outputs[0][clip_index][which_direction[i]][frame_index].max() for i in range(4)]
color_min = min(all_mins)
color_max = max(all_maxes)
print(color_min, color_max)

In [None]:
# plot all 4, forward x,y and backward x, y
# make sure to normalize.
fig, (ax1, ax2, ax3, ax4) = plt.subplots(1, 4, figsize=(15,5));

ax1.set_title("forward x");
fx = all_motion_outputs[0][clip_index][which_direction[0]][frame_index]
ax1_img = ax1.imshow(fx, cmap="viridis", vmin=color_min, vmax=color_max);
# fig.colorbar(ax1_img, ax=ax1)

ax2.set_title("forward y")
fy = all_motion_outputs[0][clip_index][which_direction[1]][frame_index]
ax2_img = ax2.imshow(fy, cmap="viridis", vmin=color_min, vmax=color_max);
# fig.colorbar(ax2_img, ax=ax2)

ax3.set_title("backward x")
bx = all_motion_outputs[0][clip_index][which_direction[2]][frame_index]
ax3_img = ax3.imshow(bx, cmap="viridis", vmin=color_min, vmax=color_max);
# fig.colorbar(ax3_img, ax=ax3)

ax4.set_title("backward y")
by = all_motion_outputs[0][clip_index][which_direction[3]][frame_index]
ax4_img = ax4.imshow(by, cmap="viridis", vmin=color_min, vmax=color_max);

cbar_ax = fig.add_axes([0.92, 0.1, 0.01, 0.75])
fig.colorbar(ax4_img, cax=cbar_ax)

# fig.tight_layout()
fig.show()

## Look at all frames in our specific clip

In [None]:
plt.ioff()

# Create a slide object

slider = IntSlider(
    value=0, # start value
    min=0,
    max=31,
    step=1,
    description='Frame:',
    continuous_update=True,
    orientation='horizontal',
)

slider.layout.margin = '0px 00% 0px 00%'
slider.layout.width = '80%'

#######################

# calculate midpoints of all 

# plot all 4, forward x,y and backward x, y
fig, (ax1, ax2, ax3, ax4) = plt.subplots(1, 4, figsize=(15,5))
ax1.set_title("forward x");
fx = motion_outputs_1[clip_index][which_direction[0]][frame_index]
ax1.imshow(fx, cmap="viridis");

ax2.set_title("forward y")
fy = motion_outputs_1[clip_index][which_direction[1]][frame_index]
ax2.imshow(fy, cmap="viridis");

ax3.set_title("backward x")
bx = motion_outputs_1[clip_index][which_direction[2]][frame_index]
ax3.imshow(bx, cmap="viridis");

ax4.set_title("backward y")
by = motion_outputs_1[clip_index][which_direction[3]][frame_index]
ax4.imshow(by, cmap="viridis");

######################

# A function that will be called whenever the slider changes.
def update_lines(change):
    fx = motion_outputs_1[clip_index][which_direction[0]][slider.value]
    ax1.imshow(fx, cmap="viridis")
    fy = motion_outputs_1[clip_index][which_direction[1]][slider.value]
    ax2.imshow(fy, cmap="viridis")
    
    bx = motion_outputs_1[clip_index][which_direction[2]][slider.value]
    ax3.imshow(bx, cmap="viridis")
    by = motion_outputs_1[clip_index][which_direction[3]][slider.value]
    ax4.imshow(by, cmap="viridis")
    
    fig.canvas.draw()
    fig.canvas.flush_events()

# Connecting the slider object to the update function above.
# This is event-handling.
slider.observe(update_lines, names='value')

# Creates an application interface with the various 
# pieces we already instantiated inside of it. 
AppLayout(
    center=fig.canvas,
    footer=slider,
    pane_heights=[0, 6, 1]
)

fig.show()

## Define the function to create our GIF Animations

How to save animation as gif: http://louistiao.me/posts/notebooks/save-matplotlib-animations-as-gifs/

In [None]:
def create_32_frame_motion_colormap_gif(model_name, motion_output_obj, clip_index, which_direction, out_file_comment="", cmap="viridis"):
    # plot all 4, forward x,y and backward x, y
    fig, (ax1, ax2, ax3, ax4) = plt.subplots(1, 4, figsize=(15,5));
    fig.tight_layout()
    
    start_frame = 0
    
    # find the max and min value needed for colormapping for all 4 motion outputs across ALL 
    # 32 frames
    all_mins = []
    all_maxes = []
    for frame_ind in range(32):
        for direction_ind in range(4):
            all_mins.append(motion_output_obj[clip_index][which_direction[direction_ind]][frame_ind].min())
            all_maxes.append(motion_output_obj[clip_index][which_direction[direction_ind]][frame_ind].max())
    
    color_min = min(all_mins)
    color_max = max(all_maxes)
    
    ax1.set_title("forward x");
    fx = motion_output_obj[clip_index][which_direction[0]][start_frame]
    ax1_img = ax1.imshow(fx, cmap=cmap, vmin=color_min, vmax=color_max);
    fig.colorbar(ax1_img, ax=ax1) # show colorbar

    ax2.set_title("forward y")
    fy = motion_output_obj[clip_index][which_direction[1]][start_frame]
    ax2_img = ax2.imshow(fy, cmap=cmap, vmin=color_min, vmax=color_max);
    fig.colorbar(ax2_img, ax=ax2)

    ax3.set_title("backward x")
    bx = motion_output_obj[clip_index][which_direction[2]][start_frame]
    ax3_img = ax3.imshow(bx, cmap=cmap, vmin=color_min, vmax=color_max);
    fig.colorbar(ax3_img, ax=ax3)

    ax4.set_title("backward y")
    by = motion_output_obj[clip_index][which_direction[3]][start_frame]
    ax4_img = ax4.imshow(by, cmap=cmap, vmin=color_min, vmax=color_max);
    fig.colorbar(ax4_img, ax=ax4)
    
    # funct to update imshow with new frame of motion output
    def animate(frame):
        fx = motion_output_obj[clip_index][which_direction[0]][frame]
        ax1_img = ax1.imshow(fx, cmap=cmap, vmin=color_min, vmax=color_max);
        # fig.colorbar(ax1_img, ax=ax1)
        fy = motion_output_obj[clip_index][which_direction[1]][frame]
        ax2_img = ax2.imshow(fy, cmap=cmap, vmin=color_min, vmax=color_max);
        # fig.colorbar(ax2_img, ax=ax2)

        bx = motion_output_obj[clip_index][which_direction[2]][frame]
        ax3_img = ax3.imshow(bx, cmap=cmap, vmin=color_min, vmax=color_max);
        # fig.colorbar(ax3_img, ax=ax3)
        by = motion_output_obj[clip_index][which_direction[3]][frame]
        ax4_img = ax4.imshow(by, cmap=cmap, vmin=color_min, vmax=color_max);
        # fig.colorbar(ax4_img, ax=ax4)

        return [ax1, ax2, ax3, ax4]
    anim = animation.FuncAnimation(fig, animate, np.arange(0, 32), interval=500, blit=True); # interval is milliseconds between each redraw, adjusts animation speed
    anim.save(f'./warren-random/visualization-outputs/{model_name}_motion_colormap_{out_file_comment}.gif', writer='imagemagick', fps=10)

## Model 1: Video animation of forward and backward x,y

In [None]:
clip_index = -1   # last 32 frame clip
which_direction = [0,1,2,3] # in order of: forward x,y, backward x,y 

out_file_comment = "3-colorbar-seismic"
color_mapping = "viridis"

create_32_frame_motion_colormap_gif(model_name_1, motion_outputs_1, clip_index, which_direction, out_file_comment=out_file_comment, cmap=color_mapping);

In [None]:
Image(f'./warren-random/visualization-outputs/{model_name_1}_motion_colormap_{out_file_comment}.gif')

## Model 2: Video animation of forward x,y

In [None]:
clip_index = -1   # last 32 frame clip
which_direction = [0,1,2,3] # in order of: forward x,y, backward x,y 

out_file_comment = "1-colorbar-seismic"
color_mapping = "seismic"

create_32_frame_motion_colormap_gif(model_name_2, motion_outputs_2, clip_index, which_direction, out_file_comment=out_file_comment, cmap=color_mapping);

In [None]:
Image(f'./warren-random/visualization-outputs/{model_name_2}_motion_colormap_{out_file_comment}.gif')

## Try to print motion output as vector fields

### One field for Forward (x,y), and another one for Backward (x,y). 

https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.quiver.html

In [None]:
clip_index = -1   # last 32 frame clip
frame_index = 0   # first frame
which_direction = [0,1,2,3] # in order of: forward x,y, backward x,y 

In [None]:
fx.shape, fy.shape

In [None]:
len(fx), len(fx[0])

In [None]:
print(plt.style.available)

In [None]:
# try to plot a single instance of the vector field from this 
# significant help from: https://stackoverflow.com/a/40370633

# plt.style.use('seaborn-white')
plt.style.use('dark_background')


## TODO: figure out how to display 4 of these: 
''' [ m1fxy, m1bxy ]
    [ m2fxy, m1bxy ] 
'''

# fig, ax = plt.subplots(2, 2, figsize=(15, 5))

##### Model 1, Forward

fx = motion_outputs_1[clip_index][which_direction[0]][frame_index]
fy = motion_outputs_1[clip_index][which_direction[1]][frame_index]

nrows, ncols = fx.shape
x_tmp = np.linspace(0, 112, ncols)  
y_tmp = np.linspace(0, 112, nrows)
x_tails, y_tails = np.meshgrid(x_tmp, y_tmp, indexing='xy') ## Todo: what is a numpy meshgrid ?

plt.figure()
plt.title(f"Forward x,y for {model_name_1}")
plt.xlabel('$\Delta x$')
plt.ylabel('$\Delta y$')
ax = plt.gca(); ax.invert_yaxis()
# plt.quiver(x_tails, y_tails, fx, fy, alpha=0.5)
plt.quiver(x_tails, y_tails, fx, fy, edgecolor='b', facecolor='none', linewidth=0.7)
plt.tight_layout()
plt.show()

### TODO
##### Model 1, Backward

In [None]:
print(len(x_tails), len(y_tails))

## Now let's try to make an animation of this.

In [None]:
def create_32_frame_motion_vector_field_gifs(model_name, motion_output_obj, clip_index, which_direction, out_file_comment=""):
    # 2 subplots, forward and backward motion
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20,10));
    
    # starting frame ind
    start_frame = 0
    
    # get initial forward and backward motion
    fx = motion_output_obj[clip_index][which_direction[0]][start_frame]
    fy = motion_output_obj[clip_index][which_direction[1]][start_frame]
    
    bx = motion_output_obj[clip_index][which_direction[2]][start_frame]
    by = motion_output_obj[clip_index][which_direction[3]][start_frame]
    
    # create values for the tails of the vectors (these don't need to change)
    nrows, ncols = fx.shape
    x_tmp = np.linspace(0, 112, ncols)  
    y_tmp = np.linspace(0, 112, nrows)
    x_tails, y_tails = np.meshgrid(x_tmp, y_tmp, indexing='xy')
    
    # put titles, axes on subplots
    plt.suptitle(f"Motion Vector Field (x,y) for {model_name}")
    ax1.set_title('Forward Motion')
    ax1.set_xlabel('$\Delta x$')
    ax1.set_ylabel('$\Delta y$')
    ax1.invert_yaxis()
    
    ax2.set_title('Backward Motion')
    ax2.set_xlabel('$\Delta x$')
    ax2.set_ylabel('$\Delta y$')
    ax2.invert_yaxis()
    
    # put initial magnitudes to the vectors with fixed tails
    # forward
    ax1.quiver(x_tails, y_tails, fx, fy, edgecolor='b', facecolor='none', linewidth=0.7);
    # backward
    ax2.quiver(x_tails, y_tails, bx, by, edgecolor='b', facecolor='none', linewidth=0.7);

    
    # funct to update imshow with new frame of motion output
    def animate(frame):
        # update forward xy and backward xy
        fx = motion_output_obj[clip_index][which_direction[0]][frame]
        fy = motion_output_obj[clip_index][which_direction[1]][frame]

        bx = motion_output_obj[clip_index][which_direction[2]][frame]
        by = motion_output_obj[clip_index][which_direction[3]][frame]
        
        # clear forward and backward axes first (we'll need to reapply the labels and inverting the y axis)
        ax1.clear()
        ax2.clear()
        ax1.set_title('Forward Motion')
        ax1.set_xlabel('$\Delta x$')
        ax1.set_ylabel('$\Delta y$')
        ax1.invert_yaxis()

        ax2.set_title('Backward Motion')
        ax2.set_xlabel('$\Delta x$')
        ax2.set_ylabel('$\Delta y$')
        ax2.invert_yaxis()
        
        # update the magnitudes of the vectors with fixed tails
        # forward
        ax1.quiver(x_tails, y_tails, fx, fy, edgecolor='b', facecolor='none', linewidth=0.7);
        # backward
        ax2.quiver(x_tails, y_tails, bx, by, edgecolor='b', facecolor='none', linewidth=0.7);

        
    anim = animation.FuncAnimation(fig, animate, np.arange(0, 32), interval=500, blit=True); # interval is milliseconds between each redraw, adjusts animation speed
    anim.save(f'./warren-random/visualization-outputs/{model_name}_motion_vector_field_{out_file_comment}.gif', writer='imagemagick', fps=10);

In [None]:
clip_index = -1   # last 32 frame clip
which_direction = [0,1,2,3] # in order of: forward x,y, backward x,y 

out_file_comment = "attempt3"

create_32_frame_motion_vector_field_gifs(model_name = model_name_1, 
                                         motion_output_obj = motion_outputs_1, 
                                         clip_index = clip_index, 
                                         which_direction = which_direction, 
                                         out_file_comment=out_file_comment);

In [None]:
Image(f'./outputs/{model_name_1}_motion_vector_field_{out_file_comment}.gif')