# requirements

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import scipy.stats as stats

from src.load_data import load_data

fbirn_data, demographics = load_data()
timeseries = fbirn_data["data"]
diagnoses = fbirn_data["diags"]
sexes = fbirn_data["sexes"]
ages = fbirn_data["ages"]

print(demographics)
print(f"# subjects: {timeseries.shape[0]}, # timepoints: {timeseries.shape[1]}, # features: {timeseries.shape[2]}")

In [None]:
ica_coords = pd.read_csv("data/ICN_coordinates.csv")
domains = ica_coords["Domain"]
# update nans with previous value
domains = domains.fillna(method='ffill')
domains = np.asarray(domains.tolist())

change_idx = np.flatnonzero(np.r_[True, domains[1:] != domains[:-1]])
# change_idx marks the start index of each group
starts = change_idx
# compute ends (inclusive) indices for each group
ends = np.r_[starts[1:] - 1, domains.size - 1]
centers = ((starts + ends) / 2.0).tolist()
# boundaries are positions between pixels: (end + 0.4) for each group except last
boundaries = (ends[:-1] + 0.4).tolist()

group_names_full = [domains[s] for s in starts]
group_names = ["SC", "AU", "SM", "VIS", "CC", "DM", "CB"]  # short names

# 0. First milestone
- Write a general setup for the experiments
- Find important features using stattests 
- Train classifiers, inspect the features that they found important

## derive PCC

In [None]:
from src.utils import corrcoef_batch

pcc_matrices = corrcoef_batch(timeseries)
pcc_matrices.shape

## run stat tests on the data

In [None]:
def ttest(data0, data1):
    stat, p_value = stats.ttest_ind(data0, data1, axis=0, equal_var=False)
    return stat, p_value

def analyze_group_differences(data, labels, stat_func = ttest):
    groups = np.unique(labels)
    group_data = [data[labels == g] for g in groups]

    stat, p_value = stat_func(group_data[0], group_data[1])
    p_thresh = (p_value < 0.05).astype(int)

    # # permutation-based p-values
    # n_perm = 5000
    # rng = np.random.RandomState(42)
    # all_data = np.concatenate(group_data, axis=0)
    # n1 = group_data[0].shape[0]
    # perm_stats = np.empty((n_perm,) + stat.shape)

    # for i in range(n_perm):
    #     idx = rng.permutation(all_data.shape[0])
    #     g0 = all_data[idx[:n1]]
    #     g1 = all_data[idx[n1:]]
    #     tperm, _ = stats.ttest_ind(g0, g1, axis=0, equal_var=False)
    #     perm_stats[i] = tperm

    # perm_p = (np.sum(np.abs(perm_stats) >= np.abs(stat), axis=0) + 1) / (n_perm + 1)
    # perm_p[np.isnan(stat)] = np.nan
    # p_value_perm = perm_p

    return stat, p_value, p_thresh

def plot_heatmap(matrix, ax, cmap='bwr', vmin=None, vmax=None, guides_color='k'):
    cax = ax.imshow(matrix, cmap=cmap, vmin=vmin, vmax=vmax)
    ax.set_xticks(centers)
    ax.set_xticklabels(group_names)
    ax.set_yticks(centers)
    ax.set_yticklabels(group_names)
    for boundary in boundaries:
        ax.axhline(boundary, color=guides_color, linewidth=1.5)
        ax.axvline(boundary, color=guides_color, linewidth=1.5)
    return cax

def plot_stats(stat, p_vals, p_thresh):
    fig, ax = plt.subplots(1, 3, figsize=(14, 5))
    cax1 = plot_heatmap(stat, ax[0], vmin=-5, vmax=5)
    ax[0].set_title("T-statistics")
    fig.colorbar(cax1, ax=ax[0], fraction=0.045)  

    cax2 = plot_heatmap(p_vals, ax[1], vmin=0, vmax=1, cmap='inferno_r')
    ax[1].set_title("p-values")
    fig.colorbar(cax2, ax=ax[1], fraction=0.045)  

    cax3 = plot_heatmap(p_thresh, ax[2], vmin=0, vmax=1, cmap='inferno')
    ax[2].set_title("Significant p < 0.05")
    cbar3 = fig.colorbar(cax3, ax=ax[2], fraction=0.045)
    cbar3.set_ticks([0, 1])
    cbar3.set_ticklabels(['False', 'True'])

    plt.tight_layout()
    plt.show()

In [None]:
# plot means

groups = np.unique(diagnoses)
group_data = [pcc_matrices[diagnoses == g] for g in groups]
means = [np.mean(gd, axis=0) for gd in group_data]

fig, ax = plt.subplots(1, 2, figsize=(10, 5))
cax = plot_heatmap(means[0], ax[0], cmap='bwr', vmin=-1, vmax=1)
fig.colorbar(cax, ax=ax[0], fraction=0.045)
ax[0].set_title("Patients Mean PCC")
cax = plot_heatmap(means[1], ax[1], cmap='bwr', vmin=-1, vmax=1)
fig.colorbar(cax, ax=ax[1], fraction=0.045)
ax[1].set_title("Controls Mean PCC")
plt.tight_layout()
plt.show()

In [None]:
stat, p_value, p_thresh = analyze_group_differences(pcc_matrices, diagnoses)
plot_stats(stat, p_value, p_thresh)

# compute True rate in p_thresh; I will use it as a threhold for ML feature importance selection
r_significant = np.sum(p_thresh)/np.size(p_thresh)
print(f"Proportion of significant connections: {r_significant:.4f}")

## use ML to find predictive features

In [None]:
from sklearn.ensemble import RandomForestClassifier
from sklearn.inspection import permutation_importance

def forest_features(X, y, n_estimators=2000, threshold=0.5):
    C = X.shape[1]
    tril_indices = np.tril_indices(X.shape[1], k=-1)
    X = X[:, tril_indices[0], tril_indices[1]]
    n_samples = X.shape[0]
    X = X.reshape(n_samples, -1)

    clf = RandomForestClassifier(n_estimators=n_estimators)
    clf.fit(X, y)
    importances = clf.feature_importances_

    # reshape importances back to matrix form
    full_importances = np.zeros((C, C))
    full_importances[tril_indices] = importances
    full_importances = full_importances + full_importances.T
    importances = full_importances

    # get a mask of top k% importances
    # topk_importances = np.percentile(importances, 100 * threshold)
    topk_importances = np.percentile(importances, 100 * (1-threshold))
    mask = importances > topk_importances

    return importances, mask

In [None]:
importances, importance_mask = forest_features(pcc_matrices, diagnoses, threshold=r_significant)

In [None]:
fig, ax = plt.subplots(1, 3, figsize=(14, 5))
cax = plot_heatmap(importances, ax[0], cmap='inferno', guides_color='white')
fig.colorbar(cax, ax=ax[0], fraction=0.045)
ax[0].set_title("Feature Importances")
cax = plot_heatmap(np.log(importances), ax[1], cmap='inferno', guides_color='white')
fig.colorbar(cax, ax=ax[1], fraction=0.045)
ax[1].set_title("Log Feature Importances")

cax = plot_heatmap(importance_mask, ax[2], cmap='inferno')
ax[2].set_title("Significant Features")
cbar3 = fig.colorbar(cax, ax=ax[2], fraction=0.045)
cbar3.set_ticks([0, 1])
cbar3.set_ticklabels(['False', 'True'])
plt.tight_layout()
plt.show()

In [None]:
# plot p_thresh and importance_mask, and AND matrix
fig, ax = plt.subplots(1, 3, figsize=(14, 5))
cax1 = plot_heatmap(p_thresh, ax[0], cmap='inferno')
sp_p = np.sum(p_thresh)/np.size(p_thresh)
ax[0].set_title(f"p < 0.05 ({int(sp_p*100)}% density)")
cbar1 = fig.colorbar(cax1, ax=ax[0], fraction=0.045)
cbar1.set_ticks([0, 1])
cbar1.set_ticklabels(['False', 'True'])

cax2 = plot_heatmap(importance_mask, ax[1], cmap='inferno')
sp_rf = np.sum(importance_mask)/np.size(importance_mask)
ax[1].set_title(f"Random Forest results ({int(sp_rf*100)}% density)")
cbar2 = fig.colorbar(cax2, ax=ax[1], fraction=0.045)
cbar2.set_ticks([0, 1])
cbar2.set_ticklabels(['False', 'True'])

and_matrix = p_thresh & importance_mask
sp_and = np.sum(and_matrix)/np.size(and_matrix)
cax3 = plot_heatmap(and_matrix, ax[2], cmap='inferno')
ax[2].set_title(f"AND Matrix ({int(sp_and*100)}% density)")
cbar3 = fig.colorbar(cax3, ax=ax[2], fraction=0.045)
cbar3.set_ticks([0, 1])
cbar3.set_ticklabels(['False', 'True'])
plt.tight_layout()
plt.show()

### compare classification scores with different subsets of predictive features 

# More statistics

I will use pyspi to compute statistics. Check their docummentation: https://time-series-features.gitbook.io/pyspi/installing-and-using-pyspi/usage/walkthrough-tutorials/getting-started-a-simple-demonstration

In [None]:
from pyspi.calculator import Calculator
import dill

for i, test_subject in enumerate(timeseries):

    if i < 222:
        continue
    test_subject = test_subject.T # pyspi calc expects data in shape [chanels, time]
    calc = Calculator(dataset=test_subject, configfile='./custom_config.yaml') # instantiate the calculator object

    calc.compute()

    # save and load pickle

    save_path = "/Users/ppopov1/adm-proj/data/pyspi/"
    dill.dump(calc, open(save_path+f"pyspi_calc_{i:03d}.pkl", 'wb'))

    # with open(save_path+f"pyspi_calc_{i:03d}.pkl", 'rb') as f:
    #     calc_load = dill.load(f)

In [None]:
counter = 0

for spi_type in calc.spis:
    spi_data = calc.table["cov_EmpiricalCovariance"].to_numpy()
    # get values off-diagonal
    off_diag = spi_data[np.triu_indices_from(spi_data, k=1)]
    nans_off_diag = np.isnan(off_diag).any()
    if not nans_off_diag:
        counter += 1
    else:
        print(F"{spi_type} data has nans off-diagonal: {nans_off_diag}")

print(F"Number of SPI types without nans off-diagonal: {counter} out of {len(calc.spis)}")

In [None]:
calc.table["cov_EmpiricalCovariance"]
calc.table["cov_EmpiricalCovariance"].to_numpy().shape