In [None]:
import numpy as np
import imageio as imio
import matplotlib.pyplot as plt
%matplotlib inline
from skimage.filters import try_all_threshold as tat
from skimage.filters import threshold_otsu as otsu
from skimage.morphology import remove_small_objects as rso
from skimage.morphology import watershed
from skimage.feature import peak_local_max
from sklearn import preprocessing
from scipy import ndimage as ndi
import nibabel as nib
from scipy.stats import pearsonr
import os
from skimage.feature import match_template
import math

## Functions

In [None]:
def cmpnt_tc(data, sm_dict):
    """ reconstructs a time course for each ground truth component """
    print(data.shape)
    sm_mtrx = []
#     print(sorted(list([int(i) for i in sm_dict.keys()])))
    print(sm_dict.keys())
    for sm in range(1,len(sm_dict.keys())+1):
        sm_arr = np.reshape(sm_dict[str(sm)], (16384))
        sm_mtrx.append(sm_arr)
    sm_matrix = np.asarray(sm_mtrx)
    print(sm_matrix.shape)
    sm_tc = np.matmul(sm_matrix,np.transpose(data))
    print(sm_tc.shape)
    return sm_tc

In [None]:
def temp_corr(tc_est, tc_gt, max_only=True, pos_th=0.6, neg_th=-0.6):
    """ 
    Finds the temporal correlation between each weight matrix's reconstructed time course and each 
    ground truth's reconstructed time course.
    
    Returns either the maximum match for each ground truth or all matches above/below the thresholds."""
    pos_tc={}
    neg_tc={}
    temp_corrs={}
    for tct in range(len(tc_gt)):
#         print(tc_est[tce].shape)
        e_tcorr = {}
        for tce in range(len(tc_est)):
#             print(tc_gt[tct].shape)
            tempcorr = pearsonr(tc_gt[tct], tc_est[tce])[0]
            e_tcorr[str(tce)] = tempcorr
#             print('estim. tc {} and gt tc {}: corr={}'.format(tce,tct,tempcorr))
            if not max_only: 
                if any((v>=pos_th or v<=neg_th) for k,v in e_tcorr.items()):
                    temp_corrs['neuron_{}__sm_{}'.format(tce+1,tct+1)] = tempcorr  
        max_tcorr = max(e_tcorr, key=lambda key: e_tcorr[key])
        min_tcorr = min(e_tcorr, key=lambda key: e_tcorr[key])
        pos_tc['neuron_{}__gt_{}'.format(int(max_tcorr)+1,tct+1)] = e_tcorr[max_tcorr]
        neg_tc['neuron_{}__gt_{}'.format(int(min_tcorr)+1,tct+1)] = e_tcorr[min_tcorr]
    if not max_only: return temp_corrs
    else: return pos_tc, neg_tc

In [None]:
def plot_temp_corr(temp_corr_pos, temp_corr_neg, wt_dict, gt_dict, pos_th, neg_th):
    """ plots the temporally matched weight matrices and ground truth component maps """
    strong_temp_corrs={}
    for k,v in temp_corr_pos.items():
        if v>pos_th:
            f = plt.figure()
            f.suptitle('{} corr={}'.format(k,v))
            ax = plt.subplot(121)
            ax.set_title(k.split('__')[0])
            plt.imshow(wt_dict[k.split('__')[0]], cmap='gray')
            ax1 = plt.subplot(122)
            ax1.set_title(k.split('__')[1])
            plt.imshow(gt_dict[k.split('_')[-1]], cmap='gray')
            strong_temp_corrs[k]=v
    for k,v in temp_corr_neg.items():
        if v<neg_th:
            f = plt.figure()
            f.suptitle('{} corr={}'.format(k,v))
            ax = plt.subplot(121)
            ax.set_title(k.split('__')[0])
            plt.imshow(wt_dict[k.split('__')[0]], cmap='gray')
            ax1 = plt.subplot(122)
            ax1.set_title(k.split('__')[1])
            plt.imshow(gt_dict[k.split('_')[-1]], cmap='gray')
            strong_temp_corrs[k]=v
    return strong_temp_corrs

## Load the reconstructed time courses
##### Reconstruction performed on original machine the RBM was run on 

In [None]:
sim_tc = np.load('./timecourses.npy')
sim_tc.shape

In [None]:
# # not needed
# subject_tc = np.asarray([sim_tc[:,ts:ts+400] for ts in range(0,sim_tc.shape[1],400)])
# subject_tc.shape

#### Load the raw data

In [None]:
gt_tc = np.load('./simtb_masked_stdz.npy')
gt_tc.shape

### Reconstruct the component-wise time courses

In [None]:
sm_tc = cmpnt_tc(gt_tc,sm_cmpnts)

### Calculate the temporal correlations

In [None]:
print(sim_tc.shape)
print(sm_tc.shape)
pos_tcorrs, neg_tcorrs = temp_corr(sim_tc, sm_tc)

### Print the temp corrrs for verification

In [None]:
# for k,v in pos_tcorrs.items():
#     print(k,v)

In [None]:
# for k,v in neg_tcorrs.items():
#     print(k,v)

### Plot the spatial maps corresponding to the temporal correlation matches

In [None]:
plot_temp_corr(pos_tcorrs, neg_tcorrs, wts_stdz, sm_cmpnts, 0.6, -0.6)

#### 21 of 27 (77.78%) of ground truth component time courses were reconstructed
#### 11 of 14 (78.57%) of experimental component time courses were reconstructed

In [None]:
### Functional connectivity analysis: 
#### Plot all spatial correlations to ground truth for any weight matrix that has at least one correlation >= the threshold

In [None]:
plot_sm_fnc(scorrs, wts_stdz)

| _ | _ |
| :-------------- | :--------: |
| <p style="text-align: left;"> Number of Weight matrices with spatial correlations to two or more spatial components | 22 |
| <p style="text-align: left;"> Number of selected weights with at least partially correct FNC's | 20 (91%) |
| <p style="text-align: left;"> Number of selected weights with fully correct FNC's | 10 (45%) |
| <p style="text-align: left;"> Number of unique fully inaccurate FNC's | 1 |
| <p style="text-align: left;"> Number of accurately selected components | 10 (62.5%) |
| <p style="text-align: left;"> Number of inaccurate FNC components | 6 (37.5%) |
| <p style="text-align: left;"> Number of connected components represented | 10 of 14 (71%) | 