In [1]:
import torch
from torchvision.utils import make_grid
from torchvision.transforms.functional import to_pil_image
from datasets.CIFAR10 import Generator, Discriminator, Partitioner, Z_DIM, SHAPE
from pathlib import Path
import matplotlib.pyplot as plt

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
actor = f"CIFAR10.40"
weights_path = Path("results") / "weights"
destination_path = Path("..") / "report" / "images"

In [3]:
generator_state_dict = torch.load(weights_path / actor / "generator.pt", map_location=device)
generator = Generator()
generator.load_state_dict(generator_state_dict)
generator.eval()

Generator(
  (main): Sequential(
    (0): ConvTranspose2d(100, 512, kernel_size=(4, 4), stride=(1, 1), bias=False)
    (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (7): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU(inplace=True)
    (9): ConvTranspose2d(128, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (10): Tanh()
  )
)

In [4]:
seed = torch.randn((100, 100, 1, 1), device=device)

In [5]:
with torch.no_grad():
    images = generator(seed)
    grid = make_grid(images, nrow=10, padding=0, value_range=(-1, 1), normalize=True)
    pil = to_pil_image(grid)
    pil = pil.resize((SHAPE[1] * 10, SHAPE[2] * 10), resample=None)
    pil.save(destination_path / f"{actor}.png")