In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)

test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=True)

def train_test_loop(model, train_loader, test_loader, epochs=10):
    model.train()
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    for epoch in range(epochs):
        for image, label in train_loader:
            pred = model(image)
            loss = criterion(pred, label)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        print(f"Training loss at epoch {epoch} = {loss.item()}")

    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for image_test, label_test in test_loader:
            pred_test = model(image_test)
            _, pred_test_vals = torch.max(pred_test, dim=1)
            total += label_test.size(0)
            correct += (pred_test_vals == label_test).sum().item()
    print(f"Test Accuracy = {(correct * 100)/total}")

class VanillaCNNModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.fc1 = nn.Linear(64*16*16, 128)
        self.fc2 = nn.Linear(128, 10)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.pool(self.relu(self.conv2(x)))
        x = x.view(x.size(0), -1)
        x = self.relu(self.fc1(x))
        x = self.fc2(x)
        return x

def config_init(init_type="kaiming"):

    def kaiming_init(m):
        if isinstance(m, nn.Conv2d):
            nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            if m.bias is not None:
                nn.init.zeros_(m.bias)
        elif isinstance(m, nn.Linear):
            nn.init.kaiming_normal_(m.weight)
            nn.init.constant_(m.bias, 0)


    def xavier_init(m):
        if isinstance(m, (nn.Conv2d, nn.Linear)):
            nn.init.xavier_normal_(m.weight)
            if m.bias is not None:
                nn.init.zeros_(m.bias)

    def zeros_init(m):
        if isinstance(m, (nn.Conv2d, nn.Linear)):
            nn.init.zeros_(m.weight)
            nn.init.zeros_(m.bias)

    def random_init(m):
        if isinstance(m, (nn.Conv2d, nn.Linear)):
            nn.init.normal_(m.weight)
            nn.init.normal_(m.bias)


    initializer_dict = {"kaiming": kaiming_init,
                        "xavier": xavier_init,
                        "zeros": zeros_init,
                        "random": random_init}

    return initializer_dict.get(init_type)

for name, model in zip(["Vanilla", "Kaiming", "Xavier", "Zeros", "Random"], [VanillaCNNModel(),
              VanillaCNNModel().apply(config_init("kaiming")),
              VanillaCNNModel().apply(config_init("xavier")),
              VanillaCNNModel().apply(config_init("zeros")),
              VanillaCNNModel().apply(config_init("random"))
              ]):
    print(f"_________{name}_______________________")
    train_test_loop(model, train_loader, test_loader)


In [2]:
## STRONG LLM
import jax
import jax.numpy as jnp
from flax import linen as nn
from flax.training import train_state
import optax
import tensorflow as tf
import tensorflow_datasets as tfds
import numpy as np

# Define the CNN model with configurable initializers.
class CNNModel(nn.Module):
    conv_kernel_init: callable = nn.initializers.lecun_normal()
    conv_bias_init: callable = nn.initializers.zeros
    dense_kernel_init: callable = nn.initializers.lecun_normal()
    dense_bias_init: callable = nn.initializers.zeros

    @nn.compact
    def __call__(self, x):
        # First convolution layer: 3 channels -> 32 filters.
        x = nn.Conv(features=32, kernel_size=(3, 3), strides=(1, 1), padding='SAME',
                    kernel_init=self.conv_kernel_init, bias_init=self.conv_bias_init)(x)
        x = nn.relu(x)
        # Second convolution layer: 32 -> 64 filters.
        x = nn.Conv(features=64, kernel_size=(3, 3), strides=(1, 1), padding='SAME',
                    kernel_init=self.conv_kernel_init, bias_init=self.conv_bias_init)(x)
        x = nn.relu(x)
        # Max pooling: 2x2 window, stride 2.
        x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2), padding='VALID')
        # Flatten.
        x = x.reshape((x.shape[0], -1))
        # Fully connected layer with 128 units.
        x = nn.Dense(features=128, kernel_init=self.dense_kernel_init, bias_init=self.dense_bias_init)(x)
        x = nn.relu(x)
        # Final fully connected layer with 10 outputs.
        x = nn.Dense(features=10, kernel_init=self.dense_kernel_init, bias_init=self.dense_bias_init)(x)
        return x

# Create a training state with the model and an Adam optimizer.
def create_train_state(rng, model, learning_rate=0.001, batch_size=32):
    dummy_input = jnp.ones([batch_size, 32, 32, 3], jnp.float32)
    params = model.init(rng, dummy_input)
    tx = optax.adam(learning_rate)
    return train_state.TrainState.create(apply_fn=model.apply, params=params, tx=tx)

# Define a jitted training step.
@jax.jit
def train_step(state, images, labels):
    def loss_fn(params):
        logits = state.apply_fn(params, images)
        onehot = jax.nn.one_hot(labels, 10)
        loss = optax.softmax_cross_entropy(logits, onehot).mean()
        return loss
    loss, grads = jax.value_and_grad(loss_fn)(state.params)
    state = state.apply_gradients(grads=grads)
    return state, loss

# Evaluate the model on the test dataset.
def evaluate_model(state, test_ds):
    correct = 0
    total = 0
    for images, labels in tfds.as_numpy(test_ds):
        logits = state.apply_fn(state.params, images)
        predictions = jnp.argmax(logits, axis=-1)
        correct += np.sum(np.array(predictions) == np.array(labels))
        total += len(labels)
    accuracy = 100 * correct / total
    print(f"Test Accuracy = {accuracy:.2f}%")
    return accuracy

# Preprocessing: convert images to float and normalize to [-1, 1]
def preprocess(image, label):
    image = tf.cast(image, tf.float32) / 255.0  # Scale to [0, 1]
    image = (image - 0.5) / 0.5                 # Normalize to [-1, 1]
    return image, label

def main():
    batch_size = 32
    num_epochs = 10
    learning_rate = 0.001
    rng = jax.random.PRNGKey(0)

    # Load CIFAR-10 training dataset in supervised mode.
    train_ds = tfds.load('cifar10', split='train', as_supervised=True, download=True)
    train_ds = train_ds.map(preprocess, num_parallel_calls=tf.data.AUTOTUNE)
    train_ds = train_ds.shuffle(1000).batch(batch_size).prefetch(tf.data.AUTOTUNE)

    # Load CIFAR-10 test dataset in supervised mode.
    test_ds = tfds.load('cifar10', split='test', as_supervised=True, download=True)
    test_ds = test_ds.map(preprocess, num_parallel_calls=tf.data.AUTOTUNE)
    test_ds = test_ds.batch(batch_size).prefetch(tf.data.AUTOTUNE)

    # Define the initializer configurations.
    initializer_configs = {
        "Vanilla": {  # Use Flax defaults.
            "conv_kernel_init": nn.initializers.lecun_normal(),
            "dense_kernel_init": nn.initializers.lecun_normal(),
        },
        "Kaiming": {
            "conv_kernel_init": nn.initializers.kaiming_normal(),
            "dense_kernel_init": nn.initializers.kaiming_normal(),
        },
        "Xavier": {
            "conv_kernel_init": nn.initializers.xavier_normal(),
            "dense_kernel_init": nn.initializers.xavier_normal(),
        },
        "Zeros": {
            "conv_kernel_init": nn.initializers.zeros,
            "dense_kernel_init": nn.initializers.zeros,
        },
        "Random": {
            "conv_kernel_init": nn.initializers.normal(stddev=1.0),
            "dense_kernel_init": nn.initializers.normal(stddev=1.0),
        }
    }

    # Loop over the initialization schemes.
    for name, init_conf in initializer_configs.items():
        print(f"_________{name}_______________________")
        model = CNNModel(conv_kernel_init=init_conf["conv_kernel_init"],
                         dense_kernel_init=init_conf["dense_kernel_init"])
        state = create_train_state(rng, model, learning_rate, batch_size)

        # Training loop.
        for epoch in range(num_epochs):
            for images, labels in tfds.as_numpy(train_ds):
                state, loss = train_step(state, images, labels)
            print(f"Training loss at epoch {epoch} = {loss:.4f}")

        # Evaluate the model on the test set.
        evaluate_model(state, test_ds)

if __name__ == "__main__":
    main()




Downloading and preparing dataset Unknown size (download: Unknown size, generated: Unknown size, total: Unknown size) to /root/tensorflow_datasets/cifar10/3.0.2...


Dl Completed...: 0 url [00:00, ? url/s]

Dl Size...: 0 MiB [00:00, ? MiB/s]

Extraction completed...: 0 file [00:00, ? file/s]

Generating splits...:   0%|          | 0/2 [00:00<?, ? splits/s]

Generating train examples...: 0 examples [00:00, ? examples/s]

Shuffling /root/tensorflow_datasets/cifar10/incomplete.QVYSA4_3.0.2/cifar10-train.tfrecord*...:   0%|         …

Generating test examples...: 0 examples [00:00, ? examples/s]

Shuffling /root/tensorflow_datasets/cifar10/incomplete.QVYSA4_3.0.2/cifar10-test.tfrecord*...:   0%|          …

Dataset cifar10 downloaded and prepared to /root/tensorflow_datasets/cifar10/3.0.2. Subsequent calls will reuse this data.
_________Vanilla_______________________
Training loss at epoch 0 = 1.1228
Training loss at epoch 1 = 1.0812
Training loss at epoch 2 = 0.5941
Training loss at epoch 3 = 0.1359
Training loss at epoch 4 = 0.0490


KeyboardInterrupt: 

In [4]:
## Weak LLM
import jax
import jax.numpy as jnp
from jax import random, grad, jit, vmap
import flax.linen as nn
from flax.training import train_state
import optax

# Constants
LEARNING_RATE = 0.001
NUM_EPOCHS = 10
BATCH_SIZE = 32
NUM_CLASSES = 10
INPUT_SHAPE = (28, 28, 1)

# Define model (VanillaCNNModel is assumed to be defined elsewhere)
class VanillaCNNModel(nn.Module):
    @nn.compact
    def __call__(self, x):
        # Define the forward pass here
        pass

def create_train_state(rng, model, learning_rate):
    # Initialize the model parameters
    params = model.init(rng, jnp.ones((1, *INPUT_SHAPE)))  # MODIFIED: Input shape for initialization
    tx = optax.adam(learning_rate)
    return train_state.TrainState.create(apply_fn=model.apply, params=params, tx=tx)

@jit
def loss_fn(params, x, y):
    # Compute the loss function
    logits = model.apply(params, x)
    loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=y))  # MODIFIED: Use optax for loss
    return loss

@jit
def compute_gradients(params, x, y):
    # Compute gradients
    return grad(loss_fn)(params, x, y)

def update(params, grads):
    # Update parameters
    return optax.apply_updates(params, grads)  # MODIFIED: Use functional update

def train_model(x_train, y_train, num_epochs, batch_size):
    rng = random.PRNGKey(0)  # PRNG key for reproducibility
    model = VanillaCNNModel()
    state = create_train_state(rng, model, learning_rate=LEARNING_RATE)

    for epoch in range(num_epochs):
        for i in range(0, len(x_train), batch_size):
            x_batch = x_train[i:i + batch_size]
            y_batch = y_train[i:i + batch_size]

            grads = compute_gradients(state.params, x_batch, y_batch)
            state = state.apply_gradients(grads=grads)  # MODIFIED: Use functional updates to apply gradients

    return state.params  # Return final weights

def main():
    # Sample training data (x_train, y_train should be defined appropriately)
    x_train = jnp.ones((100, *INPUT_SHAPE))  # Placeholder, replace with actual data
    y_train = jax.nn.one_hot(jnp.zeros(100), num_classes=NUM_CLASSES)  # Placeholder, replace with actual labels

    final_weights = train_model(x_train, y_train, NUM_EPOCHS, BATCH_SIZE)
    print('Final weights:', final_weights)  # Display final weights after training

if __name__ == "__main__":
    main()


TypeError: Module.apply() takes 2 positional arguments but 3 were given

In [None]:
"""
Error Code
 def loss_fn(params, x, y):
---> logits = model.apply(params, x)

Error:
Module.apply() takes 2 positional arguments but 3 were given


Fix Guide:
the model's parameters must be wrapped inside a dictionary, which
ensures that the apply method receives the parameters in the expected format.

Correct Code:
logits = model.apply({'params': params}, x)
"""

"""####
Error Code
logits = model.apply({'params': params}, x)


Error
ApplyScopeInvalidVariablesStructureError(variables)

Fix Guide:
Since it is wrapping variables in an extra layer of "params"

Correct Code:
logits = model.apply(params, x)

"""

"""##
Error Code
loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=y))

Error:
TypeError: log_softmax requires ndarray or scalar arguments, got <class 'NoneType'> at position 0.

Fix Guide
The error indicates that the logits passed to the softmax cross‑entropy
function are None. This happens because your model’s
forward method doesn’t return anything—it ends with a pass statement.
As a result, when you call: logits = model.apply(params, x), logits is None,
which then causes the error in optax.softmax_cross_entropy.
To fix this, you need to ensure your model returns a value.

Correct Code:
class VanillaCNNModel(nn.Module):
    @nn.compact
    def __call__(self, x):
        # Define the forward pass here
        x = nn.Conv(features=32, kernel_size=(3,3), padding='SAME')(x)
        x = nn.relu(x)
        x = x.reshape((x.shape[0], -1))
        x = nn.Dense(features=NUM_CLASSES)(x)
        return x

"""



In [10]:
## Fixed Code
import jax
import jax.numpy as jnp
from jax import random, grad, jit, vmap
import flax.linen as nn
from flax.training import train_state
import optax

# Constants
LEARNING_RATE = 0.001
NUM_EPOCHS = 10
BATCH_SIZE = 32
NUM_CLASSES = 10
INPUT_SHAPE = (28, 28, 1)

# Define model (VanillaCNNModel is assumed to be defined elsewhere)
class VanillaCNNModel(nn.Module):
    @nn.compact
    def __call__(self, x):
        # Define the forward pass here
        x = nn.Conv(features=32, kernel_size=(3,3), padding='SAME')(x)
        x = nn.relu(x)
        x = x.reshape((x.shape[0], -1))
        x = nn.Dense(features=NUM_CLASSES)(x)
        return x

def create_train_state(rng, model, learning_rate):
    # Initialize the model parameters
    params = model.init(rng, jnp.ones((1, *INPUT_SHAPE)))  # MODIFIED: Input shape for initialization
    tx = optax.adam(learning_rate)
    return train_state.TrainState.create(apply_fn=model.apply, params=params, tx=tx)

@jit
def loss_fn(params, x, y):
    # Compute the loss function
    model = VanillaCNNModel()
    logits = model.apply(params, x)
    loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=y))
    return loss

@jit
def compute_gradients(params, x, y):
    # Compute gradients
    return grad(loss_fn)(params, x, y)

def update(params, grads):
    # Update parameters
    return optax.apply_updates(params, grads)  # MODIFIED: Use functional update

def train_model(x_train, y_train, num_epochs, batch_size):
    rng = random.PRNGKey(0)  # PRNG key for reproducibility
    model = VanillaCNNModel()
    state = create_train_state(rng, model, learning_rate=LEARNING_RATE)

    for epoch in range(num_epochs):
        for i in range(0, len(x_train), batch_size):
            x_batch = x_train[i:i + batch_size]
            y_batch = y_train[i:i + batch_size]

            grads = compute_gradients(state.params, x_batch, y_batch)
            state = state.apply_gradients(grads=grads)  # MODIFIED: Use functional updates to apply gradients

    return state.params  # Return final weights

def main():
    # Sample training data (x_train, y_train should be defined appropriately)
    x_train = jnp.ones((100, *INPUT_SHAPE))  # Placeholder, replace with actual data
    y_train = jax.nn.one_hot(jnp.zeros(100), num_classes=NUM_CLASSES)  # Placeholder, replace with actual labels
    model = VanillaCNNModel()
    final_weights = train_model(x_train, y_train, NUM_EPOCHS, BATCH_SIZE)
    print('Final weights:', final_weights)  # Display final weights after training

if __name__ == "__main__":
    main()


Final weights: {'params': {'Conv_0': {'bias': Array([-0.00589933,  0.00589946, -0.00589931,  0.        ,  0.        ,
        0.00589939,  0.00589934,  0.00589948,  0.00589937, -0.00589925,
        0.00589934,  0.00589936, -0.00589932,  0.00589937, -0.00589933,
        0.        ,  0.00589935, -0.0058993 ,  0.0058994 , -0.00589933,
        0.        ,  0.00589951,  0.00589934,  0.00589934, -0.00589933,
       -0.00589933, -0.00589932, -0.00589933,  0.00589936, -0.00589932,
        0.        ,  0.        ], dtype=float32), 'kernel': Array([[[[ 2.30005816e-01, -9.06157047e-02,  1.49404481e-01,
          -4.02153790e-01,  1.22672096e-01,  2.82673627e-01,
          -8.79966170e-02, -1.54289052e-01,  1.99353799e-01,
          -1.80342361e-01,  3.05014495e-02,  3.59651685e-01,
          -5.17202854e-01,  9.60292146e-02, -2.65478969e-01,
           1.55264989e-01, -4.04435843e-02,  2.88439449e-03,
          -2.06540637e-02,  1.31564915e-01,  5.71857914e-02,
          -6.87380672e-01,  3.24100