In [1]:
import jax
import jax.numpy as jnp
from flax import linen as nn
from jax.random import PRNGKey

class CNN(nn.Module):
    @nn.compact
    def __call__(self, x):
        # x is expected to have shape (batch, 28, 28, 1).
        # Convolution: from 1 input channel to 3 output channels, kernel size 4.
        x = nn.Conv(features=3, kernel_size=(4, 4),
                    kernel_init=nn.initializers.lecun_normal())(x)
        # Max pooling with a 2x2 window (using default strides equal to the window)
        x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2), padding='VALID')
        # Apply ReLU activation.
        x = nn.relu(x)
        # Flatten all dimensions except for the batch dimension.
        x = x.reshape((x.shape[0], -1))
        # At this point in the Equinox model, the flattened features have size 1728.
        # Dense layer: from flattened features to 512 features.
        x = nn.Dense(features=512,
                     kernel_init=nn.initializers.lecun_normal())(x)
        # Sigmoid activation.
        x = nn.sigmoid(x)
        # Dense layer: from 512 features to 64 features.
        x = nn.Dense(features=64, 
                     kernel_init=nn.initializers.lecun_normal())(x)
        # ReLU activation.
        x = nn.relu(x)
        # Final Dense layer: from 64 features to 10 output classes.
        x = nn.Dense(features=10,
                     kernel_init=nn.initializers.lecun_normal())(x)
        # Apply log softmax for a log-probability output.
        x = nn.log_softmax(x)
        return x



In [2]:
# Example usage:
if __name__ == "__main__":
    # Create a random key
    rng = PRNGKey(0)
    
    # Instantiate the model; no parameters are stored on the instance.
    cnn = CNN()

    # Create a dummy input with a batch dimension.
    # For MNIST, each image is 28x28. Flax typically expects a 4D tensor:
    # (batch, height, width, channels). If your input is (1, 28, 28), add a channel dim.
    dummy_input = jnp.ones((1, 28, 28, 1))
    
    # Initialize the parameters of the model.
    params = cnn.init(rng, dummy_input)
    
    # Apply the model to the dummy input.
    logits = cnn.apply(params, dummy_input)
    print("Logits shape:", logits.shape)  # Expected output: (1, 10)


Logits shape: (1, 10)


In [3]:
logits

Array([[-2.5710626, -2.173689 , -3.0893073, -2.694003 , -2.450419 ,
        -1.308671 , -2.7668211, -2.6447802, -2.6483612, -1.9983457]],      dtype=float32)