In [None]:
# 使用するライブラリのimport
import os
import numpy as np

import torchvision.transforms as transforms
from torchvision.utils import save_image
from torchvision import datasets
from torch.autograd import Variable

import torch.nn as nn
import torch.nn.functional as F
import torch


In [None]:
# 定数の宣言
IMAGES_PATH = "images"                                  # 生成画像の出力先フォルダ
DATA_PATH = "data/mnist"                                # テストデータフォルダ

EPOCHS = 100                                             # エポック数
BATCH_SIZE = 64                                          # バッチサイズ

LEARNING_RATE = 0.0002                                 # Adam：学習率
B1 = 0.5                                                   # 勾配のモメンタム １
B2 = 0.999                                                # 勾配のモメンタム 2

LATENT_DIM = 62                                         # 潜在変数（z)の次元数
IMG_SIZE = 32                                             # 画像の幅と高さ
CHANNELS = 1                                             # 画像のチャンネル数（１＝グレー）
img_shape = (CHANNELS, IMG_SIZE, IMG_SIZE)             # イメージ画像（チャンネル、幅、高さ）

SAMPLE_INTERVAL = 50                                   # 生成画像サンプリングのタイミング

# EBGANハーパーパラメータ
LAMBDA_PT = 0.1
MARGIN = max(1, BATCH_SIZE / 64.0)

cuda = True if torch.cuda.is_available() else False          # GPUの利用チェック
print('GPU Check! cuda is ', cuda)


In [None]:
# 重みの初期化関数
def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find("BatchNorm2d") != -1:
        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
        torch.nn.init.constant_(m.bias.data, 0.0)

In [None]:
# Generatorの定義
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        self.init_size = IMG_SIZE // 4
        self.l1 = nn.Sequential(nn.Linear(LATENT_DIM, 128 * self.init_size ** 2))

        self.conv_blocks = nn.Sequential(
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 128, 3, stride=1, padding=1),
            nn.BatchNorm2d(128, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 64, 3, stride=1, padding=1),
            nn.BatchNorm2d(64, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, CHANNELS, 3, stride=1, padding=1),
            nn.Tanh(),
        )

    def forward(self, z):
        out = self.l1(z)
        out = out.view(out.shape[0], 128, self.init_size, self.init_size)
        img = self.conv_blocks(out)
        return img

In [None]:
# Discriminatorの定義
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.down = nn.Sequential(nn.Conv2d(CHANNELS, 64, 3, 2, 1), nn.ReLU())

        self.down_size = IMG_SIZE // 2
        down_dim = 64 * (IMG_SIZE // 2) ** 2

        self.embedding = nn.Linear(down_dim, 32)

        self.fc = nn.Sequential(
            nn.BatchNorm1d(32, 0.8),
            nn.ReLU(inplace=True),
            nn.Linear(32, down_dim),
            nn.BatchNorm1d(down_dim),
            nn.ReLU(inplace=True),
        )
        # Upsampling
        self.up = nn.Sequential(nn.Upsample(scale_factor=2), nn.Conv2d(64, CHANNELS, 3, 1, 1))

    def forward(self, img):
        out = self.down(img)
        embedding = self.embedding(out.view(out.size(0), -1))
        out = self.fc(embedding)
        out = self.up(out.view(out.size(0), 64, self.down_size, self.down_size))
        return out, embedding


In [None]:
# 損失関数：平均二乗誤差
pixelwise_loss = nn.MSELoss()

# GeneratorとDiscriminatorのインスタンス化
generator = Generator()
discriminator = Discriminator()

if cuda:
    generator.cuda()
    discriminator.cuda()
    pixelwise_loss.cuda()

# 重みの初期化
generator.apply(weights_init_normal)
discriminator.apply(weights_init_normal)

# MNISTデータの読み込み
os.makedirs(DATA_PATH, exist_ok=True)
dataloader = torch.utils.data.DataLoader(
    datasets.MNIST(
        DATA_PATH,
        train=True,
        download=True,
        transform=transforms.Compose(
            [transforms.Resize(IMG_SIZE), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
        ),
    ),
    batch_size=BATCH_SIZE,
    shuffle=True,
)

# 生成画像の出力フォルダの作成
os.makedirs(IMAGES_PATH, exist_ok=True)

In [None]:
# 最適化関数のセット
optimizer_G = torch.optim.Adam(generator.parameters(), lr=LEARNING_RATE, betas=(B1, B2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=LEARNING_RATE, betas=(B1, B2))

Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

In [None]:
# ----------
#  学習
# ----------
img_list = []
G_losses = []
D_losses = []

for epoch in range(EPOCHS):
    for i, (imgs, _) in enumerate(dataloader):

        # テスト画像
        real_imgs = Variable(imgs.type(Tensor))

        # -----------------
        #  Generatorの学習
        # -----------------
        optimizer_G.zero_grad()

        # サンプルノイズの生成(正規分布に従ったランダムな値)
        z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], LATENT_DIM))))

        # D(G(z))によるG(z)の損失を計算
        gen_imgs = generator(z)
        recon_imgs, img_embeddings = discriminator(gen_imgs)
        g_loss = pixelwise_loss(recon_imgs, gen_imgs.detach()) + LAMBDA_PT * pullaway_loss(img_embeddings)

        # 損失の誤差逆伝播
        g_loss.backward()
        optimizer_G.step()

        # ---------------------
        #  Discriminatorの学習
        # ---------------------
        optimizer_D.zero_grad()

        # D(x)と本物画像の損失を計算
        real_recon, _ = discriminator(real_imgs)
        d_loss_real = pixelwise_loss(real_recon, real_imgs)

        # D(G(z))とG(z)の損失を計算
        fake_recon, _ = discriminator(gen_imgs.detach())
        d_loss_fake = pixelwise_loss(fake_recon, gen_imgs.detach())

        # ヒンジ損失のマージン処理
        d_loss = d_loss_real
        if (MARGIN - d_loss_fake.data).item() > 0:
            d_loss += MARGIN - d_loss_fake

        # 損失の誤差逆伝播
        d_loss.backward()
        optimizer_D.step()

        print(
            "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
            % (epoch, EPOCHS, i, len(dataloader), d_loss.item(), g_loss.item())
        )

        batches_done = epoch * len(dataloader) + i
        if batches_done % SAMPLE_INTERVAL == 0:
            save_image(gen_imgs.data[:25], IMAGES_PATH + "/%07d.png" % batches_done, nrow=5, normalize=True)
            # ログ情報の収集
            G_losses.append(g_loss.item())
            D_losses.append(d_loss.item())
            img_list.append(gen_imgs.data[:25])

In [None]:
# 学習結果の可視化
import torchvision.utils as vutils
import matplotlib.pyplot as plt
plt.figure(figsize=(10,5))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(G_losses,label="G")
plt.plot(D_losses,label="D")
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.show()

# 生成した画像をアニメーションGIFとして保存する
imgs = []
    
# 保存しているtorch.TensorをPil Imageに変換する
i = 0
for img in img_list :
  if(i%10) == 0:
      img_temp = vutils.make_grid(img, padding=2, nrow=5, normalize=True)
      imgs.append(transforms.functional.to_pil_image(img_temp))
  i += 1
    
imgs[0].save(
    "./Generator.gif", save_all=True, append_images=imgs[1:], duration=500, loop=0
)