In [1]:
#Input
import torch
import torch.nn as nn
import torch.optim as optim

# Define a simple model
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.fc = nn.Linear(1, 1)

    def forward(self, x):
        return self.fc(x)

# Create and train the model
torch.manual_seed(42)
model = SimpleModel()
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)

# Training loop
X = torch.rand(100, 1)
y = 3 * X + 2 + torch.randn(100, 1) * 0.1
epochs = 100
for epoch in range(epochs):
    optimizer.zero_grad()
    predictions = model(X)
    loss = criterion(predictions, y)
    loss.backward()
    optimizer.step()

# Save the model to a file named "model.pth"
torch.save(model.state_dict(), "model.pth")

# Load the model back from "model.pth"
loaded_model = SimpleModel()
loaded_model.load_state_dict(torch.load("model.pth"))
loaded_model.eval()

# Verify the model works after loading
X_test = torch.tensor([[0.5], [1.0], [1.5]])
with torch.no_grad():
    predictions = loaded_model(X_test)
    print(f"Predictions after loading: {predictions}")
    

Predictions after loading: tensor([[3.3646],
        [4.2802],
        [5.1959]])


In [2]:
#Strong LLM
import jax
import jax.numpy as jnp
import pickle


key = jax.random.PRNGKey(42)
key, subkey = jax.random.split(key)
X = jax.random.uniform(subkey, shape=(100, 1))  # 100 data points in [0, 1)
key, subkey = jax.random.split(key)
y = 3 * X + 2 + 0.1 * jax.random.normal(subkey, shape=(100, 1))  # y = 3*x + 2 + noise

def predict(params, x):
    """Linear model: y = x * w + b"""
    return jnp.dot(x, params["w"]) + params["b"]

def loss_fn(params, X, y):
    preds = predict(params, X)
    return jnp.mean((preds - y) ** 2)

# Initialize weight and bias. Here, weight is initialized with small random values and bias as zeros.
key, subkey = jax.random.split(key)
params = {
    "w": 0.1 * jax.random.normal(subkey, shape=(1, 1)),
    "b": jnp.zeros((1,))
}

lr = 0.01
epochs = 100

for epoch in range(epochs):
    loss, grads = jax.value_and_grad(loss_fn)(params, X, y)
    params["w"] = params["w"] - lr * grads["w"]
    params["b"] = params["b"] - lr * grads["b"]
    # Optionally, print loss every 10 epochs
    if (epoch + 1) % 10 == 0:
        print(f"Epoch [{epoch + 1}/{epochs}], Loss: {float(loss):.4f}")

with open("model.pth", "wb") as f:
    pickle.dump(params, f)

with open("model.pth", "rb") as f:
    loaded_params = pickle.load(f)

X_test = jnp.array([[0.5], [1.0], [1.5]])
predictions = predict(loaded_params, X_test)
print("Predictions after loading:", predictions)

Epoch [10/100], Loss: 9.0799
Epoch [20/100], Loss: 5.5066
Epoch [30/100], Loss: 3.3760
Epoch [40/100], Loss: 2.1048
Epoch [50/100], Loss: 1.3452
Epoch [60/100], Loss: 0.8905
Epoch [70/100], Loss: 0.6173
Epoch [80/100], Loss: 0.4523
Epoch [90/100], Loss: 0.3518
Epoch [100/100], Loss: 0.2898
Predictions after loading: [[3.3005486]
 [4.0336018]
 [4.7666545]]


In [3]:
#Weak LLM
import jax.numpy as jnp  # MODIFIED: Consistently import jax.numpy as jnp
from jax import grad, jit, random, vmap
import flax.linen as nn
import optax
import pickle

class SimpleModel(nn.Module):
    """A simple neural network model using Flax."""
    
    def setup(self):
        """Define the layers of the model."""
        self.dense = nn.Dense(features=1)  # A layer with one output feature

    def __call__(self, x):
        """Forward pass of the model."""
        return self.dense(x)

def train_model(X, y):
    """Train the model with the given data."""
    model = SimpleModel()
    params = model.init(random.PRNGKey(0), X)
    # Loss function and optimization setup
    loss_fn = lambda params: jnp.mean((model.apply(params, X) - y) ** 2)
    optimizer = optax.adam(0.001)
    opt_state = optimizer.init(params)
    
    for epoch in range(100):  # Simple training loop
        loss, grads = jax.value_and_grad(loss_fn)(params)
        updates, opt_state = optimizer.update(grads, opt_state)
        params = optax.apply_updates(params, updates)
    
    return params

def main():
    """Main function to execute the training and evaluation of the model."""
    X_train = jnp.array([[0.0], [1.0], [2.0], [3.0]])  # Training data
    y_train = jnp.array([[0.0], [2.0], [4.0], [6.0]])  # Expected outputs

    # Train the model
    trained_params = train_model(X_train, y_train)

    # Verify the model works after loading
    X_test = jnp.array([[0.5], [1.0], [1.5]])  # Test data
    model = SimpleModel()  # Initialize model
    predictions = model.apply(trained_params, X_test)  # Get predictions
    print(f"Predictions after training: {predictions}")

if __name__ == "__main__":  # Entry point for the program
    main()  # Execute the main function

Predictions after training: [[-0.31621063]
 [-0.73144114]
 [-1.1466718 ]]


In [None]:
"""
Error Code:
loss, grads = jax.value_and_grad(loss_fn)(params)


Error:
The entire jax module is not imported, and jax.value_and_grad is not found


Fix Guide:
Added import jax


Correct Code:
import jax

    loss, grads = jax.value_and_grad(loss_fn)(params)
"""


"""
Error Code:
# Train the model
trained_params = train_model(X_train, y_train)


Error:
JAX code does not implement saving and loading of model parameters


Fix Guide:
Use pickle to save the trained parameters to a file, then load it back and use the loaded parameters for prediction


Correct Code:
# Train the model
trained_params = train_model(X_train, y_train)

# Save model parameters to file
with open("model.pkl", "wb") as f:
    pickle.dump(trained_params, f)

# Load model parameters from file
with open("model.pkl", "rb") as f:
    loaded_params = pickle.load(f)
"""


"""
Error Code:
X_train = jnp.array([[0.0], [1.0], [2.0], [3.0]])  # Training data
y_train = jnp.array([[0.0], [2.0], [4.0], [6.0]])  # Expected outputs


Error:
The training data in the PyTorch code is randomly generated and noise is added according to the formula y = 3 * X + 2. The training data in the JAX code is fixed to 4 points, which is inconsistent with the data in PyTorch.


Fix Guide:
Use JAX's random number generator to generate 100 samples of input data and construct a target value that meets y = 3 * X + 2 + noise


Correct Code:
key = random.PRNGKey(42)
key, subkey = random.split(key)
X_train = random.uniform(subkey, (100, 1))
key, subkey = random.split(key)
noise = random.normal(subkey, (100, 1)) * 0.1
y_train = 3 * X_train + 2 + noise
"""


"""
Error Code:
params = model.init(random.PRNGKey(0), X)


Error:
A hardcoded PRNG key is used in the train_model function, while a key has been generated based on the seed 42 in the main function.


Fix Guide:
Modify the train_model function to accept key as a parameter and use the passed key to initialize the model
Pass the generated key when calling in main


Correct Code:
def train_model(X, y, key):
    model = SimpleModel()
    params = model.init(key, X)
    # Loss function and optimization setup
    loss_fn = lambda params: jnp.mean((model.apply(params, X) - y) ** 2)
    optimizer = optax.adam(0.001)
    opt_state = optimizer.init(params)
    
    for epoch in range(100):  # Simple training loop
        loss, grads = jax.value_and_grad(loss_fn)(params)
        updates, opt_state = optimizer.update(grads, opt_state)
        params = optax.apply_updates(params, updates)
    
    return params

trained_params = train_model(X_train, y_train, key)
"""


"""
Error Code:
predictions = model.apply(trained_params, X_test)


Error:
When validating the model, the trained_params parameters used during training were incorrectly used


Fix Guide:
Replace the parameters used during prediction from trained_params with loaded_params after loading from the file


Correct Code:
predictions = model.apply(loaded_params, X_test)
"""


"""
Error Code:
optimizer = optax.adam(0.001)


Error:
The PyTorch code uses optim.SGD(model.parameters(), lr=0.01), while the Adam optimizer is used here with a learning rate of 0.001


Fix Guide:
Modified to use optax.sgd with a learning rate of 0.01


Correct Code:
optimizer = optax.sgd(0.01)
"""

In [9]:
#Fixed Code
import jax
import jax.numpy as jnp  # MODIFIED: Consistently import jax.numpy as jnp
from jax import grad, jit, random, vmap
import flax.linen as nn
import optax
import pickle

class SimpleModel(nn.Module):
    """A simple neural network model using Flax."""
    
    def setup(self):
        """Define the layers of the model."""
        self.dense = nn.Dense(features=1)  # A layer with one output feature

    def __call__(self, x):
        """Forward pass of the model."""
        return self.dense(x)

def train_model(X, y, key):
    """Train the model with the given data."""
    model = SimpleModel()
    params = model.init(key, X)
    # Loss function and optimization setup
    loss_fn = lambda params: jnp.mean((model.apply(params, X) - y) ** 2)
    optimizer = optax.sgd(0.01)
    opt_state = optimizer.init(params)
    
    for epoch in range(100):  # Simple training loop
        loss, grads = jax.value_and_grad(loss_fn)(params)
        updates, opt_state = optimizer.update(grads, opt_state)
        params = optax.apply_updates(params, updates)
    
    return params

def main():
    """Main function to execute the training and evaluation of the model."""
    key = random.PRNGKey(42)
    key, subkey = random.split(key)
    X_train = random.uniform(subkey, (100, 1))
    key, subkey = random.split(key)
    noise = random.normal(subkey, (100, 1)) * 0.1
    y_train = 3 * X_train + 2 + noise

    # Train the model
    trained_params = train_model(X_train, y_train, key)
    
    # Save model parameters to file
    with open("model.pkl", "wb") as f:
        pickle.dump(trained_params, f)

    # Load model parameters from file
    with open("model.pkl", "rb") as f:
        loaded_params = pickle.load(f)

    # Verify the model works after loading
    X_test = jnp.array([[0.5], [1.0], [1.5]])  # Test data
    model = SimpleModel()  # Initialize model
    predictions = model.apply(loaded_params, X_test)  # Get predictions
    print(f"Predictions after training: {predictions}")

if __name__ == "__main__":  # Entry point for the program
    main()  # Execute the main function

Predictions after training: [[3.3018272]
 [4.089325 ]
 [4.876823 ]]
