# BraTS Autoencoder Training Process

In [None]:
import os
import sys
import torch
import numpy as np
from matplotlib import pyplot as plt
from monai.bundle import ConfigParser

BUNDLE = '../brats-mri/brats_mri_class_cond'
sys.path.append(BUNDLE)
from scripts.ct_rsna import CTSubset

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [None]:
def get_autoencoder(config_file_name):
    config = ConfigParser()
    config.read_config(os.path.join(BUNDLE, 'configs', config_file_name))

    config['bundle_root'] = BUNDLE
    config['model_dir'] = os.path.join(BUNDLE, 'models')
    autoencoder = config.get_parsed_content('autoencoder')
    autoencoder.load_state_dict(torch.load(config.get_parsed_content('load_autoencoder_path', map_location=device)), strict=False)
    return autoencoder.to(device)


In [None]:
mri_vae = get_autoencoder('inference.json')
ct_vae = get_autoencoder('inference_new.json')

In [None]:
ds = CTSubset('../data/ct-rsna/train/', 'train_set_dropped_nans.csv',size=256, flip_prob=0.5, subset_len=1024)

In [None]:
np.random.seed(7)
K = 5
k_samples = np.random.choice(len(ds), K)

fig, ax = plt.subplots(3, K, figsize=(10, 6), sharex=True, sharey=True)
with torch.no_grad():
    for i in range(K):
        x = ds[k_samples[i]]['image'].to(device)
        y1, _, _ = mri_vae(x.unsqueeze(0))
        y2, _, _ = ct_vae(x.unsqueeze(0))

        ax[0, i].imshow(x.squeeze().cpu().numpy(), vmin=0., vmax=1., cmap='gray')
        ax[1, i].imshow(y1.squeeze().cpu().numpy(), vmin=0., vmax=1., cmap='gray')
        ax[2, i].imshow(y2.squeeze().cpu().numpy(), vmin=0., vmax=1., cmap='gray')

ax[0, 0].set_ylabel('input')
ax[1, 0].set_ylabel('initial')
ax[2, 0].set_ylabel('trained')
ax[2, 0].set_xticks([])
ax[2, 0].set_yticks([])
plt.show()

In [None]:
loss_dict = torch.load('../data/outputs/radimagenet_perceptual_1024_30epochs/losses_dict_epoch_30')

In [None]:
x, y = zip(*loss_dict['train'])
plt.plot(x, y, marker='o', linestyle='-')
plt.xlabel('Epoch #')
plt.ylabel('KL-VAE Loss')
plt.show