In [1]:
#Input
import torch
import torch.nn as nn
import torch.optim as optim

# Define a simple model
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.fc = nn.Linear(1, 1)

    def forward(self, x):
        return self.fc(x)

# Create and train the model
torch.manual_seed(42)
model = SimpleModel()
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)

# Training loop
X = torch.rand(100, 1)
y = 3 * X + 2 + torch.randn(100, 1) * 0.1
epochs = 100
for epoch in range(epochs):
    optimizer.zero_grad()
    predictions = model(X)
    loss = criterion(predictions, y)
    loss.backward()
    optimizer.step()

# Save the model to a file named "model.pth"
torch.save(model.state_dict(), "model.pth")

# Load the model back from "model.pth"
loaded_model = SimpleModel()
loaded_model.load_state_dict(torch.load("model.pth"))
loaded_model.eval()

# Verify the model works after loading
X_test = torch.tensor([[0.5], [1.0], [1.5]])
with torch.no_grad():
    predictions = loaded_model(X_test)
    print(f"Predictions after loading: {predictions}")
    

Predictions after loading: tensor([[3.3646],
        [4.2802],
        [5.1959]])


In [2]:
#Strong LLM
import jax
import jax.numpy as jnp
import pickle


key = jax.random.PRNGKey(42)
# Generate 100 data points with one feature
key, subkey = jax.random.split(key)
X = jax.random.uniform(subkey, shape=(100, 1))
key, subkey = jax.random.split(key)
# Create targets: y = 3*x + 2 + noise (noise scaled by 0.1)
y = 3 * X + 2 + 0.1 * jax.random.normal(subkey, shape=(100, 1))

def predict(params, x):
    """Simple linear model: y = x * w + b"""
    return jnp.dot(x, params["w"]) + params["b"]

def loss_fn(params, X, y):
    preds = predict(params, X)
    return jnp.mean((preds - y) ** 2)

key, subkey = jax.random.split(key)
params = {
    "w": 0.1 * jax.random.normal(subkey, shape=(1, 1)),
    "b": jnp.zeros((1,))
}

lr = 0.01
epochs = 100

for epoch in range(epochs):
    loss, grads = jax.value_and_grad(loss_fn)(params, X, y)
    params["w"] = params["w"] - lr * grads["w"]
    params["b"] = params["b"] - lr * grads["b"]
    
    if (epoch + 1) % 10 == 0:
        print(f"Epoch [{epoch + 1}/{epochs}], Loss: {float(loss):.4f}")

with open("model.pth", "wb") as f:
    pickle.dump(params, f)

with open("model.pth", "rb") as f:
    loaded_params = pickle.load(f)

X_test = jnp.array([[0.5], [1.0], [1.5]])
predictions = predict(loaded_params, X_test)
print("Predictions after loading:", predictions)

Epoch [10/100], Loss: 9.0799
Epoch [20/100], Loss: 5.5066
Epoch [30/100], Loss: 3.3760
Epoch [40/100], Loss: 2.1048
Epoch [50/100], Loss: 1.3452
Epoch [60/100], Loss: 0.8905
Epoch [70/100], Loss: 0.6173
Epoch [80/100], Loss: 0.4523
Epoch [90/100], Loss: 0.3518
Epoch [100/100], Loss: 0.2898
Predictions after loading: [[3.3005486]
 [4.0336018]
 [4.7666545]]


In [3]:
#Weak LLM
import jax.numpy as jnp  # MODIFIED: Consistent import of jax.numpy as jnp
from jax import random

def generate_random_numbers(shape):
    """
    Generate random numbers following a normal distribution.

    Args:
        shape (tuple): The shape of the output array.

    Returns:
        jnp.ndarray: An array of random numbers of the specified shape.
    """
    return random.normal_random(key=random.PRNGKey(0), shape=shape)  # Example method to generate random numbers

# Example usage of the generate_random_numbers function
def main():
    # Generate a 3x3 array of random numbers
    random_numbers = generate_random_numbers((3, 3))
    print("Generated Random Numbers:\n", random_numbers)

if __name__ == "__main__":
    main()  # Entry point of the program

# Additional code can go here, e.g., model definition, training loops, etc.
# Training loop
# epochs = 1000
# for epoch in range(epochs):
#     model_params, optimizer_state, loss = train_step(model, X, y, optimizer_state)
#     model = model.replace(slope=model_params)
#     if epoch % 100 == 0:
#         print(f'Epoch {epoch}, Loss: {loss:.4f}')

AttributeError: module 'jax.random' has no attribute 'normal_random'

In [None]:
"""
Error Code:
return random.normal_random(key=random.PRNGKey(0), shape=shape)


Error:
module 'jax.random' has no attribute 'normal_random'


Fix Guide:
Replace random.normal_random with random.normal


Correct Code:
return random.normal(key=random.PRNGKey(0), shape=shape)
"""


"""
Error Code:
import jax.numpy as jnp  # MODIFIED: Consistent import of jax.numpy as jnp
from jax import random

def generate_random_numbers(shape):
    return random.normal(key=random.PRNGKey(0), shape=shape)  # Example method to generate random numbers


Error:
Compared with the PyTorch code, the JAX version is missing the following parts:
Model definition
Loss function
Training loop
Model saving and loading


Fix Guide:
Define a simple linear model, store the model parameters in a dictionary, and define a model function
Define the loss function
Use jax.value_and_grad to calculate the gradient and update the parameters in the training loop
Use Python's pickle module to save and load model parameters


Correct Code:
import jax.numpy as jnp  # MODIFIED: Consistent import of jax.numpy as jnp
from jax import random, value_and_grad
import pickle

def model(params, x):
    return params['w'] * x + params['b']

def mse_loss(params, x, y):
    preds = model(params, x)
    return jnp.mean((preds - y) ** 2)

def train_step(params, x, y, learning_rate=0.01):
    loss, grads = value_and_grad(mse_loss)(params, x, y)
    new_params = {k: params[k] - learning_rate * grads[k] for k in params}
    return new_params, loss

def generate_random_numbers(shape):
    return random.normal(key=random.PRNGKey(0), shape=shape)  # Example method to generate random numbers
"""


"""
Error Code:
# Example usage of the generate_random_numbers function
def main():
    # Generate a 3x3 array of random numbers
    random_numbers = generate_random_numbers((3, 3))
    print("Generated Random Numbers:\n", random_numbers)


Error:
Missing the part of generating training data, training loop, and saving and loading models for prediction after training


Fix Guide:
Use JAX's random function to generate X (uniform distribution) and noise (normal distribution), and then construct y = 3 * X + 2 + noise
Write a training loop to update the model parameters and periodically print the loss
Use pickle to save the trained parameters to a file and then load it from the file


Correct Code:
def main():
    key = random.PRNGKey(42)
    
    key, subkey1, subkey2 = random.split(key, 3)
    params = {
        'w': random.normal(subkey1, (1,)),
        'b': random.normal(subkey2, (1,))
    }
    
    key, subkey1, subkey2 = random.split(key, 3)
    X = random.uniform(subkey1, (100, 1))
    noise = random.normal(subkey2, (100, 1)) * 0.1
    y = 3 * X + 2 + noise
    
    epochs = 100
    for epoch in range(epochs):
        params, loss = train_step(params, X, y, learning_rate=0.01)
        if epoch % 10 == 0:
            print(f"Epoch {epoch}, Loss: {loss:.4f}")
    
    with open("model.pth", "wb") as f:
        pickle.dump(params, f)
    
    with open("model.pth", "rb") as f:
        loaded_params = pickle.load(f)
    
    X_test = jnp.array([[0.5], [1.0], [1.5]])
    predictions = model(loaded_params, X_test)
    print("Predictions after loading:", predictions)
"""

In [7]:
#Fixed Code
import jax.numpy as jnp  # MODIFIED: Consistent import of jax.numpy as jnp
from jax import random, value_and_grad
import pickle

def model(params, x):
    return params['w'] * x + params['b']

def mse_loss(params, x, y):
    preds = model(params, x)
    return jnp.mean((preds - y) ** 2)

def train_step(params, x, y, learning_rate=0.01):
    loss, grads = value_and_grad(mse_loss)(params, x, y)
    new_params = {k: params[k] - learning_rate * grads[k] for k in params}
    return new_params, loss

def generate_random_numbers(shape):
    return random.normal(key=random.PRNGKey(0), shape=shape)  # Example method to generate random numbers

def main():
    key = random.PRNGKey(42)
    
    key, subkey1, subkey2 = random.split(key, 3)
    params = {
        'w': random.normal(subkey1, (1,)),
        'b': random.normal(subkey2, (1,))
    }
    
    key, subkey1, subkey2 = random.split(key, 3)
    X = random.uniform(subkey1, (100, 1))
    noise = random.normal(subkey2, (100, 1)) * 0.1
    y = 3 * X + 2 + noise
    
    epochs = 100
    for epoch in range(epochs):
        params, loss = train_step(params, X, y, learning_rate=0.01)
        if epoch % 10 == 0:
            print(f"Epoch {epoch}, Loss: {loss:.4f}")
    
    with open("model.pth", "wb") as f:
        pickle.dump(params, f)
    
    with open("model.pth", "rb") as f:
        loaded_params = pickle.load(f)
    
    X_test = jnp.array([[0.5], [1.0], [1.5]])
    predictions = model(loaded_params, X_test)
    print("Predictions after loading:", predictions)

if __name__ == "__main__":
    main()  # Entry point of the program

# Additional code can go here, e.g., model definition, training loops, etc.
# Training loop
# epochs = 1000
# for epoch in range(epochs):
#     model_params, optimizer_state, loss = train_step(model, X, y, optimizer_state)
#     model = model.replace(slope=model_params)
#     if epoch % 100 == 0:
#         print(f'Epoch {epoch}, Loss: {loss:.4f}')

Epoch 0, Loss: 22.8257
Epoch 10, Loss: 13.5925
Epoch 20, Loss: 8.1053
Epoch 30, Loss: 4.8441
Epoch 40, Loss: 2.9057
Epoch 50, Loss: 1.7532
Epoch 60, Loss: 1.0679
Epoch 70, Loss: 0.6600
Epoch 80, Loss: 0.4171
Epoch 90, Loss: 0.2723
Predictions after loading: [[3.1789165]
 [4.227249 ]
 [5.2755823]]
