In [None]:
import sys
sys.path.insert(0, '..') 

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets as dset
import os
import cv2
import numpy as np
import pandas as pd
import random

from torchvision.models import inception_v3
from torchvision import transforms, datasets
from torch.utils.data import DataLoader, TensorDataset
from tqdm import tqdm
import matplotlib.pyplot as plt
from gan_package.gan import GAN
from gan_package.vanillaGAN import VanillaGAN_Generator, VanillaGAN_Discriminator
from gan_package.dcGAN import DCGAN_Generator, DCGAN_Discriminator

# Create dataset

In [2]:
image_size = 256
root = '../lsun/bedroom/0/0'

dataset = dset.ImageFolder(root=root,
                           transform=transforms.Compose([
                               transforms.Resize(image_size),
                               transforms.CenterCrop(image_size),
                               transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                           ]))

# Define parameter grid

In [None]:
batch_sizes = [16, 32, 64, 128, 256]
learning_rate_gens = [0.0001, 0.0005, 0.001, 0.005, 0.01, 0.05, 0.1]
learning_rate_discs = [0.0001, 0.0005, 0.001, 0.005, 0.01, 0.05]

inception_model = inception_v3(pretrained=True, transform_input=False, aux_logits=True)

# Random Search

In [None]:
inception_model = inception_v3(pretrained=True, transform_input=False, aux_logits=True)
results = pd.DataFrame(columns=['batch_size', 'learning_rate_gen', 'learning_rate_disc', 'fid_score', 'loss_discriminator', 'loss_generator'])


num_searches = 10  

for _ in range(num_searches):
    batch_size = random.choice(batch_sizes)
    learning_rate_gen = random.choice(learning_rate_gens)
    learning_rate_disc = random.choice(learning_rate_discs)
    
    print('Learning_rate_gen:', learning_rate_gen, 'Learning_rate_disc:', learning_rate_disc, 'Batch_size:', batch_size)
    
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=2)
    
    img_shape = dataloader.dataset[0][0].shape
    n_out = torch.prod(torch.tensor(img_shape))
    latent_dim = 100

    vanilla_generator = VanillaGAN_Generator(latent_dim=latent_dim, img_shape=img_shape, n_out=n_out)
    vanilla_discriminator = VanillaGAN_Discriminator(img_shape=img_shape)
    vanilla_gan = GAN(generator=vanilla_generator, discriminator=vanilla_discriminator, inception_model=inception_model)

    criterion = nn.BCELoss()

    vanilla_generator_optimizer = optim.Adam(vanilla_generator.parameters(), lr=learning_rate_gen, weight_decay=0.0001)
    vanilla_discriminator_optimizer = optim.Adam(vanilla_discriminator.parameters(), lr=learning_rate_disc, weight_decay=0.0001)

    vanilla_gan.train(dataset=dataset,
                      dataloader=dataloader,
                      discriminator_optimizer=vanilla_discriminator_optimizer,
                      generator_optimizer=vanilla_generator_optimizer,
                      criterion=criterion,
                      num_epochs=2)
    
    results = results.append({'batch_size': batch_size,
                              'learning_rate_gen': learning_rate_gen,
                              'learning_rate_disc': learning_rate_disc,
                              'fid_score': vanilla_gan.history['fid_scores'][-1],
                              'loss_discriminator': vanilla_gan.history['d_losses'][-1],
                              'loss_generator': vanilla_gan.history['g_losses'][-1]},
                              ignore_index=True)

In [None]:
# save results
results.to_csv('../results/hyperparameter_tunnig.csv', index=False)