<a href="https://colab.research.google.com/github/satvikk/ai_synthesize/blob/master/learn3_cDCGAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

cDCGAN on MNIST

In [0]:
import torch
import torch.nn as nn
import plotly.graph_objects as go
import torchvision.datasets as datasets
import torch.nn.functional as F
import tqdm
torch.set_default_tensor_type(torch.cuda.FloatTensor)

In [0]:
mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=None)
mnist_testset = datasets.MNIST(root='./data', train=False, download=True, transform=None)

nz = 100
nc = 1
ngf = 128
ndf = 128
nzf = 50

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw
Processing...
Done!






In [0]:
class reshaper(nn.Module):
  def __init__(self, reshape_args):
    super().__init__()
    self.reshape_args = reshape_args
  def forward(self, x):
    return x.reshape(*self.reshape_args)

class Discriminator(nn.Module):
  def __init__(self):
    super(Discriminator, self).__init__()
    self.main = nn.Sequential(
        nn.Conv2d(nc + 1, ndf, 4, 2, 1, bias=False),
        nn.LeakyReLU(0.2, inplace=True),
        nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
        nn.BatchNorm2d(ndf * 2),
        nn.Conv2d(ndf * 2, 1, 7, 1, 0, bias=False),
        nn.Sigmoid()
    )
    self.labeller = nn.Sequential(
        reshaper((-1,1)),
        nn.Linear(1,nzf),
        nn.LeakyReLU(0.2, inplace=True),
        reshaper((-1,nzf,1,1)),
        nn.ConvTranspose2d(nzf,1,28,1,0),
        nn.LeakyReLU(0.2, inplace=True),
        # nn.BatchNorm2d(1)
    )
  def forward(self, image, label):
    label = self.labeller(label)
    image = torch.cat((image, label), dim = 1)
    return self.main(image)

class Generator(nn.Module):
  def __init__(self):
    super().__init__()
    self.main = nn.Sequential(
      nn.ConvTranspose2d(1 + (ngf * 2), ngf, 4, 2, 1, bias=False),
      nn.BatchNorm2d(ngf),
      nn.ReLU(True),
      nn.ConvTranspose2d( ngf, nc, 4, 2, 1, bias=False),
      nn.Tanh()
    )
    self.labeller = nn.Sequential(
      reshaper((-1,1)),
      nn.Linear(1,nzf),
      nn.LeakyReLU(0.2, inplace=True),
      reshaper((-1,nzf,1,1)),
      nn.ConvTranspose2d(nzf,1,7,1,0),
      nn.ReLU(inplace=True),
      # nn.BatchNorm2d(1)
    )
    self.convt0 = nn.Sequential(
      nn.ConvTranspose2d(nz, ngf * 2, 7, 1, 0, bias=False),
      nn.BatchNorm2d(ngf * 2),
      nn.ReLU(True), 
    )
  def forward(self, noise, label):
    label = label.reshape(-1,1,1,1)
    label = self.labeller(label)
    noise = self.convt0(noise)
    noise = torch.cat((noise, label), dim = 1)
    return self.main(noise)

def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)
  
discriminator = Discriminator()
discriminator.apply(weights_init)
generator = Generator()
generator.apply(weights_init)

Generator(
  (main): Sequential(
    (0): ConvTranspose2d(257, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): ConvTranspose2d(128, 1, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (4): Tanh()
  )
  (labeller): Sequential(
    (0): reshaper()
    (1): Linear(in_features=1, out_features=50, bias=True)
    (2): LeakyReLU(negative_slope=0.2, inplace=True)
    (3): reshaper()
    (4): ConvTranspose2d(50, 1, kernel_size=(7, 7), stride=(1, 1))
    (5): ReLU(inplace=True)
  )
  (convt0): Sequential(
    (0): ConvTranspose2d(100, 256, kernel_size=(7, 7), stride=(1, 1), bias=False)
    (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
  )
)

In [0]:
real_label = 0.85
fake_label = 0.15
class datamaker(torch.utils.data.Dataset):
  def __init__(self, mnist):
    self.mnist = mnist
  def __len__(self):
    return len(self.mnist) 
  def __getitem__(self, idx):
    return {'x': self.mnist.data[idx].unsqueeze(0).float()/255, 'y': self.mnist.targets[idx].float()}

batch_size = 50
dataloader = torch.utils.data.DataLoader(datamaker(mnist_trainset), batch_size=batch_size,shuffle=True,)

In [0]:
criterion = nn.BCELoss()
fixed_noise = torch.randn(64, nz, 1, 1)
lr = 0.0002
beta1 = 0.5
optimizerD = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(beta1, 0.999))
optimizerG = torch.optim.Adam(generator.parameters(), lr=lr, betas=(beta1, 0.999))

In [0]:
discriminator.train()
generator.train()
img_list = []
G_losses = []
D_losses = []
iters = 0
num_epochs = 100
for epoch in range(num_epochs):
  for i, data in enumerate(dataloader, 0):
    discriminator.zero_grad()
    real_cpu = data['x'].cuda()
    b_size = real_cpu.size(0)
    # label = torch.full((b_size,), real_label)
    label = torch.ones(b_size)*real_label + torch.randn(b_size)*0.3
    output = discriminator(real_cpu, data['y'].cuda()).view(-1)
    # Calculate loss on all-real batch
    errD_real = criterion(output, label)
    # Calculate gradients for D in backward pass
    errD_real.backward()
    D_x = output.mean().item()

    ## Train with all-fake batch
    # Generate batch of latent vectors
    noise = torch.randn(b_size, nz, 1, 1,)
    fake_categories = torch.floor(torch.rand(b_size)*10)
    # Generate fake image batch with G
    fake = generator(noise, fake_categories)
    label = torch.ones(b_size)*fake_label + torch.randn(b_size)*0.3
    # Classify all fake batch with D
    output = discriminator(fake.detach(), fake_categories).view(-1)
    # Calculate D's loss on the all-fake batch
    errD_fake = criterion(output, label)
    # Calculate the gradients for this batch
    errD_fake.backward()
    D_G_z1 = output.mean().item()
    # Add the gradients from the all-real and all-fake batches
    errD = errD_real + errD_fake
    # Update D
    optimizerD.step()

    ############################
    # (2) Update G network: maximize log(D(G(z)))
    ###########################
    generator.zero_grad()
    label = torch.ones(b_size)*real_label + torch.randn(b_size)*0.3  # fake labels are real for generator cost
    # Since we just updated D, perform another forward pass of all-fake batch through D
    output = discriminator(fake, fake_categories).view(-1)
    # Calculate G's loss based on this output
    errG = criterion(output, label)
    # Calculate gradients for G
    errG.backward()
    D_G_z2 = output.mean().item()
    # Update G
    optimizerG.step()

    # Output training stats
    if i % 600 == 0:
        print('[%02d/%d][%03d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
              % (epoch, num_epochs, i, len(dataloader),
                  errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))

    # Save Losses for plotting later
    G_losses.append(errG.item())
    D_losses.append(errD.item())

    # Check how the generator is doing by saving G's output on fixed_noise
    # if (iters % 500 == 0) or ((epoch == num_epochs-1) and (i == len(dataloader)-1)):
    #     with torch.no_grad():
    #         fake = generator(fixed_noise).detach().cpu()
    #     img_list.append(vutils.make_grid(fake, padding=2, normalize=True))

    iters += 1

[00/100][000/1200]	Loss_D: 0.9231	Loss_G: 1.7097	D(x): 0.8737	D(G(z)): 0.1764 / 0.1459
[00/100][600/1200]	Loss_D: 0.8315	Loss_G: 1.9850	D(x): 0.7525	D(G(z)): 0.0877 / 0.1051
[01/100][000/1200]	Loss_D: 0.8310	Loss_G: 1.9253	D(x): 0.8615	D(G(z)): 0.0922 / 0.0955
[01/100][600/1200]	Loss_D: 0.9311	Loss_G: 1.2302	D(x): 0.8465	D(G(z)): 0.1777 / 0.2394
[02/100][000/1200]	Loss_D: 0.6793	Loss_G: 1.9423	D(x): 0.7899	D(G(z)): 0.1214 / 0.1322
[02/100][600/1200]	Loss_D: 0.8581	Loss_G: 1.4083	D(x): 0.9112	D(G(z)): 0.1791 / 0.2130
[03/100][000/1200]	Loss_D: 0.6162	Loss_G: 1.7358	D(x): 0.8705	D(G(z)): 0.1592 / 0.1346
[03/100][600/1200]	Loss_D: 0.7017	Loss_G: 1.3758	D(x): 0.8600	D(G(z)): 0.2053 / 0.2152
[04/100][000/1200]	Loss_D: 0.5895	Loss_G: 1.4412	D(x): 0.8839	D(G(z)): 0.1343 / 0.1707
[04/100][600/1200]	Loss_D: 0.9907	Loss_G: 1.5305	D(x): 0.8360	D(G(z)): 0.1621 / 0.1664
[05/100][000/1200]	Loss_D: 0.9144	Loss_G: 1.6917	D(x): 0.8515	D(G(z)): 0.1722 / 0.1561
[05/100][600/1200]	Loss_D: 0.9018	Loss_G: 1

In [0]:
number_gen = 8.
generator.eval()
discriminator.eval()
with torch.no_grad():
  noise = torch.randn(1, nz, 1, 1,)
  fake = generator(noise, torch.tensor([number_gen]))
  while discriminator(fake, torch.tensor([number_gen])).item() < 0.0:
    noise = torch.randn(1, nz, 1, 1,)
    fake = generator(noise, torch.tensor([number_gen]))
  grid20x = torch.cat([torch.linspace(0,27,28)]*28)
  grid20y = torch.cat([torch.linspace(27,0,28).unsqueeze(1)]*28, dim = 1).flatten()
  fig = go.Figure()
  fig.add_scatter(
      x = grid20x.cpu(),
      y = grid20y.cpu(),
      mode = "markers",
      marker = dict(
          color = fake.flatten().cpu(),
          showscale=True,
          colorscale = "gray",
          symbol = "square",
          size = 15,
      )
  )
  fig.update_layout(
      yaxis = dict(
        scaleanchor = "x",
        scaleratio = 1,
      )
  )
  fig.show()
  print(discriminator(fake, torch.tensor([number_gen])).item())

0.29479220509529114


In [0]:
class reshaper(nn.Module):
  def __init__(self, reshape_args):
    super().__init__()
    self.reshape_args = reshape_args
  def forward(self, x):
    return x.reshape(*self.reshape_args)
rr = reshaper((-1,1))
rr2 = reshaper((-1,50,1,1))
labeller = nn.Sequential(
        nn.Linear(1,50),
        nn.LeakyReLU(0.2, inplace=True),
        nn.ConvTranspose2d(1,1,28,1,0),
        nn.LeakyReLU(0.2, inplace=True),
        nn.BatchNorm2d(1)
    )
a = torch.rand(6)
a = rr(a)
b = nn.Linear(1,50)
a = b(a)
print(a.shape)
a = rr2(a)
print(a.shape)
c= nn.ConvTranspose2d(50,1,28,1,0)
a = c(a)
print(a.shape)


torch.Size([6, 50])
torch.Size([6, 50, 1, 1])
torch.Size([6, 1, 28, 28])
