In [None]:
import torch
from dataloaders.brats2021 import BRATS2021EncoderSegDataset
from glob import glob
from gridencoder import GridEncoder
import os
import nibabel as nib

import torch
from torch import nn
%pylab
%matplotlib notebook

In [None]:
encoder = GridEncoder(level_dim=4, desired_resolution=196, gridtype='tiled', align_corners=True).cuda()
decoder = nn.Sequential(
    nn.Linear(64, 256),
    nn.LeakyReLU(),
    nn.Linear(256, 4)
).cuda()
decoder.load_state_dict(torch.load('/data/Implicit3DCNNTasks/brats2021/decoder.pth'))

In [None]:
encoders = sorted(glob("/data/Implicit3DCNNTasks/brats2021/encoder_BraTS2021_*pth"))

In [None]:
idx = int(input("Enter index: "))
enc = encoders[idx]
data = torch.load(enc)
# data['embeddings'] = data['embeddings']*2.0 + 0.5 * torch.randn_like(data['embeddings']) * data['embeddings'].std(0)[None]
encoder.load_state_dict(data)

# run eval
HWD = torch.tensor([240, 240, 155]).long()
xyz = torch.meshgrid([torch.arange(t) for t in HWD], indexing='ij')
xyz = torch.stack(xyz, dim=-1).reshape(-1, 3)
xyz = xyz / (HWD - 1) * 2 - 1

with torch.no_grad():
    imgs = []
    sz = xyz.shape[0]//64
    for i in range(64):
        minixyz = xyz[sz*i:sz*(i+1)].cuda()
        img = decoder(encoder(minixyz))
        imgs.append(img.cpu())
        
imgs = torch.stack(imgs, dim=0).reshape(240, 240, 155, 4)

In [None]:
fig, axs = plt.subplots(2, 2, figsize=(10, 10))
for i in range(2):
    for j in range(2):
        axs[i][j].imshow(imgs[:, :, 60, i*2 + j].data.cpu().numpy(), cmap='gray')
        axs[i][j].axis('off')

In [None]:
# # brats_path = sorted(glob("/data/BRATS2021/training/BraTS2021_00612/*nii.gz"))
# # del brats_path[1]
# # gtimgs = [nib.load(x).get_fdata() for x in brats_path]

# fig, axs = plt.subplots(2, 2, figsize=(10, 10))
# for i in range(2):
#     for j in range(2):
#         axs[i][j].imshow(gtimgs[i*2 + j][:, :, 60], cmap='gray')
#         axs[i][j].axis('off')

In [None]:
## Visualize brats images

In [None]:
brats_images = sorted(glob("/data/BRATS2021/training/*/"))

In [None]:
idx = int(input("Enter index (0-{}): ".format(len(brats_images))))
imgs = glob(os.path.join(brats_images[idx], '*nii.gz'))
imgs = list(filter(lambda x: 'seg' not in x, imgs))
print(imgs)
imgs = [nib.load(x).get_fdata() for x in imgs]

In [None]:
# fig, axs = plt.subplots(2, 2, figsize=(10, 10))
# for i in range(2):
#     for j in range(2):
#         idx = i*2 + j
#         axs[i][j].hist(imgs[idx].reshape(-1), bins=500)
#         axs[i][j].set_yscale('log')

In [None]:
## Calculate PSNR

In [None]:
from utils.util import uniform_normalize

In [None]:
for idx in range(len(encoders)):
    enc = encoders[idx]
    data = torch.load(enc)
    encoder.load_state_dict(data)
    print("Loaded state dict... {}".format(enc))
    
    # Get images
    print("Loading ground truth images... {}".format(brats_images[idx]))
    gtimgs = sorted(glob(os.path.join(brats_images[idx], '*nii.gz')))
    gtimgs = list(filter(lambda x: 'seg' not in x, gtimgs))
    print(gtimgs)
    gtimgs = [uniform_normalize(nib.load(x).get_fdata()) for x in gtimgs]
    print("Loaded ground truth images... {}".format(brats_images[idx]))
    
    # run eval
    HWD = torch.tensor([240, 240, 155]).long()
    xyz = torch.meshgrid([torch.arange(t) for t in HWD], indexing='ij')
    xyz = torch.stack(xyz, dim=-1).reshape(-1, 3)
    xyz = xyz / (HWD - 1) * 2 - 1

    with torch.no_grad():
        imgs = []
        sz = xyz.shape[0]//64
        for i in range(64):
            minixyz = xyz[sz*i:sz*(i+1)].cuda()
            img = decoder(encoder(minixyz))
            imgs.append(img.cpu())

    predimgs = torch.stack(imgs, dim=0).reshape(240, 240, 155, 4)
    psnrs = []
    for i in range(4):
        p = (predimgs[..., i] - gtimgs[i])**2
        p = p.mean().item()
        psnrs.append(10*np.log10(4/p))
    print(idx, ", ".join([str(x) for x in psnrs]))

In [None]:
encoders[44], brats_images[44]

## Separate decoders

In [None]:
encoder = [GridEncoder(level_dim=2, desired_resolution=196, gridtype='tiled', align_corners=True).cuda() for _ in range(4)]
decoder = [nn.Sequential(
    nn.Linear(32, 256),
    nn.LeakyReLU(),
    nn.Linear(256, 1)
).cuda() for _ in range(4)]
for i in range(4):
    decoder[i].load_state_dict(torch.load(f'/data/Implicit3DCNNTasks/brats2021_unimodal/decoder{i}.pth'))

In [None]:
encoders_unimodal = [sorted(glob(f"/data/Implicit3DCNNTasks/brats2021_unimodal/encoder_BraTS2021_*{i}.pth")) for i in range(4)]

In [None]:
idx = int(input("Enter index: "))
encs = [x[idx] for x in encoders_unimodal]
for i, enc in enumerate(encs):
    data = torch.load(enc)
    encoder[i].load_state_dict(data)

# run eval
HWD = torch.tensor([240, 240, 155]).long()
xyz = torch.meshgrid([torch.arange(t) for t in HWD], indexing='ij')
xyz = torch.stack(xyz, dim=-1).reshape(-1, 3)
xyz = xyz / (HWD - 1) * 2 - 1

with torch.no_grad():
    allimgs = []
    # for all images
    for imgid in range(4):
        imgs = []
        sz = xyz.shape[0]//64
        for i in range(64):
            minixyz = xyz[sz*i:sz*(i+1)].cuda()
            img = decoder[imgid](encoder[imgid](minixyz))
            imgs.append(img.cpu())
        imgs = torch.stack(imgs, dim=0)
        allimgs.append(imgs)

allimgs = torch.stack(allimgs, dim=-1).reshape(240, 240, 155, 4)

In [None]:
%matplotlib inline
fig, axs = plt.subplots(2, 2, figsize=(10, 10))
for i in range(2):
    for j in range(2):
        #axs[i][j].imshow(allimgs[:, :, 45, i*2 + j].data.cpu().numpy(), cmap='gray')
        axs[i][j].set_title('i, j, idx = {}, {}, {}'.format(i, j, i*2+j))

In [None]:
fig, axs = plt.subplots(2, 2, figsize=(10, 10))
for i in range(2):
    for j in range(2):
        axs[i][j].imshow(allimgs[:, :, 45, i*2 + j].data.cpu().numpy(), cmap='gray')
        axs[i][j].axis('off')

In [None]:
for idx in range(len(encoders_unimodal[0])):
    encs = [x[idx] for x in encoders_unimodal]
    for i in range(4):
        data = torch.load(encs[i])
        encoder[i].load_state_dict(data)
    print("Loaded state dict... {}".format(encs[0]))
    
    # Get images
    print("Loading ground truth images... {}".format(brats_images[idx]))
    gtimgs = sorted(glob(os.path.join(brats_images[idx], '*nii.gz')))
    gtimgs = list(filter(lambda x: 'seg' not in x, gtimgs))
    gtimgs = [uniform_normalize(nib.load(x).get_fdata()) for x in gtimgs]
    print("Loaded ground truth images... {}".format(brats_images[idx]))
    
    # run eval
    HWD = torch.tensor([240, 240, 155]).long()
    xyz = torch.meshgrid([torch.arange(t) for t in HWD], indexing='ij')
    xyz = torch.stack(xyz, dim=-1).reshape(-1, 3)
    xyz = xyz / (HWD - 1) * 2 - 1

    with torch.no_grad():
        allimgs = []
        # for all images
        for imgid in range(4):
            imgs = []
            sz = xyz.shape[0]//64
            for i in range(64):
                minixyz = xyz[sz*i:sz*(i+1)].cuda()
                img = decoder[imgid](encoder[imgid](minixyz))
                imgs.append(img.cpu())
            imgs = torch.stack(imgs, dim=0)
            allimgs.append(imgs)

    predimgs = torch.stack(allimgs, dim=-1).reshape(240, 240, 155, 4)
    psnrs = []
    for i in range(4):
        p = (predimgs[..., i] - gtimgs[i])**2
        p = p.mean().item()
        psnrs.append(10*np.log10(4/p))
    print(idx, ", ".join([str(x) for x in psnrs]))