#BigGan

In [None]:
!pip install pytorch-pretrained-biggan

In [None]:
import nltk
import matplotlib.pyplot as plt
import logging
nltk.download('wordnet')
%matplotlib inline

import torch
from pytorch_pretrained_biggan import *
import numpy as np

device = 'cuda'

[nltk_data] Downloading package wordnet to /root/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!


In [None]:
def vector_linspace(start, end, steps):
  """
  Vector version of torch linspace
  """
  result = []
  for dim in range(start.shape[0]):
    result.append(torch.linspace(start[dim], end[dim], steps))
  result = torch.stack(result, dim=1).to(device)
  return result

In [None]:
def show_noise_interpolations(n_rows, n_cols, image_size,truncation, label, scale=3):
    """
    Shows image interpolation (grid of [`n_rows`, `n_cols`]) in input noise space.
    """
    N = n_rows * n_cols
    class_vector = one_hot_from_names([label] * N, batch_size=N)
    anchor_noises = truncated_noise_sample(truncation=truncation, batch_size=4)

    anchor_noises = torch.from_numpy(anchor_noises).to(device)
    class_vector = torch.from_numpy(class_vector).to(device)

    left_column = vector_linspace(anchor_noises[0], anchor_noises[1], n_rows)
    right_column = vector_linspace(anchor_noises[2], anchor_noises[3], n_rows)
    rows = []
    for i in range(n_rows):
      rows.append(vector_linspace(left_column[i], right_column[i], n_cols))
    noises = torch.stack(rows, dim=0).view(n_rows * n_cols, -1)

    with torch.no_grad():
      fake_imgs = model(noises, class_vector, truncation)

    biggan_grid_show(fake_imgs,image_size, n_rows, scale=scale)

In [None]:
def biggan_grid_show(image_batch, image_size, rows=1, scale=3):
  """
  This function gets multiple images and plots them in the given number of rows.
  """
  image_batch = image_batch.detach().cpu()
  image_batch = image_batch.view(-1, 3, image_size, image_size)
  image_batch = image_batch.numpy()

  cols = np.ceil(image_batch.shape[0] / rows)
  plt.rcParams['figure.figsize'] = (cols * scale, rows * scale)

  for i in range(image_batch.shape[0]):
    img = convert_to_images(np.expand_dims(image_batch[i], axis=0))[0]
    plt.subplot(rows, cols, i + 1)
    # plt.imshow(np.transpose(img, [1, 2, 0]))
    plt.imshow(img)
    plt.axis('off')
  plt.show()

In [None]:
def show_class_interpolations(n_rows, n_cols, image_size, truncation, labels, scale=3):
    """
    Shows image interpolation (grid of [`n_rows`, `n_cols`]) in input noise space.
    """
    assert len(labels) == 4
    N = n_rows * n_cols
    class_vector = one_hot_from_names(labels, batch_size=4)
    noise = truncated_noise_sample(truncation=truncation, batch_size=1)

    noise = torch.from_numpy(noise).to(device)
    class_vector_anchors = torch.from_numpy(class_vector).to(device)

    left_column = vector_linspace(class_vector_anchors[0], class_vector_anchors[1], n_rows)
    right_column = vector_linspace(class_vector_anchors[2], class_vector_anchors[3], n_rows)
    rows = []
    for i in range(n_rows):
      rows.append(vector_linspace(left_column[i], right_column[i], n_cols))
    
    class_vectors = torch.stack(rows, dim=0).view(n_rows * n_cols, -1)
    noises = noise.expand(n_rows * n_cols, -1)

    with torch.no_grad():
      fake_imgs = model(noises, class_vectors, truncation)

    biggan_grid_show(fake_imgs,image_size, n_rows, scale=scale)

In [None]:
import torch
from pytorch_pretrained_biggan import (BigGAN, one_hot_from_names, truncated_noise_sample,
                                       save_as_images, display_in_terminal)

import logging
logging.basicConfig(level=logging.INFO)

model = BigGAN.from_pretrained('biggan-deep-256') #128 , 256, 512.

truncation = 0.5
class_vector = one_hot_from_names(['soap bubble', 'coffee', 'mushroom'], batch_size=3)
noise_vector = truncated_noise_sample(truncation=truncation, batch_size=3)

noise_vector = torch.from_numpy(noise_vector)
class_vector = torch.from_numpy(class_vector)

noise_vector = noise_vector.to('cuda')
class_vector = class_vector.to('cuda')
model.to('cuda')

with torch.no_grad():
    output = model(noise_vector, class_vector, truncation)
    output = output.to('cpu')

biggan_grid_show(output,image_size=256, rows=1, scale=5)

In [None]:
truncation = 0.5
batch_size = 4

#use labels like : ladybug , cheetah , ...
class_vector = one_hot_from_names(['mushroom', 'husky', 'coffee' , 'ladybug'], batch_size=batch_size)

#use class numbers : 0 to 999
# class_vector = one_hot_from_int([548, 234, 300, 800], batch_size=batch_size)

noise_vector = truncated_noise_sample(truncation=truncation, batch_size=batch_size)

noise_vector = torch.from_numpy(noise_vector).to(device)
class_vector = torch.from_numpy(class_vector).to(device)

with torch.no_grad():
  output = model(noise_vector, class_vector, truncation)
  output = output.cpu()

biggan_grid_show(output,image_size = 256, rows=1, scale=5)

In [None]:
show_noise_interpolations(4, 4,256, truncation=0.5, label='husky', scale=3)

In [None]:
show_class_interpolations(4, 4,256 , truncation=.4, labels=['dog', 'husky', 'tiger', 'cheetah'], scale=3)