In [0]:
!pip install tfa-nightly

Collecting tfa-nightly
[?25l  Downloading https://files.pythonhosted.org/packages/f9/0e/c1ce899fac16a86ed7a153270673f27442cff4f543232dd0f4a187294a13/tfa_nightly-0.10.0.dev20200509080841-cp36-cp36m-manylinux2010_x86_64.whl (1.0MB)
[K     |████████████████████████████████| 1.0MB 2.8MB/s 
Installing collected packages: tfa-nightly
Successfully installed tfa-nightly-0.10.0.dev20200509080841


In [0]:
import tensorflow_addons as tfa

In [0]:
%tensorflow_version 2.x

In [0]:
import tensorflow as tf
print(tf.__version__)

2.2.0


In [0]:
from tensorflow.keras import layers

In [0]:
import numpy as np
import os
import time

In [0]:
import matplotlib.pyplot as plt

# Utils

In [0]:
def detach_keypoint(keypoint):
  return { key: tf.stop_gradient(value) for key, value in keypoint.items() }

## Blocks

In [0]:
class SameBlock2d(tf.keras.Model):
  def __init__(self, num_features):
    super(SameBlock2d, self).__init__()
    self.padding = [[0, 0], [1, 1], [1, 1], [0, 0]]
    self.conv = layers.Conv2D(num_features, (3, 3), strides=1, padding=[1, 1], use_bias=False)
    self.batch_norm = layers.BatchNormalization()
  
  def call(self, input_layer):
    block = self.conv(input_layer)
    block = self.batch_norm(block)
    block = layers.ReLU()(block)

    return block

In [0]:
class ResBlock2d(tf.keras.Model):
  def __init__(self, num_features):
    super(ResBlock2d, self).__init__()
    self.conv1 = layers.Conv2D(num_features, (3, 3), strides=1, padding=[1, 1], use_bias=False)
    self.conv2 = layers.Conv2D(num_features, (3, 3), strides=1, padding=[1, 1], use_bias=False)
    self.batch_norm1 = layers.BatchNormalization()
    self.batch_norm2 = layers.BatchNormalization()
  
  def call(self, input_layer):
    block = self.batch_norm1(input_layer)
    block = layers.ReLU()(block)
    block = self.conv1(block)
    block = self.batch_norm2(block)
    block = layers.ReLU()(block)
    block = self.conv2(block)

    block += input_layer

    return block

In [0]:
class DownBlock2d(tf.keras.Model):
  def __init__(self, num_features, norm, pool):
    super(DownBlock2d, self).__init__()
    self.norm = norm
    self.pool = pool
    self.conv = layers.Conv2D(num_features, (4, 4), strides=1, padding="valid")
    # compare this with nn.InstanceNorm2d of pytorch using a conv2d with the same weights in both frameworks
    self.instance_norm = tfa.layers.InstanceNormalization(axis=3)
  
  def call(self, input_layer):
    block = self.conv(input_layer)
    if self.norm:
      block = self.instance_norm(block)
    block = layers.ReLU()(block)
    if self.pool:
      block = layers.AveragePooling2D()(block)

    return block

## Interpolation

In [0]:
def interpolate_tensor(tensor_input, final_shape):
  original_shape = tensor_input.shape[1]

  if final_shape > original_shape:
    return interpolate_increase_size(tensor_input, final_shape)
  else:
    return interpolate_reduce_size(tensor_input, final_shape)

### Add size

In [0]:
def interpolate_increase_size(tensor_input, final_shape):
  original_shape = tensor_input.shape[1]
  difference = final_shape - original_shape
  border = difference / 2

  padding = [[0, 0], [int(border), int(border)], [int(border), int(border)], [0, 0]]

  width = final_shape
  height = final_shape
  x = tf.linspace(border, width - (border + 1), width)
  yy = tf.tile(tf.reshape(x, (-1, 1)), [1, width])
  xx = tf.tile(tf.reshape(x, (1, -1)), [height, 1])

  grid = tf.concat([tf.expand_dims(xx, axis=2), tf.expand_dims(yy, axis=2)], axis=2)

  output = tfa.image.resampler(tf.pad(tensor_input, padding), tf.expand_dims(grid, axis=0))

  return output

### Reduce size

In [0]:
def interpolate_reduce_size(tensor_input, final_shape):
  original_shape = tensor_input.shape[1]
  width = original_shape
  height = original_shape
  x = tf.linspace(0.0, width - 1, final_shape)

  yy = tf.tile(tf.reshape(x, (-1, 1)), [1, final_shape])
  xx = tf.tile(tf.reshape(x, (1, -1)), [final_shape, 1])

  grid = tf.concat([tf.expand_dims(xx, axis=2), tf.expand_dims(yy, axis=2)], axis=2)

  output = tfa.image.resampler(tensor_input, tf.expand_dims(grid, axis=0))

  return output

## Anti Aliasing

In [0]:
class AntiAliasInterpolation(tf.keras.Model):
  def __init__(self, channels, scale):
    super(AntiAliasInterpolation, self).__init__()
    sigma = (1 / scale - 1) / 2
    kernel_size = 2 * round(sigma * 4) + 1
    self.scale = scale

    kernel_size = [kernel_size, kernel_size]
    sigma = [sigma, sigma]

    kernel = 1
    meshgrids = tf.meshgrid(*[tf.keras.backend.arange(size, dtype='float32') for size in kernel_size], indexing='ij')

    for size, std, mgrid in zip(kernel_size, sigma, meshgrids):
      mean = (size - 1) / 2
      kernel *= tf.math.exp(-(mgrid - mean) ** 2 / (2 * std ** 2))

    kernel = kernel / tf.keras.backend.sum(kernel)

    #[kernel_height, kernel_width, channels, num_kernels))]
    kernel = tf.reshape(kernel, [*kernel.shape, 1, 1])
    kernel = tf.tile(kernel, [1, 1, channels, 1])

    # Important since we want to apply the kernel to each channel dimension
    self.conv = layers.DepthwiseConv2D(kernel_size=kernel_size, strides=1, use_bias=False, padding="same", weights=[kernel])

  def call(self, input):
    if self.scale == 1.0:
      return input

    out = self.conv(input)
    new_size = int(self.scale * input.shape[1])
    out = interpolate_tensor(out, new_size)

    return out

## Image Pyramide

In [0]:
class ImagePyramide(tf.keras.Model):
  def __init__(self, scales):
    super(ImagePyramide, self).__init__()
    self.num_channels = 3
    self.downs = {}
    
    for scale in scales:
      self.downs[str(scale).replace('.', '-')] = AntiAliasInterpolation(self.num_channels, scale)

  def call(self, x):
    out_dict = {}
    
    for scale, down_module in self.downs.items():
      out_dict['prediction_' + str(scale).replace('-', '.')] = down_module(x)

    return out_dict

## Hourglass

In [0]:
class DownBlock(tf.keras.Model):
  def __init__(self, num_features):
    super(DownBlock, self).__init__()
    self.padding = [[0, 0], [1, 1], [1, 1], [0, 0]]
    self.conv = layers.Conv2D(num_features, (3, 3), strides=1, padding=self.padding, use_bias=False)
    self.batch_norm = layers.BatchNormalization()
  
  def call(self, input_layer):
    block = self.conv(input_layer)
    block = self.batch_norm(block)
    block = layers.ReLU()(block)
    block = layers.AveragePooling2D()(block)

    return block

In [0]:
class UpBlock(tf.keras.Model):
  def __init__(self, num_features):
    super(UpBlock, self).__init__()
    self.padding = [[0, 0], [1, 1], [1, 1], [0, 0]]
    self.conv = layers.Conv2D(num_features, (3, 3), strides=1, padding=self.padding, use_bias=False)
    self.batch_norm = layers.BatchNormalization()
  
  def call(self, input_layer):
    block = layers.UpSampling2D(size=(2, 2), interpolation='bilinear')(input_layer)
    block = self.conv(block)
    block = self.batch_norm(block)
    block = layers.ReLU()(block)

    return block

In [0]:
class Hourglass(tf.keras.Model):
  def __init__(self, down_features_list, up_features_list, num_blocks=5):
    super(Hourglass, self).__init__()
    encoder_blocks = []

    for i in range(num_blocks):
      num_features = down_features_list[i]
      encoder_blocks.append(DownBlock(num_features))
    
    self.encoder_blocks = encoder_blocks

    decoder_blocks = []

    for i in range(num_blocks):
      num_features = up_features_list[i]
      decoder_blocks.append(UpBlock(num_features))
    
    self.decoder_blocks = decoder_blocks
  
  def call(self, x):
    down_block_list = [x]
    for down_block in self.encoder_blocks:
      down_block_list.append(down_block(down_block_list[-1]))
    
    model = down_block_list.pop()

    for up_block in self.decoder_blocks:
      model = up_block(model)
      skip = down_block_list.pop()
      model = tf.concat([model, skip], axis=-1)
    
    return model

## Coordinate grid

In [0]:
def make_coordinate_grid(spatial_size, type):
  height, width = spatial_size
  x = tf.keras.backend.arange(width, dtype=type)
  y = tf.keras.backend.arange(height, dtype=type)

  x = (2 * (x / (width - 1)) - 1)
  y = (2 * (y / (height - 1)) - 1)
  
  yy = tf.tile(tf.reshape(y, (-1, 1)), [1, width])
  xx = tf.tile(tf.reshape(x, (1, -1)), [height, 1])

  meshed = tf.concat([tf.expand_dims(xx, axis=2), tf.expand_dims(yy, axis=2)], axis=2)
  # shape 256, 256, 2

  return meshed

## keypoints to gaussian

In [0]:
def keypoints_to_gaussian(keypoints, spatial_size, kp_variance):
  # TD<-R or TS<-R in equation (6)
  mean = keypoints["value"]
  # shape batch, 10, 2

  # Z in equation (6)
  coordinate_grid = make_coordinate_grid(spatial_size, mean.dtype)
  # shape height x width x 2
  
  coordinate_grid = tf.expand_dims(tf.expand_dims(coordinate_grid, axis=0), axis=0)
  # 1 x 1 x height x width x 2

  repeats = mean.shape[:2] + (1, 1, 1)
  # batch x 10 x 1 x 1 x 1
  coordinate_grid = tf.tile(coordinate_grid, multiples=repeats)
  # batch x 10 x height x width x 2

  # Preprocess kp shape
  shape = mean.shape[:2] + (1, 1, 2)
  # batch x 10 x 1 x 1 x 2

  mean = tf.reshape(mean, shape)
  # batch x 10 x 1 x 1 x 2

  mean_sub = (mean - coordinate_grid)
  # batch x 10 x 1 x 1 x 2 -
  # batch x 10 x height x width x 2

  out = tf.exp(-0.5 * tf.keras.backend.sum(mean_sub ** 2, axis=-1) / kp_variance)
  # batch x 10 x height x width

  return out

## VGG19

In [0]:
class Vgg19(tf.keras.Model):
  def __init__(self):
    layers = ['block1_conv2', 'block2_conv2', 'block3_conv2', 'block4_conv2', 'block5_conv2'] 
    vgg = tf.keras.applications.VGG19(include_top=False, weights='imagenet')
    vgg.trainable = False
    
    outputs = [vgg.get_layer(name).output for name in layers]

    self.model = tf.keras.Model([vgg.input], outputs)
  
  def call(self, x):
    x = tf.keras.applications.vgg19.preprocess_input(x)
    return model(x)

# Keypoint Detector

In [0]:
# height, width, channels = 3
class KeypointDetector(tf.keras.Model):
  def __init__(self):
    super(KeypointDetector, self).__init__()
    self.scale_factor = 0.25
    self.num_jacobian_maps = 10
    self.num_keypoints = 10
    self.num_channels = 3
    self.down_features_list = [64, 128, 256, 512, 1024]
    self.up_features_list = [512, 256, 128, 64, 32]
    self.num_blocks = 5
    self.temperature = 0.1

    self.predictor = Hourglass(self.down_features_list, self.up_features_list, self.num_blocks) # Outputs height x width x 35
    self.keypoints_map = layers.Conv2D(self.num_keypoints, (7, 7), strides=1, padding='valid')

    # Initialize the weights/bias with identity transformation localisation network
    weigth_initializer = tf.keras.initializers.zeros()
    bias_initializer = tf.keras.initializers.constant([1, 0, 0, 1] * 10)
    self.jacobian = layers.Conv2D(self.num_keypoints * 4, (7, 7), strides=1, padding='valid', bias_initializer=bias_initializer, kernel_initializer=weigth_initializer)

    self.down = AntiAliasInterpolation(self.num_channels, self.scale_factor)
  
  def get_gaussian_keypoints(self, heatmap):
    # heatmaps are confidence maps
    # compute soft-argmax (get the coords of the maximum values of the heatmap) differentiable
    heatmap = tf.expand_dims(heatmap, -1)
    # shape batch x 250 x 250 x 10 x 1
    grid = make_coordinate_grid(heatmap.shape[1:3], heatmap.dtype)
    # shape 250 x 250 x 2
    grid = tf.expand_dims(grid, axis=2)
    # shape 250 x 250 x 1 x 2
    grid = tf.expand_dims(grid, axis=0)
    # shape 1 x 250 x 250 x 1 x 2

    value = heatmap * grid
    # shape batch x 250 x 250 x 10 x 2
    value = tf.keras.backend.sum(value, axis=[1, 2])
    # shape batch x 10 x 2

    # keypoints are in a range [-1, 1] due to the grid
    kp = {'value': value}

    return kp
  
  def call(self, x):
    model = self.down(x)
    feature_map = self.predictor(model)
    raw_keypoints = self.keypoints_map(feature_map)

    final_shape = raw_keypoints.shape # pytorch 4, 10, 5, 5 tf: 4, 5, 5, 10

    heatmap = tf.keras.activations.softmax(raw_keypoints / self.temperature, axis=[1, 2])
    # temperature increase the values so is easier to compute the soft-argmax
    final_keypoints = self.get_gaussian_keypoints(heatmap)

    jacobian_map = self.jacobian(feature_map)
    # batch x height x width x 40

    jacobian_map = tf.reshape(jacobian_map, [final_shape[0], final_shape[1], final_shape[2], self.num_jacobian_maps, 4])
    # batch x height x width x 10 x 4

    heatmap = tf.expan_dims(heatmap, axis=-1)
    # batch x height x width x 10 x 1

    jacobian = heatmap * jacobian_map # reduce the importance of the places far from the keypoints coords
    # batch x height x width x 10 x 4

    jacobian = tf.reshape(jacobian, [final_shape[0], -1, final_shape[3], 4])
    # batch x (height * width) x 10 x 4

    jacobian = tf.keras.backend.sum(jacobian, axis=1)
    # batch x 10 x 4

    jacobian = tf.reshape(jacobian, [jacobian.shape[0], jacobian.shape[1], 2, 2])
    # batch x 10 x 2 x 2
    # shape batch, 10, 2, 2 where 10: keypoints each with a jacobian of size 2x2

    final_keypoints['jacobian'] = jacobian

    return final_keypoints

# Dense motion network

In [0]:
class DenseMotionNetwork(tf.keras.Model):
  def __init__(self):
    super(DenseMotionNetwork, self).__init__()
    # input shape height x width x 44
    self.num_blocks = 5
    self.num_channels = 3
    self.num_keypoints = 10
    self.scale_factor = 0.25
    self.kp_variance = 0.01
    self.down_features_list = [128, 256, 512, 1024, 1024]
    self.up_features_list = [1024, 512, 256, 128, 64]
    self.padding = [[0, 0], [3, 3], [3, 3], [0, 0]] # pad only height, width

    self.hourglass = Hourglass(self.down_features_list, self.up_features_list, self.num_blocks) # Outputs height x width x 67
    self.mask = layers.Conv2D(self.num_keypoints + 1, (7, 7), strides=1, padding="valid")
    self.occlusion = layers.Conv2D(1, (7, 7), strides=1, padding="valid")
    self.down = AntiAliasInterpolation(self.num_channels, self.scale_factor)
  
  def create_heatmap_representations(self, image_size, kp_driving, kp_source):
    spatial_size = image_size[1:4]
    # shape 256 x 256

    gaussian_driving = keypoints_to_gaussian(kp_driving, spatial_size=spatial_size, kp_variance=self.kp_variance)
    gaussian_source = keypoints_to_gaussian(kp_source, spatial_size=spatial_size, kp_variance=self.kp_variance)
    # shape batch, 10, 256, 256

    heatmap = gaussian_driving - gaussian_source
    # batch x 10 x 256 x 256

    zeros = tf.zeros((heatmap.shape[0], 1, spatial_size[0], spatial_size[1]), dtype=heatmap.dtype)
    # shape batch x 1 x 256 x 256

    heatmap = tf.concat([zeros, heatmap], axis=1)
    # shape batch x 11 x 256 x 256

    heatmap = tf.expand_dims(heatmap, axis=-1)
    # shape batch x 11 x 256 x 256 x 1

    return heatmap
  
  def create_sparse_motions(self, image_size, kp_driving, kp_source):
    batch_size, height, width, _ = image_size
    # Z in equation (4)
    identity_grid = make_coordinate_grid((height, width), type=kp_source['value'].dtype)
    # shape 256 x 256 x 2

    identity_grid = tf.expand_dims(tf.expand_dims(identity_grid, axis=0), axis=0)
    # shape 1 x 1 x 256 x 256 x 2

    # TD<-R in equation (4)
    driving_keypoints = kp_driving['value']
    # shape batch x 10 x 2
    shape = driving_keypoints.shape[:2] + (1, 1, 2)
    driving_keypoints = tf.reshape(driving_keypoints, shape)
    # shape batch, 10, 1, 1, 2

    # Z - TD<-R in equation (4)
    coordinate_grid = identity_grid - driving_keypoints
    # shape batch, 10, 256, 256, 2

    # Using the inverse of d/dp Td <- R ; Equation (5) Jk
    jacobian = tf.linalg.matmul(kp_source['jacobian'], tf.linalg.inv(kp_driving['jacobian']))
    # shape batch x 10 x 2 x 2

    jacobian = tf.expand_dims(tf.expand_dims(jacobian, axis=-3), axis=-3)
    # shape batch x 10 x 1 x 1 x 2 x 2

    jacobian = tf.tile(jacobian, [1, 1, height, width, 1, 1])
    # shape batch x 10 x 256 x 256 x 2 x 2

    # Jk . (Z - TD<-R) in equation (4)
    coordinate_grid = tf.linalg.matmul(jacobian, tf.expand_dims(coordinate_grid, axis=-1))
    # shape batch x 10 x 256 x 256 x 2 x 1

    coordinate_grid = tf.squeeze(coordinate_grid) # remove last axis
    # shape batch x 10 x 256 x 256 x 2

    source_keypoints = kp_source['value']
    # shape batch x 10 x 2    

    shape = source_keypoints.shape[:2] + (1, 1, 2)
    source_keypoints = tf.reshape(source_keypoints, shape)
    # shape batch x 10 x 1 x 1 x 2

    # Ts <- D(z) where source_keypoints is TS<-R and coordinate_grid is Jk . (Z - TD<-R)
    driving_to_source = source_keypoints + coordinate_grid 
    # shape batch x 10 x 256 x 256 x 2
               
    # Adding background feature, background feature is just the identity_grid without motions
    identity_grid = tf.tile(identity_grid, [batch_size, 1, 1, 1, 1])
    # shape batch x 1 x 256 x 256 x 2

    sparse_motions = tf.concat([identity_grid, driving_to_source], axis=1)
    # shape batch x 11 x 256 x 256 x 2 
    # 11 channels since we estimate the taylor aproximation for each keypoint

    return sparse_motions

  def create_deformed_source_image(self, source_image, sparse_motions):
    batch_size, _, height, width = source_image.shape
    # batch x 256 x 256 x 3

    source_repeat = tf.expand_dims(tf.expand_dims(source_image, axis=1))
    # batch x 1 x 256 x 256 x 3
    
    source_repeat = tf.tile(source_repeat, [1, self.num_keypoints + 1, 1, 1, 1])
    # batch x 11 x 256 x 256 x 3

    source_repeat = tf.reshape(source_repeat, [batch_size * (self.num_keypoints + 1), height, width, -1])
    # (batch . 11) x 256 x 256 x 3
    
    sparse_motions = tf.reshape(sparse_motions, [batch_size * (self.num_keypoints + 1), height, width, -1])
    # (batch . 11) x 256 x 256 x 2

    new_max = width - 1
    new_min = 0
    sparse_motions = (new_max - new_min) / (tf.keras.backend.max(sparse_motions) - tf.keras.backend.min(sparse_motions)) * (sparse_motions - tf.keras.backend.max(sparse_motions)) + new_max

    sparse_deformed = tfa.image.resampler(source_repeat, sparse_motions)
    # (batch . 11) x 256 x 256 x 3

    sparse_deformed = tf.reshape(sparse_deformed, [batch_size, (self.num_keypoints + 1), height, width, -1])
    # batch x 11 x 256 x 256 x 3

    return sparse_deformed
  
  def call(self, source_image, kp_driving, kp_source):
    source_image = self.down(source_image)
    image_size = source_image.shape
    batch_size, height, width, _ = image_size
    out_dict = dict()

    heatmap_representation = self.create_heatmap_representations(image_size, kp_driving, kp_source)
    # shape batch x 11 x 256 x 256 x 1
    sparse_motion = self.create_sparse_motions(image_size, kp_driving, kp_source) # keypoint k of d to s
    # shape batch x 11 x 256 x 256 x 2
    warped_images = self.create_deformed_source_image(source_image, sparse_motion)
    # shape batch x 11 x 256 x 256 x 3

    # Debug/print
    out_dict['warped_images'] = warped_images # sparse_deformed

    input = tf.concat([heatmap_representation, warped_images], axis=-1)
    # shape batch x 22 x 256 x 256 x 3

    input = tf.permute(input, [0, 2, 3, 1, 4])
    # shape batch x 256 x 256 x 22 x 3
    
    input = tf.reshape(input, [batch_size, height, width, -1])
    # shape batch x 256 x 256 x 66

    prediction = self.hourglass(input)
    # batch x height x width x 35

    prediction = tf.pad(prediction, self.padding)

    mask = self.mask(prediction)
    # batch x height x width x 11

    mask = tf.keras.activations.softmax(mask)
    # Along the last axis. Thus, we don't repeat values along axes (each keypoint only appears in one channel)
    # batch x height x width x 11

    # Debug/print
    out_dict['mask'] = mask

    mask = tf.expand_dims(mask, axis=-1)
    # batch x height x width x 11 x 1

    sparse_motion = tf.transpose(sparse_motion, [0, 2, 3, 1, 4])
    # batch x 256 x 256 x 11 x 2

    deformation = (sparse_motion * mask)
    deformation = tf.keras.backend.sum(sparse_motion, axis=3) 
    # batch x 256 x 256 x 2

    out_dict['dense_optical_flow'] = deformation # deformation

    occlusion_map = tf.keras.activations.sigmoid(self.occlusion(prediction))
    # shape batch x 256 x 256 x 1

    out_dict['occlusion_map'] = occlusion_map

    return out_dict

# Generator

In [0]:
class Generator(tf.keras.Model):
  def __init__(self):
    super(Generator, self).__init__()
    block_expansion = 64
    self.padding = [[0, 0], [3, 3], [3, 3], [0, 0]] # pad only height, width
    self.num_channels = 3
    self.num_keypoints = 10
    self.num_blocks = 2
    self.num_bottleneck_blocks = 6
    
    self.dense_motion_network = DenseMotionNetwork()
    self.first = SameBlock2d(block_expansion)

    self.down_features_list = [128, 256]
    self.up_features_list = [128, 64] 

    encoder_blocks = []

    for i in range(num_blocks):
      num_features = self.down_features_list[i]
      encoder_blocks.append(DownBlock(num_features))
    
    self.encoder_blocks = encoder_blocks

    decoder_blocks = []

    for i in range(num_blocks):
      num_features = self.up_features_list[i]
      decoder_blocks.append(UpBlock(num_features))
    
    self.decoder_blocks = decoder_blocks

    self.bottleneck = tf.keras.Sequential()

    for i in range(num_bottleneck_blocks):
      self.bottleneck.add(ResBlock2d(self.down_features_list[-1]))

    self.final = layers.Conv2D(self.num_channels, strides=1, padding="valid")

  def deform_input(self, x, deformation):
    _, height_old, width_old, _ = deformation.shape
    _ height, width, _ = x.shape

    if height_old != height or width_old != width:
      deformation = interpolate_tensor(deformation, width)

    new_max = width - 1
    new_min = 0
    deformation = (new_max - new_min) / (tf.keras.backend.max(deformation) - tf.keras.backend.min(deformation)) * (deformation - tf.keras.backend.max(deformation)) + new_max

    return tfa.image.resampler(x, deformation)
  
  def call(self, source_image, kp_driving, kp_source):
    out = self.first(tf.pad(source_image, self.padding))

    for down_block in self.encoder_blocks:
      out = down_block(out)
    
    output_dict = {}

    dense_motion = self.dense_motion_network(source_image, kp_driving, kp_source)

    # Debug/print                                       
    output_dict['mask'] = dense_motion['mask']
    # Debug/print
    output_dict['warped_images'] = dense_motion['warped_images'] # sparse_deformed

    occlusion_map = dense_motion['occlusion_map']
    # shape batch x 256 x 256 x 1
    
    # Debug/print
    output_dict['occlusion_map'] = occlusion_map

    dense_optical_flow = dense_motion['dense_optical_flow'] # deformation
    # batch x 256 x 256 x 2 
    out = self.deform_input(out, dense_optical_flow)
    # batch x 256 x 256 x 2 

    if out.shape[1] != occlusion_map.shape[1] or out.shape[2] != occlusion_map.shape[2]:
      occlusion_map = interpolate_tensor(occlusion_map, out[1])
    
    out = out * occlusion_map

    # Debug/print
    output_dict["aligned_features"] = self.deform_input(source_image, dense_optical_flow) # deformed
    
    # Decoder part

    out = self.bottleneck(out)

    for up_block in self.decoder_blocks:
      out = up_block(out)

    out = self.final(tf.pad(out, self.padding))
    out = tf.keras.activations.sigmoid(out)

    output_dict["prediction"] = out

    return output_dict

# Discriminator

In [0]:
class Discriminator(tf.keras.Model):
  def __init__(self):
    super(Discriminator, self).__init__()
    self.num_channels = 3
    self.num_keypoints = 10
    self.num_blocks = 4
    self.num_features_list = [64, 128, 256, 512] 
    self.kp_variance = 0.01

    encoder_blocks = []

    for i in range(num_blocks):
      num_features = self.num_features_list[i]
      norm = i != 0
      pool = i != num_blocks - 1
      encoder_blocks.append(DownBlock2d(num_features, norm, pool))
    
    self.encoder_blocks = encoder_blocks
    self.conv = layers.Conv2D(1, kernel_size=1)

  def call(self, x, key_points):
    feature_maps = []
    out = x
    # batch x height x width x 3
    heatmap = keypoints_to_gaussian(key_points, x.shape[1:3], self.kp_variance)
    # batch x 10 x height x width
    out = tf.concat([out, tf.permute(heatmap, [0, 2, 3, 1])], axis=-1)
    # batch x height x width x 13

    for block in self.encoder_blocks:
      feature_maps.append(block(out))
      out = feature_maps[-1]
    
    prediction_map = self.conv(out)
    # batch x height x width x 1

    return feature_maps, prediction_map

# Transform

In [0]:
class Transform:
  def __init__(self, batch_size):
    self.sigma_affine = 0.05
    self.sigma_tps = 0.005
    self.points_tps = 5
    noise = tf.random.normal((batch_size, 2, 3), mean=0, stddev=self.sigma_affine)
    # eye returns an identity matrix
    self.theta = noise + tf.expand_dims(tf.eye(2, 3), axis=0)
    # shape batch x 2 x 3
    self.batch_size = batch_size

    self.control_points = make_coordinate_grid((self.points_tps, self.points_tps), type=noise.dtype)
    self.control_points = tf.expand_dims(self.control_points, axis=0)
    # shape 1 x 5 x 5 x 2
    self.control_params = tf.random.normal((batch_size, 1, self.points_tps ** 2), mean=0, stddev=self.sigma_tps)
    # shape batch x 1 x 25
  
  def transform_frame(self, frame):
    grid = make_coordinate_grid(frame.shape[1:3], type=frame.dtype)
    grid = tf.expand_dims(grid, axis=0)
    grid = tf.reshape(grid, [1, frame.shape[1] * frame.shape[2], 2])
    # shape 1 x (height * width) x 2
    grid = self.warp_coordinates(grid)
    # batch x new_size x 2
    grid = tf.reshape(grid, [self.batch_size, frame.shape[1], frame.shape[2], 2])
    # batch x 256 x 256 x 2
    
    new_max = frame.shape[2] - 1
    new_min = 0
    grid = (new_max - new_min) / (tf.keras.backend.max(grid) - tf.keras.backend.min(grid)) * (grid - tf.keras.backend.max(grid)) + new_max

    return tfa.image.resampler(frame, grid)
    # return F.grid_sample(frame, grid, padding_mode="reflection")
  
  def warp_coordinates(self, coordinates):
    theta = tf.cast(self.theta, coordinates.dtype)
    theta = tf.expand_dims(theta, axis=1)
    # shape batch x 1 x 2 x 3

    # coordinates shape can be 
    #     1 x (height * width) x 2
    # batch x num_keypoints x 2
    transformed = tf.linalg.matmul(theta[:, :, :, :2], tf.expand_dims(coordinates, axis=-1)) + theta[:, :, :, 2:]
    # shape batch x (height * width) x 2 x 1 or
    # shape batch x num_keypoints x 2 x 1
    transformed = tf.squeeze(transformed, axis=-1)
    # shape batch x (height * width) x 2

    control_points = tf.cast(self.control_points, coordinates.dtype)
    # shape 1 x 5 x 5 x 2
    control_params = tf.cast(self.control_params, coordinates.dtype)
    # shape batch x 1 x 25

    # coordinates = xi, yi,  control_points = x, y
    distances = tf.reshape(coordinates, [coordinates.shape[0], -1, 1, 2]) - tf.reshape(control_points, [1, 1, -1, 2])
    # shape 1 x new_size x 25 x 2

    distances = tf.keras.backend.sum(tf.abs(distances), axis=-1)
    # shape 1 x new_size x 25

    result = distances ** 2

    result = result * tf.math.log(distances + 1e-6)
    result = result * control_params
    # batch x new_size x 25

    result = tf.keras.backend.sum(result, axis=2)
    # batch x new_size
    result = tf.reshape(result, [self.batch_size, coordinates.shape[1], 1])
    # batch x new_size x 1
    transformed = transformed + result
    # batch x new_size x 2

    return transformed

  def jacobian(self, coordinates):
    new_coordinates = self.warp_coordinates(coordinates)
    x = tf.keras.backend.sum(new_coordinates[..., 0])
    y = tf.keras.backend.sum(new_coordinates[..., 1])

    grad_x = tape.gradient(x, coordinates) 
    grad_y = tape.gradient(y, coordinates)

    return tf.concat([tf.expand_dims(grad_x, axis=-2), tf.expand_dims(grad_y, axis=-2)], axis=-2)

# Full Generator

In [0]:
class FullGenerator(tf.keras.Model):
  def __init__(self, key_point_detector, generator, discriminator):
    super(FullGenerator, self).__init__()
    self.feature_matching_weights = 10
    self.equivariance_weights = 10
    self.perceptual_weights = [10, 10, 10, 10, 10]
    self.scales = [1, 0.5, 0.25, 0.125]
    self.vgg = Vgg19()
    self.pyramid = ImagePyramide(self.scales)
    self.key_point_detector = key_point_detector
    self.generator = generator
    self.discriminator = discriminator

  def call(self, source_images, driving_images):
    kp_source = self.key_point_detector(source_images)
    kp_driving = self.key_point_detector(driving_images)

    generated = self.generator(source_images, kp_source=kp_source, kp_driving=kp_driving)
    # Debug/print
    generated.update({'kp_source': kp_source, 'kp_driving': kp_driving})

    loss_values = {}

    pyramide_real = self.pyramid(driving_images)
    pyramide_generated = self.pyramid(generated['prediction'])

    # Perceptual loss (Loss for gan generator)
    perceptual_loss = 0
    for scale in self.scales:
      x_vgg = self.vgg(pyramide_generated['prediction_' + str(scale)])
      y_vgg = self.vgg(pyramide_real['prediction_' + str(scale)])

      for i, weight in enumerate(self.perceptual_weights):
        loss = tf.reduce_mean(tf.abs(x_vgg[i] - tf.stop_gradient(y_vgg[i])))
        perceptual_loss += self.perceptual_weights[i] * loss
      loss_values['perceptual'] = perceptual_loss
    
    # Gan loss (only one scale used, the original [1])

    # We detach the keypoints here so we dont compue its gradients and we use it as input images!!!
    discriminator_maps_real, _ = self.discriminator(driving_images, kp=detach_keypoint(kp_driving))
    discriminator_maps_generated, discriminator_pred_map_generated = self.discriminator(generated['prediction'], kp=detach_keypoint(kp_driving))
    
    # LSGAN G Loss
    # Discriminator outputs a pathmap like pix2pix where 1 labels are for real images and 0 labels are for generated images
    # Since we want to fool the discriminator we want our generated images to output 1
    gan_loss = tf.reduce_mean((discriminator_pred_map_generated - 1) ** 2)
    # same as tf.reduce_mean(tf.keras.losses.mean_squared_error(tf.ones_like(discriminator_pred_map_generated), discriminator_pred_map_generated))
    gan_loss += self.loss_weights['generator_gan'] * gan_loss
    loss_values['gen_gan'] = gan_loss
    
    # feature_matching loss
    feature_matching_loss = tf.reduce_mean(tf.abs(discriminator_maps_real - discriminator_maps_generated))
    feature_matching_loss += self.feature_matching_weights * feature_matching_loss

    loss_values['feature_matching'] = feature_matching_loss

    # Equivariance Loss
    batch_size = driving_images.shape[0]
    transform = Transform(batch_size)

    transformed_frame = transform.transform_frame(driving_images)
    # image Y
    # shape batch x height x width x 2

    transformed_keypoints = self.key_point_detector(transformed_frame)
    # Ty <-R

    # Debug/print
    generated['transformed_frame'] = transformed_frame
    # Debug/print
    generated['transformed_kp'] = transformed_keypoints

    keypoints_loss = tf.reduce_mean(tf.abs(kp_driving['value'] - transform.warp_coordinates(transformed_keypoints['value'])))
    loss_values['equivariance_value'] = self.equivariance_weights * keypoints_loss

    # Here we apply the transformation for a second time and then compute the jacobian
    jacobian_transformed = tf.linalg.matmul(transform.jacobian(transformed_keypoints['value']), transformed_keypoints['jacobian'])
    # Equivariance properties

    normed_driving = tf.linalg.inv(kp_driving['jacobian']) #inverse of Tx <-R
    normed_transformed = jacobian_transformed

    jacobian_mul = tf.linalg.matmul(normed_driving, normed_transformed)
    identity_matrix = tf.cast(tf.reshape(tf.eye(2), [1, 1, 2, 2]), jacobian_mul.dtype)
    jacobian_loss = tf.reduce_mean(tf.abs(identity_matrix - jacobian_mul))
    loss_values['equivariance_jacobian'] = self.equivariance_weights * jacobian_loss

    return loss_values, generated    

# Full Discriminator

In [0]:
class FullDiscriminator(tf.keras.Model):
  def __init__(self, discriminator):
    super(FullDiscriminator, self).__init__()
    self.discriminator = discriminator

  def call(self, x_driving, generated):
    kp_driving = generated['kp_driving']

    loss_values = {}

    _, discriminator_pred_map_real = self.discriminator(x_driving, kp=detach_keypoint(kp_driving))
    _, discriminator_pred_map_generated = self.discriminator(tf.stop_gradient(generated['prediction']), kp=detach_keypoint(kp_driving))
    
    # LSGAN
    discriminator_loss = (1 - discriminator_pred_map_real) ** 2 + discriminator_pred_map_generated ** 2
    # Where discriminator_pred_map_real should output 1's and discriminator_pred_map_generated 0's
    loss_values['disc_gan'] = tf.reduce_mean(discriminator_loss)

    return loss_values

# Training

In [0]:
lr = 2e-4

optimizer_generator = tf.keras.optimizers.Adam(learning_rate=lr, beta_1=0.5, beta_2=0.999)
optimizer_keypoint_detector = tf.keras.optimizers.Adam(learning_rate=lr, beta_1=0.5, beta_2=0.999)
optimizer_discriminator = tf.keras.optimizers.Adam(learning_rate=lr, beta_1=0.5, beta_2=0.999)

batch_size = 20
epochs = 150

In [0]:
generator_full = GeneratorFullModel(kp_detector, generator, discriminator, train_params)
discriminator_full = DiscriminatorFullModel(kp_detector, generator, discriminator, train_params)

In [0]:
@tf.function
def train_step(source_images, driving_images):
  with tf.GradientTape(persistent=True) as tape: 
    losses_generator, generated = generator_full(source_images, driving_images)
    generator_loss = tf.math.reduce_sum(list(losses_generator.values()))

  generator_gradients = tape.gradient(generator_loss, generator_full.trainable_variables)
  keypoint_detector_gradients = tape.gradient(generator_loss, keypoint_detector.trainable_variables)

  optimizer_generator.apply_gradients(zip(generator_gradients, generator_full.trainable_variables))
  optimizer_keypoint_detector.apply_gradients(zip(keypoint_detector_gradients, keypoint_detector.trainable_variables))

  with tf.GradientTape() as tape:
    losses_discriminator = discriminator_full(x)
    discriminator_loss = tf.math.reduce_sum(list(losses_discriminator.values()))
  
  discriminator_gradients = tape.gradient(discriminator_loss, discriminator_full.trainable_variables)
  optimizer_discriminator.apply_gradients(zip(discriminator_gradients, discriminator_full.trainable_variables))

  return generator_loss + discriminator_loss

In [0]:
def decay_lr(optimizer, epoch):
  if epoch >= 60 and epoch <= 90:
    current_lr = tf.keras.backend.get_value(optimizer.lr)
    new_lr = current_lr * 0.1
    tf.keras.backend.set_value(optimizer.lr, new_lr)

In [0]:
def train(epochs, total_steps):
  for epoch in range(epochs):
    batch_time = time.time()
    epoch_time = time.time()
    step = 0

    epoch_count = f"0{epoch + 1}/{epochs}" if epoch < 9 else f"{epoch + 1}/{epochs}"

    for source_images, driving_images in zip(images_batches, labels_batches, masks_batches):
      total_loss = train_step(source_images, driving_images)

      loss = float(loss.numpy())
      step += 1

      print('\r', 'Epoch', epoch_count, '| Step', f"{step}/{train_steps}",
            '| loss:', f"{loss:.5f}", "| Step time:", f"{time.time() - batch_time:.2f}", end='')    
      
      batch_time = time.time()
      total_steps += 1

    loss_results.append(loss)
    decay_lr(optimizer_generator, epoch)
    decay_lr(optimizer_keypoint_detector, epoch)
    decay_lr(optimizer_discriminator, epoch)

    print('\r', 'Epoch', epoch_count, '| Step', f"{step}/{train_steps}",
          '| loss:', "| Epoch time:", f"{time.time() - epoch_time:.2f}")

# Interpolation test

In [0]:
def interpolate(xs, ys, A, B):
  epsilon = 0.0000001

  # A second row
  A_1, A_2 = A
  # B first row
  B_1, B_2 = B

  # Always take both numbers into account but only choose x near to one or another pixel if x or y doesn't move

  # Left to right coordinates
  x1, x2, x = xs

  weight_1 = (x - x1)/((x2 - x1) + epsilon)
  print("weight_1", weight_1)

  R1 = A_1 + weight_1 * (A_2 - A_1) # Second row
  R2 = B_1 + weight_1 * (B_2 - B_1) # First row

  print("R1", R1)
  print("R2", R2)

  # Bottom to top coordinates
  y1, y2, y = ys

  weight_2 = (y - y1)/((y2 - y1) + epsilon)
  print("weight_2", weight_2)

  P = R2 + weight_2 * (R1 - R2)

  print("P", P)

### Operations

In [0]:
x1 = -0.2
x2 = 0.2
x = -0.14285714

weight_1 = (x - x1)/(x2 - x1)

A_1 = 45
A_2 = 212

R1 = A_1 + weight_1 * (A_2 - A_1)

B_1 = 220
B_2 = 199

R2 = B_1 + weight_1 * (B_2 - B_1)
print(R2, "R2")

y2 = 0.6
y1 = 1
y = 1

weight_2 = (y2 - y)/(y2 - y1)

P = R2 + weight_2 * (R1 - R2)
print(P)

216.99999985 R2
68.85714404999999


In [0]:
[133.0000, 166.5714, 125.5714,  47.2857,  18.7143,  57.0000,  96.5714, 113.0000]
[ 77.2857,  93.0000,  93.2245,  83.6122,  75.9592,  55.9796,  59.9388, 104.4286]

In [0]:
[ 43.4286,  14.1429,  74.2449, 174.6735, 198.0408, 207.6939, 217.9592, 234.2857],
[ 67.0000,  19.8571,  19.8571,  68.8571, 188.1429, 228.0000, 241.4286, 245.0000]

In [0]:
numpy_tensor = np.array([[133,  180, 53,  13, 90, 113],
[55,  67, 98,  99, 23, 101],
[34,  54, 89,  4, 12, 5],
[56,  2, 3,  112, 45, 156],
[34,  3, 220,  199, 200, 230],
[67,  1, 45,  212, 240, 245]], dtype=np.float32)

In [0]:
numpy_tensor.dtype

dtype('float32')

In [0]:
input = torch.from_numpy(numpy_tensor)

In [0]:
input

tensor([[133., 180.,  53.,  13.,  90., 113.],
        [ 55.,  67.,  98.,  99.,  23., 101.],
        [ 34.,  54.,  89.,   4.,  12.,   5.],
        [ 56.,   2.,   3., 112.,  45., 156.],
        [ 34.,   3., 220., 199., 200., 230.],
        [ 67.,   1.,  45., 212., 240., 245.]])

In [0]:
input = input.unsqueeze(0)

In [0]:
input.shape

torch.Size([1, 6, 6])

In [0]:
d = torch.linspace(-1, 1, 8)
meshx, meshy = torch.meshgrid((d, d))
grid = torch.stack((meshy, meshx), 2)
grid = grid.unsqueeze(0) # add batch dim

In [0]:
test = torch.nn.functional.grid_sample(input.unsqueeze(0), grid, align_corners=True)

In [0]:
print(test[0][0])

tensor([[133.0000, 166.5714, 125.5714,  47.2857,  18.7143,  57.0000,  96.5714,
         113.0000],
        [ 77.2857,  93.0000,  93.2245,  83.6122,  75.9592,  55.9796,  59.9388,
         104.4286],
        [ 46.0000,  57.0204,  75.4490,  89.0204,  63.4082,  35.4286,  30.1633,
          59.8572],
        [ 37.1429,  43.8776,  59.4898,  68.5306,  27.6122,  17.8775,  19.5306,
          26.5714],
        [ 52.8571,  21.8367,  11.9388,  26.8979,  84.9592,  64.4082,  67.1837,
         134.4286],
        [ 43.4286,  14.2449,  55.8980, 131.9592, 156.7551, 145.6327, 152.0612,
         198.2857],
        [ 43.4286,  14.1429,  74.2449, 174.6735, 198.0408, 207.6939, 217.9592,
         234.2857],
        [ 67.0000,  19.8571,  19.8571,  68.8571, 188.1429, 228.0000, 241.4286,
         245.0000]])
