In [None]:
import tensorflow as tf
from tensorflow.keras import layers
import numpy as np
from PIL import Image
import os
os.environ['KMP_DUPLICATE_LIB_OK']='True'


In [None]:
class SpectralNormalization(tf.keras.layers.Wrapper):
    """Performs spectral normalization on weights.
    This wrapper controls the Lipschitz constant of the layer by
    constraining its spectral norm, which can stabilize the training of GANs.
    See [Spectral Normalization for Generative Adversarial Networks](https://arxiv.org/abs/1802.05957).
    Wrap `tf.keras.layers.Conv2D`:
    >>> x = np.random.rand(1, 10, 10, 1)
    >>> conv2d = SpectralNormalization(tf.keras.layers.Conv2D(2, 2))
    >>> y = conv2d(x)
    >>> y.shape
    TensorShape([1, 9, 9, 2])
    Wrap `tf.keras.layers.Dense`:
    >>> x = np.random.rand(1, 10, 10, 1)
    >>> dense = SpectralNormalization(tf.keras.layers.Dense(10))
    >>> y = dense(x)
    >>> y.shape
    TensorShape([1, 10, 10, 10])
    Args:
      layer: A `tf.keras.layers.Layer` instance that
        has either `kernel` or `embeddings` attribute.
      power_iterations: `int`, the number of iterations during normalization.
    Raises:
      AssertionError: If not initialized with a `Layer` instance.
      ValueError: If initialized with negative `power_iterations`.
      AttributeError: If `layer` does not has `kernel` or `embeddings` attribute.
    """

    def __init__(self, layer: tf.keras.layers, power_iterations: int = 1, **kwargs):
        super().__init__(layer, **kwargs)
        if power_iterations <= 0:
            raise ValueError(
                "`power_iterations` should be greater than zero, got "
                "`power_iterations={}`".format(power_iterations)
            )
        self.power_iterations = power_iterations
        self._initialized = False

    def build(self, input_shape):
        """Build `Layer`"""
        super().build(input_shape)
        input_shape = tf.TensorShape(input_shape)
        self.input_spec = tf.keras.layers.InputSpec(shape=[None] + input_shape[1:])

        if hasattr(self.layer, "kernel"):
            self.w = self.layer.kernel
        elif hasattr(self.layer, "embeddings"):
            self.w = self.layer.embeddings
        else:
            raise AttributeError(
                "{} object has no attribute 'kernel' nor "
                "'embeddings'".format(type(self.layer).__name__)
            )

        self.w_shape = self.w.shape.as_list()

        self.u = self.add_weight(
            shape=(1, self.w_shape[-1]),
            initializer=tf.initializers.TruncatedNormal(stddev=0.02),
            trainable=False,
            name="sn_u",
            dtype=self.w.dtype,
        )

    def call(self, inputs, training=None):
        """Call `Layer`"""
        if training is None:
            training = tf.keras.backend.learning_phase()

        if training:
            self.normalize_weights()

        output = self.layer(inputs)
        return output

    def compute_output_shape(self, input_shape):
        return tf.TensorShape(self.layer.compute_output_shape(input_shape).as_list())

    @tf.function
    def normalize_weights(self):
        """Generate spectral normalized weights.
        This method will update the value of `self.w` with the
        spectral normalized value, so that the layer is ready for `call()`.
        """

        w = tf.reshape(self.w, [-1, self.w_shape[-1]])
        u = self.u

        with tf.name_scope("spectral_normalize"):
            for _ in range(self.power_iterations):
                v = tf.math.l2_normalize(tf.matmul(u, w, transpose_b=True))
                u = tf.math.l2_normalize(tf.matmul(v, w))

            sigma = tf.matmul(tf.matmul(v, w), u, transpose_b=True)

            self.w.assign(self.w / sigma)
            self.u.assign(u)

    def get_config(self):
        config = {"power_iterations": self.power_iterations}
        base_config = super().get_config()
        return {**base_config, **config}

In [None]:
class ResBlock(layers.Layer):
    def __init__(self,kernel_sizes = 3, **kwargs):
        super(ResBlock, self).__init__(**kwargs)
        
    def build(self, input_shape):
        channels = input_shape[-1]
        self.relU_1 = layers.ReLU()
        self.bn_1 = layers.BatchNormalization()
        self.conv2d_1 = layers.Conv2D(channels, 3, padding = 'same')
        self.relU_2 = layers.ReLU()
        self.bn_2 = layers.BatchNormalization()
        self.conv2d_2 = layers.Conv2D(channels, 3, padding = 'same')
    def call(self, inputs):
        x = self.relU_1(inputs)
        x = self.bn_1(inputs)
        x = self.conv2d_1(x)
        x = self.relU_2(x)
        x = self.bn_2(x)
        x = self.conv2d_2(x)
        out = inputs + x
        
  

In [None]:
class ResBlockSN(layers.Layer):
    def __init__(self,kernel_sizes = 3, **kwargs):
        super(ResBlockSN, self).__init__(**kwargs)
        
    def build(self, input_shape):
        channels = input_shape[-1]
        self.relU_1 = layers.ReLU()
        self.bn_1 = layers.BatchNormalization()
        self.conv2d_1 = SpectralNormalization(layers.Conv2D(channels, 3, padding = 'same'))
        self.relU_2 = layers.ReLU()
        self.bn_2 = layers.BatchNormalization()
        self.conv2d_2 = SpectralNormalization(layers.Conv2D(channels, 3, padding = 'same'))
    def call(self, inputs):
        x = self.relU_1(inputs)
        x = self.bn_1(inputs)
        x = self.conv2d_1(x)
        x = self.relU_2(x)
        x = self.bn_2(x)
        x = self.conv2d_2(x)
        out = inputs + x
        
  

In [None]:
class ResBlockCN(layers.Layer):
    def __init__(self,class_num,embedding_size, kernel_sizes = 3, **kwargs):
        super(ResBlockCN, self).__init__(**kwargs)
        self.class_num = class_num
        self.embedding_size = embedding_size
        
    def build(self, input_shape):
        #assert (type(input_shape) == list)
        
        features, label = input_shape[0], input_shape[1]
        channels = features[-1]
        
        self.relU_1 = layers.ReLU()
        self.cbn_1 = ConditionalBatchNorm(self.class_num, embedding_size=self.embedding_size)
        self.conv2d_1 = layers.Conv2D(channels, 3, padding = 'same')
        self.relU_2 = layers.ReLU()
        self.cbn_2 = ConditionalBatchNorm(self.class_num, embedding_size=self.embedding_size)
        self.conv2d_2 = layers.Conv2D(channels, 3, padding = 'same')
    def call(self, inputs):
        features, label = inputs[0], inputs[1]
        x = self.relU_1(features)
        x = self.cbn_1([x,label])
        x = self.conv2d_1(x)
        x = self.relU_2(x)
        x = self.cbn_2([x,label])
        x = self.conv2d_2(x)
        out = features + x
        return out

In [None]:
test_tensor = tf.random.uniform((4,8,8,3))
test_labels = tf.ones((4,1))
res_cbn = ResBlockCN(9,32)
print(res_cbn([test_tensor,test_labels]).shape)

In [None]:
#used for inital embebdding and concat onto first set of filters from latent as well as image input in discrimanator
class ClassEmbedding(layers.Layer):
    def __init__(self, class_num, embedding_size, output_height,**kwargs):
        super(ClassEmbedding, self).__init__(**kwargs)
        self.class_num = class_num
        self.embedding_size = embedding_size
        self.output_height = output_height
    
    def build(self, input_shape):
        self.embed = layers.Embedding(self.class_num, self.embedding_size,input_length=1)
        self.dense1 = layers.Dense(self.output_height*self.output_height)
        self.reshape = layers.Reshape((self.output_height,self.output_height,1))
    def call(self, inputs):
        x = self.embed(inputs)
        x = self.dense1(inputs)
        x = self.reshape(x)
        return x
    
    
        

In [None]:
class ConditionalBatchNorm(layers.Layer):
    def __init__(self,class_num,embedding_size,training = True, **kwargs):
        super(ConditionalBatchNorm, self).__init__(**kwargs)
        self.training = training
        self.class_num = class_num
        self.decay = 0.9
        self.epsilon = 1e-05
        self.embedding_size = embedding_size
        
        
    def build(self, input_shape):
        assert (type(input_shape) == list)
        features, label = input_shape[0], input_shape[1]
        self.channels = features[-1]
        
        zero_init = tf.keras.initializers.Constant(0.0)
        one_init = tf.keras.initializers.Constant(1.0)
        
        self.test_mean = tf.Variable(name = "pop_mean", initial_value=zero_init(shape=(self.channels),dtype='float32'))
        self.test_var = tf.Variable(name = "pop_mean", initial_value=one_init(shape=(self.channels),dtype='float32'))
        
        self.classEmbedding = layers.Embedding(self.class_num, self.embedding_size,input_length=1)
        
        self.betaLayer = layers.Dense(self.channels)
        self.gammaLayer = layers.Dense(self.channels)
        
        
    def call(self, inputs):
        x, label = inputs[0], inputs[1]
        channels = x.shape[-1] 
        embedding = self.classEmbedding(label)
        #print(embedding.shape)
        
        beta = self.betaLayer(embedding)
        gamma = self.gammaLayer(embedding)
        
        beta = tf.reshape(beta, shape=[-1, 1, 1, self.channels])
        gamma = tf.reshape(gamma, shape=[-1, 1, 1, self.channels])
        batch_mean, batch_var = tf.nn.moments(x, [0, 1, 2])

        if self.training:
            batch_mean, batch_var = tf.nn.moments(x, [0, 1, 2])
            ema_mean = self.test_mean.assign(self.test_mean * self.decay + batch_mean * (1 - self.decay))
            ema_var = self.test_var.assign(self.test_var * self.decay + batch_var * (1 - self.decay))
            with tf.control_dependencies([ema_mean,ema_var]):
                return tf.nn.batch_normalization(x, batch_mean, batch_var, beta, gamma, self.epsilon)
        else:
            return tf.nn.batch_normalization(x, self.test_mean, self.test_var, beta, gamma, self.epsilon)
            

In [None]:
cbn = ConditionalBatchNorm(9,128)
x = tf.random.uniform((4,4,4,3))
cbn([x,tf.ones((4,1))])


#cbn([x,tf.random.normal((4,3))])


In [None]:
class ScalarMult(layers.Layer):
    def __init__(self,**kwargs):
        super(ScalarMult, self).__init__(**kwargs)
    def build(self, input_shape):
        self.k = self.add_weight(
            name='k',
            shape=(),
            initializer='zeros',
            dtype='float32',
            trainable=True,
        )
        super(ScalarMult, self).build(input_shape)
    def call(self, inputs):
        return tf.math.scalar_mul(self.k, inputs, name=None)
    def compute_output_shape(self, input_shape):
        return input_shape

In [None]:
class SelfAttention(layers.Layer):
    def __init__(self,**kwargs):
        super(SelfAttention, self).__init__(**kwargs)
        
    def build(self, input_shape):
        channels = input_shape[-1]
        
        self.key_weights = layers.Conv2D(channels // 8, 1, padding = 'same')
        self.value_weights = layers.Conv2D(channels // 8, 1, padding = 'same')
        self.query_weights = layers.Conv2D(channels, 1, padding = 'same')
        
        self.key_reshape = layers.Reshape(target_shape= (-1, input_shape[-1]//8))
        self.value_reshape = layers.Reshape(target_shape = (-1,input_shape[-1]//8))
    
        self.query_reshape = layers.Reshape(target_shape = (-1, input_shape[-1]))
        self.o_reshape = layers.Reshape(target_shape=input_shape)
        
        self.gamma = ScalarMult()
        
        
    def call(self, inputs):
        key = self.key_weights(inputs)
        value = self.value_weights(inputs)
        query = self.query_weights(inputs)
        
        
        #key = tf.reshape(key, shape = [self.a[0], -1, self.a[-1]//8])
        #value = tf.reshape(value, shape = [self.a[0], -1,self.a[-1]//8])
        key = self.key_reshape(key)
        value = self.value_reshape(value)
                
        scores = tf.matmul(key, value, transpose_b=True)
        s_max = tf.nn.softmax(scores)
        
        #o = tf.matmul(s_max, tf.reshape(query, shape = [self.a[0],-1,self.a[-1]]))
        q_reshape = self.query_reshape(query)
        o = tf.matmul(s_max, q_reshape)
        
        #o = tf.reshape(o, shape = self.a)
        o = self.o_reshape(o)
        scaled_attention_map = self.gamma(o)

        return inputs + scaled_attention_map
        

In [None]:
selfAttention = SelfAttention()
test = tf.random.uniform((4,32,32,64))
print(tf.shape(selfAttention(test)))


def makeTest():
    inputs = layers.Input(shape = (32,32,64))
    layer = selfAttention
    x = selfAttention(inputs)
    return tf.keras.models.Model(inputs,x)

x = makeTest()

In [None]:
def MakeGenerator(latent_dim, initial_channels, class_num, embedding_size):
    
    in_latent = layers.Input(shape=(latent_dim,))
    in_label = layers.Input(shape = (1,))
#     
    x = layers.Dense(4*4*initial_channels)(in_latent)
    x = layers.Reshape((4,4,initial_channels))(x)
    
    x = ResBlockCN(class_num, embedding_size)([x,in_label])
    x = layers.UpSampling2D()(x)
    #8x8
    
    x = ResBlockCN(class_num, embedding_size)([x,in_label])
    x = layers.UpSampling2D()(x)
    #16x16
    
    x = ResBlockCN(class_num, embedding_size)([x,in_label])
    x = layers.UpSampling2D()(x)
    #32x32
    x = SelfAttention()(x)
    
    x = ResBlockCN(class_num, embedding_size)([x,in_label])
    x = layers.UpSampling2D()(x)
    #64x64
    print(x.shape)
    
    
    x = ResBlockCN(class_num, embedding_size)([x,in_label])
    x = layers.UpSampling2D()(x)
    
    #128x128
    #x = layers.Conv2D(3,(1,1), padding = "same", activation = 'sigmoid')(x)
    model = tf.keras.models.Model([in_latent, in_label], x)
    
MakeGenerator(128,64,9,128)
    

In [None]:
def MakeDiscriminator()