In [None]:
%load_ext autoreload
%autoreload 1
%matplotlib widget

In [None]:
import flammkuchen as fl
from numba import jit
import numpy as np
import matplotlib.pyplot as plt
# plt.style.use("v_paper")
plt.style.use("figures.mplstyle")
import cmocean as cmo
from pathlib import Path
import os

from sklearn.linear_model import Ridge, LogisticRegression
from sklearn.svm import SVC, SVR
from sklearn.model_selection import GridSearchCV

In [None]:
import pandas as pd

In [None]:
full_dict = fl.load(r"\\FUNES\Shared\experiments\E0032_luminance\neat_exps\data_dict4decoding_complete.h5")

In [None]:
%aimport decoding

In [None]:
from decoding import confmat, get_population, decode_from_population
from stim_encoding import time_since_flash, get_valid_periods_steps

In [None]:
fig_fold = Path(r"C:\Users\otprat\Documents\figures\luminance\manuscript_figures\decoding")

if not os.path.isdir(fig_fold):
    os.mkdir(fig_fold)

# Flashes protocol

In [None]:
t_flashes    = full_dict["flashes_img_time"][:, 0]
stim_flashes = full_dict["flashes_img_time"][:, 1]

In [None]:
plt.figure()
plt.plot(t_flashes, stim_flashes);

We will try to answer the following:

Is there a linearly-separable representation
* of time since stimulus onset
* of certain time windows

## Generate data to decode

In [None]:
stim_since = time_since_flash(stim_flashes, dt=t_flashes[1]-t_flashes[0])

In [None]:
plt.figure()
plt.plot(t_flashes, stim_since);

In [None]:
valid_periods = np.logical_not(np.isnan(stim_since))

stim_cut = stim_since[valid_periods]
t_cut = t_flashes[valid_periods]

n_t_decode = len(t_cut)

## Generate populations

Select a subpopulation which has more than 6 repeats

In [None]:
n_rep = 12
n_pop = 26 # Number of IO cells

Cut the data so only the valid flash periods are taken into account

Run the decoding on many samples of GC population

In [None]:
from tqdm import tqdm

In [None]:
n_samples = 200

In [None]:
n_test = 2

In [None]:
traces = get_population(full_dict, "GC", "flashes")

In [None]:
predictions = []
best_params = []
for i in tqdm(range(n_samples)):
    sel_cells_subset = np.random.choice(traces.shape[0], n_pop, replace=False)
    population  = traces[sel_cells_subset, :, 0:n_rep]
    population = population[:, valid_periods, :]
    model, pred = decoding.decode_from_population(population, stim_cut, 2)
    predictions.append(pred)
    best_params.append(model.best_params_["alpha"])

In [None]:
outputs_gt = np.tile(stim_cut, n_test)

In [None]:
preds = np.stack(predictions)
preds = np.concatenate([preds[:,:n_t_decode], preds[:,n_t_decode:]], 0)

In [None]:
mnc = np.mean(stim_cut)

In [None]:
GC_subset_rsq = 1 - np.sum((preds - stim_cut[None, :])**2,1)/(np.sum((stim_cut-mnc)**2))

In [None]:
out_order = np.argsort(stim_cut)
out_xtime = stim_cut[out_order]

Sort out the data so that the same times are not calculated twice

In [None]:
stim_cut_slices = [(0, 21), (21, 50), (50, 135)]
n_sl = len(stim_cut_slices)

In [None]:
full_pred = np.full((n_sl*preds.shape[0], stim_cut_slices[2][1]-stim_cut_slices[2][0]), np.nan)

In [None]:
for i_sl, (left, right) in enumerate(stim_cut_slices):
    full_pred[i_sl*n_samples*n_test:(i_sl+1)*n_samples*n_test,:right-left] = preds[:, left:right]

In [None]:
time_long = stim_cut[stim_cut_slices[2][0]: stim_cut_slices[2][1]]

In [None]:
pred_mn = np.nanmean(full_pred, 0)
pred_sd = np.nanstd(full_pred, 0)

## IO

In [None]:
cell_type = "IO"

traces = full_dict["{}_flashes".format(cell_type)]["clean_traces"]

n_valid_traces = np.sum(np.logical_not(np.all(np.isnan(traces), axis=1)), axis=1)

In [None]:
n_rep = 12

sel_cells = np.where(n_valid_traces >= n_rep)[0]

n_pop = 26 # Number of IO cells

population  = traces[sel_cells, :, 0:n_rep]
population = population[:, valid_periods, :]

In [None]:
model_IO, pred_IO = decoding.decode_from_population(population, stim_cut, 2)

In [None]:
time_IO = np.tile(stim_cut, 2)

In [None]:
IO_rsq = 1 - np.sum((pred_IO - time_IO)**2)/(np.sum((time_IO-np.mean(time_IO))**2))

In [None]:
duration_decoding_dict = {
    "GC": {'time':time_long,
           'full_pred': full_pred,
           'preds': preds,
           'rsquared': GC_subset_rsq
          },
    "IO":{'time':time_IO,
          'preds': pred_IO,
          'rsquared': IO_rsq
         }
}
fl.save('duration_decoding.h5', duration_decoding_dict)

In [None]:
fig, ax = plt.subplots()
ax.fill_between(time_long, pred_mn-pred_sd, pred_mn+pred_sd, color=(0.1,0.1,0.1,0.07), linewidth=0)
ax.plot(time_long, pred_mn, label="GC, average of 26 random cell selections")
ax.scatter(time_IO, np.concatenate([preds[0], preds[1]]), color=(0.4,0.4,0.4), s=1.0, label="GC, 26 cells")
ax.scatter(time_IO, pred_IO, color=(0.9,0.3,0.1), s=1.5, label="IO, 26 cells")

ax.plot(time_long, time_long, color=(0.3, 0.3, 0.3,0.7))
ax.set_xlabel("actual time since flash onset")
ax.set_ylabel("prediceted time since flash onset")
ax.set_aspect(1)
ax.legend()

In [None]:
fig.savefig(fig_fold/"time_decoding.pdf")

In [None]:
plt.figure()
plt.hist(GC_subset_rsq, bins=20, label="GC 26 cells R squared distribution")
plt.axvline(IO_rsq, color=(0.9,0.3,0.1), label="IO R squared")
plt.xlabel("R squared")
plt.legend()

In [None]:
IO_rsq

## Now, trying to classify time bins

In [None]:
n_time_bins = 7

In [None]:
bin_boundaries = np.linspace(0, 21, num=n_time_bins+1)
bin_ids = np.digitize(stim_cut, bin_boundaries)-1

### On all GCs

In [None]:
n_test = 2

In [None]:
cell_type = "GC"

traces = full_dict["{}_flashes".format(cell_type)]["clean_traces"]

n_valid_traces = np.sum(np.logical_not(np.all(np.isnan(traces), axis=1)), axis=1)

n_rep = 12

sel_cells = np.where(n_valid_traces >= n_rep)[0]

n_pop = len(sel_cells)

population = traces[sel_cells, :, 0:n_rep]
population = population[:, valid_periods, :]

### One vs all

In [None]:
bin_probabilities = []
for i_bin in range(n_time_bins):
    mod_bins_allgc, pred_bins_allgc = \
            decoding.decode_from_population(population,
                                            bin_ids==i_bin, n_test,
                                            model=LogisticRegression(class_weight="balanced",
                                                                     solver="liblinear",
                                                                     multi_class="ovr"),
                                            hyperparams=dict(C=10.0 ** np.arange(-4, 2)),
                                            probabilities=True)
    bin_probabilities.append(pred_bins_allgc)

In [None]:
accs = []
n_tot = len(bin_ids)*n_test
for i_bin in range(n_time_bins):
    positive = np.tile(bin_ids == i_bin, n_test)
    correct_preds = positive == (bin_probabilities[i_bin][:,1] > 0.5)    
    correct_pos = np.sum(correct_preds[positive])/np.sum(positive)
    correct_neg = np.sum(correct_preds[~positive])/np.sum(~positive)
    accs.append((correct_pos,correct_neg))

In [None]:
plt.figure()
plt.plot(np.array(accs))

## Decode any

In [None]:
mod_bins_simult_allgc, pred_bins_simult_allgc = \
    decoding.decode_from_population(population,
                                            bin_ids, n_test,
                                            model=LogisticRegression(class_weight="balanced",
                                                                     solver="liblinear",
                                                                     multi_class="ovr"),
                                            hyperparams=dict(C=10.0 ** np.arange(-4, 2)),
                                   probabilities=True)

In [None]:
confusion_mat = confmat(pred_bins_simult_allgc, np.tile(bin_ids, n_test))

In [None]:
fig, ax = plt.subplots()
im = ax.imshow(confusion_mat, cmap=cmo.cm.tempo, extent=[bin_boundaries[0], bin_boundaries[-1], bin_boundaries[-1], bin_boundaries[0]])
ax.set_xticks(bin_boundaries)
ax.set_yticks(bin_boundaries)
ax.set_ylabel("Actual time bin")
ax.set_xlabel("Predicted time bin")
bar = fig.colorbar(im)
ax.spines["bottom"].set_visible(False)
ax.spines["left"].set_visible(False)
bar.set_label("Probability")

# Decoding luminance

In [None]:
t_steps    = full_dict["steps_img_time"][:, 0]
stim_steps = full_dict["steps_img_time"][:, 1]

In [None]:
plt.figure()
plt.plot(t_steps, stim_steps)

In [None]:
valid_periods = get_valid_periods_steps(stim_steps).astype(np.bool)

steps_cut = stim_steps[valid_periods]
t_steps_cut = t_steps[valid_periods]

n_t_decode = len(t_steps_cut)

## Decoding absoulute luminance

### IO

In [None]:
from sklearn.svm import SVR
import seaborn as sns

In [None]:
lum_levels = np.sort(np.unique(steps_cut))

In [None]:
pop = get_population(full_dict, cell_type="IO", protocol="steps")[:, valid_periods, :]

In [None]:
# lum_model_lin, pred_lin = decoding.decode_from_population(pop, steps_cut)
# lum_model_svm, pred_lum_svm = decoding.decode_from_population(pop, steps_cut, model=SVR(), hyperparams=dict(C=10.0**np.arange(-4,4)))
# pred_data = pd.DataFrame(dict(lum_true=np.tile(steps_cut, n_test), lum_lin=pred_lin,  lum_svm=pred_lum_svm))
# fl.save("abs_lum_decoding_results_IO_final.h5", pred_data, compression="blosc")

pred_data = fl.load("abs_lum_decoding_results_IO_final.h5")

In [None]:
fig, axes = plt.subplots(1, 2)
for pred, ax in zip(["lin", "svm"], axes):
    sns.swarmplot(x=pred_data.lum_true, y=pred_data["lum_"+pred], s=2, ax=ax, color=sns.color_palette()[1])
    ax.hlines(lum_levels, np.arange(-0.4, 3.6, 1), np.arange(0.4, 3.5, 1), lw=1, color=(0.4, 0.4, 0.4))
    ax.set_title("Predictions of absolute luminance, "+pred)

### GC

In [None]:
pop = get_population(full_dict, cell_type="GC", protocol="steps")[:, valid_periods, :]

In [None]:
# lum_model_lin, pred_lin = decoding.decode_from_population(pop, steps_cut)
# lum_model_svm, pred_lum_svm = decoding.decode_from_population(pop, steps_cut, model=SVR(), hyperparams=dict(C=10.0**np.arange(-4,4)))
# pred_data = pd.DataFrame(dict(lum_true=np.tile(steps_cut, n_test), lum_lin=pred_lin,  lum_svm=pred_lum_svm))
# fl.save("abs_lum_decoding_results_GC_final.h5", pred_data, compression="blosc")

pred_data = fl.load("abs_lum_decoding_results_GC_final.h5")

In [None]:
fig, axes = plt.subplots(1, 2)
for pred, ax in zip(["lin", "svm"], axes):
    sns.swarmplot(x=pred_data.lum_true, y=pred_data["lum_"+pred], s=2, ax=ax, color=sns.color_palette()[0])
    ax.hlines(lum_levels, np.arange(-0.4, 3.6, 1), np.arange(0.4, 3.5, 1), lw=1, color=(0.4, 0.4, 0.4))
    ax.set_title("Predictions of absolute luminance, "+pred)

#### and GC subsampled

In [None]:
n_incl_GCs = get_population(full_dict, cell_type="IO", protocol="steps")[:, valid_periods, :].shape[0]
n_incl_GCs

In [None]:
pred_data = []
n_iters = 20

# for i in tqdm(range(n_iters)):
#     pop = get_population(full_dict, cell_type="GC", protocol="steps", max_rois_incl=n_incl_GCs)[:, valid_periods, :]
#     lum_model_lin, pred_lin = decoding.decode_from_population(pop, steps_cut)
#     lum_model_svm, pred_lum_svm = decoding.decode_from_population(pop, steps_cut, model=SVR(), hyperparams=dict(C=10.0**np.arange(-4,4)))
#     pred_data.append(pd.DataFrame(dict(lum_true=np.tile(steps_cut, n_test), lum_lin=pred_lin,  lum_svm=pred_lum_svm)))
    
# fl.save("abs_lum_decoding_results_GCsubsampled_final.h5", pred_data, compression="blosc")

pred_data = fl.load("abs_lum_decoding_results_GCsubsampled_final.h5")

In [None]:
plot_iter = 5

fig, axes = plt.subplots(1, 2)
for pred, ax in zip(["lin", "svm"], axes):
    sns.swarmplot(x=pred_data[plot_iter].lum_true, y=pred_data[plot_iter]["lum_"+pred], s=2, ax=ax, color=sns.color_palette()[0])
    ax.hlines(lum_levels, np.arange(-0.4, 3.6, 1), np.arange(0.4, 3.5, 1), lw=1, color=(0.4, 0.4, 0.4))
    ax.set_title("Predictions of absolute luminance, "+pred)

## Transition finding

In [None]:
from collections import namedtuple
from itertools import product, starmap

def named_product(**items):
    Product = namedtuple("Product", items.keys())
    return starmap(Product, product(*items.values()))

In [None]:
transition_cont = np.r_[0, np.diff(stim_steps)]
transition_cont = np.convolve(transition_cont, [1, 1, 1 ,1], mode="same")
transition_cont = np.digitize(transition_cont, [-1.1, -0.3, -0.1, -0.001, 0.001, 0.1, 0.3, 1.1])-1

In [None]:
def prepare_stim(s, n_convolve=4):
    for i in range(1, n_convolve):
        s = np.logical_or(s, np.r_[0, s[:-1]])
    return s

In [None]:
all_options = dict(features = dict(
    updown=prepare_stim(np.abs(np.r_[0, np.diff(stim_steps)])>0.01),
    up=prepare_stim(np.r_[0, np.diff(stim_steps)]>0.01),
    down=prepare_stim(np.r_[0, np.diff(stim_steps)]<-0.01),
    category=transition_cont
),
population = dict(GC=get_population(full_dict, "GC", "steps"),
                   IO=get_population(full_dict, "IO", "steps")),
decoder = dict(svm=SVC(gamma="auto", probability=True),
                linear=LogisticRegression(class_weight="balanced", solver="liblinear", multi_class="ovr")))

In [None]:
fl.save("all_options.h5", all_options)

In [None]:
option_it = named_product(**all_options)

In [None]:
def get_fit(opt):
    luminance_transitions = all_options["features"][opt.features]
    print(opt)
    mod, pred  = \
            decoding.decode_from_population(all_options["population"][opt.population], luminance_transitions, n_test,
                                            model=all_options["decoder"][opt.decoder],
                                            hyperparams=dict(C=10.0 ** np.arange(-4, 4)), probabilities=True)
    return pred

In [None]:
from joblib import Parallel, delayed

In [None]:
# results = Parallel(n_jobs=20, backend="threading")(delayed(get_fit)(opt) for opt in option_it)
# fl.save("transition_results.h5", results, compression="blosc")
results = fl.load("transition_results.h5")

In [None]:
option_list = list(named_product(**all_options))

In [None]:
for opt in option_list:
    print(opt)

In [None]:
for i_item in range(12,16):
    copt = option_list[i_item]

    confusion_mat_t = confmat(results[i_item],
                              np.tile(all_options["features"][copt.features], n_test))
    plt.figure()
    plt.imshow(confusion_mat_t, cmap=cmo.cm.tempo)
    plt.colorbar()
    plt.title("transition {} decoded from {} with a {} decoder".format(copt.features, copt.population, copt.decoder))

### Trying to replicate plots

In [None]:
def plot_confmat(confusion_mat):

    plt.figure()
    plt.imshow(confusion_mat, cmap=cmo.cm.tempo, origin='lower')
    plt.colorbar()
    plt.title("transition {} decoded from {} with a {} decoder".format(copt.features, copt.population, copt.decoder))
    plt.xlabel('Predicted transition')
    plt.ylabel('Actual transition')

In [None]:
def get_fit(opt):
    luminance_transitions = all_options["features"][opt.features]
    print(opt)
    mod, pred  = \
            decoding.decode_from_population(all_options["population"][opt.population], luminance_transitions, n_test,
                                            model=all_options["decoder"][opt.decoder],
                                            hyperparams=dict(C=10.0 ** np.arange(-4, 4)), probabilities=True)
    return pred

### Run IO decoding

In [None]:
all_options = dict(features = dict(
    category=transition_cont
),
population = dict(IO=get_population(full_dict, "IO", "steps")
                   ),
decoder = dict(svm=SVC(gamma="auto", probability=True),
                ))

In [None]:
option_it = named_product(**all_options)

In [None]:
# results = Parallel(n_jobs=20, backend="threading")(delayed(get_fit)(opt) for opt in option_it)
# fl.save("transition_results_IO_final.h5", results, compression="blosc")

results = fl.load("transition_results_IO_final.h5")

In [None]:
option_it = list(named_product(**all_options))

copt = option_it[0]
confusion_mat_t = confmat(results[0], np.tile(all_options["features"][copt.features], n_test))

In [None]:
plot_confmat(confusion_mat_t)

### Run GC decoding

In [None]:
n_incl_GCs = get_population(full_dict, "IO", "steps").shape[0]
n_incl_GCs

In [None]:
#Make options dict
all_options = dict(features = dict(
    category=transition_cont
),
population = dict(GC=get_population(full_dict, "GC", "steps", max_rois_incl=n_incl_GCs)
                   ),
decoder = dict(svm=SVC(gamma="auto", probability=True),
                ))


In [None]:
# n_iters = 20
# results = []

# for i in tqdm(range(n_iters)):
    
#     #Make options dict
#     all_options = dict(features = dict(
#         category=transition_cont
#     ),
#     population = dict(GC=get_population(full_dict, "GC", "steps", max_rois_incl=n_incl_GCs)
#                        ),
#     decoder = dict(svm=SVC(gamma="auto", probability=True),
#                     ))
    
#     #Make iterator
#     option_it = named_product(**all_options)
    
#     #Decoding
#     results.append(Parallel(n_jobs=20, backend="threading")(delayed(get_fit)(opt) for opt in option_it))

# fl.save("transition_results_GC_final.h5", results, compression="blosc")

results = fl.load("transition_results_GC_final.h5")

In [None]:
confusion_mats_list = []

for i in range(n_iters):
    option_it = list(named_product(**all_options))

    copt = option_it[0]
    confusion_mats_list.append(confmat(results[i][0], np.tile(all_options["features"][copt.features], n_test)))
    
confusion_mats = np.stack(confusion_mats_list)

In [None]:
# for i in range(n_iters):
#     plot_confmat(confusion_mats[i, :, :])

In [None]:
plot_confmat(np.nanmean(confusion_mats, 0))

# Shifts

Investigate whether a time-shift enables better prediction of transitions

In [None]:
def shift_ar(ar, i_shift):
    if i_shift > 0:
        return np.r_[np.zeros(i_shift, dtype=ar.dtype), ar[:-i_shift]]
    if i_shift < 0:
        return np.r_[ar[np.abs(i_shift):], np.zeros(abs(i_shift), dtype=ar.dtype)]
    return ar

In [None]:
luminance_transitions = all_options["features"][opt.features]

In [None]:
conditions = dict(pos=[], neg=[], both=[])
for cond, cl in conditions.items():
    for i_delay in tqdm(range(-4, 5)):
        if cond == "pos":
            ltr = luminance_transitions>0.01
        elif cond == "neg":
            ltr =luminance_transitions<-0.01
        else:
            ltr = np.abs(luminance_transitions)>0.01
        gt = shift_ar(ltr, i_delay)
        tr_model_io_lin, tr_predictions_io_lin = \
            decoding.decode_from_population(transition_IO_pop, gt>0.1, n_test,
                                            model=LogisticRegression(class_weight="balanced",
                                                                         solver="liblinear",
                                                                         multi_class="ovr"),
                                            hyperparams=dict(C=10.0 ** np.arange(-4, 2)), probabilities=True)
        cl.append((np.tile(gt, n_test), tr_predictions_io_lin))

In [None]:
@jit(nopython=True)
def extract_around(signal, events, n_before=8, n_after=8):
    n_events = np.sum(events>0)
    signals_around = np.empty((n_events, n_after+n_before), dtype=signal.dtype)
    i_ev = 0
    for i_t in range(n_before, len(events)-n_after):
        if events[i_t]:
            signals_around[i_ev, :] = signal[i_t-n_before:i_t+n_after]
            i_ev += 1
    return signals_around

In [None]:
preds[0][1].shape

In [None]:
plt.figure()
i_delay=5
plt.plot(preds[i_delay][1][:,1])
plt.plot(gt)

In [None]:
plt.figure()
i_delay = 7
sigs_around = extract_around(preds[i_delay][1][:,1], preds[i_delay][0])
plt.plot(sigs_around.T, color=(0,0,0,0.5), lw=0.5)

In [None]:
i_d = 6
plt.figure()
plt.plot(preds[i_d][0])
plt.plot(preds[i_d][1])

In [None]:
plt.plot(np.tile(luminance_transitions, n_test))

In [None]:
model_svr, predictions_svr = decoding.decode_from_population(population, stim_steps, 2, model=SVR, hyperparams=dict(C=10.0**np.arange(-4,4)))

In [None]:
pred_gt = np.tile(stim_steps, 2)

In [None]:
plt.figure()
plt.plot(pred_gt)
plt.plot(predictions_svr)

In [None]:
fig, ax = plt.subplots()
ax.scatter(pred_gt+np.random.randn(*pred_gt.shape)*0.01, predictions_svr)
ax.plot([0, 1], [0,1], color=(0.1, 0.1, 0.1, 0.3))
ax.set_aspect(1)