In [1]:
#Input
import torch
import torch.nn as nn
import torch.optim as optim

# Define a Transformer Model
class TransformerModel(nn.Module):
    def __init__(self, input_dim, embed_dim, num_heads, num_layers, ff_dim, output_dim):
        super(TransformerModel, self).__init__()
        self.embedding = nn.Linear(input_dim, embed_dim)
        encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads, dim_feedforward=ff_dim)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.output_layer = nn.Linear(embed_dim, output_dim)

    def forward(self, x):
        x = self.embedding(x)
        x = self.transformer(x)
        x = x.mean(dim=1)  # Pooling across the sequence
        return self.output_layer(x)
    
# Generate synthetic data
torch.manual_seed(42)
seq_length = 10
num_samples = 100
input_dim = 1
X = torch.rand(num_samples, seq_length, input_dim)  # Random sequences
y = torch.sum(X, dim=1)  # Target is the sum of each sequence

# Initialize the model, loss function, and optimizer
input_dim = 1
embed_dim = 16
num_heads = 2
num_layers = 2
ff_dim = 64
output_dim = 1

model = TransformerModel(input_dim, embed_dim, num_heads, num_layers, ff_dim, output_dim)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training loop
epochs = 1000
for epoch in range(epochs):
    # Forward pass
    predictions = model(X)
    loss = criterion(predictions, y)

    # Backward pass and optimization
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    # Log progress every 100 epochs
    if (epoch + 1) % 100 == 0:
        print(f"Epoch [{epoch + 1}/{epochs}], Loss: {loss.item():.4f}")

# Testing on new data
X_test = torch.rand(2, seq_length, input_dim)
with torch.no_grad():
    predictions = model(X_test)
    print(f"Predictions for {X_test.tolist()}: {predictions.tolist()}")

Epoch [100/1000], Loss: 1.5825
Epoch [200/1000], Loss: 0.8951
Epoch [300/1000], Loss: 0.7411
Epoch [400/1000], Loss: 0.3951
Epoch [500/1000], Loss: 0.2077
Epoch [600/1000], Loss: 0.1377
Epoch [700/1000], Loss: 0.0902
Epoch [800/1000], Loss: 0.0622
Epoch [900/1000], Loss: 0.0366
Epoch [1000/1000], Loss: 0.0394
Predictions for [[[0.8894234895706177], [0.05691683292388916], [0.8250501155853271], [0.4220901131629944], [0.2867562174797058], [0.20837479829788208], [0.9482327699661255], [0.5099708437919617], [0.058805227279663086], [0.8394275903701782]], [[0.23448842763900757], [0.7937500476837158], [0.026732206344604492], [0.6515005230903625], [0.7079432010650635], [0.7461163997650146], [0.1356351375579834], [0.15332472324371338], [0.03719818592071533], [0.44056856632232666]]]: [[5.098686218261719], [4.298356056213379]]


In [2]:
#Strong LLM
import jax
import jax.numpy as jnp
from flax import linen as nn
import optax
import numpy as np


class TransformerEncoderBlock(nn.Module):
    embed_dim: int
    num_heads: int
    ff_dim: int
    dropout_rate: float = 0.1

    @nn.compact
    def __call__(self, x, *, train):
        # Multi-head self-attention; note: dropout is disabled when deterministic=True.
        attn_output = nn.SelfAttention(num_heads=self.num_heads,
                                       qkv_features=self.embed_dim,
                                       dropout_rate=self.dropout_rate,
                                       deterministic=not train)(x)
        # Add & Norm
        x = nn.LayerNorm()(x + attn_output)
        # Feedforward network
        ff_output = nn.Dense(self.ff_dim)(x)
        ff_output = nn.relu(ff_output)
        ff_output = nn.Dense(self.embed_dim)(ff_output)
        # Add & Norm
        x = nn.LayerNorm()(x + ff_output)
        return x

class TransformerModel(nn.Module):
    input_dim: int
    embed_dim: int
    num_heads: int
    num_layers: int
    ff_dim: int
    output_dim: int
    dropout_rate: float = 0.1

    @nn.compact
    def __call__(self, x, *, train=True):
        # x shape: (batch, seq_length, input_dim)
        # Map input to embedding space.
        x = nn.Dense(self.embed_dim)(x)
        # Pass through a stack of Transformer encoder blocks.
        for _ in range(self.num_layers):
            x = TransformerEncoderBlock(embed_dim=self.embed_dim,
                                        num_heads=self.num_heads,
                                        ff_dim=self.ff_dim,
                                        dropout_rate=self.dropout_rate)(x, train=train)
        # Pool across the sequence dimension (mean pooling).
        x = jnp.mean(x, axis=1)
        # Map to output.
        x = nn.Dense(self.output_dim)(x)
        return x

key = jax.random.PRNGKey(42)
seq_length = 10
num_samples = 100
input_dim = 1

# Generate random sequences with shape (num_samples, seq_length, input_dim)
key, subkey = jax.random.split(key)
X = jax.random.uniform(subkey, shape=(num_samples, seq_length, input_dim))
# Target: sum of each sequence, resulting in shape (num_samples, 1)
y = jnp.sum(X, axis=1)

# Hyperparameters
embed_dim = 16
num_heads = 2
num_layers = 2
ff_dim = 64
output_dim = 1

# Create the model instance.
model = TransformerModel(input_dim=input_dim,
                         embed_dim=embed_dim,
                         num_heads=num_heads,
                         num_layers=num_layers,
                         ff_dim=ff_dim,
                         output_dim=output_dim)

# Initialize model parameters using a dummy input.
dummy_input = jnp.ones((num_samples, seq_length, input_dim))
rng = jax.random.PRNGKey(0)
params = model.init(rng, dummy_input, train=True)

# Define Mean Squared Error loss.
def mse_loss(params, x, y, train):
    preds = model.apply(params, x, train=train)
    return jnp.mean((preds - y) ** 2)

# Setup the Adam optimizer.
optimizer = optax.adam(learning_rate=0.001)
opt_state = optimizer.init(params)

epochs = 1000

@jax.jit
def train_step(params, opt_state, x, y):
    loss, grads = jax.value_and_grad(mse_loss)(params, x, y, True)
    updates, opt_state = optimizer.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    return params, opt_state, loss

for epoch in range(epochs):
    params, opt_state, loss = train_step(params, opt_state, X, y)
    if (epoch + 1) % 100 == 0:
        print(f"Epoch [{epoch + 1}/{epochs}], Loss: {loss:.4f}")

# Generate new random test data with shape (2, seq_length, input_dim)
key, subkey = jax.random.split(key)
X_test = jax.random.uniform(subkey, shape=(2, seq_length, input_dim))
predictions = model.apply(params, X_test, train=False)
print("Predictions for", np.array(X_test).tolist(), ":", np.array(predictions).tolist())

InvalidRngError: SelfAttention_0 needs PRNG for "dropout" (https://flax.readthedocs.io/en/latest/api_reference/flax.errors.html#flax.errors.InvalidRngError)

In [3]:
#Weak LLM
import jax
import jax.numpy as jnp  # Ensured consistent import statement
from flax import linen as nn
from jax import random, grad, jit, vmap
import optax


class SimpleModel(nn.Module):
    """A simple feedforward neural network model."""
    @nn.compact
    def __call__(self, x):
        """Forward pass of the model."""
        x = nn.Dense(128)(x)
        x = nn.relu(x)
        x = nn.Dense(1)(x)
        return x


def create_model() -> SimpleModel:
    """Create an instance of the SimpleModel."""
    return SimpleModel()


def compute_loss(logits, labels):
    """Compute the binary cross-entropy loss."""
    return jnp.mean(jax.nn.sigmoid_cross_entropy(logits=logits, labels=labels))


def accuracy(logits, labels):
    """Calculate the accuracy of the model predictions."""
    preds = jnp.round(jax.nn.sigmoid(logits))
    return jnp.mean(preds == labels)


@jit
def train_step(optimizer, model, batch):
    """Perform a single training step."""
    def loss_fn(params):
        logits = model.apply({'params': params}, batch['X'])
        return compute_loss(logits, batch['y'])
    
    grads = grad(loss_fn)(optimizer.target)
    optimizer = optimizer.apply_gradient(grads)
    return optimizer


def train_model(X, y, num_epochs, key):
    """Train the model on the provided data."""
    model = create_model()
    params = model.init(key, jnp.ones((1, X.shape[1])))
    optimizer = optax.adam(learning_rate=0.001).init(params)

    dataset_size = X.shape[0]
    
    for epoch in range(num_epochs):
        # Shuffle dataset
        perm = random.permutation(key, dataset_size)
        X_shuffled = X[perm]
        y_shuffled = y[perm]
        
        for i in range(0, dataset_size, 32):
            batch = {
                'X': X_shuffled[i:i + 32],
                'y': y_shuffled[i:i + 32]
            }
            optimizer = train_step(optimizer, model, batch)
        
        # Example log after each epoch
        logits = model.apply({'params': optimizer.target}, X)
        train_acc = accuracy(logits, y)
        print(f"Epoch {epoch + 1}, Train Accuracy: {train_acc:.4f}")


def main():
    """Main entry point for the script."""
    # Example data generation with explicit PRNG key
    key = random.PRNGKey(0)  # Initialize PRNG key
    X = random.uniform(key, (1000, 10))  # MODIFIED: Added explicit PRNG key
    y = jnp.array([0, 1] * 500)  # Sample labels

    num_epochs = 10
    train_model(X, y, num_epochs, key)  # MODIFIED: pass key to train_model


if __name__ == "__main__":
    main()

TypeError: Cannot interpret value of type <class '__main__.SimpleModel'> as an abstract array; it does not have a dtype attribute

In [None]:
"""
Error Code:
class SimpleModel(nn.Module):
    @nn.compact
    def __call__(self, x):
        x = nn.Dense(128)(x)
        x = nn.relu(x)
        x = nn.Dense(1)(x)
        return x


Error:
The translated JAX code implements a simple fully connected network, which is inconsistent with the Transformer model implemented in the original PyTorch code.


Fix Guide:
Use Flax to implement a Transformer model. The steps include:
Use a Dense layer to implement input embedding
Implement the Transformer encoder layer
Mean pooling on the sequence dimension
Connect to the output layer to get the final regression result


Correct Code:
class TransformerEncoderLayer(nn.Module):
    embed_dim: int
    num_heads: int
    ff_dim: int
    dropout_rate: float = 0.1

    @nn.compact
    def __call__(self, x, train: bool = True):
        attn = nn.SelfAttention(num_heads=self.num_heads,
                                qkv_features=self.embed_dim,
                                dropout_rate=self.dropout_rate,
                                deterministic=not train)(x)
        x = x + attn
        x = nn.LayerNorm()(x)

        ff = nn.Dense(self.ff_dim)(x)
        ff = nn.relu(ff)
        ff = nn.Dense(self.embed_dim)(ff)
        x = x + ff
        x = nn.LayerNorm()(x)
        return x

class TransformerModel(nn.Module):
    input_dim: int
    embed_dim: int
    num_heads: int
    num_layers: int
    ff_dim: int
    output_dim: int
    dropout_rate: float = 0.1

    @nn.compact
    def __call__(self, x, train: bool = True):
        x = nn.Dense(self.embed_dim)(x)

        for _ in range(self.num_layers):
            x = TransformerEncoderLayer(embed_dim=self.embed_dim,
                                        num_heads=self.num_heads,
                                        ff_dim=self.ff_dim,
                                        dropout_rate=self.dropout_rate)(x, train=train)

        x = jnp.mean(x, axis=1)
        x = nn.Dense(self.output_dim)(x)
        return x
"""


"""
Error Code:
def compute_loss(logits, labels):
    return jnp.mean(jax.nn.sigmoid_cross_entropy(logits=logits, labels=labels))


Error:
The original PyTorch code is a regression task. The goal is to calculate the sequence and the mean square error (MSE) loss should be used


Fix Guide:
Modify the loss function to mean square error


Correct Code:
def compute_loss(predictions, targets):
    return jnp.mean((predictions - targets) ** 2)
"""


"""
Error Code:
# Example data generation with explicit PRNG key
key = random.PRNGKey(0)
X = random.uniform(key, (1000, 10))  # MODIFIED: Added explicit PRNG key
y = jnp.array([0, 1] * 500)  # Sample labels


Error:
The generated data X lacks feature dimensions. The shape of X in the original code should be (num_samples, seq_length, input_dim)
The generated labels y are alternating 0 and 1, which does not match the goal of the regression task (summing the sequence elements)


Fix Guide:
Change the shape of X to (num_samples, seq_length, input_dim)
Define y as the sum of X along the sequence dimension, i.e. y = jnp.sum(X, axis=1)


Correct Code:
key = random.PRNGKey(0)
num_samples = 100
seq_length = 10
input_dim = 1
X = random.uniform(key, (num_samples, seq_length, input_dim))
y = jnp.sum(X, axis=1)
"""


"""
Error Code:
def train_step(optimizer, model, batch):
    def loss_fn(params):
        logits = model.apply({'params': params}, batch['X'])
        return compute_loss(logits, batch['y'])
    
    grads = grad(loss_fn)(optimizer.target)
    optimizer = optimizer.apply_gradient(grads)
    return optimizer


Error:
The optimizer usage does not match


Fix Guide:
Use the TrainState class to encapsulate the parameters and optimizer state, and call state.apply_gradients to update them in the training step


Correct Code:
from flax.training import train_state

def create_train_state(rng, model, learning_rate, input_shape):
    params = model.init(rng, jnp.ones(input_shape))['params']
    tx = optax.adam(learning_rate)
    return train_state.TrainState.create(apply_fn=model.apply, params=params, tx=tx)

@jit
def train_step(state, batch):
    def loss_fn(params):
        predictions = state.apply_fn({'params': params}, batch['X'], train=True)
        loss = compute_loss(predictions, batch['y'])
        return loss, predictions
    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (loss, preds), grads = grad_fn(state.params)
    new_state = state.apply_gradients(grads=grads)
    return new_state, loss
"""


"""
Error Code:
# Shuffle dataset
perm = random.permutation(key, dataset_size)
X_shuffled = X[perm]
y_shuffled = y[perm]


Error:
Repeatedly using the same PRNG key for random operations will result in the same random sequence being generated each time in JAX


Fix Guide:
Use random.split to generate a new key before each random operation to ensure randomness


Correct Code:
rng, key = random.split(rng)
perm = random.permutation(key, dataset_size)
X_shuffled = X[perm]
y_shuffled = y[perm]
"""


"""
Error Code:
def train_model(X, y, num_epochs, key):
    ...
    for epoch in range(num_epochs):
        # Shuffle dataset
        rng, key = random.split(rng)
        perm = random.permutation(key, dataset_size)
        ...


Error:
local variable 'rng' referenced before assignment


Fix Guide:
Need to keep the variable name of the random number generator consistent


Correct Code:
def train_model(X, y, num_epochs, key):
    dataset_size = X.shape[0]
    
    for epoch in range(num_epochs):
        # Shuffle dataset
        key, subkey = random.split(key)
        perm = random.permutation(subkey, dataset_size)
        X_shuffled = X[perm]
        y_shuffled = y[perm]
        
        for i in range(0, dataset_size, 32):
            batch = {
                'X': X_shuffled[i:i + 32],
                'y': y_shuffled[i:i + 32]
            }
            optimizer = train_step(optimizer, batch)
        
        # Example log after each epoch
        logits = model.apply({'params': optimizer.target}, X)
        train_acc = accuracy(logits, y)
        print(f"Epoch {epoch + 1}, Train Accuracy: {train_acc:.4f}")
"""


"""
Error Code:
def train_model(X, y, num_epochs, key):
    model = create_model()
    params = model.init(key, jnp.ones((1, X.shape[1])))
    optimizer = optax.adam(learning_rate=0.001).init(params)
    ...
    for i in range(0, dataset_size, 32):
        batch = {
            'X': X_shuffled[i:i + 32],
            'y': y_shuffled[i:i + 32]
        }
        optimizer = train_step(optimizer, model, batch)
    ...


Error:
local variable 'optimizer' referenced before assignment


Fix Guide:
Use TrainState to manage model parameters and optimizer state.
Define a create_train_state function, use the model's init method and optax optimizer to create a training state
Call this function in train_model to generate a training state, and use this object in subsequent training steps
At the same time, modify the calling method of train_step, and return (new_state, loss), and need to receive these two return values


Correct Code:
def train_model(X, y, num_epochs, key):
    model = TransformerModel(
        input_dim=1,
        embed_dim=16,
        num_heads=2,
        num_layers=2,
        ff_dim=64,
        output_dim=1
    )
    state = create_train_state(key, model, learning_rate=0.001, input_shape=X.shape)

    dataset_size = X.shape[0]
    
    for epoch in range(num_epochs):
        key, subkey = random.split(key)
        perm = random.permutation(subkey, dataset_size)
        X_shuffled = X[perm]
        y_shuffled = y[perm]
        
        for i in range(0, dataset_size, 32):
            batch = {
                'X': X_shuffled[i:i + 32],
                'y': y_shuffled[i:i + 32]
            }
            state, loss = train_step(state, batch)
        
        logits = state.apply_fn({'params': state.params}, X, train=False)
        train_acc = accuracy(logits, y)
        print(f"Epoch {epoch + 1}, Loss: {loss:.4f}, Train Accuracy: {train_acc:.4f}")

    return state, model
"""


"""
Error Code:
def train_step(state, batch):
    def loss_fn(params):
        predictions = state.apply_fn({'params': params}, batch['X'], train=True)
        loss = compute_loss(predictions, batch['y'])
        return loss, predictions
    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (loss, preds), grads = grad_fn(state.params)
    new_state = state.apply_gradients(grads=grads)
    return new_state, loss


Error:
SelfAttention_0 needs PRNG for "dropout" (https://flax.readthedocs.io/en/latest/api_reference/flax.errors.html#flax.errors.InvalidRngError)


Fix Guide:
Modify the train_step function to accept an additional dropout random number key and pass in rngs={'dropout': dropout_rng} when calling apply_fn
During the training process, a new key needs to be assigned to dropout before each batch is processed


Correct Code:
@jit
def train_step(state, batch, dropout_rng):
    def loss_fn(params):
        predictions = state.apply_fn({'params': params}, batch['X'], train=True, rngs={'dropout': dropout_rng})
        loss = compute_loss(predictions, batch['y'])
        return loss, predictions
    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (loss, preds), grads = grad_fn(state.params)
    new_state = state.apply_gradients(grads=grads)
    return new_state, loss
"""


"""
Error Code:
def train_model(X, y, num_epochs, key):
    ...
    for i in range(0, dataset_size, 32):
        batch = {
            'X': X_shuffled[i:i + 32],
            'y': y_shuffled[i:i + 32]
        }
        state, loss = train_step(state, batch)
    ...


Error:
train_step() missing 1 required positional argument: 'dropout_rng'


Fix Guide:
In the training loop, use random.split to generate a new key for dropout before processing each batch and pass it to train_step


Correct Code:
key, dropout_key = random.split(key)
state, loss = train_step(state, batch, dropout_key)
"""


"""
Error Code:
def create_model() -> SimpleModel:
    return SimpleModel()


Error:
Reference to undefined SimpleModel class


Fix Guide:
Delete the function


Correct Code:
# def create_model() -> SimpleModel:
    # return SimpleModel()
"""


"""
Error Code:
def accuracy(logits, labels):
    preds = jnp.round(jax.nn.sigmoid(logits))
    return jnp.mean(preds == labels)


Error:
This function uses sigmoid and round to calculate the accuracy and is not suitable for regression tasks.


Fix Guide:
Remove this function


Correct Code:
# def accuracy(logits, labels):
    # preds = jnp.round(jax.nn.sigmoid(logits))
    # return jnp.mean(preds == labels)
"""


"""
Error Code:
logits = state.apply_fn({'params': state.params}, X, train=False)
train_acc = accuracy(logits, y)
print(f"Epoch {epoch + 1}, Loss: {loss:.4f}, Train Accuracy: {train_acc:.4f}")


Error:
The accuracy function is called in the training loop to calculate the accuracy, which is meaningless for regression tasks.


Fix Guide:
Removed accuracy calls and instead computed evaluation metrics for regression tasks


Correct Code:
predictions = state.apply_fn({'params': state.params}, X, train=False)
eval_loss = compute_loss(predictions, y)
print(f"Epoch {epoch + 1}, Loss: {loss:.4f}, Eval Loss: {eval_loss:.4f}")
"""


"""
Error Code:
ff = nn.relu(ff)


Error:
The nn module does not have a built-in relu function


Fix Guide:
Replace nn.relu with jax.nn.relu


Correct Code:
ff = jax.nn.relu(ff)
"""


"""
Error Code:
key = random.PRNGKey(0)
num_samples = 100
seq_length = 10
input_dim = 1
X = random.uniform(key, (num_samples, seq_length, input_dim))
y = jnp.sum(X, axis=1)

num_epochs = 10
train_model(X, y, num_epochs, key)


Error:
Reusing the same PRNG key may cause randomness issues or unexpected behavior


Fix Guide:
Use random.split to split a new key for subsequent passing to train_model


Correct Code:
key = random.PRNGKey(0)
num_samples = 100
seq_length = 10
input_dim = 1
key, subkey = random.split(key)
X = random.uniform(subkey, (num_samples, seq_length, input_dim))
y = jnp.sum(X, axis=1)

num_epochs = 10
train_model(X, y, num_epochs, key)
"""


"""
Error Code:
ff = nn.Dense(self.ff_dim)(x)
ff = jax.nn.relu(ff)
ff = nn.Dense(self.embed_dim)(ff)


Error:
In PyTorch's nn.TransformerEncoderLayer, in addition to the built-in dropout in the self-attention part, the output of the feed-forward network is usually processed by dropout for regularization.


Fix Guide:
Insert a nn.Dropout layer after the relu activation and before the second fully connected layer, and pass in the deterministic=not train parameter


Correct Code:
ff = nn.Dense(self.ff_dim)(x)
ff = jax.nn.relu(ff)
ff = nn.Dropout(rate=self.dropout_rate)(ff, deterministic=not train)
ff = nn.Dense(self.embed_dim)(ff)
"""


"""
Error Code:
def main():
    # Example data generation with explicit PRNG key
    key = random.PRNGKey(0)
    num_samples = 100
    seq_length = 10
    input_dim = 1
    key, subkey = random.split(key)
    X = random.uniform(subkey, (num_samples, seq_length, input_dim))
    y = jnp.sum(X, axis=1)

    num_epochs = 10
    train_model(X, y, num_epochs, key)


Error:
The main function main() only calls train_model for training, and there is no part similar to testing new data in PyTorch code


Fix Guide:
Add the test code at the end of the main() function


Correct Code:
def main():
    # Example data generation with explicit PRNG key
    key = random.PRNGKey(0)
    num_samples = 100
    seq_length = 10
    input_dim = 1
    key, subkey = random.split(key)
    X = random.uniform(subkey, (num_samples, seq_length, input_dim))
    y = jnp.sum(X, axis=1)

    num_epochs = 10
    state, model = train_model(X, y, num_epochs, key)
    
    # Testing on new data
    key, subkey = random.split(key)
    X_test = random.uniform(subkey, (2, seq_length, input_dim))
    predictions = state.apply_fn({'params': state.params}, X_test, train=False)
    print(f"Predictions for {X_test.tolist()}: {predictions.tolist()}")

"""


"""
Error Code:
x = nn.Dense(self.embed_dim)(x)


Error:
There is no explicit definition of the "Embedding layer" like in the PyTorch code, nor is there any use of the declared input_dim parameter


Fix Guide:
Use a named Dense layer as the embedding layer in the __call__ method of TransformerModel to explicitly map the input from input_dim to embed_dim


Correct Code:
x = nn.Dense(self.embed_dim, name="embedding")(x)
"""


"""
Error Code:
attn = nn.SelfAttention(num_heads=self.num_heads,
                         qkv_features=self.embed_dim,
                         dropout_rate=self.dropout_rate,
                         deterministic=not train)(x)
x = x + attn


Error:
PyTorch's nn.TransformerEncoderLayer usually applies dropout to the self-attention output in the residual branch


Fix Guide:
Before adding the self-attention output to the input, perform another dropout


Correct Code:
attn = nn.SelfAttention(num_heads=self.num_heads,
                        qkv_features=self.embed_dim,
                        dropout_rate=self.dropout_rate,
                        deterministic=not train)(x)
attn = nn.Dropout(rate=self.dropout_rate)(attn, deterministic=not train)
x = x + attn
"""


"""
Error Code:
for i in range(0, dataset_size, 32):
    key, dropout_key = random.split(key)
    batch = {
        'X': X_shuffled[i:i + 32],
        'y': y_shuffled[i:i + 32]
    }
    state, loss = train_step(state, batch, dropout_key)
        
predictions = state.apply_fn({'params': state.params}, X, train=False)
eval_loss = compute_loss(predictions, y)
print(f"Epoch {epoch + 1}, Loss: {loss:.4f}, Eval Loss: {eval_loss:.4f}")


Error:
The printed loss is only the loss of the last batch in the current epoch, not the average training loss of the entire epoch, which is inconsistent with the expectation of printing the overall progress each time in the PyTorch code.


Fix Guide:
In each epoch, the losses of all batches are accumulated and the average is calculated before output


Correct Code:
total_loss = 0.0
num_batches = 0
for i in range(0, dataset_size, 32):
    key, dropout_key = random.split(key)
    batch = {
        'X': X_shuffled[i:i + 32],
        'y': y_shuffled[i:i + 32]
    }
    state, batch_loss = train_step(state, batch, dropout_key)
    total_loss += batch_loss
    num_batches += 1
avg_loss = total_loss / num_batches
predictions = state.apply_fn({'params': state.params}, X, train=False)
eval_loss = compute_loss(predictions, y)
print(f"Epoch {epoch + 1}, Avg Loss: {avg_loss:.4f}, Eval Loss: {eval_loss:.4f}")
"""


"""
Error Code:
avg_loss = total_loss / num_batches
predictions = state.apply_fn({'params': state.params}, X, train=False)
eval_loss = compute_loss(predictions, y)
print(f"Epoch {epoch + 1}, Avg Loss: {avg_loss:.4f}, Eval Loss: {eval_loss:.4f}")


Error:
The PyTorch code only prints once every 100 epochs


Fix Guide:
Modify the print statement to calculate the forward propagation loss on the full dataset only when (epoch + 1) is divisible by 100


Correct Code:
if (epoch + 1) % 100 == 0:
    predictions = state.apply_fn({'params': state.params}, X, train=False)
    loss_value = compute_loss(predictions, y)
    print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {loss_value:.4f}")
"""


"""
Error Code:
num_epochs = 10


Error:
The PyTorch code trains for 1000 epochs, while the JAX code only trains for 10 epochs


Fix Guide:
Change num_epochs to 1000


Correct Code:
num_epochs = 1000
"""

In [22]:
#Fixed Code
import jax
import jax.numpy as jnp  # Ensured consistent import statement
from flax import linen as nn
from jax import random, grad, jit, vmap
import optax
from flax.training import train_state


class TransformerEncoderLayer(nn.Module):
    embed_dim: int
    num_heads: int
    ff_dim: int
    dropout_rate: float = 0.1

    @nn.compact
    def __call__(self, x, train: bool = True):
        attn = nn.SelfAttention(num_heads=self.num_heads,
                        qkv_features=self.embed_dim,
                        dropout_rate=self.dropout_rate,
                        deterministic=not train)(x)
        attn = nn.Dropout(rate=self.dropout_rate)(attn, deterministic=not train)
        x = x + attn
        x = nn.LayerNorm()(x)

        ff = nn.Dense(self.ff_dim)(x)
        ff = jax.nn.relu(ff)
        ff = nn.Dropout(rate=self.dropout_rate)(ff, deterministic=not train)
        ff = nn.Dense(self.embed_dim)(ff)
        x = x + ff
        x = nn.LayerNorm()(x)
        return x

class TransformerModel(nn.Module):
    input_dim: int
    embed_dim: int
    num_heads: int
    num_layers: int
    ff_dim: int
    output_dim: int
    dropout_rate: float = 0.1

    @nn.compact
    def __call__(self, x, train: bool = True):
        x = nn.Dense(self.embed_dim, name="embedding")(x)

        for _ in range(self.num_layers):
            x = TransformerEncoderLayer(embed_dim=self.embed_dim,
                                        num_heads=self.num_heads,
                                        ff_dim=self.ff_dim,
                                        dropout_rate=self.dropout_rate)(x, train=train)

        x = jnp.mean(x, axis=1)
        x = nn.Dense(self.output_dim)(x)
        return x


def compute_loss(predictions, targets):
    return jnp.mean((predictions - targets) ** 2)


def create_train_state(rng, model, learning_rate, input_shape):
    params = model.init(rng, jnp.ones(input_shape))['params']
    tx = optax.adam(learning_rate)
    return train_state.TrainState.create(apply_fn=model.apply, params=params, tx=tx)

@jit
def train_step(state, batch, dropout_rng):
    def loss_fn(params):
        predictions = state.apply_fn({'params': params}, batch['X'], train=True, rngs={'dropout': dropout_rng})
        loss = compute_loss(predictions, batch['y'])
        return loss, predictions
    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (loss, preds), grads = grad_fn(state.params)
    new_state = state.apply_gradients(grads=grads)
    return new_state, loss


def train_model(X, y, num_epochs, key):
    model = TransformerModel(
        input_dim=1,
        embed_dim=16,
        num_heads=2,
        num_layers=2,
        ff_dim=64,
        output_dim=1
    )
    state = create_train_state(key, model, learning_rate=0.001, input_shape=X.shape)

    dataset_size = X.shape[0]
    
    for epoch in range(num_epochs):
        key, subkey = random.split(key)
        perm = random.permutation(subkey, dataset_size)
        X_shuffled = X[perm]
        y_shuffled = y[perm]
        
        total_loss = 0.0
        num_batches = 0
        for i in range(0, dataset_size, 32):
            key, dropout_key = random.split(key)
            batch = {
                'X': X_shuffled[i:i + 32],
                'y': y_shuffled[i:i + 32]
            }
            state, batch_loss = train_step(state, batch, dropout_key)
            total_loss += batch_loss
            num_batches += 1
            
        if (epoch + 1) % 100 == 0:
            predictions = state.apply_fn({'params': state.params}, X, train=False)
            loss_value = compute_loss(predictions, y)
            print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {loss_value:.4f}")

    return state, model


def main():
    """Main entry point for the script."""
    # Example data generation with explicit PRNG key
    key = random.PRNGKey(0)
    num_samples = 100
    seq_length = 10
    input_dim = 1
    key, subkey = random.split(key)
    X = random.uniform(subkey, (num_samples, seq_length, input_dim))
    y = jnp.sum(X, axis=1)

    num_epochs = 1000
    state, model = train_model(X, y, num_epochs, key)
    
    # Testing on new data
    key, subkey = random.split(key)
    X_test = random.uniform(subkey, (2, seq_length, input_dim))
    predictions = state.apply_fn({'params': state.params}, X_test, train=False)
    print(f"Predictions for {X_test.tolist()}: {predictions.tolist()}")


if __name__ == "__main__":
    main()

Epoch [100/1000], Loss: 0.8359
Epoch [200/1000], Loss: 0.0190
Epoch [300/1000], Loss: 0.0066
Epoch [400/1000], Loss: 0.0119
Epoch [500/1000], Loss: 0.0020
Epoch [600/1000], Loss: 0.0080
Epoch [700/1000], Loss: 0.0078
Epoch [800/1000], Loss: 0.0069
Epoch [900/1000], Loss: 0.0013
Epoch [1000/1000], Loss: 0.0036
Predictions for [[[0.21237587928771973], [0.17597758769989014], [0.22925710678100586], [0.37311553955078125], [0.7130759954452515], [0.9348740577697754], [0.38974833488464355], [0.8368901014328003], [0.10798454284667969], [0.5377466678619385]], [[0.09019839763641357], [0.7678695917129517], [0.1505894660949707], [0.036005616188049316], [0.9545391798019409], [0.8174539804458618], [0.747868537902832], [0.03028249740600586], [0.3456763029098511], [0.07068979740142822]]]: [[4.516904354095459], [4.035348892211914]]
