In [None]:
import nibabel as nib
import numpy as np
import torch
import pytorch_lightning as pl
import torch.nn as nn
import torch.nn.functional as F
import argparse
import tinycudann as tcnn
import os
from os.path import expanduser
home = expanduser("~")

In [None]:
class HashMLP(pl.LightningModule):
  def __init__(self, config, dim_in=3, dim_out=1):
    super().__init__()
    self.dim_in = dim_in
    self.dim_out = dim_out

    self.encoding = tcnn.Encoding(n_input_dims=dim_in, encoding_config=config['encoding'])
    self.mlp= tcnn.Network(n_input_dims=self.encoding.n_output_dims, n_output_dims=dim_out, network_config=config['network'])
    self.model = torch.nn.Sequential(self.encoding, self.mlp)

  def forward(self, x):
    return self.model(x)

  def configure_optimizers(self):
    optimizer = torch.optim.Adam(self.parameters(), lr=5e-3)
    return optimizer

  def training_step(self, batch, batch_idx):
    x, y = batch
    z = self(x)

    loss = F.mse_loss(z, y)

    self.log("train_loss", loss)
    return loss

  def predict_step(self, batch, batch_idx):
    x, y = batch
    return self(x)

In [None]:
batch_size = 4096*4
num_workers = os.cpu_count()

#Read image
image_file = os.path.join(home,'Sync-Exp','dhcp128.nii.gz')

image = nib.load(image_file)
data = image.get_fdata()

#Create grid
dim = 3
# BUG
#x = torch.linspace(-1, 1, steps=data.shape[0])
#y = torch.linspace(-1, 1, steps=data.shape[1])
#z = torch.linspace(-1, 1, steps=data.shape[2])
# Needs positive coordinates !
x = torch.linspace(0, 1, steps=data.shape[0])
y = torch.linspace(0, 1, steps=data.shape[1])
z = torch.linspace(0, 1, steps=data.shape[2])

#Convert to X=(x,y,z) and Y=intensity
mgrid = torch.stack(torch.meshgrid(x,y,z,indexing='ij'), dim=-1)
X = torch.Tensor(mgrid.reshape(-1,dim))
Y = torch.Tensor(data.flatten())

#Normalize intensities between [-1,1]
Y = (Y - torch.min(Y)) / (torch.max(Y) - torch.min(Y)) * 2 - 1
Y = torch.reshape(Y, (-1,1))

#Pytorch dataloader
dataset = torch.utils.data.TensorDataset(X,Y)
loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True)


In [None]:
mgrid.shape
print(X[0:4,:])

In [None]:

nx = data.shape[0]
ny = data.shape[1]
nz = data.shape[2]
nmax = np.max([nx,ny,nz])

'''
# Manual building of the grid of coordinates
half_dx = 0.5 / nx
half_dy = 0.5 / ny
half_dz = 0.5 / nz
n_voxels = nx * ny * nz

X2 = torch.zeros(X.shape)
Y2 = torch.zeros(Y.shape)
n=0
for i in range(nx):
  for j in range(ny):
    for k in range(nz):
      X2[n,0] = (i * 1.0 / nx)
      X2[n,1] = (j * 1.0 / ny)
      X2[n,2] = (k * 1.0 / nz)
      Y2[n,0] = data[i,j,k]
      n = n+1
Y2 = (Y2 - torch.min(Y2)) / (torch.max(Y2) - torch.min(Y2)) * 2 - 1

#Check ordering
nonz = np.argwhere(Y>0)
print(nonz.shape)
elm = nonz[0,0]
print(X[elm:elm+3,:])
print(X2[elm:elm+3,:])
print(Y[elm:elm+3,:])
print(Y2[elm:elm+3,:])
'''


In [None]:
#Training

n_levels = 6
n_features_per_level = 2
n_features = n_levels * n_features_per_level
base_resolution = 32#16
b = np.exp((np.log(nmax)-np.log(base_resolution))/(n_levels-1))
print(b)

#https://github.com/NVlabs/tiny-cuda-nn/blob/master/DOCUMENTATION.md
config = {
"encoding": {
    "otype": "HashGrid",
    "n_levels": n_levels,
    "n_features_per_level": n_features_per_level,
    "log2_hashmap_size": 19,
    "base_resolution": base_resolution,
    "per_level_scale": b#1.3819#1.5
},
"network": {
    "otype": "FullyFusedMLP",
    "activation": "ReLU",
    "output_activation": "None",
    "n_neurons": 128,
    "n_hidden_layers": 2
}
}

In [None]:
net = HashMLP(config = config, dim_in=3, dim_out=1)

In [None]:
num_epochs = 25
#trainer = pl.Trainer(gpus=1,max_epochs=num_epochs, precision=16) # provides the gpu if necessary
trainer = pl.Trainer(max_epochs=num_epochs, precision=16) # no need to provide the gpu (depends on GPU type?)
trainer.fit(net, loader)


In [None]:
print(X.shape)
X = X.to(device='cuda')
net = net.to(device='cuda')
enc = net.encoding(X)
recon = net.forward(X)
print(enc.shape)
print(mgrid.shape)
print(data.shape)
print(recon.shape)

In [None]:
data4d = enc.cpu().detach().numpy().reshape((nx,ny,nz,n_features))
nib.save(nib.Nifti1Image(np.float32(data4d),image.affine), os.path.join(home,'enc.nii.gz'))

data3d = recon.cpu().detach().numpy().reshape((nx,ny,nz,1))
nib.save(nib.Nifti1Image(np.float32(data3d),image.affine),os.path.join(home,'recon.nii.gz'))


In [None]:
nf1 = int(n_features / 3)
nf2 = int(n_features * 2 / 3)
print(nf1)
print(nf2)

enc_low = torch.zeros(enc.shape)
enc_low[:,0:nf1] = enc[:,0:nf1]
recon_low = net.mlp(enc_low)
data_low = recon_low.cpu().detach().numpy().reshape((nx,ny,nz,1))
nib.save(nib.Nifti1Image(np.float32(data_low),image.affine),os.path.join(home,'recon_low.nii.gz'))

enc_med = torch.zeros(enc.shape)
enc_med[:,nf1:nf2] = enc[:,nf1:nf2]
recon_med = net.mlp(enc_med)
data_med = recon_med.cpu().detach().numpy().reshape((nx,ny,nz,1))
nib.save(nib.Nifti1Image(np.float32(data_med),image.affine),os.path.join(home,'recon_med.nii.gz'))

enc_high = torch.zeros(enc.shape)
enc_high[:,nf2:n_features] = enc[:,nf2:n_features]
recon_high = net.mlp(enc_high)
data_high = recon_high.cpu().detach().numpy().reshape((nx,ny,nz,1))
nib.save(nib.Nifti1Image(np.float32(data_high),image.affine),os.path.join(home,'recon_high.nii.gz'))
