In [58]:
import torch
from torch import nn
import torchvision.datasets as datasets
from torchvision.transforms import v2
from torch.utils.data import DataLoader
import torchvision
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter

In [49]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [63]:
LEARNING_RATE = 2e-4
BATCH_SIZE = 64
IMAGE_SIZE = 64
CHANNELS_IMG = 1
Z_DIM = 100
NUM_EPOCHS = 10
FEATURES_DISC = 64
FEATURES_GEN = 64

In [51]:
class Discriminator(nn.Module):
    def __init__(self, channels_img, features_d):
        super(Discriminator, self).__init__()
        self.disc = nn.Sequential(
            nn.Conv2d(channels_img, features_d, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            self._block(features_d, features_d*2, 4, 2, 1), 
            self._block(features_d*2, features_d*4, 4, 2, 1),
            self._block(features_d*4, features_d*8, 4, 2, 1),
            nn.Conv2d(features_d*8, 1, kernel_size=4, stride=2, padding=0),
            nn.Sigmoid(),
        )


    def _block(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.2)
        )
    
    def forward(self, x):
        return self.disc(x)

In [52]:
class Generator(nn.Module):
    def __init__(self, z_dim, channels_img, featuresd_d):
        super(Generator, self).__init__()
        self.gen = nn.Sequential(
            self._block(z_dim, featuresd_d*16, 4, 1, 0),
            self._block(featuresd_d*16, featuresd_d*8, 4, 2, 1),
            self._block(featuresd_d*8, featuresd_d*4, 4, 2, 1),
            self._block(featuresd_d*4, featuresd_d*2, 4, 2, 1),
            nn.ConvTranspose2d(featuresd_d*2, channels_img, kernel_size=4, stride=2, padding=1),
            nn.Tanh()
        )

    def _block(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()
        )
    
    def forward(self, x):
        return self.gen(x)

In [53]:
def initialize_weights(model):
    for m in model.modules():
        if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
            nn.init.normal_(m.weight.data, 0.0, 0.02) 

In [54]:
def test():
    N, in_channels, H, W = 8, 3, 64, 64
    z_dim = 100
    x = torch.randn((N, in_channels, H, W))
    disc = Discriminator(in_channels, 8)
    initialize_weights(disc)
    assert disc(x).shape == (N, 1, 1, 1)
    gen = Generator(z_dim, in_channels, 8)
    initialize_weights(gen)
    z = torch.randn((N, z_dim, 1, 1))
    assert gen(z).shape == (N, in_channels, H, W)
    print('SUCCESS')

In [55]:
test()

SUCCESS


In [56]:
transforms = v2.Compose([
    v2.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    v2.ToImage(),
    v2.ToDtype(torch.float, scale=True),
    v2.Normalize([0.5 for _ in range(CHANNELS_IMG)], [0.5 for _ in range(CHANNELS_IMG)]),
])

In [57]:
dataset = datasets.MNIST(root='MNIST/', train=True, transform=transforms, download=False)

In [61]:
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
gen = Generator(Z_DIM, CHANNELS_IMG, FEATURES_GEN).to(device)
disc = Discriminator(CHANNELS_IMG, FEATURES_DISC).to(device)
initialize_weights(gen)
initialize_weights(disc)

opt_gen = torch.optim.Adam(gen.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))
opt_disc = torch.optim.Adam(disc.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))
criterion = nn.BCELoss()

fixed_noise = torch.randn(32, Z_DIM, 1, 1).to(device)
writer_real = SummaryWriter(f"logs/real")
writer_fake = SummaryWriter(f"logs/fake")
step = 0

In [64]:
from tqdm.auto import tqdm
gen.train()
disc.train()

for epoch in range(NUM_EPOCHS):
    for batch_idx, (real, _) in enumerate(tqdm(loader, desc=f'Epoch {epoch+1}', unit='batch')):
        real = real.to(device)
        noise = torch.randn(BATCH_SIZE, Z_DIM, 1, 1).to(device)


        # Train Discriminator: max log(D(x)) + log(1 - D(G(z)))
        disc_real = disc(real).reshape(-1)
        lossD_real = criterion(disc_real, torch.ones_like(disc_real).to(device))
        fake = gen(noise)
        disc_fake = disc(fake).reshape(-1)
        lossD_fake = criterion(disc_fake, torch.zeros_like(disc_fake).to(device))
        loss_disc = (lossD_real + lossD_fake) / 2

        disc.zero_grad()

        loss_disc.backward(retain_graph=True)

        opt_disc.step()

        #Train Generator
        output = disc(fake).reshape(-1)
        loss_gen = criterion(output, torch.ones_like(output).to(device))
        gen.zero_grad()
        loss_gen.backward()
        opt_gen.step()

        if batch_idx % 100 == 0:
            print(
                f"Epoch [{epoch}/{NUM_EPOCHS}] Batch {batch_idx}/{len(dataloader)} \
                  Loss D: {loss_disc:.4f}, loss G: {loss_gen:.4f}"
            )

            with torch.no_grad():
                fake = gen(fixed_noise)
                # take out (up to) 32 examples
                img_grid_real = torchvision.utils.make_grid(real[:32], normalize=True)
                img_grid_fake = torchvision.utils.make_grid(fake[:32], normalize=True)

                writer_real.add_image("Real", img_grid_real, global_step=step)
                writer_fake.add_image("Fake", img_grid_fake, global_step=step)

            step += 1



Epoch 1:   0%|          | 0/938 [00:00<?, ?batch/s]

Epoch [0/10] Batch 0/938                   Loss D: 0.1480, loss G: 2.7666
Epoch [0/10] Batch 100/938                   Loss D: 0.0637, loss G: 3.2047
Epoch [0/10] Batch 200/938                   Loss D: 0.0334, loss G: 4.7036
Epoch [0/10] Batch 300/938                   Loss D: 0.0514, loss G: 3.9083
Epoch [0/10] Batch 400/938                   Loss D: 0.3200, loss G: 2.4468
Epoch [0/10] Batch 500/938                   Loss D: 0.0738, loss G: 3.7355
Epoch [0/10] Batch 600/938                   Loss D: 0.0386, loss G: 4.4156
Epoch [0/10] Batch 700/938                   Loss D: 0.2284, loss G: 1.6513
Epoch [0/10] Batch 800/938                   Loss D: 0.2810, loss G: 1.3732
Epoch [0/10] Batch 900/938                   Loss D: 0.4051, loss G: 3.4780


Epoch 2:   0%|          | 0/938 [00:00<?, ?batch/s]

Epoch [1/10] Batch 0/938                   Loss D: 0.0797, loss G: 2.9424
Epoch [1/10] Batch 100/938                   Loss D: 0.1650, loss G: 2.5974
Epoch [1/10] Batch 200/938                   Loss D: 0.0654, loss G: 3.9928
Epoch [1/10] Batch 300/938                   Loss D: 0.0580, loss G: 4.0905
Epoch [1/10] Batch 400/938                   Loss D: 0.0537, loss G: 4.0543
Epoch [1/10] Batch 500/938                   Loss D: 0.0529, loss G: 3.1256
Epoch [1/10] Batch 600/938                   Loss D: 0.1106, loss G: 2.7881
Epoch [1/10] Batch 700/938                   Loss D: 0.0403, loss G: 3.7445
Epoch [1/10] Batch 800/938                   Loss D: 0.0318, loss G: 3.7384
Epoch [1/10] Batch 900/938                   Loss D: 0.1970, loss G: 4.1271


Epoch 3:   0%|          | 0/938 [00:00<?, ?batch/s]

Epoch [2/10] Batch 0/938                   Loss D: 0.0537, loss G: 4.4720
Epoch [2/10] Batch 100/938                   Loss D: 0.0228, loss G: 4.7947
Epoch [2/10] Batch 200/938                   Loss D: 0.6211, loss G: 4.1834
Epoch [2/10] Batch 300/938                   Loss D: 0.0635, loss G: 3.4289
Epoch [2/10] Batch 400/938                   Loss D: 0.0493, loss G: 4.4751
Epoch [2/10] Batch 500/938                   Loss D: 0.0215, loss G: 4.8764
Epoch [2/10] Batch 600/938                   Loss D: 0.0880, loss G: 3.6685
Epoch [2/10] Batch 700/938                   Loss D: 0.0566, loss G: 3.8856
Epoch [2/10] Batch 800/938                   Loss D: 0.0768, loss G: 3.4655
Epoch [2/10] Batch 900/938                   Loss D: 0.0429, loss G: 5.5308


Epoch 4:   0%|          | 0/938 [00:00<?, ?batch/s]

Epoch [3/10] Batch 0/938                   Loss D: 0.0207, loss G: 4.8198
Epoch [3/10] Batch 100/938                   Loss D: 0.3560, loss G: 2.0467
Epoch [3/10] Batch 200/938                   Loss D: 0.3107, loss G: 2.6734
Epoch [3/10] Batch 300/938                   Loss D: 0.0451, loss G: 4.2749
Epoch [3/10] Batch 400/938                   Loss D: 0.0156, loss G: 4.9387
Epoch [3/10] Batch 500/938                   Loss D: 0.0184, loss G: 4.9710
Epoch [3/10] Batch 600/938                   Loss D: 0.0506, loss G: 3.1264
Epoch [3/10] Batch 700/938                   Loss D: 0.4508, loss G: 0.1944
Epoch [3/10] Batch 800/938                   Loss D: 0.0615, loss G: 3.9810
Epoch [3/10] Batch 900/938                   Loss D: 0.0767, loss G: 2.8835


Epoch 5:   0%|          | 0/938 [00:00<?, ?batch/s]

Epoch [4/10] Batch 0/938                   Loss D: 0.1522, loss G: 5.1512
Epoch [4/10] Batch 100/938                   Loss D: 0.1053, loss G: 3.2524
Epoch [4/10] Batch 200/938                   Loss D: 0.0171, loss G: 4.3776
Epoch [4/10] Batch 300/938                   Loss D: 0.0186, loss G: 4.9874
Epoch [4/10] Batch 400/938                   Loss D: 2.3290, loss G: 0.0031
Epoch [4/10] Batch 500/938                   Loss D: 0.1487, loss G: 3.8655
Epoch [4/10] Batch 600/938                   Loss D: 0.1184, loss G: 4.1417
Epoch [4/10] Batch 700/938                   Loss D: 0.0444, loss G: 3.4296
Epoch [4/10] Batch 800/938                   Loss D: 0.0658, loss G: 2.6574
Epoch [4/10] Batch 900/938                   Loss D: 0.0652, loss G: 3.5857


Epoch 6:   0%|          | 0/938 [00:00<?, ?batch/s]

Epoch [5/10] Batch 0/938                   Loss D: 0.0281, loss G: 4.9689
Epoch [5/10] Batch 100/938                   Loss D: 0.0487, loss G: 3.2039
Epoch [5/10] Batch 200/938                   Loss D: 0.2898, loss G: 8.2339
Epoch [5/10] Batch 300/938                   Loss D: 0.0178, loss G: 4.9057
Epoch [5/10] Batch 400/938                   Loss D: 0.0135, loss G: 5.0036
Epoch [5/10] Batch 500/938                   Loss D: 0.0539, loss G: 3.9341
Epoch [5/10] Batch 600/938                   Loss D: 0.2190, loss G: 1.5815
Epoch [5/10] Batch 700/938                   Loss D: 0.0234, loss G: 4.2869
Epoch [5/10] Batch 800/938                   Loss D: 0.0816, loss G: 3.6089
Epoch [5/10] Batch 900/938                   Loss D: 0.7370, loss G: 0.5250


Epoch 7:   0%|          | 0/938 [00:00<?, ?batch/s]

Epoch [6/10] Batch 0/938                   Loss D: 1.6781, loss G: 11.5139
Epoch [6/10] Batch 100/938                   Loss D: 0.2243, loss G: 2.9681
Epoch [6/10] Batch 200/938                   Loss D: 0.0560, loss G: 2.9741
Epoch [6/10] Batch 300/938                   Loss D: 0.0182, loss G: 5.0267
Epoch [6/10] Batch 400/938                   Loss D: 0.0083, loss G: 5.3834
Epoch [6/10] Batch 500/938                   Loss D: 0.0128, loss G: 5.0845
Epoch [6/10] Batch 600/938                   Loss D: 0.0042, loss G: 6.0180
Epoch [6/10] Batch 700/938                   Loss D: 0.0156, loss G: 4.4391
Epoch [6/10] Batch 800/938                   Loss D: 1.0289, loss G: 0.5668
Epoch [6/10] Batch 900/938                   Loss D: 0.3478, loss G: 3.9201


Epoch 8:   0%|          | 0/938 [00:00<?, ?batch/s]

Epoch [7/10] Batch 0/938                   Loss D: 0.0736, loss G: 4.2469
Epoch [7/10] Batch 100/938                   Loss D: 0.0368, loss G: 3.6451
Epoch [7/10] Batch 200/938                   Loss D: 0.0971, loss G: 2.6093
Epoch [7/10] Batch 300/938                   Loss D: 0.2285, loss G: 2.9570
Epoch [7/10] Batch 400/938                   Loss D: 0.0205, loss G: 4.6874
Epoch [7/10] Batch 500/938                   Loss D: 0.3762, loss G: 1.5155
Epoch [7/10] Batch 600/938                   Loss D: 0.1435, loss G: 3.3669
Epoch [7/10] Batch 700/938                   Loss D: 0.0272, loss G: 4.3319
Epoch [7/10] Batch 800/938                   Loss D: 0.0201, loss G: 4.7503
Epoch [7/10] Batch 900/938                   Loss D: 0.4434, loss G: 2.2299


Epoch 9:   0%|          | 0/938 [00:00<?, ?batch/s]

Epoch [8/10] Batch 0/938                   Loss D: 0.0538, loss G: 4.1865
Epoch [8/10] Batch 100/938                   Loss D: 0.0873, loss G: 3.6998
Epoch [8/10] Batch 200/938                   Loss D: 0.1317, loss G: 4.4159
Epoch [8/10] Batch 300/938                   Loss D: 0.0447, loss G: 4.3590
Epoch [8/10] Batch 400/938                   Loss D: 0.0127, loss G: 4.7101
Epoch [8/10] Batch 500/938                   Loss D: 0.1452, loss G: 3.4864
Epoch [8/10] Batch 600/938                   Loss D: 0.2896, loss G: 2.8867
Epoch [8/10] Batch 700/938                   Loss D: 0.1116, loss G: 2.6938
Epoch [8/10] Batch 800/938                   Loss D: 0.0193, loss G: 4.6778
Epoch [8/10] Batch 900/938                   Loss D: 0.0198, loss G: 5.0583


Epoch 10:   0%|          | 0/938 [00:00<?, ?batch/s]

Epoch [9/10] Batch 0/938                   Loss D: 0.3266, loss G: 2.0449
Epoch [9/10] Batch 100/938                   Loss D: 0.1462, loss G: 3.2131
Epoch [9/10] Batch 200/938                   Loss D: 0.2079, loss G: 2.8085
Epoch [9/10] Batch 300/938                   Loss D: 0.1827, loss G: 4.0426
Epoch [9/10] Batch 400/938                   Loss D: 0.1812, loss G: 3.0821
Epoch [9/10] Batch 500/938                   Loss D: 0.0596, loss G: 4.1957
Epoch [9/10] Batch 600/938                   Loss D: 0.0082, loss G: 5.5505
Epoch [9/10] Batch 700/938                   Loss D: 0.0090, loss G: 5.3081
Epoch [9/10] Batch 800/938                   Loss D: 0.0178, loss G: 5.1770
Epoch [9/10] Batch 900/938                   Loss D: 0.0049, loss G: 6.4314


In [66]:
writer_real.close()
writer_fake.close()

In [67]:
%load_ext tensorboard

In [70]:
%tensorboard --logdir=logs

Reusing TensorBoard on port 6014 (pid 38563), started 0:00:08 ago. (Use '!kill 38563' to kill it.)