In [1]:
import sys
from pathlib import Path

sys.path.append("..")

import multiprocessing as mp
import os

import datasets
import matplotlib.pyplot as plt
import nibabel as nib
import numpy as np
import pandas as pd
import torch

# from datasets import combined_brain_1mm, combined_brain_17mm
from generative.metrics import MultiScaleSSIMMetric
from generative.networks.nets import AutoencoderKL, DiffusionModelUNet
from generative.networks.schedulers import DDPMScheduler
from models.autoencoderkl import AutoencoderKLDownsampleControl
from monai.config import print_config
from monai.utils import set_determinism
from omegaconf import OmegaConf
from torch.utils.data import DataLoader
from tqdm import tqdm

In [5]:
def get_cond_diffusion_model(diffusion_root_path, ckpt):
    diffusion_root_path = Path(diffusion_root_path)
    diffusion_config_path = diffusion_root_path / "config.yaml"
    diffusion_ckpt_path = diffusion_root_path / ckpt
    print(f"Diffusion path {diffusion_ckpt_path}")

    config = OmegaConf.load(diffusion_config_path)
    diffusion_model = DiffusionModelUNet(**config["ldm"].get("params", dict()))
    scheduler = DDPMScheduler(**config["ldm"].get("scheduler", dict()))
    diffusion_ckpt = torch.load(diffusion_ckpt_path)

    diffusion_model = torch.nn.DataParallel(diffusion_model)
    # diffusion_model.load_state_dict(diffusion_ckpt["diffusion"])
    diffusion_model.load_state_dict(diffusion_ckpt["diffusion"])

    checkpoint_name = str(diffusion_ckpt_path).split("/")[-1]

    return diffusion_model, scheduler, checkpoint_name

In [None]:
# sample with conditional image from test dataset
device = torch.device("cuda")
root_path = "/home/sz9jt/data/generative_brain/cond_diffusion/diffusion_tomshcp3d_cond"
diffusion_model, scheduler, ckptname = get_cond_diffusion_model(root_path, "checkpoint100.pth")
print(f"using {ckptname}")
diffusion_model.eval()
diffusion_model.to(device)

config = OmegaConf.load(root_path + "/config.yaml")
ds_params = config.dataset.params
ds_params["type"] = "test"
ds_params["patch_size"] = (224, 288, 224)
test_ds = datasets.hcp_3d.TomsHCP3DCondPatchDataset(**ds_params)
print("If using patch: ", test_ds.patch)

data = test_ds[-3]
gt_img = data["gt_image"].squeeze()
cond_img = data["cond_image"].squeeze()

# first convert data shape into factors of 64
data_shape = [224, 288, 224]
factors = []
for item in data_shape:
    factors.append((int(item // 64) + 1) * 64)
print(factors)

tmp = np.zeros(factors) - 1
tmp[:224, :288, :224] = gt_img
gt_img = tmp

tmp = np.zeros(factors) - 1
tmp[:224, :288, :224] = cond_img
cond_img = tmp

all_gt_patches = []
all_cond_patches = []

for x in range(0, factors[0], 64):
    for y in range(0, factors[1], 64):
        for z in range(0, factors[2], 64):
            all_gt_patches.append(gt_img[x : x + 64, y : y + 64, z : z + 64])
            all_cond_patches.append(cond_img[x : x + 64, y : y + 64, z : z + 64])

all_gt_patches = np.array(all_gt_patches).reshape(-1, 1, 64, 64, 64)
all_cond_patches = np.array(all_cond_patches).reshape(-1, 1, 64, 64, 64)

print(all_cond_patches.shape)

torch.manual_seed(config.args.seed)

batch_size = 1
all_res_patches = []
for i in range(4, len(all_gt_patches), batch_size):
    end_idx = min(i + batch_size, len(all_gt_patches))
    batch_len = end_idx - i

    cond = torch.from_numpy(all_cond_patches[i:end_idx]).float().to(device)
    y = torch.randn(cond.shape).to(device)
    with torch.no_grad():
        prompt_embeds = None
        for t in tqdm(scheduler.timesteps, ncols=70):
            tmp_input = torch.cat([y, cond], dim=1).float().to(device)
            noise_pred = diffusion_model(x=tmp_input, timesteps=torch.asarray((t,)).to(device), context=prompt_embeds)
            y, _ = scheduler.step(noise_pred, t, y)
    all_res_patches.append(y.cpu().numpy())
    np.save("res.npy", np.array(all_res_patches))
    print(np.array(all_res_patches).shape)

Diffusion path /home/sz9jt/data/generative_brain/cond_diffusion/diffusion_tomshcp3d_cond/checkpoint100.pth
using checkpoint100.pth
If using patch:  False
[256, 320, 256]
(80, 1, 64, 64, 64)


100%|█████████████████████████████| 1000/1000 [02:50<00:00,  5.87it/s]


(1, 1, 1, 64, 64, 64)


100%|█████████████████████████████| 1000/1000 [02:50<00:00,  5.87it/s]


(2, 1, 1, 64, 64, 64)


100%|█████████████████████████████| 1000/1000 [02:50<00:00,  5.87it/s]


(3, 1, 1, 64, 64, 64)


100%|█████████████████████████████| 1000/1000 [02:50<00:00,  5.87it/s]


(4, 1, 1, 64, 64, 64)


100%|█████████████████████████████| 1000/1000 [02:50<00:00,  5.87it/s]


(5, 1, 1, 64, 64, 64)


100%|█████████████████████████████| 1000/1000 [02:50<00:00,  5.87it/s]


(6, 1, 1, 64, 64, 64)


100%|█████████████████████████████| 1000/1000 [02:50<00:00,  5.88it/s]


(7, 1, 1, 64, 64, 64)


100%|█████████████████████████████| 1000/1000 [02:50<00:00,  5.88it/s]


(8, 1, 1, 64, 64, 64)


100%|█████████████████████████████| 1000/1000 [02:50<00:00,  5.88it/s]


(9, 1, 1, 64, 64, 64)


100%|█████████████████████████████| 1000/1000 [02:49<00:00,  5.88it/s]


(10, 1, 1, 64, 64, 64)


100%|█████████████████████████████| 1000/1000 [02:50<00:00,  5.88it/s]


(11, 1, 1, 64, 64, 64)


100%|█████████████████████████████| 1000/1000 [02:49<00:00,  5.88it/s]


(12, 1, 1, 64, 64, 64)


100%|█████████████████████████████| 1000/1000 [02:50<00:00,  5.88it/s]


(13, 1, 1, 64, 64, 64)


100%|█████████████████████████████| 1000/1000 [02:50<00:00,  5.88it/s]


(14, 1, 1, 64, 64, 64)


100%|█████████████████████████████| 1000/1000 [02:49<00:00,  5.89it/s]


(15, 1, 1, 64, 64, 64)


100%|█████████████████████████████| 1000/1000 [02:50<00:00,  5.88it/s]


(16, 1, 1, 64, 64, 64)


100%|█████████████████████████████| 1000/1000 [02:50<00:00,  5.88it/s]


(17, 1, 1, 64, 64, 64)


100%|█████████████████████████████| 1000/1000 [02:50<00:00,  5.88it/s]


(18, 1, 1, 64, 64, 64)


100%|█████████████████████████████| 1000/1000 [02:49<00:00,  5.88it/s]


(19, 1, 1, 64, 64, 64)


100%|█████████████████████████████| 1000/1000 [02:49<00:00,  5.88it/s]


(20, 1, 1, 64, 64, 64)


100%|█████████████████████████████| 1000/1000 [02:49<00:00,  5.88it/s]


(21, 1, 1, 64, 64, 64)


100%|█████████████████████████████| 1000/1000 [02:50<00:00,  5.88it/s]


(22, 1, 1, 64, 64, 64)


100%|█████████████████████████████| 1000/1000 [02:50<00:00,  5.88it/s]


(23, 1, 1, 64, 64, 64)


100%|█████████████████████████████| 1000/1000 [02:50<00:00,  5.88it/s]


(24, 1, 1, 64, 64, 64)


100%|█████████████████████████████| 1000/1000 [02:49<00:00,  5.88it/s]


(25, 1, 1, 64, 64, 64)


 56%|████████████████▉             | 565/1000 [01:36<01:13,  5.88it/s]

In [6]:
# sample with conditional image from test dataset
device = torch.device("cuda")
root_path = "/home/sz9jt/data/generative_brain/cond_diffusion/diffusion_tomshcp3d_cond"
diffusion_model, scheduler, ckptname = get_cond_diffusion_model(root_path, "checkpoint100.pth")
print(f"using {ckptname}")
diffusion_model.eval()
diffusion_model.to(device)

config = OmegaConf.load(root_path + "/config.yaml")
ds_params = config.dataset.params
ds_params["type"] = "test"
ds_params["patch_size"] = (224, 288, 224)
test_ds = datasets.hcp_3d.TomsHCP3DCondPatchDataset(**ds_params)
print("If using patch: ", test_ds.patch)

data = test_ds[-3]
gt_img = data["gt_image"].squeeze()
cond_img = data["cond_image"].squeeze()

# first convert data shape into factors of 64
data_shape = [224, 288, 224]
factors = []
for item in data_shape:
    factors.append((int(item // 64) + 1) * 64)
print(factors)

tmp = np.zeros(factors) - 1
tmp[:224, :288, :224] = gt_img
gt_img = tmp

tmp = np.zeros(factors) - 1
tmp[:224, :288, :224] = cond_img
cond_img = tmp

Diffusion path /home/sz9jt/data/generative_brain/cond_diffusion/diffusion_tomshcp3d_cond/checkpoint100.pth
using checkpoint100.pth
All data: 1113
Dataset size: 223
If using patch:  False
[256, 320, 256]


In [8]:
tmp1 = np.load("tmp.npy")
tmp2 = np.load("res.npy")

print(tmp1.shape, tmp2.shape)

tmp = np.concatenate([tmp1.squeeze(), tmp2.squeeze()])
print(tmp.shape)

(4, 1, 64, 64, 64) (76, 1, 1, 64, 64, 64)
(80, 64, 64, 64)


In [10]:
print(gt_img.shape)

output_img = np.zeros(gt_img.shape) - 1

idx = 0
for x in range(0, factors[0], 64):
    for y in range(0, factors[1], 64):
        for z in range(0, factors[2], 64):
            output_img[x : x + 64, y : y + 64, z : z + 64] = tmp[idx]
            idx += 1



(256, 320, 256)


In [11]:
gt_img = data["gt_image"].squeeze()
print(gt_img.shape)
cond_img = data["cond_image"].squeeze()

outptu_img = output_img[:224, :288, :224]

tmp = nib.Nifti1Image(gt_img, np.eye(4))
nib.save(tmp, "gt_img.nii.gz")

tmp = nib.Nifti1Image(cond_img, np.eye(4))
nib.save(tmp, "cond_img.nii.gz")

tmp = nib.Nifti1Image(output_img, np.eye(4))
nib.save(tmp, "output_img.nii.gz")

(224, 288, 224)
