In [1]:
import time
import toml
import torch as np
from torchvision.utils import save_image
from torch.utils import data

from tools.celeba import CelebALoader
from dataset import Split

from models.models import create_model
from tools.toml import load_option
from tools.mask import mask_iter

## 数据载入

In [2]:
header_opt = load_option('options/header.toml')
loader = CelebALoader(header_opt.data_root)
dataset = Split(loader, header_opt.fine_size)
train_data = dataset.train('bbox')
trainset = data.DataLoader(train_data,
                           batch_size=header_opt.batch_size,
                           shuffle=True)
maskset = mask_iter(header_opt.mask_root, header_opt.fine_size)
print('训练集数量：', len(trainset))
print('掩码数量：', len(maskset))

训练集数量： 162770
掩码数量： 12000


## 模型载入

In [3]:
def mask_op(mask):
    mask = mask.cuda()
    mask = mask[0][0]
    mask = np.unsqueeze(mask, 0)
    mask = np.unsqueeze(mask, 1)
    mask = mask.byte()
    return mask

In [4]:
model_opt = load_option('options/train2.toml')
model = create_model(model_opt)

csa_net
initialize network with normal
initialize network with normal
initialize network with normal
initialize network with normal
---------- Networks initialized -------------
UnetGeneratorCSA(
  (model): UnetSkipConnectionBlock_3(
    (model): Sequential(
      (0): Conv2d(6, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): UnetSkipConnectionBlock_3(
        (model): Sequential(
          (0): LeakyReLU(negative_slope=0.2, inplace=True)
          (1): Conv2d(64, 64, kernel_size=(4, 4), stride=(2, 2), padding=(3, 3), dilation=(2, 2))
          (2): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
          (3): LeakyReLU(negative_slope=0.2, inplace=True)
          (4): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (5): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
          (6): UnetSkipConnectionBlock_3(
            (model): Sequential(
              (0): LeakyRe

In [5]:
opt = load_option('options/train.toml')
total_steps = 0
iter_start_time = time.time()
# 载入已经训练的模型
load_epoch = 1
model.load(load_epoch)
start_epoch = load_epoch + 1
for epoch in range(start_epoch, opt.epochs):
    epoch_start_time = time.time()
    epoch_iter = 0
    trainset = data.DataLoader(train_data,
                               batch_size=opt.batch_size,
                               shuffle=True)
    for batch, mask in zip(trainset, maskset):
        image = batch[0]
        mask = mask_op(mask)
        total_steps += header_opt.batch_size
        epoch_iter += header_opt.batch_size
        # it not only sets the input data with mask, but also sets the latent mask.
        model.set_input(image, mask)
        model.set_gt_latent()
        model.optimize_parameters()
        if total_steps % opt.display_freq == 0:
            real_A, real_B, fake_B = model.get_current_visuals()
            # real_A=input, real_B=ground truth fake_b=output
            pic = (np.cat([real_A, real_B, fake_B], dim=0) + 1) / 2.0
            save_image_path = f"{opt.save_dir}/epoch{epoch}-{total_steps}.jpg"
            save_image(pic, save_image_path, nrow=3)
        if total_steps % 100 == 0:
            errors = model.get_current_errors()
            t = (time.time() - iter_start_time) / opt.batch_size
            print(f"Epoch/total_steps: {epoch}/{total_steps}", dict(errors))
    if epoch % opt.save_epoch_freq == 0:
        print('保存模型 Epoch {epoch}, iters {total_steps')
        model.save(epoch)
    print(f'Epoch/Epochs {epoch}/{opt.epochs} 花费时间：{time.time() - epoch_start_time}s')
    model.update_learning_rate()

Epoch/total_steps: 1/100 {'G_GAN': 5.090385437011719, 'G_L1': 10.439626693725586, 'D': 1.0404441356658936, 'F': 0.1432613879442215}
Epoch/total_steps: 1/200 {'G_GAN': 4.885292053222656, 'G_L1': 8.919397354125977, 'D': 1.0004996061325073, 'F': 0.0711200013756752}
Epoch/total_steps: 1/300 {'G_GAN': 5.131747245788574, 'G_L1': 10.092893600463867, 'D': 1.000459909439087, 'F': 0.043339915573596954}
Epoch/total_steps: 1/400 {'G_GAN': 5.212510108947754, 'G_L1': 10.455636978149414, 'D': 1.0184283256530762, 'F': 0.06712973117828369}
Epoch/total_steps: 1/500 {'G_GAN': 4.740256309509277, 'G_L1': 10.04600715637207, 'D': 1.0535664558410645, 'F': 0.08418521285057068}
Epoch/total_steps: 1/600 {'G_GAN': 5.032135963439941, 'G_L1': 9.241515159606934, 'D': 0.9591658115386963, 'F': 0.03271446377038956}
Epoch/total_steps: 1/700 {'G_GAN': 5.258563041687012, 'G_L1': 9.529699325561523, 'D': 1.5489401817321777, 'F': 0.015598400495946407}
Epoch/total_steps: 1/800 {'G_GAN': 5.310083866119385, 'G_L1': 11.713209152

Epoch/total_steps: 1/6300 {'G_GAN': 6.106054306030273, 'G_L1': 8.350011825561523, 'D': 0.4443332552909851, 'F': 0.011299334466457367}
Epoch/total_steps: 1/6400 {'G_GAN': 8.526281356811523, 'G_L1': 18.161998748779297, 'D': 0.08440662920475006, 'F': 0.006201015319675207}
Epoch/total_steps: 1/6500 {'G_GAN': 7.981949806213379, 'G_L1': 14.778793334960938, 'D': 0.11010622978210449, 'F': 0.005261803045868874}
Epoch/total_steps: 1/6600 {'G_GAN': 4.996937274932861, 'G_L1': 8.797558784484863, 'D': 1.6286365985870361, 'F': 0.007330423686653376}
Epoch/total_steps: 1/6700 {'G_GAN': 5.450842380523682, 'G_L1': 11.53435230255127, 'D': 0.8747591972351074, 'F': 0.005793140269815922}
Epoch/total_steps: 1/6800 {'G_GAN': 6.936707496643066, 'G_L1': 13.656780242919922, 'D': 0.3565933108329773, 'F': 0.007886327803134918}
Epoch/total_steps: 1/6900 {'G_GAN': 8.424558639526367, 'G_L1': 11.842146873474121, 'D': 0.17020317912101746, 'F': 0.008935798890888691}
Epoch/total_steps: 1/7000 {'G_GAN': 6.290061950683594, 

Epoch/total_steps: 2/12300 {'G_GAN': 5.829366683959961, 'G_L1': 9.35719108581543, 'D': 0.6759192943572998, 'F': 0.004774569533765316}


KeyboardInterrupt: 