In [41]:
import torch
import torch.nn as nn
import torch.optim as optim 
from torchvision import datasets, transforms 
from torchvision.utils import save_image, make_grid
import os
import matplotlib.pyplot as plt

# device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


After applying conv the size reduces but in transpose conv the size increases. 
####  parameter
* input_size = 7x7
* kernel_szie = 4
* stride = 2
* padding(input) = 1
* output padding --> padding on top of output = 1


`Hout = (Hin - 1)*stride - 2*padding + kerner_size + output padding`

`Hout = (7-1)*2-2*1+4+1 ==> 15`

In [42]:
batch_size = 128
z_dim = 100 
image_size = 28
channels = 1 
epochs = 10 
lr = 0.0002
beta1 = 0.5

os.makedirs('generated_images', exist_ok=True)

In [43]:
transform = transforms.Compose([
    transforms.Resize(image_size),
    transforms.ToTensor(), 
    transforms.Normalize([0.5], [0.5])
])

train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('.', train=True, download=True, transform=transform),
    batch_size= batch_size, shuffle=True
)

In [44]:
class Generator(nn.Module): 
    def __init__(self, z_dim): 
        super(Generator, self).__init__()
        self.net = nn.Sequential(
            nn.ConvTranspose2d(z_dim, 256, 7, 1, 0, bias = False),
            nn.BatchNorm2d(256), # for smaller networks --> can be avoided for deeper nn
            nn.ReLU(True), 

            nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False), 
            nn.BatchNorm2d(128), 
            nn.ReLU(True), 

            nn.ConvTranspose2d(128, 1, 4, 2, 1, bias = False), 
            nn.Tanh()
        )

    def forward(self, z): 
        return self.net(z)

In [45]:
class Discriminator(nn.Module): 
    def __init__(self): 
        super(Discriminator, self).__init__()
        self.net = nn.Sequential(
            nn.Conv2d(1, 64, 4, 2, 1, bias = False),
            nn.LeakyReLU(0.2, inplace=True), 

            nn.Conv2d(64, 128, 4, 2, 1, bias=False), 
            nn.BatchNorm2d(128), 
            nn.LeakyReLU(0.2, inplace=True), 

            nn.Flatten(), 
            nn.Linear(128*7*7, 1),
            nn.Sigmoid()
        )

    def forward(self, img): 
        return self.net(img)

In [46]:
generator = Generator(z_dim).to(device)
discriminator = Discriminator().to(device)

criterion = nn.BCELoss()

optimizer_G = optim.Adam(generator.parameters(), lr = lr, betas = (beta1, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr = lr, betas = (beta1, 0.999))

In [47]:
def generate_and_save_images(epoch): 
    generator.eval()
    with torch.no_grad(): 
        z = torch.randn(64, z_dim, 1, 1).to(device)
        fake_images = generator(z)
        fake_images = fake_images*0.5+0.5 # denormalize to [0, 1]
        save_image(fake_images, f"generated_images/sample_epoch_{epoch}.png", nrow=8)
    generator.train()

In [48]:
k = 3 # genrator updates per iterations
p =1 # discriminator updates per iterations
# train generator more than discriminator 
for epoch in range(1, epochs+1) : 
    for i, (real_imgs, _) in enumerate(train_loader) :
        batch_size_curr = real_imgs.size(0)
        real_imgs = real_imgs.to(device)

        real = torch.ones(batch_size_curr, 1, device=device)
        fake = torch.zeros(batch_size_curr, 1, device=device)

        for _ in range(p) : 
            z = torch.randn(batch_size_curr, z_dim, 1, 1, device=device)
            fake_imgs = generator(z)

            real_validity = discriminator(real_imgs)
            d_real_loss = criterion(real_validity, real)

            fake_validity = discriminator(fake_imgs.detach())
            d_fake_loss = criterion(fake_validity, fake)

            d_loss = d_real_loss + d_fake_loss
            optimizer_D.zero_grad()
            d_loss.backward()
            optimizer_D.step()
        for _ in range(k) : 
            z = torch.randn(batch_size_curr, z_dim, 1, 1, device=device)
            fake_imgs = generator(z)

            validity = discriminator(fake_imgs)
            g_loss = criterion(validity, real)

            optimizer_G.zero_grad()
            g_loss.backward()
            optimizer_G.step()
        
        if i%200 == 0: 
            print(
                f"[Epoch {epoch}/{epoch}] [Batch {i}/{len(train_loader)}]"
                f"D Loss: {d_loss.item():.4f} | G Loss: {g_loss.item():.4f}"
            )

    generator.eval()
    with torch.no_grad():
        z = torch.randn(64, z_dim, 1, 1, device=device)
        samples = generator(z)
        samples = samples*0.5 + 0.5
        save_image(samples, f"generated_images/epochs_{epoch}.png", nrow=8)
    generator.train()

[Epoch 1/1] [Batch 0/469]D Loss: 1.4335 | G Loss: 0.4320
[Epoch 1/1] [Batch 200/469]D Loss: 0.0051 | G Loss: 5.9598
[Epoch 1/1] [Batch 400/469]D Loss: 0.0368 | G Loss: 4.2993
[Epoch 2/2] [Batch 0/469]D Loss: 0.0097 | G Loss: 5.2181
[Epoch 2/2] [Batch 200/469]D Loss: 0.0268 | G Loss: 4.7559
[Epoch 2/2] [Batch 400/469]D Loss: 0.0047 | G Loss: 6.2263
[Epoch 3/3] [Batch 0/469]D Loss: 0.0035 | G Loss: 6.7329
[Epoch 3/3] [Batch 200/469]D Loss: 0.0012 | G Loss: 7.3468
[Epoch 3/3] [Batch 400/469]D Loss: 0.0017 | G Loss: 7.0271
[Epoch 4/4] [Batch 0/469]D Loss: 0.0010 | G Loss: 7.2973
[Epoch 4/4] [Batch 200/469]D Loss: 0.0006 | G Loss: 8.0824
[Epoch 4/4] [Batch 400/469]D Loss: 1.3210 | G Loss: 0.6412
[Epoch 5/5] [Batch 0/469]D Loss: 1.1025 | G Loss: 0.8150
[Epoch 5/5] [Batch 200/469]D Loss: 1.1633 | G Loss: 0.8906
[Epoch 5/5] [Batch 400/469]D Loss: 1.0645 | G Loss: 0.9206
[Epoch 6/6] [Batch 0/469]D Loss: 1.0484 | G Loss: 1.0300
[Epoch 6/6] [Batch 200/469]D Loss: 1.0064 | G Loss: 1.0703
[Epoch 6/