<a href="https://colab.research.google.com/github/satvikk/ai_synthesize/blob/master/learn2_standardDCGAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

DCGAN on MNIST

In [0]:
import torch
import torch.nn as nn
import plotly.graph_objects as go
import torchvision.datasets as datasets
import torch.nn.functional as F
import tqdm
torch.set_default_tensor_type(torch.cuda.FloatTensor)

In [0]:
mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=None)
mnist_testset = datasets.MNIST(root='./data', train=False, download=True, transform=None)

nz = 100
nc = 1
ngf = 128
ndf = 128

In [0]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.Conv2d(ndf * 2, 1, 7, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, input):
        return self.main(input)

class Generator(nn.Module):
  def __init__(self):
    super().__init__()
    self.main = nn.Sequential(
        nn.ConvTranspose2d(nz, ngf * 2, 7, 1, 0, bias=False),
        nn.BatchNorm2d(ngf * 2),
        nn.ReLU(True),
        nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
        nn.BatchNorm2d(ngf),
        nn.ReLU(True),
        nn.ConvTranspose2d( ngf, nc, 4, 2, 1, bias=False),
        nn.Tanh()
    )
  def forward(self, input):
        return self.main(input)

def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)
  
discriminator = Discriminator()
discriminator.apply(weights_init)
generator = Generator()
generator.apply(weights_init)

Generator(
  (main): Sequential(
    (0): ConvTranspose2d(100, 256, kernel_size=(7, 7), stride=(1, 1), bias=False)
    (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): ConvTranspose2d(128, 1, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (7): Tanh()
  )
)

In [0]:
real_label = 1
fake_label = 0
class datamaker(torch.utils.data.Dataset):
  def __init__(self, mnist):
    self.mnist = mnist
  def __len__(self):
    return len(self.mnist) 
  def __getitem__(self, idx):
    return {'x': self.mnist.data[idx].unsqueeze(0).float()/255, 'y': torch.tensor(real_label)}

batch_size = 50
dataloader = torch.utils.data.DataLoader(datamaker(mnist_trainset), batch_size=batch_size,shuffle=True,)

In [0]:
criterion = nn.BCELoss()
fixed_noise = torch.randn(64, nz, 1, 1)
lr = 0.0002
beta1 = 0.5
optimizerD = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(beta1, 0.999))
optimizerG = torch.optim.Adam(generator.parameters(), lr=lr, betas=(beta1, 0.999))

In [0]:
discriminator.train()
generator.train()
img_list = []
G_losses = []
D_losses = []
iters = 0
num_epochs = 50
for epoch in range(num_epochs):
  for i, data in enumerate(dataloader, 0):
    discriminator.zero_grad()
    real_cpu = data['x'].cuda()
    b_size = real_cpu.size(0)
    label = torch.full((b_size,), real_label)
    # Forward pass real batch through D
    output = discriminator(real_cpu).view(-1)
    # Calculate loss on all-real batch
    errD_real = criterion(output, label)
    # Calculate gradients for D in backward pass
    errD_real.backward()
    D_x = output.mean().item()

    ## Train with all-fake batch
    # Generate batch of latent vectors
    noise = torch.randn(b_size, nz, 1, 1,)
    # Generate fake image batch with G
    fake = generator(noise)
    label.fill_(fake_label)
    # Classify all fake batch with D
    output = discriminator(fake.detach()).view(-1)
    # Calculate D's loss on the all-fake batch
    errD_fake = criterion(output, label)
    # Calculate the gradients for this batch
    errD_fake.backward()
    D_G_z1 = output.mean().item()
    # Add the gradients from the all-real and all-fake batches
    errD = errD_real + errD_fake
    # Update D
    optimizerD.step()

    ############################
    # (2) Update G network: maximize log(D(G(z)))
    ###########################
    generator.zero_grad()
    label.fill_(real_label)  # fake labels are real for generator cost
    # Since we just updated D, perform another forward pass of all-fake batch through D
    output = discriminator(fake).view(-1)
    # Calculate G's loss based on this output
    errG = criterion(output, label)
    # Calculate gradients for G
    errG.backward()
    D_G_z2 = output.mean().item()
    # Update G
    optimizerG.step()

    # Output training stats
    if i % 600 == 0:
        print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
              % (epoch, num_epochs, i, len(dataloader),
                  errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))

    # Save Losses for plotting later
    G_losses.append(errG.item())
    D_losses.append(errD.item())

    # Check how the generator is doing by saving G's output on fixed_noise
    # if (iters % 500 == 0) or ((epoch == num_epochs-1) and (i == len(dataloader)-1)):
    #     with torch.no_grad():
    #         fake = generator(fixed_noise).detach().cpu()
    #     img_list.append(vutils.make_grid(fake, padding=2, normalize=True))

    iters += 1

[0/50][0/1200]	Loss_D: 1.6306	Loss_G: 3.2473	D(x): 0.5144	D(G(z)): 0.2477 / 0.1243
[0/50][500/1200]	Loss_D: 0.0505	Loss_G: 5.4445	D(x): 0.9578	D(G(z)): 0.0056 / 0.0086
[0/50][1000/1200]	Loss_D: 0.0430	Loss_G: 5.0319	D(x): 0.9771	D(G(z)): 0.0183 / 0.0102
[1/50][0/1200]	Loss_D: 0.2237	Loss_G: 5.7362	D(x): 0.9874	D(G(z)): 0.1710 / 0.0068
[1/50][500/1200]	Loss_D: 0.0180	Loss_G: 6.3867	D(x): 0.9857	D(G(z)): 0.0031 / 0.0054
[1/50][1000/1200]	Loss_D: 0.0251	Loss_G: 6.7415	D(x): 0.9794	D(G(z)): 0.0035 / 0.0034
[2/50][0/1200]	Loss_D: 0.0672	Loss_G: 4.0674	D(x): 0.9522	D(G(z)): 0.0154 / 0.0291
[2/50][500/1200]	Loss_D: 0.0404	Loss_G: 5.0873	D(x): 0.9755	D(G(z)): 0.0115 / 0.0138
[2/50][1000/1200]	Loss_D: 0.2127	Loss_G: 3.7253	D(x): 0.9485	D(G(z)): 0.0936 / 0.0473
[3/50][0/1200]	Loss_D: 0.0571	Loss_G: 4.9410	D(x): 0.9715	D(G(z)): 0.0258 / 0.0142
[3/50][500/1200]	Loss_D: 0.1267	Loss_G: 5.0008	D(x): 0.9049	D(G(z)): 0.0107 / 0.0183
[3/50][1000/1200]	Loss_D: 0.2719	Loss_G: 6.7892	D(x): 0.9889	D(G(z)): 

In [0]:
generator.eval()
discriminator.eval()
with torch.no_grad():
  noise = torch.randn(1, nz, 1, 1,)
  fake = generator(noise)
  while discriminator(fake).item() < 0.5:
    noise = torch.randn(1, nz, 1, 1,)
    fake = generator(noise)
  grid20x = torch.cat([torch.linspace(0,27,28)]*28)
  grid20y = torch.cat([torch.linspace(27,0,28).unsqueeze(1)]*28, dim = 1).flatten()
  fig = go.Figure()
  fig.add_scatter(
      x = grid20x.cpu(),
      y = grid20y.cpu(),
      mode = "markers",
      marker = dict(
          color = fake.flatten().cpu(),
          showscale=True,
          colorscale = "gray",
          symbol = "square",
          size = 15,
      )
  )
  fig.update_layout(
      yaxis = dict(
        scaleanchor = "x",
        scaleratio = 1,
      )
  )
  fig.show()
  # 1 8 7 3 5 6 9 0 4 

In [0]:
noise = torch.randn(1, nz, 1, 1,)
fake = generator(noise)
discriminator(fake).item()

0.21326686441898346