# One sided Unpaired Medical Image Translation with Normalized Edge Priors

In [None]:
import nibabel as nib
import os, glob
import time
import warnings
import matplotlib.pyplot as plt
import numpy as np
import torch

from tqdm import tqdm
from monai.config import print_config
from monai.data import DataLoader, Dataset, CacheDataset, SmartCacheDataset, decollate_batch
from monai.inferers import SliceInferer , sliding_window_inference
from monai.metrics.regression import compute_ssim_and_cs
from monai.utils import set_determinism, first
from monai.transforms import(
    Compose,
    LoadImaged,
    SaveImage,
    EnsureChannelFirstd,
    SqueezeDimd, 
    EnsureTyped,
    ResampleToMatchd,
    RandSpatialCropSamplesd, 
    ScaleIntensityRangePercentilesd, 
    ScaleIntensityRanged, 
    Resized,
    CropForegroundd, 
    CenterSpatialCropd, 
    RandZoomd
)
# print_config()
from datetime import date
today = str(date.today()).replace('-','').replace(' ', '')
gpu_device = torch.device(f'cuda:{0}')

Get original BraTS images, define splits, slice them and write into a 2D folder

In [None]:
suffix2d = '_2d_transformed'
import glob
MRs='MR'
CTs='CT'
Masks = 'MASK'
PIDs_ALL = [i.split('/')[-1] for i in glob.glob(os.path.join(MRs,'Gamma*'))]
N=round(len(PIDs_ALL)*0.7)
np.random.seed(29100)
np.random.shuffle(PIDs_ALL)

1) Train

In [None]:
PIDs_train = PIDs_ALL[:N]

# np.random.shuffle(PIDs_train_B)

fnames_train_A_3d = [os.path.join(MRs,PID) for PID in PIDs_train]
fnames_train_B_3d = [os.path.join(CTs,PID) for PID in PIDs_train]
fnames_train_C_3d = [os.path.join(CTs+'_Seg_Mask',PID.split('.')[0]+'_mask','body_extremities.nii.gz') for PID in PIDs_train]
train_dic_3d = [{"SRC": img1, "TGT": img2, "MASK": img3} for (img1,img2,img3) in zip(
    fnames_train_A_3d, 
    fnames_train_B_3d,
    fnames_train_C_3d
)] 

2. Val

In [None]:
PIDs_val = PIDs_ALL[N:N+round(len(PIDs_ALL)*0.2)]
N = N + round(len(PIDs_ALL)*0.2)


fnames_val_A_3d = [os.path.join(MRs,PID) for PID in PIDs_val]
fnames_val_B_3d = [os.path.join(CTs,PID) for PID in PIDs_val]
fnames_val_C_3d = [os.path.join(CTs+'_Seg_Mask',PID.split('.')[0]+'_mask','body_extremities.nii.gz') for PID in PIDs_val]
val_dic_3d = [{"SRC": img1, "TGT": img2, "MASK":img3} for (img1,img2,img3) in zip(
    fnames_val_A_3d, 
    fnames_val_B_3d,
    fnames_val_C_3d
)] 

In [None]:
PIDs_test = PIDs_ALL[N:N+round(len(PIDs_ALL)*0.1)]

In [None]:
sorted(PIDs_test)

# 3D to 2D pre-processing 

#### Conversion 3D to 2D

##### Custom Rand spatial transform to reject slices that are too "black" (too many zeroes).

In [None]:
class RandSpatialCropSamplesdWithMinNonZero(RandSpatialCropSamplesd):
    def __init__(self,target, min_nonzero: int=1000, *args, **kwargs):
        """
        Initialize the custom transform.
        :param min_nonzero: Minimum number of non-zero pixels required in the crop.
        :param args: Arguments for RandSpatialCropSamplesd.
        :param kwargs: Keyword arguments for RandSpatialCropSamplesd.
        """
        super().__init__(*args, **kwargs)
        self.min_nonzero = min_nonzero
        if target:
            self.target = target
        

    def __call__(self, data, lazy=False):
        """
        Generate samples and filter them based on the non-zero pixel count.
        :param data: Input data.
        :return: Modified data with sampled crops meeting the non-zero pixel requirement.
        """
        # Use the base class to generate initial samples.
        samples = super().__call__(data)
        valid_samples = []

        # Check each sample for the non-zero pixel condition.
        for sample in samples:
            for key in self.keys:
                if key == self.target:
                    #checking if at least is there some stuff in the slice 
                    if np.sum(sample[key] > 0.4) >= self.min_nonzero:
                        valid_samples.append(sample)
                        break  # Assuming at least one key meets the condition is enough.

        # Ensure we have at least one valid sample to avoid empty returns.
        if not valid_samples:
            raise ValueError("No valid samples found. Consider adjusting the min_nonzero parameter.")

        return valid_samples

Apply transforms : Perform the 3D to 2D slicing and transforms, then write to disk

In [None]:
NUM_SAMPLES_MAX=30
train_transforms = Compose(
    [
        LoadImaged(keys=["SRC", "TGT", "MASK"], image_only=False),
        EnsureChannelFirstd(keys=["SRC", "TGT", "MASK"]),
        ScaleIntensityRangePercentilesd(keys=["SRC"], lower=1, upper=99.9, b_min=0,b_max=1, clip=True),
        CenterSpatialCropd(keys=["SRC", "TGT", "MASK"],roi_size=(256, 256, -1)),
        RandSpatialCropSamplesdWithMinNonZero(keys=["SRC", "TGT", "MASK"],target="SRC", roi_size=(-1,-1,1), random_size=False, num_samples=NUM_SAMPLES_MAX),

        ScaleIntensityRanged(keys=["SRC"], a_min=0, a_max=1, b_min=-1, b_max=1),
        ScaleIntensityRanged(keys=["TGT"], a_min=-1000, a_max= 3000, b_min=-1, b_max=1),

        # Resized(keys=["SRC", "TGT"], spatial_size=[256,256,-1], mode="trilinear"), # make it 256**2 to make sure we downsample correctly
        SqueezeDimd(keys=["SRC", "TGT", "MASK"], dim=0),
    ]
)



valtest_transforms = Compose(
    [
        LoadImaged(keys=["SRC", "TGT", "MASK"], image_only=False),
        EnsureChannelFirstd(keys=["SRC", "TGT", "MASK"]),

        ScaleIntensityRangePercentilesd(keys=["SRC"], lower=1, upper=99.9, b_min=0,b_max=1, clip=True),
        # RandSpatialCropSamplesdWithMinNonZero(keys=["SRC", "TGT"], roi_size=(-1,-1,1), random_size=False, num_samples=1),

        ScaleIntensityRanged(keys=["SRC"], a_min=0, a_max = 1, b_min=-1, b_max=1),
        ScaleIntensityRanged(keys=["TGT"], a_min=-1000, a_max = 3000, b_min=-1, b_max=1),

        CenterSpatialCropd(keys=["SRC", "TGT", "MASK"],roi_size=(-1, -1, 1)),
        # Resized(keys=["SRC", "TGT"], spatial_size=[256,256,-1], mode="trilinear"), # make it 256**2 to make sure we downsample correctly
        SqueezeDimd(keys=["SRC", "TGT", "MASK"], dim=0),
    ]
)


In [None]:
def spatial_dim_fixer_3d(target_nii,tgt_affine,src_affine):
    newclip = target_nii
    if tgt_affine[0] != src_affine[0]:
        # check if they mirrored on X
        newclip = np.flip(newclip,axis=0) 
    if tgt_affine[1] != src_affine[1]:
        # check if they mirrored on Y
        newclip = np.flip(newclip,axis=1) 
    return newclip

1) Train

In [None]:
PROCESS_DATA=True

In [None]:
if PROCESS_DATA:
    train_mr_outdir=os.path.join(MRs+suffix2d,'train_mr')
    os.makedirs(train_mr_outdir, exist_ok=True)

    train_ct_outdir=os.path.join(CTs + suffix2d,'train_ct')
    os.makedirs(train_ct_outdir, exist_ok=True)

    train_mask_outdir = os.path.join(Masks + suffix2d,'train_mask')
    os.makedirs(train_mask_outdir, exist_ok=True)

    for i in tqdm(range(len(train_dic_3d)),'Train'):       
        transformed_image = train_transforms(train_dic_3d[i])

        train_mr_fname_ref = transformed_image[0]['SRC_meta_dict']['filename_or_obj']
        train_mr_nii_ref = nib.load(train_mr_fname_ref)

        train_ct_fname_ref = transformed_image[0]['TGT_meta_dict']['filename_or_obj']
        train_ct_nii_ref = nib.load(train_ct_fname_ref)
        
        train_mask_fname_ref = transformed_image[0]['MASK_meta_dict']['filename_or_obj']
        train_mask_nii_ref = nib.load(train_mask_fname_ref)

        for j in range(len(transformed_image)):
            fname_out = os.path.join(train_mr_outdir, os.path.split(train_mr_fname_ref)[-1].split('.')[0]+('_sample%.3d.nii.gz' % j))
            nib.save(nib.Nifti1Image(transformed_image[j]['SRC'][:,:,0].numpy(), None, train_mr_nii_ref.header), fname_out)

            
            fname_out = os.path.join(train_ct_outdir, os.path.split(train_ct_fname_ref)[-1].split('.')[0]+('_sample%.3d.nii.gz' % j))
            fname_out_mask = os.path.join(train_mask_outdir, train_mask_fname_ref.split('/')[-2]+('_sample%.3d.nii.gz' % j))

            try:
                if train_ct_nii_ref.header.get_zooms()==train_mr_nii_ref.header.get_zooms():
                    # flip the slice if they mirrored and save it with same spatial coords as reference one
                    nib.save(nib.Nifti1Image(spatial_dim_fixer_3d(transformed_image[j]['TGT'][:,:,0].numpy(),train_ct_nii_ref.affine[:3,3],train_mr_nii_ref.affine[:3,3]), train_mr_nii_ref.affine,train_ct_nii_ref.header), fname_out)
            except:
                warnings.warn('The Niftis files don"t contain same real-world coordinates (mm)')
                nib.save(nib.Nifti1Image(transformed_image[j]['TGT'][:,:,0].numpy(), None,train_ct_nii_ref.header), fname_out)
            try:
                if train_mask_nii_ref.header.get_zooms()==train_mr_nii_ref.header.get_zooms():
                    nib.save(nib.Nifti1Image(spatial_dim_fixer_3d(transformed_image[j]['MASK'][:,:,0].numpy(),train_mask_nii_ref.affine[:3,3],train_mr_nii_ref.affine[:3,3]), train_mr_nii_ref.affine,train_mask_nii_ref.header), fname_out_mask)
            except:
                warnings.warn('The Niftis files don"t contain same real-world coordinates (mm)')
                nib.save(nib.Nifti1Image(transformed_image[j]['MASK'][:,:,0].numpy(), None,train_mask_nii_ref.header), fname_out_mask)
            


2) Val

In [None]:
if PROCESS_DATA:
    val_mr_outdir=os.path.join(MRs + suffix2d,'val_mr')
    os.makedirs(val_mr_outdir, exist_ok=True)
    
    val_ct_outdir=os.path.join(CTs + suffix2d,'val_ct')
    os.makedirs(val_ct_outdir, exist_ok=True)

    val_mask_outdir = os.path.join(Masks + suffix2d,'val_mask')
    os.makedirs(val_mask_outdir, exist_ok=True)

    for i in tqdm(range(len(val_dic_3d)),'Val'):
        transformed_image = valtest_transforms(val_dic_3d[i])

        fname_ref = transformed_image['SRC_meta_dict']['filename_or_obj']
        nii_ref = nib.load(fname_ref)
        mr_ref = nii_ref
        fname_out = os.path.join(val_mr_outdir, os.path.split(fname_ref)[-1].split('.')[0]+('_sCenter.nii.gz'))
        nib.save(nib.Nifti1Image(transformed_image['SRC'][:,:,0].numpy(), None, nii_ref.header), fname_out)

        fname_ref = transformed_image['TGT_meta_dict']['filename_or_obj']
        nii_ref = nib.load(fname_ref)
        fname_out = os.path.join(val_ct_outdir, os.path.split(fname_ref)[-1].split('.')[0]+('_sCenter.nii.gz'))

        fname_ref = transformed_image['MASK_meta_dict']['filename_or_obj']
        nii_ref_mask = nib.load(fname_ref)
        fname_out_mask = os.path.join(val_mask_outdir, fname_ref.split('/')[1]+('_sCenter.nii.gz'))
        
        try:
            if nii_ref.header.get_zooms()==mr_ref.header.get_zooms():
                nib.save(nib.Nifti1Image(spatial_dim_fixer_3d(transformed_image['TGT'][:,:,0].numpy(),nii_ref.affine[:3,3],mr_ref.affine[:3,3]), mr_ref.affine,nii_ref.header), fname_out)
        except:
            warnings.warn('The Niftis files don"t contain same real-world coordinates (mm)')
            nib.save(nib.Nifti1Image(transformed_image['TGT'][:,:,0].numpy(), None,nii_ref.header), fname_out)

        try:
            if nii_ref_mask.header.get_zooms()==mr_ref.header.get_zooms():
                nib.save(nib.Nifti1Image(spatial_dim_fixer_3d(transformed_image['MASK'][:,:,0].numpy(),nii_ref_mask.affine[:3,3],mr_ref.affine[:3,3]), mr_ref.affine,nii_ref_mask.header), fname_out_mask)
        except:
            warnings.warn('The Niftis files don"t contain same real-world coordinates (mm)')
            nib.save(nib.Nifti1Image(transformed_image['MASK'][:,:,0].numpy(), None,nii_ref_mask.header), fname_out_mask) 

# 2D processing 

In [None]:
transforms_2d = Compose(
    [
        LoadImaged(keys=["SRC", "TGT","MASK"], image_only=False),
        EnsureChannelFirstd(keys=["SRC", "TGT","MASK"]),
    ]
)

BATCH_SIZE=1
NUM_WORKERS=4

MRs_sufx = os.path.join(MRs + suffix2d)
CTs_sufx = os.path.join(CTs + suffix2d)
Masks_sufix = os.path.join(Masks + suffix2d)

- Train

In [None]:
fnames_train_A_2d = sorted(glob.glob(os.path.join(MRs_sufx, 'train_mr', '*.nii.gz')))
fnames_train_B_2d = sorted(glob.glob(os.path.join(CTs_sufx, 'train_ct', '*.nii.gz')))
fnames_train_C_2d = sorted(glob.glob(os.path.join(Masks_sufix, 'train_ct', '*.nii.gz')))
train_dic_2d = [{"SRC": img1, "TGT": img2, "MASK": img3} for (img1,img2,img3) in zip(
    fnames_train_A_2d, 
    fnames_train_B_2d,
    fnames_train_C_2d
)] 

train_ds = SmartCacheDataset(train_dic_2d, transforms_2d,replace_rate=0.2,cache_rate=0.5)
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)

- Val

In [None]:
fnames_val_A_2d = sorted(glob.glob(os.path.join(MRs_sufx, 'val_mr', '*.nii.gz')))
fnames_val_B_2d = sorted(glob.glob(os.path.join(CTs_sufx, 'val_ct', '*.nii.gz')))
fnames_train_C_2d = sorted(glob.glob(os.path.join(Masks_sufix, 'val_mask', '*.nii.gz')))
val_dic_2d = [{"SRC": img1, "TGT": img2, "MASK": img3} for (img1,img2,img3) in zip(
    fnames_val_A_2d, 
    fnames_val_B_2d,
    fnames_train_C_2d
)] 

val_ds = SmartCacheDataset(val_dic_2d, transforms_2d,replace_rate=0.2,cache_rate=0.5)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)

In [None]:
check_data = first(val_loader)
print("first patch's shape: ", check_data["SRC"].shape)
plt.figure(figsize=(15,15))
plt.subplot(1,4,1)
plt.imshow(check_data["TGT"][0,0,:,:].detach().cpu().numpy().squeeze(), vmin=-1, vmax=1, cmap="gray")
plt.title('Target')
plt.subplot(1,4,2)
plt.imshow(check_data["SRC"][0,0,:,:].detach().cpu().numpy().squeeze(), vmin=-1, vmax=1, cmap="gray")
plt.title('Source')
plt.subplot(1,4,3)
plt.imshow(check_data["MASK"][0,0,:,:].detach().cpu().numpy().squeeze(), vmin=0, vmax=1, cmap="gray")
plt.title('mask')

# Network training

In [None]:
from vjnetworks import Pix2Pix

In [None]:
class Options():
    def __init__(self):
        # model parameters
        self.in_channels = 1  # Adjust according to your input image channel dimensions
        self.out_channels = 1  # Adjust according to your output image channel dimensions
        self.num_filters_d = 128  # Adjust the number of filters in the discriminator
        self.num_layers_d = 4  # Adjust the number of layers in the discriminator (i.e. the receptive field)
        self.num_d = 2
        self.num_res_units_G = 10
        self.lambda_gan = 1 # Adjust the weight for the cycle consistency loss
opt=Options()

In [None]:
import torch.optim
import torch.nn.functional as F
import monai.networks.nets as nets

pix2pix_model = Pix2Pix(
    in_channels=opt.in_channels,
    out_channels=opt.out_channels,
    num_d=opt.num_d,
    num_layers_d=opt.num_layers_d,
    num_filters_d=opt.num_filters_d,
    num_res_units_G=opt.num_res_units_G,
)

# Inference

In [None]:
# PIDs_test = [file.split('_')[0] + '.nii.gz' for file in os.listdir(os.path.join('MR'+suffix2d,"test_mr"))]
fnames_test_A = sorted([os.path.join(MRs,PID) for PID in PIDs_test])
fnames_test_B = sorted([os.path.join(CTs,PID) for PID in PIDs_test])

test_dic = [{"SRC": img1, "TGT": img2} for (img1,img2) in zip(
    fnames_test_A, 
    fnames_test_B 
)] 
print(test_dic[0])

In [None]:
from monai.transforms import Compose, Transform, MapTransform
from monai.transforms import ScaleIntensityRangePercentiles
NUM_WORKERS=1
BATCH_SIZE=1
test_transforms = Compose(
    [
        LoadImaged(keys=["SRC"], image_only=False),
        EnsureChannelFirstd(keys=["SRC"]),
        ScaleIntensityRangePercentilesd(keys=["SRC"], lower=1, upper=99.9, b_min=-1,b_max=1, clip=True),
        EnsureTyped(keys=["SRC"], dtype=torch.float32),
        # Resized(keys=["SRC"], spatial_size=[256,256,-1], mode="trilinear"), # make it 256**2 to make sure we downsample correctly
        # CenterSpatialCropd(keys=["SRC"],roi_size=(256, 256, -1)),
        # SqueezeDimd(keys=["SRC", "TGT"], dim=3),
    ]
)

from monai.transforms import Invertd
test_post_transforms = Compose(
    [
        Invertd(
        keys='PRED', 
        transform=test_transforms, 
        orig_keys="SRC"
        )
    ]
)

In [None]:
test_ds = Dataset(test_dic, test_transforms)
test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)

In [None]:
EPOCHS=[116]
ROI_SIZE = 256
EPOCHS

In [None]:
EXPERIMENT_PREFIX="Gamma_crop-n2-l4-f128_GAN1_L140.00_NGF20.00_a0.08"
for EPOCH in EPOCHS:
    weights_dir='Weights/'+EXPERIMENT_PREFIX
    # weights_path = os.path.join(weights_dir, EXPERIMENT_PREFIX+'_e%.4d.h5' % EPOCH)
    # weights_dir='Weights/'+"Gamma_crop-n2-l4-f128_GAN1_L100.00_NGF20.00_a0.10"
    # weights_path = os.path.join(weights_dir, "Gamma_crop-n2-l4-f128_GAN1_L100.00_NGF20.00_a0.10"+'_e%.4d.h5' % EPOCH)
    # checkpoint = torch.load(weights_path)
    checkpoint = torch.load("Gamma_crop-n2-l4-f128_GAN1_L140.00_NGF20.00_a0.08_e0116.h5")
    pix2pix_model.load_state_dict(checkpoint['model'], strict=False)
    model=pix2pix_model.generator_A_to_B
    model.eval()


    outdir='fakeCT_'+EXPERIMENT_PREFIX+'_e%.4d' % EPOCH
    os.makedirs(outdir, exist_ok=True)


    roi_size = ((ROI_SIZE,ROI_SIZE))
    sw_batch_size=2
    slice_inferer = SliceInferer(
        roi_size=roi_size,
        sw_batch_size=sw_batch_size,
        overlap=0.75,
        mode='gaussian',
        sigma_scale=0.5, # use if mode="gaussian"
        spatial_dim=2,  # axial inference
        device=gpu_device,
        padding_mode="replicate",
    )
    with torch.no_grad():
        for i,test_data in enumerate(test_loader):
            fname = test_dic[i]['SRC']
            nii = nib.load(fname)
            img = nii.get_fdata()
            mask = img==0

            test_images =test_data["SRC"].to(gpu_device)

            # '''
            #256
            # FAKE
            test_data["PRED"] = slice_inferer(test_images, model)
            OUT  = [test_post_transforms(i) for i in decollate_batch(test_data)][0]
            output = OUT['PRED'].detach().cpu().numpy().squeeze(axis=0)
            output = ((output + 1) / 2) * 4000 - 1000
            output[mask] = -1000
            # '''
            
            # # 512
            # # FAKE
            # test_data["PRED"] = sliding_window_inference(
            #             roi_size=roi_size,
            #             inputs=test_images,
            #             sw_batch_size=sw_batch_size,
            #             predictor=model,
            #             overlap=0.75,
            #             mode='gaussian',
            #             sigma_scale=0.5,
            #             device=gpu_device,
            #             padding_mode="replicate",
            #         )
            # output = test_data['PRED'].detach().cpu().numpy()[0, 0]
            # output = ((output + 1) / 2) * 4000 - 1000
            # output[mask] = -1000
            

            fname_out = os.path.join(outdir, fname.split('/')[-2].split('_')[-1]+'_'+fname.split('/')[-1].split('.')[0]+'_pTGT.nii.gz')
            nii_out = nib.Nifti1Image(output.astype(np.int16), nii.affine, nii.header)
            nib.save(nii_out, fname_out)
            print(fname_out)

# metrics

## ssim

In [None]:
import torch
import torch.nn.functional as F
def dilate2d_volume(M, pixels=2):
    # M: [B,1,D,H,W], dilate each axial slice with max-pool
    k = 2*pixels+1
    return F.max_pool3d(M, kernel_size=(1,k,k), stride=1, padding=(0,pixels,pixels))

def _gauss1d(win, sigma, device):
    ax = torch.arange(win, device=device, dtype=torch.float32) - (win - 1) / 2
    g = torch.exp(-(ax**2) / (2 * sigma**2))
    return g / g.sum()

def _gauss3d_kernel(kw, kh, kd, sx, sy, sz, device):
    gx = _gauss1d(kw, sx, device)
    gy = _gauss1d(kh, sy, device)
    gz = _gauss1d(kd, sz, device)
    K = gz[:, None, None] * gy[None, :, None] * gx[None, None, :]
    K = K / K.sum()
    return K.view(1, 1, kd, kh, kw)

def masked_ssim_3d(pred, gt, mask, data_range=2.0,
                   win_size=(7,7,7), sigma=(1.5,1.5,1.5), eps=1e-8):
    """
    pred, gt, mask: torch.Tensor [1,1,D,H,W] or [B,1,D,H,W], float32 in [-1,1], mask in {0,1}
    Returns: scalar masked SSIM (float)
    """
    assert pred.shape == gt.shape == mask.shape
    if pred.ndim == 4:  # [1, D,H,W] -> [1,1,D,H,W]
        pred = pred.unsqueeze(0)
        gt   = gt.unsqueeze(0)
        mask = mask.unsqueeze(0)
    assert pred.ndim == 5 and pred.shape[1] == 1

    B, _, D, H, W = pred.shape
    kd, kh, kw = win_size
    sz, sy, sx = sigma
    # ensure odd window and not exceeding dims
    kd = max(3, min(kd, D));  kd = kd if kd % 2 == 1 else kd-1
    kh = max(3, min(kh, H));  kh = kh if kh % 2 == 1 else kh-1
    kw = max(3, min(kw, W));  kw = kw if kw % 2 == 1 else kw-1

    K = _gauss3d_kernel(kw, kh, kd, sx, sy, sz, pred.device)
    pad = (kw//2, kh//2, kd//2)

    # windowed mask support
    Wm = F.conv3d(mask, K, padding=pad) + eps  # [B,1,D,H,W]

    # masked local means
    mu_x = F.conv3d(pred*mask, K, padding=pad) / Wm
    mu_y = F.conv3d(gt  *mask, K, padding=pad) / Wm

    mu_x2, mu_y2 = mu_x**2, mu_y**2
    sigma_x2 = F.conv3d((pred**2)*mask, K, padding=pad) / Wm - mu_x2
    sigma_y2 = F.conv3d((gt**2)  *mask, K, padding=pad) / Wm - mu_y2
    sigma_xy = F.conv3d((pred*gt)*mask, K, padding=pad) / Wm - mu_x*mu_y

    C1 = (0.01 * data_range) ** 2
    C2 = (0.03 * data_range) ** 2

    num = (2*mu_x*mu_y + C1) * (2*sigma_xy + C2)
    den = (mu_x2 + mu_y2 + C1) * (sigma_x2 + sigma_y2 + C2)
    ssim_map = num / (den + eps)  # [B,1,D,H,W]

    # weighted pooling by mask support (Wm), ignoring places with ~no mask in the window
    ssim_perB = (ssim_map*Wm).flatten(1).sum(1) / Wm.flatten(1).sum(1).clamp_min(1.0)  # [B]
    return float(ssim_perB.mean())

In [None]:
PRED_LOCATION = "fakeCT_Gamma_nocrop-n2-l4-f128_GAN1_L140.00_NGF0.00_a0.08_V2_e0067"
# PRED_LOCATION = "fakeCT_" + EXPERIMENT_PREFIX + "_e0041"

test_transforms = Compose(
    [
        LoadImaged(keys=["PRED", "CT", "MASK"], image_only=False),
        EnsureChannelFirstd(keys=["PRED", "CT", "MASK"]),
        ScaleIntensityRanged(keys=["PRED", "CT"], a_min=-1000, a_max=3000, b_min=-1, b_max=1),
        ResampleToMatchd(keys=["CT","MASK"], key_dst="PRED"),
        EnsureTyped(keys=["PRED","CT", "MASK"]),
    ]
)

# PIDs_test = [file.split('_')[0] + '.nii.gz' for file in os.listdir(os.path.join('CT'+suffix2d,"test_ct"))]
fnames_test_A = sorted(glob.glob(os.path.join(PRED_LOCATION,"*.gz")))
fnames_test_B = sorted([os.path.join(CTs,PID) for PID in PIDs_test])
fnames_test_D = sorted([glob.glob(os.path.join("CT_Seg_Mask",PID.split('.')[0]+"_mask","*ext*"))[0]for PID in PIDs_test])


test_dic = [{"PRED": img1, "CT": img2, "MASK": img4} for (img1,img2,img4) in zip(
    fnames_test_A, 
    fnames_test_B,
    fnames_test_D,
)] 


# test_transforms = Compose(
#     [
#         LoadImaged(keys=["PRED", "CT"], image_only=False),
#         EnsureChannelFirstd(keys=["PRED", "CT"]),
#         ScaleIntensityRanged(keys=["PRED", "CT"], a_min=-1000, a_max=3000, b_min=-1, b_max=1),
#         ResampleToMatchd(keys=["CT"], key_dst="PRED"),
#         EnsureTyped(keys=["PRED","CT"]),
#     ]
# )

# fnames_test_A = sorted(glob.glob(os.path.join(PRED_LOCATION,"*.gz")))
# fnames_test_B = sorted([os.path.join(CTs,PID) for PID in PIDs_test])

# test_dic = [{"PRED": img1, "CT": img2} for (img1,img2) in zip(
#     fnames_test_A, 
#     fnames_test_B,
# )] 

In [None]:

test_ds = Dataset(test_dic[8:], test_transforms)

# pred_dir = os.path.join("Preds",PRED_LOCATION)
# os.makedirs(pred_dir,exist_ok=True)

# mean_ssim = []
# gmean_ssim = []
for sample in test_ds:
    fprefix = sample['CT_meta_dict']['filename_or_obj'].split("/")[-1].split('.')[0]

    # fake = sample["PRED"].to(gpu_device)
    # center = fake.shape[-1] // 2
    # ct_ = sample["CT"].to(gpu_device)
    # mr_ = sample["MR"].to(gpu_device)
    

    # plt.figure(figsize=(15,5))
    # plt.subplot(1,3,1)
    # plt.title("MR")
    # plt.imshow(mr_[0,:,:,center].cpu().detach().numpy().squeeze(), vmin=-1, vmax=1, cmap='gray')
    # plt.subplot(1,3,2)
    # plt.title("CT")
    # # flipping the bozo ct
    # plt.imshow(ct_[0,:,:,center].cpu().detach().numpy().squeeze(), vmin=-1, vmax=1, cmap='gray')
    # plt.subplot(1,3,3)
    # plt.title("Fake")
    # plt.imshow(fake[0,:,:,center].cpu().detach().numpy().squeeze(), vmin=-1, vmax=1, cmap='gray')
    # fname_out = os.path.join(pred_dir, fprefix) 
    # plt.savefig(fname_out, bbox_inches='tight')
    # plt.close() 
    
    
    #compute ssim 
    pred = sample["PRED"].to(gpu_device).unsqueeze(0)
    tgt = sample["CT"].to(gpu_device).unsqueeze(0)
    
    # #global
    # ssim_score,cs_score = compute_ssim_and_cs(pred, tgt,3,(11,11,11),(1.5,1.5,1.5),data_range=2)
    # gmean_ssim.append(ssim_score.mean().item())
    # print(fprefix, "global", round(ssim_score.mean().item(),6))
    
    # # local
    mask_ = sample["MASK"].to(gpu_device)
    mask_ = dilate2d_volume(mask_, pixels=2)
    mask_ = mask_.unsqueeze(0)
    ssim_roi = masked_ssim_3d(pred, tgt, mask_, data_range=2.0,win_size=(11,11,11), sigma=(1.5,1.5,1.5))
    # mean_ssim.append(ssim_roi)
    print(fprefix, "local", round(ssim_roi,6))
    

# print(f"Mean lSSIM : {torch.tensor(mean_ssim).mean():.6f}")
# print(f"Mean gSSIM : {torch.tensor(gmean_ssim).mean():.6f}")

In [None]:
### global ssim:
# NGF 20 :: Mean SSIM : 0.726693
# NGF 0 :: Mean SSIM : 0.639660 v2
# best :: Mean SSIM : 0.560659

### local ssim:
# NGF 20 :: Mean SSIM : 0.513641
# NGF 0 :: Mean SSIM : 0.320117 v2
# best :: Mean SSIM : 0.505250


###
# EXP1 : 


## PSNR

In [None]:
# EXPERIMENT_PREFIX="Gamma_Seg_nocrop-n2-l4-f128_GAN1_L140.00_NGF20.00_a0.08_fresh"

weights_dir='Weights/'+EXPERIMENT_PREFIX
weights_path = os.path.join(weights_dir, EXPERIMENT_PREFIX+'_e%.4d.h5' % 41)
# weights_dir='Weights/'+"Gamma_crop-n2-l4-f128_GAN1_L100.00_NGF20.00_a0.10"
# weights_path = os.path.join(weights_dir, "Gamma_crop-n2-l4-f128_GAN1_L100.00_NGF20.00_a0.10"+'_e%.4d.h5' % EPOCH)
checkpoint = torch.load(weights_path)
# checkpoint = torch.load("Weights/Gamma_crop-n2-l4-f128_GAN1_L100.00_NGF20.00_a0.10/Gamma_crop-n2-l4-f128_GAN1_L100.00_NGF20.00_a0.10_e0070.h5")
pix2pix_model.load_state_dict(checkpoint['model'], strict=False)
model=pix2pix_model.generator_A_to_B
model.eval()

# PRED_LOCATION = "fakeCT_Gamma_nocrop-n2-l4-f128_GAN1_L100.00_NGF20.00_a0.10_e0070"
PRED_LOCATION = "fakeCT_" + EXPERIMENT_PREFIX + "_e0041"

NUM_WORKERS=1
BATCH_SIZE=1
test_transforms = Compose(
    [
        LoadImaged(keys=["PRED", "CT"], image_only=False),
        EnsureChannelFirstd(keys=["PRED", "CT"]),
        ScaleIntensityRanged(keys=["PRED", "CT"], a_min=-1000, a_max=3000, b_min=-1, b_max=1),
        EnsureTyped(keys=["PRED","CT"]),
    ]
)

# PIDs_test = [file.split('_')[0] + '.nii.gz' for file in os.listdir(os.path.join('CT'+suffix2d,"test_ct"))]
fnames_test_A = sorted(glob.glob(os.path.join(PRED_LOCATION,"*.gz")))
fnames_test_B = sorted([os.path.join(CTs,PID) for PID in PIDs_test])
fnames_test_C = sorted([os.path.join(MRs,PID) for PID in PIDs_test])


test_dic = [{"PRED": img1, "CT": img2} for (img1,img2) in zip(
    fnames_test_A, 
    fnames_test_B
)] 


test_ds = Dataset(test_dic, test_transforms)

In [None]:
import imageio.v2 as imageio
from skimage.metrics import structural_similarity as ssim
from skimage.metrics import mean_squared_error as mse
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.exposure import match_histograms

In [None]:
mean_psnr = []
for sample in test_ds:

    fake = sample["PRED"].to(gpu_device)
    ct_ = sample["CT"].to(gpu_device)

    fake_np = fake.squeeze().cpu().numpy().astype(np.float32)
    ct_np   = ct_.squeeze().cpu().numpy().astype(np.float32)

    fake_np = np.clip(fake_np, -1.0, 1.0)
    ct_np   = np.clip(ct_np,   -1.0, 1.0)

    c_psnr = psnr(ct_np, fake_np, data_range=2)
    mean_psnr.append(c_psnr)
    print(f"{sample['CT_meta_dict']['filename_or_obj'].split('/')[-1].split('.')[0]}'s psnr: {c_psnr}")

print(f"Mean PSNR : {torch.tensor(mean_psnr).mean():.3f}")

In [None]:
OUT = decollate_batch(test_data)
OUT_processed = (OUT)[0]['PRED']
OUT_processed_np = OUT_processed.detach().cpu().numpy().squeeze()
OUT_processed_np.shape
plt.imshow(OUT_processed_np[:,:,85])

## Dice

In [None]:
PRED_LOCATION = "CT_NGF0_V2_skull"
# PRED_LOCATION = "fakeCT_" + EXPERIMENT_PREFIX + "_e0041"

test_transforms = Compose(
    [
        LoadImaged(keys=["PRED", "CT"], image_only=False),
        EnsureChannelFirstd(keys=["PRED", "CT"]),
        EnsureTyped(keys=["PRED","CT"]),
    ]
)

PIDs_test = [file.split('_')[1] for file in os.listdir("fakeCT_Gamma_crop-n2-l4-f128_GAN1_L140.00_NGF20.00_a0.08_e0041")]
fnames_test_A = sorted(glob.glob(os.path.join(PRED_LOCATION,"*","*gz")))
fnames_test_B = sorted([os.path.join(CTs,PID+".nii.gz") for PID in PIDs_test])

test_dic = [{"PRED": img1, "CT": img2} for (img1,img2) in zip(
    fnames_test_A, 
    fnames_test_B
)] 


test_ds = Dataset(test_dic, test_transforms)


In [None]:
from monai.metrics import DiceMetric, MeanIoU

dice_metric = DiceMetric(include_background=False, reduction="mean", ignore_empty=True)

mean_dice = []
for sample in test_ds:
    pred = sample["PRED"].to(gpu_device).unsqueeze(0).float()
    gt   = sample["CT"].to(gpu_device).unsqueeze(0).float()
    
    d = dice_metric(y_pred=pred, y=gt) 

    mean_dice.append(d.item())

    pid = sample['CT_meta_dict']['filename_or_obj'].split('/')[-1].split('.')[0]
    print(f"{pid} â€” Dice: {d.item():.4f}")

print(f"\nMean Dice: {sum(mean_dice)/len(mean_dice):.4f}")
