In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

def show_tensor_images(image_tensor, num_images=16, size=(3, 64, 64), nrow=3):
    '''
    Function for visualizing images: Given a tensor of images, number of images,
    size per image, and images per row, plots and prints the images in an uniform grid.
    '''
    image_tensor = (image_tensor + 1) / 2
    image_unflat = image_tensor.detach().cpu().clamp_(0, 1)
    image_grid = make_grid(image_unflat[:num_images], nrow=nrow, padding=0)
    plt.imshow(image_grid.permute(1, 2, 0).squeeze())
    plt.axis('off')
    plt.show()

In [4]:
from scipy.stats import truncnorm # The distrubution
def get_truncated_noise(n_samples, z_dim, truncation):
  lower_bound = -truncation
  upper_bound = truncation
  # Sample from the imported distribution
  truncated_noise = truncnorm.rvs(lower_bound, upper_bound, size=(n_samples, z_dim))
  return torch.Tensor(truncated_noise)

In [5]:
class MappingLayers(nn.Module):
    def __init__(self, z_dim, hidden_dim, w_dim):
        super().__init__()
        self.mapping = nn.Sequential(
            nn.Linear(z_dim, hidden_dim),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_dim, w_dim),
        )

    def forward(self, noise):
        return self.mapping(noise)

In [6]:
class InjectNoise(nn.Module):
  def __init__(self, channels):
    super().__init__()
    self.weights = nn.Parameter(
        torch.randn(1, channels, 1, 1)
    )

  def forward(self, image):
    noise_shape = (image.size(0), 1, image.size(2), image.size(3))
    noise = torch.randn(noise_shape, device=image.device)

    return image + self.weight * noise

In [7]:
class AdaIN(nn.Module):
  def __init__(self, channels, w_dim):
    super().__init__()
    self.instance_normalizer = nn.InstanceNorm2d(channels)
    self.style_scalar = nn.Linear(w_dim, channels)
    self.style_shifter = nn.Linear(w_dim, channels)

  def forward(self, image, w):
    normalized = self.instance_normalizer(image)  # b, c, h, w
    # Apply styles to each channel of the image
    scale = self.style_scalar(w)[:, :, None, None] # b, c, 1, 1
    shift = self.style_shifter(w)[:, :, None, None] # b, c, 1, 1
    styled_image = scale * normalized + shift

    return styled_image

In [9]:
class GeneratorBlock(nn.Module):
  def __init_(self, w_dim, in_chan, out_chan, kernel_size, starting_size, use_upsample=True):
    super().__init__()
    self.use_upsample = use_upsample
    if self.use_upsample:
      self.upsample = nn.Upsample(size=starting_size, mode="bilinear", align_corners=False)
    self.conv = nn.Conv2d(in_chan, out_chan, kernel_size, padding=1)
    self.inject_noise = InjectNoise(out_chan)
    self.adain = AdaIN(out_chan, w_dim)
    self.activation = nn.LeakyReLU(0.2)

  def forward(self, x, w):
    if self.use_upsample:
      x = self.upsample(x)
    x = self.conv(x)
    x = self.inject_noise(x)
    x = self.adain(x, w)
    x = self.activation(x)
    return x

In [None]:
class Generator(nn.Module):
  pass