In [1]:
import gridencoder as ge
import torch
from torch import nn
from torch.nn import functional as F
import SimpleITK as sitk
from glob import glob
from tqdm import tqdm
import os
import numpy as np
from os import path as osp
from matplotlib import pyplot as plt

In [2]:
brats_path = '/data/BRATS2021/'

In [3]:
paths = sorted(glob(osp.join(brats_path, "*")))
paths = list(filter(lambda x: osp.isdir(x), paths))

### Load T1 images only

In [4]:
def load_nifti(filepath):
    # print(f"Loading {filepath}")
    image = sitk.ReadImage(filepath)
    image = sitk.GetArrayFromImage(image)
    return image

In [5]:
t1files = list(map(lambda x: glob(osp.join(x, '*_t1.nii.gz'))[0], paths))

In [6]:
def sizeof_net(network):
    ''' get size of network in MB '''
    res = 0
    for p in network.parameters():
        res += p.data.nelement() * p.data.element_size()
    return res / 1024**2

In [7]:
class ImplicitWrapper(nn.Module):
    def __init__(self, num_levels=16, desired_resolution=256, level_dim=2, base_resolution=16, \
                 log2_hashmap_size=19, decoder_num_hiddenlayers=1, decoder_num_nodes=256,
                 decoder_num_outputs=1,
                 num_encoders=1,
                ):
        super().__init__()
        self.encoder = [ge.GridEncoder(num_levels=num_levels,
                                      level_dim=level_dim,
                                     base_resolution=base_resolution,
                                     desired_resolution=desired_resolution,
                                     log2_hashmap_size=log2_hashmap_size) for _ in range(num_encoders)]
        self.encoder = nn.ModuleList(self.encoder)
        
        self.decoder = [nn.Linear(num_levels*level_dim, decoder_num_nodes), nn.LeakyReLU()]
        for i in range(decoder_num_hiddenlayers):
            self.decoder.append(nn.Linear(decoder_num_nodes, decoder_num_nodes))
            self.decoder.append(nn.LeakyReLU())
        self.decoder.append(nn.Linear(decoder_num_nodes, decoder_num_outputs))
        self.decoder = nn.Sequential(*self.decoder)
    
    def forward(self, x, idx=0):
        return self.decoder(self.encoder[idx](x))

In [8]:
## Load image and run for single image

In [9]:
image = load_nifti(t1files[0])
H, W, D = image.shape
maxDim = max(H, W, D)

In [10]:
print(H, W, D)

155 240 240


In [11]:
# image = load_nifti(t1files[0])
# H, W, D = image.shape
# maxDim = max(H, W, D)
# convert to tensor
image_tensor = image.astype(float) / image.max() * 2 - 1
image_tensor = torch.cuda.FloatTensor(image_tensor)

In [12]:
def training_loop(net, batch_size, n_iters, lr=1e-4):
    optim = torch.optim.Adam(net.parameters(), lr=lr)
    pbar = range(n_iters)
    for i in pbar:
        optim.zero_grad()
        x, y, z = [torch.randint(0, T, (1, batch_size), device=image_tensor.device) for T in [H, W, D]]
        val = image_tensor[x[0], y[0], z[0]][..., None]   # [B, 1]
        coord = torch.cat([x, y, z], dim=0).permute(1, 0) # [B, 3]
        coord = coord / maxDim * 2 - 1
        pred = net(coord)
        # get loss
        loss = F.mse_loss(val, pred)
        # pbar.set_description(f"Epoch: {i}, loss: {np.around(loss.item(), 4)}")
        loss.backward()
        optim.step()

    with torch.no_grad():
        x, y, z = torch.meshgrid(torch.arange(H, device='cuda'), torch.arange(W, device='cuda'), torch.arange(D, device='cuda'), indexing='ij')
        coord = torch.cat([t.reshape(1, -1) for t in [x, y, z]], dim=0)
        coord = coord/maxDim*2-1
        pred = net(coord.permute(1, 0))
        pred = pred.reshape(image_tensor.shape)
        psnr = 10 * torch.log10(4 / F.mse_loss(pred, image_tensor))
        return psnr.data.cpu().numpy()    

In [None]:
desired_res = [32, 48, 64, 96, 128, 192, 256]
psnrs = []

for dres in desired_res:
    print(dres)
    net = ImplicitWrapper(desired_resolution=dres).cuda()
    psnr = training_loop(net, batch_size=1000, n_iters=5000)
    psnrs.append(psnr)
    del net
    torch.cuda.empty_cache()

In [None]:
plt.plot(desired_res, psnrs)
plt.scatter(desired_res, psnrs)
plt.xlabel('Desired resolution')
plt.ylabel('PSNR')
plt.title('Single image, PSNR with final resolution')

In [None]:
print("torch.cuda.memory_allocated: %fGB"%(torch.cuda.memory_allocated(0)/1024/1024/1024))
print("torch.cuda.memory_reserved: %fGB"%(torch.cuda.memory_reserved(0)/1024/1024/1024))
print("torch.cuda.max_memory_reserved: %fGB"%(torch.cuda.max_memory_reserved(0)/1024/1024/1024))
torch.cuda.empty_cache()

In [None]:
network_size = [sizeof_net(ImplicitWrapper(desired_resolution=dres)) for dres in desired_res]
plt.plot(network_size, psnrs)
plt.scatter(network_size, psnrs)
plt.xlabel('Network size')
plt.ylabel('PSNR')
plt.title('Single image, network size v/s PSNR')

## Effect of hash table size

In [None]:
log2_size = [14, 16, 19, 20, 21, 22, 23, 24]
psnrs = []
network_size = []

for size in log2_size:
    print(size)
    net = ImplicitWrapper(log2_hashmap_size=size).cuda()
    network_size.append(sizeof_net(net))
    psnr = training_loop(net, batch_size=1000, n_iters=5000)
    psnrs.append(psnr)
    del net
    torch.cuda.empty_cache()

In [None]:
fig, axs = plt.subplots(1, 3, figsize=(14, 4))
axs[0].plot(log2_size, psnrs)
axs[0].scatter(log2_size, psnrs)
axs[0].set_xlabel('log2 hashmap size')
axs[0].set_ylabel('PSNR')
axs[0].set_title('Single image, hashmap size v/s PSNR')

axs[1].plot(network_size, psnrs)
axs[1].scatter(network_size, psnrs)
axs[1].set_xlabel('network size')
axs[1].set_ylabel('PSNR')
axs[1].set_title('Single image, network_size v/s PSNR')

axs[2].plot(log2_size, network_size)
axs[2].scatter(log2_size, network_size)
axs[2].set_xlabel('hash size')
axs[2].set_ylabel('network size')
axs[2].set_title('Single image, hash v/s memory')

## Effect of embedding size (keeping num_lvls accordingly)

In [None]:
levels = [1, 2, 4, 8]
psnrs = []
network_size = []

for lvl in levels:
    num_lvls = int(32 / lvl)
    print(lvl, num_lvls)

    net = ImplicitWrapper(num_levels=num_lvls, level_dim=lvl).cuda()
    network_size.append(sizeof_net(net))
    psnr = training_loop(net, batch_size=1000, n_iters=5000)
    psnrs.append(psnr)
    del net
    torch.cuda.empty_cache()

In [None]:
fig, axs = plt.subplots(1, 3, figsize=(14, 4))
axs[0].plot(levels, psnrs)
axs[0].scatter(levels, psnrs)
axs[0].set_xlabel('embedding size')
axs[0].set_ylabel('PSNR')
axs[0].set_title('Single image, embedding size v/s PSNR')

axs[1].plot(network_size, psnrs)
axs[1].scatter(network_size, psnrs)
axs[1].set_xlabel('network size')
axs[1].set_ylabel('PSNR')
axs[1].set_title('Single image, network_size v/s PSNR')

axs[2].plot(levels, network_size)
axs[2].scatter(levels, network_size)
axs[2].set_xlabel('embedding size')
axs[2].set_ylabel('network size')
axs[2].set_title('Single image, hash v/s memory')

## Effect of decoder layers

In [None]:
num_layers = np.arange(1, 7)
psnrs = []
network_size = []

for l in num_layers:
    print(l)
    net = ImplicitWrapper(decoder_num_hiddenlayers=l).cuda()
    network_size.append(sizeof_net(net))
    psnr = training_loop(net, batch_size=1000, n_iters=5000)
    psnrs.append(psnr)
    del net
    torch.cuda.empty_cache()

In [None]:
fig, axs = plt.subplots(1, 3, figsize=(14, 4))
axs[0].plot(num_layers, psnrs)
axs[0].scatter(num_layers, psnrs)
axs[0].set_xlabel('decoder layers')
axs[0].set_ylabel('PSNR')
axs[0].set_title('Single image')

axs[1].plot(network_size, psnrs)
axs[1].scatter(network_size, psnrs)
axs[1].set_xlabel('network size')
axs[1].set_ylabel('PSNR')
axs[1].set_title('Single image')

axs[2].plot(num_layers, network_size)
axs[2].scatter(num_layers, network_size)
axs[2].set_xlabel('decoder layers')
axs[2].set_ylabel('network size')
axs[2].set_title('Single image')

## Multiple images

This experiment tells if training multiple images worsens the performance of PSNR or not. 

In [None]:
def training_loop_multi_image(net, batch_size, n_iters, num_images, lr=1e-4, diff_optims=False, silent=False):
    # load all images
    imgs = [load_nifti(t1files[i]) for i in range(num_images)]
    H, W, D = imgs[0].shape
    maxDim = max([H, W, D])
    maxVal = max([np.max(i) for i in imgs])
    print(f"Training with {num_images} image(s).")
    
    # see if adam updates mess up with the encoders
    if diff_optims:
        optim_encoder = [torch.optim.Adam(net.encoder[i].parameters(), lr=lr) for i in range(num_images)]
        optim_decoder = torch.optim.Adam(net.decoder.parameters(), lr=lr)
    else:
        optim = torch.optim.Adam(net.parameters(), lr=lr)
        
    ## training loop
    if silent:
        pbar = range(int(n_iters * num_images))
    else:
        pbar = tqdm(range(int(n_iters * num_images)))
    for i in pbar:
        if diff_optims:
            [o.zero_grad() for o in optim_encoder]
            optim_decoder.zero_grad()
        else:
            optim.zero_grad()
        # sample points randomly
        x, y, z = [np.random.randint(T, size=(1, batch_size)) for T in [H, W, D]]
        img_idx = np.random.randint(num_images)
        # img_idx = i % num_images
        
        # sample points
        val = imgs[img_idx][x[0], y[0], z[0]][..., None]*2.0/maxVal - 1   # [B, 1]
        coord = np.concatenate([x, y, z], axis=0).transpose(1, 0) # [B, 3]
        coord = coord / maxDim * 2 - 1
        # put to tensors
        val = torch.cuda.FloatTensor(val)
        coord = torch.cuda.FloatTensor(coord)
        pred = net(coord, img_idx)
        # get loss
        loss = F.mse_loss(val, pred)
        loss.backward()
        # check for different optims
        if diff_optims:
            optim_decoder.step()
            optim_encoder[img_idx].step()
        else:
            optim.step()

    if diff_optims:
        for x in optim_encoder:
            del x
        del optim_decoder
    else:
        del optim
    torch.cuda.empty_cache()
    psnrs = []
    with torch.no_grad():
        x, y, z = torch.meshgrid(torch.arange(H, device='cuda'), torch.arange(W, device='cuda'), torch.arange(D, device='cuda'), indexing='ij')
        coord = torch.cat([t.reshape(1, -1) for t in [x, y, z]], dim=0)
        coord = coord/maxDim*2-1
        net.eval()
        for i in range(num_images):
            pred = net(coord.permute(1, 0), i)
            pred = pred.reshape(imgs[i].shape)
            pred = pred.data.cpu().numpy()
            gt   = imgs[i]*2.0/maxVal-1
            psnr = 10*np.log10(4/((gt - pred)**2).mean())
            psnrs.append(psnr)
    return psnrs

In [None]:
all_psnrs = []
for num_images in range(1, 26):
    net = ImplicitWrapper(num_encoders=num_images).cuda()
    psnrs = training_loop_multi_image(net, 1000, 5000, num_images, diff_optims=True, silent=True)
    all_psnrs.append(psnrs)
    del net
    torch.cuda.empty_cache()

In [None]:
all_psnrs_v2 = []
# for num_images in range(1, 26):
#     net = ImplicitWrapper(num_encoders=num_images, decoder_num_hiddenlayers=3).cuda()
#     psnrs = training_loop_multi_image(net, 1000, 5000, num_images, diff_optims=True, silent=True)
#     all_psnrs_v2.append(psnrs)
#     del net
#     torch.cuda.empty_cache()

In [None]:
figs, axs = plt.subplots(1, 2, figsize=(10, 5))
axs[0].plot(np.arange(len(all_psnrs))+1, [np.mean(x) for x in all_psnrs])
axs[0].scatter(np.arange(len(all_psnrs))+1, [np.mean(x) for x in all_psnrs])
axs[0].set_xlabel('#images')
axs[0].set_ylabel('PSNR')
axs[0].set_title('decoder layers=1')

# axs[1].plot(np.arange(len(all_psnrs_v2))+1, [np.mean(x) for x in all_psnrs_v2])
# axs[1].scatter(np.arange(len(all_psnrs_v2))+1, [np.mean(x) for x in all_psnrs_v2])
# axs[1].set_xlabel('#images')
# axs[1].set_ylabel('PSNR')
# axs[1].set_title('decoder layers=3')

In [None]:
## Plot individual ones

N1, N2 = len(all_psnrs), len(all_psnrs_v2)
# figs, axs = plt.subplots(1, 2, figsize=(10, 5))
axs[0].plot(np.arange(len(all_psnrs))+1, [np.mean(x) for x in all_psnrs])
axs[0].scatter(np.arange(len(all_psnrs))+1, [np.mean(x) for x in all_psnrs])
for t in range(1, N1+1):
    axs[0].plot(np.arange(t, N1+1), [x[t-1] for x in all_psnrs[t-1:]])
    axs[0].scatter(np.arange(t, N1+1), [x[t-1] for x in all_psnrs[t-1:]])

axs[0].set_xlabel('#images')
axs[0].set_ylabel('PSNR')
axs[0].set_title('decoder layers=1')

# axs[1].plot(np.arange(len(all_psnrs_v2))+1, [np.mean(x) for x in all_psnrs_v2])
# axs[1].scatter(np.arange(len(all_psnrs_v2))+1, [np.mean(x) for x in all_psnrs_v2])
for t in range(1, N1+1):
    axs[1].plot(np.arange(t, N1+1), [x[t-1] for x in all_psnrs_v2[t-1:]])
    axs[1].scatter(np.arange(t, N1+1), [x[t-1] for x in all_psnrs_v2[t-1:]])
axs[1].set_xlabel('#images')
axs[1].set_ylabel('PSNR')
axs[1].set_title('decoder layers=3')

## Visualize slices

In [13]:
from ipywidgets import interact, widgets
%matplotlib notebook

def plot_slices(image1, image2, axis, slice_num):
    fig, ax = plt.subplots(1, 2, figsize=(10, 5))
    ax[0].imshow(np.take(image1, slice_num, axis=axis), cmap='gray')
    ax[0].set_title("Predicted"); ax[0].axis('off')
    ax[1].imshow(np.take(image2, slice_num, axis=axis), cmap='gray')
    ax[1].set_title("Ground truth"); ax[1].axis('off')
    plt.show()

In [29]:
num_images = 25
net = ImplicitWrapper(num_encoders=num_images).cuda()
net.load_state_dict(torch.load("25images.pt"))
# psnrs = training_loop_multi_image(net, 1000, 5000, num_images, diff_optims=True, silent=False)

<All keys matched successfully>

In [30]:
net.eval()

ImplicitWrapper(
  (encoder): ModuleList(
    (0): GridEncoder: input_dim=3 num_levels=16 level_dim=2 resolution=16 -> 256 per_level_scale=1.2030 params=(4555440, 2) gridtype=hash align_corners=False interpolation=linear
    (1): GridEncoder: input_dim=3 num_levels=16 level_dim=2 resolution=16 -> 256 per_level_scale=1.2030 params=(4555440, 2) gridtype=hash align_corners=False interpolation=linear
    (2): GridEncoder: input_dim=3 num_levels=16 level_dim=2 resolution=16 -> 256 per_level_scale=1.2030 params=(4555440, 2) gridtype=hash align_corners=False interpolation=linear
    (3): GridEncoder: input_dim=3 num_levels=16 level_dim=2 resolution=16 -> 256 per_level_scale=1.2030 params=(4555440, 2) gridtype=hash align_corners=False interpolation=linear
    (4): GridEncoder: input_dim=3 num_levels=16 level_dim=2 resolution=16 -> 256 per_level_scale=1.2030 params=(4555440, 2) gridtype=hash align_corners=False interpolation=linear
    (5): GridEncoder: input_dim=3 num_levels=16 level_dim=2 res

In [None]:
### Training loop for image 1
image = load_nifti(t1files[0])
H, W, D = image.shape
maxDim = max(H, W, D)
# convert to tensor
image_tensor = image.astype(float) / image.max() * 2 - 1
image_tensor = torch.cuda.FloatTensor(image_tensor)

encoder = ge.GridEncoder(num_levels=16,
                             level_dim=2,
                             base_resolution=16,
                             desired_resolution=256,
                             log2_hashmap_size=19).cuda()

optim = torch.optim.Adam(encoder.parameters(), lr=1e-4)
pbar = range(5000)
for i in tqdm(pbar):
    optim.zero_grad()
    x, y, z = [torch.randint(0, T, (1, 1000), device=image_tensor.device) for T in [H, W, D]]
    val = image_tensor[x[0], y[0], z[0]][..., None]   # [B, 1]
    coord = torch.cat([x, y, z], dim=0).permute(1, 0) # [B, 3]
    coord = coord / maxDim * 2 - 1
    pred = net.decoder(encoder(coord))
    # get loss
    loss = F.mse_loss(val, pred)
    loss.backward()
    optim.step()

with torch.no_grad():
    x, y, z = torch.meshgrid(torch.arange(H, device='cuda'), torch.arange(W, device='cuda'), torch.arange(D, device='cuda'), indexing='ij')
    coord = torch.cat([t.reshape(1, -1) for t in [x, y, z]], dim=0)
    coord = coord/maxDim*2-1
    pred = net.decoder(encoder(coord.permute(1, 0)))
    pred = pred.reshape(image_tensor.shape)
    psnr = 10 * torch.log10(4 / F.mse_loss(pred, image_tensor))
    print(psnr.data.cpu().numpy())

 50%|██████████████████████████████████████████████████████████████████████████▉                                                                            | 2482/5000 [00:03<00:03, 757.73it/s]

In [None]:
net.decoder[0].weight.grad

In [12]:
# load all images
imgs = [load_nifti(t1files[i]) for i in range(num_images)]
H, W, D = imgs[0].shape
maxDim = max([H, W, D])
maxVal = max([np.max(i) for i in imgs])

with torch.no_grad():
    x, y, z = torch.meshgrid(torch.arange(H, device='cuda'), torch.arange(W, device='cuda'), torch.arange(D, device='cuda'), indexing='ij')
    coord = torch.cat([t.reshape(1, -1) for t in [x, y, z]], dim=0)
    coord = coord/maxDim*2-1
    net.eval()
    # get predicted image
    pred = net(coord.permute(1, 0), 0)
    pred = pred.reshape(imgs[0].shape)
    pred = pred.data.cpu().numpy()

In [13]:
pred = (0.5*(pred + 1))
predint = (pred*maxVal).astype(np.int16)

In [16]:
gt = (imgs[0]*2.0/maxVal - 1).astype(np.float32)

In [17]:
predint.dtype, imgs[0].dtype, gt.dtype, pred.dtype

(dtype('int16'), dtype('int16'), dtype('float32'), dtype('float32'))

In [18]:
slice_slider = widgets.IntSlider(min=0, max=min(H, W, D)-1, step=1, value=min(H, W, D)//2)

# Create a dropdown widget for the axis
axis_dropdown = widgets.Dropdown(options=[("X", 0), ("Y", 1), ("Z", 2)], value=0)
# Create an interactive plot
interact(plot_slices, image1=widgets.fixed(predint), image2=widgets.fixed(imgs[0]), 
         axis=axis_dropdown, slice_num=slice_slider);

interactive(children=(Dropdown(description='axis', options=(('X', 0), ('Y', 1), ('Z', 2)), value=0), IntSlider…