In [1]:
from dask.distributed import Client, LocalCluster
import logging

cluster = LocalCluster(
    n_workers=28,
    threads_per_worker=8,
    silence_logs=logging.DEBUG
)

client = Client(cluster, heartbeat_interval=10000)
print(client.dashboard_link)

distributed.scheduler - INFO - Clear task state
distributed.scheduler - INFO -   Scheduler at:     tcp://127.0.0.1:39379
distributed.scheduler - INFO -   dashboard at:            127.0.0.1:8787
distributed.nanny - INFO -         Start Nanny at: 'tcp://127.0.0.1:43165'
distributed.nanny - INFO -         Start Nanny at: 'tcp://127.0.0.1:44125'
distributed.nanny - INFO -         Start Nanny at: 'tcp://127.0.0.1:44871'
distributed.nanny - INFO -         Start Nanny at: 'tcp://127.0.0.1:44723'
distributed.nanny - INFO -         Start Nanny at: 'tcp://127.0.0.1:37765'
distributed.nanny - INFO -         Start Nanny at: 'tcp://127.0.0.1:36195'
distributed.nanny - INFO -         Start Nanny at: 'tcp://127.0.0.1:39935'
distributed.nanny - INFO -         Start Nanny at: 'tcp://127.0.0.1:42569'
distributed.nanny - INFO -         Start Nanny at: 'tcp://127.0.0.1:38645'
distributed.nanny - INFO -         Start Nanny at: 'tcp://127.0.0.1:35029'
distributed.nanny - INFO -         Start Nanny at: 'tcp:

http://127.0.0.1:8787/status


In [2]:
import afqinsight as afqi
import joblib
import matplotlib.pyplot as plt
import numpy as np
import os.path as op
import pandas as pd
import pickle
import seaborn as sns

from datetime import datetime

from sklearn.base import clone
from sklearn.model_selection import RepeatedKFold
from sklearn.metrics import median_absolute_error, r2_score
from sklearn.metrics import explained_variance_score, mean_squared_error
from sklearn.linear_model import LassoCV

from skopt import BayesSearchCV
from skopt.plots import plot_convergence, plot_objective, plot_evaluations

print(afqi.__version__)

0.2.9.dev460469908


In [3]:
X, y, groups, columns, subjects, classes = afqi.load_afq_data(
    "../data/raw/age_data",
    target_cols=["Age"],
)

In [4]:
label_sets = afqi.multicol2sets(pd.MultiIndex.from_tuples(columns, names=["metric", "tractID", "nodeID"]))

In [5]:
pyafq_bundles = [
    c for c in columns
    if c[1] not in ["Right Cingulum Hippocampus", "Left Cingulum Hippocampus"]
]
pyafq_bundles = [
    [c] for c in np.unique([col[1] for col in pyafq_bundles])
]

In [6]:
X_pyafq_bundles = afqi.select_groups(
    X,
    pyafq_bundles,
    label_sets
)

In [7]:
print(X.shape)
print(X_pyafq_bundles.shape)
print(len(label_sets))

(77, 10000)
(77, 9000)
10000


In [8]:
columns = [
    c for c in columns 
    if c[1] not in ["Right Cingulum Hippocampus", "Left Cingulum Hippocampus"]
]
label_sets = afqi.multicol2sets(pd.MultiIndex.from_tuples(columns, names=["metric", "tractID", "nodeID"]))

X_md_fa = afqi.select_groups(
    X_pyafq_bundles,
    [["fa"], ["md"]],
    label_sets
)

In [9]:
print(X.shape)
print(X_pyafq_bundles.shape)
print(X_md_fa.shape)

(77, 10000)
(77, 9000)
(77, 3600)


In [10]:
groups_md_fa = groups[:36]

In [11]:
def get_cv_results(n_repeats=5, n_splits=10,
                   power_transformer=False, 
                   shuffle=False,
                   ensembler=None,
                   target_transform_func=None,
                   target_transform_inverse_func=None,
                   n_estimators=10,
                   trim_nodes=0,
                   square_features=False):
    if shuffle:
        rng = np.random.default_rng()
        y_fit = rng.permutation(y)
    else:
        y_fit = np.copy(y)

    if trim_nodes > 0:
        grp_mask = np.zeros_like(groups_md_fa[0], dtype=bool)
        grp_mask[trim_nodes:-trim_nodes] = True
        X_mask = np.concatenate([grp_mask] * len(groups_md_fa))

        groups_trim = []
        start_idx = 0
        
        for grp in groups_md_fa:
            stop_idx = start_idx + len(grp) - 2 * trim_nodes
            groups_trim.append(np.arange(start_idx, stop_idx))
            start_idx += len(grp) - 2 * trim_nodes
            
        X_trim = X_md_fa[:, X_mask]
    elif trim_nodes == 0:
        groups_trim = [grp for grp in groups_md_fa]
        X_trim = np.copy(X_md_fa)
    else:
        raise ValueError("trim_nodes must be non-negative.")
        
    if square_features:
        _n_samples, _n_features = X_trim.shape
        X_trim = np.hstack([X_trim, np.square(X_trim)])
        groups_trim = [np.concatenate([g, g + _n_features]) for g in groups_trim]
    
    cv = RepeatedKFold(
        n_splits=n_splits,
        n_repeats=n_repeats,
        random_state=1729
    )

    cv_results = {}    
    
    pipe_skopt = afqi.pipeline.make_base_afq_pipeline(
        imputer_kwargs={"strategy": "median"},
        power_transformer=power_transformer,
        scaler="standard",
        estimator=LassoCV,
        estimator_kwargs={
            "verbose": 0,
            "n_alphas": 50,
            "cv": 3,
            "n_jobs": 28,
            "max_iter": 500,
        },
        verbose=0,
        ensemble_meta_estimator=ensembler,
        ensemble_meta_estimator_kwargs={
            "n_estimators": n_estimators,
            "n_jobs": 1,
            "oob_score": True,
            "random_state": 1729,
        },
        target_transform_func=target_transform_func,
        target_transform_inverse_func=target_transform_inverse_func,
    )

    for cv_idx, (train_idx, test_idx) in enumerate(cv.split(X_trim, y_fit)):
        start = datetime.now()

        X_train, X_test = X_trim[train_idx], X_trim[test_idx]
        y_train, y_test = y_fit[train_idx], y_fit[test_idx]

        with joblib.parallel_backend("dask"):
            pipe_skopt.fit(X_train, y_train)

        cv_results[cv_idx] = {
            "pipeline": pipe_skopt,
            "train_idx": train_idx,
            "test_idx": test_idx,
            "y_pred": pipe_skopt.predict(X_test),
            "y_true": y_test,
            "test_mae": median_absolute_error(y_test, pipe_skopt.predict(X_test)),
            "train_mae": median_absolute_error(y_train, pipe_skopt.predict(X_train)),
            "test_r2": r2_score(y_test, pipe_skopt.predict(X_test)),
            "train_r2": r2_score(y_train, pipe_skopt.predict(X_train)),
        }
        
        if ((target_transform_func is not None)
            or (target_transform_inverse_func is not None)):
            cv_results[cv_idx]["coefs"] = [
                est.coef_ for est
                in pipe_skopt.named_steps["estimate"].regressor_.estimators_
            ]
            cv_results[cv_idx]["alpha"] = [
                est.alpha_ for est
                in pipe_skopt.named_steps["estimate"].regressor_.estimators_
            ]
        else:
            cv_results[cv_idx]["coefs"] = [
                est.coef_ for est
                in pipe_skopt.named_steps["estimate"].estimators_
            ]
            cv_results[cv_idx]["alpha"] = [
                est.alpha_ for est
                in pipe_skopt.named_steps["estimate"].estimators_
            ]
        
        if ensembler is None:
            if ((target_transform_func is not None)
                or (target_transform_inverse_func is not None)):
                cv_results[cv_idx]["optimizer"] = pipe_skopt.named_steps["estimate"].regressor_.bayes_optimizer_                
            else:
                cv_results[cv_idx]["optimizer"] = pipe_skopt.named_steps["estimate"].bayes_optimizer_

        print(f"CV index [{cv_idx:3d}], Elapsed time: ", datetime.now() - start)
        
    return cv_results, y_fit

In [12]:
results = {}

trim_nodes = 0
results[f"bagging_pure_lasso_trim{trim_nodes}"] = get_cv_results(
    n_splits=10, n_repeats=1, power_transformer=False,
    ensembler="serial-bagging", shuffle=False, n_estimators=20,
    trim_nodes=trim_nodes, square_features=False,
)

results[f"bagging_target_transform_pure_lasso_trim{trim_nodes}"] = get_cv_results(
    n_splits=10, n_repeats=1, power_transformer=False,
    ensembler="serial-bagging", shuffle=False, n_estimators=20,
    target_transform_func=np.log, target_transform_inverse_func=np.exp,
    trim_nodes=trim_nodes, square_features=False,
)

distributed.scheduler - INFO - Receive client connection: Client-worker-6f130502-619c-11eb-90fa-c77ba38258a6
distributed.core - INFO - Starting established connection
distributed.scheduler - INFO - Receive client connection: Client-worker-6f1141da-619c-11eb-90fa-42010a8a0002
distributed.core - INFO - Starting established connection
distributed.scheduler - INFO - Receive client connection: Client-worker-6f11e606-619c-11eb-90fa-42010a8a0002
distributed.core - INFO - Starting established connection
  positive)
distributed.scheduler - INFO - Receive client connection: Client-worker-741b142e-619c-11eb-9291-cbd01f75bb23
distributed.core - INFO - Starting established connection
  positive)
  positive)
distributed.scheduler - INFO - Receive client connection: Client-worker-7afb2f82-619c-11eb-9104-d5b9c361e6db
distributed.core - INFO - Starting established connection
distributed.scheduler - INFO - Receive client connection: Client-worker-7af9d2ba-619c-11eb-9104-42010a8a0002
distributed.core - I

CV index [  0], Elapsed time:  0:01:10.569597


  positive)
distributed.scheduler - INFO - Receive client connection: Client-worker-9b6973b8-619c-11eb-9321-a710d86cafda
distributed.core - INFO - Starting established connection
distributed.scheduler - INFO - Receive client connection: Client-worker-9b67f0da-619c-11eb-9321-42010a8a0002
distributed.core - INFO - Starting established connection
  positive)
  positive)
distributed.scheduler - INFO - Receive client connection: Client-worker-a0910000-619c-11eb-916b-eb501747007b
distributed.core - INFO - Starting established connection
distributed.scheduler - INFO - Receive client connection: Client-worker-a08ef01a-619c-11eb-916b-42010a8a0002
distributed.core - INFO - Starting established connection
distributed.scheduler - INFO - Receive client connection: Client-worker-a08fd166-619c-11eb-916b-42010a8a0002
distributed.core - INFO - Starting established connection
  positive)
  positive)
distributed.scheduler - INFO - Receive client connection: Client-worker-a50d2f6e-619c-11eb-910d-814c40f1b

CV index [  1], Elapsed time:  0:01:12.758517


distributed.scheduler - INFO - Receive client connection: Client-worker-c474a3b6-619c-11eb-9371-42010a8a0002
distributed.core - INFO - Starting established connection
distributed.scheduler - INFO - Receive client connection: Client-worker-c4776ec2-619c-11eb-9371-42010a8a0002
distributed.core - INFO - Starting established connection
distributed.scheduler - INFO - Receive client connection: Client-worker-c4755028-619c-11eb-9371-42010a8a0002
distributed.core - INFO - Starting established connection
  positive)
distributed.scheduler - INFO - Receive client connection: Client-worker-c9190ecc-619c-11eb-912d-8742d1ea5a56
distributed.core - INFO - Starting established connection
  positive)
  positive)
distributed.scheduler - INFO - Receive client connection: Client-worker-ced32ed2-619c-11eb-9100-0d79429d727f
distributed.core - INFO - Starting established connection
distributed.scheduler - INFO - Receive client connection: Client-worker-ced20f90-619c-11eb-9100-42010a8a0002
distributed.core - I

CV index [  2], Elapsed time:  0:01:04.382285


distributed.scheduler - INFO - Receive client connection: Client-worker-ecfd338a-619c-11eb-93fd-42010a8a0002
distributed.core - INFO - Starting established connection
distributed.scheduler - INFO - Receive client connection: Client-worker-ecff7e82-619c-11eb-93fd-42010a8a0002
distributed.core - INFO - Starting established connection
  positive)
  positive)
distributed.scheduler - INFO - Receive client connection: Client-worker-f1e21f88-619c-11eb-9257-275d12d021a4
distributed.core - INFO - Starting established connection
distributed.scheduler - INFO - Receive client connection: Client-worker-f1e1202e-619c-11eb-9257-42010a8a0002
distributed.core - INFO - Starting established connection
distributed.scheduler - INFO - Receive client connection: Client-worker-f1dfd9a4-619c-11eb-9257-42010a8a0002
distributed.core - INFO - Starting established connection
distributed.scheduler - INFO - Receive client connection: Client-worker-f5962a30-619c-11eb-93d1-3765a2138d39
distributed.core - INFO - Starti

CV index [  3], Elapsed time:  0:01:06.933526


  positive)
  positive)
  positive)
  positive)
  positive)
  positive)
  positive)


CV index [  4], Elapsed time:  0:01:07.805319


  positive)
  positive)
  positive)
  positive)
  positive)
  positive)
  positive)
  positive)
  positive)
  positive)


CV index [  5], Elapsed time:  0:01:06.356548


  positive)
  positive)
  positive)
  positive)
  positive)
  positive)
  positive)
  positive)
  positive)
  positive)


CV index [  6], Elapsed time:  0:01:03.735059


  positive)
  positive)
  positive)
  positive)
  positive)
  positive)
  positive)
  positive)


CV index [  7], Elapsed time:  0:01:13.812823


  positive)
  positive)
  positive)
  positive)
  positive)
  positive)
  positive)
  positive)
  positive)
  positive)
  positive)
  positive)


CV index [  8], Elapsed time:  0:01:07.490455


  positive)
  positive)
  positive)
  positive)
  positive)
  positive)
  positive)
  positive)
  positive)
  positive)
  positive)
  positive)


CV index [  9], Elapsed time:  0:01:09.639459


  positive)
  positive)
  positive)
  positive)
  positive)
  positive)


CV index [  0], Elapsed time:  0:01:16.272958


  positive)
  positive)
  positive)
  positive)
  positive)
  positive)
  positive)
  positive)
  positive)
  positive)


CV index [  1], Elapsed time:  0:01:18.252264


  positive)
  positive)


CV index [  2], Elapsed time:  0:01:08.465886


  positive)
  positive)
  positive)
  positive)
  positive)
  positive)
  positive)
  positive)
  positive)


CV index [  3], Elapsed time:  0:01:12.507264


  positive)
  positive)
  positive)
  positive)
  positive)
  positive)


CV index [  4], Elapsed time:  0:01:09.281263


  positive)
  positive)
  positive)
  positive)
  positive)
  positive)
  positive)
  positive)
  positive)


CV index [  5], Elapsed time:  0:01:14.867106


  positive)
  positive)
  positive)
  positive)
  positive)
  positive)
  positive)
  positive)
  positive)


CV index [  6], Elapsed time:  0:01:05.705771


  positive)
  positive)
  positive)
  positive)


CV index [  7], Elapsed time:  0:01:18.637546


  positive)
  positive)
  positive)
  positive)
  positive)
  positive)
  positive)


CV index [  8], Elapsed time:  0:01:16.516677


  positive)
  positive)
  positive)
  positive)
  positive)
  positive)
  positive)
  positive)


CV index [  9], Elapsed time:  0:01:16.198299


  positive)


In [13]:
with open("age_regression_pure_lasso.pkl", "wb") as fp:
    pickle.dump(results, fp)

In [14]:
results.keys()

dict_keys(['bagging_pure_lasso_trim0', 'bagging_target_transform_pure_lasso_trim0'])

In [15]:
for key, res in results.items():
    test_mae = [cvr["test_mae"] for cvr in res[0].values()]
    train_mae = [cvr["train_mae"] for cvr in res[0].values()]
    test_r2 = [cvr["test_r2"] for cvr in res[0].values()]
    train_r2 = [cvr["train_r2"] for cvr in res[0].values()]

    print(key, "test  MAE", np.mean(test_mae))
    print(key, "train MAE", np.mean(train_mae))
    print(key, "test  R2 ", np.mean(test_r2))
    print(key, "train R2 ", np.mean(train_r2))

bagging_pure_lasso_trim0 test  MAE 4.423867405561406
bagging_pure_lasso_trim0 train MAE 1.9784476594799372
bagging_pure_lasso_trim0 test  R2  0.5075985597516021
bagging_pure_lasso_trim0 train R2  0.888762570901591
bagging_target_transform_pure_lasso_trim0 test  MAE 3.8852243861875415
bagging_target_transform_pure_lasso_trim0 train MAE 1.3700674941862947
bagging_target_transform_pure_lasso_trim0 test  R2  0.5018346093008995
bagging_target_transform_pure_lasso_trim0 train R2  0.8507616898440411
