In [1]:
print('starting...')

starting...


# 基于GAN方法云层消除

In [2]:
import shutil
import time
import itertools

import matplotlib.pyplot as plt
import yaml
import os
import random

import torch
import cv2
import numpy as np
import albumentations as A
from albumentations.pytorch import ToTensorV2
from torch.utils import data
from attrdict import AttrMap
from torch import optim
from torch.backends import cudnn
from torch.autograd import Variable
from torch import nn
from torch.utils.data import DataLoader

from eval import test
from models.circle_gan.gen.gen import Generator
from models.circle_gan.dis.circle_dis import Discriminator

In [3]:
with open('config.yml', 'r', encoding='UTF-8') as f:
    config = yaml.load(f, Loader=yaml.FullLoader)
config = AttrMap(config)

## 数据增强

In [4]:
train_transform = A.Compose(
    [
        A.Resize(config.img_size, config.img_size),
    ])

val_transform = A.Compose(
    [
        A.Resize(config.img_size, config.img_size),
    ]
)

## 数据集定义

In [5]:
class TrainDataset(data.Dataset):
    # def train_list_init(self):
    #     files = os.listdir(os.path.join(config.datasets_dir, 'ground_truth'))
    #     random.shuffle(files)
    #     n_train = int(config.train_size * len(files))
    #     train_list = files[:n_train]
    #     test_list = files[n_train:]
    #     np.savetxt(os.path.join(config.datasets_dir, config.train_list), np.array(train_list), fmt='%s')
    #     np.savetxt(os.path.join(config.datasets_dir, config.test_list), np.array(test_list), fmt='%s')

    def __init__(self, config, img_list, transforms):
        super().__init__()
        self.config = config
        self.transforms = transforms
        # 如果数据集尚未分割，则进行训练集和测试集的分割
        # if not os.path.exists(train_list_file) or os.path.getsize(train_list_file) == 0:
        # train_list_init()

        self.img_list = img_list

    def __getitem__(self, index):
        t = cv2.imread(os.path.join(config.datasets_dir, 'ground_truth', str(img_list[index])), 1).astype(
            np.float32)
        x = cv2.imread(os.path.join(config.datasets_dir, 'cloudy_image', str(img_list[index])), 1).astype(
            np.float32)
        if self.transforms is not None:
            x = self.transforms(image=x)['image']
            t = self.transforms(image=t)['image']

        M = np.clip((t - x).sum(axis=2), 0, 1).astype(np.float32)
        x = x / 255
        t = t / 255
        x = x.transpose(2, 0, 1)
        t = t.transpose(2, 0, 1)

        return x, t, M

    def __len__(self):
        return len(img_list)

## 工具函数

In [6]:
from utils import *

## 数据集构建

In [7]:
from sklearn.model_selection import train_test_split

seed_manage(config)
print('===> Loading datasets')
train_list_file = os.path.join(config.datasets_dir, config.train_list)
assert (len(train_list_file) > 0)

img_list = np.loadtxt(train_list_file, str)
train_img_list, valid_img_list = train_test_split(img_list, random_state=42, test_size=config.validation_size)
train_dataset = TrainDataset(config, transforms=train_transform, img_list=train_img_list)
valid_dataset = TrainDataset(config, transforms=val_transform, img_list=valid_img_list)

print('train dataset:', len(train_dataset))
print('validation dataset:', len(valid_dataset))


Random Seed:  42
===> Loading datasets
train dataset: 400
validation dataset: 400


## 数据集测试

In [8]:
training_data_loader = DataLoader(dataset=train_dataset, num_workers=config.threads, batch_size=config.batchsize,
                                  shuffle=True)
validation_data_loader = DataLoader(dataset=valid_dataset, num_workers=config.threads,
                                    batch_size=config.validation_batchsize, shuffle=False)

## 模型构建

In [9]:
gen = Generator(gpu_ids=config.gpu_ids)

if config.gen_init is not None:
    param = torch.load(config.gen_init)
    gen.load_state_dict(param)
    print('load {} as pretrained model'.format(config.gen_init))

dis = Discriminator(in_ch=config.in_ch, out_ch=config.out_ch, gpu_ids=config.gpu_ids)

if config.dis_init is not None:
    param = torch.load(config.dis_init)
    dis.load_state_dict(param)
    print('load {} as pretrained model'.format(config.dis_init))

# setup optimizer
opt_gen = optim.Adam(gen.parameters(), lr=config.lr, betas=(config.beta1, 0.999), weight_decay=0.00001)
opt_dis = optim.Adam(dis.parameters(), lr=config.lr, betas=(config.beta1, 0.999), weight_decay=0.00001)

real_a = torch.FloatTensor(config.batchsize, config.in_ch, config.img_size, config.img_size)
real_b = torch.FloatTensor(config.batchsize, config.out_ch, config.img_size, config.img_size)
M = torch.FloatTensor(config.batchsize, config.img_size, config.img_size)

## 损失函数定义

In [10]:
from log_report import LogReport
from log_report import TestReport

criterionL1 = nn.L1Loss()
criterionMSE = nn.MSELoss()
criterionSoftplus = nn.Softplus()

if config.cuda:
    gen = gen.cuda()
    dis = dis.cuda()
    criterionL1 = criterionL1.cuda()
    criterionMSE = criterionMSE.cuda()
    criterionSoftplus = criterionSoftplus.cuda()
    real_a = real_a.cuda()
    real_b = real_b.cuda()
    M = M.cuda()

real_a = Variable(real_a)
real_b = Variable(real_b)



In [11]:
make_manager()
n_job = job_increment()

print('Job number: {:04d}'.format(n_job))

config.out_dir = os.path.join(config.out_dir, '{:06}'.format(n_job))
os.makedirs(config.out_dir)

logreport = LogReport(log_dir=config.out_dir)
validationreport = TestReport(log_dir=config.out_dir)


Job number: 0014


## 训练函数

In [12]:
def train():
    print('===> begin')
    start_time = time.time()
    # main
    for epoch in range(1, config.epoch + 1):
        epoch_start_time = time.time()
        for iteration, batch in enumerate(training_data_loader, 1):
            real_a_cpu, real_b_cpu, M_cpu = batch[0], batch[1], batch[2]
            real_a.resize_(real_a_cpu.size()).copy_(real_a_cpu)
            real_b.resize_(real_b_cpu.size()).copy_(real_b_cpu)
            M.resize_(M_cpu.size()).copy_(M_cpu)
            if config.use_attention:
                att, fake_b = gen.forward(real_a)
            else:
                fake_b = gen.forward(real_a)
            ################
            ### Update D ###
            ################

            opt_dis.zero_grad()

            # train with fake
            fake_ab = torch.cat((real_a, fake_b), 1)
            pred_fake = dis.forward(fake_ab.detach())
            batchsize, _, w, h = pred_fake.size()

            loss_d_fake = torch.sum(criterionSoftplus(pred_fake)) / batchsize / w / h

            # train with real
            real_ab = torch.cat((real_a, real_b), 1)
            pred_real = dis.forward(real_ab)
            loss_d_real = torch.sum(criterionSoftplus(-pred_real)) / batchsize / w / h

            # Combined loss
            loss_d = loss_d_fake + loss_d_real

            loss_d.backward()

            if epoch % config.minimax == 0:
                opt_dis.step()

            ################
            ### Update G ###
            ################

            opt_gen.zero_grad()

            # First, G(A) should fake the discriminator
            fake_ab = torch.cat((real_a, fake_b), 1)
            pred_fake = dis.forward(fake_ab)
            loss_g_gan = torch.sum(criterionSoftplus(-pred_fake)) / batchsize / w / h

            # Second, G(A) = B
            loss_g_l1 = criterionL1(fake_b, real_b) * config.lamb
            if config.use_attention:
                loss_g_att = criterionMSE(att[:, 0, :, :], M)
                loss_g = loss_g_gan + loss_g_l1 + loss_g_att
            else:
                loss_g = loss_g_gan + loss_g_l1

            loss_g.backward()

            opt_gen.step()

            # log
            if iteration % 10 == 0:
                print(
                    "===> Epoch[{}]({}/{}): loss_d_fake: {:.4f} loss_d_real: {:.4f} loss_g_gan: {:.4f} loss_g_l1: {:.4f}".format(
                        epoch, iteration, len(training_data_loader), loss_d_fake.item(), loss_d_real.item(),
                        loss_g_gan.item(), loss_g_l1.item()))

                log = {}
                log['epoch'] = epoch
                log['iteration'] = len(training_data_loader) * (epoch - 1) + iteration
                log['gen/loss'] = loss_g.item()
                log['dis/loss'] = loss_d.item()

                logreport(log)

        print('epoch', epoch, 'finished, use time', time.time() - epoch_start_time)

        with torch.no_grad():
            log_validation = test(config, validation_data_loader, gen, criterionMSE, epoch)
            validationreport(log_validation)
        print('validation finished')
        if epoch % config.snapshot_interval == 0:
            checkpoint(config, epoch, gen, dis)

        logreport.save_lossgraph()
        validationreport.save_lossgraph()
        print('training time:', time.time() - start_time)

In [16]:
def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm2d') != -1:
        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
        torch.nn.init.constant_(m.bias.data, 0.0)


def _lr_decay_step(G_lr_decay, D_A_lr_decay, D_B_lr_decay, current_iteration):
    G_lr_decay.step(current_iteration)
    D_A_lr_decay.step(current_iteration)
    D_B_lr_decay.step(current_iteration)


def train_cycle():
    device = torch.device('cuda')
    from circlegan.models import Generator, Discriminator
    criterion_GAN = torch.nn.MSELoss().cuda()
    criterion_cycle = torch.nn.L1Loss().cuda()
    criterion_identity = torch.nn.L1Loss().cuda()

    G_A2B = Generator(64, 9)
    D_B = Discriminator(config.img_size, 64, 4)
    G_A2B.apply(weights_init_normal)
    D_B.apply(weights_init_normal)
    G_A2B = torch.nn.DataParallel(G_A2B).to(device)
    D_B = torch.nn.DataParallel(D_B).to(device)

    G_B2A = Generator(64, 9)
    D_A = Discriminator(config.img_size, 64, 4)
    G_B2A.apply(weights_init_normal)
    D_A.apply(weights_init_normal)
    G_B2A = torch.nn.DataParallel(G_B2A).to(device)
    D_A = torch.nn.DataParallel(D_A).to(device)

    G_optimizer = optim.Adam(itertools.chain(G_B2A.parameters(), G_A2B.parameters()), lr=0.0002,
                             betas=[0.5, 0.999])
    D_A_optimizer = optim.Adam(D_A.parameters(), lr=0.0002, betas=[0.5, 0.999])
    D_B_optimizer = optim.Adam(D_B.parameters(), lr=0.0002, betas=[0.5, 0.999])

    G_lr_decay = optim.lr_scheduler.StepLR(G_optimizer, step_size=50000, gamma=0.1)
    D_A_lr_decay = optim.lr_scheduler.StepLR(D_A_optimizer, step_size=50000, gamma=0.1)
    D_B_lr_decay = optim.lr_scheduler.StepLR(D_B_optimizer, step_size=50000, gamma=0.1)

    ones = None
    zeros = None
    start_time = time.time()

    for epoch in range(1, config.epoch + 1):
        epoch_start_time = time.time()
        for iteration, batch in enumerate(training_data_loader, 1):
            print('iteration:{}'.format(iteration))
            real_a_cpu, real_b_cpu, _ = batch[0], batch[1], batch[2]
            real_a.resize_(real_a_cpu.size()).copy_(real_a_cpu)
            real_b.resize_(real_b_cpu.size()).copy_(real_b_cpu)

            if ones is None and zeros is None:
                ones = torch.ones_like(D_A(real_a))
                zeros = torch.zeros_like(D_A(real_b))

            _lr_decay_step(G_lr_decay, D_A_lr_decay, D_B_lr_decay, iteration)

            #########################################################################################################
            #                                                     Generator                                         #
            #########################################################################################################
            fake_a = G_B2A(real_b)
            fake_b = G_A2B(real_a)

            # gan loss
            gan_loss_a = criterion_GAN(D_A(fake_a), ones)
            gan_loss_b = criterion_GAN(D_B(fake_b), ones)
            gan_loss = (gan_loss_a + gan_loss_b) / 2.0

            # cycle loss
            cycle_loss_a = criterion_cycle(G_B2A(fake_b), real_a)
            cycle_loss_b = criterion_cycle(G_A2B(fake_a), real_b)
            cycle_loss = (cycle_loss_a + cycle_loss_b) / 2.0

            # idnetity loss
            identity_loss_a = criterion_identity(G_B2A(real_a), real_a)
            identity_loss_b = criterion_identity(G_A2B(real_b), real_b)
            identity_loss = (identity_loss_a + identity_loss_b) / 2.0

            # overall loss and optimize
            g_loss = gan_loss + 10 * cycle_loss + 5 * identity_loss

            loss_g = g_loss

            G_optimizer.zero_grad()
            g_loss.backward(retain_graph=True)
            G_optimizer.step()

            #########################################################################################################
            #                                                     Discriminator                                     #
            #########################################################################################################
            # discriminator a
            gan_loss_a_real = criterion_GAN(D_A(real_a), ones)
            gan_loss_a_fake = criterion_GAN(D_A(fake_a.detach()), zeros)
            gan_loss_a = (gan_loss_a_real + gan_loss_a_fake) / 2.0

            D_A_optimizer.zero_grad()
            gan_loss_a.backward()
            D_A_optimizer.step()

            # discriminator b
            gan_loss_b_real = criterion_GAN(D_B(real_b), ones)
            gan_loss_b_fake = criterion_GAN(D_B(fake_b.detach()), zeros)
            gan_loss_b = (gan_loss_b_real + gan_loss_b_fake) / 2.0

            D_B_optimizer.zero_grad()
            gan_loss_b.backward()
            D_B_optimizer.step()

            loss_d = gan_loss_b + gan_loss_a

            if iteration % 10 == 0:
                print(
                    "===> Epoch[{}]({}/{}): gen_loss: {:.4f} dis_loss: {:.4f}".format(
                        epoch, iteration, len(training_data_loader), loss_g.item(), loss_d.item()))
                log = {}
                log['epoch'] = epoch
                log['iteration'] = len(training_data_loader) * (epoch - 1) + iteration
                log['gen/loss'] = loss_g.item()
                log['dis/loss'] = loss_d.item()

                logreport(log)


        print('epoch', epoch, 'finished, use time', time.time() - epoch_start_time)
        with torch.no_grad():
            # G_A2B(real_a)
            log_validation = test(config, validation_data_loader, G_A2B, criterionMSE, epoch)
            validationreport(log_validation)
        print('validation finished')

        if epoch % config.snapshot_interval == 0:
            checkpoint(config, epoch, gen, dis)

        logreport.save_lossgraph()
        validationreport.save_lossgraph()
        print('training time:', time.time() - start_time)

## 训练过程

In [17]:
# 保存本次训练时的配置
shutil.copyfile('config.yml', os.path.join(config.out_dir, 'config.yml'))

train_cycle()


iteration:1




===> Avg. MSE: 0.4846
===> Avg. PSNR: 3.4307 dB
===> Avg. SSIM: 0.0302 dB
iteration:2
===> Avg. MSE: 0.3874
===> Avg. PSNR: 4.4027 dB
===> Avg. SSIM: 0.0430 dB
iteration:3
===> Avg. MSE: 0.3199
===> Avg. PSNR: 5.2144 dB
===> Avg. SSIM: 0.0508 dB
iteration:4
===> Avg. MSE: 0.2622
===> Avg. PSNR: 5.9608 dB
===> Avg. SSIM: 0.0642 dB
iteration:5


KeyboardInterrupt: 