In [1]:
import os

import torch
from hydra.utils import instantiate
from monai.apps import download_url
from neuro_utils.visualize import plot_scans
from torch import nn

from diffusion_3d.chestct.autoencoder.vae.maisi.config import get_config
from diffusion_3d.datasets.ct_rate import CTRATEDataModule

In [2]:
trained_autoencoder_path = r"/raid3/arjun/checkpoints/maisi/autoencoder_epoch273.pt"
trained_autoencoder_path_url = "https://developer.download.nvidia.com/assets/Clara/monai/tutorials/model_zoo/model_maisi_autoencoder_epoch273_alternative.pt"
if not os.path.exists(trained_autoencoder_path):
    download_url(url=trained_autoencoder_path_url, filepath=trained_autoencoder_path)

In [3]:
state_dict = torch.load(trained_autoencoder_path, map_location='cpu')

  state_dict = torch.load(trained_autoencoder_path, map_location='cpu')


In [4]:
model_config = {
    "_target_": "monai.apps.generation.maisi.networks.autoencoderkl_maisi.AutoencoderKlMaisi",
    "spatial_dims": 3,
    "in_channels": 1,
    "out_channels": 1,
    "latent_channels": 4,
    "num_channels": [64, 128, 256],
    "num_res_blocks": [2, 2, 2],
    "norm_num_groups": 32,
    "norm_eps": 1e-06,
    "attention_levels": [False, False, False],
    "with_encoder_nonlocal_attn": False,
    "with_decoder_nonlocal_attn": False,
    "use_checkpointing": False,
    "use_convtranspose": False,
    "norm_float16": False,
    "num_splits": 1,
    "dim_split": 2,
    'save_mem': False,
}

In [5]:
device = torch.device('cuda:0')
# Don't use cpu, server crashes because of mem and compute requirements

In [7]:
model: nn.Module = instantiate(model_config)
model.load_state_dict(state_dict)
model.to(device)
model.eval()


[1;35mAutoencoderKlMaisi[0m[1m([0m
  [1m([0mencoder[1m)[0m: [1;35mMaisiEncoder[0m[1m([0m
    [1m([0mblocks[1m)[0m: [1;35mModuleList[0m[1m([0m
      [1m([0m[1;36m0[0m[1m)[0m: [1;35mMaisiConvolution[0m[1m([0m
        [1m([0mconv[1m)[0m: [1;35mConvolution[0m[1m([0m
          [1m([0mconv[1m)[0m: [1;35mConv3d[0m[1m([0m[1;36m1[0m, [1;36m64[0m, [33mkernel_size[0m=[1m([0m[1;36m3[0m, [1;36m3[0m, [1;36m3[0m[1m)[0m, [33mstride[0m=[1m([0m[1;36m1[0m, [1;36m1[0m, [1;36m1[0m[1m)[0m, [33mpadding[0m=[1m([0m[1;36m1[0m, [1;36m1[0m, [1;36m1[0m[1m)[0m[1m)[0m
        [1m)[0m
      [1m)[0m
      [1m([0m[1;36m1[0m-[1;36m2[0m[1m)[0m: [1;36m2[0m x [1;35mMaisiResBlock[0m[1m([0m
        [1m([0mnorm1[1m)[0m: [1;35mMaisiGroupNorm3D[0m[1m([0m[1;36m32[0m, [1;36m64[0m, [33meps[0m=[1;36m1e[0m[1;36m-06[0m, [33maffine[0m=[3;92mTrue[0m[1m)[0m
        [1m([0mconv1[1m)[0m: [1;35mMaisiConvol

In [30]:
data_config = get_config((24, 512, 512)).data

datamodule = CTRATEDataModule(data_config)
dataloader = datamodule.val_dataloader()
len(dataloader)

valid:   0%|          | 0/500 [00:00<?, ?it/s]

No. of valid datapoints: 500


[1;36m500[0m

In [24]:
with torch.no_grad():
    for i, batch in enumerate(dataloader):
        x = batch['image']
        x = x[:1]
        print(x.shape, x.min(), x.max())
        torch.cuda.reset_peak_memory_stats()
        # y = model.encode(x.to(device))[0].cpu()
        y = model(x.to(device))[0].cpu()
        mem = torch.cuda.max_memory_allocated()
        print(y.shape, y.min(), y.max())
        print(f"Memory: {mem / 2 ** 30:.2f} GB")

        for i in range(x.shape[0]):
            plot_scans([x[i, 0], y[i, 0]], ["Original", "Reconstructed"])
        
        break

torch.Size([1, 1, 12, 512, 512]) metatensor(0.) metatensor(1.)
torch.Size([1, 1, 12, 512, 512]) tensor(-0.0341) tensor(1.1892)
Memory: 5.35 GB


interactive(children=(IntSlider(value=0, description='z', max=11), Output()), _dom_classes=('widget-interact',…

In [37]:
from monai.inferers.inferer import SlidingWindowInferer

inferer = SlidingWindowInferer((4, 512, 512), overlap=0.25)

In [38]:
with torch.no_grad():
    for i, batch in enumerate(dataloader):
        x = batch['image']
        x = x[:1]
        print(x.shape, x.min(), x.max())
        torch.cuda.reset_peak_memory_stats()
        # y = model.encode(x.to(device))[0].cpu()
        y = inferer(x.to(device), model)[0].cpu()
        mem = torch.cuda.max_memory_allocated()
        print(y.shape, y.min(), y.max())
        print(f"Memory: {mem / 2 ** 30:.2f} GB")

        for i in range(x.shape[0]):
            plot_scans([x[i, 0], y[i, 0]], ["Original", "Reconstructed"])
        
        break