In [1]:
import nibabel as nib
import numpy as np
import torch
import pytorch_lightning as pl
import torch.nn as nn
import torch.nn.functional as F
import argparse
import tinycudann as tcnn
import os

  warn(f"Failed to load image Python extension: {e}")


In [2]:
class HashMLP(pl.LightningModule):
  def __init__(self, config, dim_in=3, dim_out=1):
    super().__init__()
    self.dim_in = dim_in
    self.dim_out = dim_out

    self.encoding = tcnn.Encoding(n_input_dims=dim_in, encoding_config=config['encoding'])
    self.mlp= tcnn.Network(n_input_dims=self.encoding.n_output_dims, n_output_dims=dim_out, network_config=config['network'])
    self.model = torch.nn.Sequential(self.encoding, self.mlp)

  def forward(self, x):
    return self.model(x)

  def configure_optimizers(self):
    optimizer = torch.optim.Adam(self.parameters(), lr=5e-3)
    return optimizer

  def training_step(self, batch, batch_idx):
    x, y = batch
    z = self(x)

    loss = F.mse_loss(z, y)

    self.log("train_loss", loss)
    return loss

  def predict_step(self, batch, batch_idx):
    x, y = batch
    return self(x)

In [3]:
num_epochs = 25
batch_size = 4096*4
num_workers = os.cpu_count()

#Read image
image_file = '/home/rousseau/Sync-Exp/Data/template_dHCP/fetal_brain_mri_atlas/structural/t1-t21.00.nii.gz'
image = nib.load(image_file)
data = image.get_fdata()

#Create grid
dim = 3
x = torch.linspace(-1, 1, steps=data.shape[0])
y = torch.linspace(-1, 1, steps=data.shape[1])
z = torch.linspace(-1, 1, steps=data.shape[2])

mgrid = torch.stack(torch.meshgrid(x,y,z), dim=-1)

#Convert to X=(x,y,z) and Y=intensity
X = torch.Tensor(mgrid.reshape(-1,dim))
Y = torch.Tensor(data.flatten())

#Normalize intensities between [-1,1]
Y = (Y - torch.min(Y)) / (torch.max(Y) - torch.min(Y)) * 2 - 1
Y = torch.reshape(Y, (-1,1))

#Pytorch dataloader
dataset = torch.utils.data.TensorDataset(X,Y)
loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True)


  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


In [4]:
#Training
#https://github.com/NVlabs/tiny-cuda-nn/blob/master/DOCUMENTATION.md
config = {
"encoding": {
    "otype": "HashGrid",
    "n_levels": 8,
    "n_features_per_level": 2,
    "log2_hashmap_size": 15,
    "base_resolution": 16,
    "per_level_scale": 1.3819#1.5
},
"network": {
    "otype": "FullyFusedMLP",
    "activation": "ReLU",
    "output_activation": "None",
    "n_neurons": 128,
    "n_hidden_layers": 2
}
}

In [5]:
net = HashMLP(config = config, dim_in=3, dim_out=1)
trainer = pl.Trainer(max_epochs=num_epochs, precision=16)
trainer.fit(net, loader)

  rank_zero_warn(
Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA GeForce RTX 3080') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name     | Type       | Params
----------------------------------------
0 | encoding | Encoding   | 419 K 
1 | mlp      | Network    | 20.5 K
2 | model    | Sequential | 440 K 
----------------------------------------
440 K     Trainable params
0         Non-trainable params
440 K     Total params
1.761     Total estimated model params size (MB)
Traceback (most recent call

Training: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=25` reached.


In [22]:
print(X.shape)
X = X.to(device='cuda')
net = net.to(device='cuda')
enc = net.encoding(X)
print(enc.shape)
print(mgrid.shape)

torch.Size([7160400, 3])
torch.Size([7160400, 16])
torch.Size([180, 221, 180, 3])
