In [None]:

import os
import copy
import numpy as np
import torch
import torch.nn as nn
from tqdm import tqdm
from torch import optim
from utils import *
from modules import UNet_conditional, EMA, UNet_conditional_concat, UNet_conditional_fully_concat, UNet_conditional_fully_add, UNet_conditional_concat_with_mask, UNet_conditional_concat_with_mask_v2
import logging
from torch.utils.tensorboard import SummaryWriter

logging.basicConfig(format="%(asctime)s - %(levelname)s: %(message)s", level=logging.INFO, datefmt="%I:%M:%S")
"""
需要修改的地方：

"""

class Diffusion:
    def __init__(self, noise_steps=1000, beta_start=1e-4, beta_end=0.02, img_size=240, device="cuda"):
        self.noise_steps = noise_steps
        self.beta_start = beta_start
        self.beta_end = beta_end

        self.beta = self.prepare_noise_schedule().to(device)
        self.alpha = 1. - self.beta
        self.alpha_hat = torch.cumprod(self.alpha, dim=0)

        self.img_size = img_size
        self.device = device

    def prepare_noise_schedule(self):
        return torch.linspace(self.beta_start, self.beta_end, self.noise_steps)

    def noise_images(self, x, t):
        sqrt_alpha_hat = torch.sqrt(self.alpha_hat[t])[:, None, None, None]
        sqrt_one_minus_alpha_hat = torch.sqrt(1 - self.alpha_hat[t])[:, None, None, None]
        Ɛ = torch.randn_like(x)
        return sqrt_alpha_hat * x + sqrt_one_minus_alpha_hat * Ɛ, Ɛ

    def _timesteps(self, n):
        return torch.randint(low=1, high=self.noise_steps, size=(n,))

    def sample(self, model, n, labels, masks, cfg_scale=0):
        logging.info(f"Sampling {n} new images....")
        model.eval()
        with torch.no_grad():
            x = torch.randn((n, 1, self.img_size, self.img_size)).to(self.device)
            for i in tqdm(reversed(range(1, self.noise_steps)), position=0):
                t = (torch.ones(n) * i).long().to(self.device)
                # predicted_noise = model(x, t, labels)
                predicted_noise = model(x, t, labels, masks)
                if cfg_scale > 0:
                    uncond_predicted_noise = model(x, t, None)
                    predicted_noise = torch.lerp(uncond_predicted_noise, predicted_noise, cfg_scale)
                alpha = self.alpha[t][:, None, None, None]
                alpha_hat = self.alpha_hat[t][:, None, None, None]
                beta = self.beta[t][:, None, None, None]
                if i > 1:
                    noise = torch.randn_like(x)
                else:
                    noise = torch.zeros_like(x)
                x = 1 / torch.sqrt(alpha) * (x - ((1 - alpha) / (torch.sqrt(1 - alpha_hat))) * predicted_noise) + torch.sqrt(beta) * noise
        return x

def inpaint_image(original_image, generated_image, mask):
    """
    将生成的图像融合到原始图像的指定区域中。
    
    参数:
    original_image (torch.Tensor): 原始图像
    generated_image (torch.Tensor): 生成的图像
    mask (torch.Tensor): 掩码图像, 1表示需要inpaint的区域, 0表示保留原图
    
    返回:
    torch.Tensor: 输出的合成图像
    """
    # 将三个输入tensor转换到相同的设备上
    device = original_image.device
    mask = mask.to(device)
    generated_image = generated_image.to(device)
    
    # 使用掩码融合原图和生成的图像
    output_image = original_image.clone()
    # print(output_image.shape, mask.shape, generated_image.shape)
    output_image = output_image * (1 - mask) + generated_image * mask
    
    return output_image



import argparse
parser = argparse.ArgumentParser()
args, unknown = parser.parse_known_args()
args.run_name = "DDPM_conditional"
args.batch_size = 10
args.image_size = 96#这个参数有什么用，和原始图像的关系是什么
# args.dataset_path =  r"D:\ASNR-MICCAI-BraTS2023-Local-Synthesis-Challenge-Training"
args.dataset_path =  r"C:\Users\DELL\Desktop\DDPM\ddpm_brats\DDPM_brain\test_data"
args.device = "cuda"
args.lr = 3e-4
args.train = True
args.shuffle = False
device = 'cuda'
dataloader = get_data(args)
model = UNet_conditional_concat_with_mask_v2().to(device)
# model = UNet_conditional_concat().to(device)
ckpt = torch.load("./models/DDPM_conditional/ema_ckpt.pt")
model.load_state_dict(ckpt)
diffusion = Diffusion(img_size=args.image_size, device=device)
pbar = tqdm(dataloader)
images, cropped_images, masks = next(iter(pbar))
modefied_images = cropped_images
b, _, _, _ = images.shape
d_images = diffusion.sample(model, n=b, labels=modefied_images, masks=masks)

# d_images = diffusion.sample(model, n=b, labels=modefied_images)


In [None]:
def inpaint_image(original_image, generated_image, mask):
    """
    将生成的图像融合到原始图像的指定区域中。
    
    参数:
    original_image (torch.Tensor): 原始图像
    generated_image (torch.Tensor): 生成的图像
    mask (torch.Tensor): 掩码图像, 1表示需要inpaint的区域, 0表示保留原图
    
    返回:
    torch.Tensor: 输出的合成图像
    """
    # 将三个输入tensor转换到相同的设备上
    device = original_image.device
    mask = mask.to(device)
    generated_image = generated_image.to(device)
    
    # 使用掩码融合原图和生成的图像
    output_image = original_image.clone()
    print(output_image.shape, mask.shape, generated_image.shape)
    # print(output_image.shape, mask.shape, generated_image.shape)
    reference_image = output_image * mask
    output_image = output_image * (1 - mask) + generated_image * mask
    generated_image = generated_image * mask
    
    return output_image, generated_image, reference_image
images_predict_slice, generated_image, reference_image = inpaint_image(images[:,:,:,:], d_images[:,:,:,:], masks[:,:,:,:])


In [None]:
import torch
import matplotlib.pyplot as plt



# 选择要显示的图像索引
index = 8
images = images.cpu()
d_images_new = images_predict_slice.cpu()
dd = d_images[:,:,:,:].cpu()
cropped_images = modefied_images.cpu()
g_images = generated_image.cpu()
r_images = reference_image.cpu()
masks_clone = masks.cpu()
fig, axes = plt.subplots(2, 2, figsize=(12, 12))
# 获取选中的图像
img = images[index, 0, :, :]
d_img = d_images_new[index, 0, :, :]
c_img = cropped_images[index, 0, :, :]
g_img = g_images[index, 0, :, :]
r_img = r_images[index, 0, :, :]
dd_img = dd[index, 0, :, :]
mask = masks_clone[index, 0, :, :]
# print(g_img)
# print(img[10,:])
if torch.any(g_img != 0):
    print("g_img 中包含非0元素")
else:
    print("g_img 中全是0")
non_zero_elements = g_img[g_img != 0]
# print(non_zero_elements)
# 使用 matplotlib 显示图像
plt.figure(figsize=(8, 8))
axes[0,0].imshow(img, cmap='gray')
axes[0,0].set_title('Ground-truth')
axes[0,1].imshow(c_img, cmap='gray')
axes[0,1].set_title('Cropped guidance')
axes[1,0].imshow(dd_img, cmap='gray')
axes[1,0].set_title('DDPM genarated')
axes[1,1].imshow(d_img, cmap='gray')
axes[1,1].set_title('Final infilled image')
plt.show()
mse = nn.MSELoss()
# loss = mse(reference_image, masks_clone)
# print(loss)

In [None]:

import os
import copy
import numpy as np
import torch
import torch.nn as nn
from tqdm import tqdm
from torch import optim
from utils import *
# from utils import get_data_inference
from modules import UNet_conditional, EMA, UNet_conditional_concat, UNet_conditional_fully_concat, UNet_conditional_fully_add
from modules import UNet_conditional_concat_with_mask, UNet_conditional_concat_with_mask_v2, UNet_conditional_concat_Large
from modules import UNet_conditional_concat_XLarge, UNet_conditional_concat_with_mask_GAM
import logging
from torch.utils.tensorboard import SummaryWriter
# from utils import _structural_similarity_index, _peak_signal_noise_ratio, _mean_squared_error

logging.basicConfig(format="%(asctime)s - %(levelname)s: %(message)s", level=logging.INFO, datefmt="%I:%M:%S")
"""
需要修改的地方：

"""
torch.manual_seed(0)

from scipy.ndimage import gaussian_filter
sigma = 1.0  # 调整此值以控制平滑程度


class Diffusion:
    def __init__(self, noise_steps=1000, beta_start=1e-4, beta_end=0.02, img_size=240, device="cuda"):
        self.noise_steps = noise_steps
        self.beta_start = beta_start
        self.beta_end = beta_end

        self.beta = self.prepare_noise_schedule().to(device)
        self.alpha = 1. - self.beta
        self.alpha_hat = torch.cumprod(self.alpha, dim=0)

        self.img_size = img_size
        self.device = device

    def prepare_noise_schedule(self):
        return torch.linspace(self.beta_start, self.beta_end, self.noise_steps)

    def noise_images(self, x, t):
        sqrt_alpha_hat = torch.sqrt(self.alpha_hat[t])[:, None, None, None]
        sqrt_one_minus_alpha_hat = torch.sqrt(1 - self.alpha_hat[t])[:, None, None, None]
        Ɛ = torch.randn_like(x)
        return sqrt_alpha_hat * x + sqrt_one_minus_alpha_hat * Ɛ, Ɛ

    def _timesteps(self, n):
        return torch.randint(low=1, high=self.noise_steps, size=(n,))

    def sample(self, model, n, labels, masks, cfg_scale=0):
        logging.info(f"Sampling {n} new images....")
        model.eval()
        noise_img_list = []
        with torch.no_grad():
            x = torch.randn((n, 1, self.img_size, self.img_size)).to(self.device) + 0.5*labels
            # x = labels+masks
            for i in tqdm(reversed(range(1, self.noise_steps)), position=0):
                t = (torch.ones(n) * i).long().to(self.device)
                # predicted_noise = model(x, t, labels)
                predicted_noise = model(x, t, labels, masks)

                alpha = self.alpha[t][:, None, None, None]
                alpha_hat = self.alpha_hat[t][:, None, None, None]
                beta = self.beta[t][:, None, None, None]
                if i > 1:
                    noise = torch.randn_like(x)
                else:
                    noise = torch.zeros_like(x)
                x = 1 / torch.sqrt(alpha) * (x - ((1 - alpha) / (torch.sqrt(1 - alpha_hat))) * predicted_noise) + torch.sqrt(beta) * noise

                if i % 100 == 0 and i > 200:
                    noise_img_list.append(x.detach().cpu())
                if i % 40 ==0 and i <= 200:
                    noise_img_list.append(x.detach().cpu())

        return x, noise_img_list


def inpaint_image(original_image, generated_image, mask):
    """
    将生成的图像融合到原始图像的指定区域中。
    
    参数:
    original_image (torch.Tensor): 原始图像
    generated_image (torch.Tensor): 生成的图像
    mask (torch.Tensor): 掩码图像, 1表示需要inpaint的区域, 0表示保留原图
    
    返回:
    torch.Tensor: 输出的合成图像
    """
    # 将三个输入tensor转换到相同的设备上
    device = original_image.device
    mask = mask.to(device)
    generated_image = generated_image.to(device)
    
    # 使用掩码融合原图和生成的图像
    output_image = original_image.clone()
    # print(output_image.shape, mask.shape, generated_image.shape)
    # print(output_image.shape, mask.shape, generated_image.shape)
    reference_image = output_image * mask
    output_image = original_image * (1 - mask) + generated_image * mask
    generated_image = generated_image * mask
    
    return output_image, generated_image, reference_image
# images_predict_slice, generated_image, reference_image = inpaint_image(images[:,:,:,:], d_images[:,:,:,:], masks[:,:,:,:])


import argparse
import re
parser = argparse.ArgumentParser()
args, unknown = parser.parse_known_args()
args.run_name = "DDPM_conditional"
args.batch_size = 2
args.image_size = 96#这个参数有什么用，和原始图像的关系是什么
# args.dataset_path =  r"D:\ASNR-MICCAI-BraTS2023-Local-Synthesis-Challenge-Training"
args.dataset_path =  r"C:\Users\DELL\Desktop\DDPM\ddpm_brats\DDPM_brain\test_data\test_data_1"
args.generated_data_path = r"C:\Users\DELL\Desktop\DDPM\ddpm_brats\DDPM_brain\test_data\generated_data"
args.device = "cuda"
args.lr = 1e-4
args.train = True
args.shuffle = False
device = 'cuda'
# dataloader = get_data_inference(args)
# model = UNet_conditional_concat_Large().to(device)
model = UNet_conditional_concat_with_mask_GAM().to(device)
ckpt = torch.load("./models/DDPM_conditional/204_ema_ckpt.pt")
model.load_state_dict(ckpt)
diffusion = Diffusion(img_size=args.image_size, device=device)


test_dataloader = get_data_inference(args)
print(len(test_dataloader))
pbar_test = tqdm(test_dataloader)
for i, (images, cropped_images, masks, path) in enumerate(pbar_test):
    modefied_images = cropped_images
    b, _, _, _ = cropped_images.shape
    # print(b)
    # print(image_without_healthy.shape)
    masks = masks.to(torch.float)
    # print(masks.shape)
    modefied_images = modefied_images.to(device)
    d_images, noise_img_list = diffusion.sample(model, n=b, labels=modefied_images, masks=masks)
    images_predict_slice, generated_image, reference_image = inpaint_image(images[:,:,:,:], d_images[:,:,:,:], masks[:,:,:,:])
    img = images.cpu()[:,0,:,:]
    d_images_new = images_predict_slice.cpu()
    dd_img = d_images.cpu()[:, 0, :, :]
    mask_img = masks.cpu()[:, 0, :, :]
    ref_images_new = reference_image.cpu()
    # dd = d_images[:,:,:,:].cpu()
    # # cropped_images = modefied_images.cpu()
    # # masks_clone = masks.cpu()
    # plt.imshow(img[1, :,:] - img[0, :,:], cmap='gray')
    # plt.show()

    for index in range(b):
        # z = i * args.batch_size + index
        fig, axes = plt.subplots(2, 2, figsize=(12, 12))
        dd_img_new = dd_img[index, :, :]
        d_img = d_images_new[index, 0, :, :]
        ref_img = ref_images_new[index, 0, :, :]
        plt.figure(figsize=(8, 8))
        axes[0,0].imshow(d_img, cmap='gray')
        axes[0,1].imshow(dd_img_new, cmap='gray')
        axes[1,0].imshow(mask_img[index,:,:], cmap='gray')
        axes[1,1].imshow(img[index, :,:], cmap='gray')

        plt.show()
        orgin_path = path[index]
        
        filename = os.path.split(orgin_path)[-1]
        # print(orgin_path, filename)
        
        match = re.search(r'96_slice_(\d+)\.npz', filename)
        number = match.group(1)
        print(number)
        subject_name = os.path.basename(os.path.dirname(orgin_path))
        generated_img_save_folder = os.path.join(args.generated_data_path, subject_name)
        os.makedirs(os.path.join(args.generated_data_path, subject_name), exist_ok=True)
        generated_img_save_path = os.path.join(generated_img_save_folder, f'generated_slice_{number}.npz')

        np.savez(generated_img_save_path,
                 image = d_img)



In [None]:
for x in noise_img_list:
    # print(x.shape)
    fig, axes = plt.subplots(2, 1, figsize=(6, 6))
    axes[0].imshow(x[0,0,:,:], cmap='gray')
    axes[1].imshow(x[1,0,:,:], cmap='gray')
    plt.show()
    # images_predict_slice, generated_image, reference_image = inpaint_image(images[:,:,:,:], x[:,:,:,:], masks[:,:,:,:])
    # plt.imshow(x[1,0,:,:], cmap='gray')
    # plt.show()

In [None]:
import glob
import numpy as np
import nibabel as nib
import os
import torch.nn as nn
from tqdm import tqdm
from torch import optim
from utils import *
import cv2
subject_dir = r'C:\Users\DELL\Desktop\DDPM\ddpm_brats\DDPM_brain\test_data\test_data\BraTS-GLI-01666-000'
print(subject_dir)
image_path = glob.glob(os.path.join(subject_dir, '*t1n.nii.gz'))[0]
print(image_path)
cropped_image_path = glob.glob(os.path.join(subject_dir, '*t1n-voided.nii.gz'))[0]
healthy_mask_path = glob.glob(os.path.join(subject_dir, '*healthy.nii.gz'))[0]
unhealthy_mask_path = glob.glob(os.path.join(subject_dir, '*unhealthy.nii.gz'))[0]


# 加载图像和掩膜
image = nib.load(image_path).get_fdata().astype(
                np.float32
            )
healthy_mask = nib.load(healthy_mask_path).get_fdata().astype(
                np.float32
            )
unhealthy_mask = nib.load(unhealthy_mask_path).get_fdata().astype(
                np.float32
            )
cropped_image = nib.load(cropped_image_path).get_fdata().astype(
                np.float32
            )
# image = image_preprocess(image)
mask_affine = nib.load(image_path).affine
ref_img = nib.Nifti1Image(image, mask_affine)
nib.save(ref_img, 'example_ref.nii.gz')

# print(image.shape)
# cropped_image = image_preprocess(cropped_image)
voided_img = nib.Nifti1Image(cropped_image, mask_affine)
# nib.save(voided_img, 'example_voided.nii.gz')

nonzero_coords = np.nonzero(healthy_mask)
center_x = (np.min(nonzero_coords[0]) + np.max(nonzero_coords[0])) // 2
center_y = (np.min(nonzero_coords[1]) + np.max(nonzero_coords[1])) // 2
center_z = (np.min(nonzero_coords[2]) + np.max(nonzero_coords[2])) // 2
image_shape = [240,240,155]
# 计算裁剪区域的边界
img_size = 96
crop_x1 = max(center_x - int(img_size/2), 0)
crop_x2 = min(center_x + int(img_size/2), image_shape[0])
crop_y1 = max(center_y - int(img_size/2), 0)
crop_y2 = min(center_y + int(img_size/2), image_shape[1])
# crop_z1 = max(center_z - 48, 0)
# crop_z2 = min(center_z + 48, image_shape[2])
crop_z1 = np.min(nonzero_coords[2])
crop_z2 = np.max(nonzero_coords[2])

# # 如果裁剪区域小于 96x96x96,则在另一边扩展
crop_size_x = crop_x2 - crop_x1
crop_size_y = crop_y2 - crop_y1
crop_size_z = crop_z2 - crop_z1
# #保存几何坐标信息
geometric_list = [crop_x1, crop_x2, crop_y1, crop_y2, crop_z1, crop_z2]

if crop_size_x < img_size:
    if center_x - int(img_size/2) < 0:
        crop_x1 = 0
        crop_x2 = img_size
    else:
        crop_x1 = image_shape[0] - int(img_size)
        crop_x2 = image_shape[0]

if crop_size_y < img_size:
    if center_y - int(img_size/2) < 0:
        crop_y1 = 0
        crop_y2 = img_size
    else:
        crop_y1 = image_shape[1] - int(img_size)
        crop_y2 = image_shape[1]

np_org = cropped_image
np_org_clipped = np.percentile(np_org, [0.5, 99.5])
start = np.min(np_org_clipped)
end = np.max(np_org_clipped)
width = end - start
print(start, end, width)
# slice_paths = glob.glob(os.path.join(subject_, 'generated_slice_*.npz'))
# 保存每个slice的数据
for z in range(45):
    subject_slice_path = 'C:/Users/DELL/Desktop/DDPM/ddpm_brats/DDPM_brain/test_data/generated_data/BraTS-GLI-01666-000/generated_slice_' + str(z) + '.npz'
    with np.load(subject_slice_path) as data:
            # np_org = np.asarray(org_nifti.dataobj)
            

            np_redef = data['image']
            

            # normalize between ...

            clipped_image = np.clip(np_redef, 0, 1)
            
            # print()
            norm_img = (clipped_image - clipped_image.min()) / (
                clipped_image.max() - clipped_image.min()
            ) * width + start
            # kernel_size = (3, 3)  # 高斯核大小
            # sigma = 0  # 高斯核标准差,0则自动计算
            # norm_img = cv2.GaussianBlur(norm_img, kernel_size, sigma)
            # d_image = data['image']
            
    image[crop_x1:crop_x2, crop_y1:crop_y2, crop_z1+z] = norm_img
    # image[:,:, crop_z1+z] = norm_img

    # slice_adjacency_image = adjacency_image[:, :, :, z]

    # slice_unhealthy_mask = unhealthy_mask[:, :, z]
import numpy as np
from scipy.ndimage import gaussian_filter
sigma = 1.0  # 调整此值以控制平滑程度
img_smoothed = image.copy()
for d in range(image.shape[2]):
    img_smoothed[:, :, d] = gaussian_filter(image[:, :, d], sigma=sigma)

img_smoothed_v2 = img_smoothed.copy()
for d in range(image.shape[0]):
    img_smoothed_v2[d, :, :] = gaussian_filter(img_smoothed[d, :, :], sigma=sigma)

img_smoothed_v3 = img_smoothed_v2.copy()
for d in range(image.shape[1]):
    img_smoothed_v3[:, d, :] = gaussian_filter(img_smoothed_v2[:, d, :], sigma=sigma)

img = nib.Nifti1Image(image, mask_affine)

nib.save(img, 'example.nii.gz')

img_smoothed = nib.Nifti1Image(img_smoothed, mask_affine)

nib.save(img_smoothed, 'example_smoothed.nii.gz')

img_smoothed_v2 = nib.Nifti1Image(img_smoothed_v2, mask_affine)

nib.save(img_smoothed_v2, 'example_smoothed_v2.nii.gz')

img_smoothed_v3 = nib.Nifti1Image(img_smoothed_v3, mask_affine)

nib.save(img_smoothed_v3, 'example_smoothed_v3.nii.gz')

In [5]:
import glob
import numpy as np
import nibabel as nib
import os
import torch.nn as nn
from tqdm import tqdm
from torch import optim
from utils import *
import cv2

for subject_dir in glob.glob(r'C:\Users\DELL\Desktop\DDPM\ddpm_brats\DDPM_brain\test_data\test_data\*'):
    subject_name = os.path.split(subject_dir)[1]
    print(subject_dir)
    image_path = glob.glob(os.path.join(subject_dir, '*t1n.nii.gz'))[0]
    # print(image_path)
    cropped_image_path = glob.glob(os.path.join(subject_dir, '*t1n-voided.nii.gz'))[0]
    healthy_mask_path = glob.glob(os.path.join(subject_dir, '*healthy.nii.gz'))[0]
    unhealthy_mask_path = glob.glob(os.path.join(subject_dir, '*unhealthy.nii.gz'))[0]


    # 加载图像和掩膜
    image = nib.load(image_path).get_fdata().astype(
                    np.float32
                )
    healthy_mask = nib.load(healthy_mask_path).get_fdata().astype(
                    np.float32
                )
    unhealthy_mask = nib.load(unhealthy_mask_path).get_fdata().astype(
                    np.float32
                )
    cropped_image = nib.load(cropped_image_path).get_fdata().astype(
                    np.float32
                )
    # image = image_preprocess(image)
    mask_affine = nib.load(image_path).affine
    ref_img = nib.Nifti1Image(image, mask_affine)
    # nib.save(ref_img, 'example_ref.nii.gz')

    # print(image.shape)
    # cropped_image = image_preprocess(cropped_image)
    voided_img = nib.Nifti1Image(cropped_image, mask_affine)
    # nib.save(voided_img, 'example_voided.nii.gz')

    nonzero_coords = np.nonzero(healthy_mask)
    center_x = (np.min(nonzero_coords[0]) + np.max(nonzero_coords[0])) // 2
    center_y = (np.min(nonzero_coords[1]) + np.max(nonzero_coords[1])) // 2
    center_z = (np.min(nonzero_coords[2]) + np.max(nonzero_coords[2])) // 2
    image_shape = [240,240,155]
    # 计算裁剪区域的边界
    img_size = 96
    crop_x1 = max(center_x - int(img_size/2), 0)
    crop_x2 = min(center_x + int(img_size/2), image_shape[0])
    crop_y1 = max(center_y - int(img_size/2), 0)
    crop_y2 = min(center_y + int(img_size/2), image_shape[1])
    # crop_z1 = max(center_z - 48, 0)
    # crop_z2 = min(center_z + 48, image_shape[2])
    crop_z1 = np.min(nonzero_coords[2])
    crop_z2 = np.max(nonzero_coords[2])

    # # 如果裁剪区域小于 96x96x96,则在另一边扩展
    crop_size_x = crop_x2 - crop_x1
    crop_size_y = crop_y2 - crop_y1
    crop_size_z = crop_z2 - crop_z1
    # #保存几何坐标信息
    geometric_list = [crop_x1, crop_x2, crop_y1, crop_y2, crop_z1, crop_z2]

    if crop_size_x < img_size:
        if center_x - int(img_size/2) < 0:
            crop_x1 = 0
            crop_x2 = img_size
        else:
            crop_x1 = image_shape[0] - int(img_size)
            crop_x2 = image_shape[0]

    if crop_size_y < img_size:
        if center_y - int(img_size/2) < 0:
            crop_y1 = 0
            crop_y2 = img_size
        else:
            crop_y1 = image_shape[1] - int(img_size)
            crop_y2 = image_shape[1]

    np_org = cropped_image
    np_org_clipped = np.percentile(np_org, [0.5, 99.5])
    start = np.min(np_org_clipped)
    end = np.max(np_org_clipped)
    width = end - start
    print(start, end, width)
    slice_paths = glob.glob(os.path.join(r'C:\Users\DELL\Desktop\DDPM\ddpm_brats\DDPM_brain/test_data/generated_data/', subject_name,'generated_slice_*.npz'))
    # print(os.path.join(r'C:\Users\DELL\Desktop\DDPM\ddpm_brats\DDPM_brain/test_data/generated_data/', subject_name,'generated_slice_*.npz'))
    # print(len(slice_paths))
    # 保存每个slice的数据
    for z in range(len(slice_paths)):
        subject_slice_path = os.path.join(r'C:\Users\DELL\Desktop\DDPM\ddpm_brats\DDPM_brain/test_data/generated_data/', subject_name,'generated_slice_' + str(z) + '.npz')
        with np.load(subject_slice_path) as data:
                np_redef = data['image']
            

                # normalize between ...

                clipped_image = np.clip(np_redef, 0, 1)
                
                # print()
                norm_img = (clipped_image - clipped_image.min()) / (
                    clipped_image.max() - clipped_image.min()
                ) * width + start
                # print('clipped_image.max()',clipped_image.max(), clipped_image.min())
                # print('norm_img.max()',norm_img.max(), norm_img.min())
                # print('origin_image.max()',np_redef.max(), np_redef.min())
        # print(crop_z1, z, crop_z1+z)
        image[crop_x1:crop_x2, crop_y1:crop_y2, crop_z1+z] = norm_img
        
    import numpy as np
    from scipy.ndimage import gaussian_filter
    sigma = 1.0  # 调整此值以控制平滑程度
    img_smoothed = image.copy()
    for d in range(image.shape[2]):
        img_smoothed[:, :, d] = gaussian_filter(image[:, :, d], sigma=sigma)

    img_smoothed_v2 = img_smoothed.copy()
    for d in range(image.shape[0]):
        img_smoothed_v2[d, :, :] = gaussian_filter(img_smoothed[d, :, :], sigma=sigma)

    img_smoothed_v3 = img_smoothed_v2.copy()
    for d in range(image.shape[1]):
        img_smoothed_v3[:, d, :] = gaussian_filter(img_smoothed_v2[:, d, :], sigma=sigma)

    img = nib.Nifti1Image(image, mask_affine)

    nib.save(img, os.path.join(r'C:\Users\DELL\Desktop\DDPM\ddpm_brats\DDPM_brain/test_data/generated_data/', subject_name,'result.nii.gz'))

    img_smoothed = nib.Nifti1Image(img_smoothed, mask_affine)

    # nib.save(img_smoothed, 'example_smoothed.nii.gz')

    img_smoothed_v2 = nib.Nifti1Image(img_smoothed_v2, mask_affine)

    # nib.save(img_smoothed_v2, 'example_smoothed_v2.nii.gz')

    img_smoothed_v3 = nib.Nifti1Image(img_smoothed_v3, mask_affine)

    nib.save(img_smoothed_v3, os.path.join(r'C:\Users\DELL\Desktop\DDPM\ddpm_brats\DDPM_brain/test_data/generated_data/', subject_name,'result_smoothed_v3.nii.gz'))
    file_path = os.path.join(r'C:\Users\DELL\Desktop\DDPM\ddpm_brats\DDPM_brain/test_data/generated_data/', subject_name,'result_smoothed_v3.nii.gz')

    # 加载 NIfTI 文件
    img = nib.load(file_path)
    data = img.get_fdata()

    # 修改数据
    data[data < 50] = 0

    # 保存修改后的数据
    new_img = nib.Nifti1Image(data, img.affine, img.header)
    nib.save(new_img, file_path)

C:\Users\DELL\Desktop\DDPM\ddpm_brats\DDPM_brain\test_data\test_data\BraTS-GLI-01610-000
0.0 688.358277282715 688.358277282715
C:\Users\DELL\Desktop\DDPM\ddpm_brats\DDPM_brain\test_data\test_data\BraTS-GLI-01657-000
0.0 833.0 833.0
C:\Users\DELL\Desktop\DDPM\ddpm_brats\DDPM_brain\test_data\test_data\BraTS-GLI-01658-000
0.0 1447.0 1447.0
C:\Users\DELL\Desktop\DDPM\ddpm_brats\DDPM_brain\test_data\test_data\BraTS-GLI-01659-000
0.0 1795.0 1795.0
C:\Users\DELL\Desktop\DDPM\ddpm_brats\DDPM_brain\test_data\test_data\BraTS-GLI-01660-000
0.0 1699.0 1699.0
C:\Users\DELL\Desktop\DDPM\ddpm_brats\DDPM_brain\test_data\test_data\BraTS-GLI-01661-000
0.0 1705.0 1705.0
C:\Users\DELL\Desktop\DDPM\ddpm_brats\DDPM_brain\test_data\test_data\BraTS-GLI-01662-000
0.0 2876.0 2876.0
C:\Users\DELL\Desktop\DDPM\ddpm_brats\DDPM_brain\test_data\test_data\BraTS-GLI-01663-000
0.0 881.0 881.0
C:\Users\DELL\Desktop\DDPM\ddpm_brats\DDPM_brain\test_data\test_data\BraTS-GLI-01664-000
0.0 908.0 908.0
C:\Users\DELL\Desktop\D

In [21]:
import nibabel as nib
import numpy as np
import os

# 文件路径
file_path = r"C:\Users\DELL\Desktop\DDPM\ddpm_brats\DDPM_brain\test_data\generated_data\BraTS-GLI-01659-000\result_smoothed_v3.nii.gz"

# 加载 NIfTI 文件
img = nib.load(file_path)
data = img.get_fdata()

# 修改数据
data[data < 50] = 0

# 保存修改后的数据
new_img = nib.Nifti1Image(data, img.affine, img.header)
nib.save(new_img, file_path)

#要么先做平滑后做normalize
#要么normalize后做平滑之后再加一个卡阈值，一旦mask包含背景，平滑会导致背景有值

In [None]:
from inpainting.challenge_metrics_2023 import generate_metrics
import nibabel as nib
import torch

result_img = nib.load(r'C:\Users\DELL\Desktop\DDPM\ddpm_brats\DDPM_brain\example.nii.gz')
result = torch.Tensor(result_img.get_fdata()).unsqueeze(0)

# Healthy mask (evaluation volume)
mask_path = r'C:\Users\DELL\Desktop\DDPM\ddpm_brats\DDPM_brain\test_data\test_data\BraTS-GLI-01666-000\BraTS-GLI-01666-000-mask-healthy.nii.gz'
mask_img = nib.load(mask_path)
mask_healthy = torch.Tensor(mask_img.get_fdata()).bool().unsqueeze(0)
mask_affine = nib.load(mask_path).affine

# Reference (ground truth)
t1n_path = r'C:\Users\DELL\Desktop\DDPM\ddpm_brats\DDPM_brain\test_data\test_data\BraTS-GLI-01666-000\BraTS-GLI-01666-000-t1n.nii.gz'
t1n_img = nib.load(t1n_path)
t1n = torch.Tensor(t1n_img.get_fdata()).unsqueeze(0)

# Normalization Tensor (on what basis shall be normalized? On the model input!)
t1n_voided_path = r'C:\Users\DELL\Desktop\DDPM\ddpm_brats\DDPM_brain\test_data\test_data\BraTS-GLI-01666-000\BraTS-GLI-01666-000-t1n-voided.nii.gz'
t1n_voided_img = nib.load(t1n_voided_path)
t1n_voided = torch.Tensor(t1n_voided_img.get_fdata()).unsqueeze(0)

# # Compute metrics
metrics_dict = generate_metrics( #expected Tensor dimension: 1 x 255 x 255 x 
    prediction=result,
    target=t1n,
    mask=mask_healthy,
    normalization_tensor= t1n_voided #former: t1n * ~mask_healthy
    )

print(metrics_dict)

In [6]:
from inpainting.challenge_metrics_2023 import generate_metrics
import nibabel as nib
import os
import glob
import torch
ssim = 0
psnr = 0
mse = 0
counter = 0
for subject_dir in glob.glob(r'C:\Users\DELL\Desktop\DDPM\ddpm_brats\DDPM_brain\test_data\test_data\*'):
    counter += 1
    subject_name = os.path.split(subject_dir)[1]
    # print(subject_name)
    result_img = nib.load(os.path.join(r'C:\Users\DELL\Desktop\DDPM\ddpm_brats\DDPM_brain/test_data/generated_data/', subject_name, 'result_smoothed_v3.nii.gz'))
    result = torch.Tensor(result_img.get_fdata()).unsqueeze(0)

    # Healthy mask (evaluation volume)
    mask_path = os.path.join(r'C:\Users\DELL\Desktop\DDPM\ddpm_brats\DDPM_brain\test_data\test_data/', subject_name, subject_name + '-mask-healthy.nii.gz')
    mask_img = nib.load(mask_path)
    mask_healthy = torch.Tensor(mask_img.get_fdata()).bool().unsqueeze(0)
    mask_affine = nib.load(mask_path).affine

    # Reference (ground truth)
    t1n_path = os.path.join(r'C:\Users\DELL\Desktop\DDPM\ddpm_brats\DDPM_brain\test_data\test_data/', subject_name, subject_name + '-t1n.nii.gz')
    t1n_img = nib.load(t1n_path)
    t1n = torch.Tensor(t1n_img.get_fdata()).unsqueeze(0)

    # Normalization Tensor (on what basis shall be normalized? On the model input!)
    t1n_voided_path = os.path.join(r'C:\Users\DELL\Desktop\DDPM\ddpm_brats\DDPM_brain\test_data\test_data/', subject_name, subject_name + '-t1n-voided.nii.gz')
    t1n_voided_img = nib.load(t1n_voided_path)
    t1n_voided = torch.Tensor(t1n_voided_img.get_fdata()).unsqueeze(0)

    # # Compute metrics
    metrics_dict = generate_metrics( #expected Tensor dimension: 1 x 255 x 255 x 
        prediction=result,
        target=t1n,
        mask=mask_healthy,
        normalization_tensor= t1n_voided #former: t1n * ~mask_healthy
        )
    ssim += metrics_dict['ssim']
    psnr += metrics_dict['psnr']
    mse += metrics_dict['mse']
    print(subject_name, metrics_dict['ssim'], metrics_dict['psnr'], metrics_dict['mse'])
print(ssim/counter)
print(psnr/counter)
print(mse/counter)

BraTS-GLI-01610-000 0.7639491558074951 10.56435489654541 0.0873989388346672
BraTS-GLI-01657-000 0.5881217122077942 14.475207328796387 0.03568447008728981
BraTS-GLI-01658-000 0.6736892461776733 14.809887886047363 0.03303780406713486
BraTS-GLI-01659-000 0.7168177366256714 20.010009765625 0.008948862552642822
BraTS-GLI-01660-000 0.8512413501739502 13.571722984313965 0.04199332371354103
BraTS-GLI-01661-000 0.6454342603683472 13.226490020751953 0.047571949660778046
BraTS-GLI-01662-000 0.4823773503303528 7.932794570922852 0.16096094250679016
BraTS-GLI-01663-000 0.7906925678253174 13.756220817565918 0.042109277099370956
BraTS-GLI-01664-000 0.8452531695365906 16.904422760009766 0.02039659395813942
BraTS-GLI-01665-000 0.967793881893158 17.548715591430664 0.016946502029895782
BraTS-GLI-01666-000 0.5539344549179077 11.373947143554688 0.07287947088479996
0.716300444169478
14.01579761505127
0.0516298304904591


In [None]:
# 尝试修改切片能不能生成
#看下生成的和原图差多少
# 看下原始数据最大最小值的差距
# 生成更多数据，或者加入额外的数据集
# 尝试多模态数据
# 可以加一个分类，告诉模型有没有超出边界