In [None]:
import argparse
from dataloaders import BRATS2021KPlanesDataset
from gridencoder import GridEncoder
import torch
from torch import nn
from tqdm import tqdm
from os import path as osp
import numpy as np
from torch.nn import functional as F
import gc
from configs.config import get_cfg_defaults
import os
from networks.kplanes.encoder import ImageKPlaneAttnEncoder
from collections import namedtuple

In [None]:
%pylab
%matplotlib inline

In [None]:
dataset = BRATS2021KPlanesDataset(root_dir='/data/rohitrango/BRATS2021/training', augment=True, 
                           multimodal=False, mlabel=0)

In [None]:
dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=0)

In [None]:
for datum in dataloader:
    break

In [None]:
img = datum['images'].cuda()

In [None]:
_, _, H, W, D = img.shape

In [None]:
# img = (img > 0).float()

In [None]:
C = 128

In [None]:
d = 4
C = 128
f_xy = nn.Parameter(0.01 * torch.randn(1, C, H//d, W//d).cuda())
f_yz = nn.Parameter(0.01 * torch.randn(1, C, W//d, D//d).cuda())
f_zx = nn.Parameter(0.01 * torch.randn(1, C, H//d, D//d).cuda())
# f_xy = nn.Parameter((img.mean(-1).repeat(1, C, 1, 1))).cuda()
# f_yz = nn.Parameter((img.mean(2).repeat(1, C, 1, 1))).cuda()
# f_zx = nn.Parameter((img.mean(3).repeat(1, C, 1, 1))).cuda()

In [None]:
class SineLayer(nn.Module):
    # See paper sec. 3.2, final paragraph, and supplement Sec. 1.5 for discussion of omega_0.
    
    # If is_first=True, omega_0 is a frequency factor which simply multiplies the activations before the 
    # nonlinearity. Different signals may require different omega_0 in the first layer - this is a 
    # hyperparameter.
    
    # If is_first=False, then the weights will be divided by omega_0 so as to keep the magnitude of 
    # activations constant, but boost gradients to the weight matrix (see supplement Sec. 1.5)
    
    def __init__(self, in_features, out_features, bias=True, activation=True,
                 is_first=False, omega_0=30):
        super().__init__()
        self.omega_0 = omega_0
        self.is_first = is_first
        self.activation = activation
        
        self.in_features = in_features
        self.linear = nn.Linear(in_features, out_features, bias=bias)
        
        self.init_weights()
    
    def init_weights(self):
        with torch.no_grad():
            if self.is_first:
                self.linear.weight.uniform_(-1 / self.in_features, 
                                             1 / self.in_features)      
            else:
                self.linear.weight.uniform_(-np.sqrt(6 / self.in_features) / self.omega_0, 
                                             np.sqrt(6 / self.in_features) / self.omega_0)
        
    def forward(self, input):
        ret = self.omega_0 * self.linear(input)
        return torch.sin(ret) if self.activation else ret
    
    def forward_with_intermediate(self, input): 
        # For visualization of activation distributions
        intermediate = self.omega_0 * self.linear(input)
        return torch.sin(intermediate), intermediate

In [None]:
decoder = nn.Sequential(
    nn.Linear(C*3, 256),
    nn.GELU(),
    nn.Linear(256, 256),
    nn.GELU(),
    nn.Linear(256, 256),
    nn.GELU(),
    nn.Linear(256, 256),
    nn.GELU(),
    nn.Linear(256, 256),
    nn.GELU(),
    nn.Linear(256, 1)
    # SineLayer(C*3, 256, is_first=True),
    # SineLayer(256, 256),
    # SineLayer(256, 256),
    # SineLayer(256, 256),
    # SineLayer(256, 1, activation=False),
).cuda()

In [None]:
opt = torch.optim.Adam([f_xy, f_yz, f_zx], lr=5e-2)
opt_d = torch.optim.Adam(decoder.parameters(), lr=3e-4)

In [None]:
add = False
cat = True

In [None]:
def concat_coord(x, y):
    # (x, y): [N,]
    x = x[None, None, :, None]
    y = y[None, None, :, None]
    coord = np.concatenate([x, y], axis=-1) # [1, 1, N, 2]
    return coord

pbar = tqdm(range(10000))
losses = []
for i in pbar:
    x = np.random.randint(H, size=(50000,))
    y = np.random.randint(W, size=(50000,))
    z = np.random.randint(D, size=(50000,))
#     x, y, z = [np.random.randint(60, 80, size=(10000,)) for _ in range(3)]
    if d==1:
        if add:
            f = f_xy[..., x, y] + f_yz[..., y, z] + f_zx[..., x, z]
        else:
            if cat:
                f = torch.cat([f_xy[..., x, y], f_yz[..., y, z], f_zx[..., x, z]], dim=1)  # [B, C, p]
            else:
                f = f_xy[..., x, y] * f_yz[..., y, z] * f_zx[..., x, z]
    else:
        # only assume concat
        x1, y1, z1 = x/H*2.0-1, y/W*2.0-1, z/D*2.0-1
        xy_c = torch.from_numpy(concat_coord(y1, x1)).cuda().float()
        yz_c = torch.from_numpy(concat_coord(z1, y1)).cuda().float()
        xz_c = torch.from_numpy(concat_coord(z1, x1)).cuda().float()
        fxy = F.grid_sample(f_xy, xy_c, mode='bilinear', align_corners=True)  # [1, C, 1, N]
        fyz = F.grid_sample(f_yz, yz_c, mode='bilinear', align_corners=True)
        fzx = F.grid_sample(f_zx, xz_c, mode='bilinear', align_corners=True)
        f = torch.cat([fxy, fyz, fzx], dim=1)  # [1, 3C, 1, N]
        f = f.squeeze(2)   # [1, 3C, N]
    
    f = f.permute(0, 2, 1)
#     f = f.mean(-1)[..., None]
    f.retain_grad()
    f2 = decoder(f).permute(0, 2, 1)
    f2.retain_grad()
    opt.zero_grad()
    opt_d.zero_grad()
    loss = F.mse_loss(f2, img[..., x, y, z])
    l1_loss = torch.abs(f2 - img[..., x, y, z]).mean()
    (loss + 0.1*l1_loss).backward()
    opt.step()
    opt_d.step()
    pbar.set_description("Iter: {}, loss: {}".format(i, loss.item()))
    losses.append(loss.item())

In [None]:
plt.plot(losses)
plt.yscale('log')

In [None]:
f2 - img[..., x, y, z]

In [None]:
with torch.no_grad():
    if add:
        f = f_xy[..., None] + f_yz[..., None, :, :] + f_zx[..., None, :] 
    else:
        if cat:
            fxyimg = f_xy[..., None].repeat(1, 1, 1, 1, D//d)
            fyzimg = f_yz[..., None, :, :].repeat(1, 1, H//d, 1, 1)
            fzximg = f_zx[..., None, :].repeat(1, 1, 1, W//d, 1)
            f = torch.cat([fxyimg, fyzimg, fzximg], dim=1)
#             f = torch.cat([f_xy[..., None], f_yz[..., None, :, :], f_zx[..., None, :]], dim=1)  # [B, C, p]
        else:
            f = f_xy[..., x, y] * f_yz[..., y, z] * f_zx[..., x, z]
                
    # f = [B, C, H, W, D]
    f = F.interpolate(f, (H, W, D), mode='trilinear', align_corners=True)
    f = f.permute(0, 2, 3, 4, 1)
    f = decoder(f).permute(0, 4, 1, 2, 3)
    mse = F.mse_loss(f, img)
    psnr = 10 * torch.log10(4/mse)
    print(mse.item(), psnr.item())
   

In [None]:
s = 70
fig, (ax0, ax1) = plt.subplots(1, 2, figsize=(20, 10))
ax0.imshow(f[0, 0, ..., s].data.cpu().numpy(), cmap='gray')
ax1.imshow(img[0, 0, ..., s].cpu().numpy(), cmap='gray')

In [None]:
F.mse_loss(f, img)

# Multi-scale KP

In [None]:
params = {
    'xy': [],
    'yz': [],
    'zx': []
}
C = 64
for d in [2, 4, 8, 16]:
    f_xy = nn.Parameter(0.99 * torch.randn(1, C, H//d, W//d).cuda())
    f_yz = nn.Parameter(0.99 * torch.randn(1, C, W//d, D//d).cuda())
    f_zx = nn.Parameter(0.99 * torch.randn(1, C, H//d, D//d).cuda())
    params['xy'].append(f_xy)
    params['yz'].append(f_yz)
    params['zx'].append(f_zx)

In [None]:
decoder = nn.Sequential(
    nn.Linear(C*4*3, 256),
    nn.GELU(),
    nn.Linear(256, 256),
    nn.GELU(),
    nn.Linear(256, 256),
    nn.GELU(),
    nn.Linear(256, 256),
    nn.GELU(),
    nn.Linear(256, 256),
    nn.GELU(),
    nn.Linear(256, 1)
).cuda()

In [None]:
opt = torch.optim.Adam(params['xy'] + params['yz'] + params['zx'], lr=5e-2)
opt_d = torch.optim.Adam(decoder.parameters(), lr=3e-4)

In [None]:
def concat_coord(x, y):
    # (x, y): [N,]
    x = x[None, None, :, None]
    y = y[None, None, :, None]
    coord = np.concatenate([x, y], axis=-1) # [1, 1, N, 2]
    return coord

In [None]:
pbar = tqdm(range(10000))
losses = []
for i in pbar:
    x = np.random.randint(H, size=(50000,))
    y = np.random.randint(W, size=(50000,))
    z = np.random.randint(D, size=(50000,))
    #f = torch.cat([f_xy[..., x, y], f_yz[..., y, z], f_zx[..., x, z]], dim=1)  # [B, C, p]
    x1, y1, z1 = x/H*2.0-1, y/W*2.0-1, z/D*2.0-1
    xy_c = torch.from_numpy(concat_coord(y1, x1)).cuda().float()
    yz_c = torch.from_numpy(concat_coord(z1, y1)).cuda().float()
    xz_c = torch.from_numpy(concat_coord(z1, x1)).cuda().float()
    fxy = [F.grid_sample(f_xy, xy_c, mode='bilinear', align_corners=True) for f_xy in params['xy']]  # [1, C, 1, N]
    fyz = [F.grid_sample(f_yz, yz_c, mode='bilinear', align_corners=True) for f_yz in params['yz']]
    fzx = [F.grid_sample(f_zx, xz_c, mode='bilinear', align_corners=True) for f_zx in params['zx']]
    f = torch.cat([*fxy, *fyz, *fzx], dim=1)  # [1, 3C, 1, N]
    f = f.squeeze(2)   # [1, 3C, N]

    f = f.permute(0, 2, 1)
    f.retain_grad()
    f2 = decoder(f).permute(0, 2, 1)
    f2.retain_grad()
    opt.zero_grad()
    opt_d.zero_grad()
    loss = F.mse_loss(f2, img[..., x, y, z])
    l1_loss = torch.abs(f2 - img[..., x, y, z]).mean()
    (loss + l1_loss*0.1).backward()
    opt.step()
    opt_d.step()
    pbar.set_description("Iter: {}, loss: {}".format(i, loss.item()))
    losses.append(loss.item())
    

In [None]:
with torch.no_grad():
    fxyimg = [f_xy[..., None].repeat(1, 1, 1, 1, D//2) for f_xy in params['xy']]
    fyzimg = [f_yz[..., None, :, :].repeat(1, 1, H//2, 1, 1) for f_yz in params['yz']]
    fzximg = [f_zx[..., None, :].repeat(1, 1, 1, W//2, 1) for f_zx in params['zx']]
    fxyimg = [F.interpolate(f, (H//2, W//2, D//2), mode='trilinear', align_corners=True) for f in fxyimg]
    fyzimg = [F.interpolate(f, (H//2, W//2, D//2), mode='trilinear', align_corners=True) for f in fyzimg]
    fzximg = [F.interpolate(f, (H//2, W//2, D//2), mode='trilinear', align_corners=True) for f in fzximg]
    f = torch.cat([*fxyimg, *fyzimg, *fzximg], dim=1)

    pred_img = np.zeros((1, 1, H, W, D))
    for octet in range(8):
        a, b, c = octet//4, (octet//2)%2, octet%2
        fmini = f[:, :, a*H//4:(a+1)*H//4, b*W//4:(b+1)*W//4, c*D//4:(c+1)*D//4]
        fmini = F.interpolate(fmini, (H//2, W//2, D//2), mode='trilinear', align_corners=False)
        fmini = fmini.permute(0, 2, 3, 4, 1)
        fmini = decoder(fmini).permute(0, 4, 1, 2, 3)
        pred_img[:, :, a*H//2:(a+1)*H//2, b*W//2:(b+1)*W//2, c*D//2:(c+1)*D//2] = fmini.data.cpu().numpy()
    mse = ((pred_img - img.cpu().numpy())**2).mean()
    psnr = 10 * np.log10(4/mse)
    print(mse.item(), psnr.item())

In [None]:
s = 100
fig, (ax0, ax1) = plt.subplots(1, 2, figsize=(20, 10))
ax0.imshow(pred_img[0, 0, ..., s], cmap='gray')
ax1.imshow(img[0, 0, ..., s].cpu().numpy(), cmap='gray')

# What about multi-res 3D images

In [None]:
params = []

for d in [1, 2, 4, 8, 16]:
    C = 64
    p = nn.Parameter(0.01*torch.randn(1, C, H//d, W//d, D//d).cuda())
    params.append(p)
print([p.shape for p in params])

In [None]:
inps = [p.shape[1] for p in params]

In [None]:
decoder = nn.Sequential(
    nn.Linear(sum(inps), 256),
    nn.GELU(),
    nn.Linear(256, 256),
    nn.GELU(),
    nn.Linear(256, 256),
    nn.GELU(),
    nn.Linear(256, 256),
    nn.GELU(),
    nn.Linear(256, 256),
    nn.GELU(),
    nn.Linear(256, 1)
).cuda()

In [None]:
opt = torch.optim.Adam(params, lr=5e-2)
opt_d = torch.optim.Adam(decoder.parameters(), lr=3e-4)

In [None]:
def concat_coord_3d(x, y, z):
    # (x, y): [N,]
    x = x[None, None, None, :, None]
    y = y[None, None, None, :, None]
    z = z[None, None, None, :, None]
    coord = np.concatenate([x, y, z], axis=-1) # [1, 1, 1, N, 3]
    return coord

pbar = tqdm(range(10000))
losses = []
for i in pbar:
    x = np.random.randint(H, size=(50000,))
    y = np.random.randint(W, size=(50000,))
    z = np.random.randint(D, size=(50000,))
    x1, y1, z1 = x/H*2.0-1, y/W*2.0-1, z/D*2.0-1
    coords = torch.from_numpy(concat_coord_3d(z1, y1, x1)).cuda().float()
    fs = [F.grid_sample(p, coords, mode='bilinear', align_corners=True) for p in params]
    fs = torch.cat(fs, dim=1)
    fs = fs.squeeze(2).squeeze(2)  # [b, C, N]
    f = fs.permute(0, 2, 1)
    f2 = decoder(f).permute(0, 2, 1)
    opt.zero_grad()
    opt_d.zero_grad()
    loss = F.mse_loss(f2, img[..., x, y, z])
    l1_loss = torch.abs(f2 - img[..., x, y, z]).mean()
    (loss + 0.1*l1_loss).backward()
    opt.step()
    opt_d.step()
    pbar.set_description("Iter: {}, loss: {}".format(i, loss.item()))
    losses.append(loss.item())
    

In [None]:
with torch.no_grad():
    fimgs = [F.interpolate(p, (H, W, D), mode='trilinear', align_corners=True) for p in params]
    fimgs = torch.cat(fimgs, dim=1)
    fimgs = fimgs.permute(0, 2, 3, 4, 1)
    f = decoder(fimgs).permute(0, 4, 1, 2, 3)
    mse = F.mse_loss(f, img)
    psnr = 10 * torch.log10(4/mse)
    print(mse.item(), psnr.item())

In [None]:
s = 60
fig, (ax0, ax1) = plt.subplots(1, 2, figsize=(20, 10))
ax0.imshow(f[0, 0, ..., s].data.cpu().numpy(), cmap='gray')
ax1.imshow(img[0, 0, ..., s].cpu().numpy(), cmap='gray')