In [6]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.utils import save_image
import os

In [7]:
# Use GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [8]:
import torch.nn as nn

# Define the generator network
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(100, 256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, 1024),
            nn.ReLU(),
            nn.Linear(1024, 3 * 64 * 64),
            nn.Tanh()
        )
        self.model = self.model.to(device)

    def forward(self, z):
        return self.model(z).view(-1, 3, 64, 64)



In [9]:
# Define the discriminator network
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2),
            nn.Flatten(),
            nn.Linear(256 * 8 * 8, 1),
            nn.Sigmoid()
        )
        self.model = self.model.to(device)

    def forward(self, img):
        return self.model(img)

In [10]:
torch.manual_seed(42)

<torch._C.Generator at 0x7f6e682687f0>

In [11]:
# Set up data loaders
def load_data(root_folder, batch_size):
    transform = transforms.Compose([
        transforms.Resize((64, 64)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    dataset = datasets.ImageFolder(root=root_folder, transform=transform)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)

    return dataloader



In [12]:
# Training function
def train_gan(generator, discriminator, dataloader, num_epochs, batch_size, learning_rate):
    criterion = nn.BCELoss()
    optimizer_g = optim.Adam(generator.parameters(), lr=learning_rate, betas=(0.5, 0.999))
    optimizer_d = optim.Adam(discriminator.parameters(), lr=learning_rate, betas=(0.5, 0.999))

    fixed_noise = torch.randn(64, 100).to(device)

    for epoch in range(num_epochs):
        for i, data in enumerate(dataloader, 0):
            real_imgs, _ = data
            real_imgs = real_imgs.to(device)

 
            # we first generate outputs from 'descriminator' by using real images and fake images (fake images are generated by 'generator' with noise as input) and then we calculate loss for fake and real image cases and then use it to train descriminator (losses from outputs of descriminator function for real images and fake images are used in backpropogation). Then, 'generator' is trained  by loss generated from discriminator' output of fake images
            # Train discriminator with real images
            optimizer_d.zero_grad()
            real_outputs = discriminator(real_imgs)
            real_labels = torch.ones(real_outputs.size(0), 1).to(device)  # Explicitly set the target shape
            loss_real = criterion(real_outputs, real_labels)
            loss_real.backward()
            
            # Train discriminator with fake images
            noise = torch.randn(real_imgs.size(0), 100).cuda()
            fake_imgs = generator(noise)
            fake_outputs = discriminator(fake_imgs.detach())
            fake_labels = torch.zeros(fake_outputs.size(0), 1).to(device)  # Explicitly set the target shape
            loss_fake = criterion(fake_outputs, fake_labels)
            loss_fake.backward()
            optimizer_d.step()
            
            # Train generator
            optimizer_g.zero_grad()
            gen_labels = torch.ones(fake_outputs.size(0), 1).to(device)  # Explicitly set the target shape
            gen_outputs = discriminator(fake_imgs)
            loss_gen = criterion(gen_outputs, gen_labels)
            loss_gen.backward()
            optimizer_g.step()

            if i % 100 == 0:
                print('[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f'
                      % (epoch, num_epochs, i, len(dataloader),
                         (loss_real + loss_fake).item(), loss_gen.item()))

        if epoch % 10 == 0:
            fake = generator(fixed_noise)
            save_image(fake.detach(), 'fake_samples_epoch_%03d.png' % epoch, normalize=True)

In [13]:
# Set random seed for reproducibility
torch.manual_seed(42)

# Set up data loaders for cat and dog images
cat_dataloader = load_data('/home/tom/Python/Machine learning/pytorch/data/cats/', batch_size=64)
#dog_dataloader = load_data('/home/tom/Python/Machine learning/pytorch/GaN/data/dogs', batch_size=64)

# Initialize the generator and discriminator
generator = Generator()
discriminator = Discriminator()

generator.to(device)
discriminator.to(device)

# Train the GAN
train_gan(generator, discriminator, cat_dataloader, num_epochs=200, batch_size=64, learning_rate=0.0002)


[0/200][0/27] Loss_D: 1.5184 Loss_G: 4.6142
[1/200][0/27] Loss_D: 0.6283 Loss_G: 5.4614
[2/200][0/27] Loss_D: 0.3293 Loss_G: 5.0455
[3/200][0/27] Loss_D: 0.3355 Loss_G: 4.9650
[4/200][0/27] Loss_D: 0.1514 Loss_G: 5.1523
[5/200][0/27] Loss_D: 0.2153 Loss_G: 3.2033
[6/200][0/27] Loss_D: 0.2326 Loss_G: 3.4324
[7/200][0/27] Loss_D: 0.2360 Loss_G: 4.3199
[8/200][0/27] Loss_D: 0.1339 Loss_G: 4.7388
[9/200][0/27] Loss_D: 0.1670 Loss_G: 4.1772
[10/200][0/27] Loss_D: 0.2433 Loss_G: 4.5591
[11/200][0/27] Loss_D: 0.2830 Loss_G: 3.7545
[12/200][0/27] Loss_D: 1.1343 Loss_G: 4.8237
[13/200][0/27] Loss_D: 0.4937 Loss_G: 3.0803
[14/200][0/27] Loss_D: 0.5474 Loss_G: 4.5738
[15/200][0/27] Loss_D: 0.6311 Loss_G: 3.1122
[16/200][0/27] Loss_D: 0.6729 Loss_G: 6.4456
[17/200][0/27] Loss_D: 0.2445 Loss_G: 4.2814
[18/200][0/27] Loss_D: 0.2686 Loss_G: 4.9952
[19/200][0/27] Loss_D: 0.4218 Loss_G: 3.5474
[20/200][0/27] Loss_D: 0.9026 Loss_G: 8.5791
[21/200][0/27] Loss_D: 0.6131 Loss_G: 4.2350
[22/200][0/27] Loss_