In [1]:
#Input
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter

# Generate synthetic data
torch.manual_seed(42)
X = torch.rand(100, 1) * 10  # 100 data points between 0 and 10
y = 3 * X + 5 + torch.randn(100, 1)  # Linear relationship with noise

# Define a simple Linear Regression Model
class LinearRegressionModel(nn.Module):
    def __init__(self):
        super(LinearRegressionModel, self).__init__()
        self.linear = nn.Linear(1, 1)  # Single input and single output

    def forward(self, x):
        return self.linear(x)

# Initialize TensorBoard writer
writer = SummaryWriter(log_dir="runs/linear_regression")

# Initialize the model, loss function, and optimizer
model = LinearRegressionModel()
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)

# Training loop
epochs = 100
for epoch in range(epochs):
    # Forward pass
    predictions = model(X)
    loss = criterion(predictions, y)

    # Backward pass and optimization
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    # Log loss to TensorBoard
    writer.add_scalar("Loss/train", loss.item(), epoch)

    # Log progress every 10 epochs
    if (epoch + 1) % 10 == 0:
        print(f"Epoch [{epoch + 1}/{epochs}], Loss: {loss.item():.4f}")

# Close the TensorBoard writer
writer.close()

# Run TensorBoard using the logs generated
# Command to run: tensorboard --logdir=runs

Epoch [10/100], Loss: 6.1875
Epoch [20/100], Loss: 5.6829
Epoch [30/100], Loss: 5.2245
Epoch [40/100], Loss: 4.8079
Epoch [50/100], Loss: 4.4293
Epoch [60/100], Loss: 4.0854
Epoch [70/100], Loss: 3.7728
Epoch [80/100], Loss: 3.4888
Epoch [90/100], Loss: 3.2307
Epoch [100/100], Loss: 2.9962


In [11]:
#Strong LLM
import jax
import jax.numpy as jnp
import optax
from tensorboardX import SummaryWriter  # pip install tensorboardX
import numpy as np


key = jax.random.PRNGKey(42)
key, subkey = jax.random.split(key)
X = jax.random.uniform(subkey, shape=(100, 1)) * 10  # 100 data points between 0 and 10
key, subkey = jax.random.split(key)
y = 3 * X + 5 + jax.random.normal(subkey, shape=(100, 1))  # Linear relationship with noise

def predict(params, x):
    # Linear model: y = x*w + b
    return jnp.dot(x, params["w"]) + params["b"]

def loss_fn(params, X, y):
    preds = predict(params, X)
    return jnp.mean((preds - y) ** 2)

# Initialize weight and bias (mimicking default behavior in nn.Linear)
key, subkey = jax.random.split(key)
params = {
    "w": 0.1 * jax.random.normal(subkey, shape=(1, 1)),
    "b": jnp.zeros((1,))
}

optimizer = optax.sgd(learning_rate=0.01)
opt_state = optimizer.init(params)

writer = SummaryWriter(log_dir="runs/linear_regression")

epochs = 100
for epoch in range(epochs):
    # Compute loss and gradients
    loss, grads = jax.value_and_grad(loss_fn)(params, X, y)
    updates, opt_state = optimizer.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    
    # Log loss to TensorBoard
    writer.add_scalar("Loss/train", float(loss), epoch)
    
    # Log progress every 10 epochs
    if (epoch + 1) % 10 == 0:
        print(f"Epoch [{epoch + 1}/{epochs}], Loss: {float(loss):.4f}")

writer.close()

# To run TensorBoard using the generated logs, execute the following command in your terminal:
# tensorboard --logdir=runs


Epoch [10/100], Loss: 5.4223
Epoch [20/100], Loss: 5.0028
Epoch [30/100], Loss: 4.6238
Epoch [40/100], Loss: 4.2814
Epoch [50/100], Loss: 3.9721
Epoch [60/100], Loss: 3.6926
Epoch [70/100], Loss: 3.4401
Epoch [80/100], Loss: 3.2120
Epoch [90/100], Loss: 3.0059
Epoch [100/100], Loss: 2.8198


In [3]:
#Weak LLM
import jax
import jax.numpy as jnp
from jax import grad, jit, random, vmap
from flax import linen as nn
import optax
import numpy as np
import tensorboard

# Linear regression model definition
class LinearRegressionModel(nn.Module):
    input_dim: int

    def setup(self):
        self.w = self.param('w', nn.initializers.xavier_uniform(), (self.input_dim, 1))
        self.b = self.param('b', nn.initializers.xavier_uniform(), (1,))

    def __call__(self, x):
        return jnp.dot(x, self.w) + self.b

# Loss function
def loss_fn(model, inputs, targets):
    predictions = model(inputs)
    return jnp.mean((predictions - targets) ** 2)

# Jitted gradient computation using vectorization
@jit
def compute_gradients(model, inputs, targets):
    return grad(loss_fn)(model, inputs, targets)  # MODIFIED: Use JAX's vectorized grad function

# Training function
def train_model(model, inputs, targets, num_epochs=1000, learning_rate=0.01):
    optimizer = optax.adam(learning_rate)
    opt_state = optimizer.init(model)

    for epoch in range(num_epochs):
        grads = compute_gradients(model, inputs, targets)  # MODIFIED: Use optimized gradient computation
        updates, opt_state = optimizer.update(grads, opt_state)
        model = model.apply(updates)

        if epoch % 100 == 0:
            current_loss = loss_fn(model, inputs, targets)
            print(f"Epoch {epoch}, Loss: {current_loss}")

    return model

def main():
    # Generate synthetic data
    inputs = jnp.array([[1.0], [2.0], [3.0]])  # Input features
    targets = jnp.array([[2.0], [3.0], [4.0]])  # Target output

    # Initialize model
    model = LinearRegressionModel(input_dim=1)  # MODIFIED: Clearer initialization

    # Train the model
    trained_model = train_model(model, inputs, targets)  # MODIFIED: Clearer function call

if __name__ == "__main__":
    main()  # Entry point of the program

TypeError: zeros_like requires ndarray or scalar arguments, got <class '__main__.LinearRegressionModel'> at position 0.

In [16]:
"""
Error Code:
self.b = self.param('b', nn.initializers.xavier_uniform(), (1,))


Error:
Can't compute input and output sizes of a 1-dimensional weights tensor. Must be at least 2D


Fix Guide:
For bias parameters, zero initialization is usually sufficient. Change the initializer to nn.initializers.zeros


Correct Code:
self.b = self.param('b', nn.initializers.zeros, (1,))
"""


"""
Error Code:
trained_model = train_model(model, inputs, targets)


Error:
train_model() missing 1 required positional argument: 'targets'


Fix Guide:
Modify the function call to pass in the correct order of parameters: first pass in the initialized parameter dictionary params, then pass in the model model, then inputs and targets


Correct Code:
trained_params = train_model(params, model, inputs, targets)
final_predictions = model.apply(trained_params, inputs)
"""


"""
Error Code:
model = LinearRegressionModel(input_dim=1)


Error:
The model parameters need to be initialized by calling model.init(rng, inputs)


Fix Guide:
Call model.init with a random key and input example to get the parameter dictionary, and then use the parameters in subsequent training


Correct Code:
model = LinearRegressionModel(input_dim=1)
key = jax.random.PRNGKey(0)
params = model.init(key, inputs)
"""


"""
Error Code:
def loss_fn(model, inputs, targets):
    predictions = model(inputs)
    return jnp.mean((predictions - targets) ** 2)
    

Error:
Directly calling model(inputs) cannot pass in parameters


Fix Guide:
Modify the loss function so that its first parameter is a parameter dictionary and pass in the model object to call the apply method


Correct Code:
def loss_fn(params, inputs, targets, model):
    predictions = model.apply(params, inputs)
    return jnp.mean((predictions - targets) ** 2)
"""


"""
Error Code:
@jit
def compute_gradients(model, inputs, targets):
    return grad(loss_fn)(model, inputs, targets)
    

Error:
The loss function passes in the model instance instead of the parameters


Fix Guide:
Modify the function parameters so that the first parameter is a parameter dictionary and pass in the model object


Correct Code:
@jit
def compute_gradients(params, inputs, targets, model):
    return grad(loss_fn)(params, inputs, targets, model)
"""


"""
Error Code:
updates, opt_state = optimizer.update(grads, opt_state)
model = model.apply(updates)


Error:
In Flax + Optax, update parameters using optax.apply_updates(params, updates) instead of calling model.apply


Fix Guide:
Assign the updated parameters to params


Correct Code:
updates, opt_state = optimizer.update(grads, opt_state)
params = optax.apply_updates(params, updates)
"""


"""
Error Code:
def train_model(model, inputs, targets, num_epochs=1000, learning_rate=0.01):
    optimizer = optax.adam(learning_rate)
    opt_state = optimizer.init(model)
    ...
    return model
    

Error:
During training, the parameter dictionary should be passed in and updated instead of the model instance
Parameters should be passed in when initializing the optimizer


Fix Guide:
Modify the parameters of the training function so that it receives a parameter dictionary and returns the updated parameters on return


Correct Code:
def train_model(params, model, inputs, targets, num_epochs=100, learning_rate=0.01):
    optimizer = optax.adam(learning_rate)
    opt_state = optimizer.init(params)

    for epoch in range(num_epochs):
        grads = compute_gradients(params, inputs, targets, model)
        updates, opt_state = optimizer.update(grads, opt_state)
        params = optax.apply_updates(params, updates)

        if epoch % 10 == 0:
            current_loss = loss_fn(params, inputs, targets, model)
            print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {current_loss:.4f}")
            writer.add_scalar("Loss/train", current_loss, epoch)

    return params
"""


"""
Error Code:
import tensorboard


Error:
SummaryWriter is used in the PyTorch code to record the training process, while the tensorboard module is imported in the JAX code but not actually used.


Fix Guide:
Use tensorboardX to create a SummaryWriter and log scalars during training


Correct Code:
from tensorboardX import SummaryWriter
writer = SummaryWriter(log_dir="runs/linear_regression")
"""


"""
Error Code:
inputs = jnp.array([[1.0], [2.0], [3.0]])
targets = jnp.array([[2.0], [3.0], [4.0]])


Error:
The original PyTorch code generates 100 random data in the interval [0,10] and adds noise. Here we use only 3 data points.


Fix Guide:
Generate 100 random data points using jax.random and add noise to construct the target value


Correct Code:
key = jax.random.PRNGKey(42)
key, subkey1, subkey2 = jax.random.split(key, 3)
inputs = jax.random.uniform(subkey1, (100, 1), minval=0.0, maxval=10.0)
noise = jax.random.normal(subkey2, (100, 1))
targets = 3 * inputs + 5 + noise
"""


"""
Error Code:
@jit
def compute_gradients(params, inputs, targets, model):
    return grad(loss_fn)(params, inputs, targets, model)
    

Error:
Cannot interpret value of type <class '__main__.LinearRegressionModel'> as an abstract array; it does not have a dtype attribute


Fix Guide:
The model parameter needs to be marked as a static parameter


Correct Code:
@jit(static_argnums=(3,))
def compute_gradients(params, inputs, targets, model):
    return grad(loss_fn)(params, inputs, targets, model)
"""


"""
Error Code:
@jit(static_argnums=(3,))
def compute_gradients(params, inputs, targets, model):
    return grad(loss_fn)(params, inputs, targets, model)
    

Error:
jit() missing 1 required positional argument: 'fun'


Fix Guide:
First define the function compute_gradients
Use jit to explicitly convert the function and specify the static parameter static_argnums=(3,)


Correct Code:
def compute_gradients(params, inputs, targets, model):
    return grad(loss_fn)(params, inputs, targets, model)
compute_gradients = jit(compute_gradients, static_argnums=(3,))
"""


"""
Error Code:
optimizer = optax.adam(learning_rate)


Error:
The Adam optimizer is used here, while the original PyTorch code uses SGD (stochastic gradient descent)


Fix Guide:
Use optax.sgd(learning_rate) instead of optax.adam(learning_rate)


Correct Code:
optimizer = optax.sgd(learning_rate)
"""


"""
Error Code:
if epoch % 10 == 0:
    current_loss = loss_fn(params, inputs, targets, model)
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {current_loss:.4f}")
    writer.add_scalar("Loss/train", current_loss, epoch)
    

Error:
The original PyTorch code prints when (epoch + 1) % 10 == 0, that is, it prints at the 10th, 20th, ... epochs.


Fix Guide:
Modify the condition to if (epoch + 1) % 10 == 0:


Correct Code:
if (epoch + 1) % 10 == 0:
    current_loss = loss_fn(params, inputs, targets, model)
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {current_loss:.4f}")
    writer.add_scalar("Loss/train", current_loss, epoch)
"""


"""
Error Code:
writer = SummaryWriter(log_dir="runs/linear_regression")
...
return params

Error:
Failure to call writer.close() may result in the log file not being written to disk correctly or resources not being released, which may affect the log viewing of TensorBoard.


Fix Guide:
Call writer.close() after the training loop ends and before returning the arguments


Correct Code:
writer = SummaryWriter(log_dir="runs/linear_regression")
...
writer.close()
return params
"""


"\nError Code:\n@jit\ndef compute_gradients(params, inputs, targets, model):\n    return grad(loss_fn)(params, inputs, targets, model)\n    \n\nError:\nCannot interpret value of type <class '__main__.LinearRegressionModel'> as an abstract array; it does not have a dtype attribute\n\n\nFix Guide:\nThe model parameter needs to be marked as a static parameter\n\n\nCorrect Code:\n@jit(static_argnums=(3,))\ndef compute_gradients(params, inputs, targets, model):\n    return grad(loss_fn)(params, inputs, targets, model)\n"

In [22]:
#Fixed Code
import jax
import jax.numpy as jnp
from jax import grad, jit, random, vmap
from flax import linen as nn
import optax
import numpy as np
from tensorboardX import SummaryWriter

# Linear regression model definition
class LinearRegressionModel(nn.Module):
    input_dim: int

    def setup(self):
        self.w = self.param('w', nn.initializers.xavier_uniform(), (self.input_dim, 1))
        self.b = self.param('b', nn.initializers.zeros, (1,))

    def __call__(self, x):
        return jnp.dot(x, self.w) + self.b

# Loss function
def loss_fn(params, inputs, targets, model):
    predictions = model.apply(params, inputs)
    return jnp.mean((predictions - targets) ** 2)

# Jitted gradient computation using vectorization
def compute_gradients(params, inputs, targets, model):
    return grad(loss_fn)(params, inputs, targets, model)
compute_gradients = jit(compute_gradients, static_argnums=(3,))

# Training function
def train_model(params, model, inputs, targets, num_epochs=100, learning_rate=0.01):
    optimizer = optax.sgd(learning_rate)
    opt_state = optimizer.init(params)
    writer = SummaryWriter(log_dir="runs/linear_regression")

    for epoch in range(num_epochs):
        grads = compute_gradients(params, inputs, targets, model)
        updates, opt_state = optimizer.update(grads, opt_state)
        params = optax.apply_updates(params, updates)

        if (epoch + 1) % 10 == 0:
            current_loss = loss_fn(params, inputs, targets, model)
            print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {current_loss:.4f}")
            writer.add_scalar("Loss/train", current_loss, epoch)
    
    writer.close()
    return params

def main():
    # Generate synthetic data
    key = jax.random.PRNGKey(42)
    key, subkey1, subkey2 = jax.random.split(key, 3)
    inputs = jax.random.uniform(subkey1, (100, 1), minval=0.0, maxval=10.0)
    noise = jax.random.normal(subkey2, (100, 1))
    targets = 3 * inputs + 5 + noise

    # Initialize model
    model = LinearRegressionModel(input_dim=1)  # MODIFIED: Clearer initialization
    key = jax.random.PRNGKey(0)
    params = model.init(key, inputs)

    # Train the model
    trained_params = train_model(params, model, inputs, targets)
    final_predictions = model.apply(trained_params, inputs)

if __name__ == "__main__":
    main()  # Entry point of the program

Epoch [10/100], Loss: 4.6780
Epoch [20/100], Loss: 4.3839
Epoch [30/100], Loss: 4.1136
Epoch [40/100], Loss: 3.8654
Epoch [50/100], Loss: 3.6373
Epoch [60/100], Loss: 3.4277
Epoch [70/100], Loss: 3.2352
Epoch [80/100], Loss: 3.0583
Epoch [90/100], Loss: 2.8958
Epoch [100/100], Loss: 2.7465
