In [1]:
import torch
from torch.utils.data import DataLoader, Subset
from torchvision import transforms
from torchvision.datasets import ImageFolder

from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
import argparse
import os
from datetime import datetime

from gan.train import train
from gan.losses import discriminator_loss, generator_loss
from gan.models import Discriminator, Generator
from gan.utils import sample_noise, deprocess_img, show_images

%load_ext autoreload
%autoreload 2

In [8]:
latent_dim = 50
imsize = 128
device = torch.device("cuda" if torch.cuda.is_available() else "gpu")

iter_no = 17325
model_load_path = "checkpoints_/2023-04-13 13:07:02.992796/{}/iter_{}.pth"

G = Generator(noise_dim=latent_dim, imsize=imsize).to(device)
G.load_state_dict(torch.load(model_load_path.format("generator", iter_no)))

result_dir = "./results/test/"

num_images = 1024

for i in range(num_images):
    noise = sample_noise(128, latent_dim).reshape(128, latent_dim, 1, 1).to(device)
    fake_image = G(noise).reshape(128, 3, imsize, imsize)
    disp_fake_image = deprocess_img(fake_image.data)
    img = (disp_fake_image).cpu().numpy()
    
    # show_images(img, color=True)
    
    img = img[0, :, :, :]
    
    # img = np.squeeze(img, axis=0)
    img = np.swapaxes(np.swapaxes(img, 0, 1), 1, 2)

    im = Image.fromarray((img * 255).astype(np.uint8))
    im.save(result_dir + "generated/{}.png".format(i))