In [None]:
import sys
from pathlib import Path

sys.path.append(str(Path().absolute().parent))

print(Path().absolute().parent)


In [None]:
import torch
import matplotlib.pyplot as plt
import numpy as np

import torch.nn.functional as F
from torch import Tensor

import nibabel as nib
from pathlib import Path

from params import ModelConfig, GMM_ModelConfig , config
from nipype.interfaces.fsl import FAST

from collections.abc import Callable

import tempfile
import os

from common.utils import *
from components.datawrapper import norm, norm_mri, mask_soft_tissue_only, sigmoid_soft_clip, check_nan, is_all_zero

from skimage.metrics import structural_similarity as ssim_np
from skimage.metrics import peak_signal_noise_ratio as psnr_np

from parts import get_network, get_loss_func, safe_mean_loss
import h5py



In [None]:
def load_data(path: Path):
    mri = nib.load(str(path))
    data = mri.get_fdata()
    return data

def to_tensor(x):
    x = torch.from_numpy(x).float()
    x = x.unsqueeze(0).unsqueeze(0)  # [1, 1, H, W]
    return x

def get_slice_arrays(axes, slice_idx , mri, ct, mask=None):
    if axes == "z":
        return (
            np.rot90(mri[:, :, slice_idx], k=3).copy(),
            np.rot90(ct[:, :, slice_idx], k=3).copy(),
            np.rot90(mask[:, :, slice_idx], k=3).copy() if mask is not None else None
        )
    elif axes == "y":
        return (
            np.rot90(mri[:, slice_idx, :], k=1).copy(),
            np.rot90(ct[:, slice_idx, :], k=1).copy(),
            np.rot90(mask[:, slice_idx, :], k=1).copy() if mask is not None else None
        )
    elif axes == "x":
        return (
            np.rot90(mri[slice_idx, :, :], k=1).copy(),
            np.rot90(ct[slice_idx, :, :], k=1).copy(),
            np.rot90(mask[slice_idx, :, :], k=1).copy() if mask is not None else None
        )


In [None]:
def center_crop_or_pad_3d(x: np.ndarray, target_shape) -> np.ndarray:
    """
    Center-crop or pad a 3D array x to the target_shape.
    x.shape == (D, H, W), target_shape == (tD, tH, tW)
    """
    D, H, W = x.shape
    tD, tH, tW = target_shape

    # compute padding for each axis
    pad_d = max(0, tD - D)
    pad_h = max(0, tH - H)
    pad_w = max(0, tW - W)

    pd0, pd1 = pad_d // 2, pad_d - pad_d // 2
    ph0, ph1 = pad_h // 2, pad_h - pad_h // 2
    pw0, pw1 = pad_w // 2, pad_w - pad_w // 2

    # pad on all three axes
    x_padded = np.pad(
        x,
        ((pd0, pd1), (ph0, ph1), (pw0, pw1)),
        mode="constant",
        constant_values=0,
    )

    # now crop centrally
    D2, H2, W2 = x_padded.shape
    sd = (D2 - tD) // 2
    sh = (H2 - tH) // 2
    sw = (W2 - tW) // 2

    return x_padded[sd : sd + tD, sh : sh + tH, sw : sw + tW]

In [None]:
ckpt = 2
log_v = config.run_dir
run_idx = 5

idx = 102
best = False
config.model_type = "diffusion"
ROOT_DIR = Path().absolute().parent



In [None]:
def check_nan(arr, name=""):
    if np.isnan(arr).any():
        print(f"⚠️ NaN detected in {name} (shape={arr.shape})")
        return -1
    else:
        return 0

In [None]:
DATA_DIR = ROOT_DIR / "sample_data"

mri_path = DATA_DIR / "mr.nii.gz"
ct_path = DATA_DIR / "ct.nii.gz"
mask_path = DATA_DIR / "mask.nii.gz"
brain_mask_path = DATA_DIR / "brain_mask_mask.nii.gz"

mri_np = load_data(mri_path)
ct_np = load_data(ct_path)
mask_np = load_data(mask_path)
brain_mask_np = load_data(brain_mask_path)

IMG_CROP = (192, 192, 192)

mri_np_crop = center_crop_or_pad_3d(mri_np, IMG_CROP)
ct_np_crop = center_crop_or_pad_3d(ct_np, IMG_CROP)
mask_np_crop = center_crop_or_pad_3d(mask_np, IMG_CROP) if mask_np is not None else None
brain_mask_np_crop = center_crop_or_pad_3d(brain_mask_np, IMG_CROP)

In [None]:
mask_np_crop = brain_mask_np_crop

In [None]:
print(ct_np_crop.shape)



In [None]:
x_dim, y_dim, z_dim = ct_np_crop.shape
axes = "x"  # or "y", "x"

ct_clip = (-1000 , 1000)
scale = 0.05
n=1

num_subj = n

if axes == "z":
    num_slices = z_dim
elif axes == "y":
    num_slices = y_dim
elif axes == "x":
    num_slices = x_dim
else:
    raise ValueError(f"Unknown axis {axes!r}")

valid_slices = []

for subj_idx in range(num_subj):
    mri  = mri_np_crop[ :, : ,:]
    ct   = ct_np_crop[ :, : ,:]
    mask = mask_np_crop[ :, : ,:] if mask_np_crop is not None else None

    if mask is not None:
        mask = (mask > 0).astype(np.float32)
        mri  = mri * mask

    ct_hard = mask_soft_tissue_only(ct, mask, min_hu=ct_clip[0], max_hu=ct_clip[1])
    ct_hard = norm(ct_hard, mask)
    ct_soft = sigmoid_soft_clip(ct_hard, scale=scale)

    mri = norm_mri(mri, mask)
    ct  = norm(ct_soft, mask)

    for slice_idx in range(num_slices):
        mri_slice, ct_slice, mask_slice = get_slice_arrays(axes, slice_idx, mri, ct, mask)
        if check_nan(mri_slice, "MRI slice") < 0 or check_nan(ct_slice, "CT slice") < 0 or (mask_slice is not None and check_nan(mask_slice, "Mask slice") < 0):
            continue
        if (is_all_zero(mri_slice) or is_all_zero(ct_slice) or(mask_slice is not None and is_all_zero(mask_slice))):
            continue
        valid_slices.append((mri_slice, ct_slice, mask_slice))

data_len = len(valid_slices)
print(f"Number of valid slices: {data_len}")


In [None]:
idx = 62

mri_img, ct_img, mask_img = valid_slices[idx]

mri_slice = to_tensor(mri_img)
ct_slice = to_tensor(ct_img)
mask_slice = to_tensor(mask_img) if mask_img is not None else None


# Show CT/MRI middle slice
fig, axes = plt.subplots(1, 3, figsize=(12, 4))
fig.patch.set_facecolor('black')

titles = ['CT Middle Slice', 'MRI Middle Slice' , 'MASK']
images = [ct_img, mri_img, mask_img]
cmaps = ['gray', 'gray', 'gray']

for ax, img, title, cmap in zip(axes, images, titles, cmaps):
    ax.imshow(img, cmap=cmap)
    ax.set_title(title, color='white', fontsize=14)
    ax.axis('off')
    ax.set_facecolor('black')

plt.subplots_adjust(wspace=0.05, hspace=0, left=0.01, right=0.99, top=0.92, bottom=0.08)
plt.show()

In [None]:
ckpt = 1
config.log_lv = "INFO"
run_idx = 0
idx = 102

log_v = ROOT_DIR /'log/log_2025_07_31_clip_otflow_baseline_192_192'
log_v = config.run_dir
config.model_type = "diffusion"

best = True


In [None]:
if best:
    ckpt_path = Path(f"{log_v}/{run_idx:05d}_train/checkpoints/best/checkpoint_{ckpt}.ckpt")
else:
    ckpt_path = Path(f"{log_v}/{run_idx:05d}_train/checkpoints/checkpoint_{ckpt}.ckpt")

print("Loading checkpoint...")
checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=True)

In [None]:
model_config_dict = checkpoint.get("model_config", {})
gmm_model_config_dict = checkpoint.get("gmm_model_config", {})

modelconfig = ModelConfig(**model_config_dict) if model_config_dict else ModelConfig()
gmm_modelconfig = GMM_ModelConfig(**gmm_model_config_dict) if gmm_model_config_dict else GMM_ModelConfig()


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

diff = get_network(
            device=device,
            model_type=config.model_type,
            modelconfig=modelconfig,
            gmm_modelconfig=gmm_modelconfig,
        )
diff = diff.to(device)
print("Model instantiated and moved to device.")

state_dict = checkpoint["model_state_dict"] if "model_state_dict" in checkpoint else checkpoint

print("Loading model state dict...")
missing, unexpected = diff.load_state_dict(state_dict, strict=True)
print("Missing keys:", missing)
print("Unexpected keys:", unexpected)
print("Model loaded successfully.")

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Reconstruct (denoise) using the model
diff.eval()
output = diff(
        cond=ct_slice.to(device),
        interval=5,
        mode="recon",
    )

In [None]:
vmin, vmax = 0.0, 1.0

fig, axes = plt.subplots(1, 3, figsize=(12, 5))
fig.patch.set_facecolor("black")  

out_img = output[0,0].cpu().numpy(),

titles = ["CT", "Gen", "MRI"]
images = [
    ct_slice[0].squeeze().cpu().detach().numpy(),
    output[0].squeeze().cpu().detach().numpy(),
    mri_slice[0].squeeze().cpu().detach().numpy(),
]

for ax, img, title in zip(axes, images, titles):
    ax.imshow(img, cmap="gray", vmin=vmin, vmax=vmax)
    ax.set_title(title, color='white', fontsize=16)
    ax.axis("off")
    ax.set_facecolor("black")

plt.subplots_adjust(wspace=0.02, hspace=0, left=0.01, right=0.99, top=0.92, bottom=0.08)

plt.show()

# Reconstruction loss
loss = F.mse_loss(output, mri_slice.to(device)).item()
print(f"Reconstruction loss: {loss:.6f}")

loss_func = get_loss_func(config.loss_model)

loss_mse = loss_func(output, mri_slice.to(device))

loss_mse = safe_mean_loss(loss_mse, dims=(1, 2, 3), keepdim=True)
if loss_mse.dim() == 0:
    loss_mse = loss_mse.view(1, 1, 1, 1)
elif loss_mse.dim() == 1:
    loss_mse = loss_mse.view(-1, 1, 1, 1)
psnr = calculate_psnr(output, mri_slice.to(device), mask_slice.to(device))
ssim = calculate_ssim(output, mri_slice.to(device), mask_slice.to(device))

loss_mse = loss_mse.cpu().detach().numpy()
psnr = psnr.cpu().detach().numpy()
ssim = ssim.cpu().detach().numpy()

loss_mse = np.mean(loss_mse)
psnr = np.mean(psnr)
ssim = np.mean(ssim)

print(f"loss: {loss_mse:.6f}")
print(f"psnr: {psnr:.2f}dB")
print(f"ssim: {ssim:.4f}")

