In [2]:
import os
import scipy.io as scio
import numpy as np
import matplotlib.pyplot as plt
import torch
from architecture import MST
from skimage.metrics import peak_signal_noise_ratio as compare_psnr
from skimage.metrics import structural_similarity as compare_ssim
def thre(a, tau):
    x = a
    x[ torch.abs(a) <= tau] = 0
    return x

################################ 全局设置 ################################
device = torch.device('cuda')
dtype = torch.cuda.FloatTensor
eigen_num, lr = 4, 0.0001
################################ 重新读取数据 ################################
Data = scio.loadmat("./SimuData.mat")
Clean, Mask, Cloud, Noisy = Data["Clean"], Data["Mask"], Data["Cloud"], Data["Noisy"]
M, N, = 384, 384
B, T = 4, 4
# 相邻四张图像为同一个时间，因而初始的维度为M, N, T, B
Clean = torch.from_numpy(Clean[:M, :N, :].reshape(M, N, T, B)).to(device)
Mask = torch.from_numpy(Mask[:M, :N, :].reshape(M, N, T, B)).to(device)      # 云=0
Cloud = torch.from_numpy(Cloud[:M, :N, :].reshape(M, N, T, B)).to(device)    # 云=1
Noisy = torch.from_numpy(Noisy[:M, :N, :].reshape(M, N, T, B)).to(device)

# 进行维度调整：M, N, B, T
Clean = torch.permute(Clean, [0, 1, 3, 2])
Mask = torch.permute(Mask, [0, 1, 3, 2])
Cloud = torch.permute(Cloud, [0, 1, 3, 2])
Noisy = torch.permute(Noisy, [0, 1, 3, 2])

# 将第二张观测图像设置为干净图像
Cloud[:, :, :, 1] = 0
Mask[:, :, :, 1] = 1
Noisy[:, :, :, 1] = Clean[:, :, :, 1]

YY = Noisy.clone()
Mask = Cloud.clone()
h, w = M, N

################################ 初始化 ################################
alpha, beta1, rho1, rho2 = 0.5, 0.5, 1, 0.1
M, N, B, T = YY.shape
SS = torch.zeros(YY.shape).type(dtype).to(device)
k_subspacce = 4
MM = torch.zeros((M, N, k_subspacce, T)).type(dtype).to(device)
Gam1 = torch.zeros((M, N, k_subspacce, T)).type(dtype).to(device)
Gam2 = torch.zeros((M, N, B, T)).type(dtype).to(device)
XX = torch.clone(YY).type(dtype).to(device)
Rec = torch.clone(YY).type(dtype).to(device)

beta2 = beta1
print( 'Alpha = %2.4f, beta1 = %2.3f, Rho1 = %2.4f, Rho2 = %2.4f\n'%(alpha,beta1, rho1,rho2))

index1 = (Mask != 0)
index2 = (Mask == 0)

A_sum = torch.zeros((B, k_subspacce, T)).type(dtype).to(device)   
Z_sum = torch.zeros((M*N, k_subspacce, T)).type(dtype).to(device)
T2 = torch.zeros((M, N, B, T)).type(dtype).to(device)

model = MST(dim=16).cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, betas=(0.9, 0.999))
iters = 0
order = 0
log_dir = "logs-order[%d]" % order
if not os.path.exists(log_dir):
    os.makedirs(log_dir)
############################# 迭代更新 #############################
while iters < 500:
    print("The %dth iter is start..." % iters)
    # 更新子问题F_k
    for i in range(T):
        Y = torch.reshape( XX[:, :, :, i] + Gam2[:, :, :, i]/rho2, (M*N, B) ).T
        Y = Y.type(dtype).to(device)
        if iters == 0:
            A, _, _ = torch.linalg.svd(Y, full_matrices=False)
            A = A[:, :k_subspacce]
            A_sum[:, :, i] = torch.clone(A)
        else:
            result = torch.mm(Z_sum[:, :, i].T, Y.T)
            A1, _, A2H = torch.linalg.svd(result, full_matrices=False)
            A2 = A2H.T
            A = torch.mm(A2, A1.T)
            A = A[:, :k_subspacce]
            A_sum[:, :, i] = torch.clone(A)
    
###################################################################
    ## 2 更新子问题——A_k， 使用gdd进行更新
    # net = gdd(input_depth, num_output_channels,
    #     num_channels_down = num_channels,
    #     num_channels_up = num_channels,
    #     num_channels_skip = num_channels,  
    #     filter_size_up = 3, filter_size_down = 3, filter_skip_size=1,
    #     upsample_mode='bilinear', # downsample_mode='avg',
    #     need1x1_up=False,
    #     need_sigmoid=True, need_bias=True, pad=pad, act_fun='LeakyReLU').type(dtype).to(device)
    # OPT_OVER =  'net'
    # p = get_params(OPT_OVER, net, noise)
    # optimizer = torch.optim.Adam(p, lr=lr)

    loss = 0
    torch.set_grad_enabled(True)
    epoch = 20
    loss_history = []
    for i in range(epoch):
        optimizer.zero_grad()
        # Z_sum = net(G, noise).reshape((M*N, k_subspacce, T))
        Y_temp = torch.permute(YY, [2, 3, 0, 1]).reshape(1, B * T, M, N).type(dtype).to(device)
        Mask_temp = torch.permute(Mask, [2, 3, 1, 0]).reshape(1, B*T, M, N).type(dtype).to(device)
        out = model(Y_temp, Mask_temp)
        Z_sum = torch.permute(out.reshape(B, T, M, N), [2, 3, 0, 1]).reshape(M*N, B, T).type(dtype).to(device)
        ## 计算损失值
        loss = 0
        for t in range(T):
            A = A_sum[:, :, t]
            Y = torch.reshape( XX[:, :, :, t] + Gam2[:, :, :, t]/rho2, (M*N, B) ).T
            Z = Z_sum[:, :, t].T

            part1 = rho2/2 * torch.norm( torch.mm(A.T, Y) - Z )**2
            part2 = rho1/2 * torch.norm( torch.reshape(MM[:, :, :, t] + Gam1[:, :, :,t] / rho1, (M*N, k_subspacce)).T - Z ) ** 2
            loss_i = part1 + part2
            loss += loss_i
        if i % 40 == 0:
            print("The %dth epoch, %dth iter, loss: %.5f" % (iters, i, loss.item()))
        loss.backward()
        loss_history.append(loss.item())
        YY = YY.detach()
        Rec = Rec.detach()
        SS = SS.detach()
        optimizer.step()
    
    for t in range(T):
        A = A_sum[:, :, t]
        Z = Z_sum[:, :, t].T
        Rec[:, :, :, t] = torch.reshape(torch.mm(A, Z).T, (M, N, B))
###############################################################
    
    YY = YY.detach()
    Rec = Rec.detach()
    SS = SS.detach()
    A_sum = A_sum.detach()
    Z_sum = Z_sum.detach()

    ## 3 更新子问题——W_k
    sv, p, r = 10, B, 36
    temp = torch.reshape( torch.reshape(Z_sum, (M, N, k_subspacce, T))  
            - Gam1/rho1, (M*N, k_subspacce*T))
    U, sigma, VH = torch.linalg.svd(temp, full_matrices=False)
    V = VH.T
    svp = min(sum(sigma > alpha/rho1), r)

    if svp < sv:
        sv = min(p, svp+1)
    else:
        sv = min(p, svp+torch.round(torch.tensor(0.05*p).type(dtype)))
    L_temp = torch.mm(U[:, :svp], torch.diag(sigma[:svp] - alpha/rho1))
    L = torch.mm(L_temp, V[:, :svp].T) 
    MM = L.reshape((M, N, k_subspacce, T)) 

    ## 4 更新子问题 C_K+1
    SS = thre(YY - XX, beta2)

    ## 5 更新子问题 X_k+1 
    XX_last = torch.clone(XX)
    XX = ( torch.mul((YY - SS), (1 - Mask)) + rho2*Rec - Gam2 ) / ((1 - Mask) + rho2*torch.ones(YY.shape).to(device))
    ## 6 更新子问题-拉格朗日乘子Pi和Qi
    Gam1 = Gam1 + rho1*(MM - Z_sum.reshape((M, N, k_subspacce, T)))
    Gam2 = Gam2 + rho2*(XX - Rec)

    iters += 1

    ####################### 评估指标，记录日志 #######################
    mse = torch.sum((XX - Clean) ** 2)
    Recover = torch.mul(YY, Mask) + torch.mul(XX, 1 - Mask)
    psnr_rec = compare_psnr(Recover.detach().cpu().numpy(), Clean.detach().cpu().numpy(), data_range=2)
    psnr = compare_psnr(XX.detach().cpu().numpy(), Clean.detach().cpu().numpy(), data_range=2)
    ssim = 0
    for k in range(B):
        for l in range(T):
            ssim += compare_ssim(XX[:, :, k, l].detach().cpu().numpy(), Clean[:, :, k, l].detach().cpu().numpy())
    ssim = ssim / (B * T)
    print("The %03dth iters, the %03dth epoch, loss: %.5f, mse: %.5f, psnr: %.5f, ssim: %.5f" % (
        iters, epoch, loss.item(), mse, psnr, ssim))

    image_Clean = Clean.cpu().numpy()
    image_Y = YY.cpu().numpy()
    image_X = XX.cpu().detach().numpy()
    image_C = SS.cpu().numpy()
    plt.figure(figsize=(20, 20))
    for i in range(4):
        plt.subplot(5, 4, 1 + i)
        plt.title("Clean")
        plt.imshow(image_Clean[:, :, [3, 2, 0], i])
        plt.axis('off')

        plt.subplot(5, 4, 5 + i)
        plt.title("X")
        plt.imshow(image_X[:, :, [3, 2, 0], i])
        plt.axis('off')

        plt.subplot(5, 4, 9 + i)
        plt.title("Y")
        plt.imshow(image_Y[:, :, [3, 2, 0], i])
        plt.axis('off')

        plt.subplot(5, 4, 13 + i)
        plt.title("Y-X")
        plt.imshow(image_Y[:, :, [3, 2, 0], i] - image_X[:, :, [3, 2, 0], i])
        plt.axis('off')

        plt.subplot(5, 4, 17 + i)
        plt.title("C")
        plt.imshow(image_C[:, :, 0, i], cmap='gray')
        plt.axis('off')
    result_path = "%s/iter[%03d, %03d]-pnsr[%.3f, %.3f]-ssim[%.3f]-mse[%.5E]-loss[%.5E].png" % (
        log_dir, iters, epoch, psnr, psnr_rec, ssim, mse, loss.item())
    plt.savefig(result_path)
    plt.clf()
    plt.figure(figsize=(30, 60))
    plt.subplot(2, 1, 1)
    plt.title("total loss")
    plt.plot([i for i in range(len(loss_history[50:]))], loss_history[50:])

    plt.subplot(2, 1, 2)
    plt.title("total loss")
    plt.plot([i for i in range(len(loss_history))], loss_history)

    loss_path = "%s/iter[%03d, %03d]-LR[%f]-losss[%.5E].png" % (
        log_dir, iters, epoch, lr, loss.item())
    plt.savefig(loss_path)
    plt.clf()

    plt.close()

Alpha = 0.5000, beta1 = 0.500, Rho1 = 1.0000, Rho2 = 0.1000

The 0th iter is start...
The 0th epoch, 0th iter, loss: 248641.96875


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


The 001th iters, the 020th epoch, loss: 2542497768807596032.00000, mse: 2554465059025179136.00000, psnr: -114.32458, ssim: 0.48047


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


The 1th iter is start...


KeyboardInterrupt: 

<Figure size 2000x2000 with 0 Axes>

In [None]:
from skimage.metrics import peak_signal_noise_ratio as compare_psnr
from skimage.metrics import structural_similarity as compare_ssim
## 计算评估指标
mse = torch.sum((XX - Clean)**2)
psnr = compare_psnr(XX.detach().cpu().numpy(), Clean.detach().cpu().numpy(), data_range=2)
ssim = 0
for k in range(B):
    for l in range(T):
        ssim += compare_ssim(XX[:, :, k, l].detach().cpu().numpy(), Clean[:, :, k, l].detach().cpu().numpy())
print("The %dth epoch, mse: %.5f, psnr: %.5f, ssim: %.5f" % (iters, mse, psnr, ssim))

In [None]:
# 可展示的图像：观测图像YY、云影SS、恢复图像XX、掩码Mask
image_Clean = Clean.cpu().numpy()
image_YY = YY.cpu().numpy()
Rec = XX.clone()
index_ = Mask == 0
Rec[index_] = YY[index_]
image_XX = Rec.cpu().numpy()
image_SS = SS.cpu().numpy()
plt.figure(figsize=(20,20))
for i in range(4):

    plt.subplot(4,4,1+i)
    plt.imshow(image_Clean[:,:, [3,2,0], i])
    plt.axis('off')

    plt.subplot(4,4,5+i)
    plt.imshow(image_YY[:,:,[3, 2, 0], i ])
    plt.axis('off')

    plt.subplot(4,4,9+i)
    plt.imshow(image_XX[:,:,[3, 2, 0],i])
    plt.axis('off')

    plt.subplot(4,4,13+i)
    plt.imshow((image_Clean-image_XX)[:,:,[3, 2, 0],i])
    plt.axis('off')
plt.show()
plt.clf()
plt.close()

    

In [None]:
from skimage.metrics import peak_signal_noise_ratio as compare_psnr
from skimage.metrics import structural_similarity as compare_ssim
## 计算评估指标
mse = torch.sum((Rec - Clean)**2)
psnr = compare_psnr(Rec.detach().cpu().numpy(), Clean.detach().cpu().numpy(), data_range=2)
ssim = 0
for k in range(B):
    for l in range(T):
        ssim += compare_ssim(Rec[:, :, k, l].detach().cpu().numpy(), Clean[:, :, k, l].detach().cpu().numpy())
print("The %dth epoch, mse: %.5f, psnr: %.5f, ssim: %.5f" % (iters, mse, psnr, ssim))