In [None]:
import os
import random
import numpy as np
from glob import glob
from PIL import Image
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms, utils
from torchvision.datasets import ImageFolder
from torchvision.models import inception_v3
from data_processing import make_dataloader
from utils import compute_activation_statistics, load_or_compute_real_stats
from utils import compute_fid_score_unified
from gan import DCGANGenerator, DCGANDiscriminator, train_dcgan,train_lsgan, interpolate_and_generate
from vae import VAE, train_vae, generate_vae_samples, vae_interpolate



In [None]:
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True  
    torch.backends.cudnn.benchmark = False    
set_seed(42)
class Config:
    # Data
    root_path = 'Data'             # path to cats-only images
    image_size = 64                     # resize all images to this size
    batch_size = 64
    ndf = 64                     # number of discriminator filters
    ngf = 64                     # number of generator filters

    # Model & training
    latent_dim = 100
    lr = 0.0001
    beta1 = 0.5                        
    num_epochs = 30

    # FID
    fid_batch = 50
    device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')

cfg = Config()


loader_cats = make_dataloader(cfg)


In [None]:
gen_cats, disc_cats, history = train_dcgan(loader_cats, cfg, save_path=f'./histories/dcgan.pkl')
fid_value = compute_fid_score_unified(
generator_type='gan',
model=gen_cats,
real_loader=loader_cats,
cfg=cfg,
num_fake=1000  
)

In [None]:
set_seed(12511)
z1 = np.random.randn(cfg.latent_dim, 1, 1)
z2 = np.random.randn(cfg.latent_dim, 1, 1)
imgs = interpolate_and_generate(gen_cats, z1, z2,cfg, steps=10)
utils.save_image(imgs, 'interpolated_wgan_cats_30epochs.png', nrow=10, normalize=True)