# 2d Brain model analysis

In [1]:
### Base ###
import os
import sys
import csv
from argparse import Namespace
import numpy as np 
import torch 
import torch.nn as nn
from torch.optim import Adam
import fnmatch
import itertools
import math
from sklearn.decomposition import PCA
import PIL.Image as pimg
import pytorch_lightning as pl

### Visualization ###
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import matplotlib.mlab as mlab
from matplotlib import rc
rc('text', usetex=True)
rc('font', **{'family':'serif','serif':['Palatino']})
%matplotlib inline

parent = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(os.getcwd()))))
sys.path.insert(0, parent)
os.chdir(parent)
print('Setting root path to : {}'.format(parent))

from src.support.base_miccai import *
from src.in_out.datasets_miccai import ZeroOneT12DDataset
from src.core.models.bayesian_atlas_2dmiccai import VariationalMetamorphicAtlas2dExecuter

Setting root path to : /network/lustre/dtlake01/aramis/users/paul.vernhet/Scripts/Ongoing/MICCAI2020/deepshape
Setting root path to : /network/lustre/dtlake01/aramis/users/paul.vernhet/Scripts/Ongoing/MICCAI2020/deepshape


In [2]:
# PATHS :
HOME_PATH = '/network/lustre/dtlake01/aramis/users/paul.vernhet'
NIFTI_PATH = os.path.join(HOME_PATH, 'Data/MICCAI_dataset/2_datasets/2_t1ce_normalized')
BRATS_TENSORS_PATH = os.path.join(HOME_PATH, 'Data/MICCAI_dataset/3_tensors3d/2_t1ce_normalized/0_reduction')

EXPERIMENTS_DIR = os.path.join(HOME_PATH, 'Results/MICCAI/2dBraTs/2D_rdm_slice2_normalization_1_reduction/VAE_2020-02-28-09-31-48')
ckpt_file = os.path.join(EXPERIMENTS_DIR, 'lightning_logs/version_0/checkpoints/_ckpt_epoch_139.ckpt')
metavar_file = os.path.join(EXPERIMENTS_DIR, 'lightning_logs/version_0', 'meta_tags.csv')

Manually loading models ( hparams | weights )

In [3]:
def load_hparams_from_tags_csv(tags_csv):
    from argparse import Namespace
    import pandas as pd

    tags_df = pd.read_csv(tags_csv)
    dic = tags_df.to_dict(orient='records')

    ns_dict = {row['key']: convert(row['value']) for row in dic}

    ns = Namespace(**ns_dict)
    return ns


def convert(val):
    constructors = [int, float, str]

    if type(val) is str:
        if val.lower() == 'true':
            return True
        if val.lower() == 'false':
            return False

    for c in constructors:
        try:
            return c(val)
        except ValueError:
            pass
    return val

In [4]:
# Load hparams
hparams = load_hparams_from_tags_csv(metavar_file)
print('>> hparams loaded')

# craete model, initialize it with hparams | load weights
model_executer = VariationalMetamorphicAtlas2dExecuter(hparams)
initial_template = model_executer.model.template_intensities.clone()

checkpoint = torch.load(ckpt_file, map_location=lambda storage, loc: storage)
model_executer.load_state_dict(checkpoint['state_dict'])
model = model_executer.model
loaded_template = model_executer.model.template_intensities.clone()
print('>> model_executer loaded')
assert not torch.all(initial_template.eq(loaded_template))

# Load state dict 
#path_to_state_dict = os.path.join(
#    path_to_results, 
#    sorted(fnmatch.filter(os.listdir(path_to_results), 'model__epoch_*.pth'), key=(lambda x: int(x.split('_')[-1][:-4])))[-1])
#print('path_to_state_dict = %s' % os.path.basename(path_to_state_dict))
#state_dict = torch.load(path_to_state_dict, map_location=lambda storage, loc: storage)

# Lead centering parameters
#path_to_v_star_average = os.path.join(path_to_results, 'vsa__epoch_%s.npy' % (os.path.basename(path_to_state_dict).split('_')[-1][:-4]))
#path_to_n_star_average = os.path.join(path_to_results, 'nsa__epoch_%s.npy' % (os.path.basename(path_to_state_dict).split('_')[-1][:-4]))
#print('path_to_v_star_average = %s' % os.path.basename(path_to_v_star_average))
#print('path_to_n_star_average = %s' % os.path.basename(path_to_n_star_average))
#v_star_average = torch.from_numpy(np.load(path_to_v_star_average))
#n_star_average = torch.from_numpy(np.load(path_to_n_star_average))

  2%|▏         | 6/330 [00:00<00:05, 54.89it/s]

>> hparams loaded
>> Creating dataset with 330 files (from 335 available)
>> Creating dataset with 16 files (from 125 available)
>> Computing online statistics for dataset ...


100%|██████████| 330/330 [00:03<00:00, 92.83it/s] 

>> Template intensities are torch.Size([1, 160, 192]) = 30720 parameters
>> Encoder2d__5_down has 931528 parameters
>> DeepDecoder2d__5_up has 933200 parameters
>> DeepDecoder2d__4_up has 468320 parameters
>> Metamorphic 2D BayesianAtlas has 2363768 parameters
>> model_executer loaded





In [5]:
noise_dimension = model.noise_dimension

# PRINT TEMPLATE
img_01 = 255 * initial_template.squeeze().detach().cpu().numpy()
img_02 = 255 * loaded_template.squeeze().detach().cpu().numpy()

figsize = 6
f, ax = plt.subplots(1, 2, figsize=(2*figsize, figsize))
ax[0].imshow(img_01, cmap='gray')
ax[1].imshow(img_02, cmap='gray')
plt.axis('off')
plt.title('Learned template')
plt.show()

Set CUDA DEVICE :

In [6]:
torch.manual_seed(hparams.seed)
np.random.seed(hparams.seed)

if hparams.cuda:
    print('>> GPU available.')
    DEVICE = torch.device('cuda')
    torch.cuda.set_device(hparams.num_gpu)
    torch.cuda.manual_seed(hparams.seed)
else:
    DEVICE = torch.device('cpu')
    print('>> CUDA is not available. Overridding with device = "cpu".')
    print('>> OMP_NUM_THREADS will be set to ' + str(hparams.num_threads))
    os.environ['OMP_NUM_THREADS'] = str(hparams.num_threads)
    torch.set_num_threads(hparams.num_threads)
    
model_executer = model_executer.to(DEVICE)

>> GPU available.


# Check data | reconstruction 

### 1. Qualitative analysis

In [25]:
indexes = [0, 1, 2]
data_loader = model_executer.train_dataloader()
n = min(5, hparams.nb_train)
intensities_to_write = []
for batch_idx, intensities in enumerate(data_loader):
    if n <= 0:
        break
    bts = intensities.size(0)
    nb_selected = min(bts, n)
    intensities_to_write.append(intensities[:nb_selected])
    n = n - nb_selected
observations = torch.cat(intensities_to_write).to(DEVICE)

##############
### DEFORM ###
##############

s, _, a, _ = model.encode(observations)

# INIT
bts = s.size(0)
assert bts == a.size(0)
ntp = model.number_of_time_points
kws = model.kernel_width__s
kwa = model.kernel_width__a
dim = model.dimension
gs = model.grid_size
dgs = model.downsampled_grid_size
dsf = model.downsampling_grid

v_star = model.decoder__s(s) - model.v_star_average.type(str(s.type()))
n_star = model.decoder__a(a) - model.n_star_average.type(str(a.type()))

# GAUSSIAN SMOOTHING
v = batched_vector_smoothing(v_star, kws, scaled=False)
n = batched_scalar_smoothing(n_star, kwa, scaled=False)

# NORMALIZE
s_norm_squared = torch.sum(s.view(bts, -1) ** 2, dim=1)
a_norm_squared = torch.sum(a.view(bts, -1) ** 2, dim=1)
v_norm_squared = torch.sum(v * v_star, dim=tuple(range(1, dim + 2)))
n_norm_squared = torch.sum(n * n_star, dim=tuple(range(1, dim + 2)))
normalizer__s = torch.where(s_norm_squared > 1e-10,
                            torch.sqrt(s_norm_squared / v_norm_squared),
                            torch.from_numpy(np.array(0.0)).float().type(str(s.type())))
normalizer__a = torch.where(a_norm_squared > 1e-10,
                            torch.sqrt(a_norm_squared / n_norm_squared),
                            torch.from_numpy(np.array(0.0)).float().type(str(a.type())))

normalizer__s = normalizer__s.view(*([bts] + (dim + 1) * [1])).expand(v.size())
normalizer__a = normalizer__a.view(*([bts] + (dim + 1) * [1])).expand(n.size())
v = v * normalizer__s
n = n * normalizer__a

# FLOW
grid = torch.stack(torch.meshgrid([torch.linspace(0.0, elt - 1.0, delt) for elt, delt in zip(gs, dgs)])
                   ).type(str(s.type())).view(*([1, dim] + list(dgs))).repeat(*([bts] + (dim + 1) * [1]))

x = grid.clone() + v / float(2 ** ntp)
for t in range(ntp):
    x += batched_vector_interpolation_adaptive(x - grid, x, dsf)

# INTERPOLATE
intensities = batched_scalar_interpolation_adaptive(model.template_intensities + n, x).float()

# WRITE
template = model.template_intensities.float().mul(255).cpu()

images = []
images_ = []
sliced_images = []
for i in range(bts):
    # Get data
    appearance = (model.template_intensities + n[i]).float().cpu().mul(255)
    shape = batched_scalar_interpolation_adaptive(model.template_intensities.float().cpu(),
                                                  x[i].float().unsqueeze(0).detach().cpu())[0].mul(255)
    metamorphosis = intensities[i].float().mul(255).cpu()
    target = observations[i].float().mul(255).cpu()
    
    # Get sliced image
    images_i = [template.squeeze(1), appearance.squeeze(1), shape.squeeze(1),
                metamorphosis.squeeze(1), target.squeeze(1)]
    images_ += images_i
    images.append(images_i)
images_ = torch.cat(images_)
vmax = torch.max(images_).detach().numpy()
empiric_noise_std = torch.sqrt(torch.sum((intensities - observations) ** 2) / float(intensities.size(0) * noise_dimension)).detach().cpu().numpy()
print('empiric_noise_std = %.2E' % empiric_noise_std)

In [25]:
############
### PLOT ###
############

figsize = 4
nrows = len(indexes)
ncols = len(images_i)

f, axes = plt.subplots(nrows, ncols , figsize=(ncols*figsize, nrows*figsize))
for i in range(nrows):
    for j in range(ncols):
        ax = axes.reshape(nrows, ncols)[i, j]
        index = indexes[i]

        img = images[index][j].detach().cpu().numpy()[0]
        if j != 2:
            ax.imshow(img, cmap='gray', vmin=0.0, vmax=vmax)

        if j == 2:
            g = x[index].permute(1, 2, 0).detach().cpu().numpy()
            ax.plot([g[:-1, :, 0].ravel(), g[1:, :, 0].ravel()], 
                    [g[:-1, :, 1].ravel(), g[1:, :, 1].ravel()], 'grey', linewidth=0.5)
            ax.plot([g[:, :-1, 0].ravel(), g[:, 1:, 0].ravel()],
                    [g[:, :-1, 1].ravel(), g[:, 1:, 1].ravel()], 'grey', linewidth=0.5)
        
        ax.axis('off')
plt.show()

empiric_noise_std = 4.17E-02


In [None]:
figsize = 5
nrows = 2
ncols = 3

f, axes = plt.subplots(nrows, ncols , figsize=(ncols*figsize, nrows*figsize))

# Learned template
template = model.template_intensities.detach().cpu().numpy()[0]
img = template
ax = axes[0, 0]
ax.imshow(img, cmap='gray')
ax.axis('off')
ax.set_title('Learned template')

# Average intensity increment
n_average = torch.mean(n, dim=0)[0].detach().cpu().numpy()
img = n_average
ax = axes[0, 1]
ax.imshow(img, cmap='gray')
ax.axis('off')
ax.set_title('Average intensity increment')

# Difference
img = template + n_average
ax = axes[0, 2]
ax.imshow(img, cmap='gray')
ax.axis('off')
ax.set_title('Template + Average increment')

# Learned template + increment
img = template + n_average
ax = axes[1, 0]
ax.imshow(img, cmap='gray')
ax.axis('off')
ax.set_title('Template + Average increment')

# Average deformation
x_average = torch.mean(x, dim=0)
g = x_average.permute(1, 2, 0).detach().cpu().numpy()
ax = axes[1, 1]
ax.plot([g[:-1, :, 0].ravel(), g[1:, :, 0].ravel()], 
        [g[:-1, :, 1].ravel(), g[1:, :, 1].ravel()], 'grey', linewidth=0.5)
ax.plot([g[:, :-1, 0].ravel(), g[:, 1:, 0].ravel()],
        [g[:, :-1, 1].ravel(), g[:, 1:, 1].ravel()], 'grey', linewidth=0.5)
ax.axis('off')
ax.set_title('Average deformation')

# Difference
img = batched_scalar_interpolation(model.template_intensities + torch.mean(n, dim=0), 
                                   x_average.unsqueeze(0))[0].detach().cpu().numpy()[0]
ax = axes[1, 2]
ax.imshow(img, cmap='gray')
ax.axis('off')
ax.set_title('Deformed template + Average deformation')

plt.show()

In [None]:
index = 0

img = n[index, 0].detach().cpu().numpy() - n_average
img = img - np.mean(img)

figsize = 5
f = plt.figure(figsize=(figsize, figsize))
plt.imshow(img, cmap='gray')
plt.axis('off')
# plt.title('Learned template')
plt.show()




thresholds = [0.95, 0.85, 0.75, 0.65, 0.55, 0.54, 0.53, 0.52, 0.51, 0.50]
figsize = 4
ncols = min(len(thresholds), 5)
nrows = len(thresholds) // ncols

f, axes = plt.subplots(nrows, ncols , figsize=(ncols*figsize, nrows*figsize))
for k, c in enumerate(thresholds):
    ax = np.ravel(axes)[k]

    mi = np.min(img)
    ma = np.max(img)
    img_thres = 1.0 * (img > mi + c * (ma - mi)) - 1.0 * (img < ma - c * (ma - mi))

    ax.imshow(img_thres, cmap='gray')
    ax.axis('off')
    ax.set_title('c = %.2f' % c)
plt.show()

In [None]:
zs = s.detach().cpu().numpy()
za = a.detach().cpu().numpy()

index_a = 0
index_b = 2

figsize = 4
ncols = 4
nrows = 1
f, axes = plt.subplots(nrows, ncols , figsize=(ncols*figsize, nrows*figsize))

### SHAPE 
pca = PCA(2)
zs = pca.fit_transform(zs)

ax = axes.ravel()[0]
ax.scatter(zs[:, 0], zs[:, 1], s=50, c='tab:blue')
ax.scatter(zs[index_a, 0], zs[index_a, 1], s=200, c='tab:red')
ax.scatter(zs[index_b, 0], zs[index_b, 1], s=200, c='tab:green')
ax.set_title('shape')

### APPEARANCE
pca = PCA(2)
za = pca.fit_transform(za)

ax = axes.ravel()[1]
ax.scatter(za[:, 0], za[:, 1], s=50, c='tab:blue')
ax.scatter(za[index_a, 0], za[index_a, 1], s=200, c='tab:red')
ax.scatter(za[index_b, 0], za[index_b, 1], s=200, c='tab:green')
ax.set_title('appearance')

### Image A

ax = axes.ravel()[2]
img = images[index_a][3].detach().cpu().numpy()[0]
ax.imshow(img, cmap='gray', vmin=0.0, vmax=vmax)
ax.set_title('red')

### Image B

ax = axes.ravel()[3]
img = images[index_b][3].detach().cpu().numpy()[0]
ax.imshow(img, cmap='gray', vmin=0.0, vmax=vmax)
ax.set_title('green')

plt.show()


### 
### INTERPOLATION
###

s_start = s[index_a]
a_start = a[index_a]
s_end = s[index_b]
a_end = a[index_b]

T = 5
dt = 1.0 / float(T-1)

imgs = []
tol = 1e-10

# MIXED INTERPOLATION
for t in range(T): 
    at = (1.0 - t*dt) * a_start + t*dt * a_end
    st = (1.0 - t*dt) * s_start + t*dt * s_end
    
    img = torch.clamp(intensities_mean + intensities_std * model(st.unsqueeze(0), at.unsqueeze(0))[0], tol, 255. - tol)
    imgs.append(img)

# APPEARANCE INTERPOLATION
for t in range(T): 
    at = (1.0 - t*dt) * a_start + t*dt * a_end
    st = s_start
    
    img = torch.clamp(intensities_mean + intensities_std * model(st.unsqueeze(0), at.unsqueeze(0))[0], tol, 255. - tol)
    imgs.append(img)

# SHAPE INTERPOLATION
for t in range(T): 
    at = a_start
    st = (1.0 - t*dt) * s_start + t*dt * s_end
    
    img = torch.clamp(intensities_mean + intensities_std * model(st.unsqueeze(0), at.unsqueeze(0))[0], tol, 255. - tol)
    imgs.append(img)
    
imgs = torch.stack(imgs)
save_image(imgs, '4_interpolation__start_%d__end_%d.pdf' % (index_a, index_b), 
           nrow=T, normalize=True, range=(0., float(torch.max(imgs).detach().cpu().numpy())))