In [1]:
import tensorflow as tf

In [3]:
class LatentDiffusionTransformer(tf.keras.models.Model):
    def __init__(self, position_embedder, time_embedder, dit_block, patcher, adaln, linear_layer):
        super(LatentDiffusionTransformer, self).__init__()


        # ViT sine-cosine frequency embedder
        self.position_embedder = position_embedder

        # time embedder
        # 256-dimensional frequency embedding
        # followed by two-layer MLP with SiLU activation
        self.time_embedder = time_embedder

        # linearly embeds patch in input
        # creates tokens T, determined by patch hyperparameter p
        self.patchify = patcher

        # entire block, including transformer
        self.dit_block = dit_block

        # feeds sum of timestep and class embeddings
        # into SiLU nonlinearity and linear layer
        self.ada_layernorm = adaln
        self.linear = linear_layer

    # forward pass through the diffusion transformer
    def call(self, x):

        # x is the latent input from the encoder

        noised_latent = x

        # obtaining the timestep embedding
        t = self.time_embedder(x)

        # split x into patches
        x = self.patchify(noised_latent)

        # obtaining information for reshape
        b, h, w, c = x.shape

        # obtain the positional embedding
        # this is a ViT frequency-based sine-cosine embedding 
        # that is added to the patch embeddings
        pos = self.position_embedder(x)
        x = x + pos

        # pass through the DiT Block with adaLN-zero
        x = self.dit_block(x, t)

        # layer norm
        x = self.ada_layernorm(x, t)
        x = self.linear(x)

        # linear reshape
        x = tf.reshape(x, [b, h, w, -1])

        # combine after patches

        return x


In [7]:
class DiTBlock(tf.keras.models.Model):
    def __init__(self, attention, adaln, scale, mlp, sands):
        super(DiTBlock, self).__init__()

        self.attention = attention

        self.shiftandscale = sands

        self.ada_layernorm = adaln

        self.adaln_1 = adaln
        self.scale_1 = scale


        self.adaln_2 = adaln
        self.scale_2 = scale

        # two layer feedforward (dense layers)
        self.feedforward = mlp

    def call(self, x, t):
        # x and t are the input tokens
        input_x = x

        # first layer norm
        x, omega, beta = self.adaln_1(x, t)

        # scale and shift
        x = self.shiftandscale(x, omega, beta)

        # apply multihead attention
        x = self.attention(x)

        # apply scale
        x = self.scale_1(x)

        # first residual addition
        first_out = x + input_x


        ## Now, second half of the DiT block
        x, omega, beta = self.adaln_2(first_out, t)

        # scale and shift
        x = self.shiftandscale(x, omega, beta)

        # feedforward
        x = self.feedforward(x)

        # scale
        x = self.scale_2(x, t)

        # final residual addition
        out = first_out + x 

        return out 