In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import AdamW
from torch.optim.lr_scheduler import CyclicLR
import torchvision.utils as vutils

from utils.loss import ContentLoss, AdversialLoss
from utils.transforms import get_default_transforms, get_no_aug_transform
from utils.datasets import get_dataloader
from utils.transforms import get_pair_transforms
from torch.utils.tensorboard import SummaryWriter
from models.discriminator import Discriminator
from models.generator import Generator

from datetime import datetime
import numpy as np
import cv2
from PIL import Image
import matplotlib.pyplot as plt
import pickle
import os

In [8]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Config
batch_size = 8
image_size = 256
learning_rate = 1.5e-4
beta1, beta2 = (.5, .99)
weight_decay = 1e-4
epochs = 100

# Models
netD = Discriminator().to(device)
netG = Generator().to(device)

# netG = torch.load("./checkpoints/pretrained_netG.pth")
netG.load_state_dict(torch.load("./checkpoints/init_netG.pth"))

# 优化器配置
optimizerD = AdamW(netD.parameters(), lr=learning_rate, betas=(beta1, beta2), weight_decay=weight_decay)
optimizerG = AdamW(netG.parameters(), lr=learning_rate, betas=(beta1, beta2), weight_decay=weight_decay)

schedulerD = CyclicLR(optimizer=optimizerD, base_lr=learning_rate, max_lr=learning_rate*1e1, cycle_momentum=False)
schedulerG = CyclicLR(optimizer=optimizerG, base_lr=learning_rate, max_lr=learning_rate*1e1, cycle_momentum=False)

# Labels
cartoon_labels = torch.ones (batch_size, 1, image_size // 4, image_size // 4).to(device)
fake_labels    = torch.zeros(batch_size, 1, image_size // 4, image_size // 4).to(device)

# Loss functions 损失函数
content_loss = ContentLoss(omega = 0.1,device = device)
adv_loss     = AdversialLoss(cartoon_labels, fake_labels)
BCE_loss     = nn.BCELoss().to(device)
#内容损失、对抗损失和交叉熵损失

# Dataloaders
torch.manual_seed(1)
real_dataloader    = get_dataloader("./datasets/real_images/flickr_nuneaton", size = image_size, bs = batch_size)
cartoon_dataloader = get_dataloader("./datasets/cartoon_images_smoothed/Studio Ghibli", size = image_size, bs = batch_size, trfs=get_pair_transforms(image_size))


tracked_images = next(iter(real_dataloader)).to(device)
last_epoch = 0
last_i = 0




Save original tracked images for comparison.

In [9]:
original_images = tracked_images.detach().cpu()
grid = vutils.make_grid(original_images, padding=2, normalize=True, nrow=3)
plt.imsave(f"./results/original.png", np.transpose(grid, (1,2,0)).numpy())

In [None]:
# netG.load_state_dict(torch.load("./checkpoints/_trained_netG.pth"))
# netD.load_state_dict(torch.load("./checkpoints/_trained_netD.pth"))

In [None]:
# #训练状态回复
# with open("./checkpoints/iter_data.pickle", "rb") as handle:
#     last_epoch, last_i = pickle.load(handle)

In [None]:
# Lists to keep track of progress
img_list = []
G_losses = []
D_losses = []

start_epoch = last_epoch
start_i = last_i

print("Starting Training Loop...")
# 每个epoch
for epoch in range(start_epoch, epochs):
    # 加载迭代器
    real_dl_iter = iter(real_dataloader)
    cartoon_dl_iter = iter(cartoon_dataloader)
    iterations =  min(len(real_dl_iter), len(cartoon_dl_iter))
    
    for i in range(start_i, iterations):
        real_data = next(real_dl_iter)
        cartoon_data = next(cartoon_dl_iter)
        
        # （一）训练判别器
        netD.train()
        netG.eval()
        
        netD.zero_grad()
        
        edge_data = cartoon_data[:, :, :, :image_size]   
        cartoon_data = cartoon_data[:, :, :, image_size:]

        # Format batch.
        cartoon_data   = cartoon_data.to(device)
        edge_data      = edge_data.to(device)
        real_data      = real_data.to(device)

        # Generate image
        generated_data = netG(real_data)
        
        # Forward pass all batches through D.
        cartoon_pred   = netD(cartoon_data)      #.view(-1)
        edge_pred      = netD(edge_data)         #.view(-1)
        generated_pred = netD(generated_data)    #.view(-1)
        
        errD = adv_loss(cartoon_pred, generated_pred, edge_pred)
        errD.backward()
        
        D_x = torch.sigmoid(cartoon_pred).mean().item() # Should be close to 1

        optimizerD.step()

        # （二）训练G
        netG.train()
        netD.eval()

        for param in netD.parameters():
            param.requires_grad = False

        netG.zero_grad()

        generated_data = netG(real_data)
        generated_pred = netD(generated_data) #.view(-1)

        if generated_data.shape[1] == 1:
            generated_data = generated_data.repeat(1, 3, 1, 1)
        if real_data.shape[1] == 1:
            real_data = real_data.repeat(1, 3, 1, 1)

        generated_data = generated_data.clamp(0, 1)
        real_data = real_data.clamp(0, 1)
        cartoon_data = cartoon_data.clamp(0, 1)
        edge_data = edge_data.clamp(0, 1)

        generated_pred = torch.sigmoid(generated_pred)
        
        adv = BCE_loss(generated_pred, cartoon_labels)
        content = content_loss(generated_data, real_data)
        print(f"Adv loss: {adv.item():.4f}, Content loss: {content.item():.4f}")
        errG = adv + content

        errG.backward()

        D_G_z2 = torch.sigmoid(generated_pred).mean().item() # Should be close to 1

        optimizerG.step()

        # Output training stats
        if i % 20 == 0:
            with torch.no_grad():
                fake = netG(tracked_images).detach().cpu()
                
            grid = vutils.make_grid(fake, padding=2, normalize=True, nrow=3)
            time = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
            plt.imsave(f"./results/E{epoch}_i{i}.png", np.transpose(grid, (1,2,0)).numpy())
            img_list.append(grid)
            
            print(('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f\t%s'
                % (epoch, epochs, i, iterations, errD.item(), errG.item(), D_x, D_G_z2, time)).expandtabs(25) )
            
        # Save Losses for plotting later
        G_losses.append(errG.item())
        D_losses.append(errD.item())
        
        schedulerD.step()
        schedulerG.step()
        
        last_i = i
    
    # schedulerD.step()
    # schedulerG.step()
    start_i = 0
    last_epoch = epoch

Starting Training Loop...
Adv loss: 0.6826, Content loss: 0.0460
[0/100][0/62]            Loss_D: 2.2740           Loss_G: 0.7286           D(x): 0.5503             D(G(z)): 0.6237          2025-04-27_22-14-21
Adv loss: 0.6762, Content loss: 0.0373
Adv loss: 0.6686, Content loss: 0.0344
Adv loss: 0.6634, Content loss: 0.0380
Adv loss: 0.6548, Content loss: 0.0404
Adv loss: 0.6446, Content loss: 0.0426
Adv loss: 0.6327, Content loss: 0.0441
Adv loss: 0.6219, Content loss: 0.0398
Adv loss: 0.5982, Content loss: 0.0369
Adv loss: 0.5732, Content loss: 0.0386
Adv loss: 0.5076, Content loss: 0.0523
Adv loss: 0.4870, Content loss: 0.0548
Adv loss: 0.4510, Content loss: 0.0446
Adv loss: 0.4166, Content loss: 0.0588
Adv loss: 0.3827, Content loss: 0.0594
Adv loss: 0.3449, Content loss: 0.0584
Adv loss: 0.3195, Content loss: 0.0607
Adv loss: 0.2844, Content loss: 0.0650
Adv loss: 0.2732, Content loss: 0.0596
Adv loss: 0.2517, Content loss: 0.0574
Adv loss: 0.2192, Content loss: 0.0656
[0/100][20

KeyboardInterrupt: 

In [None]:
# Uncomment if you want to resume prevous traning.
torch.save(netG.state_dict(), "./checkpoints/_trained_netG.pth")
torch.save(netD.state_dict(), "./checkpoints/_trained_netD.pth")

In [None]:
#保存当前训练状态
with open("./checkpoints/iter_data.pickle", "wb") as handle:
    pickle.dump([last_epoch, last_i], handle)