In [None]:
import os
import json
import math
from pathlib import Path
from itertools import chain

import cv2
import h5py
import numpy as np
from math import isnan
import matplotlib as mpl
import ipywidgets as widgets
from inspect import signature
import matplotlib.pyplot as plt
from matplotlib import rcParams
from IPython.display import display
from scipy.optimize import curve_fit
from scipy.signal import butter,filtfilt, sosfilt

from vlab.data.io import get_slice
from vlab.data.annotations import load_annotations
from vlab.images.transform import mip_threeview, apply_lut

In [None]:
def load_numpy(path, file_name='traces.npy'):
    return np.load(os.path.join(path, file_name))

def save_numpy(path, table, file_name='traces.npy'):
    np.save(os.path.join(path, file_name), table)
    
def save_traces(path, filename, traces, n):
    f = os.path.join(path, filename)
    if os.path.exists(f):
        saved_traces = load_numpy(path, filename)
        saved_traces[:, n, :] = traces[:, n, :]
        save_numpy(path, saved_traces, filename)
    else:
        save_numpy(path, traces, filename)
    
    
def get_names(path):
    _, W = load_annotations(path)
    return [name.decode() for name in W.df['name']]
    
def double_exp(t, a1, b1, a2, b2):
    return a1 * np.exp(-b1 * t) + a2 * np.exp(-b2 * t)
    
def fit_double_exp(x, y):
    idx = ~np.isnan(y[x])
    boundaries = ([0.0, 0.0, 0.0, 0.0], [1.0, 10.0/y.shape[0], 1.0, 100.0/y.shape[0]])
    return curve_fit(double_exp, np.array(x)[idx], y[x][idx], loss='soft_l1',
                     bounds=boundaries, method='trf')[0]


def remove_artifacts(r_range1, r_range2, traces, scale):
    time_points = list((set(chain(range(r_range1[0], r_range1[1]),
                                  range(r_range2[0], r_range2[1])))))
    params_r = fit_double_exp(time_points, traces[0])
    curve_r = double_exp(range(traces.shape[1]), *params_r)
    a = (traces[0] - curve_r) / curve_r
    return np.stack([traces[0] / (1 + a), traces[1] / (1 + scale * a)]), time_points, params_r


def remove_nans(signal):
    temp = signal.copy()
    nans = ~np.isnan(temp)
    nans_idx = np.where(nans==False)[0]
    for idx in nans_idx:
        to_replace_idx = idx - 1
        while(temp[to_replace_idx]==np.NaN):
            to_replace_idx =- 1
        temp[idx] = temp[to_replace_idx]
    return temp, nans_idx

def insert_nans(signal, nans_idx):
    temp = signal.copy()
    for idx in nans_idx:
        temp[idx] = np.NaN
    return temp

def lowpass_filter(signal, cutoff):
    interpolated_signal, idx = remove_nans(signal)
    params = butter(N=2, Wn=cutoff, output='ba', btype='lowpass', fs=4.0, analog=False)
    filtered_signal = filtfilt(params[0], params[1], interpolated_signal)
    return insert_nans(filtered_signal, idx)

def debleach(g_range1, g_range2, g_range3, g_range4, traces, params_r):
    time_points = list((set(chain(range(g_range1[0], g_range1[1]),
                                  range(g_range2[0], g_range2[1]),
                                  range(g_range3[0], g_range3[1]),
                                  range(g_range4[0], g_range4[1])))))
    
    params_g = fit_double_exp(time_points, traces[1])
    curve_g = double_exp(range(traces.shape[1]), *params_g)
    curve_r = double_exp(range(traces.shape[1]), *params_r)
    scale_r = curve_r[0] if isnan(traces[0, 0]) else traces[0, 0]
    scale_g = curve_g[0] if isnan(traces[1, 0]) else traces[1, 0]
    
    g = traces[0] * scale_r / curve_r
    r = traces[1] * scale_g / curve_g
    
    if g.max() > 1:
        g = g / g.max()
    if r.max() > 1:
        r = r / r.max()
    return np.stack([g, r]), time_points, params_g



def get_list_diff(list1, list2):
    sub_list = list1.copy()
    for element in list2:
        if element in list1:
            sub_list.remove(element)
    return sub_list
    
def plot(path, name, total_range, range_r, range_g, params, raw,
         artifact_subtracted, lowpass_filtered, debleached,
         save):
    
    y = np.zeros(len(total_range))
    figure, axes = plt.subplots(nrows=5, ncols=1)
    figure.set_size_inches(12, 20)
    
    range_r_grey = get_list_diff(total_range, range_r)
    range_g_grey = get_list_diff(total_range, range_g)
    
    y[range_r_grey] = raw[0][range_r_grey]
    y[range_r] = np.NaN
    axes[0].plot(total_range, y, color='#F1948A', label='tagRFP raw not used for curve fitting', alpha=0.8)
    
    y[range_r] = raw[0][range_r]
    y[range_r_grey] = np.NaN
    axes[0].plot(total_range, y, color='#E74C3C', label='tagRFP raw used for curve fitting', alpha=0.8)
    axes[0].plot(total_range, artifact_subtracted[0], color='#424949', label='fitted curve', alpha=0.8)
    axes[0].legend()
    
    axes[1].plot(total_range, raw[1], color='#424949', label='GCaMP raw', alpha=0.8)
    axes[1].plot(total_range, artifact_subtracted[1], color='#229954', label='GCaMP artifact subtracted', alpha=0.8)
    axes[1].legend()
    
    axes[2].plot(total_range, lowpass_filtered[0], color='#E74C3C', label='tagRFP lowpass filtered')
    axes[2].plot(total_range, lowpass_filtered[1], color='#229954', label='GCaMP lowpass filtered')
    axes[2].legend()
    
    y[range_g_grey] = lowpass_filtered[1][range_g_grey]
    y[range_g] = np.NaN
    axes[3].plot(total_range, y, color='#82E0AA', label='GCaMP not used for curve fitting', alpha=0.8)
    
    y[range_g] = lowpass_filtered[1][range_g]
    y[range_g_grey] = np.NaN
    axes[3].plot(total_range, y, color='#229954', label='GCaMP used for curve fitting', alpha=0.8)
    axes[3].plot(total_range, double_exp(np.array(total_range), params[1, 0], params[1, 1], params[1, 2], params[1, 3]), 
                 color='#424949', label='GCaMP fitted curve', alpha=0.8)
    axes[3].legend()
    
    axes[4].plot(total_range, debleached[0], color='#E74C3C', label='tagRFP debleached')
    axes[4].plot(total_range, debleached[1], color='#229954', label='GCaMP debleached')
    axes[4].legend()
    
    axes[0].set_ylim([0, 1.2 * np.nanmax([artifact_subtracted[0], raw[0]])])
    axes[0].xaxis.set_visible(False)
    axes[1].set_ylim([0, 1.2 * np.nanmax([artifact_subtracted[1], raw[1]])])
    axes[1].xaxis.set_visible(False)
    axes[2].set_ylim([0, 1.2 * np.nanmax(lowpass_filtered)])
    axes[2].xaxis.set_visible(False)
    axes[3].set_ylim([0, 1.2 * np.nanmax([double_exp(np.array(total_range), params[1, 0], params[1, 1], params[1, 2], params[1, 3]), lowpass_filtered[1]])])
    axes[3].xaxis.set_visible(False)
    axes[4].set_ylim([0, 1.2 * np.nanmax(debleached)])
    figure.suptitle(name, fontsize=16)
    if save:
        name = name.replace('/','_').strip()
        filename = os.path.join(path, 'plots' ,name + '.pdf')
        figure.savefig(filename , dpi=1000, transparent=True)
    plt.show()
    return figure, axes

def interactive_plot(path):
    plots_path = os.path.join(path, 'plots')
    if not os.path.exists(plots_path):
        os.mkdir(plots_path)
        
    traces = load_numpy(path)
    _, n_neurons, n_timepoints = traces.shape
    total_range = list(range(n_timepoints))
    
    artifact_subtracted_traces = np.zeros_like(traces)
    lowpass_filtered_traces = np.zeros_like(traces)
    debleached_traces = np.zeros_like(traces)
    
    fit_params = np.zeros((2, n_neurons, 4))
    
    butterworth_cutoff = np.zeros((n_neurons))
    
    
    names = get_names(path)
    
    def widget_fn(n, r_range1, r_range2, scale, cutoff, g_range1, g_range2, g_range3, g_range4, save):
        
        artifact_subtracted_traces[:, n, :], range_r, fit_params[0, n, :] = remove_artifacts(r_range1, r_range2,
                                                                                             traces[:, n, :], scale)

        butterworth_cutoff[n] = cutoff
        
        lowpass_filtered_traces[0, n, :] = lowpass_filter(artifact_subtracted_traces[0, n, :], cutoff)
        lowpass_filtered_traces[1, n, :] = lowpass_filter(artifact_subtracted_traces[1, n, :], cutoff)
        
        debleached_traces[:, n, :], range_g, fit_params[1, n, :] = debleach(g_range1, g_range2, g_range3, g_range4,
                                                                            lowpass_filtered_traces[:, n, :],
                                                                            fit_params[0, n, :])
        
        plot(path, names[n], total_range,
             range_r, range_g, fit_params[:, n, :],
             traces[:, n, :], artifact_subtracted_traces[:, n, :],
             lowpass_filtered_traces[:, n, :], debleached_traces[:, n, :],
             save)
        if save:
            save_traces(path, 'artifact_subtracted_traces.npy', artifact_subtracted_traces, n)
            save_traces(path, 'lowpass_filtered_traces.npy', lowpass_filtered_traces, n)
            save_traces(path, 'debleached_traces.npy', debleached_traces, n)
            
            parameter_filename = os.path.join(path, 'fit_parameters.h5')
            if os.path.exists(parameter_filename):
                parameter_file = h5py.File(parameter_filename, 'r+')
                parameter_file['parameters'][:, n, :] = fit_params[:, n, :]
                parameter_file.close()
            else:
                parameter_file = h5py.File(parameter_filename, 'w')
                parameter_file.create_dataset('parameters', data=fit_params)
                parameter_file.close()
            
        
        
    widget1 = widgets.BoundedIntText(value=0, min=0, max=len(names) - 1, step=1, disabled=False)
    
    widget2 = widgets.IntRangeSlider(value=[0, n_timepoints], min=0, max=n_timepoints, step=1, disabled=False,
                                     continuous_update=False, orientation='horizontal', readout=True, readout_format='d')
    widget3 = widgets.IntRangeSlider(value=[0, n_timepoints], min=0, max=n_timepoints, step=1, disabled=False,
                                     continuous_update=False, orientation='horizontal', readout=True, readout_format='d')
    
    widget4 = widgets.FloatSlider(value=1.0, min=0.01, max=1.99, step=0.01,
                                  disabled=False, continuous_update=False,
                                  orientation='horizontal',readout=True , readout_format='.2f')
    
    widget5 = widgets.FloatSlider(value=1.99, min=0.01, max=1.99, step=0.01,
                                  disabled=False, continuous_update=False,
                                  orientation='horizontal',readout=True , readout_format='.2f')
    
    widget6 = widgets.IntRangeSlider(value=[0, n_timepoints], min=0, max=n_timepoints, step=1, disabled=False,
                                     continuous_update=False, orientation='horizontal', readout=True, readout_format='d')
    widget7 = widgets.IntRangeSlider(value=[0, n_timepoints], min=0, max=n_timepoints, step=1, disabled=False,
                                     continuous_update=False, orientation='horizontal', readout=True, readout_format='d')
    widget8 = widgets.IntRangeSlider(value=[0, n_timepoints], min=0, max=n_timepoints, step=1, disabled=False,
                                     continuous_update=False, orientation='horizontal', readout=True, readout_format='d')
    widget9 = widgets.IntRangeSlider(value=[0, n_timepoints], min=0, max=n_timepoints, step=1, disabled=False,
                                     continuous_update=False, orientation='horizontal', readout=True, readout_format='d')
    
    
    widget11 = widgets.Checkbox(value=False, description='Save', disabled=False, indent=False)
    
    w = widgets.interactive(
        widget_fn,
        n=widget1,
        r_range1=widget2,
        r_range2=widget3,
        scale=widget4,
        cutoff=widget5,
        g_range1=widget6,
        g_range2=widget7,
        g_range3=widget8,
        g_range4=widget9,
        save=widget11
    )
    display(w)

In [None]:
# provide the path to the dataset of interes.
paths = [Path(r'C:\Users\Mahdi\Desktop\XE2551\01_22_2021\animal_1\run_1\centered'),
         Path(r'C:\Users\Mahdi\Desktop\XE2551\01_22_2021\animal_3\run_1\centered'),
         Path(r'C:\Users\Mahdi\Desktop\XE2551\01_22_2021\animal_4\run_1\centered'),
         Path(r'C:\Users\Mahdi\Desktop\XE2551\01_30_2021\animal_2\run_1\centered'),
         Path(r'C:\Users\Mahdi\Desktop\XE2551\01_30_2021\animal_3\run_1\centered'),
         Path(r'C:\Users\Mahdi\Desktop\XE2551\01_30_2021\animal_7\run_1\centered'),
         Path(r'C:\Users\Mahdi\Desktop\XE2551\01_30_2021\animal_8\run_1\centered'),
         Path(r'C:\Users\Mahdi\Desktop\XE2551\01_30_2021\animal_10\run_1\centered'),
         Path(r'C:\Users\Mahdi\Desktop\XE2551\12_14_2020\animal_3\run_1\centered'),
         Path(r'C:\Users\Mahdi\Desktop\XE2551\12_14_2020\animal_5\run_1\centered'),
         Path(r'C:\Users\Mahdi\Desktop\XE2551\12_14_2020\animal_8\run_1\centered'),
         Path(r'C:\Users\Mahdi\Desktop\XE2551\01_22_2021\animal_2\run_1\centered'),
         Path(r'C:\Users\Mahdi\Desktop\XE2551\01_30_2021\animal_6\run_1\centered')]

'''
widgets:
        n: neuron number,
        r_range1, r_range2: specify the timepoints to be used to fit a curve to the red channel.
        scale: specifies how much of the artifact found in red channel to be subtracted from the green channel.
        cutoff: specifies the cutoff for the low pass filter applied to the green channel
        g_range1, ..., g_range4: specify timepoints to be used to fir a curve to the green channel.
        save: saves the plot and traces for the current neuron
'''
interactive_plot(paths[0])