In [1]:
import torch
from torchvision.utils import make_grid
from torchvision.transforms.functional import to_pil_image
from datasets.CIFAR10 import Generator, Discriminator, Partitioner, Z_DIM, SHAPE
from pathlib import Path
import matplotlib.pyplot as plt

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
actor = f"1000"
weights_path = Path("results") / "weights"
destination_path = Path("..") / "report" / "images"
N_real = 3

In [3]:
dataset = Partitioner(0, 0)
dataset.load_data()
dataloader = torch.utils.data.DataLoader(dataset.train_dataset, batch_size=100, shuffle=True)
for i in range(N_real):
    real_batch = next(iter(dataloader))[0]
    pil = to_pil_image(make_grid(real_batch, nrow=10, value_range=(-1, 1), normalize=True, padding=0))
    pil = pil.resize((SHAPE[1] * 10, SHAPE[2] * 10), resample=None)
    pil.save(destination_path / f"real_batch_{i}.png")

Files already downloaded and verified
Files already downloaded and verified


In [4]:
generator_state_dict = torch.load(weights_path / actor / "generator.pt", map_location=device)
generator = Generator()
generator.load_state_dict(generator_state_dict)
generator.eval()

Generator(
  (main): Sequential(
    (0): ConvTranspose2d(100, 512, kernel_size=(4, 4), stride=(1, 1), bias=False)
    (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (7): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU(inplace=True)
    (9): ConvTranspose2d(128, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (10): Tanh()
  )
)

In [23]:
weights = (weights_path/"continuous").glob("generator_*.pt")
weights = sorted(weights, key=lambda x: int(x.stem.split("_")[1]))
print(f"Loading weights {weights}")

images = []
for i, weight in enumerate(weights):
    weight_name = int(weight.stem.split('_')[1])

    if weight_name % 200 != 0 and weight_name != 29999:
        continue

    generator_state_dict = torch.load(weight, map_location=device)
    generator = Generator()
    generator.load_state_dict(generator_state_dict)
    generator.eval()

    z = torch.randn(1, Z_DIM, 1, 1, device=device)
    fake_batch = generator(z)
    single_image = make_grid(fake_batch, nrow=1, value_range=(-1, 1), normalize=True, padding=0)
    #single_image = single_image.resize((SHAPE[1] * 5, SHAPE[2] * 5), resample=None)
    images.append(single_image)
    print(f"Generated {weight_name}")

grid = to_pil_image(make_grid(images, nrow=14, padding=0))

# Create a white background image with the same size as the grid image
white_background = Image.new("RGB", grid.size, (255, 255, 255))

# Paste the grid image onto the white background
white_background.paste(grid, (0, 0))

white_background.save(destination_path / "fake_batch_grid.png")

Loading weights [PosixPath('results/weights/continuous/generator_50.pt'), PosixPath('results/weights/continuous/generator_100.pt'), PosixPath('results/weights/continuous/generator_150.pt'), PosixPath('results/weights/continuous/generator_200.pt'), PosixPath('results/weights/continuous/generator_250.pt'), PosixPath('results/weights/continuous/generator_300.pt'), PosixPath('results/weights/continuous/generator_350.pt'), PosixPath('results/weights/continuous/generator_400.pt'), PosixPath('results/weights/continuous/generator_450.pt'), PosixPath('results/weights/continuous/generator_500.pt'), PosixPath('results/weights/continuous/generator_550.pt'), PosixPath('results/weights/continuous/generator_600.pt'), PosixPath('results/weights/continuous/generator_650.pt'), PosixPath('results/weights/continuous/generator_700.pt'), PosixPath('results/weights/continuous/generator_750.pt'), PosixPath('results/weights/continuous/generator_800.pt'), PosixPath('results/weights/continuous/generator_850.pt')

NameError: name 'Image' is not defined

In [6]:
seed = torch.randn((100, 100, 1, 1), device=device)

In [7]:
with torch.no_grad():
    images = generator(seed)
    grid = make_grid(images, nrow=10, padding=0, value_range=(-1, 1), normalize=True)
    pil = to_pil_image(grid)
    pil = pil.resize((SHAPE[1] * 10, SHAPE[2] * 10), resample=None)
    pil.save(destination_path / f"{actor}.png")