<a href="https://colab.research.google.com/github/oyanrayring/HiAI-Engine/blob/master/wgan_gp_cnn.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [4]:
import time
from PIL import Image
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset
from torch.utils.tensorboard import SummaryWriter
import torchvision.transforms as transforms
from torchvision.utils import save_image
from torch.autograd import Variable

In [5]:
import os
IMAGE_DIR = 'drive/MyDrive/Colab Notebooks/acg_images/faces'
os.listdir('drive/MyDrive/Colab Notebooks/acg_images')

['faces']

In [6]:
# 训练轮数
N_EPOCHS = 3
# 批处理数量
BATCH_SIZE = 64
# 图片大小
IMAGE_SIZE = 96
# 学习率
LEARNING_RATE = 0.0002
# b1
BIAS_1 = 0.5
# b2
BIAS_2 = 0.999
# 维度
LATENT_DIM = 100
# 通道数
CHANNELS = 3
# 生成器生成5次后，鉴别器优化1次
N_CRITIC = 5
# 每200个BATCH后输出一次样例数据
SAMPLE_INTERVAL = 200

In [None]:
torch.manual_seed(123)

In [8]:
transform = transforms.Compose([
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])

In [9]:
class SourceDataset(Dataset):
    def __init__(self, root_dir):
        self.root_dir = root_dir
        self.image_path = os.listdir(self.root_dir)

    def __getitem__(self, idx):
        img_name = self.image_path[idx]
        img_path = os.path.join(self.root_dir, img_name)
        image_data = Image.open(img_path).convert('RGB')
        return transform(image_data)

    def __len__(self):
        return len(self.image_path)


In [10]:
dataloader = torch.utils.data.DataLoader(
    SourceDataset(IMAGE_DIR),
    batch_size=BATCH_SIZE,
    shuffle=True
)

In [11]:
os.makedirs("images", exist_ok=True)
os.makedirs("logs", exist_ok=True)
os.makedirs("models", exist_ok=True)

In [12]:
cuda = True if torch.cuda.is_available() else False

In [13]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.init_size = 12  # 计算得到的初始尺寸
        self.l1 = nn.Sequential(nn.Linear(100, 128 * self.init_size ** 2))
        self.conv_blocks = nn.Sequential(
            nn.BatchNorm2d(128),
            nn.Upsample(scale_factor=2),  # 12x12 -> 24x24
            nn.Conv2d(128, 128, 3, stride=1, padding=1),
            nn.BatchNorm2d(128, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Upsample(scale_factor=2),  # 24x24 -> 48x48
            nn.Conv2d(128, 64, 3, stride=1, padding=1),
            nn.BatchNorm2d(64, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Upsample(scale_factor=2),  # 48x48 -> 96x96
            nn.Conv2d(64, 32, 3, stride=1, padding=1),
            nn.BatchNorm2d(32, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(32, 3, 3, stride=1, padding=1),
            nn.Tanh()
        )

    def forward(self, z):
        out = self.l1(z)
        out = out.view(out.shape[0], 128, self.init_size, self.init_size)
        img = self.conv_blocks(out)
        return img


In [14]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        def discriminator_block(in_filters, out_filters, bn=True):
            block = [nn.Conv2d(in_filters, out_filters, 3, 2, 1), nn.LeakyReLU(0.2, inplace=True), nn.Dropout2d(0.25)]
            if bn:
                block.append(nn.BatchNorm2d(out_filters, 0.8))
            return block

        self.model = nn.Sequential(
            *discriminator_block(3, 16, bn=False),
            *discriminator_block(16, 32),
            *discriminator_block(32, 64),
            *discriminator_block(64, 128)
        )
        ds_size = 96 // 2 ** 4  # 确保尺寸一致
        self.adv_layer = nn.Sequential(nn.Linear(128 * ds_size ** 2, 1))

    def forward(self, img):
        out = self.model(img)
        out = out.view(out.shape[0], -1)
        validity = self.adv_layer(out)
        return validity


In [15]:
# Initialize generator and discriminator
generator = Generator()
discriminator = Discriminator()
if cuda:
  generator.cuda()
  discriminator.cuda()

In [16]:
lambda_gp = 10

In [17]:
# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=LEARNING_RATE, betas=(BIAS_1, BIAS_2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=LEARNING_RATE, betas=(BIAS_1, BIAS_2))


In [18]:
Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

In [19]:
# 梯度惩罚
def compute_gradient_penalty(D, real_samples, fake_samples):
    """Calculates the gradient penalty loss for WGAN GP"""
    # Random weight term for interpolation between real and fake samples
    alpha = Tensor(np.random.random((real_samples.size(0), 1, 1, 1)))
    # Get random interpolation between real and fake samples
    interpolates = (alpha * real_samples + ((1 - alpha) * fake_samples)).requires_grad_(True)
    d_interpolates = D(interpolates)
    fake = Variable(Tensor(real_samples.shape[0], 1).fill_(1.0), requires_grad=False)
    # Get gradient w.r.t. interpolates
    gradients = torch.autograd.grad(
        outputs=d_interpolates,
        inputs=interpolates,
        grad_outputs=fake,
        create_graph=True,
        retain_graph=True,
        only_inputs=True,
    )[0]
    gradients = gradients.view(gradients.size(0), -1)
    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
    return gradient_penalty


In [20]:
writer = SummaryWriter("logs/wgan_gp_cnn")

In [23]:
def train():
  batches_done = 0
  start_time = int(time.time())
  for epoch in range(N_EPOCHS):
    for i, imgs in enumerate(dataloader):
      # Configure input
      real_imgs = Variable(imgs.type(Tensor))

      # ---------------------
      #  Train Discriminator
      # ---------------------
      optimizer_D.zero_grad()
      # Sample noise as generator input
      z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], LATENT_DIM))))
      # Generate a batch of images
      fake_imgs = generator(z)
      # Real images
      real_validity = discriminator(real_imgs)
      # Fake images
      fake_validity = discriminator(fake_imgs)
      # Gradient penalty
      gp = compute_gradient_penalty(discriminator, real_imgs.data, fake_imgs.data)
      # Adversarial loss
      d_loss = -torch.mean(real_validity) + torch.mean(fake_validity) + lambda_gp * gp
      d_loss.backward()
      optimizer_D.step()

      optimizer_G.zero_grad()
      # Train the generator every n_critic steps
      if i % opt.n_critic == 0:
        # -----------------
        #  Train Generator
        # -----------------
        # Generate a batch of images
        fake_imgs = generator(z)
        # Loss measures generator's ability to fool the discriminator
        # Train on fake images
        fake_validity = discriminator(fake_imgs)
        g_loss = -torch.mean(fake_validity)
        writer.add_scalars(
          "loss",
          {"generator": g_loss.item(), "discriminator": d_loss.item()},
          batches_done
        )

        g_loss.backward()
        optimizer_G.step()

        print(f"[Epoch {epoch}/{N_EPOCHS}] [Batch {i}/{len(dataloader)}] [D loss: {d_loss.item()}] [G loss: {g_loss.item()}]")
        # 保存样例
        if batches_done % opt.sample_interval == 0:
          save_image(fake_imgs.data[:25], "images/%d.png" % batches_done, nrow=5, normalize=True)
          # tensorboard中展示的图片
          writer.add_image("generate images", fake_imgs[0], batches_done)

        batches_done += opt.n_critic
        pass
    # end one epoch
    print(f"one cost {int(time.time()) - start_time} seconds .\n")
    if epoch % 10 == 0:
      # save model
      torch.save(generator.state_dict(), f"models/generator_{epoch}.pth")
      torch.save(discriminator.state_dict(), f"models/discriminator_{epoch}.pth")


In [24]:
train()
print("finished")

NameError: name 'opt' is not defined