In [None]:
from __future__ import print_function

import os
import mat73
import numpy as np
import matplotlib.pyplot as plt
import scipy.io as scio
import torch
import torch.optim
from skimage.metrics import peak_signal_noise_ratio as compare_psnr
from skimage.metrics import structural_similarity as compare_ssim

# from architecture import MST
torch.cuda.manual_seed(seed=666)
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
dtype = torch.cuda.FloatTensor
device = torch.device('cuda')
torch.cuda.manual_seed(seed=666)

########################################### 辅助函数 ###########################################
def thres_21(L, tau, M, N, B):
    S = torch.sqrt(torch.sum(torch.mul(L, L), 2))
    S[S == 0] = 1
    T = 1 - tau / S
    T[T < 0] = 0
    R = T.reshape(M, N, 1).repeat((1, 1, B))
    res = torch.mul(R, L)
    return res


################################ 读取数据 ################################
YY= mat73.loadmat("./Data/T31UDQ/T31UDQ_20m_raw1.mat")['T31UDQ_20m_raw1'] 
YY = np.double(YY)
YY = YY/YY.max()
Mask1_500 = mat73.loadmat("./Data/Mask1_500.mat")['Mask1_500']
Mask2_500 = mat73.loadmat("./Data/Mask2_500.mat")['Mask2_500']

Mask = np.copy(Mask2_500)
Mask[ Mask != 0 ] = 1
# Test1_Clean是干净图像，取的是4，2，0三个时相
Test1_Clean = np.copy(YY[:, :, :, [4, 2, 0]])

# 生成云图YY，含有6个时相
Clean = YY[:, :, :, [4, 2, 0, 5, 3, 1]]
YY = np.multiply((1 - Mask), (YY[:, :, :, [4, 2, 0, 5, 3, 1]])) + Mask
Noisy = np.copy(YY)
Cloud = np.copy(Mask)
Mask = 1 - Cloud
###########################################################################

M, N, B, T = Noisy.shape
Clean = torch.from_numpy(Clean.reshape(M, N, B, T)).to(device)
Mask = torch.from_numpy(Mask.reshape(M, N, B, T)).to(device)  # 云=0
Cloud = torch.from_numpy(Cloud.reshape(M, N, B, T)).to(device)  # 云=1
Noisy = torch.from_numpy(Noisy.reshape(M, N, B, T)).to(device)


################################ 全局设置 ################################
def grid_combine(args_list):
    for rate in args_list[0]:
        for lambda2 in args_list[1]:
            for rho in args_list[2]:
                for alpha in args_list[3]:
                    if rate == 125 and rho ==0.08 and alpha==1.03 :
                        continue
                    yield rate, lambda2, rho, alpha
args_list = [
    [20, 30, 40, 50, 60],  # rate
    [0.5, 0.4, 0.6],  # lambda2
    [0.01, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3],  # rho
    [1.01, 1.02, 1.03, 1.04, 1.05]  # alpha
]
condidates = [
    # lambda1, lambda2, rho
    # [6.0, 0.1, 0.12, 1.03]     # condidate2
    # [12.5, 0.1, 0.1], 
    # [10.0, 0.1, 0.1], 
    # [7.50, 0.1, 0.1],  
    # [10.0, 0.2, 0.1], 
    # [10.0, 0.4, 0.1], 
]
condidates = [
    # rate, lambda2, rho, alpha
    # [90, 0.06, 0.06, 1.03],
    # [95, 0.06, 0.08, 1.03],
    # [88, 0.06, 0.08, 1.03],
    # [90, 0.08, 0.08, 1.03],
    # [30, 0.4, 0.06, 1.0],
    [60, 0.4, 0.06, 1.02],
    [60, 0.2, 0.06, 1.03],
    [50, 0.4, 0.06, 1.03],
]
iter_num, epoch_num = 150, 300
order = 0
# for rate, lambda2, rho, alpha in condidates:
for rate, lambda2, rho, alpha in grid_combine(args_list):
    # rate = lambda1 / rho
    lambda1 = rate * rho
    iters = 0
    order += 1
    log_dir = "./model3-picardie-acc-1/order[%03d]-iter[%3d]-epoch[%3d]-rate[%.3f]-lambda2[%.3f]-rho[%.3f,%.3f]" \
              % (order, iter_num, epoch_num, rate, lambda2, rho, alpha)
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)
    else:
        continue
    # rate, lambda2, rho, alpha = 125, 0.1, 0.08, 1.04
    ################################ 初始化 ################################
    Y = Noisy.clone()
    X = Noisy.clone()
    W = Noisy.clone()
    C = torch.zeros(Y.shape).type(dtype).to(device)
    M_updata = Mask.clone()
    # model = MST(dim=B*T, stage=2, num_blocks=[2, 2, 2]).cuda()
    psnr1 = 0
    s0_min = 10000
    ################################ 子问题更新 ################################
    while iters < iter_num:
        if iters >= 100 and psnr1 < 43:
            break
        iters += 1
        ################################ 更新子问题X ################################

        X = (torch.mul(M_updata, (Y - C)) + rho * W ) / (M_updata + rho)

        ################################ 更新子问题W ################################
        U, s, VH = torch.linalg.svd(X.reshape(M * N, B * T), full_matrices=False)  # B*T, M*N
        s0 = s[0]
        rho *= alpha
        # if s0<s0_min: # s0处于不断下降的第一阶段
        #     s0_min = s0
        #     rho *= alpha
        # else:
        #     rho *= 1.02
        rate = lambda1 / rho
        print("s: ", s)
        s = s - lambda1 / rho
        s[s < 0] = 0
        S = torch.diag(s)
        W = torch.mm(torch.mm(U, S), VH).reshape(M, N, B, T)
        ################################ 更新子问题C ################################
        L = Y - torch.mul(X, M_updata)
        for t in range(T):
            C[:, :, :, t] = thres_21(L[:, :, :, t], lambda2, M, N, B)

        ################################ 更新掩码M ################################
        # B_mean = torch.mean(Y-X, dim=2).reshape(M, N, 1, T).repeat((1, 1, B, 1))
        # M_updata[:, :, :, :3][torch.abs(B_mean[:, :, :, :3]) > 0.01] = 0   # M_update=1表示无云，M_update=0表示云
        # if iters > 5:
        #     M_updata[:, :, :, :3][torch.abs(B_mean[:, :, :, :3]) < 0.01] = 1

        ################################ 进行指标评估 ################################
        # mse = torch.sum((X - Clean) ** 2)
        psnr = []
        for t in range(3):
            psnr_t = 0
            for b in range(B):
                psnr_t += compare_psnr(X[:, :, b, t].detach().cpu().numpy(), Clean[:, :, b, t].detach().cpu().numpy(), data_range=1)
            psnr.append(psnr_t / B)
        ssim = []
        for t in range(3):
            ssim_b = 0
            for b in range(B):
                ssim_b += compare_ssim(X[:, :, b, t].detach().cpu().numpy(), Clean[:, :, b, t].detach().cpu().numpy())
            ssim.append(ssim_b / B)
        # print("The %03dth iters, the %03dth epoch, loss: %.5f, mse: %.5f, psnr: %.5f, ssim: %.5f" % (
        #     iters, epoch, loss.item(), mse, psnr, ssim))
        if iters % 2 == 0:
            image_Clean = Clean.cpu().numpy()
            image_Y = Y.cpu().numpy()
            image_X = X.cpu().detach().numpy()
            image_C = C.cpu().numpy()
            image_M = M_updata.cpu().numpy()
            plt.figure(figsize=(20, 20))
            for i in range(6):
                plt.subplot(6, 6, 1 + i)
                plt.title("Clean")
                plt.imshow(image_Clean[:, :, [4, 2, 0], i]*5)
                plt.axis('off')

                plt.subplot(6, 6, 7 + i)
                plt.title("X")
                plt.imshow(image_X[:, :, [4, 2, 0], i]*5)
                plt.axis('off')

                plt.subplot(6, 6, 13 + i)
                plt.title("Y")
                plt.imshow(image_Y[:, :, [4, 2, 0], i]*5)
                plt.axis('off')

                plt.subplot(6, 6, 19 + i)
                plt.title("Y-X")
                plt.imshow(image_Y[:, :, [4, 2, 0], i]*5 - image_X[:, :, [4, 2, 0], i]*5)
                plt.axis('off')

                plt.subplot(6, 6, 25 + i)
                plt.title("C")
                plt.imshow(image_C[:, :, 0, i]*5, cmap='gray')
                plt.axis('off')

                plt.subplot(6, 6, 31 + i)
                plt.title("M_update")
                plt.imshow(image_M[:, :, 0, i], cmap='gray')
                plt.axis('off')
            psnr1, psnr2, psnr3 = psnr
            ssim1, ssim2, ssim3 = ssim
            result_path = "%s/iter[%03d]-pnsr[%.3f, %.3f, %.3f]-ssim[%.3f, %.3f, %.3f]-s0, rate[%.3f, %.2f].png" % (
                log_dir, iters, psnr1, psnr2, psnr3, ssim1, ssim2, ssim3, s0, rate)
            plt.savefig(result_path)
            plt.clf()

            plt.close()
            scio.savemat("%s/recover_data.mat" % log_dir, {"recover_data": image_X})
    # decrease = True
    # up = False
