In [None]:
import toml
from torch.utils import data

from tools.celeba import CelebALoader
from dataset import Split

from models.models import create_model
from tools.toml import load_option
from tools.mask import mask_iter

In [None]:
header_opt = load_option('options/header.toml')
loader = CelebALoader(header_opt.data_root)
dataset = Split(loader, header_opt.fine_size)

In [None]:
train_data = dataset.train('bbox')

In [None]:
trainset = data.DataLoader(train_data,
                           batch_size=header_opt.batch_size,
                           shuffle=True)
maskset = mask_iter(header_opt.mask_root, header_opt.fine_size)
print('训练集数量：', len(trainset))
print('掩码数量：', len(maskset))

In [None]:
import time
import torch
import torchvision
model_opt = load_option('options/train2.toml')

model = create_model(model_opt)

In [None]:
def mask_op(mask):
    mask = mask.cuda()
    mask = mask[0][0]
    mask = torch.unsqueeze(mask, 0)
    mask = torch.unsqueeze(mask, 1)
    mask = mask.byte()
    return mask

In [None]:
# 载入已经训练的模型
load_epoch = 0
model.load(load_epoch)

In [None]:
opt = load_option('options/train.toml')
total_steps = 0
iter_start_time = time.time()
for epoch in range(opt.epochs):
    epoch_start_time = time.time()
    epoch_iter = 0
    trainset = data.DataLoader(train_data,
                               batch_size=opt.batch_size,
                               shuffle=True)
    for batch, mask in zip(trainset, maskset):
        image = batch[0]
        mask = mask_op(mask)
        total_steps += header_opt.batch_size
        epoch_iter += header_opt.batch_size
        # it not only sets the input data with mask, but also sets the latent mask.
        model.set_input(image, mask)
        model.set_gt_latent()
        model.optimize_parameters()
        if total_steps % opt.display_freq == 0:
            real_A, real_B, fake_B = model.get_current_visuals()
            # real_A=input, real_B=ground truth fake_b=output
            pic = (torch.cat([real_A, real_B, fake_B], dim=0) + 1) / 2.0
            save_image_path = f"{opt.save_dir}/epoch{epoch}-{total_steps}.jpg"
            torchvision.utils.save_image(pic, save_image_path, nrow=2)
        if total_steps % 20 == 0:
            errors = model.get_current_errors()
            t = (time.time() - iter_start_time) / opt.batch_size
            print(f"Epoch/total_steps: {epoch}/{total_steps}", dict(errors))
    if epoch % opt.save_epoch_freq == 0:
        print('保存模型 Epoch {epoch}, iters {total_steps')
        model.save(epoch)
    print('Epoch/Epochs {epoch}/{opt.epochs} 花费时间：{time.time() - epoch_start_time}s')
    model.update_learning_rate()