In [1]:
#Input
import torch
import torch.nn as nn
import torch.optim as optim

# Generate synthetic data
torch.manual_seed(42)
X = torch.rand(100, 2) * 10  # 100 data points with 2 features
y = (X[:, 0] + X[:, 1] * 2).unsqueeze(1) + torch.randn(100, 1)  # Non-linear relationship with noise

# Define the Deep Neural Network Model
class DNNModel(nn.Module):
    def __init__(self):
        super(DNNModel, self).__init__()
        self.fc1 = nn.Linear(2, 10)  # Input layer to hidden layer
        self.relu = nn.ReLU()       # Activation function
        self.fc2 = nn.Linear(10, 1) # Hidden layer to output layer

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x
    
# Initialize the model, loss function, and optimizer
model = DNNModel()
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)

# Training loop
epochs = 1000
for epoch in range(epochs):
    # Forward pass
    predictions = model(X)
    loss = criterion(predictions, y)

    # Backward pass and optimization
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    # Log progress every 100 epochs
    if (epoch + 1) % 100 == 0:
        print(f"Epoch [{epoch + 1}/{epochs}], Loss: {loss.item():.4f}")

# Testing on new data
X_test = torch.tensor([[4.0, 3.0], [7.0, 8.0]])
with torch.no_grad():
    predictions = model(X_test)
    print(f"Predictions for {X_test.tolist()}: {predictions.tolist()}")
    

Epoch [100/1000], Loss: 2.3831
Epoch [200/1000], Loss: 0.8492
Epoch [300/1000], Loss: 0.7732
Epoch [400/1000], Loss: 0.7514
Epoch [500/1000], Loss: 0.7371
Epoch [600/1000], Loss: 0.7291
Epoch [700/1000], Loss: 0.7251
Epoch [800/1000], Loss: 0.7233
Epoch [900/1000], Loss: 0.7225
Epoch [1000/1000], Loss: 0.7221
Predictions for [[4.0, 3.0], [7.0, 8.0]]: [[9.915834426879883], [23.08173179626465]]


In [2]:
#Strong LLM
import jax
import jax.numpy as jnp
import optax
import numpy as np


# Create a PRNG key
key = jax.random.PRNGKey(42)

# Generate synthetic data: 100 data points with 2 features in [0, 10)
key, subkey = jax.random.split(key)
X = jax.random.uniform(subkey, shape=(100, 2)) * 10

# Compute target: y = (X[:, 0] + 2 * X[:, 1]) + noise, with shape (100, 1)
key, subkey = jax.random.split(key)
noise = jax.random.normal(subkey, shape=(100, 1))
# Compute (X[:,0] + 2*X[:,1]) and expand dims to (100, 1)
y = (X[:, 0] + 2 * X[:, 1]).reshape(-1, 1) + noise


# We define a two-layer network:
#  - fc1: linear transformation from 2 -> 10
#  - ReLU activation
#  - fc2: linear transformation from 10 -> 1
def predict(params, x):
    hidden = jnp.dot(x, params['W1']) + params['b1']  # (batch, 10)
    hidden = jax.nn.relu(hidden)
    output = jnp.dot(hidden, params['W2']) + params['b2']  # (batch, 1)
    return output


# For simplicity, we initialize weights using a uniform distribution.
# You could also use other initializers (e.g., Xavier/Glorot uniform).
def init_params(key):
    keys = jax.random.split(key, 4)
    # fc1: shape (2, 10)
    W1 = jax.random.uniform(keys[0], shape=(2, 10), minval=-1.0, maxval=1.0)
    b1 = jax.random.uniform(keys[1], shape=(10,), minval=-1.0, maxval=1.0)
    # fc2: shape (10, 1)
    W2 = jax.random.uniform(keys[2], shape=(10, 1), minval=-1.0, maxval=1.0)
    b2 = jax.random.uniform(keys[3], shape=(1,), minval=-1.0, maxval=1.0)
    return {'W1': W1, 'b1': b1, 'W2': W2, 'b2': b2}

params = init_params(key)


def loss_fn(params, x, y):
    preds = predict(params, x)
    return jnp.mean((preds - y) ** 2)


# Use the Adam optimizer from optax
optimizer = optax.adam(learning_rate=0.01)
opt_state = optimizer.init(params)

epochs = 1000

# Define a function to compute loss and gradients
loss_and_grad = jax.value_and_grad(loss_fn)

for epoch in range(epochs):
    loss, grads = loss_and_grad(params, X, y)
    updates, opt_state = optimizer.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    
    # Log progress every 100 epochs
    if (epoch + 1) % 100 == 0:
        print(f"Epoch [{epoch + 1}/{epochs}], Loss: {loss:.4f}")


X_test = jnp.array([[4.0, 3.0], [7.0, 8.0]])
predictions = predict(params, X_test)
print(f"Predictions for {np.array(X_test).tolist()}: {np.array(predictions).tolist()}")

Epoch [100/1000], Loss: 1.3837
Epoch [200/1000], Loss: 1.1774
Epoch [300/1000], Loss: 1.0918
Epoch [400/1000], Loss: 1.0450
Epoch [500/1000], Loss: 1.0342
Epoch [600/1000], Loss: 1.0314
Epoch [700/1000], Loss: 1.0294
Epoch [800/1000], Loss: 1.0276
Epoch [900/1000], Loss: 1.0258
Epoch [1000/1000], Loss: 1.0239
Predictions for [[4.0, 3.0], [7.0, 8.0]]: [[10.156254768371582], [23.063282012939453]]


In [3]:
#Weak LLM
import jax.numpy as jnp  # MODIFIED: Ensured consistent import for jax.numpy as jnp
from jax import random  # MODIFIED: Added necessary import for random functionality
from typing import Any, Tuple

def init_params(key: Any, input_shape: Tuple[int, ...]) -> Any:
    """Initialize parameters for the model."""
    param_shape = (input_shape[0], 1)  # Example shape for parameters
    return random.normal(key, param_shape)  # Use explicit PRNG key

def loss_fn(params: Any, inputs: jnp.ndarray, targets: jnp.ndarray) -> float:
    """Calculate the loss."""
    predictions = jnp.dot(inputs, params)  # Simulate predictions
    return jnp.mean((predictions - targets) ** 2)  # Mean Squared Error

def main() -> None:
    """Main entry point for the program."""
    key = random.PRNGKey(0)  # Create an explicit PRNG key
    input_shape = (5, 10)  # Define input shape
    params = init_params(key, input_shape)  # Initialize parameters
    inputs = jnp.ones((5, 10))  # Example input data
    targets = jnp.ones((5,))  # Example target data

    # Calculate loss
    loss_value = loss_fn(params, inputs, targets)  # Using loss function
    print(f"Loss: {loss_value}")  # Displaying loss

if __name__ == "__main__":
    main()  # Entry point for the program

TypeError: dot_general requires contracting dimensions to have the same shape, got (10,) and (5,).

In [None]:
"""
Error Code:
def init_params(key: Any, input_shape: Tuple[int, ...]) -> Any:
    param_shape = (input_shape[0], 1)  # Example shape for parameters
    return random.normal(key, param_shape)


Error:
dot_general requires contracting dimensions to have the same shape, got (10,) and (5,)


Fix Guide:
Modify the init_params function so that the shape of the parameters matches the input data. 
The parameters should be initialized to (input_shape[1], 1)


Correct Code:
def init_params(key: Any, input_shape: Tuple[int, ...]) -> Any:
    Initialize parameters for the model.
    param_shape = (input_shape[1], 1)
    return random.normal(key, param_shape)
"""


"""
Error Code:
def init_params(key: Any, input_shape: Tuple[int, ...]) -> Any:
    param_shape = (input_shape[1], 1)
    return random.normal(key, param_shape)
    

Error:
The parameter initialization is incomplete. 
The four parameters (W1, b1, W2, b2) that need to be initialized in the two-layer network in the original code are inconsistent.


Fix Guide:
Remove the redundant input_shape parameter, use random.split to divide the 4 sub-keys, and then initialize the weights and biases of fc1 and the weights and biases of fc2 respectively.


Correct Code:
def init_params(key: Any) -> Any:
    keys = random.split(key, 4)
    W1 = random.uniform(keys[0], shape=(2, 10), minval=-1.0, maxval=1.0)
    b1 = random.uniform(keys[1], shape=(10,), minval=-1.0, maxval=1.0)
    W2 = random.uniform(keys[2], shape=(10, 1), minval=-1.0, maxval=1.0)
    b2 = random.uniform(keys[3], shape=(1,), minval=-1.0, maxval=1.0)
    return {'W1': W1, 'b1': b1, 'W2': W2, 'b2': b2}
"""


"""
Error Code:
params = init_params(key, input_shape)

Error:
The function init_params is defined to accept only one parameter (PRNG key)

Fix Guide:
Remove input_shape parameter from init_params function

Correct Code:
params = init_params(key)
"""


"""
Error Code:
predictions = jnp.dot(inputs, params)

Error:
You cannot directly perform a dot product operation on params. 
The parameters are dictionaries, and the two-layer network needs to go through the hidden layer before calculating the output.

Fix Guide:
Define a predict function, first calculate the first layer linear transformation and use ReLU activation, then calculate the second layer linear transformation to get the final output

Correct Code:
def predict(params: Any, x: jnp.ndarray) -> jnp.ndarray:
    hidden = jnp.dot(x, params['W1']) + params['b1']
    hidden = jax.nn.relu(hidden)
    output = jnp.dot(hidden, params['W2']) + params['b2']
    return output
"""


"""
Error Code:
def loss_fn(params: Any, inputs: jnp.ndarray, targets: jnp.ndarray) -> float:
    predictions = jnp.dot(inputs, params)  # Simulate predictions
    return jnp.mean((predictions - targets) ** 2)


Error:
The params dictionary is incorrectly matrix multiplied directly, the newly defined predict function should be called


Fix Guide:
Change the line that calculates the predicted value to call the predict function


Correct Code:
def loss_fn(params: Any, inputs: jnp.ndarray, targets: jnp.ndarray) -> float:
    predictions = predict(params, inputs)
    return jnp.mean((predictions - targets) ** 2)
"""


"""
Error Code:
input_shape = (5, 10)  # Define input shape
inputs = jnp.ones((5, 10))  # Example input data
targets = jnp.ones((5,))  # Example target data


Error:
The model expects 2 features as input, but 10 is used here
The shape of the target data is (5,), but the predicted output shape is (5, 1), which indicates a dimension mismatch.


Fix Guide:
Reshape the input data to have 2 features and expand the target data into a 2D array


Correct Code:
inputs = jnp.ones((5, 2))  # Example input data with 2 features
targets = jnp.ones((5, 1))  # Example target data with shape (batch, 1)
"""


"""
Error Code:
hidden = jax.nn.relu(hidden)


Error:
jax.nn.relu is used, but the entire jax module is not imported in the file, resulting in jax being undefined


Fix Guide:
Add import jax at the beginning of the file


Correct Code:
import jax
"""


"""
Error Code:
inputs = jnp.ones((5, 2))  # Example input data with 2 features
targets = jnp.ones((5, 1))  # Example target data with shape (batch, 1)


Error:
Does not meet the synthetic data requirement of randomly generating 100 data points and adding noise in the original pytorch code


Fix Guide:
Generate 100 2D data using random numbers and calculate the target value as X[:,0] + X[:,1] * 2 plus noise


Correct Code:
key = random.PRNGKey(42)
key, subkey = random.split(key)
X = random.uniform(key, shape=(100, 2), minval=0.0, maxval=1.0) * 10
key, subkey = random.split(subkey)
noise = random.normal(subkey, shape=(100, 1))
y = (X[:, 0:1] + X[:, 1:2] * 2) + noise
"""


"""
Error Code:
# Calculate loss
loss_value = loss_fn(params, inputs, targets)  # Using loss function
print(f"Loss: {loss_value}")  # Displaying loss


Error:
There is no backpropagation (using jax.grad to calculate gradients) and parameter update steps in the jax code


Fix Guide:
Add a training loop, define an update function, calculate the gradient through jax.grad(loss_fn)
Use simple gradient descent to update the parameters, and print the current loss every certain epoch.


Correct Code:
def update(params, inputs, targets, lr):
    grads = jax.grad(loss_fn)(params, inputs, targets)
    new_params = {k: params[k] - lr * grads[k] for k in params}
    return new_params


epochs = 1000
lr = 0.01
for epoch in range(epochs):
    params = update(params, X, y, lr)
    if (epoch + 1) % 100 == 0:
        current_loss = loss_fn(params, X, y)
        print(f"Epoch [{epoch+1}/{epochs}], Loss: {current_loss:.4f}")
"""


"""
Error Code:
epochs = 1000
    lr = 0.01
    for epoch in range(epochs):
        params = update(params, X, y, lr)
        if (epoch + 1) % 100 == 0:
            current_loss = loss_fn(params, X, y)
            print(f"Epoch [{epoch+1}/{epochs}], Loss: {current_loss:.4f}")


Error:
The code does not include the part that makes predictions on the test data


Fix Guide:
After training is complete, add prediction code for test data and print the prediction results


Correct Code:
epochs = 1000
    lr = 0.01
    for epoch in range(epochs):
        params = update(params, X, y, lr)
        if (epoch + 1) % 100 == 0:
            current_loss = loss_fn(params, X, y)
            print(f"Epoch [{epoch+1}/{epochs}], Loss: {current_loss:.4f}")

X_test = jnp.array([[4.0, 3.0], [7.0, 8.0]])
predictions = predict(params, X_test)
print(f"Predictions for {X_test.tolist()}: {predictions.tolist()}")
"""


"""
Error Code:
input_shape = (5, 10)  # Define input shape


Error:
The variable is not used and does not match the shape of the actual data


Fix Guide:
Remove the useless input_shape variable or replace it with correct synthetic data generation code


Correct Code:
# input_shape = (5, 10)  # Define input shape
"""


"""
Error Code:
key = random.PRNGKey(0)  # Create an explicit PRNG key
# input_shape = (5, 10)  # Define input shape
params = init_params(key)
key = random.PRNGKey(42)
key, subkey = random.split(key)


Error:
Different random seeds are used for model parameter initialization and data generation


Fix Guide:
Use the same random seed and split it appropriately to ensure that parameters and data generation are based on the same initial seed.


Correct Code:
key = random.PRNGKey(42) 
key, subkey = random.split(key)
params = init_params(subkey)
"""


"""
Error Code:
key, subkey = random.split(key)
X = random.uniform(key, shape=(100, 2), minval=0.0, maxval=1.0) * 10
key, subkey = random.split(subkey)
noise = random.normal(subkey, shape=(100, 1))
y = (X[:, 0:1] + X[:, 1:2] * 2) + noise


Error:
Reusing variable names when splitting keys can easily cause confusion, and using the split key and subkey at the same time is not clear enough


Fix Guide:
Split the key continuously when generating data, and explicitly use the split key to generate each part of the data


Correct Code:
key = random.PRNGKey(42)
key, subkey_params = random.split(key)
params = init_params(subkey_params)

key, subkey_X = random.split(key)
X = random.uniform(subkey_X, shape=(100, 2), minval=0.0, maxval=1.0) * 10
key, subkey_noise = random.split(key)
noise = random.normal(subkey_noise, shape=(100, 1))
y = (X[:, 0:1] + X[:, 1:2] * 2) + noise
"""


"""
Error Code:
def update(params, inputs, targets, lr):
    grads = jax.grad(loss_fn)(params, inputs, targets)
    new_params = {k: params[k] - lr * grads[k] for k in params}
    return new_params

Error:
The original PyTorch code uses the Adam optimizer, while the JAX code here only implements a simple gradient descent update.


Fix Guide:
Use the optax library commonly used in the JAX ecosystem to implement the Adam optimizer


Correct Code:
import optax

    optimizer = optax.adam(lr)
    opt_state = optimizer.init(params)
    
    for epoch in range(epochs):
        grads = jax.grad(loss_fn)(params, X, y)
        updates, opt_state = optimizer.update(grads, opt_state)
        params = optax.apply_updates(params, updates)
        
        if (epoch + 1) % 100 == 0:
            current_loss = loss_fn(params, X, y)
            print(f"Epoch [{epoch+1}/{epochs}], Loss: {current_loss:.4f}")
"""

In [21]:
#Fixed Code
import jax
import jax.numpy as jnp  # MODIFIED: Ensured consistent import for jax.numpy as jnp
from jax import random  # MODIFIED: Added necessary import for random functionality
from typing import Any, Tuple

def init_params(key: Any) -> Any:
    keys = random.split(key, 4)
    W1 = random.uniform(keys[0], shape=(2, 10), minval=-1.0, maxval=1.0)
    b1 = random.uniform(keys[1], shape=(10,), minval=-1.0, maxval=1.0)
    W2 = random.uniform(keys[2], shape=(10, 1), minval=-1.0, maxval=1.0)
    b2 = random.uniform(keys[3], shape=(1,), minval=-1.0, maxval=1.0)
    return {'W1': W1, 'b1': b1, 'W2': W2, 'b2': b2}

def predict(params: Any, x: jnp.ndarray) -> jnp.ndarray:
    hidden = jnp.dot(x, params['W1']) + params['b1']
    hidden = jax.nn.relu(hidden)
    output = jnp.dot(hidden, params['W2']) + params['b2']
    return output

def loss_fn(params: Any, inputs: jnp.ndarray, targets: jnp.ndarray) -> float:
    predictions = predict(params, inputs)
    return jnp.mean((predictions - targets) ** 2)

def update(params, inputs, targets, lr):
    grads = jax.grad(loss_fn)(params, inputs, targets)
    new_params = {k: params[k] - lr * grads[k] for k in params}
    return new_params

def main() -> None:
    """Main entry point for the program."""
    key = random.PRNGKey(42)
    key, subkey_params = random.split(key)
    params = init_params(subkey_params)

    key, subkey_X = random.split(key)
    X = random.uniform(subkey_X, shape=(100, 2), minval=0.0, maxval=1.0) * 10
    key, subkey_noise = random.split(key)
    noise = random.normal(subkey_noise, shape=(100, 1))
    y = (X[:, 0:1] + X[:, 1:2] * 2) + noise

    epochs = 1000
    lr = 0.01
    optimizer = optax.adam(lr)
    opt_state = optimizer.init(params)
    
    for epoch in range(epochs):
        grads = jax.grad(loss_fn)(params, X, y)
        updates, opt_state = optimizer.update(grads, opt_state)
        params = optax.apply_updates(params, updates)
        
        if (epoch + 1) % 100 == 0:
            current_loss = loss_fn(params, X, y)
            print(f"Epoch [{epoch+1}/{epochs}], Loss: {current_loss:.4f}")
    
    X_test = jnp.array([[4.0, 3.0], [7.0, 8.0]])
    predictions = predict(params, X_test)
    print(f"Predictions for {X_test.tolist()}: {predictions.tolist()}")

if __name__ == "__main__":
    main()  # Entry point for the program

Epoch [100/1000], Loss: 1.2904
Epoch [200/1000], Loss: 1.0354
Epoch [300/1000], Loss: 0.9703
Epoch [400/1000], Loss: 0.9253
Epoch [500/1000], Loss: 0.8995
Epoch [600/1000], Loss: 0.8899
Epoch [700/1000], Loss: 0.8866
Epoch [800/1000], Loss: 0.8850
Epoch [900/1000], Loss: 0.8842
Epoch [1000/1000], Loss: 0.8836
Predictions for [[4.0, 3.0], [7.0, 8.0]]: [[9.865676879882812], [23.006179809570312]]
