# Examples for loading samples

## September 2020

In [3]:
import torch

import utils

# Gaussian projections of greyscale CIFAR100

In [3]:
device = "cpu"
num_samples = 10000
generator_name = "dcgan_cifar100_grey"
transform_name = "rand_proj_gauss_sign_from1024_to2048"
cs, xs = utils.get_inputs_by_name(device, num_samples, generator_name, transform_name)

Loaded moments of generator dcgan_cifar100_grey


Let's make a quick check to ensure that the covariance matrix stored agrees with the empirical covariance matrix of the data we just generated:

In [14]:
emp_cov = 1 / num_samples * xs.T @ xs

In [15]:
Omega_fname = "moments/%s_%s_omega.pt" % (generator_name, transform_name)
Omega = torch.load(Omega_fname, map_location=device)

In [16]:
torch.sum((Omega - emp_cov)**2) / torch.sum(Omega**2)

tensor(0.0051)

## Let's look at some images

In [1]:
import torch

from torchvision import datasets
import torchvision.utils as vutils

import utils

Load the generator

In [2]:
gen = utils.get_generator("dcgan_cifar100_grey", device="cpu")

and generate some images:

In [10]:
cs = torch.randn(64, gen.N_in)

In [11]:
gan_xs = gen.transform(cs).reshape(-1, 1, 32, 32)

Now let's load cifar100 and transform to grayscale

In [5]:
N = 32 * 32
dataset = datasets.CIFAR100("~/datasets/cifar100", train=False, download=True)
xs = torch.tensor(dataset.data).float()

# Now let's transform the inputs to gray-scale
xs = xs.permute(0, 3, 1, 2)
# constants taken from the pyTorch function rgb_to_grayscale
cifar100_xs = (
    0.2989 * xs[:, 0, :, :]
    + 0.5870 * xs[:, 1, :, :]
    + 0.1140 * xs[:, 2, :, :]
).reshape(-1, 1, 32, 32)

Files already downloaded and verified


Now do some manual normalisation for the images to be in the range (0, 1):

In [12]:
cifar100_xs_normalised = cifar100_xs / 255

gan_xs_normalised = (gan_xs - torch.min(gan_xs))
gan_xs_normalised /= torch.max(gan_xs_normalised)

and put the two together:

In [13]:
images = torch.zeros(64, 1, 32, 32)

In [14]:
images[:32] = cifar100_xs_normalised[:32]
images[32:] = gan_xs_normalised[:32]

In [15]:
vutils.save_image(images.detach(), "dcgan_cifar100_gray_samples.png", 
       )