In [21]:
# TODO: add credits and adjust the code

In [9]:
import tensorflow as tf
import math

In [3]:
# defining the diffusion transformer itself

class DiT(tf.keras.models.Model):
    def __init__(self, img_size, patch_size, model_dim=256, k=64, heads=4, mlp_dim=512, depth=3, cuant_dim=4):
        super(DiT, self).__init__()
        
        # size of patches we decompose images to
        # also calculate the number of patches we get
        self.patch_size = patch_size
        self.n_patches = (img_size//patch_size)**2

        self.depth = depth

        self.patches = tf.keras.Sequential([
            tf.keras.layers.Conv2D(model_dim, kernel_size=patch_size, strides=patch_size, padding='same'),
        ])


        self.pos = PositionalEmbedding(self.n_patches, model_dim)
        self.sin_emb = TimestepEmbedder(model_dim)
        self.transformer = [DiTBlock(model_dim,
                            heads, mlp_dim, mod_init='zeros', k=k) for _ in range(depth)]
        self.final_layer = FinalLayer(patch_size, cuant_dim,
                                     initializer='zeros')
    
    def call(self, x):
        # this gets the latent input x from the encoder
        noisy_latent, noise_variances = x
        B = noise_variances.shape[0]

        # adding the batch dimension to the noise variances
        noise_variances = tf.reshape(noise_variances, [B, -1])

        # using the sine embedding to get the time embedding
        t = self.sin_emb(noisy_latent)

        # splitting x into the patches
        x = self.patches(noisy_latent)

        B, H, W, C = x.shape

        x = tf.reshape(x, [B, H*W, C])

        x = self.pos(x)

        # forward pass through the DiT
        for i in range(self.depth):
            x = self.transformer[i](x, t)

        # first, adaLn (adaptive layer norm for the layer norm)
        # then linear layer
        x = self.final_layer(x, t)

        # final reshape
        x = tf.reshape(x, [B, H, W, -1])

        # we have to do this because we were using patches
        x = tf.nn.depth_to_space(x, self.patch_size, data_format='NHWC')

        return x

In [6]:
class DiTBlock(tf.keras.layers.Layer):
    def __init__(self, model_dim, n_heads=2, mlp_dim=512, rate=0.0, eps=1e-6, 
                 initializer='glorot_uniform', mod_init='glorot_uniform', k=64, **kwargs):
        super(DiTBlock, self).__init__(**kwargs)

        self.attn = LinformerAttention(model_dim, n_heads, k=k, initializer=initializer)

        self.mlp = tf.keras.Sequential([
            tf.keras.layers.Dense(mlp_dim, activation='gelu', kernel_initializer=initializer),
            tf.keras.layers.Dense(model_dim, kernel_initializer=initializer),
        ])

        self.sm1 = AdaLN(epsilon=eps, initializer=mod_init)
        self.sm2 = AdaLN(epsilon=eps, initializer=mod_init)
        self.scale1 = Scale(initializer=mod_init)
        self.scale2 = Scale(initializer=mod_init)

    def call(self, inputs, z, training):
        # TODO: what about the conditioning??
        # so the conditioning gives us the z (this is the time embedding)
        # this is what we use for the alpha, gamma, and beta parameters later


        # first layer norm
        # this will do the scale and shift also
        x = self.sm1(inputs, z)

        # attention
        x = self.attn(x, x, x)

        # scale and shift
        x = self.scale1(x, z)

        # first residual
        out1 = x + inputs

        # second layer norm
        # this will do the scale and shift also
        x = self.sm2(out1, z)

        # mlp input
        x = self.mlp(x) # this is the pointwise feedforward

        # scale and shift
        out2 = self.scale2(x, z)

        # add the residual here
        return out1 + out2

In [7]:
#  scale - learn the alpha parameter

class Scale(tf.keras.layers.Layer):
    def __init__(self, initializer='glorot_uniform', **kwargs):
        super(Scale, self).__init__(**kwargs)
        self.initializer = initializer
        
    def build(self, input_shape):

        # use_bias just means it uses a bias vector (set to True by default)
        # input_shape[2] is number of output nodes -> third element of the input shape
        # kernel_initializer is just initializing the weights

        # TODO: understand the input shape
        self.alpha = tf.keras.layers.Dense(input_shape[2], use_bias=True, kernel_initializer=self.initializer)
    
    def call(self, x, z):
        scale = self.alpha(z)
        x *= tf.expand_dims(scale, axis=1)
        return x

In [8]:
class PositionalEmbedding(tf.keras.layers.Layer):
    def __init__(self, n_patches, model_dim, initializer='glorot_uniform', **kwargs):
        super(PositionalEmbedding, self).__init__(**kwargs)

        self.n_patches = n_patches

        # TODO: it uses glorot uniform already, you don't need to add it
        self.positional_embedding = tf.keras.layers.Embedding(
            input_dim=n_patches, output_dim=model_dim, embeddings_initializer=initializer
        )

    def call(self, patches):
        # delta is the spacing
        positions = tf.range(start=0, limit=self.n_patches, delta=1)

        # we apply the positional embedding to the input
        return patches + self.positional_embedding(positions)


In [11]:
class TimestepEmbedder(tf.keras.layers.Layer):
    def __init__(self, model_dim, initializer='glorot_uniform', **kwargs):
        super(TimestepEmbedder, self).__init__(**kwargs)

        self.model_dim = model_dim

        self.mlp = tf.keras.Sequential([
            # model_dim is the output dimension
            tf.keras.layers.Dense(model_dim, activation='silu', kernel_initializer=initializer),
            # no activation here (linear activation)
            tf.keras.layers.Dense(model_dim, kernel_initializer=initializer),
        ])

    def sinusoidal_embedding(self, x):
        embedding_min_freq = 1.0
        noise_embedding_max_freq = 1000.0

        frequencies = tf.exp(
            tf.linspace(
                tf.math.log(embedding_min_freq),
                tf.math.log(noise_embedding_max_freq),
                self.model_dim//2
            )
        )

        angular_speeds = 2.0 * math.pi * frequencies

        embeddings = tf.concat(
            [
                tf.sin(angular_speeds * x),
                tf.cos(angular_speeds * x)
            ],
            axis=1
        )
        
        return embeddings
    def call(self, x):
        x = tf.keras.layers.Lambda(self.sinusoidal_embedding)(x)
        x = self.mlp(x)
        return x

In [12]:
# final layer where we apply layer norm and then linear
class FinalLayer(tf.keras.layers.Layer):
    def __init__(self, patch_size, out_channels, eps=1e-6, initializer='glorot_uniform', **kwargs):
        super(FinalLayer, self).__init__(**kwargs)

        self.linear = tf.keras.Sequential([
            tf.keras.layers.Dense(patch_size*patch_size*out_channels, kernel_initializer=initializer),
        ])

        self.sm = AdaLN(epsilon=eps, initializer=initializer)

    def call(self, inputs, z, training):
        x = self.sm(inputs, z)
        x = self.linear(x)

        return x

In [None]:
# defining the AdaLN linear norm layer

class AdaLN(tf.keras.layers.Layer):
    def __init__(self, epsilon=1e-3, initializer='glorot_uniform'):
        super(AdaLN, self).__init__()

        self.epsilon = epsilon
        self.initializer = initializer
        self.norm = tf.keras.layers.LayerNormalization(epsilon=epsilon, center=False, scale=False)
    
    def build(self, input_shape):
        # getting the gamma and beta parameters here

        # z is 32x32x4
        # so input_shape[2] is 4 (one for each block)
        # and dense layer would output (batch_size, 4)
    

        self.gamma = tf.keras.layers.Dense(input_shape[2], use_bias=True, kernel_initializer=self.initializer)
        self.beta = tf.keras.layers.Dense(input_shape[2], use_bias=True, kernel_initializer=self.initializer)


    def call(self, x, z):
        # getting the gamma and beta parameters
        gamma = self.gamma(z)
        beta = self.beta(z)

        # gamma is the scale
        # beta is the shift

        # normalizing the input
        x = self.norm(x)

        # applying the gamma and beta parameters

        # do 1 + so that it doesn't just scale to 0 if gamma is 0
        # expand dim adds 1 dimension to end (just to make shapes work out and we can broadcast)
        x = x * (1+tf.expand_dims(gamma, axis=1)) + tf.expand_dims(beta, axis=1)

        return x

In [None]:
# now, time to implement the Linformer attention

class LinformerAttention(tf.keras.layers.Layer):
    def __init__(self, model_dim, n_heads, k, rate=0.0, initializer='glorot_uniform', **kwargs):
        super(LinformerAttention, self).__init__(**kwargs)
        self.n_heads = n_heads
        self.model_dim = model_dim


        assert model_dim % self.n_heads == 0

        # dimension of each head
        self.head_dim = model_dim // self.n_heads

        # weights for the query, key, and value
        self.wq = tf.keras.layers.Dense(model_dim, kernel_initializer=initializer)
        self.wk = tf.keras.layers.Dense(model_dim, kernel_initializer=initializer)
        self.wv = tf.keras.layers.Dense(model_dim, kernel_initializer=initializer)

        # now, we use E and F to save computational resources
        
        self.E = tf.keras.layers.Dense(k)
        self.F = tf.keras.layers.Dense(k)

        
        # dropout
        self.dropout1 = tf.keras.layers.Dropout(rate)
        self.dropout2 = tf.keras.layers.Dropout(rate)

        # self.w0 - how we concatenate the heads
        self.w0 = tf.keras.layers.Dense(model_dim, kernel_initializer=initializer)

    def split_heads(self, x, batch_size):

        # -1 tells tensorflow to figure out the size of the dimension
        # -1 depends on the size of the input tensor
        x = tf.reshape(x, (batch_size, -1, self.n_heads, self.head_dim))
        return tf.transpose(x, perm=[0, 2, 1, 3])
    
    # we just send in x, x, x because we are using the same input for q, k, and v
    def call(self, q, k, v):

        batch_size = q.shape[0]

        q = self.wq(q)
        k = self.wk(k)
        v = self.wv(v)


        # so we can do multihead attention
        q = self.split_heads(q, batch_size)
        k = self.split_heads(k, batch_size)
        v = self.split_heads(v, batch_size)


        # TODO: why do we split afterthe matrix multiply?

        # getting dimensionality
        dh = tf.cast(self.head_dim, tf.float32)

        # getting qk
        qk = tf.matmul(q, k, transpose=True)
        scaled_qk = qk / tf.math.sqrt(dh)

        attn = tf.nn.softmax(scaled_qk, axis=-1)
        attn = self.dropout1(attn)
        attn = tf.matmul(attn, v)

        # og attention
        # we undo the split heads by transposing
        attn = tf.transpose(attn, perm=[0, 2, 1, 3])
        original_size_attention = tf.reshape(attn, (batch_size, -1, self.model_dim))

        # now that we have the attention, we concatenate the heads
        # have to get og size with w0

        out = self.w0(original_size_attention)
        out = self.dropout2(self.w0(original_size_attention))
        return out
