In [1]:
import time
from pathlib import Path

from random import randint
from matplotlib import pyplot as plt

import torch as np
from torchvision.utils import save_image

from models.CSA import CSA
from tools.toml import load_option
from plot import array2image

from loader import loader


def mkdir(out_dir):
    out_dir = Path(out_dir)
    if not out_dir.exists():
        out_dir.mkdir(parents=True, exist_ok=True)


def mask_op(mask):
    mask = mask.cuda()
    mask = mask[0][0]
    mask = np.unsqueeze(mask, 0)
    mask = np.unsqueeze(mask, 1)
    mask = mask.byte()
    return mask

## 模型定义

In [2]:
# 超参数设定
## 固定参数
epochs = 15
display_freq = 200
save_epoch_freq = 1

## 模型参数
alpha = 1
beta = 0.2


model_name = f'CSA-{alpha}-{beta}'

In [3]:
base_opt = load_option('options/base.toml')
opt = load_option('options/train.toml')
opt.update(base_opt)
opt.update({'name': model_name}) # 设定模型名称
model = CSA(beta, **opt)

image_save_dir = model.save_dir / 'images'
mkdir(image_save_dir)

initialize network with normal
initialize network with normal
initialize network with normal
initialize network with normal
---------- Networks initialized -------------
UnetGeneratorCSA(
  (model): UnetSkipConnectionBlock_3(
    (model): Sequential(
      (0): Conv2d(6, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): UnetSkipConnectionBlock_3(
        (model): Sequential(
          (0): LeakyReLU(negative_slope=0.2, inplace=True)
          (1): Conv2d(64, 64, kernel_size=(4, 4), stride=(2, 2), padding=(3, 3), dilation=(2, 2))
          (2): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
          (3): LeakyReLU(negative_slope=0.2, inplace=True)
          (4): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (5): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
          (6): UnetSkipConnectionBlock_3(
            (model): Sequential(
              (0): LeakyReLU(negat

## 模型训练

In [4]:
# 训练阶段
start_epoch = 0
total_steps = 0
iter_start_time = time.time()
for epoch in range(start_epoch, epochs):
    epoch_start_time = time.time()
    epoch_iter = 0
    trainset = loader.trainset(alpha)
    for batch, mask in zip(trainset, loader.maskset):
        image = batch[0]
        mask = mask_op(mask)
        total_steps += model.batch_size
        epoch_iter += model.batch_size
        # it not only sets the input data with mask, but also sets the latent mask.
        model.set_input(image, mask)
        model.set_gt_latent()
        model.optimize_parameters()
        if total_steps % display_freq == 0:
            real_A, real_B, fake_B = model.get_current_visuals()
            # real_A=input, real_B=ground truth fake_b=output
            pic = (np.cat([real_A, real_B, fake_B], dim=0) + 1) / 2.0
            image_name = f"epoch{epoch}-{total_steps}-{alpha}.jpg"
            save_image(pic, image_save_dir/image_name, nrow=1)
        if total_steps % 100 == 0:
            errors = model.get_current_errors()
            t = (time.time() - iter_start_time) / model.batch_size
            print(
                f"Epoch/total_steps/alpha-beta: {epoch}/{total_steps}/{alpha}-{beta}", dict(errors))
    if epoch % save_epoch_freq == 0:
        print(f'保存模型 Epoch {epoch}, iters {total_steps} 在 {model.save_dir}')
        model.save(epoch)
    print(
        f'Epoch/Epochs {epoch}/{epochs-1} 花费时间：{time.time() - epoch_start_time}s')
    model.update_learning_rate()

Epoch/total_steps/alpha-beta: 0/100/1-0.2 {'G_GAN': 5.532593727111816, 'G_L1': 49.57136154174805, 'D': 0.9619242548942566, 'F': 0.08634133636951447}
Epoch/total_steps/alpha-beta: 0/200/1-0.2 {'G_GAN': 5.667086601257324, 'G_L1': 38.517059326171875, 'D': 1.0354654788970947, 'F': 0.07919453084468842}
Epoch/total_steps/alpha-beta: 0/300/1-0.2 {'G_GAN': 8.042593955993652, 'G_L1': 19.413557052612305, 'D': 0.25185173749923706, 'F': 0.06721243262290955}
Epoch/total_steps/alpha-beta: 0/400/1-0.2 {'G_GAN': 6.005727767944336, 'G_L1': 17.635417938232422, 'D': 0.7998839020729065, 'F': 0.02251690998673439}
Epoch/total_steps/alpha-beta: 0/500/1-0.2 {'G_GAN': 5.424815654754639, 'G_L1': 18.42581558227539, 'D': 1.126939058303833, 'F': 0.02911282517015934}
Epoch/total_steps/alpha-beta: 0/600/1-0.2 {'G_GAN': 5.823899745941162, 'G_L1': 25.875961303710938, 'D': 0.7367348670959473, 'F': 0.037544719874858856}
Epoch/total_steps/alpha-beta: 0/700/1-0.2 {'G_GAN': 5.889689922332764, 'G_L1': 16.254438400268555, 'D

KeyboardInterrupt: 