# pix2pix

## Load Data

In [None]:
import tensorflow as tf
import numpy as np
import glob
import math

BATCH_SIZE = 1


## Load data ##
input_paths = glob.glob('data/pix2pix/facades/train/*.jpg') # All jpgs
path_queue = tf.train.string_input_producer(input_paths, shuffle=True) # Produces image paths
reader = tf.WholeFileReader()
paths, contents = reader.read(path_queue)
rawInput = tf.image_decode_jpeg(contents)
rawInput = tf.image.convert_image_dtype(rawInput, dtype=tf.float32)

# [height, width, channel]
raw_input.set_shape([None, None, 3]) # Allow for any size (?)
width = tf.shape(rawInput)[1]

preprocess = lambda x: x * 2 - 1 # [0, 1] => [-1, 1]
a_images = preprocess(rawInput[:,:width//2,:]) # Left side
b_images = preprocess(rawInput[:,width//2:,:]) # Right side

# Can be switched for other direction
inputs, targets = [a_images, b_images]
paths, inputs, targets = tf.train.batch([paths, inputs, targets], batch_size=BATCH_SIZE)
steps_per_epoch = int(math.ceil(len(input_paths) / BATCH_SIZE))

# Generator

In [None]:
def conv(batch_input, out_channels, stride):
    ''' Convolve input with given stride. '''
    in_channels = batch_input.get_shape()[3]
    # Create filter
    conv_filter = tf.get_variable("filter",
                                  [4, 4, in_channels, out_channels],
                                  dtype=tf.float32,
                                  initializer=tf.random_normal_initializer(0, 0.02))

    padded_input = tf.pad(batch_input,
                          [[0, 0], [1, 1], [1, 1], [0, 0]],
                          mode="CONSTANT")
    # Conv output
    conv = tf.nn.conv2d(padded_input,
                        conv_filter,
                        [1, stride, stride, 1],
                        padding="VALID")
    return conv

NGF = 64
out_channels = int(targets.get_shape()[-1])

## Create Generator ##
gen_layers = []
# encoder_1: [batch, 256, 256, in_channels] => [batch, 128, 128, ngf]
output = conv(generator_inputs, 64, stide=2)
gen_layers.append(output)

layer_specs = [
    NGF * 2, # encoder_2: [batch, 128, 128, ngf] => [batch, 64, 64, ngf * 2]
    NGF * 4, # encoder_3: [batch, 64, 64, ngf * 2] => [batch, 32, 32, ngf * 4]
    NGF * 8, # encoder_4: [batch, 32, 32, ngf * 4] => [batch, 16, 16, ngf * 8]
    NGF * 8, # encoder_5: [batch, 16, 16, ngf * 8] => [batch, 8, 8, ngf * 8]
    NGF * 8, # encoder_6: [batch, 8, 8, ngf * 8] => [batch, 4, 4, ngf * 8]
    NGF * 8, # encoder_7: [batch, 4, 4, ngf * 8] => [batch, 2, 2, ngf * 8]
    NGF * 8, # encoder_8: [batch, 2, 2, ngf * 8] => [batch, 1, 1, ngf * 8]
]

for out_channels in layer_specs:
    rectified = lrelu(layers[-1], 0.2)
    # [batch, in_height, in_width, in_channels] => [batch, in_height/2, in_width/2, out_channels]
    convolved = conv(rectified, out_channels, stride=2)
    output = batchnorm(convolved)
    layers.append(output)

layer_specs = [
    (NGF * 8, 0.5),   # decoder_8: [batch, 1, 1, ngf * 8] => [batch, 2, 2, ngf * 8 * 2]
    (NGF * 8, 0.5),   # decoder_7: [batch, 2, 2, ngf * 8 * 2] => [batch, 4, 4, ngf * 8 * 2]
    (NGF * 8, 0.5),   # decoder_6: [batch, 4, 4, ngf * 8 * 2] => [batch, 8, 8, ngf * 8 * 2]
    (NGF * 8, 0.0),   # decoder_5: [batch, 8, 8, ngf * 8 * 2] => [batch, 16, 16, ngf * 8 * 2]
    (NGF * 4, 0.0),   # decoder_4: [batch, 16, 16, ngf * 8 * 2] => [batch, 32, 32, ngf * 4 * 2]
    (NGF * 2, 0.0),   # decoder_3: [batch, 32, 32, ngf * 4 * 2] => [batch, 64, 64, ngf * 2 * 2]
    (NGF, 0.0),       # decoder_2: [batch, 64, 64, ngf * 2 * 2] => [batch, 128, 128, ngf * 2]
]

num_encoder_layers = len(layers)
for decoder_layer, (out_channels, dropout) in enumerate(layer_specs):
    skip_layer = num_encoder_layers - decoder_layer - 1
    if decoder_layer == 0:
        # first decoder layer doesn't have skip connections
        # since it is directly connected to the skip_layer
        input = layers[-1]
    else:
        input = tf.concat([layers[-1], layers[skip_layer]], axis=3)

    rectified = tf.nn.relu(input)
    # [batch, in_height, in_width, in_channels] => [batch, in_height*2, in_width*2, out_channels]
    output = deconv(rectified, out_channels)
    output = batchnorm(output)

    if dropout > 0.0:
        output = tf.nn.dropout(output, keep_prob=1 - dropout)

    layers.append(output)

# decoder_1: [batch, 128, 128, ngf * 2] => [batch, 256, 256, generator_outputs_channels]
with tf.variable_scope("decoder_1"):
    input = tf.concat([layers[-1], layers[0]], axis=3)
    rectified = tf.nn.relu(input)
    output = deconv(rectified, generator_outputs_channels)
    output = tf.tanh(output)
    layers.append(output)

model = layers[-1]