In [24]:
from __future__ import print_function
#%matplotlib inline
import argparse
import os
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torchvision.utils as vutils
from torch.utils.data import SubsetRandomSampler
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML
from tqdm import tqdm

In [2]:
from google.colab import drive
drive.mount('/content/drive') 
%cd /content/drive/My\ Drive/

Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3aietf%3awg%3aoauth%3a2.0%3aoob&response_type=code&scope=email%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdocs.test%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive.photos.readonly%20https%3a%2f%2fwww.googleapis.com%2fauth%2fpeopleapi.readonly

Enter your authorization code:
··········
Mounted at /content/drive
/content/drive/My Drive


In [59]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
data_dir = 'div2k' #Jus to check
batch_size = 64
TRAIN_ALL = False
#All images will be resized to this size using a transformer.
image_size = 64

# Number of channels in the training images. For color images this is 3
nc = 1

# Size of z latent vector (i.e. size of generator input)
nz = 100

# Size of feature maps in generator
ngf = 64

# Size of feature maps in discriminator
ndf = 64

# Number of training epochs
num_epochs = 4

# Learning rate for optimizers
lr = 0.0002

# Beta1 hyperparam for Adam optimizers
beta1 = 0.5
# Beta2 hyperparam for Adam optimizers
beta2 = 0.999

real_label = 1.
fake_label = 0.
# Input to generator
fixed_noise = torch.randn(64, nz, 1, 1, device=device)
# Define Loss function
criterion = nn.BCELoss()

In [4]:
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

In [5]:
class Generator(nn.Module):

    def __init__(self):
        super(Generator, self).__init__()
        self._model = nn.Sequential(
            # input is Z, going into a convolution
            #i/p,o/p,kernel size,stride,padding
            nn.ConvTranspose2d( nz, ngf * 16, 5, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 16),
            nn.ReLU(True),
            # state size. (ngf*16) x 4 x 4
            nn.ConvTranspose2d( ngf * 16, ngf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            # state size. (ngf*8) x 8 x 8
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            # state size. (ngf*4) x 16 x 16
            nn.ConvTranspose2d( ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            # state size. (ngf*2) x 32 x 32
            nn.ConvTranspose2d( ngf*2, nc, 4, 2, 1, bias=False),
            nn.Tanh()
            # state size. (nc) x 64 x 64
        )

    def forward(self, input):
        return self._model(input)

In [6]:
class Discriminator(nn.Module):

    def __init__(self):
        super(Discriminator, self).__init__()
        self._model = nn.Sequential(
            # input is (nc) x 64 x 64
            nn.Conv2d(nc, ndf * 2, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*2) x 32 x 32
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*4) x 16 x 16
            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*8) x 8 x 8
            nn.Conv2d(ndf * 8, ndf * 16, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 16),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*16) x 4 x 4
            nn.Conv2d(ndf * 16, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, input):
        return self._model(input)

In [15]:
from torchsummary import summary

summary(generator, (100, 1, 1))
summary(discriminator, (1, 64, 64))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
   ConvTranspose2d-1           [-1, 1024, 4, 4]       1,638,400
       BatchNorm2d-2           [-1, 1024, 4, 4]           2,048
              ReLU-3           [-1, 1024, 4, 4]               0
   ConvTranspose2d-4            [-1, 512, 8, 8]       8,388,608
       BatchNorm2d-5            [-1, 512, 8, 8]           1,024
              ReLU-6            [-1, 512, 8, 8]               0
   ConvTranspose2d-7          [-1, 256, 16, 16]       2,097,152
       BatchNorm2d-8          [-1, 256, 16, 16]             512
              ReLU-9          [-1, 256, 16, 16]               0
  ConvTranspose2d-10          [-1, 128, 32, 32]         524,288
      BatchNorm2d-11          [-1, 128, 32, 32]             256
             ReLU-12          [-1, 128, 32, 32]               0
  ConvTranspose2d-13            [-1, 1, 64, 64]           2,048
             Tanh-14            [-1, 1,

In [7]:
img_list = []
G_losses = []
D_losses = []

def train_gan(generator, discriminator, gen_optimizer, dis_optimizer, train_loader, valid_loader):
        iters = 0
        print("GAN training started :D...")

        for epoch in range(num_epochs):
            # For each batch in the dataloader
            for i, data in enumerate(tqdm(train_loader, 0)):

                # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
                ## Train with all-real batch
                discriminator.zero_grad()
                # Format batch
                b_real = data[0].to(device)
                b_size = b_real.size(0)
                label = torch.full((b_size,), real_label, device=device)
                # Forward pass real batch through D
                output = discriminator(b_real).view(-1)
                # Calculate loss on all-real batch
                errD_real = criterion(output, label)
                # Calculate gradients for D in backward pass
                errD_real.backward()
                D_x = output.mean().item()

                ## Train with all-fake batch
                # Generate batch of latent vectors
                noise = torch.randn(b_size, nz, 1, 1, device=device)
                # Generate fake image batch with G
                fake = generator(noise)
                label.fill_(fake_label)
                # Classify all fake batch with D
                output = discriminator(fake.detach()).view(-1)
                # Calculate D's loss on the all-fake batch
                errD_fake = criterion(output, label)
                # Calculate the gradients for this batch
                errD_fake.backward()
                D_G_z1 = output.mean().item()
                # Add the gradients from the all-real and all-fake batches
                errD = errD_real + errD_fake
                # Update D
                dis_optimizer.step()

                # (2) Update G network: maximize log(D(G(z)))
                generator.zero_grad()
                label.fill_(real_label)  # fake labels are real for generator cost
                # Since we just updated D, perform another forward pass of all-fake batch through D
                output = discriminator(fake).view(-1)
                # Calculate G's loss based on this output
                errG = criterion(output, label)
                # Calculate gradients for G
                errG.backward()
                D_G_z2 = output.mean().item()
                # Update G
                gen_optimizer.step()

                # Output training stats
                #if i % 50 == 0:
                print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
                      % (epoch, num_epochs, i, len(train_loader),
                          errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))

                # Save Losses for plotting later
                G_losses.append(errG.item())
                D_losses.append(errD.item())

                # Check how the generator is doing by saving G's output on fixed_noise
                if (iters % 500 == 0) or ((epoch == num_epochs-1) and (i == len(train_loader)-1)):
                    with torch.no_grad():
                        fake = generator(fixed_noise).detach().cpu()
                    img_list.append(vutils.make_grid(fake, padding=2, normalize=True))

                iters += 1


In [66]:
def get_indices(dataset,class_name, indices):
    #indices =  []
    j = 0
    for i in range(len(dataset.labels)):
        if dataset.labels[i] == class_name:
            indices.append(i)
            j += 1
    print("Total Samples of class", class_name,"found are",j)
    return indices

In [68]:
mu = [0.5]
sigma = [0.5]
#Originally authors used just scaling
transform = transforms.Compose([transforms.RandomHorizontalFlip(),
                                    transforms.Grayscale(num_output_channels=1),
                                    transforms.Resize((64,64)),
                                    transforms.ToTensor(),
                                    transforms.Normalize(mu, sigma)])
#train_set = datasets.ImageFolder(os.path.join(
#      data_dir, "train/"), transform=transform)
train_set = datasets.STL10(root='./data', split='train', download=True,
                                transform=transform)
train_loader = torch.utils.data.DataLoader(
      train_set, batch_size=batch_size, shuffle=True)
#valid_set = datasets.ImageFolder(os.path.join( 
#      data_dir, "val/"), transform=transform)
valid_set = datasets.STL10(root='./data', split='test', download=True,
                                transform=transform)
valid_loader = torch.utils.data.DataLoader(
      valid_set, batch_size=batch_size, shuffle=False)

generator = Generator().to(device)
discriminator = Discriminator().to(device)
generator.apply(weights_init)
discriminator.apply(weights_init)
gen_optimizer = optim.Adam(generator.parameters(), lr=lr, betas=(beta1, beta2))
dis_optimizer = optim.Adam(discriminator.parameters(), lr=lr, betas=(beta1, beta2))

if TRAIN_ALL:
    train_gan(generator, discriminator, gen_optimizer, dis_optimizer, train_loader, valid_loader)
else:
    idx = []
    idx = get_indices(train_set, 0, idx) #Airplane
    idx_2 = get_indices(train_set, 8, idx) #Ship
    idx_3 = get_indices(train_set, 9, idx) #Truck
    #idx_1.append(idx_2)
    #idx_1.append(idx_3)
    #idx.append(idx_3)
    print("Total samples now are ",len(idx))
    selected_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size,
                                             sampler = SubsetRandomSampler(idx))
    train_gan(generator, discriminator, gen_optimizer, dis_optimizer, selected_loader, valid_loader)

Files already downloaded and verified
Files already downloaded and verified




  0%|          | 0/24 [00:00<?, ?it/s][A[A

Total Samples of class 0 found are 500
Total Samples of class 8 found are 500
Total Samples of class 9 found are 500
Total samples now are  1500
GAN training started :D...




  4%|▍         | 1/24 [00:00<00:11,  1.93it/s][A[A

[0/4][0/24]	Loss_D: 1.3942	Loss_G: 16.3015	D(x): 0.5943	D(G(z)): 0.4297 / 0.0000




  8%|▊         | 2/24 [00:00<00:10,  2.07it/s][A[A

[0/4][1/24]	Loss_D: 3.2365	Loss_G: 18.0994	D(x): 0.8497	D(G(z)): 0.8576 / 0.0000




 12%|█▎        | 3/24 [00:01<00:09,  2.19it/s][A[A

[0/4][2/24]	Loss_D: 1.9396	Loss_G: 15.2317	D(x): 0.6520	D(G(z)): 0.2864 / 0.0000




 17%|█▋        | 4/24 [00:01<00:08,  2.26it/s][A[A

[0/4][3/24]	Loss_D: 3.6348	Loss_G: 24.4377	D(x): 0.7212	D(G(z)): 0.8310 / 0.0000




 21%|██        | 5/24 [00:02<00:08,  2.33it/s][A[A

[0/4][4/24]	Loss_D: 3.3988	Loss_G: 16.4836	D(x): 0.5622	D(G(z)): 0.0005 / 0.0000




 25%|██▌       | 6/24 [00:02<00:07,  2.37it/s][A[A

[0/4][5/24]	Loss_D: 2.3314	Loss_G: 26.3463	D(x): 0.7081	D(G(z)): 0.6536 / 0.0000




 29%|██▉       | 7/24 [00:02<00:07,  2.41it/s][A[A

[0/4][6/24]	Loss_D: 1.4027	Loss_G: 22.5829	D(x): 0.6811	D(G(z)): 0.0000 / 0.0000




 33%|███▎      | 8/24 [00:03<00:06,  2.43it/s][A[A

[0/4][7/24]	Loss_D: 1.6662	Loss_G: 14.3182	D(x): 0.8147	D(G(z)): 0.1604 / 0.0000




 38%|███▊      | 9/24 [00:03<00:06,  2.43it/s][A[A

[0/4][8/24]	Loss_D: 3.3277	Loss_G: 37.2890	D(x): 0.9086	D(G(z)): 0.9022 / 0.0000




 42%|████▏     | 10/24 [00:04<00:05,  2.45it/s][A[A

[0/4][9/24]	Loss_D: 1.2431	Loss_G: 46.9929	D(x): 0.6775	D(G(z)): 0.0000 / 0.0000




 46%|████▌     | 11/24 [00:04<00:05,  2.46it/s][A[A

[0/4][10/24]	Loss_D: 1.9282	Loss_G: 50.0749	D(x): 0.6853	D(G(z)): 0.0000 / 0.0000




 50%|█████     | 12/24 [00:04<00:04,  2.46it/s][A[A

[0/4][11/24]	Loss_D: 0.3750	Loss_G: 51.1254	D(x): 0.9024	D(G(z)): 0.0000 / 0.0000




 54%|█████▍    | 13/24 [00:05<00:04,  2.45it/s][A[A

[0/4][12/24]	Loss_D: 0.0340	Loss_G: 51.5192	D(x): 0.9793	D(G(z)): 0.0000 / 0.0000




 58%|█████▊    | 14/24 [00:05<00:04,  2.45it/s][A[A

[0/4][13/24]	Loss_D: 0.0022	Loss_G: 51.6514	D(x): 0.9978	D(G(z)): 0.0000 / 0.0000




 62%|██████▎   | 15/24 [00:06<00:03,  2.45it/s][A[A

[0/4][14/24]	Loss_D: 0.0005	Loss_G: 51.4887	D(x): 0.9995	D(G(z)): 0.0000 / 0.0000




 67%|██████▋   | 16/24 [00:06<00:03,  2.45it/s][A[A

[0/4][15/24]	Loss_D: 0.0782	Loss_G: 51.4331	D(x): 0.9667	D(G(z)): 0.0000 / 0.0000




 71%|███████   | 17/24 [00:06<00:02,  2.46it/s][A[A

[0/4][16/24]	Loss_D: 0.0068	Loss_G: 51.3022	D(x): 0.9941	D(G(z)): 0.0000 / 0.0000




 75%|███████▌  | 18/24 [00:07<00:02,  2.46it/s][A[A

[0/4][17/24]	Loss_D: 0.0009	Loss_G: 51.5996	D(x): 0.9991	D(G(z)): 0.0000 / 0.0000




 79%|███████▉  | 19/24 [00:07<00:02,  2.46it/s][A[A

[0/4][18/24]	Loss_D: 0.0004	Loss_G: 51.1845	D(x): 0.9996	D(G(z)): 0.0000 / 0.0000




 83%|████████▎ | 20/24 [00:08<00:01,  2.46it/s][A[A

[0/4][19/24]	Loss_D: 0.0035	Loss_G: 51.5137	D(x): 0.9967	D(G(z)): 0.0000 / 0.0000




 88%|████████▊ | 21/24 [00:08<00:01,  2.46it/s][A[A

[0/4][20/24]	Loss_D: 0.0002	Loss_G: 51.1077	D(x): 0.9998	D(G(z)): 0.0000 / 0.0000




 92%|█████████▏| 22/24 [00:09<00:00,  2.45it/s][A[A

[0/4][21/24]	Loss_D: 0.0026	Loss_G: 50.8996	D(x): 0.9975	D(G(z)): 0.0000 / 0.0000




 96%|█████████▌| 23/24 [00:09<00:00,  2.45it/s][A[A

100%|██████████| 24/24 [00:09<00:00,  2.49it/s]


  0%|          | 0/24 [00:00<?, ?it/s][A[A

[0/4][22/24]	Loss_D: 0.0010	Loss_G: 51.2777	D(x): 0.9990	D(G(z)): 0.0000 / 0.0000
[0/4][23/24]	Loss_D: 0.0003	Loss_G: 51.4823	D(x): 0.9997	D(G(z)): 0.0000 / 0.0000




  4%|▍         | 1/24 [00:00<00:09,  2.40it/s][A[A

[1/4][0/24]	Loss_D: 0.0012	Loss_G: 51.2761	D(x): 0.9988	D(G(z)): 0.0000 / 0.0000




  8%|▊         | 2/24 [00:00<00:09,  2.41it/s][A[A

[1/4][1/24]	Loss_D: 0.0002	Loss_G: 50.9231	D(x): 0.9998	D(G(z)): 0.0000 / 0.0000




 12%|█▎        | 3/24 [00:01<00:08,  2.42it/s][A[A

[1/4][2/24]	Loss_D: 0.0014	Loss_G: 51.1526	D(x): 0.9986	D(G(z)): 0.0000 / 0.0000




 17%|█▋        | 4/24 [00:01<00:08,  2.43it/s][A[A

[1/4][3/24]	Loss_D: 0.0001	Loss_G: 51.1101	D(x): 0.9999	D(G(z)): 0.0000 / 0.0000




 21%|██        | 5/24 [00:02<00:07,  2.42it/s][A[A

[1/4][4/24]	Loss_D: 0.0153	Loss_G: 51.3027	D(x): 0.9899	D(G(z)): 0.0000 / 0.0000




 25%|██▌       | 6/24 [00:02<00:07,  2.43it/s][A[A

[1/4][5/24]	Loss_D: 0.0008	Loss_G: 51.1529	D(x): 0.9992	D(G(z)): 0.0000 / 0.0000




 29%|██▉       | 7/24 [00:02<00:06,  2.43it/s][A[A

[1/4][6/24]	Loss_D: 0.0007	Loss_G: 50.8414	D(x): 0.9993	D(G(z)): 0.0000 / 0.0000




 33%|███▎      | 8/24 [00:03<00:06,  2.43it/s][A[A

[1/4][7/24]	Loss_D: 0.0056	Loss_G: 50.5252	D(x): 0.9952	D(G(z)): 0.0000 / 0.0000




 38%|███▊      | 9/24 [00:03<00:06,  2.43it/s][A[A

[1/4][8/24]	Loss_D: 0.0003	Loss_G: 50.9264	D(x): 0.9997	D(G(z)): 0.0000 / 0.0000




 42%|████▏     | 10/24 [00:04<00:05,  2.43it/s][A[A

[1/4][9/24]	Loss_D: 0.0008	Loss_G: 50.7608	D(x): 0.9992	D(G(z)): 0.0000 / 0.0000




 46%|████▌     | 11/24 [00:04<00:05,  2.41it/s][A[A

[1/4][10/24]	Loss_D: 0.0001	Loss_G: 50.7244	D(x): 0.9999	D(G(z)): 0.0000 / 0.0000




 50%|█████     | 12/24 [00:04<00:04,  2.43it/s][A[A

[1/4][11/24]	Loss_D: 0.0004	Loss_G: 50.9819	D(x): 0.9996	D(G(z)): 0.0000 / 0.0000




 54%|█████▍    | 13/24 [00:05<00:04,  2.42it/s][A[A

[1/4][12/24]	Loss_D: 0.0002	Loss_G: 50.8617	D(x): 0.9998	D(G(z)): 0.0000 / 0.0000




 58%|█████▊    | 14/24 [00:05<00:04,  2.41it/s][A[A

[1/4][13/24]	Loss_D: 0.0002	Loss_G: 50.8737	D(x): 0.9998	D(G(z)): 0.0000 / 0.0000




 62%|██████▎   | 15/24 [00:06<00:03,  2.41it/s][A[A

[1/4][14/24]	Loss_D: 0.0023	Loss_G: 50.5453	D(x): 0.9978	D(G(z)): 0.0000 / 0.0000




 67%|██████▋   | 16/24 [00:06<00:03,  2.41it/s][A[A

[1/4][15/24]	Loss_D: 0.0003	Loss_G: 50.9453	D(x): 0.9997	D(G(z)): 0.0000 / 0.0000




 71%|███████   | 17/24 [00:07<00:02,  2.40it/s][A[A

[1/4][16/24]	Loss_D: 0.0009	Loss_G: 50.5601	D(x): 0.9991	D(G(z)): 0.0000 / 0.0000




 75%|███████▌  | 18/24 [00:07<00:02,  2.41it/s][A[A

[1/4][17/24]	Loss_D: 0.0026	Loss_G: 50.9398	D(x): 0.9976	D(G(z)): 0.0000 / 0.0000




 79%|███████▉  | 19/24 [00:07<00:02,  2.39it/s][A[A

[1/4][18/24]	Loss_D: 0.0001	Loss_G: 50.9001	D(x): 0.9999	D(G(z)): 0.0000 / 0.0000




 83%|████████▎ | 20/24 [00:08<00:01,  2.39it/s][A[A

[1/4][19/24]	Loss_D: 0.0002	Loss_G: 50.9525	D(x): 0.9998	D(G(z)): 0.0000 / 0.0000




 88%|████████▊ | 21/24 [00:08<00:01,  2.39it/s][A[A

[1/4][20/24]	Loss_D: 0.0000	Loss_G: 50.8063	D(x): 1.0000	D(G(z)): 0.0000 / 0.0000




 92%|█████████▏| 22/24 [00:09<00:00,  2.39it/s][A[A

[1/4][21/24]	Loss_D: 0.0019	Loss_G: 50.6394	D(x): 0.9982	D(G(z)): 0.0000 / 0.0000




 96%|█████████▌| 23/24 [00:09<00:00,  2.40it/s][A[A

100%|██████████| 24/24 [00:09<00:00,  2.85it/s]

[1/4][22/24]	Loss_D: 0.0003	Loss_G: 50.9590	D(x): 0.9997	D(G(z)): 0.0000 / 0.0000
[1/4][23/24]	Loss_D: 0.0001	Loss_G: 50.5370	D(x): 0.9999	D(G(z)): 0.0000 / 0.0000


100%|██████████| 24/24 [00:09<00:00,  2.46it/s]


  0%|          | 0/24 [00:00<?, ?it/s][A[A

  4%|▍         | 1/24 [00:00<00:09,  2.42it/s][A[A

[2/4][0/24]	Loss_D: 0.0005	Loss_G: 50.8624	D(x): 0.9995	D(G(z)): 0.0000 / 0.0000




  8%|▊         | 2/24 [00:00<00:09,  2.42it/s][A[A

[2/4][1/24]	Loss_D: 0.0002	Loss_G: 50.6104	D(x): 0.9998	D(G(z)): 0.0000 / 0.0000




 12%|█▎        | 3/24 [00:01<00:08,  2.41it/s][A[A

[2/4][2/24]	Loss_D: 0.0001	Loss_G: 50.6761	D(x): 0.9999	D(G(z)): 0.0000 / 0.0000




 17%|█▋        | 4/24 [00:01<00:08,  2.40it/s][A[A

[2/4][3/24]	Loss_D: 0.0002	Loss_G: 50.4963	D(x): 0.9998	D(G(z)): 0.0000 / 0.0000




 21%|██        | 5/24 [00:02<00:07,  2.40it/s][A[A

[2/4][4/24]	Loss_D: 0.0001	Loss_G: 50.8092	D(x): 0.9999	D(G(z)): 0.0000 / 0.0000




 25%|██▌       | 6/24 [00:02<00:07,  2.39it/s][A[A

[2/4][5/24]	Loss_D: 0.0001	Loss_G: 50.7675	D(x): 0.9999	D(G(z)): 0.0000 / 0.0000




 29%|██▉       | 7/24 [00:02<00:07,  2.38it/s][A[A

[2/4][6/24]	Loss_D: 0.0004	Loss_G: 50.8297	D(x): 0.9996	D(G(z)): 0.0000 / 0.0000




 33%|███▎      | 8/24 [00:03<00:06,  2.39it/s][A[A

[2/4][7/24]	Loss_D: 0.0003	Loss_G: 50.7865	D(x): 0.9997	D(G(z)): 0.0000 / 0.0000




 38%|███▊      | 9/24 [00:03<00:06,  2.39it/s][A[A

[2/4][8/24]	Loss_D: 0.0004	Loss_G: 50.7919	D(x): 0.9996	D(G(z)): 0.0000 / 0.0000




 42%|████▏     | 10/24 [00:04<00:05,  2.40it/s][A[A

[2/4][9/24]	Loss_D: 0.0004	Loss_G: 50.6069	D(x): 0.9996	D(G(z)): 0.0000 / 0.0000




 46%|████▌     | 11/24 [00:04<00:05,  2.40it/s][A[A

[2/4][10/24]	Loss_D: 0.0003	Loss_G: 50.5478	D(x): 0.9997	D(G(z)): 0.0000 / 0.0000




 50%|█████     | 12/24 [00:05<00:04,  2.41it/s][A[A

[2/4][11/24]	Loss_D: 0.0002	Loss_G: 50.7154	D(x): 0.9998	D(G(z)): 0.0000 / 0.0000




 54%|█████▍    | 13/24 [00:05<00:04,  2.41it/s][A[A

[2/4][12/24]	Loss_D: 0.0003	Loss_G: 51.1770	D(x): 0.9997	D(G(z)): 0.0000 / 0.0000




 58%|█████▊    | 14/24 [00:05<00:04,  2.40it/s][A[A

[2/4][13/24]	Loss_D: 0.0003	Loss_G: 50.7400	D(x): 0.9997	D(G(z)): 0.0000 / 0.0000




 62%|██████▎   | 15/24 [00:06<00:03,  2.41it/s][A[A

[2/4][14/24]	Loss_D: 0.0005	Loss_G: 50.7262	D(x): 0.9995	D(G(z)): 0.0000 / 0.0000




 67%|██████▋   | 16/24 [00:06<00:03,  2.42it/s][A[A

[2/4][15/24]	Loss_D: 0.0009	Loss_G: 50.5693	D(x): 0.9991	D(G(z)): 0.0000 / 0.0000




 71%|███████   | 17/24 [00:07<00:02,  2.42it/s][A[A

[2/4][16/24]	Loss_D: 0.0001	Loss_G: 50.1556	D(x): 0.9999	D(G(z)): 0.0000 / 0.0000




 75%|███████▌  | 18/24 [00:07<00:02,  2.42it/s][A[A

[2/4][17/24]	Loss_D: 0.0001	Loss_G: 50.6917	D(x): 0.9999	D(G(z)): 0.0000 / 0.0000




 79%|███████▉  | 19/24 [00:07<00:02,  2.40it/s][A[A

[2/4][18/24]	Loss_D: 0.0003	Loss_G: 50.3466	D(x): 0.9997	D(G(z)): 0.0000 / 0.0000




 83%|████████▎ | 20/24 [00:08<00:01,  2.42it/s][A[A

[2/4][19/24]	Loss_D: 0.0002	Loss_G: 51.0311	D(x): 0.9998	D(G(z)): 0.0000 / 0.0000




 88%|████████▊ | 21/24 [00:08<00:01,  2.42it/s][A[A

[2/4][20/24]	Loss_D: 0.0001	Loss_G: 50.1895	D(x): 0.9999	D(G(z)): 0.0000 / 0.0000




 92%|█████████▏| 22/24 [00:09<00:00,  2.42it/s][A[A

[2/4][21/24]	Loss_D: 0.0001	Loss_G: 50.2265	D(x): 0.9999	D(G(z)): 0.0000 / 0.0000




 96%|█████████▌| 23/24 [00:09<00:00,  2.43it/s][A[A

100%|██████████| 24/24 [00:09<00:00,  2.46it/s]


  0%|          | 0/24 [00:00<?, ?it/s][A[A

[2/4][22/24]	Loss_D: 0.0002	Loss_G: 50.3262	D(x): 0.9998	D(G(z)): 0.0000 / 0.0000
[2/4][23/24]	Loss_D: 0.0001	Loss_G: 50.7163	D(x): 0.9999	D(G(z)): 0.0000 / 0.0000




  4%|▍         | 1/24 [00:00<00:09,  2.44it/s][A[A

[3/4][0/24]	Loss_D: 0.0003	Loss_G: 50.7314	D(x): 0.9997	D(G(z)): 0.0000 / 0.0000




  8%|▊         | 2/24 [00:00<00:08,  2.45it/s][A[A

[3/4][1/24]	Loss_D: 0.0005	Loss_G: 50.8991	D(x): 0.9995	D(G(z)): 0.0000 / 0.0000




 12%|█▎        | 3/24 [00:01<00:08,  2.43it/s][A[A

[3/4][2/24]	Loss_D: 0.0003	Loss_G: 50.7181	D(x): 0.9997	D(G(z)): 0.0000 / 0.0000




 17%|█▋        | 4/24 [00:01<00:08,  2.44it/s][A[A

[3/4][3/24]	Loss_D: 0.0001	Loss_G: 50.7079	D(x): 0.9999	D(G(z)): 0.0000 / 0.0000




 21%|██        | 5/24 [00:02<00:07,  2.44it/s][A[A

[3/4][4/24]	Loss_D: 0.0000	Loss_G: 50.4025	D(x): 1.0000	D(G(z)): 0.0000 / 0.0000




 25%|██▌       | 6/24 [00:02<00:07,  2.45it/s][A[A

[3/4][5/24]	Loss_D: 0.0007	Loss_G: 50.2052	D(x): 0.9993	D(G(z)): 0.0000 / 0.0000




 29%|██▉       | 7/24 [00:02<00:06,  2.46it/s][A[A

[3/4][6/24]	Loss_D: 0.0007	Loss_G: 50.4929	D(x): 0.9993	D(G(z)): 0.0000 / 0.0000




 33%|███▎      | 8/24 [00:03<00:06,  2.45it/s][A[A

[3/4][7/24]	Loss_D: 0.0001	Loss_G: 50.3946	D(x): 0.9999	D(G(z)): 0.0000 / 0.0000




 38%|███▊      | 9/24 [00:03<00:06,  2.46it/s][A[A

[3/4][8/24]	Loss_D: 0.0004	Loss_G: 50.3932	D(x): 0.9996	D(G(z)): 0.0000 / 0.0000




 42%|████▏     | 10/24 [00:04<00:05,  2.46it/s][A[A

[3/4][9/24]	Loss_D: 0.0002	Loss_G: 50.3678	D(x): 0.9998	D(G(z)): 0.0000 / 0.0000




 46%|████▌     | 11/24 [00:04<00:05,  2.47it/s][A[A

[3/4][10/24]	Loss_D: 0.0000	Loss_G: 50.7359	D(x): 1.0000	D(G(z)): 0.0000 / 0.0000




 50%|█████     | 12/24 [00:04<00:04,  2.47it/s][A[A

[3/4][11/24]	Loss_D: 0.0001	Loss_G: 50.6326	D(x): 0.9999	D(G(z)): 0.0000 / 0.0000




 54%|█████▍    | 13/24 [00:05<00:04,  2.47it/s][A[A

[3/4][12/24]	Loss_D: 0.0002	Loss_G: 50.1899	D(x): 0.9998	D(G(z)): 0.0000 / 0.0000




 58%|█████▊    | 14/24 [00:05<00:04,  2.47it/s][A[A

[3/4][13/24]	Loss_D: 0.0000	Loss_G: 50.5939	D(x): 1.0000	D(G(z)): 0.0000 / 0.0000




 62%|██████▎   | 15/24 [00:06<00:03,  2.48it/s][A[A

[3/4][14/24]	Loss_D: 0.0002	Loss_G: 50.4794	D(x): 0.9998	D(G(z)): 0.0000 / 0.0000




 67%|██████▋   | 16/24 [00:06<00:03,  2.48it/s][A[A

[3/4][15/24]	Loss_D: 0.0001	Loss_G: 50.6083	D(x): 0.9999	D(G(z)): 0.0000 / 0.0000




 71%|███████   | 17/24 [00:06<00:02,  2.48it/s][A[A

[3/4][16/24]	Loss_D: 0.0003	Loss_G: 50.3540	D(x): 0.9997	D(G(z)): 0.0000 / 0.0000




 75%|███████▌  | 18/24 [00:07<00:02,  2.46it/s][A[A

[3/4][17/24]	Loss_D: 0.0001	Loss_G: 50.5448	D(x): 0.9999	D(G(z)): 0.0000 / 0.0000




 79%|███████▉  | 19/24 [00:07<00:02,  2.46it/s][A[A

[3/4][18/24]	Loss_D: 0.0003	Loss_G: 50.2543	D(x): 0.9997	D(G(z)): 0.0000 / 0.0000




 83%|████████▎ | 20/24 [00:08<00:01,  2.48it/s][A[A

[3/4][19/24]	Loss_D: 0.0004	Loss_G: 50.5190	D(x): 0.9996	D(G(z)): 0.0000 / 0.0000




 88%|████████▊ | 21/24 [00:08<00:01,  2.48it/s][A[A

[3/4][20/24]	Loss_D: 0.0015	Loss_G: 50.3807	D(x): 0.9985	D(G(z)): 0.0000 / 0.0000




 92%|█████████▏| 22/24 [00:08<00:00,  2.49it/s][A[A

[3/4][21/24]	Loss_D: 0.0001	Loss_G: 50.6947	D(x): 0.9999	D(G(z)): 0.0000 / 0.0000




 96%|█████████▌| 23/24 [00:09<00:00,  2.49it/s][A[A

[3/4][22/24]	Loss_D: 0.0001	Loss_G: 50.9908	D(x): 0.9999	D(G(z)): 0.0000 / 0.0000
[3/4][23/24]	Loss_D: 0.0000	Loss_G: 50.5737	D(x): 1.0000	D(G(z)): 0.0000 / 0.0000




100%|██████████| 24/24 [00:09<00:00,  2.51it/s]


In [9]:
print(generator)

Generator(
  (_model): Sequential(
    (0): ConvTranspose2d(100, 1024, kernel_size=(4, 4), stride=(1, 1), bias=False)
    (1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): ConvTranspose2d(1024, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (4): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (7): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU(inplace=True)
    (9): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (10): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (11): ReLU(inplace=True)
    (12): ConvTranspose2d(128, 1, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (13): Tanh(

In [10]:
print(discriminator)

Discriminator(
  (_model): Sequential(
    (0): Conv2d(1, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
    (2): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (4): LeakyReLU(negative_slope=0.2, inplace=True)
    (5): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (6): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): LeakyReLU(negative_slope=0.2, inplace=True)
    (8): Conv2d(512, 1024, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (9): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): LeakyReLU(negative_slope=0.2, inplace=True)
    (11): Conv2d(1024, 1, kernel_size=(4, 4), stride=(1, 1), bias=False)
    (12): Sigmoid()
  )
)
