In [1]:
import os
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from PIL import Image
import random

# Custom Dataset Class for CelebA-HQ
class CelebAHQDataset(Dataset):
    def __init__(self, img_dir, transform=None, num_images=10000):
        self.img_dir = img_dir
        self.transform = transform
        self.image_files = os.listdir(img_dir)[:num_images]  # Select only the first 10,000 images
        self.images = []
        for img_file in self.image_files:
            img_path = os.path.join(img_dir, img_file)
            image = Image.open(img_path).convert("RGB")
            if self.transform:
                image = self.transform(image)
            self.images.append(image)

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        return self.images[idx]


# Define the transformations
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])

# Initialize the dataset and dataloader
dataset = CelebAHQDataset(img_dir='datasets/celeba_hq_256', transform=transform, num_images=10000)


In [2]:
# Load your custom pre-trained model
import pickle
model_path = 'models/stylegan3-t-ffhqu-256x256.pkl'
with open(model_path, 'rb') as f:
    checkpoint = pickle.load(f)

generator = checkpoint['G']  # Load the generator part of the GAN
discriminator = checkpoint['D']  # Load the discriminator part of the GAN

generator.eval()  # Set the model to evaluation mode
discriminator.eval()  # Set the model to evaluation mode

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
generator = generator.to(device)
discriminator = discriminator.to(device)


cuda


In [3]:
adversarial_loss = torch.nn.BCELoss()
pixelwise_loss = torch.nn.L1Loss()

optimizer_G = torch.optim.Adam(generator.parameters(), lr=0.0001)
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=0.0001)

if __name__ == '__main__':
    dataloader = DataLoader(dataset, batch_size=64, shuffle=True, num_workers=0)
    print('Starting Fine-Tuning')
    num_epochs = 5
    for epoch in range(num_epochs):
        print('starting epoch', epoch)
        for i, imgs in enumerate(dataloader):
            print('starting batch', i)
            real_imgs = imgs.to(device)
            batch_size = real_imgs.size(0)
            # Adversarial ground truths
            valid = torch.ones((batch_size, 1), requires_grad=False).to(device)
            fake = torch.zeros((batch_size, 1), requires_grad=False).to(device)

            # ------------------
            # Train Generator
            # ------------------
            optimizer_G.zero_grad()

            # Sample noise as generator input
            z = torch.randn(batch_size, 512).to(device)
            c = torch.zeros(batch_size, 0).to(device)

            # Generate a batch of images
            gen_imgs = generator(z,c)

            # Loss measures generator's ability to fool the discriminator
            g_loss = adversarial_loss(discriminator(gen_imgs, c), valid)
            g_loss += pixelwise_loss(gen_imgs, real_imgs)  # Adding pixel-wise loss

            g_loss.backward()
            optimizer_G.step()

            # ---------------------
            # Train Discriminator
            # ---------------------
            optimizer_D.zero_grad()

            # Measure discriminator's ability to classify real from generated samples
            real_loss = adversarial_loss(discriminator(real_imgs), valid)
            fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)
            d_loss = (real_loss + fake_loss) / 2

            d_loss.backward()
            optimizer_D.step()

            print(f"[Epoch {epoch+1}/{num_epochs}] [Batch {i+1}/{len(dataloader)}] "
                f"[D loss: {d_loss.item():.4f}] [G loss: {g_loss.item():.4f}]")

    print('Finished Fine-Tuning')

    torch.save({
        'generator': generator.state_dict(),
        'discriminator': discriminator.state_dict()
    }, 'models/celeba-hq-256-tuned_stylegan3-t-ffhqu-256.pth')

    print('Finished Saving the Model')


Starting Fine-Tuning
starting epoch 0
