# rPCA + Group Registration

In [13]:
import glob
import logging
import multiprocessing
import os
import time
from pathlib import Path

import hydra
import pandas as pd
import scipy.io
from omegaconf import DictConfig, OmegaConf
from tqdm import tqdm

os.environ['VXM_BACKEND'] = 'pytorch'
import voxelmorph_group as vxm  # nopep8

## Weight=0.3, rank=10, 5, 5

In [14]:
def main(path, weight=0.3, rank1=10):
    cfg = OmegaConf.load(path)
    
    cfg['weight'] = weight
    cfg.rpca_rank.rank1 = rank1
    conf = OmegaConf.structured(OmegaConf.to_container(cfg, resolve=True))   
    return conf

conf = main('../conf/config.yaml', weight=0.3, rank1=10)

conf.round = 3
conf.moved = os.path.join(conf.inference, f"round{conf.round}", 'moved')

In [None]:
col = ['Cases', 'raw MSE', 'registered MSE', 'raw PCA',
       'registered PCA', 'raw T1err', 'registered T1err']
df = pd.DataFrame(columns=col)

# load the TI for all subjects
if conf.TI_json:
    import json
    with open(f"{conf.TI_json}") as json_file:
        TI_dict = json.load(json_file)

device = 'cpu'

train_files = os.listdir(conf.moving)

add_feat_axis = not conf.multichannel
for idx, subject in enumerate(train_files):
    name = Path(subject).stem
    start = time.time()
    raw_vols, fixed_affine = vxm.py.utils.load_volfile(os.path.join(conf.moving, name), add_feat_axis=add_feat_axis, ret_affine=True)
    rig_vols, fixed_affine = vxm.py.utils.load_volfile(os.path.join(conf.moved, name), add_feat_axis=add_feat_axis, ret_affine=True)
    orig_T1err = vxm.groupwise.utils.update_atlas(raw_vols, -1, 't1map', tvec=tvec)
    rigs_T1err = vxm.groupwise.utils.update_atlas(rig_vols, -1, 't1map', tvec=tvec)
    et = time.time()
    mean_orig_T1err = np.mean(orig_T1err)
    mean_rigs_T1err = np.mean(rigs_T1err)
    print(f"{name}, Time elapsed: {(et - start)/60} mins, T1 error orig {mean_orig_T1err} and rigs {mean_rigs_T1err}")
    
    
df['MSE changes percentage'] = percentage_change(
    df['raw MSE'], df['registered MSE'])
df['PCA changes percentage'] = percentage_change(
    df['raw PCA'], df['registered PCA'])
df['T1err changes percentage'] = percentage_change(
    df['raw T1err'], df['registered T1err'])
df.to_csv(os.path.join(conf.result, 'results.csv'), index=False)
hydralog.info(f"The summary is \n {df.describe()}")

logger.log_dataframe(df, 'Results', path=os.path.join(
    conf.result, 'results.csv'))