In [None]:
import warnings
warnings.filterwarnings("ignore")

# 进行训练
from visualdl import LogWriter
import os
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddle.io import DataLoader
from dataset.data_loader import TrainDataSet, ValidDataSet
from loss.Loss import LossWithGAN_STE, LossWithSwin
from models.swin_gan import STRnet2_change
import utils
import random
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
import math
%matplotlib inline


log = LogWriter('log_swin_v100')
def psnr(img1, img2):
   mse = np.mean((img1/1.0 - img2/1.0) ** 2 )
   if mse < 1.0e-10:
      return 100
   return 10 * math.log10(255.0**2/mse)


# 训练配置字典
CONFIG = {
    'numOfWorkers': 0,
    'modelsSavePath': 'train_models_swin_v100',
    'batchSize': 10,
    'traindataRoot': 'dataset/dataset',
    'validdataRoot': 'dataset/valid_dataset', 
    'pretrained': 'train_models_swin/STE_15_43.2223.pdparams',
    'num_epochs': 100,
    'net': 'str',
    'lr': 1e-4,
    'lr_decay_iters': 40000,
    'gamma': 0.5,
    'seed': 9420
}


# 设置gpu
if paddle.is_compiled_with_cuda():
    paddle.set_device('gpu:0')
else:
    paddle.set_device('cpu')


# 设置随机种子
random.seed(CONFIG['seed'])
np.random.seed(CONFIG['seed'])
paddle.seed(CONFIG['seed'])
# noinspection PyProtectedMember
paddle.framework.random._manual_program_seed(CONFIG['seed'])


batchSize = CONFIG['batchSize']
if not os.path.exists(CONFIG['modelsSavePath']):
    os.makedirs(CONFIG['modelsSavePath'])

traindataRoot = CONFIG['traindataRoot']
validdataRoot = CONFIG['validdataRoot']


TrainData = TrainDataSet(training=True, file_path=traindataRoot)
TrainDataLoader = DataLoader(TrainData, batch_size=batchSize, shuffle=True,
                             num_workers=CONFIG['numOfWorkers'], drop_last=True)
ValidData = ValidDataSet(file_path=validdataRoot)
ValidDataLoader = DataLoader(ValidData, batch_size=1, shuffle=True, num_workers=0, drop_last=True)


netG = STRnet2_change()


if CONFIG['pretrained'] is not None:
    print('loaded ')
    weights = paddle.load(CONFIG['pretrained'])
    netG.load_dict(weights)


# 开始直接上大火
lr = 2e-3
G_optimizer = paddle.optimizer.Adam(learning_rate=lr, parameters=netG.parameters())


loss_function = LossWithGAN_STE()


print('OK!')
num_epochs = CONFIG['num_epochs']
mse = nn.MSELoss()
best_psnr = 0
iters = 0


for epoch_id in range(1, num_epochs + 1):

    netG.train()

    if epoch_id % 8 == 0:
        # 每8个epoch时重置优化器，学习率变为1/10
        lr /= 10
        paddle.optimizer.Adam(learning_rate=lr, parameters=netG.parameters())

    for k, (imgs, gts, masks) in enumerate(TrainDataLoader):
        iters += 1

        fake_images, mm = netG(imgs)
        G_loss = loss_function(masks, fake_images, mm, gts)
        G_loss = G_loss.sum()

        #后向传播，更新参数的过程
        G_loss.backward()
        # 最小化loss,更新参数
        G_optimizer.step()
        # 清除梯度
        G_optimizer.clear_grad()

        # 打印训练信息
        if iters % 100 == 0:
            print('epoch{}, iters{}, loss:{:.5f}, net:{}, lr:{}'.format(
                epoch_id, iters, G_loss.item(), CONFIG['net'], G_optimizer.get_lr()
            ))
            log.add_scalar(tag="train_loss", step=iters, value=G_loss.item())

    # 对模型进行评价并保存
    netG.eval()
    val_psnr = 0

    # noinspection PyAssignmentToLoopOrWithParameter
    for index, (imgs, gt) in enumerate(ValidDataLoader):
        _, _, h, w = imgs.shape
        rh, rw = h, w
        step = 512
        pad_h = step - h if h < step else 0
        pad_w = step - w if w < step else 0
        m = nn.Pad2D((0, pad_w, 0, pad_h))
        imgs = m(imgs)
        _, _, h, w = imgs.shape
        res = paddle.zeros_like(imgs)
        mm_out = paddle.zeros_like(imgs)
        mm_in = paddle.zeros_like(imgs)
        input_array = []
        i_j_list = []
        for i in range(0, h, step):
            for j in range(0, w, step):
                if h - i < step:
                    i = h - step
                if w - j < step:
                    j = w - step
                clip = imgs[:, :, i:i + step, j:j + step]
                input_array.append(clip[0])
                i_j_list.append((i, j))

        # 并行处理进行加速
        input_array = paddle.to_tensor(input_array)
        input_array = input_array.cuda()
        with paddle.no_grad():
            g_images, mm = netG(input_array)
        g_images, mm = g_images.cpu(), mm.cpu()
        for idx in range(len(i_j_list)):
            i, j = i_j_list[idx]
            mm_in[:, :, i:i + step, j:j + step] = mm[idx]
            g_image_clip_with_mask = imgs[:, :, i:i + step, j:j + step] * (1 - mm[idx]) + g_images[idx] * mm[idx]
            res[:, :, i:i + step, j:j + step] = g_image_clip_with_mask
            mm_out[:, :, i:i + step, j:j + step] = mm[idx]

        # for i in range(0, h, step):
        #     for j in range(0, w, step):
        #         if h - i < step:
        #             i = h - step
        #         if w - j < step:
        #             j = w - step
        #         clip = imgs[:, :, i:i + step, j:j + step]
        #         clip = clip.cuda()
        #         with paddle.no_grad():
        #             g_images_clip, mm = netG(clip)
        #         g_images_clip = g_images_clip.cpu()
        #         mm = mm.cpu()
        #         clip = clip.cpu()
        #         mm_in[:, :, i:i + step, j:j + step] = mm
        #         # mm = paddle.where(F.sigmoid(mm) > 0.5, paddle.zeros_like(mm), paddle.ones_like(mm))
        #         # g_image_clip_with_mask = clip * mm + g_images_clip * (1 - mm)
        #         g_image_clip_with_mask = clip * (1 - mm) + g_images_clip * mm
        #         res[:, :, i:i + step, j:j + step] = g_image_clip_with_mask
        #         mm_out[:, :, i:i + step, j:j + step] = mm

        res = res[:, :, :rh, :rw]
        mm_out = mm_out[:, :, :rh, :rw]
        # 改变通道
        output = utils.pd_tensor2img(res)
        target = utils.pd_tensor2img(gt)
        mm_out = utils.pd_tensor2img(mm_out)
        mm_in = utils.pd_tensor2img(mm_in)

        psnr_value = psnr(output, target)
        print('psnr: ', psnr_value)

        if index in [2, 3, 5, 7, 11]:
            fig = plt.figure(figsize=(20, 10),dpi=100)
            # 图一
            ax1 = fig.add_subplot(2, 2, 1)  # 1行 2列 索引为1
            ax1.imshow(output)
            # 图二
            ax2 = fig.add_subplot(2, 2, 2)
            ax2.imshow(mm_in)
            # 图三
            ax3 = fig.add_subplot(2, 2, 3)
            ax3.imshow(target)
            # 图四
            ax4 = fig.add_subplot(2, 2, 4)
            ax4.imshow(mm_out)
            plt.show()

        del res
        del gt
        del target
        del output

        val_psnr += psnr_value
    ave_psnr = val_psnr / (index + 1)
    print('epoch:{}, psnr:{}'.format(epoch_id, ave_psnr))
    log.add_scalar(tag="valid_psnr", step=epoch_id, value=ave_psnr)
    paddle.save(netG.state_dict(), CONFIG['modelsSavePath'] +
                '/STE_{}_{:.4f}.pdparams'.format(epoch_id, ave_psnr
                ))
    if ave_psnr > best_psnr:
        best_psnr = ave_psnr
        paddle.save(netG.state_dict(), CONFIG['modelsSavePath'] + '/STE_best.pdparams')