In [40]:
import torch
import torch.nn as nn
import torchvision
from torchvision import datasets, transforms
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt


In [41]:
class Discriminator(nn.Module):
    def __init__(self, channels_img, features_d):
        super(Discriminator, self).__init__()
        self.disc = nn.Sequential(
            # input: N x channels_img x 64 x 64
            nn.Conv2d(
                channels_img, features_d, kernel_size=4, stride=2, padding=1
            ),
            nn.LeakyReLU(0.2),
            # _block(in_channels, out_channels, kernel_size, stride, padding)
            self._block(features_d, features_d * 2, 4, 2, 1),
            self._block(features_d * 2, features_d * 4, 4, 2, 1),
            self._block(features_d * 4, features_d * 8, 4, 2, 1),
            # After all _block img output is 4x4 (Conv2d below makes into 1x1)
            nn.Conv2d(features_d * 8, 1, kernel_size=4, stride=2, padding=0),
        )

    def _block(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.Conv2d(
                in_channels,
                out_channels,
                kernel_size,
                stride,
                padding,
                bias=False,
            ),
            nn.InstanceNorm2d(out_channels, affine=True),
            nn.LeakyReLU(0.2),
        )

    def forward(self, x):
        return self.disc(x)


In [42]:
class Generator(nn.Module):
    def __init__(self, channels_noise, channels_img, features_g):
        super(Generator, self).__init__()
        self.net = nn.Sequential(
            # Input: N x channels_noise x 1 x 1
            self._block(channels_noise, features_g * 16, 4, 1, 0),  # img: 4x4
            self._block(features_g * 16, features_g * 8, 4, 2, 1),  # img: 8x8
            self._block(features_g * 8, features_g * 4, 4, 2, 1),  # img: 16x16
            self._block(features_g * 4, features_g * 2, 4, 2, 1),  # img: 32x32
            nn.ConvTranspose2d(
                features_g * 2, channels_img, kernel_size=4, stride=2, padding=1
            ),
            # Output: N x channels_img x 64 x 64
            nn.Tanh(),
        )

    def _block(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.ConvTranspose2d(
                in_channels,
                out_channels,
                kernel_size,
                stride,
                padding,
                bias=False,
            ),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
        )

    def forward(self, x):
        return self.net(x)


In [43]:
def initialize_weights(model):
    # Initializes weights according to the DCGAN paper
    for m in model.modules():
        if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
            nn.init.normal_(m.weight.data, 0.0, 0.02)

In [44]:
# Hyperparameters etc
device = "cuda:1" if torch.cuda.is_available() else "cpu"
LEARNING_RATE = 5e-5
BATCH_SIZE = 64
IMAGE_SIZE = 64
CHANNELS_IMG = 1
Z_DIM = 128
NUM_EPOCHS = 100
FEATURES_CRITIC = 64
FEATURES_GEN = 64
CRITIC_ITERATIONS = 5
WEIGHT_CLIP = 0.01


In [45]:
transforms = transforms.Compose(
    [
        transforms.Resize(IMAGE_SIZE),
        transforms.ToTensor(),
        transforms.Normalize(
            [0.5 for _ in range(CHANNELS_IMG)], [0.5 for _ in range(CHANNELS_IMG)]
        ),
    ]
)

In [46]:

dataset = datasets.MNIST(root="dataset/", transform=transforms, download=True)
#comment mnist and uncomment below if you want to train on CelebA dataset
#dataset = datasets.ImageFolder(root="celeb_dataset", transform=transforms)
loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

# initialize gen and disc/critic
gen = Generator(Z_DIM, CHANNELS_IMG, FEATURES_GEN).to(device)
critic = Discriminator(CHANNELS_IMG, FEATURES_CRITIC).to(device)
initialize_weights(gen)
initialize_weights(critic)

# initializate optimizer
opt_gen = torch.optim.RMSprop(gen.parameters(), lr=LEARNING_RATE)
opt_critic = torch.optim.RMSprop(critic.parameters(), lr=LEARNING_RATE)

# for tensorboard plotting
fixed_noise = torch.randn(32, Z_DIM, 1, 1).to(device)
writer_real = SummaryWriter(f"logs/real")
writer_fake = SummaryWriter(f"logs/fake")
step = 0


In [47]:

gen.train()
critic.train()

for epoch in range(NUM_EPOCHS):
    # Target labels not needed! <3 unsupervised
    for batch_idx, (data, _) in enumerate(tqdm(loader)):
        data = data.to(device)
        cur_batch_size = data.shape[0]

        # Train Critic: max E[critic(real)] - E[critic(fake)]
        for _ in range(CRITIC_ITERATIONS):
            noise = torch.randn(cur_batch_size, Z_DIM, 1, 1).to(device)
            fake = gen(noise)
            critic_real = critic(data).reshape(-1)
            critic_fake = critic(fake).reshape(-1)
            loss_critic = -(torch.mean(critic_real) - torch.mean(critic_fake))
            critic.zero_grad()
            loss_critic.backward(retain_graph=True)
            opt_critic.step()

            # clip critic weights between -0.01, 0.01
            for p in critic.parameters():
                p.data.clamp_(-WEIGHT_CLIP, WEIGHT_CLIP)

        # Train Generator: max E[critic(gen_fake)] <-> min -E[critic(gen_fake)]
        gen_fake = critic(fake).reshape(-1)
        loss_gen = -torch.mean(gen_fake)
        gen.zero_grad()
        loss_gen.backward()
        opt_gen.step()

        # Print losses occasionally and print to tensorboard
        if batch_idx % 100 == 0 and batch_idx > 0:
            gen.eval()
            critic.eval()
            print(
                f"Epoch [{epoch}/{NUM_EPOCHS}] Batch {batch_idx}/{len(loader)} \
                  Loss D: {loss_critic:.4f}, loss G: {loss_gen:.4f}"
            )

            with torch.no_grad():
                fake = gen(noise)
                # take out (up to) 32 examples
                img_grid_real = torchvision.utils.make_grid(
                    data[:32], normalize=True
                )
                img_grid_fake = torchvision.utils.make_grid(
                    fake[:32], normalize=True
                )

                writer_real.add_image("Real", img_grid_real, global_step=step)
                writer_fake.add_image("Fake", img_grid_fake, global_step=step)

            step += 1
            gen.train()
            critic.train()

 11%|█         | 101/938 [00:10<01:44,  8.02it/s]

Epoch [0/100] Batch 100/938                   Loss D: -1.3731, loss G: 0.6618


 22%|██▏       | 202/938 [00:20<01:29,  8.25it/s]

Epoch [0/100] Batch 200/938                   Loss D: -1.5179, loss G: 0.7331


 32%|███▏      | 302/938 [00:31<01:20,  7.87it/s]

Epoch [0/100] Batch 300/938                   Loss D: -1.5323, loss G: 0.7382


 43%|████▎     | 402/938 [00:41<01:05,  8.18it/s]

Epoch [0/100] Batch 400/938                   Loss D: -1.5341, loss G: 0.7385


 54%|█████▎    | 502/938 [00:51<00:54,  7.97it/s]

Epoch [0/100] Batch 500/938                   Loss D: -1.5357, loss G: 0.7390


 64%|██████▍   | 602/938 [01:02<00:42,  7.99it/s]

Epoch [0/100] Batch 600/938                   Loss D: -1.5320, loss G: 0.7385


 75%|███████▍  | 702/938 [01:12<00:29,  7.99it/s]

Epoch [0/100] Batch 700/938                   Loss D: -1.5240, loss G: 0.7368


 86%|████████▌ | 802/938 [01:22<00:16,  8.06it/s]

Epoch [0/100] Batch 800/938                   Loss D: -1.4832, loss G: 0.7283


 96%|█████████▌| 902/938 [01:32<00:04,  8.10it/s]

Epoch [0/100] Batch 900/938                   Loss D: -1.4726, loss G: 0.7131


100%|██████████| 938/938 [01:36<00:00,  9.71it/s]
 11%|█         | 102/938 [00:10<01:44,  8.03it/s]

Epoch [1/100] Batch 100/938                   Loss D: -1.3564, loss G: 0.6422


 22%|██▏       | 202/938 [00:20<01:32,  7.92it/s]

Epoch [1/100] Batch 200/938                   Loss D: -1.4132, loss G: 0.6780


 32%|███▏      | 302/938 [00:31<01:17,  8.20it/s]

Epoch [1/100] Batch 300/938                   Loss D: -1.4024, loss G: 0.6869


 43%|████▎     | 402/938 [00:41<01:07,  7.99it/s]

Epoch [1/100] Batch 400/938                   Loss D: -1.3281, loss G: 0.5582


 54%|█████▎    | 502/938 [00:52<00:54,  8.03it/s]

Epoch [1/100] Batch 500/938                   Loss D: -1.3670, loss G: 0.6799


 64%|██████▍   | 602/938 [01:02<00:42,  7.88it/s]

Epoch [1/100] Batch 600/938                   Loss D: -1.3389, loss G: 0.6841


 75%|███████▍  | 702/938 [01:12<00:29,  7.90it/s]

Epoch [1/100] Batch 700/938                   Loss D: -1.3195, loss G: 0.6225


 86%|████████▌ | 802/938 [01:23<00:17,  7.92it/s]

Epoch [1/100] Batch 800/938                   Loss D: -1.3292, loss G: 0.6357


 96%|█████████▌| 902/938 [01:33<00:04,  8.10it/s]

Epoch [1/100] Batch 900/938                   Loss D: -1.3142, loss G: 0.6357


100%|██████████| 938/938 [01:37<00:00,  9.63it/s]
 11%|█         | 102/938 [00:10<01:42,  8.17it/s]

Epoch [2/100] Batch 100/938                   Loss D: -1.1911, loss G: 0.6822


 22%|██▏       | 202/938 [00:21<01:28,  8.29it/s]

Epoch [2/100] Batch 200/938                   Loss D: -1.2378, loss G: 0.6087


 32%|███▏      | 302/938 [00:31<01:17,  8.26it/s]

Epoch [2/100] Batch 300/938                   Loss D: -1.2404, loss G: 0.5966


 43%|████▎     | 402/938 [00:41<01:05,  8.14it/s]

Epoch [2/100] Batch 400/938                   Loss D: -1.1940, loss G: 0.5234


 54%|█████▎    | 502/938 [00:52<00:51,  8.41it/s]

Epoch [2/100] Batch 500/938                   Loss D: -1.1500, loss G: 0.6419


 64%|██████▍   | 602/938 [01:02<00:39,  8.58it/s]

Epoch [2/100] Batch 600/938                   Loss D: -1.0611, loss G: 0.4465


 75%|███████▍  | 702/938 [01:13<00:28,  8.22it/s]

Epoch [2/100] Batch 700/938                   Loss D: -1.0773, loss G: 0.6311


 86%|████████▌ | 802/938 [01:23<00:16,  8.42it/s]

Epoch [2/100] Batch 800/938                   Loss D: -1.1376, loss G: 0.6170


 96%|█████████▌| 902/938 [01:33<00:04,  8.60it/s]

Epoch [2/100] Batch 900/938                   Loss D: -1.1290, loss G: 0.5974


100%|██████████| 938/938 [01:37<00:00,  9.62it/s]
 11%|█         | 102/938 [00:10<01:38,  8.51it/s]

Epoch [3/100] Batch 100/938                   Loss D: -1.0912, loss G: 0.4108


 22%|██▏       | 202/938 [00:20<01:26,  8.49it/s]

Epoch [3/100] Batch 200/938                   Loss D: -1.0812, loss G: 0.6101


 32%|███▏      | 302/938 [00:31<01:15,  8.37it/s]

Epoch [3/100] Batch 300/938                   Loss D: -1.1228, loss G: 0.6216


 43%|████▎     | 402/938 [00:41<01:03,  8.39it/s]

Epoch [3/100] Batch 400/938                   Loss D: -1.1357, loss G: 0.5570


 54%|█████▎    | 502/938 [00:52<00:52,  8.32it/s]

Epoch [3/100] Batch 500/938                   Loss D: -1.0968, loss G: 0.4651


 64%|██████▍   | 602/938 [01:02<00:40,  8.33it/s]

Epoch [3/100] Batch 600/938                   Loss D: -1.0925, loss G: 0.6153


 75%|███████▍  | 702/938 [01:12<00:28,  8.35it/s]

Epoch [3/100] Batch 700/938                   Loss D: -1.0125, loss G: 0.3483


 86%|████████▌ | 802/938 [01:23<00:16,  8.34it/s]

Epoch [3/100] Batch 800/938                   Loss D: -1.0656, loss G: 0.4450


 96%|█████████▌| 902/938 [01:33<00:04,  8.39it/s]

Epoch [3/100] Batch 900/938                   Loss D: -0.9759, loss G: 0.5995


100%|██████████| 938/938 [01:37<00:00,  9.63it/s]
 11%|█         | 102/938 [00:10<01:39,  8.43it/s]

Epoch [4/100] Batch 100/938                   Loss D: -1.0176, loss G: 0.4213


 22%|██▏       | 202/938 [00:20<01:27,  8.45it/s]

Epoch [4/100] Batch 200/938                   Loss D: -1.1154, loss G: 0.5380


 32%|███▏      | 302/938 [00:31<01:17,  8.21it/s]

Epoch [4/100] Batch 300/938                   Loss D: -1.0669, loss G: 0.4783


 43%|████▎     | 402/938 [00:41<01:03,  8.50it/s]

Epoch [4/100] Batch 400/938                   Loss D: -1.0553, loss G: 0.5901


 54%|█████▎    | 502/938 [00:52<00:50,  8.63it/s]

Epoch [4/100] Batch 500/938                   Loss D: -0.9495, loss G: 0.6097


 64%|██████▍   | 602/938 [01:02<00:39,  8.44it/s]

Epoch [4/100] Batch 600/938                   Loss D: -1.0038, loss G: 0.5597


 75%|███████▍  | 702/938 [01:12<00:28,  8.40it/s]

Epoch [4/100] Batch 700/938                   Loss D: -1.0565, loss G: 0.5932


 86%|████████▌ | 802/938 [01:23<00:16,  8.36it/s]

Epoch [4/100] Batch 800/938                   Loss D: -1.1249, loss G: 0.5261


 96%|█████████▌| 902/938 [01:33<00:04,  8.45it/s]

Epoch [4/100] Batch 900/938                   Loss D: -0.9608, loss G: 0.3521


100%|██████████| 938/938 [01:37<00:00,  9.65it/s]
 11%|█         | 102/938 [00:10<01:38,  8.45it/s]

Epoch [5/100] Batch 100/938                   Loss D: -0.9066, loss G: 0.5967


 22%|██▏       | 202/938 [00:20<01:25,  8.60it/s]

Epoch [5/100] Batch 200/938                   Loss D: -1.0065, loss G: 0.5973


 32%|███▏      | 302/938 [00:31<01:16,  8.28it/s]

Epoch [5/100] Batch 300/938                   Loss D: -0.9685, loss G: 0.6018


 43%|████▎     | 402/938 [00:41<01:02,  8.54it/s]

Epoch [5/100] Batch 400/938                   Loss D: -1.0003, loss G: 0.5829


 54%|█████▎    | 502/938 [00:52<00:49,  8.74it/s]

Epoch [5/100] Batch 500/938                   Loss D: -1.0220, loss G: 0.4605


 64%|██████▍   | 602/938 [01:02<00:39,  8.43it/s]

Epoch [5/100] Batch 600/938                   Loss D: -0.8773, loss G: 0.3090


 75%|███████▍  | 702/938 [01:12<00:27,  8.47it/s]

Epoch [5/100] Batch 700/938                   Loss D: -0.9449, loss G: 0.3449


 86%|████████▌ | 802/938 [01:23<00:16,  8.38it/s]

Epoch [5/100] Batch 800/938                   Loss D: -0.9400, loss G: 0.3551


 96%|█████████▌| 902/938 [01:33<00:04,  8.31it/s]

Epoch [5/100] Batch 900/938                   Loss D: -1.0061, loss G: 0.5738


100%|██████████| 938/938 [01:37<00:00,  9.64it/s]
 11%|█         | 102/938 [00:10<01:39,  8.38it/s]

Epoch [6/100] Batch 100/938                   Loss D: -0.6447, loss G: 0.0879


 22%|██▏       | 202/938 [00:20<01:27,  8.39it/s]

Epoch [6/100] Batch 200/938                   Loss D: -0.8389, loss G: 0.5913


 32%|███▏      | 302/938 [00:31<01:13,  8.64it/s]

Epoch [6/100] Batch 300/938                   Loss D: -1.0022, loss G: 0.5593


 43%|████▎     | 402/938 [00:41<01:04,  8.34it/s]

Epoch [6/100] Batch 400/938                   Loss D: -0.9507, loss G: 0.5896


 54%|█████▎    | 502/938 [00:52<00:51,  8.46it/s]

Epoch [6/100] Batch 500/938                   Loss D: -0.8485, loss G: 0.2856


 64%|██████▍   | 602/938 [01:02<00:39,  8.55it/s]

Epoch [6/100] Batch 600/938                   Loss D: -0.9083, loss G: 0.5856


 75%|███████▍  | 702/938 [01:12<00:28,  8.36it/s]

Epoch [6/100] Batch 700/938                   Loss D: -0.8571, loss G: 0.6023


 86%|████████▌ | 802/938 [01:23<00:16,  8.45it/s]

Epoch [6/100] Batch 800/938                   Loss D: -0.9002, loss G: 0.5827


 96%|█████████▌| 902/938 [01:33<00:04,  8.44it/s]

Epoch [6/100] Batch 900/938                   Loss D: -0.8838, loss G: 0.5809


100%|██████████| 938/938 [01:37<00:00,  9.66it/s]
 11%|█         | 102/938 [00:10<01:39,  8.44it/s]

Epoch [7/100] Batch 100/938                   Loss D: -0.8550, loss G: 0.5731


 22%|██▏       | 202/938 [00:20<01:26,  8.51it/s]

Epoch [7/100] Batch 200/938                   Loss D: -0.9364, loss G: 0.5701


 32%|███▏      | 302/938 [00:31<01:17,  8.25it/s]

Epoch [7/100] Batch 300/938                   Loss D: -0.8525, loss G: 0.5655


 43%|████▎     | 402/938 [00:41<01:01,  8.71it/s]

Epoch [7/100] Batch 400/938                   Loss D: -0.8935, loss G: 0.5751


 54%|█████▎    | 502/938 [00:51<00:52,  8.37it/s]

Epoch [7/100] Batch 500/938                   Loss D: -0.8677, loss G: 0.5674


 64%|██████▍   | 602/938 [01:02<00:39,  8.57it/s]

Epoch [7/100] Batch 600/938                   Loss D: -0.7115, loss G: 0.5537


 75%|███████▍  | 702/938 [01:12<00:27,  8.55it/s]

Epoch [7/100] Batch 700/938                   Loss D: -0.8239, loss G: 0.5589


 86%|████████▌ | 802/938 [01:22<00:15,  8.73it/s]

Epoch [7/100] Batch 800/938                   Loss D: -0.7668, loss G: 0.5797


 96%|█████████▌| 902/938 [01:33<00:04,  8.55it/s]

Epoch [7/100] Batch 900/938                   Loss D: -0.8556, loss G: 0.5542


100%|██████████| 938/938 [01:36<00:00,  9.67it/s]
 11%|█         | 102/938 [00:10<01:40,  8.35it/s]

Epoch [8/100] Batch 100/938                   Loss D: -0.8348, loss G: 0.5804


 22%|██▏       | 202/938 [00:20<01:28,  8.33it/s]

Epoch [8/100] Batch 200/938                   Loss D: -0.7143, loss G: 0.1606


 32%|███▏      | 302/938 [00:31<01:16,  8.26it/s]

Epoch [8/100] Batch 300/938                   Loss D: -0.8869, loss G: 0.3384


 43%|████▎     | 402/938 [00:41<01:03,  8.45it/s]

Epoch [8/100] Batch 400/938                   Loss D: -0.8713, loss G: 0.5764


 54%|█████▎    | 502/938 [00:51<00:51,  8.49it/s]

Epoch [8/100] Batch 500/938                   Loss D: -0.8937, loss G: 0.3646


 64%|██████▍   | 602/938 [01:02<00:40,  8.29it/s]

Epoch [8/100] Batch 600/938                   Loss D: -0.9084, loss G: 0.4016


 75%|███████▍  | 702/938 [01:12<00:27,  8.46it/s]

Epoch [8/100] Batch 700/938                   Loss D: -0.8002, loss G: 0.2503


 86%|████████▌ | 802/938 [01:23<00:16,  8.31it/s]

Epoch [8/100] Batch 800/938                   Loss D: -0.7874, loss G: 0.2732


 96%|█████████▌| 902/938 [01:33<00:04,  8.57it/s]

Epoch [8/100] Batch 900/938                   Loss D: -0.8461, loss G: 0.5702


100%|██████████| 938/938 [01:37<00:00,  9.66it/s]
 11%|█         | 102/938 [00:10<01:40,  8.31it/s]

Epoch [9/100] Batch 100/938                   Loss D: -0.6838, loss G: 0.1705


 22%|██▏       | 202/938 [00:20<01:26,  8.52it/s]

Epoch [9/100] Batch 200/938                   Loss D: -0.8357, loss G: 0.3073


 32%|███▏      | 302/938 [00:31<01:14,  8.59it/s]

Epoch [9/100] Batch 300/938                   Loss D: -0.7893, loss G: 0.3189


 43%|████▎     | 402/938 [00:41<01:03,  8.47it/s]

Epoch [9/100] Batch 400/938                   Loss D: -0.8073, loss G: 0.2536


 54%|█████▎    | 502/938 [00:51<00:52,  8.25it/s]

Epoch [9/100] Batch 500/938                   Loss D: -0.8789, loss G: 0.3623


 64%|██████▍   | 602/938 [01:02<00:39,  8.54it/s]

Epoch [9/100] Batch 600/938                   Loss D: -0.8257, loss G: 0.2740


 75%|███████▍  | 702/938 [01:12<00:28,  8.39it/s]

Epoch [9/100] Batch 700/938                   Loss D: -0.8088, loss G: 0.2827


 86%|████████▌ | 802/938 [01:23<00:15,  8.79it/s]

Epoch [9/100] Batch 800/938                   Loss D: -0.7097, loss G: 0.1658


 96%|█████████▌| 902/938 [01:33<00:04,  8.57it/s]

Epoch [9/100] Batch 900/938                   Loss D: -0.8429, loss G: 0.2344


100%|██████████| 938/938 [01:37<00:00,  9.67it/s]
 11%|█         | 102/938 [00:10<01:38,  8.46it/s]

Epoch [10/100] Batch 100/938                   Loss D: -0.7253, loss G: 0.1586


 32%|███▏      | 302/938 [00:31<01:14,  8.51it/s]

Epoch [11/100] Batch 300/938                   Loss D: -0.8310, loss G: 0.5389


 43%|████▎     | 402/938 [00:41<01:05,  8.15it/s]

Epoch [11/100] Batch 400/938                   Loss D: -0.7759, loss G: 0.5322


 54%|█████▎    | 502/938 [00:51<00:51,  8.52it/s]

Epoch [11/100] Batch 500/938                   Loss D: -0.7266, loss G: 0.5562


 64%|██████▍   | 602/938 [01:02<00:38,  8.70it/s]

Epoch [11/100] Batch 600/938                   Loss D: -0.7990, loss G: 0.5310


 75%|███████▍  | 702/938 [01:12<00:27,  8.45it/s]

Epoch [11/100] Batch 700/938                   Loss D: -0.7248, loss G: 0.5359


 86%|████████▌ | 802/938 [01:22<00:16,  8.38it/s]

Epoch [11/100] Batch 800/938                   Loss D: -0.6930, loss G: 0.5258


 96%|█████████▌| 902/938 [01:33<00:04,  8.52it/s]

Epoch [11/100] Batch 900/938                   Loss D: -0.7072, loss G: 0.5270


100%|██████████| 938/938 [01:36<00:00,  9.67it/s]
 11%|█         | 102/938 [00:10<01:40,  8.35it/s]

Epoch [12/100] Batch 100/938                   Loss D: -0.7723, loss G: 0.1657


 22%|██▏       | 202/938 [00:20<01:28,  8.35it/s]

Epoch [12/100] Batch 200/938                   Loss D: -0.7826, loss G: 0.3054


 32%|███▏      | 302/938 [00:31<01:15,  8.46it/s]

Epoch [12/100] Batch 300/938                   Loss D: -0.7392, loss G: 0.1769


 43%|████▎     | 402/938 [00:41<01:04,  8.30it/s]

Epoch [12/100] Batch 400/938                   Loss D: -0.7650, loss G: 0.5316


 54%|█████▎    | 502/938 [00:51<00:49,  8.74it/s]

Epoch [12/100] Batch 500/938                   Loss D: -0.7753, loss G: 0.2778


 64%|██████▍   | 602/938 [01:02<00:38,  8.79it/s]

Epoch [12/100] Batch 600/938                   Loss D: -0.7524, loss G: 0.2557


 75%|███████▍  | 702/938 [01:12<00:27,  8.57it/s]

Epoch [12/100] Batch 700/938                   Loss D: -0.6789, loss G: 0.1955


 86%|████████▌ | 802/938 [01:22<00:16,  8.18it/s]

Epoch [12/100] Batch 800/938                   Loss D: -0.7020, loss G: 0.2602


 96%|█████████▌| 902/938 [01:33<00:04,  8.42it/s]

Epoch [12/100] Batch 900/938                   Loss D: -0.7507, loss G: 0.2332


100%|██████████| 938/938 [01:36<00:00,  9.67it/s]
 11%|█         | 102/938 [00:10<01:41,  8.27it/s]

Epoch [13/100] Batch 100/938                   Loss D: -0.7591, loss G: 0.2888


 22%|██▏       | 202/938 [00:20<01:24,  8.71it/s]

Epoch [13/100] Batch 200/938                   Loss D: -0.7512, loss G: 0.2698


 32%|███▏      | 302/938 [00:31<01:15,  8.41it/s]

Epoch [13/100] Batch 300/938                   Loss D: -0.7132, loss G: 0.2133


 43%|████▎     | 402/938 [00:41<01:04,  8.35it/s]

Epoch [13/100] Batch 400/938                   Loss D: -0.7920, loss G: 0.2530


 54%|█████▎    | 502/938 [00:51<00:52,  8.29it/s]

Epoch [13/100] Batch 500/938                   Loss D: -0.7327, loss G: 0.2338


 64%|██████▍   | 602/938 [01:02<00:39,  8.40it/s]

Epoch [13/100] Batch 600/938                   Loss D: -0.7288, loss G: 0.2849


 75%|███████▍  | 702/938 [01:12<00:27,  8.52it/s]

Epoch [13/100] Batch 700/938                   Loss D: -0.7401, loss G: 0.2326


 86%|████████▌ | 802/938 [01:22<00:15,  8.52it/s]

Epoch [13/100] Batch 800/938                   Loss D: -0.7259, loss G: 0.2621


 96%|█████████▌| 902/938 [01:33<00:04,  8.51it/s]

Epoch [13/100] Batch 900/938                   Loss D: -0.7363, loss G: 0.3109


100%|██████████| 938/938 [01:37<00:00,  9.67it/s]
 11%|█         | 102/938 [00:10<01:38,  8.47it/s]

Epoch [14/100] Batch 100/938                   Loss D: -0.7513, loss G: 0.2342


 22%|██▏       | 202/938 [00:20<01:31,  8.08it/s]

Epoch [14/100] Batch 200/938                   Loss D: -0.6982, loss G: 0.2091


 32%|███▏      | 302/938 [00:31<01:17,  8.16it/s]

Epoch [14/100] Batch 300/938                   Loss D: -0.7300, loss G: 0.2560


 43%|████▎     | 402/938 [00:41<01:03,  8.44it/s]

Epoch [14/100] Batch 400/938                   Loss D: -0.6905, loss G: 0.1491


 54%|█████▎    | 502/938 [00:51<00:51,  8.40it/s]

Epoch [14/100] Batch 500/938                   Loss D: -0.7687, loss G: 0.2961


 64%|██████▍   | 602/938 [01:02<00:38,  8.69it/s]

Epoch [14/100] Batch 600/938                   Loss D: -0.6362, loss G: 0.2088


 75%|███████▍  | 702/938 [01:12<00:27,  8.53it/s]

Epoch [14/100] Batch 700/938                   Loss D: -0.6665, loss G: 0.5064


 78%|███████▊  | 735/938 [01:16<00:20,  9.79it/s]IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

 32%|███▏      | 302/938 [00:31<01:15,  8.45it/s]

Epoch [20/100] Batch 300/938                   Loss D: -0.7095, loss G: 0.4983


 43%|████▎     | 402/938 [00:41<01:04,  8.37it/s]

Epoch [20/100] Batch 400/938                   Loss D: -0.6449, loss G: 0.5033


 54%|█████▎    | 502/938 [00:52<00:51,  8.54it/s]

Epoch [20/100] Batch 500/938                   Loss D: -0.6795, loss G: 0.5061


 64%|██████▍   | 602/938 [01:02<00:40,  8.34it/s]

Epoch [20/100] Batch 600/938                   Loss D: -0.6528, loss G: 0.4846


 75%|███████▍  | 702/938 [01:12<00:28,  8.38it/s]

Epoch [20/100] Batch 700/938                   Loss D: -0.6138, loss G: 0.1232


 86%|████████▌ | 802/938 [01:23<00:15,  8.59it/s]

Epoch [20/100] Batch 800/938                   Loss D: -0.6686, loss G: 0.1949


 96%|█████████▌| 902/938 [01:33<00:04,  8.45it/s]

Epoch [20/100] Batch 900/938                   Loss D: -0.5834, loss G: 0.4602


100%|██████████| 938/938 [01:37<00:00,  9.67it/s]
 11%|█         | 102/938 [00:10<01:38,  8.53it/s]

Epoch [21/100] Batch 100/938                   Loss D: -0.5986, loss G: 0.4895


 22%|██▏       | 202/938 [00:20<01:27,  8.38it/s]

Epoch [21/100] Batch 200/938                   Loss D: -0.6459, loss G: 0.4221


 32%|███▏      | 302/938 [00:31<01:14,  8.56it/s]

Epoch [21/100] Batch 300/938                   Loss D: -0.6498, loss G: 0.2334


 43%|████▎     | 402/938 [00:41<01:03,  8.42it/s]

Epoch [21/100] Batch 400/938                   Loss D: -0.7133, loss G: 0.4254


 54%|█████▎    | 502/938 [00:52<00:50,  8.60it/s]

Epoch [21/100] Batch 500/938                   Loss D: -0.5752, loss G: 0.4709


 64%|██████▍   | 602/938 [01:02<00:39,  8.43it/s]

Epoch [21/100] Batch 600/938                   Loss D: -0.5988, loss G: 0.4319


 75%|███████▍  | 702/938 [01:12<00:27,  8.71it/s]

Epoch [21/100] Batch 700/938                   Loss D: -0.6044, loss G: 0.5051


 86%|████████▌ | 802/938 [01:23<00:15,  8.51it/s]

Epoch [21/100] Batch 800/938                   Loss D: -0.4942, loss G: 0.4524


 96%|█████████▌| 902/938 [01:33<00:04,  8.43it/s]

Epoch [21/100] Batch 900/938                   Loss D: -0.6561, loss G: 0.4826


100%|██████████| 938/938 [01:37<00:00,  9.67it/s]
 11%|█         | 102/938 [00:10<01:39,  8.41it/s]

Epoch [22/100] Batch 100/938                   Loss D: -0.6865, loss G: 0.5018


 22%|██▏       | 202/938 [00:20<01:25,  8.63it/s]

Epoch [22/100] Batch 200/938                   Loss D: -0.6668, loss G: 0.1905


 32%|███▏      | 302/938 [00:31<01:16,  8.35it/s]

Epoch [22/100] Batch 300/938                   Loss D: -0.5217, loss G: 0.1332


 43%|████▎     | 402/938 [00:41<01:03,  8.42it/s]

Epoch [22/100] Batch 400/938                   Loss D: -0.5005, loss G: 0.4549


 54%|█████▎    | 502/938 [00:51<00:51,  8.43it/s]

Epoch [22/100] Batch 500/938                   Loss D: -0.6574, loss G: 0.4375


 64%|██████▍   | 602/938 [01:02<00:40,  8.40it/s]

Epoch [22/100] Batch 600/938                   Loss D: -0.6009, loss G: 0.4643


 75%|███████▍  | 702/938 [01:12<00:27,  8.44it/s]

Epoch [22/100] Batch 700/938                   Loss D: -0.6560, loss G: 0.4618


 86%|████████▌ | 802/938 [01:23<00:16,  8.37it/s]

Epoch [22/100] Batch 800/938                   Loss D: -0.5474, loss G: 0.4799


 96%|█████████▌| 902/938 [01:33<00:04,  8.29it/s]

Epoch [22/100] Batch 900/938                   Loss D: -0.5837, loss G: 0.1120


100%|██████████| 938/938 [01:37<00:00,  9.67it/s]
 11%|█         | 102/938 [00:10<01:37,  8.56it/s]

Epoch [23/100] Batch 100/938                   Loss D: -0.6273, loss G: 0.4822


 22%|██▏       | 202/938 [00:20<01:28,  8.36it/s]

Epoch [23/100] Batch 200/938                   Loss D: -0.6023, loss G: 0.4875


 32%|███▏      | 302/938 [00:31<01:15,  8.47it/s]

Epoch [23/100] Batch 300/938                   Loss D: -0.5982, loss G: 0.4754


 43%|████▎     | 402/938 [00:41<01:04,  8.37it/s]

Epoch [23/100] Batch 400/938                   Loss D: -0.6109, loss G: 0.4979


 54%|█████▎    | 502/938 [00:51<00:51,  8.50it/s]

Epoch [23/100] Batch 500/938                   Loss D: -0.6502, loss G: 0.4732


 64%|██████▍   | 602/938 [01:02<00:40,  8.31it/s]

Epoch [23/100] Batch 600/938                   Loss D: -0.5963, loss G: 0.4564


 75%|███████▍  | 702/938 [01:12<00:27,  8.68it/s]

Epoch [23/100] Batch 700/938                   Loss D: -0.5370, loss G: 0.5039


 86%|████████▌ | 802/938 [01:22<00:15,  8.67it/s]

Epoch [23/100] Batch 800/938                   Loss D: -0.5959, loss G: 0.1533


 96%|█████████▌| 902/938 [01:33<00:04,  8.41it/s]

Epoch [23/100] Batch 900/938                   Loss D: -0.6154, loss G: 0.1340


 99%|█████████▉| 927/938 [01:35<00:01,  9.67it/s]


KeyboardInterrupt: 

In [50]:
def save_image_grid(images: torch.Tensor, ncol: int):
    
    image_grid = make_grid(images, ncol)     # Images in a grid
    image_grid = image_grid.permute(1, 2, 0) # Move channel last
    image_grid = image_grid.cpu().numpy()    # To Numpy

    plt.imshow(image_grid)
    plt.xticks([])
    plt.yticks([])
    plt.savefig('dcgan_generated1_.jpg')
    plt.close()

In [51]:
noise  = torch.randn(64, 128, 1, 1).to(device)
with torch.no_grad():
    images = gen(noise)
save_image_grid(images, 8)
    

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
