# Set up

In [None]:
import os
import sys

module_path = os.path.abspath(os.path.join('../src/'))
print(module_path)
if module_path not in sys.path:
    sys.path.append(module_path)

In [None]:
import glob
import random
import pickle
import numpy as np
import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt
import seaborn as sns
import nibabel as nib
from tqdm.auto import tqdm
from collections import Counter

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
from torch.utils.tensorboard import SummaryWriter

import dipy
from dipy.viz import window, actor

from data.BundleData import BundleData
from data.data_util import *
from utils.general_util import *
from model.model import *
from model.train_model import train_model
from evaluation import *

In [None]:
SEED = 2022
DEVICE_NUM = 5
set_seed(seed=SEED)
DEVICE = set_device()
if DEVICE == 'cuda':
    torch.cuda.set_device(DEVICE_NUM)
    print(torch.cuda.device_count(), 
          torch.cuda.current_device(),
          torch.cuda.get_device_name(DEVICE_NUM))

In [None]:
model_folder = "../results/models/"
plot_folder = "../results/plots/"
result_data_folder = "../results/data/"
log_folder = "../results/logs/"
data_files_folder = "../data_files/"

# CHANGE DATA FOLDER BELOW
data_folder = ""

# Load data

Change the code below for selecting training subjects accordingly.

In [None]:
'''Load metadata & select CN and sort by bundles then streamlines'''

df_meta = pd.read_csv(data_files_folder + "metadata.csv")
df_tmp = df_meta.loc[df_meta.DX=='CN'].sort_values(by=['bundle_count','streamline_count'], 
                                          ascending=False)
df_tmp.head()

In [None]:
'''Select subject to be trained on'''
n_subj = 10
subjs_train = df_tmp[:n_subj].Subject.values
subjs_train

`FiberData` can load in bundle data from RecoBundles output. See example file structure under `subjects_small` [here](https://github.com/dipy/dipy/blob/master/doc/interfaces/buan_flow.rst).

In [None]:
%%time

args = {'n_points' : 256, 'n_lines' : None, 'min_lines' : 2, 
        'tracts_exclude' : ['CST_L_s', 'CST_R_s'],'preprocess' : '3d', 
        'rng' : None, 'verbose': False, 'data_folder' : data_folder}

data = FiberData(subjs_train, **args)

# 1D convVAE

## Dataset & Dataloader

Preprocess data into torch Dataset and Dataloader

In [None]:
X = data.X
y = data.y

In [None]:
'''Split into train and test'''
train, test = split_data(X, y, n_splits=50, test_size=0.2, random_state=SEED)
X_train = torch.from_numpy(X[train])
y_train = torch.from_numpy(y[train])
X_test = torch.from_numpy(X[test])
y_test = torch.from_numpy(y[test])

print(X_train.shape, y_train.shape)
print(X_test.shape, y_test.shape)

In [None]:
'''Get the x,y,z mean and std from training data'''
mean, std = X_train.mean([0,1]), X_train.std([0,1])
print(mean, std)

In [None]:
'''Apply standard scaling (zero mean and unit variance) to train and test data'''

X_train_norm = X_train.sub(mean).div(std)
X_test_norm = X_test.sub(mean).div(std)

X_norm = torch.tensor(X).sub(mean).div(std)
data.X_norm = X_norm

print(X_train_norm.shape, X_test_norm.shape, X_norm.shape)

# histogram to double check scaling is correct
_ = plt.hist(np.array(X_train_norm).ravel(), bins=50, density=True)

In [None]:
'''Convert data to pytorch dataloader for training'''
g_seed = torch.Generator()
g_seed.manual_seed(SEED)
batch_size = 512

train_data = TensorDataset(X_train_norm, y_train)
train_loader = make_data_loader(train_data, SEED, batch_size, num_workers=4)

test_data = TensorDataset(X_test_norm, y_test)
test_loader = make_data_loader(test_data, SEED, batch_size, num_workers=4)


print(f"# Batches: train {len(train_loader)}, eval {len(test_loader)}")

In [None]:
'''Used for testing model forward function'''
set_seed(SEED)
model = convVAE(3, 2, Encoder3L, Decoder3L)
model.to(DEVICE)
print("# Params: ", sum(p.numel() for p in model.parameters()))

# model.apply(init_weights)

for i, (x,_) in enumerate(test_loader):
    x = x.to(DEVICE)
    x_hat, z, elbo = model.loss(x, computeMSE=False)
    print(z.shape, x_hat.shape, elbo.item())
    print(model.result_dict)
    break

In [None]:
del X_train, X_test, X_train_norm, X_test_norm

# Training

- [Save torch models](https://pytorch.org/tutorials/beginner/saving_loading_models.html#save-load-state-dict-recommended)

In [None]:
total_epochs = 100
save_every = 5 # save model every N epochs

zdim = 32
model_type = "3L"
lr = 2e-4
wd = 1e-3

model_name = f"convVAE{model_type}_XUXU_Z{zdim}_B{batch_size}_" \
             f"LR{lr:.0E}_WD{wd:.0E}_GCN2E+00_CN{n_subj}"
model_info = parse_model_setting(model_name)
print(f"Saving to {model_name}")
model_info

If resume training model, set `model_resume` to `True`, and `resume_epoch` to the epoch for which the model was last saved.

In [None]:
model_resume = False
resume_epoch = 0
model, optimizer, starting_epoch, starting_batch_train, _ = init_model(model_folder, model_name,
                                                                       SEED, DEVICE, 
                                                                       model_resume=model_resume,
                                                                       resume_epoch=resume_epoch)

In [None]:
'''For saving tensorboard log and model files'''

!mkdir {log_folder}{model_name}
!mkdir {model_folder}{model_name}
!mkdir {result_data_folder}{model_name}

In [None]:
'''Define training arguments'''

writer = SummaryWriter(log_folder + model_name)

args = {
        'model' : model, 'optimizer' : optimizer, 
        'train_loader' : train_loader, 'test_loader': test_loader,
        'num_epochs' : total_epochs, 'writer' : writer,
        'starting_epoch' : starting_epoch, 
        'starting_batch_train' : starting_batch_train,
        'mean' : mean, 'std' : std,
        'gradient_type' : model_info['grad_type'], 'gradient_clip' : model_info['GC'],
        'computeMSE' : False, 'verbose' : writer,
        'save_folder' : model_folder + model_name, 'save_every' : save_every,
        'save_type' : 'checkpoint', 'device' : DEVICE
       }

In [None]:
'''Training'''

train_losses, eval_losses = train_model(**args)
writer.flush()
writer.close()