<a href="https://colab.research.google.com/github/shizoda/education/blob/main/image/dcgan_cifar10.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms
from torchvision.utils import save_image, make_grid
import matplotlib.pyplot as plt
import numpy as np

# ハイパーパラメータ
params = {
    "n_epochs": 300,      # エポック数 (CIFAR-10はFashionMNISTより複雑なので多めに設定)
    "batch_size": 128,    # バッチサイズ
    "lr": 0.0002,         # 学習率
    "beta1": 0.5,         # Adamオプティマイザのパラメータ
    "nz": 100,            # 潜在変数の次元数 (ノイズのサイズ)
    "ngf": 128,            # Generatorの特徴マップのサイズ
    "ndf": 64,            # Discriminatorの特徴マップのサイズ
    "image_size": 32,     # CIFAR-10の画像サイズは32x32
    "nc": 3,              # 画像のチャンネル数 (CIFAR-10はカラーなので3)
    "target_class_name": "automobile" # 学習したいクラス名
}

# GPUが利用可能か確認し、デバイスを設定
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

In [None]:
# 画像の前処理
transform = transforms.Compose([
    transforms.Resize(params["image_size"]),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # [-1, 1]に正規化
])

# データセットのダウンロード
full_dataset = datasets.CIFAR10(root="./data", train=True, download=True, transform=transform)

# CIFAR-10のクラス名
class_names = full_dataset.classes
print(f"CIFAR-10 class names: {class_names}")

# 学習対象のクラスのインデックスを取得
try:
    target_class_idx = class_names.index(params["target_class_name"])
    print(f"Targeting class: '{params['target_class_name']}' (index: {target_class_idx})")
except ValueError:
    print(f"Error: '{params['target_class_name']}' is not a valid CIFAR-10 class name. Please choose from {class_names}")
    exit()

# 特定のクラスのデータのみを抽出
indices = [i for i, label in enumerate(full_dataset.targets) if label == target_class_idx]
dataset = Subset(full_dataset, indices)
dataloader = DataLoader(dataset, batch_size=params["batch_size"], shuffle=True)

if len(dataset) == 0:
    print(f"Error: No images found for class '{params['target_class_name']}'. Please check the class name or dataset.")
    exit()
else:
    print(f"Number of images for '{params['target_class_name']}': {len(dataset)}")

# データの可視化 (確認用)
real_batch = next(iter(dataloader))
plt.figure(figsize=(8, 8))
plt.axis("off")
plt.title(f"Training Images ({params['target_class_name']})")
plt.imshow(np.transpose(make_grid(real_batch[0].to(device)[:64], padding=2, normalize=True).cpu(), (1, 2, 0)))
plt.show()

In [None]:
# 重みの初期化関数
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

# Generator
class Generator(nn.Module):
    def __init__(self, nz, ngf, nc):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            # 入力は潜在変数Z
            nn.ConvTranspose2d(nz, ngf * 4, 4, 1, 0, bias=False), # CIFAR-10は画像サイズが小さいのでConvTransposeの層を調整
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            # state size. (ngf*4) x 4 x 4
            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            # state size. (ngf*2) x 8 x 8
            nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            # state size. (ngf) x 16 x 16
            nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
            nn.Tanh() # 出力は-1から1の範囲
            # state size. (nc) x 32 x 32
        )

    def forward(self, input):
        return self.main(input)

In [None]:
# Discriminator
class Discriminator(nn.Module):
    def __init__(self, ndf, nc):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            # input is (nc) x 32 x 32
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf) x 16 x 16
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*2) x 8 x 8
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*4) x 4 x 4
            nn.Conv2d(ndf * 4, 1, 4, 1, 0, bias=False),
            nn.Sigmoid() # 確率を出力
        )

    def forward(self, input):
        return self.main(input)

In [None]:
# モデルのインスタンス化
netG = Generator(params["nz"], params["ngf"], params["nc"]).to(device)
netD = Discriminator(params["ndf"], params["nc"]).to(device)

# 重みの初期化を適用
netG.apply(weights_init)
netD.apply(weights_init)

# 損失関数とオプティマイザ
criterion = nn.BCELoss() # バイナリクロスエントロピー損失

# 学習の進捗可視化用に、固定のノイズを生成
fixed_noise = torch.randn(64, params["nz"], 1, 1, device=device)

# 本物と偽物のラベル
real_label = 0.9
fake_label = 0.1

# オプティマイザの設定
optimizerD = optim.Adam(netD.parameters(), lr=params["lr"], betas=(params["beta1"], 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=params["lr"], betas=(params["beta1"], 0.999))

# 学習ループ
G_losses = []
D_losses = []
img_list = []
iters = 0

print(f"Starting Training Loop for class: {params['target_class_name']}...")
for epoch in range(params["n_epochs"]):
    for i, data in enumerate(dataloader, 0):
        # ---------------------
        # (1) Discriminatorの学習
        # ---------------------
        netD.zero_grad()
        # 本物の画像で学習
        real_cpu = data[0].to(device)
        b_size = real_cpu.size(0)
        label = torch.full((b_size,), real_label, dtype=torch.float, device=device)
        output = netD(real_cpu).view(-1)
        errD_real = criterion(output, label)
        errD_real.backward()
        D_x = output.mean().item()

        # 偽物の画像で学習
        noise = torch.randn(b_size, params["nz"], 1, 1, device=device)
        fake = netG(noise)
        label.fill_(fake_label)
        output = netD(fake.detach()).view(-1)
        errD_fake = criterion(output, label)
        errD_fake.backward()
        D_G_z1 = output.mean().item()

        # 損失を合計し、オプティマイザで更新
        errD = errD_real + errD_fake
        optimizerD.step()

        # ---------------------
        # (2) Generatorの学習
        # ---------------------
        netG.zero_grad()
        label.fill_(real_label) # Generatorにとっては偽物が本物
        output = netD(fake).view(-1)
        errG = criterion(output, label)
        errG.backward()
        D_G_z2 = output.mean().item()
        optimizerG.step()

        # 損失を記録
        if i % 50 == 0: # 50イテレーションごとにログを出力
            print(f'[{epoch+1}/{params["n_epochs"]}][{i}/{len(dataloader)}] '
                  f'Loss_D: {errD.item():.4f} Loss_G: {errG.item():.4f} '
                  f'D(x): {D_x:.4f} D(G(z)): {D_G_z1:.4f} / {D_G_z2:.4f}')

        G_losses.append(errG.item())
        D_losses.append(errD.item())

    # 各エポックの終わりに、固定ノイズから生成した画像を確認
    with torch.no_grad():
        fake = netG(fixed_noise).detach().cpu()
    img_list.append(make_grid(fake, padding=2, normalize=True))

print("Training Finished.")

In [None]:
plt.figure(figsize=(10,5))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(G_losses,label="G")
plt.plot(D_losses,label="D")
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.show()

In [None]:
# 最後の_エポックで生成された画像を表示
plt.figure(figsize=(8,8))
plt.axis("off")
plt.title(f"Generated Images for {params['target_class_name']}")
plt.imshow(np.transpose(img_list[-1],(1,2,0)))
plt.show()

# リアルな画像と生成画像を比較
# dataloaderが空の場合を考慮
if len(dataloader) > 0:
    real_batch = next(iter(dataloader))
    plt.figure(figsize=(15,15))
    plt.subplot(1,2,1)
    plt.axis("off")
    plt.title(f"Real Images ({params['target_class_name']})")
    plt.imshow(np.transpose(make_grid(real_batch[0].to(device)[:64], padding=5, normalize=True).cpu(),(1,2,0)))

    plt.subplot(1,2,2)
    plt.axis("off")
    plt.title(f"Fake Images ({params['target_class_name']})")
    plt.imshow(np.transpose(img_list[-1],(1,2,0)))
    plt.show()
else:
    print(f"Cannot display real images as the dataloader for {params['target_class_name']} is empty.")