In [1]:
#Input
import torch
import torch.nn as nn
import torch.optim as optim

# Generate synthetic data
torch.manual_seed(42)
X = torch.rand(100, 1) * 10  # 100 data points between 0 and 10
y = 2 * X + 3 + torch.randn(100, 1)  # Linear relationship with noise


class HuberLoss(nn.Module):
    def __init__(self, delta=1.0):
        super(HuberLoss, self).__init__()
        self.delta = delta
    
    def forward(self, y_pred, y_true):
        # Calculate the absolute error
        error = torch.abs(y_pred - y_true)
        
        # Apply the Huber loss formula
        loss = torch.where(error <= self.delta,
                           0.5 * error**2,  # L2 loss for small errors
                           self.delta * (error - 0.5 * self.delta))  # L1 loss for large errors
        return loss.mean()  # Return the mean loss across all samples


# Define the Linear Regression Model
class LinearRegressionModel(nn.Module):
    def __init__(self):
        super(LinearRegressionModel, self).__init__()
        self.linear = nn.Linear(1, 1)  # Single input and single output

    def forward(self, x):
        return self.linear(x)

# Initialize the model, loss function, and optimizer
model = LinearRegressionModel()
criterion = HuberLoss(delta=1.0)
optimizer = optim.SGD(model.parameters(), lr=0.01)

# Training loop
epochs = 1000
for epoch in range(epochs):
    # Forward pass
    predictions = model(X)
    loss = criterion(predictions, y)

    # Backward pass and optimization
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    # Log progress every 100 epochs
    if (epoch + 1) % 100 == 0:
        print(f"Epoch [{epoch + 1}/{epochs}], Loss: {loss.item():.4f}")

# Display the learned parameters
[w, b] = model.linear.parameters()
print(f"Learned weight: {w.item():.4f}, Learned bias: {b.item():.4f}")

# Testing on new data
X_test = torch.tensor([[4.0], [7.0]])
with torch.no_grad():
    predictions = model(X_test)
    print(f"Predictions for {X_test.tolist()}: {predictions.tolist()}")

Epoch [100/1000], Loss: 0.8869
Epoch [200/1000], Loss: 0.7855
Epoch [300/1000], Loss: 0.6945
Epoch [400/1000], Loss: 0.6134
Epoch [500/1000], Loss: 0.5433
Epoch [600/1000], Loss: 0.4861
Epoch [700/1000], Loss: 0.4404
Epoch [800/1000], Loss: 0.4045
Epoch [900/1000], Loss: 0.3767
Epoch [1000/1000], Loss: 0.3551
Learned weight: 2.0713, Learned bias: 2.4650
Predictions for [[4.0], [7.0]]: [[10.750251770019531], [16.964160919189453]]


In [2]:
#Strong LLM
import jax
import jax.numpy as jnp
import numpy as np


# Set the random seed and create a PRNG key
key = jax.random.PRNGKey(42)

# Generate synthetic data: 100 data points between 0 and 10 with noise
key, subkey = jax.random.split(key)
X = jax.random.uniform(subkey, shape=(100, 1)) * 10
key, subkey = jax.random.split(key)
noise = jax.random.normal(subkey, shape=(100, 1))
y = 2 * X + 3 + noise  # y = 2*x + 3 + noise

# Define the linear regression model function
def predict(params, x):
    # Computes a linear transformation: x * w + b
    return jnp.dot(x, params["w"]) + params["b"]

# Define the Huber loss function
def huber_loss(params, x, y, delta=1.0):
    preds = predict(params, x)
    error = jnp.abs(preds - y)
    loss = jnp.where(error <= delta,
                     0.5 * error**2,            # L2 loss for small errors
                     delta * (error - 0.5 * delta))  # L1 loss for large errors
    return jnp.mean(loss)

# Initialize parameters for a linear layer with 1 input and 1 output.
bound = 1.0  # Using a simple uniform initialization bound
key, subkey = jax.random.split(key)
w = jax.random.uniform(subkey, shape=(1, 1), minval=-bound, maxval=bound)
key, subkey = jax.random.split(key)
b = jax.random.uniform(subkey, shape=(1,), minval=-bound, maxval=bound)
params = {"w": w, "b": b}

lr = 0.01
epochs = 1000

# Create a function that returns loss and gradients with respect to the parameters.
loss_and_grad = jax.value_and_grad(huber_loss, argnums=0)

for epoch in range(epochs):
    loss, grads = loss_and_grad(params, X, y, 1.0)
    # Parameter update using SGD
    params["w"] = params["w"] - lr * grads["w"]
    params["b"] = params["b"] - lr * grads["b"]
    
    # Log progress every 100 epochs
    if (epoch + 1) % 100 == 0:
        print(f"Epoch [{epoch + 1}/{epochs}], Loss: {loss:.4f}")


learned_w = params["w"][0, 0]
learned_b = params["b"][0]
print(f"Learned weight: {learned_w:.4f}, Learned bias: {learned_b:.4f}")


X_test = jnp.array([[4.0], [7.0]])
predictions = predict(params, X_test)
print(f"Predictions for {np.array(X_test).tolist()}: {np.array(predictions).tolist()}")


Epoch [100/1000], Loss: 0.7171
Epoch [200/1000], Loss: 0.6591
Epoch [300/1000], Loss: 0.6122
Epoch [400/1000], Loss: 0.5753
Epoch [500/1000], Loss: 0.5469
Epoch [600/1000], Loss: 0.5257
Epoch [700/1000], Loss: 0.5097
Epoch [800/1000], Loss: 0.4977
Epoch [900/1000], Loss: 0.4887
Epoch [1000/1000], Loss: 0.4819
Learned weight: 2.1092, Learned bias: 2.4387
Predictions for [[4.0], [7.0]]: [[10.875606536865234], [17.203285217285156]]


In [3]:
#Weak LLM
import jax
import jax.numpy as jnp
from jax import grad, jit, random, vmap
import optax

# Define a simple model
class LinearModel:
    def __init__(self, key):
        self.w = random.normal(key, (1,))
        self.b = random.normal(key, ())

    def __call__(self, x):
        return jnp.dot(x, self.w) + self.b

# Loss function
def loss_fn(model, x, y):
    preds = model(x)
    return jnp.mean((preds - y) ** 2)

# Update function using functional programming
def update(params, x, y, learning_rate=0.1):
    w, b = params
    loss_value, grads = jax.value_and_grad(loss_fn)(lambda x: model(x), x, y)
    w -= learning_rate * grads[0]
    b -= learning_rate * grads[1]
    return w, b

# Training function
def train_model(key, model, x, y, epochs=100):
    for epoch in range(epochs):  # MODIFIED
        model.w, model.b = update((model.w, model.b), x, y)  # MODIFIED
    return model

def main():
    # Generate synthetic data
    key = random.PRNGKey(0)  # MODIFIED: Explicit PRNG key
    model = LinearModel(key)
    
    # Generate synthetic data
    x = jnp.array([[1.0], [2.0], [3.0]])
    y = jnp.array([[2.0], [4.0], [6.0]])

    # Train the model
    model = train_model(key, model, x, y, epochs=100)

    # Test the model
    predictions = model(x)
    print(f"Predictions for {x.tolist()}: {predictions.tolist()}")
    print(f"Trained weights: {model.w}, bias: {model.b}")

if __name__ == "__main__":
    main()

TypeError: Argument '<function update.<locals>.<lambda> at 0x000001D545DF03A0>' of type <class 'function'> is not a valid JAX type.

In [None]:
"""
Error Code:
def update(params, x, y, learning_rate=0.1):
    w, b = params
    loss_value, grads = jax.value_and_grad(loss_fn)(lambda x: model(x), x, y)
    w -= learning_rate * grads[0]
    b -= learning_rate * grads[1]
    return w, b


Error:
Argument '<function update.<locals>.<lambda> at 0x000001D545DF03A0>' of type <class 'function'> is not a valid JAX type


Fix Guide:
Extract the model logic from the class method and define a pure function that accepts a parameter tuple (w, b) and input x and returns the prediction result
Change loss_fn to receive parameters (w, b) instead of the entire model instance, and use jax.value_and_grad to directly calculate the gradient of the parameters
In update, directly pass the parameter tuple to loss_fn to avoid using lambda functions


Correct Code:
def model_fn(params, x):
    w, b = params
    return jnp.dot(x, w) + b

def update(params, x, y, learning_rate=0.1):
    loss_value, grads = jax.value_and_grad(loss_fn)(params, x, y)
    w, b = params
    w = w - learning_rate * grads[0]
    b = b - learning_rate * grads[1]
    return (w, b)
"""


"""
Error Code:
class LinearModel:
    def __init__(self, key):
        self.w = random.normal(key, (1,))
        self.b = random.normal(key, ())

Error:
Weights should be a 2D matrix (shape (1, 1)) to perform correct matrix multiplication with x


Fix Guide:
Modify the parameter initialization, set the shape of w to (1,1) and the shape of b to (1,)


Correct Code:
class LinearModel:
    def __init__(self, key):
        self.w = random.normal(key, (1, 1))
        self.b = random.normal(key, (1,))
"""


"""
Error Code:
def loss_fn(params, x, y):
    preds = model_fn(params, x)
    return jnp.mean((preds - y) ** 2)
    

Error:
The original code uses Huber loss, while the incorrect code here uses mean square error (MSE) as the loss function


Fix Guide:
Change the loss function to Huber loss function, set delta=1.0
And use the L2 part when the error is less than or equal to delta: 0.5 * error²
And use the L1 part when the error is greater than delta: delta * (error - 0.5 * delta)

Correct Code:
def loss_fn(params, x, y, delta=1.0):
    preds = model_fn(params, x)
    error = jnp.abs(preds - y)
    loss = jnp.where(error <= delta,
                     0.5 * error**2, 
                     delta * (error - 0.5 * delta))
    return jnp.mean(loss)
"""


"""
Error Code:
import jax
import jax.numpy as jnp
from jax import grad, jit, random, vmap
import optax


Error:
The optax module was not used later.


Fix Guide:
Remove optax module

Correct Code:
import jax
import jax.numpy as jnp
from jax import grad, jit, random, vmap
"""


"""
Error Code:
model = train_model(key, model, x, y, epochs=100)

def train_model(key, model, x, y, epochs=100):

Error:
The parameter key is not used during training function

Fix Guide:
Remove key parameter from train_model function

Correct Code:
model = train_model(model, x, y, epochs=100)

def train_model(model, x, y, epochs=100):
"""


"""
Error Code:
def __init__(self, key):
    self.w = random.normal(key, (1, 1))
    self.b = random.normal(key, (1,))


Error:
JAX requires that the PRNG key be split each time a random number is used.


Fix Guide:
Use random.split to split the key and generate a separate sub-key for each random variable


Correct Code:
def __init__(self, key):
    key, subkey = random.split(key)
    self.w = random.normal(subkey, (1, 1))
    key, subkey = random.split(key)
    self.b = random.normal(subkey, (1,))
"""


"""
Error Code:
class LinearModel:
    def __init__(self, key):
        key, subkey = random.split(key)
        self.w = random.normal(subkey, (1, 1))
        key, subkey = random.split(key)
        self.b = random.normal(subkey, (1,))

    def __call__(self, x):
        return jnp.dot(x, self.w) + self.b
        

Error:
The parameters are stored in self.w and self.b respectively, and the update function during subsequent training uses the method of packing the parameters into a tuple and updating them, which is inconsistent with the original code's method of using a dictionary to store parameters.


Fix Guide:
Unified use of dictionary form to store parameters


Correct Code:
class LinearModel:
    def __init__(self, key):
        key, subkey = random.split(key)
        w = random.uniform(subkey, (1, 1), minval=-1.0, maxval=1.0)
        key, subkey = random.split(key)
        b = random.uniform(subkey, (1,), minval=-1.0, maxval=1.0)
        self.params = {"w": w, "b": b}

    def __call__(self, x):
        return jnp.dot(x, self.params["w"]) + self.params["b"]
"""


"""
Error Code:
def loss_fn(params, x, y, delta=1.0):
    preds = model_fn(params, x)
    error = jnp.abs(preds - y)
    loss = jnp.where(error <= delta,
                     0.5 * error**2, 
                     delta * (error - 0.5 * delta))
    return jnp.mean(loss)
    

Error:
Inconsistent naming and usage of loss functions


Fix Guide:
Rename the loss function to huber_loss, explicitly pass in the delta parameter (such as 1.0) in the update function, and modify the parameters in dictionary form for internal calculations


Correct Code:
def huber_loss(params, x, y, delta=1.0):
    preds = jnp.dot(x, params["w"]) + params["b"]
    error = jnp.abs(preds - y)
    loss = jnp.where(error <= delta,
                     0.5 * error**2, 
                     delta * (error - 0.5 * delta))
    return jnp.mean(loss)
"""


"""
Error Code:
def update(params, x, y, learning_rate=0.1):
    loss_value, grads = jax.value_and_grad(loss_fn)(params, x, y)
    w, b = params
    w = w - learning_rate * grads[0]
    b = b - learning_rate * grads[1]
    return (w, b)


Error:
The loss_fn is called here, which needs to be changed to huber_loss
The parameters are unpacked in tuple form, which is different from the previous code
The learning rate value is different from the original code


Fix Guide:
Change loss_fn to huber_loss
Change tuple to dict
Change lr to 0.01


Correct Code:
def update(params, x, y, learning_rate=0.01):
    loss_value, grads = jax.value_and_grad(huber_loss)(params, x, y, 1.0)
    params["w"] = params["w"] - learning_rate * grads["w"]
    params["b"] = params["b"] - learning_rate * grads["b"]
    return params
"""


"""
Error Code:
def train_model(model, x, y, epochs=100):
    for epoch in range(epochs):
        model.w, model.b = update((model.w, model.b), x, y)
    return model

Error:
Use tuple unpacking instead of dictionary when updating
No loss log is output during the entire training process, and the training progress cannot be observed


Fix Guide:
In the training function, update model.params using dict form
Add log output statement


Correct Code:
def train_model(model, x, y, epochs=1000):
    for epoch in range(epochs):
        model.params = update(model.params, x, y, learning_rate=0.01)
        if (epoch + 1) % 100 == 0:
            current_loss = huber_loss(model.params, x, y, 1.0)
            print(f"Epoch [{epoch+1}/{epochs}], Loss: {current_loss:.4f}")
    return model
"""


"""
Error Code:
x = jnp.array([[1.0], [2.0], [3.0]])
y = jnp.array([[2.0], [4.0], [6.0]])


Error:
The original code used 100 data points and added noise


Fix Guide:
Generate 100 samples using PRNGKey and add noise


Correct Code:
key, subkey = random.split(key)
x = random.uniform(subkey, shape=(100, 1)) * 10
key, subkey = random.split(key)
noise = random.normal(subkey, shape=(100, 1))
y = 2 * x + 3 + noise
"""


"""
Error Code:
print(f"Trained weights: {model.w}, bias: {model.b}")


Error:
The model parameters are stored in the dictionary self.params. There are no direct attributes w and b. Directly calling model.w and model.b will result in an error.


Fix Guide:
Take the weights and biases from model.params and modify them to model.params["w"] and model.params["b"]


Correct Code:
print(f"Trained weights: {model.params['w']}, bias: {model.params['b']}")
"""


"""
Error Code:
x = jnp.array([[1.0], [2.0], [3.0]])


Error:
Wrong test x cases


Fix Guide:
Change to same as original code


Correct Code:
x = jnp.array([[4.0], [7.0]])
"""

In [18]:
#Fixed Code
import jax
import jax.numpy as jnp
from jax import grad, jit, random, vmap


def model_fn(params, x):
    w, b = params
    return jnp.dot(x, w) + b

# Define a simple model
class LinearModel:
    def __init__(self, key):
        key, subkey = random.split(key)
        w = random.uniform(subkey, (1, 1), minval=-1.0, maxval=1.0)
        key, subkey = random.split(key)
        b = random.uniform(subkey, (1,), minval=-1.0, maxval=1.0)
        self.params = {"w": w, "b": b}

    def __call__(self, x):
        return jnp.dot(x, self.params["w"]) + self.params["b"]

# Loss function
def huber_loss(params, x, y, delta=1.0):
    preds = jnp.dot(x, params["w"]) + params["b"]
    error = jnp.abs(preds - y)
    loss = jnp.where(error <= delta,
                     0.5 * error**2, 
                     delta * (error - 0.5 * delta))
    return jnp.mean(loss)

# Update function using functional programming
def update(params, x, y, learning_rate=0.01):
    loss_value, grads = jax.value_and_grad(huber_loss)(params, x, y, 1.0)
    params["w"] = params["w"] - learning_rate * grads["w"]
    params["b"] = params["b"] - learning_rate * grads["b"]
    return params

# Training function
def train_model(model, x, y, epochs=1000):
    for epoch in range(epochs):
        model.params = update(model.params, x, y, learning_rate=0.01)
        if (epoch + 1) % 100 == 0:
            current_loss = huber_loss(model.params, x, y, 1.0)
            print(f"Epoch [{epoch+1}/{epochs}], Loss: {current_loss:.4f}")
    return model

def main():
    # Generate synthetic data
    key = random.PRNGKey(0)  # MODIFIED: Explicit PRNG key
    model = LinearModel(key)
    
    # Generate synthetic data
    key, subkey = random.split(key)
    x = random.uniform(subkey, shape=(100, 1)) * 10
    key, subkey = random.split(key)
    noise = random.normal(subkey, shape=(100, 1))
    y = 2 * x + 3 + noise

    # Train the model
    model = train_model(model, x, y, epochs=1000)

    x = jnp.array([[4.0], [7.0]])
    # Test the model
    predictions = model(x)
    print(f"Predictions for {x.tolist()}: {predictions.tolist()}")
    print(f"Trained weights: {model.params['w']}, bias: {model.params['b']}")

if __name__ == "__main__":
    main()

Epoch [100/1000], Loss: 0.8937
Epoch [200/1000], Loss: 0.8241
Epoch [300/1000], Loss: 0.7588
Epoch [400/1000], Loss: 0.6986
Epoch [500/1000], Loss: 0.6439
Epoch [600/1000], Loss: 0.5948
Epoch [700/1000], Loss: 0.5516
Epoch [800/1000], Loss: 0.5159
Epoch [900/1000], Loss: 0.4874
Epoch [1000/1000], Loss: 0.4650
Predictions for [[4.0], [7.0]]: [[10.567580223083496], [16.88758659362793]]
Trained weights: [[2.1066687]], bias: [2.1409056]
