In [None]:
import numpy as np
import imageio as imio
import matplotlib.pyplot as plt
%matplotlib inline
from skimage.filters import try_all_threshold as tat
from skimage.filters import threshold_otsu as otsu
from skimage.morphology import remove_small_objects as rso
from skimage.morphology import watershed
from skimage.feature import peak_local_max
from scipy import ndimage as ndi
import nibabel as nib
from scipy.stats import pearsonr
import os
from skimage.feature import match_template
import math

In [None]:
def plot_img(img):
    f = plt.figure()
    n_ = imio.imread('./neuron_{}.png'.format(img))[:,:,0]  
    ax1 = plt.subplot(121)
    ax1.set_title('raw image')
    plt.imshow(n_, cmap='gray')
    n = n_.astype(float)
    n = (n - np.mean(n)) / np.std(n)
    ax2 = plt.subplot(122)
    ax2.set_title('stdized')
    plt.imshow(n, cmap='gray')
    return n

In [None]:
def spat_corr(sm_est, sm_gt, max_only=True, pos_th=0.6, neg_th=-0.6):
#     sim_sm = list([sm_gt+i for i in os.listdir(sm_gt) if i.endswith('.nii')])
    sim_sm = sorted(list([sm_gt+i for i in os.listdir(sm_gt) if i.endswith('.nii')]), key= lambda x: int(x.split('/')[-1].split('.')[0]))
    pos_sc = {}
    neg_sc = {}
    sm = {}
    spat_corrs = {}
    for n in range(1,129):
        wt = 'neuron_{}'.format(n)
        est_sm = sm_est.get(wt)
        n_sc = {}
        for nii in sim_sm:
            gt = nii.split('/')[-1].split('.')[0]
#             if int(gt) < 10: gt = '0{}'.format(gt)
            gt_sm_ = nib.load(nii)
            gt_sm = gt_sm_.get_fdata()
            sm[gt]=gt_sm
#             print(gt_sm.shape)
            sc = pearsonr(np.ravel(est_sm), np.ravel(gt_sm))[0]
            n_sc[gt] = sc
            if not max_only: 
                if any((v>=pos_th or v<=neg_th) for k,v in n_sc.items()):
                    spat_corrs[wt+'__sm_{}'.format(gt)] = sc    
        max_sc = max(n_sc, key=lambda key: n_sc[key])
        min_sc = min(n_sc, key=lambda key: n_sc[key])
        pos_sc[wt+'__sm_{}'.format(max_sc)] = n_sc[max_sc]
        neg_sc[wt+'__sm_{}'.format(min_sc)] = n_sc[min_sc]                    
    if not max_only: return spat_corrs, sm
    else: return pos_sc, neg_sc, sm

In [None]:
def spat_corr2(sm_est, sm_gt):
    sim_sm = list([sm_gt+i for i in os.listdir(sm_gt) if i.endswith('.nii')])
    pos_sc = {}
    neg_sc = {}
    sm = {}
    for nii in sim_sm:
        gt = nii.split('/')[-1]
        gt = gt.split('.')[0]
        gt_sm_ = nib.load(nii)
        gt_sm = gt_sm_.get_fdata()
        sm[gt]=gt_sm
        n_sc = {}
        for n in range(1,129):
            wt = 'neuron_{}'.format(n)
            est_sm = sm_est.get(wt)
            sc = pearsonr(np.ravel(gt_sm), np.ravel(est_sm))[0]
            n_sc[wt] = sc
        max_sc = max(n_sc, key=lambda key: n_sc[key])
        min_sc = min(n_sc, key=lambda key: n_sc[key])
        pos_sc['{}__sm_{}'.format(max_sc, gt)] = n_sc[max_sc]
        neg_sc['{}__sm_{}'.format(min_sc, gt)] = n_sc[min_sc]
    return pos_sc, neg_sc, sm

In [None]:
def cmpnt_tc(data, sm_dict):
    print(data.shape)
    sm_mtrx = []
#     print(sorted(list([int(i) for i in sm_dict.keys()])))
    print(sm_dict.keys())
    for sm in range(1,len(sm_dict.keys())+1):
        sm_arr = np.reshape(sm_dict[str(sm)], (16384))
        sm_mtrx.append(sm_arr)
    sm_matrix = np.asarray(sm_mtrx)
    print(sm_matrix.shape)
    sm_tc = np.matmul(sm_matrix,np.transpose(data))
    print(sm_tc.shape)
    return sm_tc

In [None]:
def temp_corr(tc_est, tc_gt, max_only=True):
    pos_tc={}
    neg_tc={}
    temp_corrs={}
    for tct in range(len(tc_gt)):
#         print(tc_est[tce].shape)
        e_tcorr = {}
        for tce in range(len(tc_est)):
#             print(tc_gt[tct].shape)
            tempcorr = pearsonr(tc_gt[tct], tc_est[tce])[0]
            e_tcorr[str(tce)] = tempcorr
#             print('estim. tc {} and gt tc {}: corr={}'.format(tce,tct,tempcorr))
            if not max_only: 
                if any((v>=pos_th or v<=neg_th) for k,v in e_tcorr.items()):
                    temp_corrs['neuron_{}__sm_{}'.format(tce+1,tct+1)] = tempcorr  
        max_tcorr = max(e_tcorr, key=lambda key: e_tcorr[key])
        pos_tc['neuron_{}__gt_{}'.format(int(max_tcorr)+1,tct+1)] = e_tcorr[max_tcorr]
        min_tcorr = min(e_tcorr, key=lambda key: e_tcorr[key])
        neg_tc['neuron_{}__gt_{}'.format(int(min_tcorr)+1,tct+1)] = e_tcorr[min_tcorr]
    if not max_only: return temp_corrs
    else: return pos_tc, neg_tc

In [None]:
def plot_spat_corr(pos_spat_corr_dict, neg_spat_corr_dict, wt_dict, gt_dict, pos_th, neg_th):
    strong_corrs={}
    for k,v in pos_spat_corr_dict.items():
        if v>=pos_th:
            f = plt.figure()
            f.suptitle('{} corr={}'.format(k,v))
            ax1 = plt.subplot(121)
            ax1.set_title(k.split('__')[0])
            plt.imshow(wt_dict[k.split('__')[0]], cmap='gray')
            ax1 = plt.subplot(122)
            ax1.set_title(k.split('__')[1])
            plt.imshow(gt_dict[k.split('_')[-1]], cmap='gray')
            strong_corrs[k]=v
    for k,v in neg_spat_corr_dict.items():
        if v<=neg_th:
            f = plt.figure()
            f.suptitle('{} corr={}'.format(k,v))
            ax1 = plt.subplot(121)
            ax1.set_title(k.split('__')[0])
            plt.imshow(wt_dict[k.split('__')[0]], cmap='gray')
            ax1 = plt.subplot(122)
            ax1.set_title(k.split('__')[1])
            plt.imshow(gt_dict[k.split('_')[-1]], cmap='gray')
            strong_corrs[k]=v
    return strong_corrs

In [None]:
def plot_temp_corr(temp_corr_pos, temp_corr_neg, wt_dict, gt_dict, pos_th, neg_th):
    strong_temp_corrs={}
    for k,v in temp_corr_pos.items():
        if v>pos_th:
            f = plt.figure()
            f.suptitle('{} corr={}'.format(k,v))
            ax = plt.subplot(121)
            ax.set_title(k.split('__')[0])
            plt.imshow(wt_dict[k.split('__')[0]], cmap='gray')
            ax1 = plt.subplot(122)
            ax1.set_title(k.split('__')[1])
            plt.imshow(gt_dict[k.split('_')[-1]], cmap='gray')
            strong_temp_corrs[k]=v
    for k,v in temp_corr_neg.items():
        if v<neg_th:
            f = plt.figure()
            f.suptitle('{} corr={}'.format(k,v))
            ax = plt.subplot(121)
            ax.set_title(k.split('__')[0])
            plt.imshow(wt_dict[k.split('__')[0]], cmap='gray')
            ax1 = plt.subplot(122)
            ax1.set_title(k.split('__')[1])
            plt.imshow(gt_dict[k.split('_')[-1]], cmap='gray')
            strong_temp_corrs[k]=v
    return strong_temp_corrs

In [None]:
def plot_sm_fnc(spat_corr_dict, wt_dict):
    for wt in range(1,129):
        wt_sc = []
        for k,v in spat_corr_dict.items():
            if k.split('__')[0] == 'neuron_{}'.format(wt):# and (v>=pos_th or v<=neg_th):
                wt_sc.append([k.split('__')[1],v])#
#         print(wt_sc)
        wt_sc_sorted = sorted(wt_sc,key=lambda x: int(x[0].split('_')[1]))
#         print(wt_sc_sorted)
        if wt_sc:
            f = plt.figure(figsize=(12.8,4.8))
            f.suptitle('Wt matrix {}'.format(wt))
            plt.xticks(range(1,28))
            ax = plt.bar(list(int(wt_sc_sorted[i][0].split('_')[1]) for i in range(len(wt_sc_sorted))),\
                        list(wt_sc_sorted[ii][1] for ii in range(len(wt_sc_sorted))))
#         plt.imshow(ax)

In [None]:
def max_corrs(pos_sm_corr, neg_sm_corr, wt_dict, gt_dict):
    for sm in range(1,28):
        f = plt.figure()
        f.suptitle('SM {}'.format(sm))
        sm_dict={}
        for k,v in pos_sm_corr.items():
            if k.split('_')[-1] == str(sm):
                sm_dict[k] = v
        if sm_dict:
            sm_max = max(sm_dict, key=lambda key: sm_dict[key])
            ax1 = plt.subplot(131)
            ax1.set_title('pos corr={0:.3f}'.format(sm_dict[sm_max]))
            plt.imshow(wt_dict[sm_max.split('__')[0]], cmap='gray')
        for k,v in neg_sm_corr.items():
            if k.split('_')[-1] == str(sm):
                sm_dict[k] = v
        if sm_dict:
            sm_min = min(sm_dict, key=lambda key: sm_dict[key])
            ax2 = plt.subplot(132)
            ax2.set_title('neg corr={0:.3f}'.format(sm_dict[sm_min]))
            plt.imshow(wt_dict[sm_min.split('__')[0]], cmap='gray')
        ax3 = plt.subplot(133)
        ax3.set_title('SM')
        plt.imshow(gt_dict[str(sm)], cmap='gray')

In [None]:
wts_stdz1 = {}
for img in range(1,129):
    n_stdz = plot_img(img)
    wts_stdz1['neuron_{}'.format(img)] = n_stdz

In [None]:
wts_stdz2 = sorted(wts_stdz1.items(), key=lambda kv: int(kv[0].split('_')[1]))
wts_stdz = {kv1[0]:kv1[1] for kv1 in wts_stdz2}
wts_stdz.keys()

In [None]:
gt_cmpnts = '../../sim_SM/'

In [None]:
sc_pos, sc_neg, sm_cmpnts = spat_corr(wts_stdz, gt_cmpnts)

In [None]:
for k,v in sc_pos.items():
    print(k,v)

In [None]:
for k,v in sc_neg.items():
    print(k,v)

In [None]:
spat_corrs = plot_spat_corr(sc_pos, sc_neg, wts_stdz, sm_cmpnts, 0.6, -0.6)

In [None]:
sc_pos2, sc_neg2, _ = spat_corr2(wts_stdz, gt_cmpnts)

In [None]:
for k,v in sc_pos2.items():
    print(k,v)

In [None]:
for k,v in sc_neg2.items():
    print(k,v)

In [None]:
spat_corrs2 = plot_spat_corr(sc_pos2, sc_neg2, wts_stdz, sm_cmpnts, 0.6, -0.6)

In [None]:
max_corrs(sc_pos, sc_neg, wts_stdz, sm_cmpnts)

In [None]:
sim_tc = np.load('./timecourses.npy')
sim_tc.shape