In [1]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])
  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


In [17]:
class stylegan(object):
    def __init__(self, session, output_resolution=256):
        self.sess = session
        self.output_resolution = output_resolution
        self.num_style_blocks = 0
        self.num_to_rgbs = 0
        
        self.conv_weights = []
        self.w_transforms = []
        
        self.latent_w = tf.placeholder(
            [None, 512]
        )
        
    def generator(self, W_in):
        with tf.variable_scope("generator") as vs:
            self.constant_input = tf.get_variable(
                "c_1",
                [4, 4, 512],
                initializer=tf.initializers.orthogonal
            )
            
            block_4_2 = self.styleBlock(
                self.constant_input,
                W_in,
                num_input_channels=512,
                num_output_channels=512
            )
            
            block_8_1 = self.styleBlock(
                block_4_2,
                W_in,
                num_input_channels=512,
                num_output_channels=512,
                upsample=True
            )
            
            to_rgb_1 = self.toRgb(block_8_1)
            
            block_8_2 = self.styleBlock(
                block_8_1,
                W_in,
                num_input_channels=512,
                num_output_channels=512
            )
            
            block_16_1 = self.styleBlock(
                block_8_2,
                W_in,
                num_input_channels=512,
                num_output_channels=512,
                upsample=True
            )
            
            block_16_2 = self.styleBlock(
                block_16_1,
                W_in,
                num_input_channels=512,
                num_output_channels=512,
            )
            
            block_32_1 = self.styleBlock(
                block_16_2,
                W_in,
                num_input_channels=512,
                num_output_channels=512,
                upsample=True
            )
            
            block_32_2 = self.styleBlock(
                block_32_1,
                W_in,
                num_input_channels=512,
                num_output_channels=512,
            )
            
            block_64_1 = self.styleBlock(
                block_32_2,
                W_in,
                num_input_channels=512,
                num_output_channels=512,
                upsample=True
            )
            
            block_64_2 = self.styleBlock(
                block_64_1,
                W_in,
                num_input_channels=512,
                num_output_channels=256,
            )
            
            block_128_1 = self.styleBlock(
                block_64_2,
                W_in,
                num_input_channels=256,
                num_output_channels=256,
                upsample=True
            )
            
            block_128_2 = self.styleBlock(
                block_128_1,
                W_in,
                num_input_channels=256,
                num_output_channels=128,
            )
            
            block_256_1 = self.styleBlock(
                block_128_2,
                W_in,
                num_input_channels=128,
                num_output_channels=128,
                upsample=True
            )
            
            block_256_2 = self.styleBlock(
                block_256_1,
                W_in,
                num_input_channels=128,
                num_output_channels=64,
            )
            
            
            
    def styleBlock(self, V_in, latent_w, num_input_channels, num_output_channels, upsample=False):
        # V_in        --> [batch_size, height, width, num_input_channels]
        # latent_w    --> [batch_size, 512]
        # conv_weight --> [num_in_fm, num_out_fm, conv_dim]
        #    num_in_fm  = number of input feature maps
        #    num_out_fm = number of output feature maps
        #    conv_dim   = dimension of convolution
        self.num_style_blocks += 1
        
        if upsample:
            V_in = self.upsample(V_in)
        
        A = tf.get_variable(
            "A_style" + str(self.num_style_blocks),
            [512, num_input_feature_maps],
            initializer=tf.initializers.orthogonal
        )
        
        conv_weight = tf.get_variable(
            "conv_w_style" + str(self.num_style_blocks),
            [3, 3, num_input_channels, num_output_channels],
            initialier=tf.initializers.orthogonal
        )
        
        # Affine transformation of latent space vector.
        scale = tf.matmul(A, latent_w)
        
        # Scale input feature map acros input channels by the affine transformation
        # of the latent space input.
        V_in_scaled = tf.einsum("bi,bhwi->bhwi", scale, V_in)
        
        V_out = tf.nn.conv2d(V_in_scaled, conv_weight, padding="same")
        
        # This increases the number of weights by a factor of batch_size,
        # which is weird.
        modul_conv_weight = tf.einsum("bi,ijk->bijk", scale, conv_weight)
        sigma_j = tf.sqrt(tf.reduce_sum(tf.square(modul_conv_weight), axis=[1, 3]) + 1e-6)
        
        # Need to add biases and broadcast noise.
        V_out_scaled = tf.nn.leaky_relu(
            tf.einsum("bhwj,bj->bhwj", V_in_scaled, sigma_j),
            alpha=0.2
        )

        return V_out_scaled
    
    def upsample(self, V_in):
        # Tested with the channel dimension.
        fm_size = tf.shape(V_in)
        batch_size, h, w, c = fm_size
        V_in_a = tf.concat([V_in, V_in,], axis=2)
        V_in_b = tf.reshape(V_in_a, [batch_size, 2*h, w, c])

        V_in_c = tf.transpose(V_in_b, perm=[0, 2, 1, 3])
        V_in_d = tf.concat([V_in_c, V_in_c], axis=2)
        V_out = tf.transpose(tf.reshape(e, [batch_size, 2*h, 2*w, c]), perm=[0, 2, 1, 3])
        
        return V_out
    
    def toRgb(self, V_in):
        '''
        Convert an NxNxC output block to an RGB image with dimensions
        NxNx3.
        '''
        
        self.num_to_rgbs += 1
        
        V_in_shape = tf.shape(V_in)
        batch_size, h, w, c = V_in_shape
        to_rgb = tf.get_variable(
            "to_rgb" + str(self.num_to_rgbs),
            [h, w, c, 3],
            initializer=tf.random_normal(stdev=0.2)
        )
        
        rgb_out = tf.nn.relu(
            tf.nn.conv2d(V_in, to_rgb, padding="same")
        )
        
        return rgb_out

In [6]:
a = tf.constant(
    np.array(
        [
            [[1,2,3],
             [4,5,6],
             [7,8,9]
            ],
            [[-1,-2,-3],
             [-4,-5,-6],
             [-7,-8,-9]
            ],
            [[1.2,2.2,3.2],
             [4.2,5.2,6.2],
             [7.2,8.2,9.2]
            ],
        ]
    )
)

In [7]:
sess = tf.Session()

In [16]:
b = tf.concat([a, a,], axis=2)
c = tf.reshape(b, [3,6,3])

d = tf.transpose(c, perm=[0, 2, 1])
e = tf.concat([d, d], axis=2)
f = tf.transpose(tf.reshape(e, [3, 6, 6]), perm=[0, 2, 1])

g = tf.stack([a, 2*a, 3.4*a], axis=3)
h = tf.concat([g, g], axis=2)
i = tf.reshape(h, [3, 6, 3, 3])
j = tf.transpose(i, perm=[0, 2, 1, 3])
k = tf.concat([j, j], axis=2)
l = tf.transpose(tf.reshape(k, [3, 6, 6, 3]), perm=[0, 2, 1, 3])

for i in range(3):
    print(sess.run(l)[:, :, :, i])

[[[ 1.   1.   2.   2.   3.   3. ]
  [ 1.   1.   2.   2.   3.   3. ]
  [ 4.   4.   5.   5.   6.   6. ]
  [ 4.   4.   5.   5.   6.   6. ]
  [ 7.   7.   8.   8.   9.   9. ]
  [ 7.   7.   8.   8.   9.   9. ]]

 [[-1.  -1.  -2.  -2.  -3.  -3. ]
  [-1.  -1.  -2.  -2.  -3.  -3. ]
  [-4.  -4.  -5.  -5.  -6.  -6. ]
  [-4.  -4.  -5.  -5.  -6.  -6. ]
  [-7.  -7.  -8.  -8.  -9.  -9. ]
  [-7.  -7.  -8.  -8.  -9.  -9. ]]

 [[ 1.2  1.2  2.2  2.2  3.2  3.2]
  [ 1.2  1.2  2.2  2.2  3.2  3.2]
  [ 4.2  4.2  5.2  5.2  6.2  6.2]
  [ 4.2  4.2  5.2  5.2  6.2  6.2]
  [ 7.2  7.2  8.2  8.2  9.2  9.2]
  [ 7.2  7.2  8.2  8.2  9.2  9.2]]]
[[[  2.    2.    4.    4.    6.    6. ]
  [  2.    2.    4.    4.    6.    6. ]
  [  8.    8.   10.   10.   12.   12. ]
  [  8.    8.   10.   10.   12.   12. ]
  [ 14.   14.   16.   16.   18.   18. ]
  [ 14.   14.   16.   16.   18.   18. ]]

 [[ -2.   -2.   -4.   -4.   -6.   -6. ]
  [ -2.   -2.   -4.   -4.   -6.   -6. ]
  [ -8.   -8.  -10.  -10.  -12.  -12. ]
  [ -8.   -8.  -10. 

In [9]:
sess.run(c)

array([[[ 1. ,  2. ,  3. ],
        [ 1. ,  2. ,  3. ],
        [ 4. ,  5. ,  6. ],
        [ 4. ,  5. ,  6. ],
        [ 7. ,  8. ,  9. ],
        [ 7. ,  8. ,  9. ]],

       [[-1. , -2. , -3. ],
        [-1. , -2. , -3. ],
        [-4. , -5. , -6. ],
        [-4. , -5. , -6. ],
        [-7. , -8. , -9. ],
        [-7. , -8. , -9. ]],

       [[ 1.2,  2.2,  3.2],
        [ 1.2,  2.2,  3.2],
        [ 4.2,  5.2,  6.2],
        [ 4.2,  5.2,  6.2],
        [ 7.2,  8.2,  9.2],
        [ 7.2,  8.2,  9.2]]])

In [10]:
sess.run(f)

array([[[ 1. ,  1. ,  2. ,  2. ,  3. ,  3. ],
        [ 1. ,  1. ,  2. ,  2. ,  3. ,  3. ],
        [ 4. ,  4. ,  5. ,  5. ,  6. ,  6. ],
        [ 4. ,  4. ,  5. ,  5. ,  6. ,  6. ],
        [ 7. ,  7. ,  8. ,  8. ,  9. ,  9. ],
        [ 7. ,  7. ,  8. ,  8. ,  9. ,  9. ]],

       [[-1. , -1. , -2. , -2. , -3. , -3. ],
        [-1. , -1. , -2. , -2. , -3. , -3. ],
        [-4. , -4. , -5. , -5. , -6. , -6. ],
        [-4. , -4. , -5. , -5. , -6. , -6. ],
        [-7. , -7. , -8. , -8. , -9. , -9. ],
        [-7. , -7. , -8. , -8. , -9. , -9. ]],

       [[ 1.2,  1.2,  2.2,  2.2,  3.2,  3.2],
        [ 1.2,  1.2,  2.2,  2.2,  3.2,  3.2],
        [ 4.2,  4.2,  5.2,  5.2,  6.2,  6.2],
        [ 4.2,  4.2,  5.2,  5.2,  6.2,  6.2],
        [ 7.2,  7.2,  8.2,  8.2,  9.2,  9.2],
        [ 7.2,  7.2,  8.2,  8.2,  9.2,  9.2]]])

In [11]:
help(tf.reshape)

Help on function reshape in module tensorflow.python.ops.gen_array_ops:

reshape(tensor, shape, name=None)
    Reshapes a tensor.
    
    Given `tensor`, this operation returns a tensor that has the same values
    as `tensor` with shape `shape`.
    
    If one component of `shape` is the special value -1, the size of that dimension
    is computed so that the total size remains constant.  In particular, a `shape`
    of `[-1]` flattens into 1-D.  At most one component of `shape` can be -1.
    
    If `shape` is 1-D or higher, then the operation returns a tensor with shape
    `shape` filled with the values of `tensor`. In this case, the number of elements
    implied by `shape` must be the same as the number of elements in `tensor`.
    
    For example:
    
    ```
    # tensor 't' is [1, 2, 3, 4, 5, 6, 7, 8, 9]
    # tensor 't' has shape [9]
    reshape(t, [3, 3]) ==> [[1, 2, 3],
                            [4, 5, 6],
                            [7, 8, 9]]
    
    # tensor 't' 

In [12]:
help(tf.tile)

Help on function tile in module tensorflow.python.ops.gen_array_ops:

tile(input, multiples, name=None)
    Constructs a tensor by tiling a given tensor.
    
    This operation creates a new tensor by replicating `input` `multiples` times.
    The output tensor's i'th dimension has `input.dims(i) * multiples[i]` elements,
    and the values of `input` are replicated `multiples[i]` times along the 'i'th
    dimension. For example, tiling `[a b c d]` by `[2]` produces
    `[a b c d a b c d]`.
    
    Args:
      input: A `Tensor`. 1-D or higher.
      multiples: A `Tensor`. Must be one of the following types: `int32`, `int64`.
        1-D. Length must be the same as the number of dimensions in `input`
      name: A name for the operation (optional).
    
    Returns:
      A `Tensor`. Has the same type as `input`.



In [13]:
help(tf.transpose)

Help on function transpose in module tensorflow.python.ops.array_ops:

transpose(a, perm=None, name='transpose', conjugate=False)
    Transposes `a`.
    
    Permutes the dimensions according to `perm`.
    
    The returned tensor's dimension i will correspond to the input dimension
    `perm[i]`. If `perm` is not given, it is set to (n-1...0), where n is
    the rank of the input tensor. Hence by default, this operation performs a
    regular matrix transpose on 2-D input Tensors. If conjugate is True and
    `a.dtype` is either `complex64` or `complex128` then the values of `a`
    are conjugated and transposed.
    
    @compatibility(numpy)
    In `numpy` transposes are memory-efficient constant time operations as they
    simply return a new view of the same data with adjusted `strides`.
    
    TensorFlow does not support strides, so `transpose` returns a new tensor with
    the items permuted.
    @end_compatibility
    
    For example:
    
    ```python
    x = tf.constant

In [14]:
help(tf.stack)

Help on function stack in module tensorflow.python.ops.array_ops:

stack(values, axis=0, name='stack')
    Stacks a list of rank-`R` tensors into one rank-`(R+1)` tensor.
    
    Packs the list of tensors in `values` into a tensor with rank one higher than
    each tensor in `values`, by packing them along the `axis` dimension.
    Given a list of length `N` of tensors of shape `(A, B, C)`;
    
    if `axis == 0` then the `output` tensor will have the shape `(N, A, B, C)`.
    if `axis == 1` then the `output` tensor will have the shape `(A, N, B, C)`.
    Etc.
    
    For example:
    
    ```python
    x = tf.constant([1, 4])
    y = tf.constant([2, 5])
    z = tf.constant([3, 6])
    tf.stack([x, y, z])  # [[1, 4], [2, 5], [3, 6]] (Pack along first dim.)
    tf.stack([x, y, z], axis=1)  # [[1, 2, 3], [4, 5, 6]]
    ```
    
    This is the opposite of unstack.  The numpy equivalent is
    
    ```python
    tf.stack([x, y, z]) = np.stack([x, y, z])
    ```
    
    Args:
      val

In [15]:
# Use matmuls for upsampling, see whiteboard