<a href="https://colab.research.google.com/github/tristanoprofetto/neural-networks/blob/main/GAN/StyleGAN/StyleGAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch 
import torch.nn as nn
import torch.nn.functional as F
from torchvision.utils import make_grid
import matplotlib.pyplot as plt
from scipy.stats import truncnorm
import numpy as np

In [None]:
# Function for visualizing images given a tensor
def show_images(tensor, n, size=(3, 64, 64), nrow=3):

  image = (tensor+1)/2
  image = image.detach().cpu().clamp_(0, 1)
  grid = make_grid(image[:n], nrow=nrow, padding=0)

  plt.imshow(grid.permute(1, 2, 0).squeeze())
  plt.show()

### Truncation Trick

The truncation trick resamples the noise vector $z$ from a truncated normal distribution which allows you to tune the generator's fidelity/diversity. The truncation value is at least 0, where 1 means there is little truncation (high diversity) and 0 means the distribution is all truncated except for the mean (high quality/fidelity). This trick is not exclusive to StyleGAN. In fact, you may recall playing with it in an earlier GAN notebook.

* Truncation: non-negative scalar
* Z Dim: dimension of the noise vector
* Number of Samples: total samples to be generated


In [None]:
def truncation_noise(n_samples, z_dim, truncation):
  z = truncnorm.rvs(-1 * truncation, truncation, size=(n_samples, z_dim))
  z = torch.Tensor(z)

  return z

In [None]:
# Unit Test
assert tuple(truncation_noise(5, 10, 0.7).shape) ==  (10, 5)
simple_noise = truncation_noise(10, 1000, truncation=0.2)
assert simple_noise.max() > 0.199 and simple_noise.max() < 2
assert simple_noise.min() < -0.199 and simple_noise.min() > -0.2
assert simple_noise.std() > 0.113 and simple_noise.std() < 0.117
print("Success!")

Success!


### Mapping Network

takes the noise vector, $z$, and maps it to an intermediate noise vector, $w$. This makes it so $z$ can be represented in a more disentangled space which makes the features easier to control later.

(The mapping network in StyleGAN is composed of 8 layers)

In [None]:
class MappingNetwork(nn.Module):

  def __init__(self, z_dim, w_dim, hidden_dim):
    super().__init__()
    self.mapping = nn.Sequential(
        nn.Linear(z_dim, hidden_dim),
        nn.ReLU(),
        nn.Linear(hidden_dim, hidden_dim),
        nn.ReLU(),
        nn.Linear(hidden_dim, w_dim)
    )

  # Function for completing the forward pass of the Mapping Network given 'z', returns 'w'
  def feedForward(self, z):
    w = self.mapping(z)
    return w


### Random Noise Injection

Noise tensor is initialized as one random channel, then multiplied by learned weights for each channel in the image ... noise tensor must be the same size as the feature map ... (Occurs before every AdaIN block)

In [None]:
class NoiseInjection(nn.Module):

  def __init__(self, channels):
    super().__init__()
    self.weights = nn.Parameter(torch.randn(channels)[None, :, None, None])

  # Forwards pass: given an image adds random noise
  def feedForward(self, image):
    z = torch.randn(image.shape[0], 1, image.shape[2], iamge.shape[3])
    image = image + self.weights * z
    return image

### Adaptive Instance Normalization (AdaIN)

By injecting $w$ (intermediate noise) mutliple times throughout the network ... 
AdaIN takes the instance normalization of the image and multiplies it by the style scale ($y_s$) and adds the style bias ($y_b$). You need to calculate the learnable style scale and bias by using linear mappings from $w$.

$ \text{AdaIN}(\boldsymbol{\mathrm{x}}_i, \boldsymbol{\mathrm{y}}) = \boldsymbol{\mathrm{y}}_{s,i} \frac{\boldsymbol{\mathrm{x}}_i - \mu(\boldsymbol{\mathrm{x}}_i)}{\sigma(\boldsymbol{\mathrm{x}}_i)} + \boldsymbol{\mathrm{y}}_{b,i} $

In [None]:
class AdaIN(nn.Module):

  def __init__(self, w_dim, channels):
    super().__init__()

    self.instanceNorm = nn.InstanceNorm2d(channels)
    self.styleScaleTransform = nn.Linear(w_dim, channels)
    self.styleShiftTransform = nn.Linear(w_dim, channels)

  # Given an image and w, returns the normalized image that is scaled + shifted by the style
  def feedForward(self, image, w):

    norm_image = self.instanceNorm(image)
    scale = self.styleScaleTransform(w)[:, :, None, None]
    shift = self.styleShiftTransform(w)[:, :, None, None]

    final_image = norm_image * scale + shifted

    return final_image

### Progressive Growing

This component helps StyleGAN create high resolution images by doubling the images' size gradually

In [None]:
class GeneratorBlock(nn.Module):

  def __init__(self, w_dim, inputChannels, outputChannels, kernel_size, initial_size,upsample=True):
    super().__init__()

    self.upsample = upsample
    if self.upsample:
      self.upsample = nn.Upsample((initial_size), mode='bilinear')

    self.c = nn.Conv2d(inputChannels, outputChannels, kernel_size, padding=1)
    self.z = NoiseInjection(outputChannels)
    self.AdaIN = AdaIN(outputChannels, w_dim)
    self.activation = nn.LeakyReLU(0.2)

  # Given x and w, returns a StyleGAN generator block
  def feedForward(self, x, w):

    if self.upsample:
      x = self.upsample(x)

    x = self.c(x)
    x = self.z(x)
    x = self.activation(x)
    x = self.AdaIN(x, w)

    return x



In [None]:
class Generator(nn.Module):

  def __init__(self, z_dim, w_dim, hidden_dim, inputChannels, outputChannels, hiddenChannels, kernel_size):
    super().__init__()

    self.map = MappingNetwork(z_dim, w_dim, hidden_dim)
    self.constant = nn.Parameter(torch.randn(1, inputChannels, 4, 4))

    self.b0 = GeneratorBlock(w_dim, inputChannels, outputChannels, kernel_size, 4, upsample=False)
    self.b1 = GeneratorBlock(w_dim, hiddenChannels, hiddenChannels, kernel_size, 8)
    self.b2 = GeneratorBlock(w_dim, hiddenChannels, hiddenChannels, kernel_size, 16)

    self.b1_2_image = nn.Conv2d(hiddenChannels, outputChannels, kernel_size=1)
    self.b2_2_image = nn.Conv2d(hiddenChannels, outputChannels, kernel_size=1)

    self.alpha = 0.2


  # Upsampling small images to big images
  def Upsample(self, small, big):
    return F.interpolate(small, size=big.shape[-2:], mode='bilinear')

  
  def feedForward(self, z, return_intermediate=False):

    x = self.constant
    w = self.map(z)
    x = self.b0(x, w)

    x_small = self.b1(x, w)
    small_image = self.b1_2_image(x_small)

    x_big = self.b2(x_small, w)
    big_image = self.b2_2_image(x_big)

    x_upsample = self.Upsample(small_image, big_image)


    interpolation = self.alpha * (big_image) + (1-self.alpha) * (x_upsample)

    if return_intermediate:
      return interpolation, x_upsample, big_image
    return interpolation


In [None]:
plt.rcParams['figure.figsize'] = [15, 15]

z_dim = 128
w_dim = 496
hidden_dim = 1024
inputChannels = 512
outputChannels = 3
kernel_size=3
hiddenChannels = 256
truncation = 0.7

vZ = truncation_noise(10, z_dim, truncation) * 10

stylegan = Generator(z_dim, w_dim, hidden_dim, inputChannels, outputChannels, hiddenChannels, kernel_size)

stylegan.eval()
images= []
for alpha in np.linspace(0, 1, num=5):
  stylegan.alpha = alpha
  result,_,_ = stylegan(vZ, return_intermediate=True)
  images += [tensor for tensor in result]

show_images(torch.stack(images), nrow=10, n=len(images))
stylegan = stylegan.train()

TypeError: ignored