In [87]:
import jax
import jax.numpy as jnp
import flax.nnx as nnx
import einops

"""
Dimension key:
B: batch size
L: sequence length
D: embedding dimension
H: number of attention heads
K: size of each attention head (D // H)
"""

class SelfAttention(nnx.Module):
    def __init__(self, n_heads: int, d_embed: int, rngs: nnx.Rngs, in_proj_bias: bool = True, out_proj_bias: bool = True):
        self.n_heads = n_heads
        self.d_embed = d_embed
        self.d_head = d_embed // n_heads
        self.in_proj = nnx.Linear(in_features=d_embed, out_features=3*d_embed, use_bias=in_proj_bias, rngs=rngs)
        self.out_proj = nnx.Linear(in_features=d_embed, out_features=d_embed, use_bias=out_proj_bias, rngs=rngs)

    def __call__(self, x_BLD,  causal_mask=False):
        B, L, D = jnp.shape(x_BLD)
        H = self.n_heads
        K = self.d_head # K = D // H

        qkv_BL3D = self.in_proj(x_BLD) # (B, L, 3D)
        q_BLD, k_BLD, v_BLD = jnp.split(qkv_BL3D, 3, axis=-1)

        q_BHLK = einops.rearrange(q_BLD, 'B L (H K) -> B H L K', H=H)
        k_BHLK = einops.rearrange(k_BLD, 'B L (H K) -> B H L K', H=H)
        v_BHLK = einops.rearrange(v_BLD, 'B L (H K) -> B H L K', H=H)

        attn_logits_BHLL = jnp.einsum('BHLK,BHMK->BHLM', q_BHLK, k_BHLK) / jnp.sqrt(K)

        if causal_mask:
            mask_LL = jnp.triu(jnp.ones((L, L)), k=1).astype(bool)
            attn_logits_BHLL = jnp.where(mask_LL, -jnp.inf, attn_logits_BHLL)

        attn_weights_BHLL = nnx.softmax(attn_logits_BHLL, axis=-1)

        out_BHLK = jnp.einsum('BHLM,BHMK->BHLK', attn_logits_BHLL, v_BHLK)

        out_BLD = einops.rearrange(out_BHLK, 'B H L K -> B L (H K)')

        out_BLD = self.out_proj(out_BLD)

        return out_BLD

# Test attention

In [88]:
def test_self_attention_shape():
    B, L, D, H = 8, 10, 64, 8
    rngs = nnx.Rngs({'params': jax.random.PRNGKey(0)})

    attention_layer = SelfAttention(n_heads=H, d_embed=D, rngs=rngs)
    x_BLD = jax.random.normal(jax.random.PRNGKey(1), (B, L, D))
    output = attention_layer(x_BLD, causal_mask=True)
    assert output.shape == (B, L, D), f"Expected shape {(B, L, D)}, but got {output.shape}"
    print("SelfAttention shape test passed!")

# Run the shape test
test_self_attention_shape()


SelfAttention shape test passed!


In [89]:
import flax.nnx as nnx
import jax
import jax.numpy as jnp
import einops

"""
Dimension key:
(missing!)
"""

class ResidualBlock:
    def __init__(self, in_features: int, out_features: int, rngs: nnx.Rngs):
        self.gn1 = nnx.GroupNorm(num_features=in_features, num_groups=32, rngs=rngs)
        self.conv1 = nnx.Conv(in_features=in_features, out_features=out_features, kernel_size=(3,3), padding=1, rngs=rngs)

        self.gn2 = nnx.GroupNorm(num_features=out_features, num_groups=32, rngs=rngs)
        self.conv2 = nnx.Conv(in_features=out_features, out_features=out_features, kernel_size=(3,3), padding=1, rngs=rngs)

        if in_features != out_features:
            self.shortcut = nnx.Conv(in_features=in_features, out_features=out_features, kernel_size=(1,1), padding=0, rngs=rngs)
        else:
            self.shortcut = lambda x: x

    def __call__(self, x):
        residual = x

        x = self.gn1(x)
        x = nnx.silu(x)
        x = self.conv1(x)
        x = self.gn2(x)
        x = nnx.silu(x)
        x = self.conv2(x)

        return x + self.shortcut(residual)

class AttentionBlock:
    def __init__(self, num_features: int, rngs: nnx.Rngs):
        self.gn = nnx.GroupNorm(num_features=num_features, num_groups=32, rngs=rngs)
        self.attention = SelfAttention(n_heads=1, d_embed=num_features, rngs=rngs)

    def __call__(self, x):
        # recall that JAX expects (B, H, W, C)
        residual = x

        x = self.gn(x)

        B, H, W, C = jnp.shape(x)

        x = einops.rearrange(x, 'B H W C -> B (H W) C')
        x = self.attention(x)
        x = einops.rearrange(x, 'B (H W) C -> B H W C', H=H, W=W)
        x += residual

        return x


# Test residual block and attention block

In [90]:
def test_residual_block_shape():
    B, H, W, in_features, out_features = 8, 32, 32, 64, 128  # Example: batch size 8, height and width 32, input channels 64, output channels 128
    rngs = nnx.Rngs({'params': jax.random.PRNGKey(0)})
    
    # Instantiate the ResidualBlock
    residual_block = ResidualBlock(in_features=in_features, out_features=out_features, rngs=rngs)

    # Random input tensor with shape (B, H, W, in_features)
    x = jax.random.normal(jax.random.PRNGKey(1), (B, H, W, in_features))

    # Forward pass through the ResidualBlock
    output = residual_block(x)

    # Check the shape of the output
    assert output.shape == (B, H, W, out_features), f"Expected shape {(B, H, W, out_features)}, but got {output.shape}"

    print("ResidualBlock shape test passed!")

# Run the shape test
test_residual_block_shape()


ResidualBlock shape test passed!


In [91]:
def test_attention_block_shape():
    B, H, W, C = 8, 32, 32, 64 
    rngs = nnx.Rngs({'params': jax.random.PRNGKey(0)})
    attention_block = AttentionBlock(num_features=C, rngs=rngs)
    x = jax.random.normal(jax.random.PRNGKey(1), (B, H, W, C))
    output = attention_block(x)
    assert output.shape == (B, H, W, C), f"Expected shape {(B, H, W, C)}, but got {output.shape}"

    print("AttentionBlock shape test passed!")

# Run the shape test
test_attention_block_shape()

AttentionBlock shape test passed!


# Test encoder and decoder

In [92]:
class Encoder(nnx.Module):
    def __init__(self, rngs: nnx.Rngs):
        # 3 -> 128 -> (6 x 128)
        self.conv1 = nnx.Conv(in_features=3, out_features=128, kernel_size=(3, 3), padding=1, rngs=rngs)

        self.res1 = ResidualBlock(in_features=128, out_features=128, rngs=rngs)
        self.res2 = ResidualBlock(in_features=128, out_features=128, rngs=rngs)

        self.conv2 = nnx.Conv(in_features=128, out_features=128, kernel_size=(3, 3), strides=(2, 2), padding=0, rngs=rngs)

        # 128 -> 256 -> (4 x 256)
        self.res3 = ResidualBlock(in_features=128, out_features=256, rngs=rngs)
        self.res4 = ResidualBlock(in_features=256, out_features=256, rngs=rngs)

        self.conv3 = nnx.Conv(in_features=256, out_features=256, kernel_size=(3, 3), strides=(2, 2), padding=0, rngs=rngs)

        # 256 -> 512 -> (10 x 512)
        self.res5 = ResidualBlock(in_features=256, out_features=512, rngs=rngs)
        self.res6 = ResidualBlock(in_features=512, out_features=512, rngs=rngs)

        self.conv4 = nnx.Conv(in_features=512, out_features=512, kernel_size=(3, 3), strides=(2, 2), padding=0, rngs=rngs)

        self.res7 = ResidualBlock(in_features=512, out_features=512, rngs=rngs)
        self.res8 = ResidualBlock(in_features=512, out_features=512, rngs=rngs)
        self.res9 = ResidualBlock(in_features=512, out_features=512, rngs=rngs)

        self.attention = AttentionBlock(num_features=512, rngs=rngs)

        # 512 -> 512
        self.res10 = ResidualBlock(in_features=512, out_features=512, rngs=rngs)

        self.groupnorm = nnx.GroupNorm(num_features=512, num_groups=32, rngs=rngs)

        # 512 -> 8 -> (2 x 8)
        self.conv5 = nnx.Conv(in_features=512, out_features=8, kernel_size=(3, 3), padding=1, rngs=rngs)
        self.conv6 = nnx.Conv(in_features=8, out_features=8, kernel_size=(1, 1), padding=0, rngs=rngs)

    def __call__(self, x, noise):
        """
        Args:
        x: (B, H, W, C=3)
        noise: (B, H/8, W/8, C=4)

        Returns:
        latent (z): (B, H/8, W/8, 4)
        """

        # Q: is there a better way to do this?
        for module in [self.conv1, self.res1, self.res2, self.conv2, self.res3, self.res4, self.conv3,
                       self.res5, self.res6, self.conv4, self.res7, self.res8, self.res9, self.attention,
                       self.res10, self.groupnorm]:

            if isinstance(module, nnx.Conv) and module.strides == (2, 2):
                x = pad_asymmetric(x,  padding=(1, 1))
            x = module(x)

        x = nnx.silu(x)
        x = self.conv5(x)
        x = self.conv6(x)

        if x.shape[-1] % 2 != 0:
            raise ValueError("Final output channels must be divisible by 2 to split mean and log_var")

        mean, log_var = jnp.split(x, 2, axis=-1)
        log_var = jnp.clip(log_var, -30, 20)
        variance = jnp.exp(log_var)
        stdev = jnp.sqrt(variance)

        x = mean + stdev * noise
        x *= 0.18215

        return x

class Decoder(nnx.Module):
    def __init__(self, rngs: nnx.Rngs):
        # Initial Convolutions and Blocks
        self.conv1 = nnx.Conv(in_features=4, out_features=4, kernel_size=(1, 1), padding=0, rngs=rngs)
        self.conv2 = nnx.Conv(in_features=4, out_features=512, kernel_size=(3, 3), padding=1, rngs=rngs)
        self.res1 = ResidualBlock(in_features=512, out_features=512, rngs=rngs)
        self.attention = AttentionBlock(num_features=512, rngs=rngs)
        self.res2 = ResidualBlock(in_features=512, out_features=512, rngs=rngs)
        self.res3 = ResidualBlock(in_features=512, out_features=512, rngs=rngs)
        self.res4 = ResidualBlock(in_features=512, out_features=512, rngs=rngs)
        self.res5 = ResidualBlock(in_features=512, out_features=512, rngs=rngs)

        # Upsampling Blocks
        self.conv3 = nnx.Conv(in_features=512, out_features=512, kernel_size=(3, 3), padding=1, rngs=rngs)
        self.res6 = ResidualBlock(in_features=512, out_features=512, rngs=rngs)
        self.res7 = ResidualBlock(in_features=512, out_features=512, rngs=rngs)
        self.res8 = ResidualBlock(in_features=512, out_features=512, rngs=rngs)

        # Upsampling Blocks
        self.conv4 = nnx.Conv(in_features=512, out_features=256, kernel_size=(3, 3), padding=1, rngs=rngs)
        self.res9 = ResidualBlock(in_features=256, out_features=256, rngs=rngs)
        self.res10 = ResidualBlock(in_features=256, out_features=256, rngs=rngs)
        self.res11 = ResidualBlock(in_features=256, out_features=256, rngs=rngs)

        # Upsampling Blocks
        self.conv5 = nnx.Conv(in_features=256, out_features=128, kernel_size=(3, 3), padding=1, rngs=rngs)
        self.res12 = ResidualBlock(in_features=128, out_features=128, rngs=rngs)
        self.res13 = ResidualBlock(in_features=128, out_features=128, rngs=rngs)
        self.res14 = ResidualBlock(in_features=128, out_features=128, rngs=rngs)

        # Final Layers
        self.groupnorm = nnx.GroupNorm(num_features=128, num_groups=32, rngs=rngs)
        self.conv_out = nnx.Conv(in_features=128, out_features=3, kernel_size=(3, 3), padding=1, rngs=rngs)

    def __call__(self, x):
        """
        Args:
            x: (B, H/8, W/8, C=4)

        Returns:
            decoded latent (x): (B, H, W, C=3)
        """
        x /= 0.18215  # Unscale

        # Define where to upsample
        upsamplings = {7, 11, 15}

        # List of modules
        modules = [self.conv1, self.conv2, self.res1, self.attention,
                   self.res2, self.res3, self.res4, self.res5,
                   self.conv3, self.res6, self.res7, self.res8,
                   self.conv4, self.res9, self.res10, self.res11,
                   self.conv5, self.res12, self.res13, self.res14]

        for i, module in enumerate(modules):
            x = module(x)
            if i in upsamplings:
                x = upsample(x, scale_factor=2)

        x = self.groupnorm(x)
        x = nnx.silu(x)
        x = self.conv_out(x)

        return x

def pad_asymmetric(x, padding):
    pad_height = padding[0]
    pad_width = padding[1]
    x = jnp.pad(x, ((0, 0), (pad_height, pad_height), (pad_width, pad_width), (0, 0)), mode='constant')
    return x

def upsample(x, scale_factor):
    B, H, W, C = jnp.shape(x)
    out_shape = (B, H*scale_factor, W*scale_factor, C)
    x = jax.image.resize(x, out_shape, method="nearest")
    return x

In [93]:
def test_encoder_shape():
    B, H, W, C = 8, 256, 256, 3  
    noise_channels = 4 

    rngs = nnx.Rngs({'params': jax.random.PRNGKey(0)})
    encoder = Encoder(rngs=rngs)
    x = jax.random.normal(jax.random.PRNGKey(1), (B, H, W, C))  # Input image tensor
    noise = jax.random.normal(jax.random.PRNGKey(2), (B, H // 8, W // 8, noise_channels))  # Noise tensor
    latent = encoder(x, noise)
    assert latent.shape == (B, H // 8, W // 8, noise_channels), f"Expected shape {(B, H // 8, W // 8, noise_channels)}, but got {latent.shape}"

    print("Encoder shape test passed!")

test_encoder_shape()


Encoder shape test passed!


In [94]:
def test_decoder_shape():
    B, H, W, C = 8, 32, 32, 4 
    rngs = nnx.Rngs({'params': jax.random.PRNGKey(0)})
    decoder = Decoder(rngs=rngs)
    x = jax.random.normal(jax.random.PRNGKey(1), (B, H, W, C))
    decoded = decoder(x)
    expected_H = H * 8
    expected_W = W * 8
    expected_C = 3 
    assert decoded.shape == (B, expected_H, expected_W, expected_C), f"Expected shape {(B, expected_H, expected_W, expected_C)}, but got {decoded.shape}"

    print("Decoder shape test passed!")

test_decoder_shape()


Decoder shape test passed!


# CLIP test

In [103]:
from typing import List

class CLIPEmbedding(nnx.Module):
    def __init__(self, V: int, D: int, T: int, rngs: nnx.Rngs):
        # Token embeddings: shape (V, D)
        self.token_embedding_VD = nnx.Embed(
            num_embeddings=V, features=D, rngs=rngs
        )
        # Positional embeddings: shape (T, D)
        self.position_embedding_TD = nnx.Param(
            jnp.zeros((T, D))
        )

    def __call__(self, tokens_BL):
        """
        Args:
            tokens_BL: Input token IDs, shape (B, L)
        Returns:
            embeddings_BLD: Token embeddings with positional encoding, shape (B, L, D)
        """
        embeddings_BLD = self.token_embedding_VD(tokens_BL)  # Shape: (B, L, D)
        position_embeddings_LD = self.position_embedding_TD[:embeddings_BLD.shape[1], :]  # Shape: (L, D)
        embeddings_BLD += position_embeddings_LD  # Broadcasting over batch dimension
        return embeddings_BLD

class SelfAttention(nnx.Module):
    def __init__(self, H: int, D: int, rngs: nnx.Rngs):
        self.H = H  # Number of attention heads
        self.D = D  # Embedding dimension
        self.K = D // H  # Dimension per head

        assert D % H == 0, "Embedding dimension D must be divisible by number of heads H"

        rng_qkv, rng_out = rngs.split(2)
        self.qkv_proj_D3D = nnx.Linear(
            in_features=D, out_features=3 * D, use_bias=False, rngs=rng_qkv
        )
        self.out_proj_DD = nnx.Linear(
            in_features=D, out_features=D, rngs=rng_out
        )

    def __call__(self, input_BLD, causal_mask=True):
        """
        Args:
            input_BLD: Input tensor, shape (B, L, D)
            causal_mask: Whether to apply causal masking
        Returns:
            output_BLD: Output tensor after self-attention, shape (B, L, D)
        """
        B, L, D = input_BLD.shape
        H, K = self.H, self.K

        qkv_BL3D = self.qkv_proj_D3D(input_BLD)  # Shape: (B, L, 3*D)
        qkv_BL3HK = qkv_BL3D.reshape(B, L, 3, H, K)  # Shape: (B, L, 3, H, K)
        qkv_3BHLD = qkv_BL3HK.transpose(2, 0, 3, 1, 4)  # Shape: (3, B, H, L, K)
        q_BHLK, k_BHLK, v_BHLK = qkv_3BHLD  # Each has shape: (B, H, L, K)

        # Compute attention scores
        attn_logits_BHLL = jnp.einsum("BHLK,BHMK->BHLM", q_BHLK, k_BHLK) / jnp.sqrt(K)
        # attn_logits_BHLL: (B, H, L, L)

        if causal_mask:
            mask_LL = jnp.tril(jnp.ones((L, L), dtype=bool))
            attn_logits_BHLL = jnp.where(mask_LL, attn_logits_BHLL, -1e10)

        attn_weights_BHLL = nnx.softmax(attn_logits_BHLL, axis=-1)
        attn_output_BHLK = jnp.einsum("BHLM,BHMK->BHLK", attn_weights_BHLL, v_BHLK)
        attn_output_BLHK = attn_output_BHLK.transpose(0, 2, 1, 3)  # Shape: (B, L, H, K)
        attn_output_BLD = attn_output_BLHK.reshape(B, L, D)  # Shape: (B, L, D)

        output_BLD = self.out_proj_DD(attn_output_BLD)  # Shape: (B, L, D)
        return output_BLD

class CLIPLayer(nnx.Module):
    def __init__(self, H: int, D: int, rngs: nnx.Rngs):
        # Pre-attention layer normalization
        self.layernorm_1 = nnx.LayerNorm()

        # Self-attention module
        self.attention = SelfAttention(H, D, rngs=rngs)

        # Pre-FFN layer normalization
        self.layernorm_2 = nnx.LayerNorm()

        # Feedforward layers
        self.linear_1_DF = nnx.Linear(
            in_features=D, out_features=4 * D, rngs=rngs
        )  # Shape: (D, F)
        self.linear_2_FD = nnx.Linear(
            in_features=4 * D, out_features=D, rngs=rngs
        )  # Shape: (F, D)

    def __call__(self, input_BLD):
        """
        Args:
            input_BLD: Input tensor, shape (B, L, D)
        Returns:
            output_BLD: Output tensor, shape (B, L, D)
        """
        # Residual connection for self-attention
        residue_BLD = input_BLD

        # Pre-Attention LayerNorm
        normalized_BLD = self.layernorm_1(input_BLD)

        # Self-Attention
        attn_out_BLD = self.attention(normalized_BLD, causal_mask=True)

        # Add residual connection
        x_BLD = attn_out_BLD + residue_BLD

        # Residual connection for feedforward network
        residue_BLD = x_BLD

        # Pre-FFN LayerNorm
        normalized_BLD = self.layernorm_2(x_BLD)

        # Feedforward network
        ff1_BLF = self.linear_1_DF(normalized_BLD)  # Shape: (B, L, F)
        # QuickGELU activation
        ff1_BLF = ff1_BLF * nnx.sigmoid(1.702 * ff1_BLF)
        ff2_BLD = self.linear_2_FD(ff1_BLF)  # Shape: (B, L, D)

        # Add residual connection
        output_BLD = ff2_BLD + residue_BLD

        return output_BLD

class CLIP(nnx.Module):
    def __init__(self, V: int, D: int, T: int, H: int, n_layer: int, rngs: nnx.Rngs):

        # Embedding layer
        self.embedding = CLIPEmbedding(V, D, T, rngs=rngs)

        # Encoder layers
        self.layers: List[CLIPLayer] = []
        for i in range(n_layer):
            self.layers.append(CLIPLayer(H, D, rngs=rngs))

        # Final layer normalization
        self.layernorm = nnx.LayerNorm()

    def __call__(self, tokens_BL):
        """
        Args:
            tokens_BL: Input token IDs, shape (B, L)
        Returns:
            output_BLD: Output tensor, shape (B, L, D)
        """
        tokens_BL = tokens_BL.astype(jnp.int32)

        # Embedding
        x_BLD = self.embedding(tokens_BL)

        # Apply encoder layers
        for layer in self.layers:
            x_BLD = layer(x_BLD)

        # Final layer normalization
        output_BLD = self.layernorm(x_BLD)

        return output_BLD  # Shape: (B, L, D)


In [104]:
rngs = nnx.Rngs({'params': jax.random.PRNGKey(0)})

# Define model parameters
V = 49408     # Vocabulary size
D = 768       # Embedding dimension
T = 77        # Maximum sequence length
H = 12        # Number of attention heads
n_layer = 12  # Number of layers

# Instantiate the model
model = CLIP(
    V=V,
    D=D,
    T=T,
    H=H,
    n_layer=n_layer,
    rngs=rngs
)

# Example input tokens_BL (batch size B=2, sequence length L=77)
tokens_BL = jnp.array([
    [1, 2, 3] + [0] * (T - 3),
    [4, 5, 6] + [0] * (T - 3)
])  # Shape: (B=2, L=77)

# Forward pass
output_BLD = model(tokens_BL)  # Output shape: (B=2, L=77, D=768)
print(f"Output shape: {output_BLD.shape}")  # Should print (2, 77, 768)


TypeError: LayerNorm.__init__() missing 1 required positional argument: 'num_features'