# Group-Specific Discriminant Analysis for sex-specific lateralization Running Demo

[Open in Colab](https://colab.research.google.com/github/shuo-zhou/GSDA-Lateralization/blob/main/gsda_demo.ipynb)  (click `Runtime` → `Run all (Ctrl+F9)`)

## Setup

The first few blocks of code are necessary to set up the notebook execution environment. This checks if the notebook is running on Google Colab and installs required packages.


In [None]:
#@title  ---- setup environment and fetch code from GitHub ----
if 'google.colab' in str(get_ipython()):
    print('Running on CoLab')
    !pip install yacs
    !git clone https://github.com/shuo-zhou/GSDA-Lateralization
    %cd GSDA-Lateralization
else:
    print('Not running on CoLab')
     

## Import required modules

In [None]:
#@title ---- import modules ----
import os
from configs.default_cfg import get_cfg_defaults
from utils.experiment import run_experiment
from joblib import Parallel, delayed

from utils.io_ import load_result, reformat_results
from utils import plot

## Configurations

The customized configuration used in this demo is stored in `configs/gsda_demo.yaml`, this file overwrites defaults in `default_cfg.py` where a value is specified.

In [None]:
#@title ---- setup configs ----
cfg_path = "configs/demo-hcp.yaml" # Path to `.yaml` config file

cfg = get_cfg_defaults()
cfg.merge_from_file(cfg_path)
cfg.freeze()
print(cfg)

## Model training

It could take a while (15 to 25 mins) to run the experiments. 

In [None]:
Parallel(n_jobs=2)(delayed(run_experiment)(cfg, lambda_) for lambda_ in cfg.SOLVER.LAMBDA_)
# run_experiment(cfg)

## Load and visualize results 

In [None]:
#@title ---- load results to a dataframe ----

dataset = cfg.DATASET.DATASET
model_root_dir = cfg.OUTPUT.ROOT
lambdas = cfg.SOLVER.LAMBDA_
seed_start = cfg.SOLVER.SEED
test_size = cfg.DATASET.TEST_RATIO

res_df = load_result(dataset=dataset, root_dir=model_root_dir, 
                         lambdas=lambdas, seed_start=seed_start, test_size=test_size)

res_df["GSI_train_session"] = 2 * (res_df["acc_tgt_train_session"] * 
                                       (res_df["acc_tgt_train_session"] - 
                                        res_df["acc_nt_train_session"]))
res_df["GSI_test_session"] = 2 * (res_df["acc_tgt_test_session"] * 
                                      (res_df["acc_tgt_test_session"] - 
                                       res_df["acc_nt_test_session"]))

res_df_train_session = reformat_results(res_df, ["acc_tgt_train_session", "acc_nt_train_session"])
res_df_test_session = reformat_results(res_df, ["acc_tgt_test_session", "acc_nt_test_session"])

res_df.loc[res_df["train_group" ]==0, "train_group" ] = "Male"
res_df.loc[res_df["train_group" ]==1, "train_group" ] = "Female"

In [None]:
if not os.path.exists("figures"):
    os.mkdir("figures")
plot.plot_accuracy(res_df_train_session)

In [None]:
plot.plot_gsi(res_df, x="lambda", y="GSI_train_session", hue="train_group")