# Setting up

## load modules

In [None]:
%%capture
%load_ext autoreload
%autoreload 2
%matplotlib inline
import os
import sys
import numpy as np
import xarray as xr
from dask.distributed import Client, LocalCluster
from tqdm import tqdm
import matplotlib.pyplot as plt 
from scipy.signal import periodogram, find_peaks
import re
import warnings
from os import listdir
from pathlib import Path
import cv2
import dask as da
import math
import dask.array as darr
import xarray as xr
import zarr as zr
from natsort import natsorted
from tifffile import TiffFile, imread
import matplotlib

## set path and parameters

In [None]:
# Set up Initial Basic Parameters#
dpath = "FILE_PATH_HERE"
dpath = os.path.abspath(dpath)
framesPerFile = 1000

# Pre-processing Parameters#
param_load_videos = {
    "pattern": "[0-9]+\.avi$", 
    "dtype": np.uint8,
    "downsample": dict(frame=1, height=1, width=1),
    "downsample_strategy": "subset",
}

os.environ["OMP_NUM_THREADS"] = "1"
os.environ["MKL_NUM_THREADS"] = "1"
os.environ["OPENBLAS_NUM_THREADS"] = "1"
os.environ["NUMBA_NUM_THREADS"] = "1"

## functions for loading in all videos

In [None]:
def load_videos(
    vpath,
    pattern=r"msCam[0-9]+\.avi$",
    dtype=np.float64,
    in_memory=False,
    downsample=None,
    downsample_strategy="subset",
    post_process=None,
):
    """
    Load videos from the folder specified in `vpath` and according to the regex
    `pattern`, then concatenate them together across time and return a
    `xarray.DataArray` representation of the concatenated videos. The default
    assumption is video filenames start with ``msCam`` followed by at least a
    number, and then followed by ``.avi``. In addition, it is assumed that the
    name of the folder correspond to a recording session identifier.

    Parameters
    ----------
    vpath : str
        The path to search for videos
    pattern : str, optional
        The pattern that describes filenames of videos. (Default value =
        'msCam[0-9]+\.avi')

    Returns
    -------
    xarray.DataArray or None
        The labeled 3-d array representation of the videos with dimensions:
        ``frame``, ``height`` and ``width``. Returns ``None`` if no data was
        found in the specified folder.
    """
    vpath = os.path.normpath(vpath)
    ssname = os.path.basename(vpath)
    vlist = natsorted(
        [vpath + os.sep + v for v in os.listdir(vpath) if re.search(pattern, v)]
    )
    if not vlist:
        raise FileNotFoundError(
            "No data with pattern {}"
            " found in the specified folder {}".format(pattern, vpath)
        )
    print("loading {} videos in folder {}".format(len(vlist), vpath))

    file_extension = os.path.splitext(vlist[0])[1]
    if file_extension in (".avi", ".mkv"):
        movie_load_func = load_avi_lazy
    elif file_extension == ".tif":
        movie_load_func = load_tif_lazy
    else:
        raise ValueError("Extension not supported.")

    varr_list = [movie_load_func(v) for v in vlist]
    varr = darr.concatenate(varr_list, axis=0)
    varr = xr.DataArray(
        varr,
        dims=["frame", "height", "width"],
        coords=dict(
            frame=np.arange(varr.shape[0]),
            height=np.arange(varr.shape[1]),
            width=np.arange(varr.shape[2]),
        ),
    )
    if dtype:
        varr = varr.astype(dtype)
    if downsample:
        bin_eg = {d: np.arange(0, varr.sizes[d], w) for d, w in downsample.items()}
        if downsample_strategy == "mean":
            varr = (
                varr.coarsen(**downsample, boundary="trim")
                .mean()
                .assign_coords(**bin_eg)
            )
        elif downsample_strategy == "subset":
            varr = varr.sel(**bin_eg)
        else:
            warnings.warn("unrecognized downsampling strategy", RuntimeWarning)
    varr = varr.rename("fluorescence")
    if post_process:
        varr = post_process(varr, vpath, ssname, vlist, varr_list)
    return varr

def load_tif_lazy(fname):
    data = TiffFile(fname)
    f = len(data.pages)

    fmread = da.delayed(load_tif_perframe)
    flist = [fmread(fname, i) for i in range(f)]

    sample = flist[0].compute()
    arr = [
        da.array.from_delayed(fm, dtype=sample.dtype, shape=sample.shape)
        for fm in flist
    ]
    return da.array.stack(arr, axis=0)


def load_tif_perframe(fname, fid):
    return imread(fname, key=fid)


def load_avi_lazy(fname):
    cap = cv2.VideoCapture(fname)
    f = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    fmread = da.delayed(load_avi_perframe)
    flist = [fmread(fname, i) for i in range(f)]
    sample = flist[0].compute()
    arr = [
        da.array.from_delayed(fm, dtype=sample.dtype, shape=sample.shape)
        for fm in flist
    ]
    return da.array.stack(arr, axis=0)


def load_avi_perframe(fname, fid):
    cap = cv2.VideoCapture(fname)
    h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    cap.set(cv2.CAP_PROP_POS_FRAMES, fid)
    ret, fm = cap.read()
    if ret:
        return np.flip(cv2.cvtColor(fm, cv2.COLOR_RGB2GRAY), axis=0)
    else:
        print("frame read failed for frame {}".format(fid))
        return np.zeros((h, w))


## start cluster

In [None]:
cluster = LocalCluster(n_workers=1, memory_limit="8GB")
client = Client(cluster)

# Misalignment detection

## loading videos and visualization

In [None]:
%%time
varr = load_videos(dpath, **param_load_videos)

In [None]:
varr_ref = varr.chunk({"frame": 20, "height": -1, "width": -1})

## get rid of V4 stripe noise (present on earlier releases of the V4 scope)

In [None]:
# f = 0
# im = varr_ref[f,:,:].copy().values
# im_fft = fftpack.fft2(im.astype(np.float64))
# im_fft2 = im_fft.copy()
# n = 5
# y = 2
# im_fft2[n:im_fft2.shape[0]-n, :y] = 0
# im_fft2[n:im_fft2.shape[0]-n, -y:] = 0
# im_new = fftpack.ifft2(im_fft2).real
# fig, ax = plt.subplots(ncols=4, figsize=(50,15))
# ax[0].imshow(np.abs(im_fft), norm=LogNorm(vmin=5))
# ax[0].set_title('Fourier transform', fontsize=30)
# ax[1].imshow(np.abs(im_fft2), norm=LogNorm(vmin=5))
# ax[1].set_title('Filtered spectrum', fontsize=30)
# ax[2].imshow(im, cmap='binary_r', vmin=0, vmax=255, origin='lower')
# ax[2].set_title('Old image', fontsize=30)
# ax[3].imshow(im_new, cmap='binary_r', vmin=0, vmax=255, origin='lower')
# ax[3].set_title('New image', fontsize=30)
# plt.show()
# im_opts = dict(frame_width=500, aspect=608/608, cmap='Spectral_r', colorbar=True)
# hv.Image(np.log(np.abs(im_fft)), ['width', 'height'], label='before_mc').opts(**im_opts)

In [None]:
# def sensor_denoise(varr, n, y):
#     return xr.apply_ufunc(
#         sensor_denoise_perframe,
#         varr,
#         input_core_dims=[['height', 'width']],
#         output_core_dims=[['height', 'width']],
#         vectorize=True,
#         dask='parallelized',
#         output_dtypes=[np.float64],
#         kwargs=dict(n=n, y=y))

# def sensor_denoise_perframe(f, n, y):
#     im_fft = fftpack.fft2(f)
#     im_fft[n:im_fft.shape[0]-n, :y] = 0
#     im_fft[n:im_fft.shape[0]-n, -y:] = 0
#     im_new = fftpack.ifft2(im_fft).real
#     return im_new

# varr_ref = sensor_denoise(varr_ref, n=5, y=2)
# min_val = varr_ref.min().compute()
# max_val = varr_ref.max().compute()
# varr_ref = ((varr_ref - min_val) / (max_val - min_val) * 255)
# varr_ref = varr_ref.astype(np.uint8)

## find frames with misaligned rows

In [None]:
# First, find the frames where the stripes exist.

from scipy.signal import periodogram, find_peaks
def no_stripes_frames(varr):
    return xr.apply_ufunc(
        no_stripes_in_frame,
        varr,
        input_core_dims=[['height', 'width']],
        vectorize=True,
        dask='parallelized',
        output_dtypes=[bool])

def no_stripes_in_frame(x, thresh=20):
    x = x[:,0]
    f, Pxx_spec = periodogram(x, 1)
    peaks = find_peaks(np.sqrt(Pxx_spec)[:int(len(Pxx_spec)/5)], height=thresh)[0]
    return not any(f[peaks] > 0.03)

frames_without_stripes = no_stripes_frames(varr_ref).values
bad_frames = np.asarray(varr_ref[~frames_without_stripes].frame)

print('bad frames:')
bad_frames

In [None]:
# Plot what the minimum projection would look like without the misaligned frames. 
# If stripes still appear, try lowering spec_thresh in no_stripes_in_frame() and rerunning it. 

varr_min = varr_ref[frames_without_stripes].min("frame").compute()
plt.imshow(varr_min)

In [None]:
# Inspect frames by eye and optionally, save the frames to a folder.

def plot_bad_frame(bad_frame, save=False, show_plot=False, dpath=dpath):
    x = varr_ref.sel(frame=bad_frame)
    f, Pxx_spec = periodogram(x[:,0], 1)
    peaks = find_peaks(np.sqrt(Pxx_spec)[:int(len(Pxx_spec)/5)], height=20)[0]
    fig, ax = plt.subplots(ncols=2, figsize=(40,10))
    ax[0].imshow(x, cmap='binary_r', aspect='equal', origin='lower')
    ax[1].semilogy(f, np.sqrt(Pxx_spec), c='k')
    ax[1].scatter(f[peaks], np.sqrt(Pxx_spec)[peaks], s=100, c='brown')
    if len(peaks) > 1:
        ax[1].text(x = 0.01, y = 0.1, s='Frame has stripes', color='brown', fontsize=20)
    ax[1].set_ylim([1e-2, 1e3])
    ax[1].margins(x=0.01)
    ax[1].set_xlabel('Frequency [Hz]', fontsize=30)
    ax[1].set_ylabel('Linear spectrum [V RMS]', fontsize=30)
    
    if show_plot:
        fig.show()
    else:
        plt.close(fig)
    
    if save:
        fig.savefig(os.path.join(dpath, 'bad_frames', str(bad_frame) + '.png'))
        
    return f, peaks

In [None]:
bad_frames_folder = os.path.join(dpath, 'bad_frames')
if not os.path.exists(bad_frames_folder):
    os.mkdir(bad_frames_folder)
for bad_frame in tqdm(bad_frames):
    plot_bad_frame(bad_frame, show_plot=False, save=True)

## get file names that need to be replaced

In [None]:
def get_filename(frame_list):
    # Get the file names associated with each frame. 
    vid_numbers = np.unique([math.floor(f/framesPerFile) for f in frame_list])
    fnames = [os.path.join(dpath, str(n) + '.avi') for n in vid_numbers]
    
    # Get the frame number within that video file. 
    relative_frame_numbers = []
    for n in vid_numbers:
        quotient, remainder = np.divmod(frame_list, n*framesPerFile)
        
        relative_frame_numbers.append(remainder[(quotient==1) & (remainder < framesPerFile)])
    
    return fnames, relative_frame_numbers

In [None]:
fnames, frame_numbers = get_filename(bad_frames)
fnames

# Realignment

## rewrite videos into a new folder

In [None]:
def fix_frame(frame, shift_amount=8184*2, show_plot=False, ax=None):
    buffer_size = 8184
    flattened_frame = frame.flatten()
    frame_size = frame.shape
    n_pixels = len(flattened_frame)
    
    for pixel_number in range(n_pixels):
        buf_num = int(pixel_number/buffer_size)
        
        if ((buf_num % 2) == 0):
            if ((pixel_number + shift_amount) < n_pixels):
                flattened_frame[pixel_number] = flattened_frame[pixel_number + shift_amount]
    
    fixed_frame = flattened_frame.reshape(frame_size)
    
    if show_plot:
        if ax is None:
            fig, ax = plt.subplots(figsize=(24,24))
        ax.imshow(fixed_frame)
    
    return fixed_frame

def fix_video(fnames, frame_numbers):
    folder = os.path.join(os.path.split(fnames[0])[0], 'repaired')
    if not os.path.exists(folder):
        os.mkdir(folder)
        print(f'Created {folder}')
    
    compressionCodec = "FFV1"
    codec = cv2.VideoWriter_fourcc(*compressionCodec)
    
    buffer_size = 8184
    shift_amount = buffer_size*2
    
    # For each video...
    for video, bad_frame_numbers in zip(fnames, frame_numbers):
        print(f'Rewriting {video}')
        cap = cv2.VideoCapture(video)
        rows, cols = int(cap.get(4)), int(cap.get(3))
        
        fname = os.path.split(video)[1]
        new_fpath = os.path.join(folder, fname)

        writeFile = cv2.VideoWriter(new_fpath, codec, 60, (cols,rows), isColor=False)
        
        for frame_number in tqdm(range(int(cap.get(7)))):
            ret, frame = cap.read()

            if ret:
                write_frame = frame[:,:,0]

                if frame_number in bad_frame_numbers:
                    fix_frame(write_frame, shift_amount, show_plot=False)

                writeFile.write(np.uint8(write_frame))
            else:
                break

        writeFile.release()
        cap.release()
    cv2.destroyAllWindows()
                

In [None]:
fix_video(fnames, frame_numbers)