In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
from tqdm import tqdm
import matplotlib.pyplot as plt
from vanillaGAN import VanillaGAN_Generator, VanillaGAN_Discriminator
from dcGAN import DCGAN_Generator, DCGAN_Discriminator

In [2]:
IMG_SHAPE = (1, 256, 256)
LATENT_DIM = 100
N_OUT = torch.prod(torch.tensor(IMG_SHAPE))
print(N_OUT)

tensor(65536)


# Dataset (replace with LSUN dataset)

In [3]:
transform = transforms.Compose([transforms.ToTensor(), transforms.Resize((256, 256)),transforms.Normalize((0.5,), (0.5,))])
dataset = datasets.MNIST(root='.', train=True, transform=transform, download=True)
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)

In [4]:
for x, y in dataloader:
    print(x.shape)
    break

torch.Size([64, 1, 256, 256])


# Instantiation of models, otpimizers e.t.c

In [10]:
generator = VanillaGAN_Generator(latent_dim=LATENT_DIM, img_shape=IMG_SHAPE, n_out=N_OUT)
discriminator = VanillaGAN_Discriminator(img_shape=IMG_SHAPE)

criterion = nn.BCELoss()

generator_optimizer = optim.Adam(generator.parameters(), lr=0.0002)
discriminator_optimizer = optim.Adam(discriminator.parameters(), lr=0.0002)

# Test output shapes

In [8]:
latent_dim=100
num_channels=1
batch_size=64
img_shape = (1, 256, 256)
n_out = int(torch.prod(torch.tensor(img_shape)))
generator = VanillaGAN_Generator(latent_dim=latent_dim, img_shape=img_shape, n_out=n_out)

# Assume input shape: (batch_size, latent_dim)
input_shape = (batch_size, 100)

# Forward pass
output = generator(torch.randn(input_shape))

# Get the output shape
output_shape = output.shape

print(output_shape)

torch.Size([64, 1, 256, 256])


# Training loop

In [11]:
def train_discriminator(optimizer, discriminator, real_images, fake_images):
    batch_size = real_images.size(0)
    
    # Train the discriminator
    real_labels = torch.ones(batch_size, 1)
    fake_labels = torch.zeros(batch_size, 1)
    
    optimizer.zero_grad()
    
    # Compute discriminator loss on real images
    real_outputs = discriminator(real_images)
    d_loss_real = criterion(real_outputs, real_labels)
    
    # Compute discriminator loss on fake images
    fake_outputs = discriminator(fake_images.detach())
    d_loss_fake = criterion(fake_outputs, fake_labels)
    
    # Total discriminator loss
    d_loss = d_loss_real + d_loss_fake
    
    d_loss.backward()
    optimizer.step()
    
    return d_loss

def train_generator(optimizer, fake_images, fake_images_outputs):
    
    batch_size = fake_images.size(0)
    real_labels = torch.ones(batch_size, 1)

    
    # Compute generator loss
    g_loss = criterion(fake_images_outputs, real_labels)
    
    optimizer.zero_grad()
    g_loss.backward()
    optimizer.step()
    
    return g_loss

In [13]:
num_epochs = 10

for epoch in range(num_epochs):
    for real_images, _ in tqdm(dataloader):
        
        batch_size = real_images.size(0)
        
        z = torch.randn(batch_size, LATENT_DIM)
        fake_images = generator(z)
       
        # Train the discriminator
        d_loss = train_discriminator(discriminator_optimizer, discriminator, real_images, fake_images)
        
        # Train the generator
        fake_images_outputs = discriminator(fake_images)
        g_loss = train_generator(generator_optimizer, fake_images, fake_images_outputs)
        

  1%|          | 6/938 [00:13<35:41,  2.30s/it]


KeyboardInterrupt: 