In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt
import os
import glob
import SimpleITK as sitk
import monai
import itertools
from monai.transforms import (
    Compose,
    ToTensor,
    ScaleIntensityRange,
)
from monai.utils import set_determinism
from networks.add_net.generator import UnetGenerator
from networks.add_net.discriminator import ConditionalDiscriminator

In [None]:
torch.__version__
set_determinism(seed=42)

In [None]:
imgs_path = glob.glob('D:/DeepLearning/image2image/train/VMI40/*.nii.gz')

In [None]:
len(imgs_path)

In [None]:
annos_path = glob.glob('D:/DeepLearning/image2image/train/CI/*.nii.gz')

In [None]:
len(annos_path)

In [None]:
imgs_path[:3], annos_path[:3]

In [None]:
plt.figure(figsize=(10,10))
for i, img_path in enumerate(imgs_path[:4]):
    img = sitk.ReadImage(img_path)
    img_np = sitk.GetArrayFromImage(img)
    img_np = np.expand_dims(img_np, axis=0)
    plt.subplot(2,2,i+1)
    plt.imshow(img_np[0,:,:], cmap='gray')
    plt.title(img_path.split('\\')[-1])
    plt.axis('off')

In [None]:
plt.figure(figsize=(10,10))
for i, img_path in enumerate(annos_path[:4]):
    img = sitk.ReadImage(img_path)
    img_np = sitk.GetArrayFromImage(img)
    img_np = np.expand_dims(img_np, axis=0)
    plt.subplot(2,2,i+1)
    plt.imshow(img_np[0,:,:], cmap='gray')
    plt.title(img_path.split('\\')[-1])
    plt.axis('off')

In [None]:
#transform = Compose([ScaleIntensity(minv=-1, maxv=1), ToTensor()])
transform = Compose([ScaleIntensityRange(a_min=-1000, a_max=3700, b_min=-1, b_max=1,clip=True), ToTensor()])


In [None]:
class nii_dataset(torch.utils.data.Dataset):
    def __init__(self, annos_path,imgs_path):
        self.imgs_path = imgs_path
        self.annos_path = annos_path
    
    def __getitem__(self, index):
        anno_path = self.annos_path[index]
        anno = sitk.ReadImage(anno_path)
        anno_np = sitk.GetArrayFromImage(anno).astype(np.float32)
        anno_np = np.expand_dims(anno_np, axis=0)
        anno_tensor = transform(anno_np)
        img_path = self.imgs_path[index]
        img = sitk.ReadImage(img_path)
        img_np = sitk.GetArrayFromImage(img).astype(np.float32)
        img_np = np.expand_dims(img_np, axis=0)
        #print(img_np.shape)
        img_tensor = transform(img_np)

        return anno_tensor,img_tensor
    def __len__(self):
        return len(self.imgs_path)

In [None]:
dataset = nii_dataset(annos_path, imgs_path)

In [None]:
len(dataset)

In [None]:
BATCH_SIZE = 20
dataloader = torch.utils.data.DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

In [None]:
annos_batch, imgs_batch = next(iter(dataloader))

In [None]:
annos_batch.shape, imgs_batch.shape

In [None]:
plt.imshow(annos_batch[0].numpy()[0,:,:],cmap='gray')
sitk.WriteImage(sitk.GetImageFromArray(annos_batch[0]), './annos_batch.nii.gz')

In [None]:
plt.figure(figsize=(10,20))
for i, (anno, img) in enumerate(zip(annos_batch[:4], imgs_batch[:4])):
    anno = (anno.numpy() +1)/2
    img = (img.numpy() +1)/2
    plt.subplot(4,2,2*i+1)
    plt.imshow(anno[0,:,:], cmap='gray')
    plt.title('input image')
    plt.subplot(4,2,2*i+2)
    plt.imshow(img[0,:,:], cmap='gray')
    plt.title('output image')
    plt.axis('off')

In [None]:
test_imgs_path = glob.glob('D:/DeepLearning/image2image/test/40kev/*.nii.gz')

In [None]:
len(test_imgs_path)

In [None]:
test_annos_path = glob.glob('D:/DeepLearning/image2image/test/CI/*.nii.gz')

In [None]:
len(test_annos_path)

In [None]:
dataset_test = nii_dataset(test_annos_path, test_imgs_path)

In [None]:
dataloader_test = torch.utils.data.DataLoader(dataset_test, batch_size=BATCH_SIZE)

In [None]:
annos_batch, imgs_batch = next(iter(dataloader_test))

plt.figure(figsize=(10,20))
for i, (anno, img) in enumerate(zip(annos_batch[:4], imgs_batch[:4])):
    anno = (anno.numpy() + 1)/2
    img = (img.numpy() + 1)/2
    plt.subplot(4,2,2*i+1)
    plt.imshow(anno[0,:,:], cmap='gray')
    plt.title('input image')
    plt.axis('off')
    plt.subplot(4,2,2*i+2)
    plt.imshow(img[0,:,:], cmap='gray')
    plt.title('output image')
    plt.axis('off')

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
gen = UnetGenerator().to(device)
#gen = AttentionUnet(spatial_dims=2, in_channels=1, out_channels=1, channels=(32,64,128,256,512), strides=(2,2,2,2,1)).to(device)
#gen = UNet(spatial_dims=2,in_channels=1,out_channels=1,channels=(32,64,128,256,256),strides=(2,2,2,2),num_res_units=2).to(device)
#gen =  UNETR(in_channels=1,out_channels=1,img_size=(512,512),feature_size=16,hidden_size=768,mlp_dim=3072,num_heads=12,norm_name="instance",res_block=True,dropout_rate=0.0,spatial_dims=2).to(device)
dis = ConditionalDiscriminator().to(device)
loss_fn = nn.BCEWithLogitsLoss()
dis_optimizer = optim.Adam(dis.parameters(), lr=0.0002, betas=(0.5, 0.999))
gen_optimizer = optim.Adam(gen.parameters(), lr=0.0002, betas=(0.5, 0.999))

In [None]:
input = torch.rand(1,1,512,512).to(device)
output = gen(input)

import netron
import torch.onnx
input = torch.rand(1,1,512,512).to(device)
output = gen(input)
onnx_path = "netForwatch.onnx"
torch.onnx.export(gen, input, onnx_path,export_params=True,opset_version=11) #输入可视化模型

In [None]:
def set_window(image_array):
    image_array = image_array.astype(np.float32)
    image_array = -1000 + 2350 * (image_array - image_array.min())
    return image_array

In [None]:
def generater_images(model, test_input, true_traget):
    prediction = model(test_input).permute(0,2,3,1).detach().cpu().numpy()
    prediction = (prediction + 1)/2
    test_input = test_input.permute(0,2,3,1).detach().cpu().numpy()
    true_traget = true_traget.permute(0,2,3,1).detach().cpu().numpy()
    plt.figure(figsize=(15,15))
    display_list = [test_input[12], true_traget[12], prediction[12]]
    title = ['Input Image', 'True Mask', 'Predicted Mask']
    for i in range(3):
        plt.subplot(1, 3, i+1)
        plt.title(title[i])
        plt.imshow(display_list[i], cmap='gray')
        plt.axis('off')
    plt.show()

In [None]:
#LAMBDA = 7
imgs_batch = imgs_batch.to(device)
annos_batch = annos_batch.to(device)

In [None]:
class test_dataset(torch.utils.data.Dataset):
    def __init__(self, annos_path):
        self.annos_path = annos_path
    
    def __getitem__(self, index):
        anno_path = self.annos_path[index]
        anno = sitk.ReadImage(anno_path)
        anno_np = sitk.GetArrayFromImage(anno).astype(np.float32)
        anno_np = np.expand_dims(anno_np, axis=0)
        anno_tensor = transform(anno_np)
        return anno_tensor
    def __len__(self):
        return len(self.annos_path)

In [None]:
CI_test = test_dataset(test_annos_path)
test_dataloader = torch.utils.data.DataLoader(CI_test, batch_size = 1, shuffle = True)

In [None]:
D_loss = []
G_loss = []

for epoch in range(250):
    D_epoch_loss = 0
    G_epoch_loss = 0
    count = len(dataloader)
    for step, (annos, imgs) in enumerate(dataloader):
        imgs = imgs.to(device)
        annos = annos.to(device)
        for p in dis.parameters(): 
            p.data.clamp_(-0.01, 0.01) # clamp parameters between -0.01 and 0.01
        dis_optimizer.zero_grad()
        dis_real_output = dis(annos, imgs) #输入真实的成对图片
        dis_real_loss = loss_fn(dis_real_output, torch.ones_like(dis_real_output, device=device))# 希望真实的图片判定为1
        dis_real_loss.backward()
        gen_output = gen(annos)
        #
        dis_fake_output = dis(annos, gen_output.detach())#
        dis_fake_loss =  loss_fn(dis_fake_output, torch.zeros_like(dis_fake_output, device=device))
        dis_fake_loss.backward()

        dis_loss = dis_real_loss + dis_fake_loss
        dis_optimizer.step()

        if epoch < 10:
            LAMBDA = 0.5  
        elif epoch < 20:
            LAMBDA = 1
        elif epoch < 30:
            LAMBDA = 5  
        elif epoch <60:
            LAMBDA = 10
        else:
            LAMBDA = 20  
        gen_optimizer.zero_grad()
        dis_gen_output = dis(annos, gen_output)
        gen_loss_cross_entropy = loss_fn(dis_gen_output, torch.ones_like(dis_gen_output, device=device)) #
        gen_loss_L1 = torch.mean(torch.abs(imgs - gen_output))
        gen_loss = gen_loss_cross_entropy + (LAMBDA * gen_loss_L1)
        gen_loss.backward()
        gen_optimizer.step()

        with torch.no_grad():
            D_epoch_loss += dis_loss.item()
            G_epoch_loss += gen_loss.item()
        #    generater_images(gen, imgs_batch, annos_batch)
    with torch.no_grad():
        D_epoch_loss /= count
        G_epoch_loss /= count
        D_loss.append(D_epoch_loss)
        G_loss.append(G_epoch_loss)
        state = {'model':gen.state_dict(), 'optimizer':gen_optimizer.state_dict(), 'epoch':epoch}
        if epoch % 2 == 0:
            print('Epoch [{}/{}], D_loss: {:.4f}, G_loss: {:.4f}'.format(epoch, 200, D_epoch_loss, G_epoch_loss))
            generater_images(gen, annos_batch, imgs_batch)
            plt.plot(D_loss, label='D_loss')
            plt.plot(G_loss, label='G_loss')
            plt.legend()
            plt.show()
            torch.save(gen.state_dict(), "saved_models/gen_%d.pth" % epoch)
    #        torch.save(dis.state_dict(), 'saved_models/dis_%d.pth" % epoch')
            print('Saved model')
            temp = gen(next(iter(test_dataloader)).to(device)).permute(0,2,3,1).detach().cpu().numpy()
        #    temp = gen_output.permute(0,2,3,1).detach().cpu().numpy()
        #    sitk.WriteImage(sitk.GetImageFromArray(temp), './data/'+str(epoch)+'.nii.gz')
            sitk.WriteImage(sitk.GetImageFromArray(set_window(temp)), './data/'+str(epoch)+'.nii.gz')

In [None]:
D_loss = []
G_loss = []

for epoch in range(250):
    D_epoch_loss = 0
    G_epoch_loss = 0
    count = len(dataloader)
    for step, (annos, imgs) in enumerate(dataloader):
        imgs = imgs.to(device)
        annos = annos.to(device)
        for p in dis.parameters(): 
            p.data.clamp_(-0.01, 0.01) # clamp parameters between -0.01 and 0.01
        dis_optimizer.zero_grad()
        dis_real_output = dis(annos, imgs) #输入真实的成对图片
        dis_real_loss = loss_fn(dis_real_output, torch.ones_like(dis_real_output, device=device))# 希望真实的图片判定为1
        dis_real_loss.backward()
        gen_output = gen(annos)
        #
        dis_fake_output = dis(annos, gen_output.detach())#
        dis_fake_loss =  loss_fn(dis_fake_output, torch.zeros_like(dis_fake_output, device=device))
        dis_fake_loss.backward()

        dis_loss = dis_real_loss + dis_fake_loss
        dis_optimizer.step()
        gen_optimizer.zero_grad()
        dis_gen_output = dis(annos, gen_output)
        gen_loss_cross_entropy = loss_fn(dis_gen_output, torch.ones_like(dis_gen_output, device=device)) #
        gen_loss_L1 = torch.mean(torch.abs(imgs - gen_output))
        gen_loss = gen_loss_cross_entropy + (LAMBDA * gen_loss_L1)
        gen_loss.backward()
        gen_optimizer.step()

        with torch.no_grad():
            D_epoch_loss += dis_loss.item()
            G_epoch_loss += gen_loss.item()
        #    generater_images(gen, imgs_batch, annos_batch)
    with torch.no_grad():
        D_epoch_loss /= count
        G_epoch_loss /= count
        D_loss.append(D_epoch_loss)
        G_loss.append(G_epoch_loss)
        state = {'model':gen.state_dict(), 'optimizer':gen_optimizer.state_dict(), 'epoch':epoch}
        if epoch % 2 == 0:
            print('Epoch [{}/{}], D_loss: {:.4f}, G_loss: {:.4f}'.format(epoch, 200, D_epoch_loss, G_epoch_loss))
            generater_images(gen, annos_batch, imgs_batch)
            plt.plot(D_loss, label='D_loss')
            plt.plot(G_loss, label='G_loss')
            plt.legend()
            plt.show()
            torch.save(gen.state_dict(), "saved_models/gen_%d.pth" % epoch)
    #        torch.save(dis.state_dict(), 'saved_models/dis_%d.pth" % epoch')
            print('Saved model')
            temp = gen(next(iter(test_dataloader)).to(device)).permute(0,2,3,1).detach().cpu().numpy()
        #    temp = gen_output.permute(0,2,3,1).detach().cpu().numpy()
        #    sitk.WriteImage(sitk.GetImageFromArray(temp), './data/'+str(epoch)+'.nii.gz')
            sitk.WriteImage(sitk.GetImageFromArray(set_window(temp)), './data/'+str(epoch)+'.nii.gz')

In [None]:
class test_dataset(torch.utils.data.Dataset):
    def __init__(self, annos_path):
        self.annos_path = annos_path
    
    def __getitem__(self, index):
        anno_path = self.annos_path[index]
        anno = sitk.ReadImage(anno_path)
        anno_np = sitk.GetArrayFromImage(anno).astype(np.float32)
        anno_np = np.expand_dims(anno_np, axis=0)
        anno_tensor = transform(anno_np)
        return anno_tensor
    def __len__(self):
        return len(self.annos_path)

In [None]:
VMI_test = test_dataset(test_annos_path)
test_dataloader = torch.utils.data.DataLoader(VMI_test, batch_size = 1, shuffle = False)

In [None]:
test_dataloader = torch.utils.data.DataLoader(VMI_test, batch_size = 1, shuffle = False)

In [None]:
gen_dataset = 'D:/DeepLearning/image2image/gen_data'

In [None]:
#使用fid判断生成的图像与原图像的判断那个权重生成的图像最好。
#gen.load_state_dict(torch.load("saved_models/gen_74.pth")) #恢复torch的权重
gen.eval()

In [None]:
for i, image in enumerate(test_dataloader):
    gen_ID = gen(image.to(device))
    img_array = np.squeeze(gen_ID.cpu().data.numpy())
#    img_array = (img_array + 1)/2
    file_name = test_dataloader.dataset.annos_path[i].split('\\')[-1].split('.')[0]
    spacing = sitk.ReadImage(test_dataloader.dataset.annos_path[i]).GetSpacing()
    direction = sitk.ReadImage(test_dataloader.dataset.annos_path[i]).GetDirection()
    orign = sitk.ReadImage(test_dataloader.dataset.annos_path[i]).GetOrigin()
    temp = sitk.GetImageFromArray(set_window(img_array))
    temp.SetSpacing(spacing)
    temp.SetDirection(direction)
    temp.SetOrigin(orign)
    sitk_seg = sitk.Threshold(temp, lower=-1000, upper=3700, outsideValue=-1001) #设置中值滤波器
    sitk_median = sitk.MedianImageFilter()
    sitk_median.SetRadius(1)
    sitk_median = sitk_median.Execute(sitk_seg)
    sitk.WriteImage(sitk_median,os.path.join(gen_dataset,'gen_ID_%s.nii.gz'% (str(file_name))))  

In [None]:
for e in np.arange(150,250,2):
    gen.load_state_dict(torch.load("saved_models/gen_{}.pth".format(e))) #恢复torch的权重
    os.mkdir('D:/DeepLearning/image2image/gen_data/model_{}'.format(e))
    for i, image in enumerate(test_dataloader):
        gen_ID = gen(image.to(device))
        img_array = np.squeeze(gen_ID.cpu().data.numpy())
        file_name = test_dataloader.dataset.annos_path[i].split('\\')[-1].split('.')[0]
        spacing = sitk.ReadImage(test_dataloader.dataset.annos_path[i]).GetSpacing()
        direction = sitk.ReadImage(test_dataloader.dataset.annos_path[i]).GetDirection()
        orign = sitk.ReadImage(test_dataloader.dataset.annos_path[i]).GetOrigin()
        temp = sitk.GetImageFromArray(set_window(img_array))
        temp.SetSpacing(spacing)
        temp.SetDirection(direction)
        temp.SetOrigin(orign)
        sitk_seg = sitk.Threshold(temp, lower=-1000, upper=3700, outsideValue=-1001) #设置中值滤波器
        sitk_median = sitk.MedianImageFilter()
        sitk_median.SetRadius(5)
        sitk_median = sitk_median.Execute(sitk_seg)
        sitk.WriteImage(temp,os.path.join('D:/DeepLearning/image2image/gen_data/model_{}'.format(e),'gen_ID_%s.nii.gz'% (str(file_name))))  

In [None]:
#对图像进行中值滤波
path = 'D:/VSCODE/GAN_pytorch/gen_data/gen_ID_BaiYaoYao024.nii.gz'
image = sitk.ReadImage(path)
sitk_seg = sitk.Threshold(image, lower=-1000, upper=3700, outsideValue=-1001)
sitk_median = sitk.MedianImageFilter()
sitk_median.SetRadius(1)
sitk_median = sitk_median.Execute(sitk_seg)
sitk.WriteImage(sitk_median, 'D:/VSCODE/GAN_pytorch/median_filtering/gen_ID_BaiYaoYao024.nii.gz')