In [None]:
# 必要なパッケージのインストール
!pip install torch torchvision
!pip install tqdm

In [None]:
# Google Driveとのデータのやり取り
from google.colab import drive
drive_dir = '/content/drive'
drive.mount(drive_dir)

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# パッケージのインストール
import os
import sys
import time
import datetime
import struct
import cv2
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data
import torchvision
from torch.utils.tensorboard import SummaryWriter

In [None]:
# よく使うファイル・ディレクトリ
colab_dir = os.path.join(drive_dir, 'My Drive', 'Colab Notebooks')
model_path = 'DCGAN_flower.pth'

In [None]:
# データセットの読み取り
class OxfordFlowerDataset(torch.utils.data.Dataset):
    def __init__(self, root_dir, resize=128):
        super(OxfordFlowerDataset, self).__init__()

        self.root_dir = root_dir
        self.resize = resize
        self.image_list = [f for f in os.listdir(self.root_dir) if not f.startswith('.jpg')]
        self.image_list = [os.path.join(self.root_dir, f) for f in self.image_list]

    def __len__(self):
        return len(self.image_list)

    def __getitem__(self, idx):
        # 画像の読み込み
        image_file = self.image_list[idx]
        image = cv2.imread(image_file, cv2.IMREAD_COLOR)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        # 画像のクロップ
        h, w, _ = image.shape
        size = min(h, w)
        cx, cy = w // 2, h // 2
        sx = max(0, cx - size // 2)
        sy = max(0, cy - size // 2)
        image = image[sy:sy+size, sx:sx+size, :]

        image = cv2.resize(image, (self.resize, self.resize))
        image = (image / 255.0).astype('float32')
        image = np.transpose(image, axes=(2, 0, 1))
        return {
            'images': image
        }

In [None]:
# 基本処理
class BlockG(nn.Module):
    """ Basic convolution block for generator (Conv, BN, ReLU) """
    def __init__(self, in_channels, out_channels):
        super(BlockG, self).__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=False)
        )

    def forward(self, x):
        return self.net(x)

class BlockD(nn.Module):
    """ Basic convolution block for discriminator (Conv, LeakyReLU) """
    def __init__(self, in_channels, out_channels):
        super(BlockD, self).__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1),
            nn.LeakyReLU(0.1)
        )

    def forward(self, x):
        return self.net(x)

class Up(nn.Module):
    """ Up-sampling """
    def __init__(self, in_channels, out_channels):
        super(Up, self).__init__()
        self.net = nn.Sequential(
            nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=False)
        )

    def forward(self, x):
        return self.net(x)

class Down(nn.Module):
    """ Down-sampling """
    def __init__(self, in_channels, out_channels):
        super(Down, self).__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.1)
        )

    def forward(self, x):
        return self.net(x)


In [None]:
# Generatorの定義
class NetG(nn.Module):
    def __init__(self, in_features=128, out_channels=3, base_filters=8):
        super(NetG, self).__init__()
        self.in_features = in_features
        self.out_channels = out_channels
        self.base_filters = base_filters

        self.net = nn.Sequential(
            nn.ConvTranspose2d(self.in_features, self.base_filters * 16,
                                kernel_size=4, stride=1, padding=0),
            nn.BatchNorm2d(self.base_filters * 16),
            nn.ReLU(inplace=True),
            Up(self.base_filters * 16, self.base_filters * 16),
            Up(self.base_filters * 16, self.base_filters * 8),
            Up(self.base_filters * 8, self.base_filters * 4),
            Up(self.base_filters * 4, self.base_filters * 2),
            Up(self.base_filters * 2, self.base_filters * 1),
            nn.Conv2d(base_filters * 1, self.out_channels,
                      kernel_size=3, stride=1, padding=1)
        )

    def forward(self, x):
        n_batches, n_dims = x.size()
        x = x.view(n_batches, n_dims, 1, 1)
        x = self.net(x)
        return torch.tanh(x)

In [None]:
# Discriminatorの定義
class NetD(nn.Module):
    def __init__(self, in_channels=3, base_filters=8):
        super(NetD, self).__init__()
        self.in_channels = in_channels
        self.base_filters = base_filters

        self.net = nn.Sequential(
            BlockD(self.in_channels, self.base_filters),
            Down(self.base_filters, self.base_filters * 2),
            Down(self.base_filters * 2, self.base_filters * 4),
            Down(self.base_filters * 4, self.base_filters * 8),
            Down(self.base_filters * 8, self.base_filters * 16),
            Down(self.base_filters * 16, self.base_filters * 16),
            nn.Conv2d(self.base_filters * 16, 1, kernel_size=4, stride=1, padding=0)
        )

    def forward(self, x):
        x = self.net(x)
        return x.squeeze()  # BCELossWithLogitsを使うのでsigmoidに入れない

In [None]:
# 使用するデバイスの設定
if torch.cuda.is_available():
    device = torch.device('cuda', 0)
else:
    device = torch.device('cpu')
print('Device: {}'.format(device))

In [None]:
# 各種パラメータ
sample_dims = 32            # zの次元
base_lr = 2.0e-4            # 学習率
beta1 = 0.5                 # Adamのbeta1
base_filters = 32           # CNNの基本チャンネル数
data_root = 'OxfordFlower'  # データセットのディレクトリ
total_epochs = 20           # 総学習エポック数 (適宜増やす)

In [None]:
# ネットワークとoptimizerの定義
netD = NetD(in_channels=3, base_filters=base_filters)
netG = NetG(in_features=sample_dims, base_filters=base_filters)
netD.to(device)
netG.to(device)
optimD = torch.optim.Adam(netD.parameters(), lr=base_lr, betas=(beta1, 0.999))
optimG = torch.optim.Adam(netG.parameters(), lr=base_lr, betas=(beta1, 0.999))
criterion = nn.BCEWithLogitsLoss()

# モデルファイルの読み込み (続きから学習するときはresumeにフォルダ名を入れる)
resume = ''
start_epoch = 0
start_steps = 0
if resume != '':
    # 保存済みモデルから読み込み
    log_dir = os.path.join(colab_dir, 'runs', resume)
    ckpt = torch.load(os.path.join(log_dir, model_path))
    optimG.load_state_dict(ckpt['optimG'])
    optimD.load_state_dict(ckpt['optimD'])
    netG.load_state_dict(ckpt['netG'])
    netD.load_state_dict(ckpt['netD'])
    start_epoch = ckpt['epoch'] + 1
    start_steps = ckpt['steps']
else:
    # 学習の途中経過を保存するフォルダの作成
    now = datetime.datetime.now()
    time_stamp = now.strftime('%Y%m%d-%H%M%S')
    runs_dir = os.path.join(colab_dir, 'runs')
    log_dir = os.path.join(runs_dir, time_stamp)
    os.makedirs(log_dir, exist_ok=True)

In [None]:
# データセットローダの準備
dataset = OxfordFlowerDataset(os.path.join(colab_dir, data_root), resize=128)
data_loader = torch.utils.data.DataLoader(dataset, batch_size=25, num_workers=4, shuffle=True, drop_last=True)

In [None]:
# 学習ループ
steps = start_steps
for epoch in range(start_epoch, 100):
    tqdm_iter = tqdm(data_loader, file=sys.stdout)
    for data in tqdm_iter:
        x_real = data['images'].to(device)
        x_real = 2.0 * x_real - 1.0
        n_batches, _, _, _ = x_real.size()

        netD.train()
        netG.train()

        # Discriminatorの学習
        optimD.zero_grad()

        z = torch.randn([n_batches, sample_dims], dtype=torch.float32, device=device)
        x_fake = netG(z)
        x_fake = x_fake.detach()
        y_fake = netD(x_fake)
        y_real = netD(x_real)

        lossD = criterion(y_fake, torch.zeros_like(y_fake)) +\
                criterion(y_real, torch.ones_like(y_real))
        lossD.backward()

        optimD.step()

        # Generatorの学習
        optimG.zero_grad()

        z = torch.randn([n_batches, sample_dims], dtype=torch.float32, device=device)
        x_fake = netG(z)
        y_fake = netD(x_fake)

        lossG = criterion(y_fake, torch.ones_like(y_fake))
        lossG.backward()

        optimG.step()

        # ロスを標準出力する
        tqdm_iter.set_description("epoch #{:d}, {:d} steps, lossD={:.4f}, lossG={:.4f}".format(epoch, steps, lossD.item(), lossG.item()))

        # 途中経過の保存
        if steps % 50 == 0:
            outfile = os.path.join(log_dir, 'x_real_{:03d}.jpg'.format(epoch))
            torchvision.utils.save_image(x_real * 0.5 + 0.5, outfile, nrow=5, padding=10)
            outfile = os.path.join(log_dir, 'x_fake_{:03d}.jpg'.format(epoch))
            torchvision.utils.save_image(x_fake * 0.5 + 0.5, outfile, nrow=5, padding=10)

        steps += 1

    # 学習途中のモデルを保存
    ckpt = {
        'optimG': optimG.state_dict(),
        'optimD': optimD.state_dict(),
        'netG': netG.state_dict(),
        'netD': netD.state_dict(),
        'epoch': epoch,
        'steps': steps
    }
    torch.save(ckpt, os.path.join(log_dir, model_path))

In [None]:
# 10x10の画像を作る
rows = 10
cols = 10
netG.eval()

z = torch.randn([rows * cols, sample_dims], dtype=torch.float32, device=device)
x_fake = netG(z)

image_grid = torchvision.utils.make_grid(x_fake * 0.5 + 0.5, nrow=rows, padding=10)
image_grid = image_grid.detach().cpu().numpy()
image_grid = np.transpose(image_grid, axes=[1, 2, 0])

plt.figure(figsize=(15, 15))
plt.imshow(image_grid)
plt.show()

# 保存するときは以下をコメントアウト(適宜保存する名前は変更すること)
# cv2.imwrite('image_grid_10x10.png', image_grid)