# Search HPS model
- get the inputs correct
- run through all steps
- get an accuracy estimate
- now with subtype weights

In [1]:
%matplotlib inline

In [2]:
import os
import re
import sys
import time
sys.path.append('/home/surchs/git/HPS')
from hps.predic import high_confidence
from hps.visu import hps_visu
sys.path.append('/home/surchs/git/HPS/examples/')
import visu_demo
import scipy as sp
import patsy as pat
import numpy as np
import pandas as pd
import nibabel as nib
import sklearn as skl
import scipy.io as sio
import seaborn as sbn
from scipy import cluster as scl
from scipy import stats as spt
from matplotlib import pyplot as plt
from sklearn import linear_model as sln
from sklearn import preprocessing as skp
from sklearn.model_selection import StratifiedKFold

In [3]:
import warnings
warnings.filterwarnings('ignore')

In [4]:
n_seed = 7
n_subtypes = 4

In [5]:
# Paths
root_p = '/home/surchs/sim_big/PROJECT/abide_hps/'
# Pheno
sample_p = os.path.join(root_p, 'pheno', 'psm_abide1.csv')
# Data
ct_p = os.path.join(root_p, 'ct')
seed_p = os.path.join(root_p, 'seed', 'MIST_{}'.format(n_seed))
mask_p = os.path.join(root_p, 'mask', 'MIST_mask.nii.gz')
label_p = os.path.join(root_p, 'mask', 'roi_label_scale_20_overlap.csv')
# File templates
ct_t = '{}+{:07}_{}+{}_native_rms_rsl_tlaplace_30mm_{}.txt'
sd_t = 'sub_{{}}_mist_{}.npy'.format(n_seed)

In [6]:
# Load data
sample = pd.read_csv(sample_p)
sample['DX_CODE'] = sample['DX_GROUP'].replace({'Autism':1, 'Control':0})
label = pd.read_csv(label_p, delimiter=';')

In [7]:
mask_i = nib.load(mask_p)
mask = mask_i.get_data().astype(bool)
n_vox = np.sum(mask)

# Run the CV model

In [8]:
def corr2_coeff(A,B):
    # Rowwise mean of input arrays & subtract from input arrays themeselves
    A_mA = A - A.mean(1)[:,None]
    B_mB = B - B.mean(1)[:,None]

    # Sum of squares across rows
    ssA = (A_mA**2).sum(1);
    ssB = (B_mB**2).sum(1);

    # Finally get corr coeff
    return np.dot(A_mA,B_mB.T)/np.sqrt(np.dot(ssA[:,None],ssB[None]))

In [9]:
def subtype(stack, n_subtypes):
    # Normalize and then get the distance
    norm = skp.scale(stack, axis=1)
    # Get the lower triangle of the distance metric
    dist = sp.spatial.distance.pdist(norm)
    # Build the cluster
    link = scl.hierarchy.linkage(dist, method='ward')
    order = scl.hierarchy.dendrogram(link, no_plot=True)['leaves']
    part = scl.hierarchy.fcluster(link, n_subtypes, criterion='maxclust')
    return order, part, dist

In [10]:
def regress_fc(sample, formula, n_vox, n_seed, seed_p, sd_t):
    n_sub = sample.shape[0]
    resid_seed = np.zeros((n_sub, n_vox, n_seed))
    dmat_seed = pat.dmatrix(formula, data=sample)
    for sid in range(n_seed):
        # Build the regression model for the seed maps
        mod = sln.LinearRegression(fit_intercept=True, normalize=True, n_jobs=-1)
        sub_seed = np.zeros((n_sub, n_vox))
        # Line index doesn't necessarily match continuous index
        for rid, (rid_abs, row) in enumerate(sample.iterrows()):
            p = os.path.join(seed_p, sd_t.format(row['SUB_ID']))
            d = np.load(p)
            sub_seed[rid, :] = d[sid, ...]
        res = mod.fit(dmat_seed, sub_seed)
        resid = sub_seed - res.predict(dmat_seed)
        resid_seed[..., sid] = resid
    
    return resid_seed

In [11]:
def regress_ct(sample, formula, ct_p, ct_t):
    n_sub = sample.shape[0]
    # Generate the CT residuals
    for rid, (rid_abs, row) in enumerate(sample.iterrows()):
        p_right = os.path.join(ct_p, ct_t.format(row['Site'], row['Subject'], row['Session'], row['Run'], 'right'))
        p_left = os.path.join(ct_p, ct_t.format(row['Site'], row['Subject'], row['Session'], row['Run'], 'left'))
        ct_l = pd.read_csv(p_left, header=None)[0].values
        ct_r = pd.read_csv(p_right, header=None)[0].values
        # Combine left and right
        ct = np.concatenate((ct_l, ct_r))
        if rid==0:
            n_vert = len(ct)
            sub_ct = np.zeros((n_sub, n_vert))
        sub_ct[rid, :] = ct
    dmat_ct = pat.dmatrix(formula, data=sample)
    mod = sln.LinearRegression(fit_intercept=True, normalize=True, n_jobs=-1)
    res = mod.fit(dmat_ct, sub_ct)
    resid_ct = sub_ct - res.predict(dmat_ct)
    
    return resid_ct

In [12]:
def make_subtype_fc(resid, n_subtypes):
    n_sub, n_vox, n_seed = resid.shape
    # Run the FC subtypes
    weights_fc = np.zeros((n_sub, n_subtypes, n_seed))
    subtypes_fc = np.zeros((n_subtypes,) + resid.shape[1:])
    parts_fc = np.zeros((n_sub, n_seed))
    orders_fc = np.zeros((n_sub, n_seed))
    dists_fc = np.zeros((n_sub, n_sub, n_seed))

    for sid in range(n_seed):
        order_fc, part_fc, dist_fc = subtype(resid[..., sid], n_subtypes)
        dists_fc[..., sid] = sp.spatial.distance.squareform(dist_fc)
        parts_fc[:, sid] = part_fc
        orders_fc[:, sid] = order_fc
        # Make the subtypes
        subtypes_fc_tmp = np.array([np.mean(resid[part_fc==i, :, sid], 0) 
                                    for i in range(1,n_subtypes+1)])
        subtypes_fc[..., sid] = subtypes_fc_tmp
        # Compute the weights
        weights_fc[..., sid] = corr2_coeff(resid[..., sid], subtypes_fc_tmp)
    return subtypes_fc, weights_fc

In [13]:
def make_subtype_ct(resid, n_subtypes):
    order_ct, part_ct, dist_ct = subtype(resid, n_subtypes)
    # Make the subtypes
    subtypes_ct = np.array([np.mean(resid[part_ct==i, :], 0) 
                            for i in range(1,n_subtypes+1)])
    # Compute the weights
    weights_ct = corr2_coeff(resid, subtypes_ct)
    return (subtypes_ct, weights_ct)

In [14]:
def make_weights_fc(subtypes, resid):
    n_sub, n_vox, n_seed = resid.shape
    n_subtypes = subtypes.shape[0]
    weights_fc = np.zeros((n_sub, n_subtypes, n_seed))
    for sid in range(n_seed):
    # Compute the weights
        weights_fc[..., sid] = corr2_coeff(resid[..., sid], subtypes[..., sid])
    return weights_fc

In [15]:
def make_weights_ct(subtypes, resid):
    weights_ct = corr2_coeff(resid, subtypes)
    return weights_ct

In [16]:
# Get the full range of subject indices and clinical labels
sub_indices = sample.index.values
labels = sample['DX_CODE'].values

In [17]:
fc_cols = ['fc_n{}_s{}'.format(nid+1, sid+1) 
           for sid in range(n_subtypes) 
           for nid in range(n_seed)]
ct_cols = ['ct_s{}'.format(sid+1) 
           for sid in range(n_subtypes)]
cols = ct_cols + fc_cols
col_features = ['BV', 'AGE_AT_SCAN', 'FD_scrubbed', ] + cols
#col_features = ['BV', 'FD_scrubbed', ] + cols

In [18]:
scores_s1_l = list()
scores_s2_l = list()
y_target_l = list()

start = time.time()
took = []
skf = StratifiedKFold(n_splits=10)
for cv_idx, (train_index, test_index) in enumerate(skf.split(sub_indices, labels)):
    cv_start = time.time()
    
    # Get the train, and test sample
    train_sample = sample.loc[train_index]
    test_sample = sample.loc[test_index]
    n_sub_train = train_sample.shape[0]
    n_sub_test = test_sample.shape[0]
    
    # Replicate the subtyping process
    # Extract the train and test data and regress nuisance factors
    train_resid_fc = regress_fc(train_sample, 
                               'AGE_AT_SCAN + FD_scrubbed + Site', 
                               n_vox, n_seed=n_seed, 
                               seed_p=seed_p, sd_t=sd_t)
    test_resid_fc = regress_fc(test_sample, 
                              'AGE_AT_SCAN + FD_scrubbed + Site', 
                              n_vox, n_seed=n_seed, 
                              seed_p=seed_p, sd_t=sd_t)
    train_resid_ct = regress_ct(train_sample, 'AGE_AT_SCAN + Site', ct_p, ct_t)
    test_resid_ct = regress_ct(test_sample, 'AGE_AT_SCAN + Site', ct_p, ct_t)
    # Make the subtypes from the train data
    (subtypes_fc, train_weights_fc) = make_subtype_fc(train_resid_fc, n_subtypes=n_subtypes)
    (subtypes_ct, train_weights_ct) = make_subtype_ct(train_resid_ct, n_subtypes=n_subtypes)
    # Get the test weights
    test_weights_fc = make_weights_fc(subtypes_fc, test_resid_fc)
    test_weights_ct = make_weights_ct(subtypes_ct, test_resid_ct)
    
    # Build input data
    train_fc = np.reshape(train_weights_fc, (n_sub_train, n_subtypes*n_seed))
    test_fc = np.reshape(test_weights_fc, (n_sub_test, n_subtypes*n_seed))
    train_w = np.concatenate((train_weights_ct, train_fc), 1)
    test_w = np.concatenate((test_weights_ct, test_fc), 1)
    
    # Make sure we use the correct index or else there will be NaNs in the weight columns
    w_data_train = pd.DataFrame(data=train_w, columns=cols, index=train_index)
    data_train = train_sample.join(w_data_train)
    w_data_test = pd.DataFrame(data=test_w, columns=cols, index=test_index)
    data_test = test_sample.join(w_data_test)
    
    # Select the features
    scaler = skl.preprocessing.StandardScaler()
    x_train = data_train.loc[:, col_features]
    # Normalize
    X_train = scaler.fit_transform(x_train)
    # Take the numeric diagnosis code, 0 is control, 1 is autism
    y_train = data_train.loc[:, ['DX_CODE']].values.squeeze()

    # Same for the test data
    x_test = data_test.loc[:, col_features]
    # Normalize, but use the fitted scalar of the training data
    X_test = scaler.transform(x_test)
    y_test = data_test.loc[:, ['DX_CODE']].values.squeeze()
    
    # Train the model
    hps = high_confidence.TwoStagesPrediction(verbose=False,
                                          n_iter=1000,
                                          shuffle_test_split=0.5,
                                            gamma=1,
                                          min_gamma=0.95,
                                          thresh_ratio=0.2)
    hps.fit(X_train, y_train)
    scores, dic_results = hps.predict(X_test)
    scores_s1_l.append(dic_results['s1_hat'])
    scores_s2_l.append(dic_results['s2_hat'])
    y_target_l.append(y_test)
    
    
    
    current_duration = time.time()-cv_start
    took.append(current_duration)
    avg_time = np.mean(took)
    elapsed_time = np.sum(took)
    remaining_time = avg_time * (9-cv_idx)
    
    print('CV fold {} done. Took {:.2f}s ({:.2f}s), {:.2f}s total, {:.2f}s to go.'.format(cv_idx+1,
                                                                              current_duration,
                                                                              avg_time,
                                                                              elapsed_time,
                                                                              remaining_time))

Stage 1
Stage 2
CV fold 1 done. Took 167.32s (167.32s), 167.32s total, 1505.84s to go.
Stage 1
Stage 2
CV fold 2 done. Took 163.13s (165.22s), 330.45s total, 1321.79s to go.
Stage 1
Stage 2
CV fold 3 done. Took 174.95s (168.47s), 505.40s total, 1179.27s to go.
Stage 1
Stage 2
CV fold 4 done. Took 119.45s (156.21s), 624.85s total, 937.27s to go.
Stage 1
Stage 2
CV fold 5 done. Took 168.77s (158.72s), 793.62s total, 793.62s to go.
Stage 1
Stage 2
CV fold 6 done. Took 149.45s (157.18s), 943.08s total, 628.72s to go.
Stage 1
Stage 2
CV fold 7 done. Took 140.95s (154.86s), 1084.03s total, 464.58s to go.
Stage 1
Stage 2
CV fold 8 done. Took 171.23s (156.91s), 1255.26s total, 313.82s to go.
Stage 1
Stage 2
CV fold 9 done. Took 126.80s (153.56s), 1382.06s total, 153.56s to go.
Stage 1
Stage 2
CV fold 10 done. Took 162.27s (154.43s), 1544.32s total, 0.00s to go.


In [19]:
y = sample.DX_CODE.values.squeeze()
ohe = skl.preprocessing.OneHotEncoder(sparse=False)
ohe.fit(y.reshape(-1, 1))
labels = ohe.transform(y.reshape(-1, 1))

In [20]:
scores_s1_arr = np.vstack(scores_s1_l)
scores_s2_arr = np.vstack(scores_s2_l)
y_target_arr = np.hstack(y_target_l)

########################
print('##########################')
# S1
y_mb = ohe.transform(y_target_arr[:,np.newaxis])
pred_y_ = scores_s1_arr

print('Stage 1 (BASE)')
hps_visu.print_scores(hps_visu.scores(y_mb, pred_y_))


# S2
y_mb = ohe.transform(y_target_arr[:,np.newaxis])
pred_y_ = scores_s2_arr

print('Stage 2 (HPS)')
hps_visu.print_scores(hps_visu.scores(y_mb, pred_y_)) 
print('##########################')

##########################
Stage 1 (BASE)
Class 0 Precision: 58.06 Specificity: 57.14 Recall: 57.45 N: 186
Class 1 Precision: 56.52 Specificity: 57.45 Recall: 57.14 N: 184
Total Precision: 57.29 Specificity: 57.29 Recall: 57.29 N: 185
Stage 2 (HPS)
Class 0 Precision: 69.05 Specificity: 92.86 Recall: 15.43 N: 42
Class 1 Precision: 62.07 Specificity: 94.15 Recall:  9.89 N: 29
Total Precision: 65.56 Specificity: 93.50 Recall: 12.66 N: 35
##########################


# Findings

1. min_gamma = 0.9, gamma=0.98, min_thresh=0.2, n_iter=1000, split_ratio=0.5:
    - prec: 85.71%
    - spec: 97.87%
    - sens: 13.19%
2. min_gamma = 0.85, gamma=1, min_thresh=0.25, n_iter=1000, split_ratio=0.5:
    - prec: 75.68%
    - spec: 95.21%
    - sens: 15.38%
3. min_gamma = 0.99, gamma=1, min_thresh=0.25, n_iter=1000, split_ratio=0.5:
    - prec: something around 85
    - spec: something around 99
    - sens: something around 1% (I got 2 labeled).
4. min_gamma = 0.96, gamma=1, min_thresh=0.25, n_iter=1000, split_ratio=0.5:
    - prec: 66.67
    - spec: 97.87
    - sens: 4.40 (12)
5. min_gamma = 0.9, gamma=1, min_thresh=0.3, n_iter=1000, split_ratio=0.5:
    - prec: 85.71
    - spec: 97.81
    - sens: 13.19
6. min_gamma = 0.7, gamma=1, min_thresh=0.3, n_iter=1000, split_ratio=0.5:
    - prec: 70.27
    - spec: 88.30
    - sens: 28.12
7. N_sbt=3 (so far 5). min_gamma = 0.9, gamma=0.98, min_thresh=0.1, n_iter=1000, split_ratio=0.5:
    - prec: 73.68
    - spec: 94.68
    - sens: 15.38
8. N_sbt=3 (so far 5). min_gamma = 0.9, gamma=0.98, min_thresh=0.2, n_iter=1000, split_ratio=0.5:
    - prec: 68.09
    - spec: 92.02
    - sens: 17.58
9. Scale 12, N_sbt=3 (so far 5). min_gamma = 0.9, gamma=0.98, min_thresh=0.2, n_iter=1000, split_ratio=0.5:
    - prec: 74.51
    - spec: 93.09
    - sens: 20.88
10. Scale 12, N_sbt=5. min_gamma = 0.9, gamma=1, min_thresh=0.2, n_iter=1000, split_ratio=0.5:
    - prec: 74.93
    - spec: 95.67
    - sens: 13.46
11. Scale 12, N_sbt=3. min_gamma = 0.95, gamma=1, min_thresh=0.3, n_iter=500, split_ratio=0.5 (NO AGE):
    - prec: 75
    - spec: 94.68
    - sens: 16.48
12. Scale 12, N_sbt=4. min_gamma = 0.95, gamma=1, min_thresh=0.2, n_iter=500, split_ratio=0.5 (NO AGE):
    - prec: 84.62
    - spec: 97.87
    - sens: 12.09
13. Scale 12, N_sbt=4. min_gamma = 0.95, gamma=1, min_thresh=0.2, n_iter=500, split_ratio=0.5 (WITH AGE):
    - prec: 91.30
    - spec: 98.94
    - sens: 11.54
13. Scale 20, N_sbt=4. min_gamma = 0.95, gamma=1, min_thresh=0.2, n_iter=500, split_ratio=0.5 (WITH AGE):
    - prec: 81.25
    - spec: 98.40
    - sens: 7.14