In [14]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.utils import save_image
import os

# Define the generator network
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(100, 256),
            nn.ReLU(),
            nn.Linear(256, 3 * 64 * 64),
            nn.Tanh()
        )
        self.model = self.model.cuda()

    def forward(self, z):
        return self.model(z).view(-1, 3, 64, 64)

# Define the discriminator network
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(3 * 64 * 64, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )
        self.model = self.model.cuda()

    def forward(self, img):
        return self.model(img.view(img.size(0), -1))

# Set up data loaders
def load_data(root_folder, batch_size):
    transform = transforms.Compose([
        transforms.Resize((64, 64)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    dataset = datasets.ImageFolder(root=root_folder, transform=transform)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)

    return dataloader

# Training function
def train_gan(generator, discriminator, dataloader, num_epochs, batch_size, learning_rate):
    criterion = nn.BCELoss()
    optimizer_g = optim.Adam(generator.parameters(), lr=learning_rate, betas=(0.5, 0.999))
    optimizer_d = optim.Adam(discriminator.parameters(), lr=learning_rate, betas=(0.5, 0.999))

    fixed_noise = torch.randn(64, 100).cuda()

    for epoch in range(num_epochs):
        for i, data in enumerate(dataloader, 0):
            real_imgs, _ = data
            real_imgs = real_imgs.cuda()

 
            # we first generate outputs from 'descriminator' by using real images and fake images (fake images are generated by 'generator' with noise as input) and then we calculate loss for fake and real image cases and then use it to train descriminator. Then, 'generator' is trained  by discriminator' output generated from fake images
            # Train discriminator with real images
            optimizer_d.zero_grad()
            real_outputs = discriminator(real_imgs)
            real_labels = torch.ones(real_outputs.size(0), 1).cuda()  # Explicitly set the target shape
            loss_real = criterion(real_outputs, real_labels)
            loss_real.backward()
            
            # Train discriminator with fake images
            noise = torch.randn(real_imgs.size(0), 100).cuda()
            fake_imgs = generator(noise)
            fake_outputs = discriminator(fake_imgs.detach())
            fake_labels = torch.zeros(fake_outputs.size(0), 1).cuda()  # Explicitly set the target shape
            loss_fake = criterion(fake_outputs, fake_labels)
            loss_fake.backward()
            optimizer_d.step()
            
            # Train generator
            optimizer_g.zero_grad()
            gen_labels = torch.ones(fake_outputs.size(0), 1).cuda()  # Explicitly set the target shape
            gen_outputs = discriminator(fake_imgs)
            loss_gen = criterion(gen_outputs, gen_labels)
            loss_gen.backward()
            optimizer_g.step()

            if i % 100 == 0:
                print('[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f'
                      % (epoch, num_epochs, i, len(dataloader),
                         (loss_real + loss_fake).item(), loss_gen.item()))

        if epoch % 10 == 0:
            fake = generator(fixed_noise)
            save_image(fake.detach(), 'fake_samples_epoch_%03d.png' % epoch, normalize=True)

# Set random seed for reproducibility
torch.manual_seed(42)

# Set up data loaders for cat and dog images
#cat_dataloader = load_data('/home/tom/Python/Machine learning/pytorch/GaN/data/cat/', batch_size=64)
dog_dataloader = load_data('/home/tom/Python/Machine learning/pytorch/GaN/data/dogs', batch_size=64)

# Initialize the generator and discriminator
generator = Generator()
discriminator = Discriminator()

# Train the GAN
train_gan(generator, discriminator, dog_dataloader, num_epochs=200, batch_size=64, learning_rate=0.0002)


[0/200][0/24] Loss_D: 1.3782 Loss_G: 1.1221
[1/200][0/24] Loss_D: 0.7534 Loss_G: 1.0054
[2/200][0/24] Loss_D: 0.6284 Loss_G: 1.4632
[3/200][0/24] Loss_D: 0.7423 Loss_G: 2.4168
[4/200][0/24] Loss_D: 0.3863 Loss_G: 2.2779
[5/200][0/24] Loss_D: 0.4120 Loss_G: 1.8583
[6/200][0/24] Loss_D: 0.4492 Loss_G: 1.6300
[7/200][0/24] Loss_D: 0.5057 Loss_G: 1.5529
[8/200][0/24] Loss_D: 0.6584 Loss_G: 1.1120
[9/200][0/24] Loss_D: 0.5071 Loss_G: 1.3675
[10/200][0/24] Loss_D: 0.4760 Loss_G: 1.4471
[11/200][0/24] Loss_D: 0.4506 Loss_G: 1.4142
[12/200][0/24] Loss_D: 0.4286 Loss_G: 1.4060
[13/200][0/24] Loss_D: 0.4439 Loss_G: 1.5326
[14/200][0/24] Loss_D: 0.3445 Loss_G: 1.6963
[15/200][0/24] Loss_D: 0.3220 Loss_G: 1.7178
[16/200][0/24] Loss_D: 0.3120 Loss_G: 1.6725
[17/200][0/24] Loss_D: 0.3102 Loss_G: 1.7859
[18/200][0/24] Loss_D: 0.2809 Loss_G: 1.9249
[19/200][0/24] Loss_D: 0.2918 Loss_G: 1.9887
[20/200][0/24] Loss_D: 0.2669 Loss_G: 2.0759
[21/200][0/24] Loss_D: 0.3397 Loss_G: 1.8341
[22/200][0/24] Loss_