In [None]:
## Original Code
import torch
import torch.nn as nn
import torch.optim as optim
from torch.quantization import quantize_dynamic

# Define a simple Language Model (e.g., an LSTM-based model)
class LanguageModel(nn.Module):
    def __init__(self, vocab_size, embed_size, hidden_size, num_layers):
        super(LanguageModel, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embed_size)
        self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, vocab_size)
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        embedded = self.embedding(x)
        lstm_out, (hidden, cell) = self.lstm(embedded)
        output = self.fc(lstm_out[:, -1, :])  # Use the last hidden state for prediction
        return self.softmax(output)

# Create synthetic training data
torch.manual_seed(42)
vocab_size = 50
seq_length = 10
batch_size = 32
X_train = torch.randint(0, vocab_size, (batch_size, seq_length))  # Random integer input
y_train = torch.randint(0, vocab_size, (batch_size,))  # Random target words

# Initialize the model, loss function, and optimizer
embed_size = 64
hidden_size = 128
num_layers = 2
model = LanguageModel(vocab_size, embed_size, hidden_size, num_layers)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training loop
epochs = 5
for epoch in range(epochs):
    model.train()
    optimizer.zero_grad()
    output = model(X_train)
    loss = criterion(output, y_train)
    loss.backward()
    optimizer.step()

    # Log progress every epoch
    print(f"Epoch [{epoch + 1}/{epochs}] - Loss: {loss.item():.4f}")

# Now, we will quantize the model dynamically to reduce its size and improve inference speed
# Quantization: Apply dynamic quantization to the language model
quantized_model = quantize_dynamic(model, {nn.Linear, nn.LSTM}, dtype=torch.qint8)

# Save the quantized model
torch.save(quantized_model.state_dict(), "quantized_language_model.pth")

# Load the quantized model and test it
quantized_model = LanguageModel(vocab_size, embed_size, hidden_size, num_layers)

# Apply dynamic quantization on the model after defining it
quantized_model = quantize_dynamic(quantized_model, {nn.Linear, nn.LSTM}, dtype=torch.qint8)

# quantized_model.load_state_dict(torch.load("quantized_language_model.pth"))
quantized_model.eval()
test_input = torch.randint(0, vocab_size, (1, seq_length))
with torch.no_grad():
    prediction = quantized_model(test_input)
    print(f"Prediction for input {test_input.tolist()}: {prediction.argmax(dim=1).item()}")

Epoch [1/5] - Loss: 3.9118
Epoch [2/5] - Loss: 3.9113
Epoch [3/5] - Loss: 3.9108
Epoch [4/5] - Loss: 3.9103
Epoch [5/5] - Loss: 3.9097
Prediction for input [[15, 28, 33, 19, 37, 24, 48, 42, 33, 35]]: 46


In [None]:
## Strong LLM
import jax
import jax.numpy as jnp
import flax.linen as nn
import optax
import pickle

# ---------------------------------------------------------------------
# Define an LSTM stack that processes the sequence step by step.
# We use Flax’s LSTMCell and a Python loop to build a multi-layer LSTM.
# ---------------------------------------------------------------------
class LSTMStack(nn.Module):
    hidden_size: int
    num_layers: int

    @nn.compact
    def __call__(self, x):
        # x: (batch, seq_length, embed_size)
        batch_size = x.shape[0]
        # Initialize LSTM states (carry, hidden) for each layer.
        states = [
            nn.LSTMCell.initialize_carry(self.make_rng('lstm'), (batch_size,), self.hidden_size)
            for _ in range(self.num_layers)
        ]
        out = None
        seq_length = x.shape[1]
        # Process each time step sequentially.
        for t in range(seq_length):
            inp = x[:, t, :]
            new_states = []
            for i in range(self.num_layers):
                # Create an LSTM cell for layer i (parameters are registered by name).
                lstm_cell = nn.LSTMCell(name=f'lstm_cell_{i}', hidden_size=self.hidden_size)
                # Update state and get output.
                states[i], out = lstm_cell(states[i], inp)
                # For next layer, use the output from the current layer.
                inp = out
                new_states.append(states[i])
            states = new_states
        # Return the output of the last time step (from the last layer).
        return out

# ---------------------------------------------------------------------
# Define the LanguageModel using Flax modules.
# It embeds the input tokens, processes them through the LSTM stack,
# applies a Dense layer, and returns softmax probabilities.
# ---------------------------------------------------------------------
class LanguageModel(nn.Module):
    vocab_size: int
    embed_size: int
    hidden_size: int
    num_layers: int

    @nn.compact
    def __call__(self, x):
        # x has shape (batch, seq_length) containing token indices.
        x_embed = nn.Embed(num_embeddings=self.vocab_size, features=self.embed_size)(x)
        lstm_out = LSTMStack(hidden_size=self.hidden_size, num_layers=self.num_layers)(x_embed)
        logits = nn.Dense(features=self.vocab_size)(lstm_out)
        probabilities = nn.softmax(logits)
        return probabilities

# ---------------------------------------------------------------------
# Create synthetic training data (similar to torch.randint).
# ---------------------------------------------------------------------
key = jax.random.PRNGKey(42)
vocab_size = 50
seq_length = 10
batch_size = 32
X_train = jax.random.randint(key, (batch_size, seq_length), 0, vocab_size)
key, subkey = jax.random.split(key)
y_train = jax.random.randint(subkey, (batch_size,), 0, vocab_size)

# ---------------------------------------------------------------------
# Initialize the model, loss, and optimizer.
# ---------------------------------------------------------------------
embed_size = 64
hidden_size = 128
num_layers = 2
model = LanguageModel(
    vocab_size=vocab_size,
    embed_size=embed_size,
    hidden_size=hidden_size,
    num_layers=num_layers,
)

# Initialize model parameters.
# Note: We pass two PRNG keys – one for parameters and one for the LSTM initialization.
variables = model.init({'params': key, 'lstm': key}, X_train)
params = variables['params']

# Define a simple cross-entropy loss.
def loss_fn(params, x, y):
    preds = model.apply({'params': params}, x)
    # Compute the negative log likelihood for the true classes.
    loss = -jnp.mean(jnp.log(preds[jnp.arange(preds.shape[0]), y] + 1e-7))
    return loss

# Set up the Adam optimizer using Optax.
optimizer = optax.adam(learning_rate=0.001)
opt_state = optimizer.init(params)

@jax.jit
def train_step(params, opt_state, x, y):
    loss, grads = jax.value_and_grad(loss_fn)(params, x, y)
    updates, opt_state = optimizer.update(grads, opt_state)
    new_params = optax.apply_updates(params, updates)
    return new_params, opt_state, loss

# ---------------------------------------------------------------------
# Training loop (5 epochs).
# ---------------------------------------------------------------------
epochs = 5
for epoch in range(epochs):
    params, opt_state, loss_val = train_step(params, opt_state, X_train, y_train)
    print(f"Epoch [{epoch + 1}/{epochs}] - Loss: {loss_val:.4f}")

# ---------------------------------------------------------------------
# Quantization: Simulate dynamic quantization by converting parameters to int8 and back.
# This is only a simple simulation; JAX does not offer PyTorch-like dynamic quantization.
# ---------------------------------------------------------------------
def quantize_param(param):
    scale = jnp.max(jnp.abs(param))
    # Avoid division by zero.
    scale = jnp.where(scale == 0, 1.0, scale)
    # Scale, round, and cast to int8.
    param_int8 = jnp.round(param / scale * 127).astype(jnp.int8)
    # Dequantize back to float32.
    return param_int8.astype(jnp.float32) * scale / 127

def quantize_params(params):
    if isinstance(params, dict):
        return {k: quantize_params(v) for k, v in params.items()}
    else:
        return quantize_param(params)

quantized_params = quantize_params(params)

# Save the quantized parameters (similarly to torch.save).
with open("quantized_language_model.pkl", "wb") as f:
    pickle.dump(quantized_params, f)

# Load the quantized parameters.
with open("quantized_language_model.pkl", "rb") as f:
    loaded_params = pickle.load(f)

# ---------------------------------------------------------------------
# Evaluate the quantized model.
# ---------------------------------------------------------------------
test_input = jax.random.randint(key, (1, seq_length), 0, vocab_size)
predictions = model.apply({'params': loaded_params}, test_input)
predicted_class = jnp.argmax(predictions, axis=1)
print(f"Prediction for input {test_input.tolist()}: {int(predicted_class[0])}")


NameError: name 'lstm_rngs' is not defined

In [None]:
## Weak LLM
import jax
import jax.numpy as jnp
import flax.linen as nn
import optax
import numpy as np

class LSTM(nn.Module):
    # Your LSTM implementation here

    def __call__(self, x):
        # Forward pass logic for LSTM


def process_sequence(inputs, prng_key):  # MODIFIED: Added prng_key as a parameter
    # Instead of using Python loops, use a JAX scan to process the sequence
    def step(carry, input_data):
        # Define the operation per timestep
        # Note: Include the logic for LSTM cell operations here
        carry = carry  # update the carry state here based on LSTM operations
        return carry, carry  # return updated state and output

    # Use `jax.lax.scan` for efficient looping over the inputs
    initial_carry = jnp.zeros((inputs.shape[0],))  # or appropriate shape
    outputs, _ = jax.lax.scan(step, initial_carry, inputs)
    return outputs

def loss_fn(params, X, y):
    # Your loss function implementation here
    return jnp.mean((X - y) ** 2)  # Example loss calculation

def main():
    # Initialize your parameters and data here
    batch_size = 32
    input_size = 10
    num_epochs = 100
    key = jax.random.PRNGKey(0)  # Initialize PRNG key

    # Example inputs; replace with actual data loading logic
    X_train = jax.random.normal(key, (batch_size, input_size))
    y_train = jax.random.normal(key, (batch_size, input_size))

    # Initialize model, optimizer, etc.
    model = LSTM()  # Initialize the LSTM model
    params = model.init(key, X_train)  # Initialize model parameters
    optimizer = optax.adam(learning_rate=0.001)  # Example optimizer
    opt_state = optimizer.init(params)

    for epoch in range(num_epochs):
        key, subkey = jax.random.split(key)  # MODIFIED: Split the PRNG key for each iteration
        outputs = process_sequence(X_train, subkey)  # MODIFIED: Pass subkey to process_sequence
        current_loss = loss_fn(params, outputs, y_train)  # Calculate loss based on outputs

        # Update weights, optimizer state, etc.
        grad = jax.grad(loss_fn)(params, outputs, y_train)  # Compute gradients
        updates, opt_state = optimizer.update(grad, opt_state)  # Update optimizer state
        params = optax.apply_updates(params, updates)  # Apply updates to parameters

        # Log progress every epoch
        print(f"Epoch [{epoch + 1}/{num_epochs}] - Loss: {current_loss:.4f}")

if __name__ == "__main__":
    main()

ValueError: Incompatible shapes for broadcasting: shapes=[(32,), (32, 10)]

In [None]:
"""Error Code
def loss_fn(params, X, y):
    # Your loss function implementation here
    return jnp.mean((X - y) ** 2)  # Example loss calculation

Error:

ValueError: Incompatible shapes for broadcasting: shapes=[(32,), (32, 10)]
Fix guide:
Fix broadcasting shape error and RNG issue

Correct code
 def __call__(self, x):
    # Forward pass logic for LSTM
    rng = self.make_rng('lstm')
    return process_sequence(x, rng)
"""

"""Error Code
outputs = jnp.swapaxes(outputs, 0, 1)
return outputs

Error:
IndexError: index 1 is out of bounds for axis 0 with size 1

Fix guide
Add the @nn.compact Decorator @nn.compact

Correct code
@nn.compact
def __call__(self, x):
  # Dummy Dense layer to force parameter creation and proper input tracing.
  x = nn.Dense(features=x.shape[-1], name='dummy_dense')(x)
  rng = self.make_rng('lstm')
  return process_sequence(x, rng)
"""

"""Error Code
x = nn.Dense(features=x.shape[-1], name='dummy_dense')(x)
Error:
AssignSubModuleError: Submodule Dense must be defined in `setup()` or in a method wrapped in `@compact` (https://flax.readthedocs.io/en/latest/api_reference/flax.errors.html#flax.errors.AssignSubModuleError)

Fix guide

Define the Submodule in setup():
Create the Dense submodule in the module’s setup() method, where you can use a fixed (or pre-determined) number of output features. For instance, if you know the sequence length in advance, use that as the feature count.
Pass the Feature Count as an Argument:
Modify the module to accept a features parameter. Then, in setup(), instantiate the dummy Dense layer using this parameter.
Update the __call__ Method:
In __call__, call the pre-defined Dense submodule and then continue with your processing logic.

Correct Code
x = self.dummy_dense(x)
"""

"""
Error Code
outputs, _ = jax.lax.scan(step, initial_carry, inputs)
outputs = jnp.swapaxes(outputs, 0, 1)

Error
IndexError: index 1 is out of bounds for axis 0 with size 1

Fix guide
ensure outputs has at least 2 dimensions before calling jnp.swapaxes

Correct Code
if outputs.ndim == 1:  # If outputs is 1D, add a dimension
      outputs = outputs[:, None]  # Shape: (seq_length,) -> (seq_length, 1)
      outputs = jnp.swapaxes(outputs, 0, 1)
"""

"""Error Code
if outputs.ndim == 1:  # If outputs is 1D, add a dimension
      outputs = outputs[:, None]  # Shape: (seq_length,) -> (seq_length, 1)
      outputs = jnp.swapaxes(outputs, 0, 1

Error
IndexError: index 1 is out of bounds for axis 0 with size 1

Fix Guide
ensure outputs is at least 2D before calling jnp.swapaxes

Correct Code
if outputs.ndim == 1:
  outputs = outputs[None, :]
outputs = jnp.swapaxes(outputs, 0, 1)
"""

In [6]:
## Fixed Code
import jax
import jax.numpy as jnp
import flax.linen as nn
import optax

class LSTM(nn.Module):
    features: int  # Expected to be the sequence length (e.g., 10)

    def setup(self):
        # Define the Dense submodule with a fixed feature size.
        self.dummy_dense = nn.Dense(features=self.features, name='dummy_dense')

    def __call__(self, x):
        # Use the pre-defined Dense layer.
        x = self.dummy_dense(x)
        rng = self.make_rng('lstm')
        return process_sequence(x, rng)

def process_sequence(inputs, prng_key):
    # Ensure inputs is 2D (batch, seq_length).
    assert inputs.ndim == 2, f"Expected inputs to be 2D (batch, seq_length), got {inputs.shape}"
    # Transpose to (seq_length, batch) for scanning over time.
    inputs = jnp.swapaxes(inputs, 0, 1)

    def step(carry, input_data):
        # Example step: add the input to the carry.
        new_carry = carry + input_data
        return new_carry, new_carry

    batch = inputs.shape[1]  # Number of samples in the batch.
    initial_carry = jnp.zeros((batch,))
    outputs, _ = jax.lax.scan(step, initial_carry, inputs)
    if outputs.ndim == 1:  # If seq_length = 1, outputs is 1D
        outputs = outputs[None, :]
    #   outputs = jnp.swapaxes(outputs, 0, 1)
    outputs = jnp.swapaxes(outputs, 0, 1)
    return outputs

def loss_fn(params, X, y):
    # Use a fixed RNG for LSTM operations during loss computation.
    preds = model.apply({'params': params}, X, rngs={'lstm': jax.random.PRNGKey(0)})
    return jnp.mean((preds - y) ** 2)

def main():
    batch_size = 32
    seq_length = 10  # Number of time steps.
    num_epochs = 100
    key = jax.random.PRNGKey(0)

    # Generate explicit training data of shape (batch, seq_length).
    X_train = jax.random.normal(key, (batch_size, seq_length))
    key, subkey = jax.random.split(key)
    y_train = jax.random.normal(subkey, (batch_size, seq_length))

    global model
    # Initialize LSTM with fixed feature size (equal to seq_length).
    model = LSTM(features=seq_length)
    # Use separate PRNG keys for parameters and LSTM operations.
    params_key, lstm_key = jax.random.split(key)
    variables = model.init({'params': params_key, 'lstm': lstm_key}, X_train)
    params = variables['params']
    optimizer = optax.adam(learning_rate=0.001)
    opt_state = optimizer.init(params)

    for epoch in range(num_epochs):
        key, subkey = jax.random.split(key)
        outputs = process_sequence(X_train, subkey)
        current_loss = loss_fn(params, X_train, y_train)
        grad = jax.grad(loss_fn)(params, X_train, y_train)
        updates, opt_state = optimizer.update(grad, opt_state)
        params = optax.apply_updates(params, updates)
        print(f"Epoch [{epoch + 1}/{num_epochs}] - Loss: {current_loss:.4f}")

if __name__ == "__main__":
    main()



Epoch [1/100] - Loss: 19.6439
Epoch [2/100] - Loss: 19.3840
Epoch [3/100] - Loss: 19.1270
Epoch [4/100] - Loss: 18.8730
Epoch [5/100] - Loss: 18.6221
Epoch [6/100] - Loss: 18.3743
Epoch [7/100] - Loss: 18.1296
Epoch [8/100] - Loss: 17.8881
Epoch [9/100] - Loss: 17.6499
Epoch [10/100] - Loss: 17.4148
Epoch [11/100] - Loss: 17.1830
Epoch [12/100] - Loss: 16.9545
Epoch [13/100] - Loss: 16.7293
Epoch [14/100] - Loss: 16.5074
Epoch [15/100] - Loss: 16.2889
Epoch [16/100] - Loss: 16.0736
Epoch [17/100] - Loss: 15.8617
Epoch [18/100] - Loss: 15.6531
Epoch [19/100] - Loss: 15.4478
Epoch [20/100] - Loss: 15.2457
Epoch [21/100] - Loss: 15.0470
Epoch [22/100] - Loss: 14.8515
Epoch [23/100] - Loss: 14.6592
Epoch [24/100] - Loss: 14.4700
Epoch [25/100] - Loss: 14.2841
Epoch [26/100] - Loss: 14.1012
Epoch [27/100] - Loss: 13.9214
Epoch [28/100] - Loss: 13.7446
Epoch [29/100] - Loss: 13.5708
Epoch [30/100] - Loss: 13.3999
Epoch [31/100] - Loss: 13.2319
Epoch [32/100] - Loss: 13.0667
Epoch [33/100] - 