In [None]:
# 必要なパッケージのインストール
!pip install torch torchvision
!pip install tqdm



In [None]:
# Google Driveとのデータのやり取り
from google.colab import drive
drive_dir = '/content/drive'
drive.mount(drive_dir)

Mounted at /content/drive


In [None]:
# パッケージのインストール
import os
import sys
import time
import datetime
import struct
import cv2
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data
import torchvision
from torch.utils.tensorboard import SummaryWriter

In [None]:
# よく使うファイル・ディレクトリ
colab_dir = os.path.join(drive_dir, 'My Drive', 'Colab Notebooks')
model_path = os.path.join(colab_dir, 'DCGAN_mnist.pth')

In [None]:
# データセットの読み取り
class MnistDataset(torch.utils.data.Dataset):
    def __init__(self, root_dir, mode='train'):
        super(MnistDataset, self).__init__()

        self.root_dir = root_dir
        self.mode = mode
        self.n_classes = 10

        if self.mode == 'train':
            self.image_file = 'train-images-idx3-ubyte'
            self.label_file = 'train-labels-idx1-ubyte'
        elif self.mode == 'test':
            self.image_file = 't10k-images-idx3-ubyte'
            self.label_file = 't10k-labels-idx1-ubyte'
        else:
            raise Exception('MNIST dataset mode must be "train" or "test"')
        
        self.image_data = self._load_images(os.path.join(self.root_dir, self.image_file))
        self.label_data = self._load_labels(os.path.join(self.root_dir, self.label_file))

    def __len__(self):
        return len(self.image_data)

    def __getitem__(self, idx):
        return {
            'images': self.image_data[idx],
            'labels': self.label_data[idx]
        }

    def _load_images(self, filename):
        with open(filename, 'rb') as fp:
            magic = struct.unpack('>i', fp.read(4))[0]
            if magic != 2051:
                raise Exception('Magic number does not match!')

            n_images, height, width = struct.unpack('>iii', fp.read(4 * 3))

            n_pixels = n_images * height * width
            pixels = struct.unpack('>' + 'B' * n_pixels, fp.read(n_pixels))
            pixels = np.asarray(pixels, dtype='uint8').reshape((n_images, 1, height, width))

            # 画像サイズを2べきにしておく
            pixels = np.pad(pixels, [(0, 0), (0, 0), (2, 2), (2, 2)], mode='constant', constant_values=0)
            pixels = (pixels / 255.0).astype('float32')

        return pixels

    def _load_labels(self, filename):
        with open(filename, 'rb') as fp:
            magic = struct.unpack('>i', fp.read(4))[0]
            if magic != 2049:
                raise Exception('Magic number does not match!')

            n_labels = struct.unpack('>i', fp.read(4))[0]
            labels = struct.unpack('>' + 'B' * n_labels, fp.read(n_labels))
            labels = np.asarray(labels, dtype='int64')

        return labels


In [None]:
# 基本処理
class BlockG(nn.Module):
    """ Basic convolution block for generator (Conv, BN, ReLU) """
    def __init__(self, in_channels, out_channels):
        super(BlockG, self).__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=False)
        )

    def forward(self, x):
        return self.net(x)

class BlockD(nn.Module):
    """ Basic convolution block for discriminator (Conv, LeakyReLU) """
    def __init__(self, in_channels, out_channels):
        super(BlockD, self).__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1),
            nn.LeakyReLU(0.2)
        )

    def forward(self, x):
        return self.net(x)

class Up(nn.Module):
    """ Up-sampling """
    def __init__(self, in_channels, out_channels):
        super(Up, self).__init__()
        self.net = nn.Sequential(
            nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=False)
        )

    def forward(self, x):
        return self.net(x)

class Down(nn.Module):
    """ Down-sampling """
    def __init__(self, in_channels, out_channels):
        super(Down, self).__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.2)
        )

    def forward(self, x):
        return self.net(x)


In [None]:
# Generatorの定義
class NetG(nn.Module):
    def __init__(self, in_features=128, out_channels=3, base_filters=8):
        super(NetG, self).__init__()
        self.in_features = in_features
        self.out_channels = out_channels
        self.base_filters = base_filters

        self.net = nn.Sequential(
            nn.ConvTranspose2d(self.in_features, self.base_filters * 8,
                                kernel_size=4, stride=1, padding=0),
            nn.BatchNorm2d(self.base_filters * 8),
            nn.ReLU(inplace=True),
            Up(self.base_filters * 8, self.base_filters * 4),
            Up(self.base_filters * 4, self.base_filters * 2),
            Up(self.base_filters * 2, self.base_filters * 1),
            nn.Conv2d(base_filters * 1, self.out_channels,
                      kernel_size=3, stride=1, padding=1)
        )

    def forward(self, x):
        n_batches, n_dims = x.size()
        x = x.view(n_batches, n_dims, 1, 1)
        for conv in self.net:
            x = conv(x)
        return torch.tanh(x)

In [None]:
# Discriminatorの定義
class NetD(nn.Module):
    def __init__(self, in_channels=3, base_filters=8):
        super(NetD, self).__init__()
        self.in_channels = in_channels
        self.base_filters = base_filters

        self.net = nn.Sequential(
            BlockD(self.in_channels, self.base_filters),
            Down(self.base_filters, self.base_filters * 2),
            Down(self.base_filters * 2, self.base_filters * 4),
            Down(self.base_filters * 4, self.base_filters * 8),
            nn.Conv2d(self.base_filters * 8, 1, kernel_size=4, stride=1, padding=0)
        )

    def forward(self, x):
        x = self.net(x)
        return x.squeeze()  # BCELossWithLogitsを使うのでsigmoidに入れない

In [None]:
def to_onehot(cls, n_classes):
    ident = torch.eye(n_classes, dtype=torch.float32, device=cls.device)
    return ident[cls]

In [None]:
# 使用するデバイスの設定
if torch.cuda.is_available():
    device = torch.device('cuda', 0)
else:
    device = torch.device('cpu')
print('Device: {}'.format(device))

Device: cuda:0


In [None]:
# 各種パラメータ
sample_dims = 32            # zの次元
base_lr = 2.0e-4            # 学習率
beta1 = 0.5                 # Adamのbeta1
beta2 = 0.9                 # Adamのbeta2
base_filters = 8            # CNNの基本チャンネル数
data_root = 'MNIST'  # データセットのディレクトリ

In [None]:
# データセットローダの準備
dataset = MnistDataset(os.path.join(colab_dir, data_root))
data_loader = torch.utils.data.DataLoader(dataset, batch_size=25, num_workers=4, shuffle=True, drop_last=True)

  cpuset_checked))


In [None]:
# ネットワークとoptimizerの定義
n_classes = dataset.n_classes
netD = NetD(in_channels=1 + n_classes, base_filters=base_filters)
netG = NetG(in_features=sample_dims + n_classes, out_channels=1, base_filters=base_filters)
netD.to(device)
netG.to(device)
optimD = torch.optim.Adam(netD.parameters(), lr=base_lr, betas=(beta1, beta2))
optimG = torch.optim.Adam(netG.parameters(), lr=base_lr, betas=(beta1, beta2))
criterion = nn.BCEWithLogitsLoss()

# モデルファイルの読み込み (続きから学習するときはreload_modelをTrueにする)
reload_model = False
start_epoch = 0
start_steps = 0
if reload_model:
    ckpt = torch.load(model_path)
    optimG.load_state_dict(ckpt['optimG'])
    optimD.load_state_dict(ckpt['optimD'])
    netG.load_state_dict(ckpt['netG'])
    netD.load_state_dict(ckpt['netD'])
    start_epoch = ckpt['epoch']
    start_steps = ckpt['steps']

In [None]:
# 学習の途中経過を保存するフォルダの作成
now = datetime.datetime.now()
time_stamp = now.strftime('%Y%m%d-%H%M%S')
runs_dir = os.path.join(colab_dir, 'runs')
log_dir = os.path.join(runs_dir, time_stamp)
os.makedirs(log_dir, exist_ok=True)

In [None]:
# 学習ループ
steps = start_steps
for epoch in range(start_epoch, 100):
    tqdm_iter = tqdm(data_loader, file=sys.stdout)
    for data in tqdm_iter:
        x_real = data['images'].to(device)
        c_real = data['labels'].to(device)
        n_batches, _, height, width = x_real.size()

        x_real = 2.0 * x_real - 1.0
        c_onehot = to_onehot(c_real, n_classes)
        c_onehot_tile = c_onehot.view(n_batches, -1, 1, 1).repeat(1, 1, height, width)

        netD.train()
        netG.train()

        # Discriminatorの学習
        optimD.zero_grad()

        z = torch.randn([n_batches, sample_dims], dtype=torch.float32, device=device)
        z = torch.cat([z, c_onehot], dim=1)
        x_fake = netG(z)
        x_fake = x_fake.detach()
        y_fake = netD(torch.cat([x_fake, c_onehot_tile], dim=1))
        y_real = netD(torch.cat([x_real, c_onehot_tile], dim=1))

        lossD = criterion(y_fake, torch.zeros_like(y_fake)) +\
                criterion(y_real, torch.ones_like(y_real))
        lossD.backward()

        optimD.step()

        # Generatorの学習
        optimG.zero_grad()

        z = torch.randn([n_batches, sample_dims], dtype=torch.float32, device=device)
        z = torch.cat([z, c_onehot], dim=1)
        x_fake = netG(z)
        y_fake = netD(torch.cat([x_fake, c_onehot_tile], dim=1))

        lossG = criterion(y_fake, torch.ones_like(y_fake))
        lossG.backward()

        optimG.step()

        # ロスを標準出力する
        tqdm_iter.set_description("epoch #{:d}, {:d} steps, lossD={:.4f}, lossG={:.4f}".format(epoch, steps, lossD.item(), lossG.item()))

        # 途中経過の保存
        if steps % 50 == 0:
            outfile = os.path.join(log_dir, 'x_real_{:03d}.jpg'.format(epoch))
            torchvision.utils.save_image(x_real * 0.5 + 0.5, outfile, nrow=5, padding=10)
            outfile = os.path.join(log_dir, 'x_fake_{:03d}.jpg'.format(epoch))
            torchvision.utils.save_image(x_fake * 0.5 + 0.5, outfile, nrow=5, padding=10)

            netG.eval()
            z = torch.randn([n_classes, sample_dims], dtype=torch.float32).to(device)
            c_onehot = torch.eye(n_classes, dtype=torch.float32).to(device)
            z = torch.cat([z, c_onehot], dim=1)
            x_fake = netG(z)
            
            outfile = os.path.join(log_dir, 'x_fake_class_{:03d}.jpg'.format(epoch))
            torchvision.utils.save_image(x_fake * 0.5 + 0.5, outfile, nrow=5, padding=10)


        steps += 1

    # 学習途中のモデルを保存
    ckpt = {
        'optimG': optimG.state_dict(),
        'optimD': optimD.state_dict(),
        'netG': netG.state_dict(),
        'netD': netD.state_dict(),
        'epoch': epoch,
        'steps': steps
    }
    torch.save(ckpt, model_path)

  0%|          | 0/2400 [00:00<?, ?it/s]

  cpuset_checked))


epoch #0, 2399 steps, lossD=0.9565, lossG=1.0555: 100%|██████████| 2400/2400 [00:49<00:00, 48.61it/s]
epoch #1, 4799 steps, lossD=1.1848, lossG=0.6473: 100%|██████████| 2400/2400 [00:47<00:00, 50.34it/s]
epoch #2, 7199 steps, lossD=0.9570, lossG=0.6864: 100%|██████████| 2400/2400 [00:48<00:00, 49.98it/s]
epoch #3, 9599 steps, lossD=1.5152, lossG=0.9667: 100%|██████████| 2400/2400 [00:47<00:00, 50.60it/s]
epoch #4, 11999 steps, lossD=0.9248, lossG=1.0520: 100%|██████████| 2400/2400 [00:47<00:00, 50.45it/s]
epoch #5, 14399 steps, lossD=0.9998, lossG=1.0292: 100%|██████████| 2400/2400 [00:47<00:00, 50.45it/s]
epoch #6, 16799 steps, lossD=1.2674, lossG=2.2031: 100%|██████████| 2400/2400 [00:47<00:00, 50.70it/s]
epoch #7, 19199 steps, lossD=1.1176, lossG=1.3245: 100%|██████████| 2400/2400 [00:47<00:00, 50.65it/s]
epoch #8, 21599 steps, lossD=0.5243, lossG=2.6681: 100%|██████████| 2400/2400 [00:47<00:00, 50.11it/s]
epoch #9, 23999 steps, lossD=0.6172, lossG=2.4913: 100%|██████████| 2400/2400

In [None]:
# 10x10の画像を作る
n_samples = 10

netG.eval()

z = torch.randn([n_samples, sample_dims], dtype=torch.float32, device=device)
z = z.repeat(1, n_classes).view(-1, sample_dims)
c_onehot = torch.eye(n_classes, dtype=torch.float32).to(device)
c_onehot = c_onehot.repeat(n_samples, 1)

z_and_c = torch.cat([z, c_onehot], dim=1)
x_fake = netG(z_and_c)

outfile = os.path.join(log_dir, 'x_fake_class_tile.jpg'.format(epoch))
torchvision.utils.save_image(x_fake * 0.5 + 0.5, outfile, nrow=n_samples, padding=10)