In [66]:
import torch
import torch.nn as nn

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class Generator(nn.Module):
    def __init__(self, latent_dim):
        super(Generator, self).__init__()
        kernel_size = 5
        self.blocks = nn.ModuleList()
        self.fc = nn.Linear(latent_dim, 7 * 7 * 128)
        self.block1 = nn.Sequential(
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.ConvTranspose2d(128, 128, kernel_size, stride=2, padding=2, output_padding=1),
        )
        self.block2 = nn.Sequential(
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.ConvTranspose2d(128, 64, kernel_size, stride=2, padding=2, output_padding=1),
        )
        self.block3 = nn.Sequential(
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.ConvTranspose2d(64, 32, kernel_size, stride=1, padding=2),
        )
        self.block4 = nn.Sequential(
            nn.ConvTranspose2d(32, 1, kernel_size, stride=1, padding=2),
            nn.Sigmoid()
        )
    def forward(self, z):
        x = self.fc(z)
        x = x.view(-1, 128, 7, 7)
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        img = self.block4(x)
        return img



In [67]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        kernel_size = 5
        layer_filters = [32, 64, 128, 256]
        self.block1 = nn.Sequential(
            nn.LeakyReLU(0.2),
            nn.Conv2d(1, layer_filters[0], kernel_size, stride=2, padding=2),
        )
        self.block2 = nn.Sequential(
            nn.LeakyReLU(0.2),
            nn.Conv2d(layer_filters[0], layer_filters[1], kernel_size, stride=2, padding=2),
        )
        self.block3 = nn.Sequential(
            nn.LeakyReLU(0.2),
            nn.Conv2d(layer_filters[1], layer_filters[2], kernel_size, stride=2, padding=2),
        )
        self.block4 = nn.Sequential(
            nn.LeakyReLU(0.2),
            nn.Conv2d(layer_filters[2], layer_filters[3], kernel_size, stride=1, padding=2),     
        )
        self.fc = nn.Linear(4 * 4 * layer_filters[3], 1)

    def forward(self, x):
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        x = self.block4(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        x = torch.sigmoid(x)
        return x

In [68]:
class Adversarial(nn.Module):
    def __init__(self, generator, discriminator):
        super(Adversarial, self).__init__()
        self.generator = generator
        self.discriminator = discriminator

    def forward(self, z):
        img = self.generator(z)
        validity = self.discriminator(img)
        return validity

In [69]:
import torch
import torchvision
from torchvision import transforms

# Define the transformation to apply to the MNIST dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    #transforms.Normalize((0.5,), (0.5,))
])

# Load the MNIST dataset
train_dataset = torchvision.datasets.MNIST(root='~/data', train=True, transform=transform, download=True)
test_dataset = torchvision.datasets.MNIST(root='~/data', train=False, transform=transform, download=True)

# Create dataloaders for training and testing
batch_size = 64
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)


In [70]:
import matplotlib.pyplot as plt
import torchvision.utils as vutils

def save_generated_images(generator, device, epoch=1):
    # Generate random noise vectors
    noise = torch.randn(16, 100).to(device)
    
    # Generate images using the generator
    with torch.no_grad():
        generated_images = generator(noise).detach().cpu()
    
    # Create a grid of 4 x 4 images
    grid = vutils.make_grid(generated_images, nrow=4, padding=2, normalize=True)
    
    # Save the grid of images
    filename = 'gen_img_ep-%s.png' % (epoch)
    vutils.save_image(grid, filename)


In [71]:
def train(generator, discriminator, 
          adversarial, trainloader,
          epochs=50):
    generator.train()
    discriminator.train()
    criterion = nn.BCELoss()
    lr = 2e-4
    decay = 6e-8
    factor = 0.5
    optimizer_g = torch.optim.RMSprop(generator.parameters(), 
                                    lr=lr*factor, weight_decay=decay*factor)    
    optimizer_d = torch.optim.RMSprop(discriminator.parameters(),
                                    lr=lr, weight_decay=decay) 
    #scheduler_g = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer_g, T_max=epochs)
    #scheduler_d = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer_d, T_max=epochs)
    
    for epoch in range(epochs):
        for i, (real_imgs, _) in enumerate(trainloader):
            real_imgs = real_imgs.to(device)
            batch_size = real_imgs.shape[0]
            real_labels = torch.ones(batch_size, 1, requires_grad=False).to(device)
            fake_labels = torch.zeros(batch_size, 1, requires_grad=False).to(device)
            z = torch.randn(batch_size, 100).to(device)
            
            # Train discriminator
            optimizer_d.zero_grad()
            generator.eval()
            with torch.no_grad():
                fake_imgs = generator(z)
        
            real_out = discriminator(real_imgs)
            fake_out = discriminator(fake_imgs)
            loss_real = criterion(real_out, real_labels)
            loss_fake = criterion(fake_out, fake_labels)
            loss_d = loss_real + loss_fake
            loss_d.backward()
            optimizer_d.step()
            
            # Train generator
            optimizer_g.zero_grad()
            generator.train()
            z = torch.randn(batch_size, 100, requires_grad=False).to(device)
            
            fake_out = adversarial(z)
            loss_g = criterion(fake_out, real_labels)
            loss_g.backward()
            optimizer_g.step()
            if i % 100 == 0:
                print('Epoch: %d, Iter: %d, Loss D: %.6f, Loss G: %.6f' % (epoch, i, loss_d.item(), loss_g.item()))
        #scheduler_g.step()
        #scheduler_d.step()
        print('Learning Rate (Generator):', optimizer_g.param_groups[0]['lr'])
        print('Learning Rate (Discriminator):', optimizer_d.param_groups[0]['lr'])
        save_generated_images(generator, device, epoch=epoch)


In [72]:
generator = Generator(100).to(device)
discriminator = Discriminator().to(device)
adversarial = Adversarial(generator, discriminator).to(device)

train(generator, discriminator, adversarial, train_dataloader)


Epoch: 0, Iter: 0, Loss D: 1.385114, Loss G: 1.113374
Epoch: 0, Iter: 100, Loss D: 0.000456, Loss G: 7.802855
Epoch: 0, Iter: 200, Loss D: 0.207746, Loss G: 3.632315
Epoch: 0, Iter: 300, Loss D: 0.001795, Loss G: 6.726063
Epoch: 0, Iter: 400, Loss D: 0.001615, Loss G: 7.410963
Epoch: 0, Iter: 500, Loss D: 0.001772, Loss G: 7.345804
Epoch: 0, Iter: 600, Loss D: 0.000196, Loss G: 8.635316
Epoch: 0, Iter: 700, Loss D: 0.002317, Loss G: 6.019722
Epoch: 0, Iter: 800, Loss D: 0.000571, Loss G: 7.443192
Epoch: 0, Iter: 900, Loss D: 0.000211, Loss G: 8.648391
Learning Rate (Generator): 0.0001
Learning Rate (Discriminator): 0.0002
Epoch: 1, Iter: 0, Loss D: 0.000287, Loss G: 8.349941
Epoch: 1, Iter: 100, Loss D: 0.000052, Loss G: 9.929526
Epoch: 1, Iter: 200, Loss D: 0.000025, Loss G: 10.655956
Epoch: 1, Iter: 300, Loss D: 0.000023, Loss G: 10.723509
Epoch: 1, Iter: 400, Loss D: 0.064607, Loss G: 3.696424
Epoch: 1, Iter: 500, Loss D: 0.004399, Loss G: 5.603803
Epoch: 1, Iter: 600, Loss D: 1.133