In [None]:
import torch
from pytorch_pretrained_biggan import (BigGAN, one_hot_from_int, truncated_noise_sample,
                                       save_as_images, display_in_terminal)
import numpy as np
from PIL import Image
from torchvision import transforms

def to_pil(out):
    img_np = out.detach().cpu().numpy()[0]
    img_np = np.clip((img_np + 1) / 2.0 * 255.0, 0, 255).astype(np.uint8)

    # C X H X W to H x W x C
    img_np = img_np.transpose(1, 2, 0)

    return Image.fromarray(img_np)

def show(img):
    pred = np.clip((img.detach().cpu().squeeze().numpy()+1) / 2.0 * 255, 0, 255)
    pred = pred.transpose(1, 2, 0).astype("uint8")
    pred = Image.fromarray(pred)
    display(pred)
    
truncation = 0.4  # truncation trick (0-1, larger value will generate more diverse image but the quality is worse)
class_vector = one_hot_from_int(250)  # input imagenet label here (husky)
noise_vector = truncated_noise_sample(truncation=truncation, batch_size=1)

noise_vector = torch.from_numpy(noise_vector)
class_vector = torch.from_numpy(class_vector)

# If you have a GPU, put everything on cuda
noise_vector = noise_vector.to('cuda')
class_vector = class_vector.to('cuda')
model = BigGAN.from_pretrained('biggan-deep-128').cuda()  

In [None]:
def dip(img, G, z_, class_, truncation, lr, num_iter=5000):
    """DIP Style Reconstruction. We don't change the latent code z now, instead we change the generator's weight"""
    G.train()
    
    dist = torch.nn.MSELoss()

    # RMSProp or ADAM or SGD
    optZ = torch.optim.Adam(G.parameters(), lr=lr)
    
    for i in range(num_iter):
        img_pred = G(z_, class_, truncation)
        loss = dist(img_pred, img)
    
        # scheduler.step(loss.item())
        if i % 100 == 0:
            print("[Iter {}] error: {}"
                  .format(i, loss.item()))
            show(img_pred)
            show(img)
            

        optZ.zero_grad()
        loss.backward()
        optZ.step()

        # truncate
        # z_ = torch.clamp(z_, max=1., min=-1.)
        
    return G(z_, class_, truncation)

In [None]:
# Generate an image so that the image is in the range
with torch.no_grad():
    img_gt = model(noise_vector, class_vector, truncation)
show(img_gt)

In [None]:
TRANSFORMS_256 = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(256),
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])

TRANSFORMS_128 = transforms.Compose([
    transforms.Resize(128),
    transforms.CenterCrop(128),
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])


In [None]:
# reset model for different image as we will change model's weight
model = BigGAN.from_pretrained('biggan-deep-128').cuda()  
z_dip = torch.from_numpy(truncated_noise_sample(truncation=truncation, batch_size=1)).to('cuda')

dip(img_gt, model, z_dip, class_vector, truncation, lr=1e-5, num_iter=700)

In [None]:
# load a real image
img_real = Image.open("./data/tiger_cat.png")


In [None]:
display(img_real)
img_real = TRANSFORMS_256(img_real).cuda().unsqueeze(0)

In [None]:
# reset model for different image as we will change model's weight
# You can see that dip style reconstruction with BigGan weights is quite good for inversion/denoising
model = BigGAN.from_pretrained('biggan-deep-256').cuda()  
z_dip = torch.from_numpy(truncated_noise_sample(truncation=truncation, batch_size=1)).to('cuda')

class_vector = one_hot_from_int(282)  # tiger-cat: 282, husky: 248
class_vector = torch.from_numpy(class_vector).cuda()

dip(img_real, model, z_dip, class_vector, truncation, lr=1e-5, num_iter=700)