### Imports

In [1]:
import os
from pylab import *

import numpy as np
import nibabel as nib
from tqdm import tqdm_notebook as tqdm
import matplotlib.animation as animation

### Load sample image

In [66]:
### Config
root         = "/Volumes/hd_4tb/raw/connhc/000/conn030/workingmemMB/normalized.nii.gz"
project      = "connhc"
subject      = "conn001"
time_session = "000"
task         = "conscious"

### Load image
# filepath = os.path.join(root, project, time_session, subject, task, "normalized.nii.gz")
filepath = root
image    = nib.as_closest_canonical(nib.load(filepath))
data     = image.get_data()

### Helper function

fps.  Frames per second

<b>TODO:</b>
- include_flash.  Flashes when the images updates.  Should increase fps and flash every module time unit
- normalize and threshold activations

In [74]:
dpi = 100

def get_data_slice(data, axis):
    axis = axis.lower()
    assert axis in {"x", "y", "z"}, "invalid axis %s" % axis
    if axis == "x":
        L = data.shape[0] // 2
        return data[L, :, :, :]
    elif axis == "y":
        L = data.shape[1] // 2
        return data[:, L, :, :]
    elif axis == "z":
        L = data.shape[2] // 2
        return data[:, :, L, :]

def norm_data(data):
    std  = np.std(data, axis=3)
    mean = np.mean(data, axis=3)
    for time_point in range(data.shape[3]):
        std[std == 0] = 1
        data[:, :, :, time_point] = (data[:, :, :, time_point] - mean) / std
    return data

def thresh_data(data):
    data[(data < 1) & (data > -1)] = 0
    return data

def animate_nii(data, save_path, axis="x", num_dummies=3, TR=2, speed=1, norm=True):
    data_slice = get_data_slice(data, axis)
#     return data_slice
    animate_frames(data_slice, save_path, fps = 1 / TR)

def _create_plt_figure():
    fig = plt.figure()
    ax = fig.add_subplot(111)
    ax.set_aspect('equal')
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
    return fig, ax
    
def _load_frame(frames, n):
    frame = frames[:, :, n]
    return np.ma.masked_where(frame == 0, frame)

def _flash(n):
    if n == 10:
        return True
    return False

def animate_frames(frames, save_path, fps = 30, speed = 1):
    fps = fps * speed
    N = frames.shape[2]
    fig, ax = _create_plt_figure()
    
    cmap = plt.cm.RdBu
    cmap.set_bad("black")
    
    frame = _load_frame(frames, 0)
    im = ax.imshow(frame, cmap=cmap)

    fig.set_size_inches([5,5])

    tight_layout()

    def update_img(n):
        frame = _load_frame(frames, n)
        im.set_data(frame)
        if _flash(n):
            cmap.set_bad("yellow")
        else:
            cmap.set_bad("black")
        im.set_cmap(cmap)
        return im

    ani = animation.FuncAnimation(fig, update_img, N, interval=fps)
    writer = animation.writers["ffmpeg"](fps=fps)

    ani.save(save_path, writer=writer, dpi=dpi)
    plt.close()

### Main

Steps:

- Raw image (not interesting)
- Speed up (still not interesting)
- Normalize (getting somewhere)
- Threshold (to remove noise, very exciting!)

In [75]:
TR = 0.1
save_path = "/Users/pstetz/Desktop/thresh.mp4"
frames = animate_nii(data, save_path, axis = "z", TR=TR, speed=20)

### cmap colors

In [71]:
colors = [
            'binary', 'gist_yarg', 'gist_gray', 'gray', 'bone', 'pink',
            'spring', 'summer', 'autumn', 'winter', 'cool', 'Wistia',
            'hot', 'afmhot', 'gist_heat', 'copper', 'PiYG', 'PRGn', 'BrBG', 'PuOr', 'RdGy', 'RdBu',
            'RdYlBu', 'RdYlGn', 'Spectral', 'coolwarm', 'bwr', 'seismic',
            'twilight', 'twilight_shifted', 'hsv',
            'Pastel1', 'Pastel2', 'Paired', 'Accent',
            'Dark2', 'Set1', 'Set2', 'Set3',
            'tab10', 'tab20', 'tab20b', 'tab20c'
]