In [1]:
import torch; torch.manual_seed(0)
import numpy as np
import matplotlib.pyplot as plt
import torch.nn as nn
from torch.utils.data import DataLoader,random_split
from function.Dir import Dir
from function.ssim_torch import ssim_function
from function.jsdiv import jsdiv,jsdiv_single
from function.Dataset import ImageDataset
from function.Loss import Custom_criterion
from function.Log import log
import torch.nn.functional as F
import importlib

EXP_NAME = "EXP_0_1"
MODEL_NAME = 'CNN'
traintype = 'src+pointsrc'

VAE_module = importlib.import_module(f'function.{MODEL_NAME}_{EXP_NAME}')
VAE = getattr(VAE_module, MODEL_NAME)

filepath = "/root/autodl-fs/SimulatedData/norandPSF_6.4x6.4.npy"
#filepath = "/root/autodl-fs/SimulatedData/GaussianSRC_tx.npy"
savefigname = 'CNN_375epo_norandpsf'
filetmp = np.load(filepath,allow_pickle=True)
filelen = filetmp.shape[0]
del filetmp
frac_train = 0.8
NUM_TO_LEARN = int(filelen*frac_train) #训练集放入图片对数量
NUM_TO_TEST = int(filelen*(1-frac_train))

EPOCHS = 500 #参数1
BATCH_SIZE = 32 #参数2
LATENTDIM = 64 #参数3
LR_MAX = 5e-4
LR_MIN = 5e-6
mode = 1 #0代表STED_HC文件训练，1代表使用STED，对应ImageDataset里的 mode 参数。（STED出的模型对泛化能力弱，STED_HC对训练集的还原会有点失真）

DEVICE = 'cuda'
LOSS_PLOT = []
TESTLOSS_PLOT = []
EPOCH_PLOT = []

In [None]:
name = f'{EPOCHS}epo_{BATCH_SIZE}bth_{LATENTDIM}latn_{traintype}'

#加载数据集
dataset = ImageDataset(NUM_TO_LEARN, mode,filepath)
trainset, testset = random_split(dataset,
                                            lengths=[int(frac_train *len(dataset)),
                                            len(dataset) - int(frac_train * len(dataset))],
                                            generator=torch.Generator().manual_seed(0))

dataloader = DataLoader(trainset,shuffle=True,batch_size=BATCH_SIZE)
testloader = DataLoader(testset,shuffle=False,batch_size=BATCH_SIZE)

#初始化VAE网络
vae = VAE(LATENTDIM).to(DEVICE)
vae.load_state_dict(torch.load(f'../models/{MODEL_NAME}_{EXP_NAME}_model_{name}.pth',weights_only=True))
#vae.load_state_dict(torch.load(f'../models/{EXP_NAME}_model_{name}.pth',weights_only=True))

#vae.reset_parameters()
#vae = nn.DataParallel(vae) #将 VAE 包装成一个并行化模型，以便在多个 GPU 上并行地进行训练

#定义LOSS函数与优化器
lossfunction = jsdiv
#optimizer = torch.optim.AdamW(vae.parameters(), lr = LR_MAX)

In [None]:
# 打印数据集中的一些样本
for batch_idx, (blurry_img, original_img) in enumerate(testloader):
    continue
    print(f"Batch {batch_idx + 1}:")
    print(f"Blurry image shape: {blurry_img.shape}")
    print(f"Original image shape: {original_img.shape}")
blurry_img_numpy = blurry_img[1].squeeze().detach().numpy()
print(blurry_img_numpy.min())
print(blurry_img_numpy.max())
print(blurry_img_numpy)

In [None]:
import sys 
sys.path.append('./function/')
from ssim_torch import ssim_function_valid

vae.cpu()
LOSS_SR = np.array([])
LOSS_BLU = np.array([])
valid_lossf = jsdiv_single#ssim_function_valid
for batch_idx, (blurry_img, original_img) in enumerate(testloader):
    img_sr,jpt,jpt = vae(blurry_img.detach())
    loss_sr = valid_lossf(img_sr,original_img).detach().cpu().numpy()
    loss_blurry = valid_lossf(blurry_img,original_img).detach().numpy()
    LOSS_SR = np.concat((LOSS_SR,loss_sr.flatten()))
    LOSS_BLU = np.concat((LOSS_BLU,loss_blurry.flatten()))
    #print(loss_sr.shape)

def hist(arr,color,nbins = 50,histtype = 'step',label = 'label'):
    #bins = np.logspace(np.log10(arr.min()),np.log10(arr.max()),nbins)
    #jpt = plt.hist(arr,bins = bins,density=True,histtype = histtype,color =color)
    #plt.xscale('log')
    bins = np.linspace((arr.min()),(arr.max()),nbins)
    jpt = plt.hist(arr,bins = bins,density=False,histtype = histtype,color =color,label = label)
    plt.legend()

plt.figure()
hist(LOSS_BLU,'red',label = 'blur')
hist(LOSS_SR,'blue',label = 'SR')
plt.xlim(0,0.002)
plt.savefig('/root/autodl-fs/Plots/%s_jsdiv.png'%savefigname,dpi=300)




In [None]:
import torch
import matplotlib.pyplot as plt
import scipy
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
num_images_to_show = 10

def interp2d(x1,x2,y1,y2,arr):
    x = np.arange(arr.shape[0])
    y = np.arange(arr.shape[0])
    xx,yy = np.meshgrid(x,y)
    interpolate = scipy.interpolate.RegularGridInterpolator((x,y),arr)
    y_t = np.linspace(x1,x2,101)
    x_t = np.linspace(y1,y2,101)
    z_t = interpolate((x_t,y_t))
    return z_t

vae.eval()  # 切换到评估模式
vae.to(DEVICE)

img_LR=[] 
img_HR=[]
img_SR=[]

showlist = [3,7]
num_images_to_show = len(showlist)
for i in showlist:
    item = trainset.__getitem__(i)
    img_LR.append(item[0])
    img_HR.append(item[1])
    img_SR.append(vae((item[0].reshape(1,1,64,64).to(DEVICE)))[0].cpu())
    
#for _, (img_LR, img_HR) in enumerate(testloader):
#    #print(img_LR.shape)
#    img_SR, _, _ = vae(img_LR.to(DEVICE))
#    img_SR = img_SR.cpu()
#    break


# 创建一个包含三列的图像网格，每一列显示一张模糊图像、超分辨率图像和原始图像
fig, axes = plt.subplots(num_images_to_show,5 , figsize=(15, 3 * num_images_to_show))

xys = np.zeros(num_images_to_show).tolist()
for i in range(num_images_to_show):
    xys[i] = [0,0,0,0]

#txdata
xys[0] = [35,45,55,35]
xys[1]=[25,45,45,60]

#randpsfdata
xys[0] = [0,20,0,15]
xys[1]=[20,55,45,25]

# 遍历每个样本
for i in range(num_images_to_show):
    color = 'white'
    x1,x2,y1,y2 = xys[i]
    # 获取模糊图像、超分辨率图像和原始图像
    blurry_img_numpy = img_LR[i].squeeze().detach().cpu().numpy()
    sr_img_numpy = img_SR[i].squeeze().detach().cpu().numpy()
    original_img_numpy = img_HR[i].squeeze().detach().cpu().numpy()
    
    blurry_img_numpy =blurry_img_numpy/blurry_img_numpy.sum()
    original_img_numpy=original_img_numpy/original_img_numpy.sum()
    sr_img_numpy =sr_img_numpy/sr_img_numpy.sum() 
    
    # 显示模糊图像
    im0=axes[i, 0].imshow(blurry_img_numpy)
    axes[i, 0].set_title('Blurry Image')
    #axes[i, 0].axis('off')
    axes[i,0].plot([x1,x2],[y1,y2],linestyle='--',color =color,linewidth = 2)
    #-----------
    axins = inset_axes(axes[i,0], width="20%", height="20%", loc=4)
    axins.axis('off')
    axins.patch.set_alpha(0)
    axins.plot(interp2d(x1,x2,y1,y2,blurry_img_numpy),color=color)

    # 显示超分辨率图像
    im1=axes[i, 1].imshow(sr_img_numpy)
    axes[i, 1].set_title('SR Image')
    #axes[i, 1].axis('off')
    axes[i,1].plot([x1,x2],[y1,y2],linestyle='--',color = color,linewidth = 2)
    #-----------
    axins = inset_axes(axes[i,1], width="20%", height="20%", loc=4)
    axins.axis('off')
    axins.patch.set_alpha(0)
    axins.plot(interp2d(x1,x2,y1,y2,sr_img_numpy),color=color)
    
    # 显示原始图像
    im2=axes[i, 2].imshow(original_img_numpy)
    axes[i, 2].set_title('Original Image')
    #axes[i, 2].axis('off')
    axes[i,2].plot([x1,x2],[y1,y2],linestyle='--',color = color,linewidth = 2)
    #-----------
    axins = inset_axes(axes[i,2], width="20%", height="20%", loc=4)
    axins.axis('off')
    axins.patch.set_alpha(0)
    axins.plot(interp2d(x1,x2,y1,y2,original_img_numpy),color=color)


    
    # res_blur
    res_blur = (blurry_img_numpy-original_img_numpy)
    res_sr = (sr_img_numpy-original_img_numpy)
    vmin = min(res_blur.min(),res_sr.min())
    vmax = max(res_blur.max(),res_sr.max())
    im3= axes[i, 3].imshow(res_blur,vmin =vmin,vmax=vmax)
    axes[i, 3].set_title('Res Blur')
    cbar2 = fig.colorbar(
        im3, ax=axes[i,3],shrink = 0.5
    )
    #axes[i, 3].axis('off')

    # res_sr
    im4 = axes[i, 4].imshow(res_sr,vmin =vmin,vmax=vmax)
    axes[i, 4].set_title('Res SR')
    #axes[i, 4].axis('off')
    cbar2 = fig.colorbar(
        im4, ax=axes[i,4],shrink = 0.5
    )

# 调整子图之间的间距
plt.tight_layout()

# 显示图像
fig.tight_layout()
plt.show()
fig.savefig('/root/autodl-fs/Plots/%s_figshow.png'%savefigname,dpi=300)
