# Select Channel Subset that is sensitive to parcel subset

In [None]:
import numpy as np
import cedalion
import cedalion.nirs
import cedalion.imagereco.forward_model as fw
import cedalion.datasets
import os
import cedalion.xrutils as xrutils
import cedalion.plots
import xarray as xr
import cedalion.geometry.landmarks as cd_landmarks
import matplotlib.pyplot as plt
import cedalion.sim.synthetic_hrf as synHRF_ced
from cedalion import units
import cedalion.dataclasses as cdc
import pyvista as pv
import cedalion.models.glm as glm
from cedalion.imagereco.solver import pseudo_inverse_stacked
#pv.set_jupyter_backend('server') # this enables interactive plots
import cedalion.sigproc.quality as quality
from cedalion.sigproc.quality import repair_amp
import random
import pickle
import cedalion.vis.plot_sensitivity_matrix

xr.set_options(display_expand_data=False);

import cedalion.imagereco.forward_model as fw
from cedalion.imagereco.solver import pseudo_inverse_stacked
import configs
from configs import load_dataset_configs

In [None]:
data_path = configs.data_path_prefix

fwm = fw.ForwardModel

In [None]:
with open(data_path + 'parcels_colin.pickle', 'rb') as f:
    parcels_colin = pickle.load(f)

with open(data_path + 'parcels_icbm.pickle', 'rb') as f:
    parcels_icbm = pickle.load(f)

laura_sens_path = os.path.join(data_path, 'BS_Laura', "BS_Laura_YY_parcel_sens_channels")
with open(laura_sens_path, 'rb') as f:
    channel_roi_sens_laura = pickle.load(f)

with open(os.path.join(data_path, 'NN22_Resting_State', "NN_22_C3_C4_close_channels"), 'rb') as f:
    c3_c4_rs = pickle.load(f)

In [None]:
data_type = 'HD_Squeezing'
#data_type = 'BS_Laura'

In [None]:
if data_type == 'BS_Laura':
    rec = cedalion.io.read_snirf(data_path + 'BS_Laura/' + 'BS_Laura_Data/' + 'sub-586' + "/nirs/" + 'sub-586' + '_task-BS_run-01_nirs.snirf')[0]
    with open(data_path + 'BS_Laura/Adot/Adot_BSLaura_ICBM.pickle', 'rb') as f:
        Adot = pickle.load(f)
elif data_type == 'HD_Squeezing':
    rec = cedalion.datasets.get_fingertappingDOT()
    with open(data_path + 'HD_Squeezing/Adot/Adot_HDSqueezing_ICBM.pickle', 'rb') as f:
        Adot = pickle.load(f)
geo3d = rec.geo3d
amp = rec["amp"]
dists = cedalion.xrutils.norm(geo3d.loc[amp.source] - geo3d.loc[amp.detector], dim=geo3d.points.crs)
selected_channels = list(dists.channel[dists.values < 44].values)
amp = amp.sel(channel=selected_channels)

In [None]:
len(selected_channels)

In [None]:
dOD_thresh = 0.001
minCh = 1 
dHbO = 10 #µM 
dHbR = -3 #µM

Adot = Adot.assign_coords(parcel = ("vertex", parcels_icbm))
parcel_dOD, parcel_mask = fwm.parcel_sensitivity(Adot, None, dOD_thresh, minCh, dHbO, dHbR)
sensitive_parcels = parcel_mask.where(parcel_mask, drop=True)["parcel"].values.tolist()

In [None]:
parcel_dOD = parcel_dOD.sel(channel=selected_channels)
if data_type == 'BS_Laura':
    parcel_dOD = parcel_dOD.sel(channel=channel_roi_sens_laura)

In [None]:
dataset_configs = load_dataset_configs(data_types=['HD_Squeezing'], load_sensitivity=True)

In [None]:
hd_sq_par = dataset_configs['HD_Squeezing'].sensitive_parcels
som_mot_a_parcels = [p for p in sensitive_parcels if 'SomMotA' in p]

In [None]:
len(hd_sq_par)

In [None]:
#parcel_dOD_roi_max = parcel_dOD.sel(parcel=hd_sq_par, wavelength='850').max(dim='parcel')
parcel_dOD_roi_max = parcel_dOD.sel(parcel=som_mot_a_parcels, wavelength='850').max(dim='parcel')

In [None]:
# for all-parcel-roi:
#parcel_dOD_roi_mask = (parcel_dOD_roi_max > parcel_dOD_roi_max.mean())
# for SomMotA parcels in BS_Laura
# restrictive
#parcel_dOD_roi_mask = (parcel_dOD_roi_max > 0.01)
# not so restrictive
parcel_dOD_roi_mask = (parcel_dOD_roi_max > 0.0035)

In [None]:
# example hrf roi channels 
example_hrf_roi_channels = ['S52D106', 'S51D108', 'S52D111', 'S46D105', 'S48D107', 'S46D112', 'S46D107', 'S40D19', 'S50D23', 'S53D30', 'S11D110', 'S53D21', 'S11D111', 'S53D24', 'S53D32']

In [None]:
parcel_dOD_roi_mask.sum()

In [None]:
channel_array = xr.where(parcel_dOD_roi_mask, 1, 0)
# set entries to 2 where channel is in the list
channel_array.loc[dict(channel=[ch for ch in example_hrf_roi_channels if ch in channel_array.channel.values])] = 1

In [None]:
channel_array

In [None]:
channel_array

In [None]:
fig, ax = plt.subplots(1, 1)
# adjust plot size
fig.set_size_inches(10, 10)

cedalion.plots.scalp_plot(
    amp,
    geo3d,
    channel_array,
    ax,
    title="",
    vmin=0,
    vmax=1,
    cb_label="",
    add_colorbar=False,
)

In [None]:
parcel_roi_masked_channels = list(channel_array[channel_array == 1].channel.values)

if data_type == 'BS_Laura':
    write_path = os.path.join(data_path, 'BS_Laura', "BS_Laura_SomMotA_sens_channels")
elif data_type == 'HD_Squeezing':
    write_path = os.path.join(data_path, 'HD_Squeezing', "HD_Squeezing_SomMotA_sens_channels")


#with open(write_path, 'wb') as f:
#    pickle.dump(parcel_roi_masked_channels, f)