In [1]:
import os
import nrrd 
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt

In [2]:
# datafolder = 'data'
# split = 'val'
# img_dir = f'{datafolder}/img/{split}'
# seg_dir = f'{datafolder}/seg/all/{split}'
# result_dir = f'runs/ddim-AVT-256-segguided-20250725-162416/checkpoint-epoch400/samples_many_1000'

# allfiles = os.listdir(img_dir)

# for filename in allfiles[0:2]:
#     data, header = nrrd.read(os.path.join(img_dir, filename))
#     seg, seg_header = nrrd.read(os.path.join(seg_dir, filename))

#     print(data.shape, seg.shape)

#     for z in range(data.shape[2]):
#         org_CT_slice = data[:, :, z]
#         org_seg_slice = seg[:, :, z]

#         stem = os.path.splitext(os.path.basename(filename))[0]
#         syn_CT_slice = Image.open(os.path.join(result_dir, f'condon_{stem}_axial_{z:04d}.png'))
#         syn_CT_slice = np.array(syn_CT_slice)

#         print(syn_CT_slice.shape, org_CT_slice.shape, org_seg_slice.shape)

#         # Plot orginal CT slice, synthetic CT slice, segmentation overlayyed on synthetic CT slice
#         plt.figure(figsize=(12, 4)) 
#         plt.subplot(1, 3, 1)
#         plt.imshow(org_CT_slice, cmap='gray')
#         plt.title('Original CT Slice')
#         plt.axis('off')

#         plt.subplot(1, 3, 2)
#         plt.imshow(syn_CT_slice[0], cmap='gray')
#         plt.title('Synthetic CT Slice')
#         plt.axis('off')

#         plt.subplot(1, 3, 3)
#         plt.imshow(syn_CT_slice[0], cmap='gray')
#         plt.imshow(org_seg_slice, alpha=0.5, cmap='jet')
#         plt.title('Segmentation Overlay on Synthetic CT Slice')
#         plt.axis('off')


In [3]:
import os
from pathlib import Path

import numpy as np
import torch
import torch.nn.functional as F
import nrrd
import matplotlib.pyplot as plt
from PIL import Image

from monai.transforms import (
    Compose, LoadImaged, EnsureChannelFirstd, Orientationd, Spacingd,
    ScaleIntensityRanged, EnsureTyped
)

TARGET_SPACING = (0.8, 0.8, 3.0) 
WINDOW = (-1000, 1000)
OUT_SIZE = 256

def pad_crop_2d(t: torch.Tensor, size: int = OUT_SIZE) -> torch.Tensor:
    if t.ndim == 2:
        t = t.unsqueeze(0)
    _, H, W = t.shape
    pad_h = max(0, size - H)
    pad_w = max(0, size - W)
    if pad_h > 0 or pad_w > 0:
        # pad order: (left, right, top, bottom)
        t = F.pad(t, (0, pad_w, 0, pad_h))
    H2, W2 = t.shape[1:]
    sh = (H2 - size) // 2
    sw = (W2 - size) // 2
    return t[:, sh:sh+size, sw:sw+size]

def build_eval_xform(seg_key="seg_all", have_image=True):
    keys = (["image"] if have_image else []) + [seg_key]
    spacing_modes = ("bilinear", "nearest") if have_image else ("nearest",)

    steps = [
        LoadImaged(keys=keys),
        EnsureChannelFirstd(keys=keys),
        Orientationd(keys=keys, axcodes="RAS"),
        Spacingd(keys=keys, pixdim=TARGET_SPACING, mode=spacing_modes),
        ScaleIntensityRanged(
                keys="image",
                a_min=WINDOW[0], a_max=WINDOW[1],
                b_min=-1.0, b_max=1.0, clip=True
            ),
        EnsureTyped(keys=keys),
    ]
    return Compose(steps)

def to_display(arr: np.ndarray):
    if arr.ndim == 3 and arr.shape[0] == 1:
        arr = arr[0]
    if arr.dtype == np.uint8:
        return arr / 255.0
    return (arr + 1.0) / 2.0

def read_syn_png(path: str) -> np.ndarray:
    img = Image.open(path)
    x = np.array(img)
    # If RGB, convert to grayscale for consistent display (simple average)
    if x.ndim == 3 and x.shape[-1] in (3, 4):
        x = x[..., :3].mean(axis=-1)
    return x

def plot_case(img_path: str, seg_path: str, result_dir: str, max_slices: int = None, save_dir: str = None):
    seg_key = "seg_all"
    xform = build_eval_xform(seg_key=seg_key, have_image=(img_path is not None))

    data_dict = {}
    if img_path is not None:
        data_dict["image"] = img_path
    data_dict[seg_key] = seg_path

    data = xform(data_dict) 
    stem = Path(img_path if img_path is not None else seg_path).stem

    if img_path is not None:
        img = data["image"]   
        H, W, D = img.shape[1:]
    else:
        D = data[seg_key].shape[-1]

    seg = data[seg_key]   

    os.makedirs(save_dir, exist_ok=True) if save_dir else None

    num = D if max_slices is None else min(D, max_slices)
    for z in range(num):
        syn_name = f"condon_{stem}_axial_{z:04d}.png"
        syn_path = os.path.join(result_dir, syn_name)
        if not os.path.exists(syn_path):
            continue

        syn = read_syn_png(syn_path)     
        syn_disp = syn / 255.0 if syn.dtype == np.uint8 else syn

        if img_path is not None:
            img_z = pad_crop_2d(img[..., z]).cpu().numpy()  
            img_disp = to_display(img_z)
        else:
            img_disp = None

        seg_z = pad_crop_2d(seg[..., z]).cpu().numpy()      
        seg_mask = seg_z[0] if seg_z.ndim == 3 else seg_z   

        # --- Plot ---
        plt.figure(figsize=(12, 4))
        plt.suptitle(f"Case: {stem}, Slice: {z+1}/{num}", fontsize=16)

        # 1) Transformed original CT slice (what the model saw)
        plt.subplot(1, 3, 1)
        if img_disp is not None:
            plt.imshow(img_disp, cmap='gray', vmin=0, vmax=1)
        else:
            plt.text(0.5, 0.5, "No image", ha='center', va='center')
        plt.title('Transformed CT (train space)')
        plt.axis('off')

        # 2) Synthetic CT slice
        plt.subplot(1, 3, 2)
        plt.imshow(syn_disp, cmap='gray')
        plt.title('Synthetic CT')
        plt.axis('off')

        # 3) Segmentation overlay on synthetic CT
        plt.subplot(1, 3, 3)
        plt.imshow(syn_disp, cmap='gray')
        # if multi-class, you can show edges or per-class colors
        plt.imshow(seg_mask, alpha=0.35, cmap='jet')
        plt.title('Seg overlay on Synthetic')
        plt.axis('off')

        if save_dir:
            out_path = os.path.join(save_dir, f"vis_{stem}_axial{z:04d}.png")
            plt.savefig(out_path, bbox_inches='tight', dpi=150)
            plt.close()
        else:
            plt.show()

# ------------------------------------------------------
datafolder = 'data'
split = 'val'
img_dir = f'{datafolder}/img/{split}'
seg_dir = f'{datafolder}/seg/all/{split}'
result_dir = 'runs/ddim-AVT-256-segguided-20250725-162416/checkpoint-epoch400/samples_many_1000'
save_dir = 'results/visualizations/ddim-AVT-256-segguided-20250725-162416-checkpoint-epoch400' 

cases = sorted(os.listdir(img_dir))
for filename in cases:
    img_path = os.path.join(img_dir, filename)
    seg_path = os.path.join(seg_dir, filename)
    plot_case(img_path, seg_path, result_dir, max_slices=None, save_dir=save_dir)
