# Set up

In [None]:
import os
import sys

module_path = os.path.abspath(os.path.join('../src/'))
print(module_path)
if module_path not in sys.path:
    sys.path.append(module_path)

In [None]:
import glob
import random
import pickle
import numpy as np
import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt
import plotly.express as px
import seaborn as sns
import nibabel as nib
from tqdm.auto import tqdm
from collections import Counter

from data.BundleData import *
from data.data_util import *
from utils.general_util import *
from model.model import *
from evaluation import *

In [None]:
SEED = 2022
DEVICE_NUM = 5
set_seed(seed=SEED)
DEVICE = set_device()
if DEVICE == 'cuda':
    torch.cuda.set_device(DEVICE_NUM)
    print(torch.cuda.device_count(), 
          torch.cuda.current_device(),
          torch.cuda.get_device_name(DEVICE_NUM))

In [None]:
model_folder = "../results/models/"
plot_folder = "../results/plots/"
result_data_folder = "../results/data/"
log_folder = "../results/logs/"
data_files_folder = "../data_files/"

# CHANGE DATA FOLDER BELOW
data_folder = ""

# Inference

This portion is using a metadata file to get CN/MCI/AD subjects for inference. Ignore if not applicable

## Get subj IDs

In [None]:
df_meta = pd.read_csv(data_files_folder + "metadata.csv")

In [None]:
def select_subject(df, dx='CN', n_subj=5, 
                   subj_train='007_S_6120_20171117_A3_DWI_S127',
                   data_folder='.'):
    set_seed(SEED)
    df_selected = df[(df.DX==dx) & (df.Subject != subj_train)]
    
    if n_subj:
        idx = np.random.choice(len(df_dx), n_subj, replace=False)
        df_selected = df_dx.iloc[idx]
    print(f"Selected {len(df_selected)} {dx} subject.")
    
    return df_selected

In [None]:
subj_cn = select_subject(df_meta, 'CN', None, data_folder=data_folder)
subj_mci = select_subject(df_meta, 'MCI', None, data_folder=data_folder)
subj_ad = select_subject(df_meta, 'Dementia', None, data_folder=data_folder)
print(subj_cn.shape, subj_mci.shape, subj_ad.shape)

## Run inference script

In [None]:
'''Select which model to perform inference on'''

subj_setting = 'CN10'
model_setting = 'convVAE3L_XUXU_Z2_B512_LR2E-04_WD1E-03_GCN2E+00' 
epoch = 100
model_type = "checkpoint"
model_subfolder = f"{model_setting}_{subj_setting}"
print(model_subfolder)

In [None]:
# Check if inferenced was done on subject if so ignore in the following steps
subj_inferred = []
for f in glob.glob(f"{result_data_folder}{model_subfolder}/*"):
    if f.split("/")[-1].startswith("E"):
        subj_inferred.append(f.split("/")[-1].split("_",1)[1].split(".")[0])
len(subj_inferred)

In [None]:
'''Get subjects to infer on'''

subjs = ["Subj01", "Subj02", "Subj03"]
subj_str = " ".join(set(subjs)-set(subj_inferred))
subj_str

In [None]:
!python ../src/inference.py --model_name {model_subfolder} \
                        --epoch {epoch} \
                        --seed {SEED} \
                        --subj_list {subj_str} \
                        --device {DEVICE} \
                        --device_num {DEVICE_NUM} \
                        --model_type {model_type} \
                        --data_folder {data_folder} \
                        --model_folder {model_folder} \
                        --result_data_folder {result_data_folder}