In [None]:
from __future__ import print_function

import os

import matplotlib.pyplot as plt
import scipy.io as scio
import torch
import torch.optim
from skimage.metrics import peak_signal_noise_ratio as compare_psnr
from skimage.metrics import structural_similarity as compare_ssim

from architecture import MST

torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
dtype = torch.cuda.FloatTensor
# dtype = torch.cuda.DoubleTensor
# dtype = torch.double
device = torch.device('cuda')
torch.cuda.manual_seed(seed=666)


########################################### 辅助函数 ###########################################
def thres_21(L, tau, M, N, B):
    S = torch.sqrt(torch.sum(torch.mul(L, L), 2))
    S[S == 0] = 1
    T = 1 - tau / S
    T[T < 0] = 0
    R = T.reshape(M, N, 1).repeat((1, 1, B))
    res = torch.mul(R, L)
    return res


################################ 读取数据 ################################
Data = scio.loadmat("./Data.mat")
Clean, Noisy, Mask = Data["Clean"], Data["Cloud_image"], Data["Mask"]
Cloud = 1 - Mask
# Data = scio.loadmat("./SimuData.mat")
# Clean, Mask, Cloud, Noisy = Data["Clean"], Data["Mask"], Data["Cloud"], Data["Noisy"]

# M, N, = 384, 384
# M, N, = 256, 256
M, N = 400, 400
B, T = 4, 4
# 相邻四张图像为同一个时间，因而初始的维度为M, N, T, B
Clean = torch.from_numpy(Clean[:M, :N, :].reshape(M, N, T, B)).to(device)
Mask = torch.from_numpy(Mask[:M, :N, :].reshape(M, N, T, B)).to(device)  # 云=0
Cloud = torch.from_numpy(Cloud[:M, :N, :].reshape(M, N, T, B)).to(device)  # 云=1
Noisy = torch.from_numpy(Noisy[:M, :N, :].reshape(M, N, T, B)).to(device)

################################ 全局设置 ################################

def get_args(args_list):
    count = 0
    for i in range(len(args_list[0])):
        for j in range(len(args_list[1])):
            for k in range(len(args_list[2])):
                for l in range(len(args_list[3])):
                    # if not (i == j and j == k):
                    yield args_list[0][i], args_list[1][j], args_list[2][k], args_list[3][l]
                count += 1

def grid_combine(args_list):
    for lambda1 in args_list[0]:
        for lambda2 in args_list[1]:
            for rho in args_list[2]:
                    yield lambda1, lambda2, rho
args_list = [
    [20],       # ratio1
    [0.2],      # lambda2
    [0.05],     # rho
    [-0.5, -0.2, -0.1, -0.01, 0.01, 0.1, 0.2, 0.5],  # ratio2
]
condidates = [
    # ratio1, lambda2, rho, ratio2
    [20, 0.2, 0.05, 0.1],
]
iter_num, epoch_num = 200, 300
order = 1
for ratio1, lambda2, rho, ratio2 in grid_combine(args_list)):
    alpha = 1.03
    lambda1 = ratio1 * rho
    iters = 0
    order += 1
    log_dir = "./log-model33-sentinel2/order[%03d]-iter[%3d]-epoch[%3d]-rate[%.3f]-lambda2[%.3f]-rho[%.3f]" \
              % (order, iter_num, epoch_num, ratio1, lambda2, rho)
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)
    else:
        continue

    ################################ 初始化 ################################
    Y = Noisy.clone()
    X = Noisy.clone()
    W = Noisy.clone()
    C = torch.zeros(Y.shape).type(dtype).to(device)
    model = MST(dim=16, stage=2, num_blocks=[2, 2, 2]).cuda()
    psnr = 0
    ################################ 子问题更新 ################################
    while iters < iter_num:
        iters += 1
        ################################ 更新子问题X ################################
        X = (torch.mul(Mask, (Y - C)) + rho * ( W + ratio2) ) / (Mask + rho)

        ################################ 更新子问题W ################################
        U, s, VH = torch.linalg.svd(X.reshape(M * N, B * T) - ratio2, full_matrices=False)  # B*T, M*N
        s0 = s[0]
        print("s: ", s)
        s = s - lambda1 / rho
        s[s < 0] = 0
        S = torch.diag(s)
        W = torch.mm(torch.mm(U, S), VH).reshape(M, N, B, T)

        ################################ 更新子问题C ################################
        L = Y - torch.mul(X, Mask)
        for t in range(T):
            C[:, :, :, t] = thres_21(L[:, :, :, t], lambda2, M, N, B)

        mse = torch.sum((X - Clean) ** 2)
        Recover = torch.mul(Y, Mask) + torch.mul(X, 1 - Mask)
        psnr = []
        for t in range(T):
            psnr_t = 0
            for b in range(B):
                psnr_t += compare_psnr(Recover[:, :, b, t].detach().cpu().numpy(), Clean[:, :, b, t].detach().cpu().numpy(), data_range=1)
            psnr.append(psnr_t/B)
        psnr1, psnr2, psnr3, psnr4 = psnr
        # ssim = []
        # for t in range(3):
        #     ssim_t = 0
        #     for b in range(B):
        #         ssim_t += compare_ssim(X[:, :, b, t].detach().cpu().numpy(), Clean[:, :, b, t].detach().cpu().numpy())
        #     ssim.append(ssim_t / B)
        # ssim1, ssim2, ssim3 = ssim

        image_Clean = Clean.cpu().numpy() * 2
        image_Y = Y.cpu().numpy() * 2
        image_X = X.cpu().detach().numpy() * 2
        image_C = C.cpu().numpy() * 2
        plt.figure(figsize=(20, 20))
        for i in range(4):
            plt.subplot(5, 4, 1 + i)
            plt.title("Clean")
            plt.imshow(image_Clean[:, :, [3, 1, 0], i])
            plt.axis('off')

            plt.subplot(5, 4, 5 + i)
            plt.title("X")
            plt.imshow(image_X[:, :, [3, 1, 0], i])
            plt.axis('off')

            plt.subplot(5, 4, 9 + i)
            plt.title("Y")
            plt.imshow(image_Y[:, :, [3, 1, 0], i])
            plt.axis('off')

            plt.subplot(5, 4, 13 + i)
            plt.title("Y-X")
            plt.imshow(image_Y[:, :, [3, 1, 0], i] - image_X[:, :, [3, 1, 0], i])
            plt.axis('off')

            plt.subplot(5, 4, 17 + i)
            plt.title("C")
            plt.imshow(image_C[:, :, 0, i], cmap='gray')
            plt.axis('off')
        result_path = "%s/iter[%03d]-pnsr[%.3f, %.3f, %.3f, %.3f].png" % (
            log_dir, iters, psnr1, psnr2, psnr3, psnr4)
        plt.savefig(result_path)
        plt.clf()


    # 保存本次实验的恢复图像
    plt.close()
    scio.savemat("%s/recover_data.mat" % log_dir, {"recover_data": image_X})


In [None]:
scio.savemat("%s/0recover_data.mat" % log_dir, {"recover_data" : image_X})
a = scio.loadmat("%s/0recover_data.mat" % log_dir)["recover_data"]
# scio.savemat("%s/recover_data.mat" % log_dir, {"recover_data", X.detach().cpu().numpy()})

In [None]:
b = 39.56+45.52+39.35+38.35
b, b/4

In [None]:
for i in range(4):
    psnr = compare_psnr(X[:, :, :, i].detach().cpu().numpy(), Clean[:, :, :, i].detach().cpu().numpy(), data_range=2)
    print(psnr)