In [None]:
import torch
from pytorch_pretrained_biggan import (BigGAN, one_hot_from_int, truncated_noise_sample,
                                       save_as_images, display_in_terminal)

model = BigGAN.from_pretrained('biggan-deep-128') # or biggan-deep-256

import numpy as np
from PIL import Image

def to_pil(out):
    img_np = out.detach().cpu().numpy()[0]
    img_np = np.clip((img_np + 1) / 2.0 * 255.0, 0, 255).astype(np.uint8)

    # C X H X W to H x W x C
    img_np = img_np.transpose(1, 2, 0)

    return Image.fromarray(img_np)

def show(img):
    pred = np.clip((img.detach().cpu().squeeze().numpy()+1) / 2.0 * 255, 0, 255)
    pred = pred.transpose(1, 2, 0).astype("uint8")
    pred = Image.fromarray(pred)
    display(pred)

In [None]:
def find_z(img, G, class_, truncation, lr, num_iter=5000):
    # sample tensor
    z_ = torch.from_numpy(truncated_noise_sample(truncation=truncation, batch_size=1)).to('cuda')

    dist = torch.nn.MSELoss()

    # RMSProp or ADAM or SGD
    optZ = torch.optim.SGD([z_], lr=lr, momentum=0)
    
    for i in range(num_iter):
        img_pred = G(z_, class_, truncation)
        loss = dist(img_pred, img)
    
        if i % 100 == 0:
            print("[Iter {}] error: {}"
                  .format(i, loss.item()))
            show(img_pred)
            show(img)
            

        optZ.zero_grad()
        loss.backward()
        optZ.step()

        # truncate
        # z_ = torch.clamp(z_, max=1., min=-1.)

    return z_

In [None]:
truncation = 0.4  # truncation trick (0-1, larger value will generate more diverse image but the quality is worse)
class_vector = one_hot_from_int(250)  # input imagenet label here
noise_vector = truncated_noise_sample(truncation=truncation, batch_size=1)

noise_vector = torch.from_numpy(noise_vector)
class_vector = torch.from_numpy(class_vector)

# If you have a GPU, put everything on cuda
noise_vector = noise_vector.to('cuda')
class_vector = class_vector.to('cuda')
model.to('cuda')

# Generate an image
with torch.no_grad():
    output = model(noise_vector, class_vector, truncation)


In [None]:
# inversion or denoising 
noise_scale = 0.   # set to something larger than 0 for denoising
img_gt = torch.clamp(torch.randn_like(output) * noise_scale + output, min=-1, max=1)  

# when we can only change latent code z, even in range image's inversion is hard for BigGAN using gradient descent
z_pred = find_z(img_gt.detach(), model, class_vector, truncation, lr=1, num_iter=2000)