### Imports

In [10]:
import os
from pylab import *

import numpy as np
import nibabel as nib
from tqdm import tqdm_notebook as tqdm
import matplotlib.animation as animation

### Load sample image

In [2]:
### Config
root         = "/Volumes/group/PANLab_Datasets"
project      = "ENGAGE"
subject      = "LA13012"
time_session = "000_data_archive"
task         = "100_fMRI/101_fMRI_preproc_GO_NO_GO/s02_globalremoved_func_data.nii"

### Load image
filepath = os.path.join(root, project, "data", subject, time_session, task)
image    = nib.as_closest_canonical(nib.load(filepath))
data     = image.get_data()

### Helper function

fps.  Frames per second

<b>TODO:</b>
- include_flash.  Flashes when the images updates.  Should increase fps and flash every module time unit
- normalize and threshold activations

In [13]:
dpi = 100

def get_data_slice(data, axis):
    axis = axis.lower()
    assert axis in {"x", "y", "z"}, "invalid axis %s" % axis
    if axis == "x":
        L = data.shape[0] // 2
        return data[L, :, :, :]
    elif axis == "y":
        L = data.shape[1] // 2
        return data[:, L, :, :]
    elif axis == "z":
        L = data.shape[2] // 2
        return data[:, :, L, :]

def cut_dummies(data, num_dummies):
    return data[:, :, :, num_dummies:]

def norm_data(data):
    std  = np.std(data, axis=3)
    mean = np.mean(data, axis=3)
    for time_point in range(data.shape[3]):
        std[std == 0] = 1
        data[:, :, :, time_point] = (data[:, :, :, time_point] - mean) / std
    return data

def thresh_data(data):
    data[(data < 1) & (data > -1)] = 0
    return data

def animate_nii(data, save_path, axis="x", num_dummies=3, TR=2, speed=1, norm=True, cmap="summer"):
    data       = cut_dummies(data, num_dummies)
    if norm:
        data = norm_data(data)
        data = thresh_data(data)
    data_slice = get_data_slice(data, axis)
    animate_frames(data_slice, save_path, cmap, fps = 1 / TR)

def animate_frames(frames, save_path, cmap, fps = 30, speed = 1):
    fps = fps * speed
    
    N = frames.shape[2]
    
    fig = plt.figure()
    ax = fig.add_subplot(111)
    ax.set_aspect('equal')
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)

    im = ax.imshow(frames[:, :, 0], cmap=cmap)

    fig.set_size_inches([5,5])

    tight_layout()

    def update_img(n):
        im.set_data(frames[:, :, n])
        return im

    ani = animation.FuncAnimation(fig, update_img, N, interval=fps)
    writer = animation.writers['ffmpeg'](fps=fps)

    ani.save(save_path, writer=writer, dpi=dpi)
    plt.close()

### Main

Steps:

- Raw image (not interesting)
- Speed up (still not interesting)
- Normalize (getting somewhere)
- Threshold (to remove noise, very exciting!)

In [19]:
data_copy = data.copy()
data_copy[np.abs(data_copy) < 0.01] = 0

colors = ['binary', 'gist_yarg', 'gist_gray', 'gray', 'bone', 'pink',
            'spring', 'summer', 'autumn', 'winter', 'cool', 'Wistia',
            'hot', 'afmhot', 'gist_heat', 'copper', 'PiYG', 'PRGn', 'BrBG', 'PuOr', 'RdGy', 'RdBu',
            'RdYlBu', 'RdYlGn', 'Spectral', 'coolwarm', 'bwr', 'seismic',
         'twilight', 'twilight_shifted', 'hsv',
         'Pastel1', 'Pastel2', 'Paired', 'Accent',
                        'Dark2', 'Set1', 'Set2', 'Set3',
                        'tab10', 'tab20', 'tab20b', 'tab20c'
         ]

def get_cmap():
    cmap = plt.cm.seismic
    cmap.set_bad(color="k")
    return cmap

# for cmap in tqdm(colors):
TR = 0.1
cmap = get_cmap()
save_path = "/Users/pbezuhov/Desktop/thresh.mp4"
animate_nii(data_copy, save_path, axis = "z", TR=TR, speed=20, cmap = cmap)

  ret = um.sqrt(ret, out=ret)
