In [6]:
# Input
import torch
import torch.nn as nn
import torch.optim as optim

# Generate synthetic data
torch.manual_seed(42)
X = torch.rand(100, 1) * 10  # 100 data points between 0 and 10
y = 2 * X + 3 + torch.randn(100, 1)  # Linear relationship with noise

# Define the Linear Regression Model
class LinearRegressionModel(nn.Module):
    def __init__(self):
        super(LinearRegressionModel, self).__init__()
        self.linear = nn.Linear(1, 1)  # Single input and single output

    def forward(self, x):
        return self.linear(x)

# Initialize the model, loss function, and optimizer
model = LinearRegressionModel()
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)

# Training loop
epochs = 1000
for epoch in range(epochs):
    # Forward pass
    predictions = model(X)
    loss = criterion(predictions, y)

    # Backward pass and optimization
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    # Log progress every 100 epochs
    if (epoch + 1) % 100 == 0:
        print(f"Epoch [{epoch + 1}/{epochs}], Loss: {loss.item():.4f}")

# Display the learned parameters
[w, b] = model.linear.parameters()
print(f"Learned weight: {w.item():.4f}, Learned bias: {b.item():.4f}")

# Testing on new data
X_test = torch.tensor([[4.0], [7.0]])
with torch.no_grad():
    predictions = model(X_test)
    print(f"Predictions for {X_test.tolist()}: {predictions.tolist()}")

Epoch [100/1000], Loss: 1.6039
Epoch [200/1000], Loss: 1.0242
Epoch [300/1000], Loss: 0.8017
Epoch [400/1000], Loss: 0.7163
Epoch [500/1000], Loss: 0.6836
Epoch [600/1000], Loss: 0.6710
Epoch [700/1000], Loss: 0.6662
Epoch [800/1000], Loss: 0.6643
Epoch [900/1000], Loss: 0.6636
Epoch [1000/1000], Loss: 0.6634
Learned weight: 1.9577, Learned bias: 3.2045
Predictions for [[4.0], [7.0]]: [[11.035286903381348], [16.90837860107422]]


In [1]:
# Strong LLM
import jax
import jax.numpy as jnp

# Set the random seed and create a PRNG key
key = jax.random.PRNGKey(42)

# Generate synthetic data
key, subkey = jax.random.split(key)
X = jax.random.uniform(subkey, shape=(100, 1)) * 10  # 100 data points between 0 and 10
key, subkey = jax.random.split(key)
noise = jax.random.normal(subkey, shape=(100, 1))
y = 2 * X + 3 + noise  # Linear relationship with noise

# Initialize model parameters similar to PyTorch's nn.Linear(1, 1)
# PyTorch uses a uniform distribution in [-1/sqrt(in_features), 1/sqrt(in_features)]
bound = 1.0  # For in_features=1, bound = 1.
key, subkey = jax.random.split(key)
w = jax.random.uniform(subkey, shape=(1, 1), minval=-bound, maxval=bound)
key, subkey = jax.random.split(key)
b = jax.random.uniform(subkey, shape=(1,), minval=-bound, maxval=bound)
params = {"w": w, "b": b}

# Define the forward pass (prediction function)
def predict(params, x):
    # Using dot product to mimic nn.Linear (x is (batch, 1), w is (1, 1), b is (1,))
    return jnp.dot(x, params["w"]) + params["b"]

# Define the loss function (Mean Squared Error)
def loss_fn(params, x, y):
    preds = predict(params, x)
    return jnp.mean((preds - y) ** 2)

# Set learning rate and number of epochs
lr = 0.01
epochs = 1000

# Get a function to compute loss and its gradients
loss_and_grad = jax.value_and_grad(loss_fn)

# Training loop
for epoch in range(epochs):
    loss, grads = loss_and_grad(params, X, y)
    
    # Update parameters using SGD
    params["w"] = params["w"] - lr * grads["w"]
    params["b"] = params["b"] - lr * grads["b"]
    
    # Log progress every 100 epochs
    if (epoch + 1) % 100 == 0:
        print(f"Epoch [{epoch + 1}/{epochs}], Loss: {loss:.4f}")

# Display the learned parameters
learned_w = params["w"][0, 0]
learned_b = params["b"][0]
print(f"Learned weight: {learned_w:.4f}, Learned bias: {learned_b:.4f}")

# Testing on new data
X_test = jnp.array([[4.0], [7.0]])
predictions = predict(params, X_test)
print(f"Predictions for {X_test.tolist()}: {predictions.tolist()}")


Epoch [100/1000], Loss: 1.4685
Epoch [200/1000], Loss: 1.2193
Epoch [300/1000], Loss: 1.1290
Epoch [400/1000], Loss: 1.0963
Epoch [500/1000], Loss: 1.0844
Epoch [600/1000], Loss: 1.0801
Epoch [700/1000], Loss: 1.0786
Epoch [800/1000], Loss: 1.0780
Epoch [900/1000], Loss: 1.0778
Epoch [1000/1000], Loss: 1.0777
Learned weight: 2.0446, Learned bias: 2.8493
Predictions for [[4.0], [7.0]]: [[11.027803421020508], [17.161645889282227]]


In [3]:
# Weak LLM
import jax
import jax.numpy as jnp
from jax import grad, jit, random, vmap
import optax

# Generate synthetic data
def generate_data(num_samples=100):
    X = jnp.linspace(0, 10, num_samples).reshape(-1, 1)
    y = 2 * X + 1 + jnp.random.normal(0, 1, X.shape)  # Adjusted for noise
    return X, y

# Linear regression model
def model(params, x):
    w, b = params
    return w * x + b

# Loss function
def loss_fn(params, x, y):
    preds = model(params, x)
    return jnp.mean((preds - y) ** 2)

# Gradient computation
@jax.jit
def compute_gradient(params, x, y):
    return grad(loss_fn)(params, x, y)

# Training step
@jax.jit
def train_step(params, x, y, rng_key):
    grads = compute_gradient(params, x, y)
    return params - 0.01 * grads  # Update rate

# Batch training step using vectorization
def train_model(X, y, num_epochs=1000):
    key = random.PRNGKey(0)  # Initialize PRNG key
    params = jnp.array([0.0, 0.0])  # Initial parameters (w, b)
    
    # Vectorized training over epochs
    for epoch in range(num_epochs):
        # Create a new PRNG key for each batch (if batching)
        key, subkey = random.split(key)
        params = train_step(params, X, y, subkey)  # // MODIFIED: Removed loop for vectorization

    return params

# Main function
def main():
    # Generate data
    X, y = generate_data(100)
    
    # Train the model
    learned_params = train_model(X, y)
    
    # Display the learned parameters
    w, b = learned_params
    print(f"Learned weight: {w.item():.4f}, Learned bias: {b.item():.4f}")

    # Testing on new data
    X_test = jnp.array([[4.0], [7.0]])
    predictions = model(learned_params, X_test)
    print(f"Predictions for {X_test.tolist()}: {predictions.tolist()}")

if __name__ == "__main__":
    main()

AttributeError: module 'jax.numpy' has no attribute 'random'

In [None]:
"""
Error Code:
X = jnp.linspace(0, 10, num_samples).reshape(-1, 1)
y = 2 * X + 1 + jnp.random.normal(0, 1, X.shape)  # Adjusted for noise


Error:
AttributeError: module 'jax.numpy' has no attribute 'random'


Fix Guide:
Correct random number generation requires the use of jax.random.normal and the need to pass in the PRNG key


Correct Code:
key = random.PRNGKey(0)
X = jnp.linspace(0, 10, num_samples).reshape(-1, 1)
noise = random.normal(key, shape=X.shape)
y = 2 * X + 1 + noise
"""


"""
Error Code:
y = 2 * X + 1 + noise


Error:
The linear relationship when the data was generated should be 2 * X + 3 instead of 2 * X + 1


Fix Guide:
The linear relationship when the data is generated should be 2 * X + 3


Correct Code:
y = 2 * X + 3 + noise
"""


"""
Error Code:
def train_step(params, x, y, rng_key):


Error:
The rng_key parameter is passed into the train_step function, but the training step does not require randomness


Fix Guide:
Removed unused rng_key parameter


Correct Code:
def train_step(params, x, y):
"""


"""
Error Code:
# Batch training step using vectorization
def train_model(X, y, num_epochs=1000):
    key = random.PRNGKey(0)  # Initialize PRNG key
    params = jnp.array([0.0, 0.0])  # Initial parameters (w, b)
    
    # Vectorized training over epochs
    for epoch in range(num_epochs):
        # Create a new PRNG key for each batch (if batching)
        key, subkey = random.split(key)
        params = train_step(params, X, y, subkey)  # // MODIFIED: Removed loop for vectorization

    return params


Error:
Since the training step does not require randomness, the generation and passing of rng_key should also be removed when training the model.


Fix Guide:
Remove the generation and passing of rng_key when training the model


Correct Code:
def train_model(X, y, num_epochs=1000):
    params = jnp.array([0.0, 0.0])  # Initial parameters (w, b)
    for epoch in range(num_epochs):
        params = train_step(params, X, y)
    return params
"""


"""
Error Code:
params = jnp.array([0.0, 0.0])  # Initial parameters (w, b)


Error:
Does not conform to the expected data structure and does not take advantage of random initialization


Fix Guide:
The parameters are initialized using a dictionary structure, and the weights and biases are initialized using random uniform distribution


Correct Code:
bound = 1.0  # For in_features=1, bound = 1.
key = random.PRNGKey(0)
key, subkey = random.split(key)
w = random.uniform(subkey, shape=(1, 1), minval=-bound, maxval=bound)
key, subkey = random.split(key)
b = random.uniform(subkey, shape=(1,), minval=-bound, maxval=bound)
params = {"w": w, "b": b}
"""


"""
Error Code:
def model(params, x):
    w, b = params
    return w * x + b


Error:
The parameter structure and operation method are inconsistent with the original code


Fix Guide:
The forward function should take a dictionary of arguments and use matrix multiplication to emulate the behavior of nn.Linear


Correct Code:
def model(params, x):
    return jnp.dot(x, params["w"]) + params["b"]
"""


"""
Error Code:
return params - 0.01 * grads


Error:
In the original code, params is a dictionary (including "w" and "b"), and each parameter in the dictionary should be updated separately


Fix Guide:
Each parameter in the dictionary should be updated separately


Correct Code:
return {
        "w": params["w"] - 0.01 * grads["w"],
        "b": params["b"] - 0.01 * grads["b"]
    }
"""


"""
Error Code:
for epoch in range(num_epochs):
        params = train_step(params, X, y)
    return params
            

Error:
In this training loop, no loss value for the current round is calculated, and no print statements are added to output log information


Fix Guide:
In each epoch, first calculate the loss and gradient, then update the parameters, and print the log when the conditions are met


Correct Code:
for epoch in range(num_epochs):
        loss, grads = jax.value_and_grad(loss_fn)(params, X, y)
        params = {
            "w": params["w"] - 0.01 * grads["w"],
            "b": params["b"] - 0.01 * grads["b"]
        }

        if (epoch + 1) % 100 == 0:
            print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {loss:.4f}")
    return params
"""

In [9]:
# fixed code
import jax
import jax.numpy as jnp
from jax import grad, jit, random

# Generate synthetic data
def generate_data(num_samples=100):
    key = random.PRNGKey(0)
    X = jnp.linspace(0, 10, num_samples).reshape(-1, 1)
    noise = random.normal(key, shape=X.shape)
    y = 2 * X + 3 + noise  
    return X, y

# Linear regression model
def model(params, x):
    return jnp.dot(x, params["w"]) + params["b"]

# Loss function
def loss_fn(params, x, y):
    preds = model(params, x)
    return jnp.mean((preds - y) ** 2)

# Gradient computation
@jit
def compute_gradient(params, x, y):
    return grad(loss_fn)(params, x, y)

# Training step
@jit
def train_step(params, x, y):
    grads = compute_gradient(params, x, y)
    return {
        "w": params["w"] - 0.01 * grads["w"],
        "b": params["b"] - 0.01 * grads["b"]
    }

# Training loop
def train_model(X, y, num_epochs=1000):
    bound = 1.0  # For in_features=1, bound = 1.
    key = random.PRNGKey(0)
    key, subkey = random.split(key)
    w = random.uniform(subkey, shape=(1, 1), minval=-bound, maxval=bound)
    key, subkey = random.split(key)
    b = random.uniform(subkey, shape=(1,), minval=-bound, maxval=bound)
    params = {"w": w, "b": b}
    
    for epoch in range(num_epochs):
        loss, grads = jax.value_and_grad(loss_fn)(params, X, y)
        params = {
            "w": params["w"] - 0.01 * grads["w"],
            "b": params["b"] - 0.01 * grads["b"]
        }

        if (epoch + 1) % 100 == 0:
            print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {loss:.4f}")
    return params

# Main function
def main():
    X, y = generate_data(100)
    learned_params = train_model(X, y)
    learned_w = learned_params["w"][0, 0]
    learned_b = learned_params["b"][0]
    print(f"Learned weight: {learned_w:.4f}, Learned bias: {learned_b:.4f}")
    
    X_test = jnp.array([[4.0], [7.0]])
    predictions = model(learned_params, X_test)
    print(f"Predictions for {X_test.tolist()}: {predictions.tolist()}")

if __name__ == "__main__":
    main()


Epoch [100/1000], Loss: 1.6636
Epoch [200/1000], Loss: 1.1676
Epoch [300/1000], Loss: 0.9843
Epoch [400/1000], Loss: 0.9166
Epoch [500/1000], Loss: 0.8915
Epoch [600/1000], Loss: 0.8823
Epoch [700/1000], Loss: 0.8789
Epoch [800/1000], Loss: 0.8776
Epoch [900/1000], Loss: 0.8771
Epoch [1000/1000], Loss: 0.8770
Learned weight: 2.0338, Learned bias: 2.9082
Predictions for [[4.0], [7.0]]: [[11.043389320373535], [17.14480972290039]]
