In [175]:
from typing import List
import numpy as np
import os
from tqdm.auto import tqdm
from pathlib import Path

import torch
import torch.nn as nn
from torch.autograd import Variable
import torchvision
import torchvision.transforms as T
from torchvision.utils import save_image

In [176]:
class args:
    gpus = "0"
    # dataset
    n_epochs = 200000
    batch_size = 64
    img_shape = [1,28,28]
    
    n_latent = 100
    
    G_lr = 0.0002
    G_betas = (0.5, 0.999)
    D_lr = 0.0002
    D_betas = (0.5, 0.999)
    
    gene_img_dir = Path("./generated_images")

if not os.path.isdir(args.gene_img_dir):
    os.makedirs(args.gene_img_dir)
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [177]:
transform = T.Compose([
    T.ToTensor()
])
train_dataset = torchvision.datasets.MNIST(root="/data", download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size = args.batch_size, shuffle=True)

In [178]:
def init_weight(layer):
    cls_name = layer.__class__.__name__
    if cls_name.find("Conv") != -1:
        torch.nn.init.normal_(layer.weight.data, 0.0, 0.02)
    elif cls_name.find("BatchNorm") != -1:
        torch.nn.init.normal_(layer.weight.data, 1.0, 0.02)
        torch.nn.init.constant_(layer.bias.data, 0.0)
        


class Generator(nn.Module):
    def __init__(self, n_latent : int, img_shape : List):
        super(Generator, self).__init__()
        self.init_shape = [img_shape[0], img_shape[1]//4, img_shape[2]//4]
        self.linear = nn.Linear(n_latent, np.prod(self.init_shape))
        self.generator = nn.Sequential(
            nn.ConvTranspose2d(self.init_shape[0], 64, 2, 2),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 128, 2, 2),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(128, img_shape[0], 3, 1, 1),
            nn.Tanh()
        )
        
    def forward(self, z):
        z = self.linear(z)
        z = z.reshape(z.shape[0], *self.init_shape)
        return self.generator(z)
    
class Discriminator(nn.Module):
    def __init__(self, img_shape):
        super(Discriminator, self).__init__()
        self.discrminator = nn.Sequential(
            nn.Conv2d(img_shape[0], 64, 3, 2, 1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64, 128, 3, 2, 1),
            nn.BatchNorm2d(128),
            nn.ReLU()
        )
        self.linear = nn.Sequential(
            nn.Linear(128 * 7 * 7, 1)
        )
        
    def forward(self, x):
        x = self.discrminator(x)
        x = x.reshape(x.shape[0], -1)
        return self.linear(x)
        


In [179]:
generator = Generator(args.n_latent, args.img_shape).to(device)
discriminator = Discriminator(args.img_shape).to(device)
generator.apply(init_weight)
discriminator.apply(init_weight)                                         

Discriminator(
  (discrminator): Sequential(
    (0): Conv2d(1, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU()
  )
  (linear): Sequential(
    (0): Linear(in_features=6272, out_features=1, bias=True)
  )
)

In [180]:
criterion_MSE = torch.nn.MSELoss()
optimizer_G = torch.optim.Adam(generator.parameters(), lr = args.G_lr, betas=args.G_betas)
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr = args.D_lr, betas=args.D_betas)

In [None]:
ones_label = Variable(torch.ones(args.batch_size), requires_grad=False).to(device)
zeros_label = Variable(torch.zeros(args.batch_size), requires_grad=False).to(device)
for epoch in range(args.n_epochs):
    train_loop = tqdm(train_loader, total=len(train_loader), desc="training", colour="blue", leave=False)
    G_loss_sum = 0
    D_loss_sum = 0
    for img, label in train_loop:
        img = img.to(device)
        label = label.to(device)
        latent_z = torch.randn(args.batch_size, args.n_latent).to(device)
        
        
        
        gene_img = generator(latent_z)
        # training D
        
        gene_logit = discriminator(gene_img.detach())
        real_logit = discriminator(img)
        D_loss = (criterion_MSE(gene_logit, zeros_label) + criterion_MSE(real_logit, ones_label)) / 2
        optimizer_D.zero_grad()
        D_loss.backward()
        optimizer_D.step()    
        
        # training G
        gene_logit = discriminator(gene_img)
        G_loss = criterion_MSE(gene_logit, ones_label)
        optimizer_G.zero_grad()
        G_loss.backward()
        optimizer_G.step()
        
        D_loss_sum += D_loss.item()
        G_loss_sum += G_loss.item()
    print(f"D loss : {D_loss_sum/len(train_loader)}, G loss : {G_loss_sum/len(train_loader)}")
    save_image(gene_img, args.gene_img_dir / f"{epoch}.png")

training:   0%|          | 0/938 [00:00<?, ?it/s]