In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, ConcatDataset
from torchvision import datasets, transforms
from torchvision.utils import save_image
import os

# Define the generator network
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(100, 256),
            nn.ReLU(),
            nn.Linear(256, 2 * 3 * 64 * 64),  # Two channels for cats and dogs
            nn.Tanh()
        )
        self.model = self.model.cuda()

    def forward(self, z):
        return self.model(z).view(-1, 2, 3, 64, 64)  # Two channels

# Define the discriminator network
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(2 * 3 * 64 * 64, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )
        self.model = self.model.cuda()

    def forward(self, img):
        return self.model(img.view(img.size(0), -1))

# Set up data loader for both cats and dogs
def load_data(root_folder_cat, root_folder_dog, batch_size):
    transform = transforms.Compose([
        transforms.Resize((64, 64)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    dataset_cat = datasets.ImageFolder(root=root_folder_cat, transform=transform)
    dataset_dog = datasets.ImageFolder(root=root_folder_dog, transform=transform)
    
    combined_dataset = ConcatDataset([dataset_cat, dataset_dog])
    dataloader = DataLoader(combined_dataset, batch_size=batch_size, shuffle=True, num_workers=4)

    return dataloader

# Training function
def train_gan(generator, discriminator, dataloader, num_epochs, batch_size, learning_rate):
    criterion = nn.BCELoss()
    optimizer_g = optim.Adam(generator.parameters(), lr=learning_rate, betas=(0.5, 0.999))
    optimizer_d = optim.Adam(discriminator.parameters(), lr=learning_rate, betas=(0.5, 0.999))

    fixed_noise = torch.randn(64, 100).cuda()

    for epoch in range(num_epochs):
        for i, data in enumerate(dataloader, 0):
            real_imgs, _ = data
            real_imgs = real_imgs.cuda()

            # Train discriminator with real images
            optimizer_d.zero_grad()
            real_outputs = discriminator(real_imgs)
            real_labels = torch.ones(real_outputs.size(0), 1).cuda()  # Explicitly set the target shape
            loss_real = criterion(real_outputs, real_labels)
            loss_real.backward()

            # Train discriminator with fake images
            noise = torch.randn(real_imgs.size(0), 100).cuda()
            fake_imgs = generator(noise)
            fake_outputs = discriminator(fake_imgs.detach())
            fake_labels = torch.zeros(fake_outputs.size(0), 1).cuda()  # Explicitly set the target shape
            loss_fake = criterion(fake_outputs, fake_labels)
            loss_fake.backward()
            optimizer_d.step()

            # Train generator
            optimizer_g.zero_grad()
            gen_labels = torch.ones(fake_outputs.size(0), 1).cuda()  # Explicitly set the target shape
            gen_outputs = discriminator(fake_imgs)
            loss_gen = criterion(gen_outputs, gen_labels)
            loss_gen.backward()
            optimizer_g.step()

            if i % 100 == 0:
                print('[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f'
                      % (epoch, num_epochs, i, len(dataloader),
                         (loss_real + loss_fake).item(), loss_gen.item()))

        if epoch % 10 == 0:
            fake = generator(fixed_noise)
            save_image(fake.view(-1, 3, 64, 64).detach(), 'fake_samples_epoch_%03d.png' % epoch, normalize=True)

# Set random seed for reproducibility
torch.manual_seed(42)

# Set up data loaders for both cats and dogs
combined_dataloader = load_data('/home/tom/Python/Machine learning/pytorch/GaN/data/cat/',
                                '/home/tom/Python/Machine learning/pytorch/GaN/data/dogs',
                                batch_size=64)

# Initialize the generator and discriminator
generator = Generator()
discriminator = Discriminator()

# Train the GAN
train_gan(generator, discriminator, combined_dataloader, num_epochs=200, batch_size=64, learning_rate=0.0002)
