In [8]:
import torch
import torch.nn as nn
import torch.optim as optim

# Generate synthetic sequential data
torch.manual_seed(42)
sequence_length = 10
num_samples = 100

# Create a sine wave dataset
X = torch.linspace(0, 4 * 3.14159, steps=num_samples).unsqueeze(1)
y = torch.sin(X)

# Prepare data for RNN
def create_in_out_sequences(data, seq_length):
    in_seq = []
    out_seq = []
    for i in range(len(data) - seq_length):
        in_seq.append(data[i:i + seq_length])
        out_seq.append(data[i + seq_length])
    return torch.stack(in_seq), torch.stack(out_seq)

X_seq, y_seq = create_in_out_sequences(y, sequence_length)

# Define the RNN Model
class RNNModel(nn.Module):
    def __init__(self):
        super(RNNModel, self).__init__()
        self.rnn = nn.RNN(input_size=1, hidden_size=50, num_layers=1, batch_first=True)
        self.fc = nn.Linear(50, 1)
        self.relu = nn.ReLU()

    def forward(self, x):
        out, _ = self.rnn(x)
        out = self.fc(out[:, -1, :])  # Use the last output of the RNN
        return out

# Initialize the model, loss function, and optimizer
model = RNNModel()
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training loop
epochs = 500
for epoch in range(epochs):
    for sequences, labels in zip(X_seq, y_seq):
        sequences = sequences.unsqueeze(0)  # Add batch dimension
        labels = labels.unsqueeze(0)  # Add batch dimension

        # Forward pass
        outputs = model(sequences)
        loss = criterion(outputs, labels)

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"Epoch [{epoch + 1}/{epochs}], Loss: {loss.item():.4f}")

# Testing on new data
X_test = torch.linspace(4 * 3.14159, 5 * 3.14159, steps=10).unsqueeze(1)

# Reshape to (batch_size, sequence_length, input_size)
X_test = X_test.unsqueeze(0)  # Add batch dimension, shape becomes (1, 10, 1)

with torch.no_grad():
    predictions = model(X_test)
    print(f"Predictions for new sequence: {predictions.tolist()}")


Epoch [1/500], Loss: 0.1162
Epoch [2/500], Loss: 0.0031
Epoch [3/500], Loss: 0.0095
Epoch [4/500], Loss: 0.0002
Epoch [5/500], Loss: 0.0000
Epoch [6/500], Loss: 0.0268
Epoch [7/500], Loss: 0.0164
Epoch [8/500], Loss: 0.0763
Epoch [9/500], Loss: 0.0163
Epoch [10/500], Loss: 0.0009
Epoch [11/500], Loss: 0.0009
Epoch [12/500], Loss: 0.0009
Epoch [13/500], Loss: 0.0015
Epoch [14/500], Loss: 0.0000
Epoch [15/500], Loss: 0.0006
Epoch [16/500], Loss: 0.0002
Epoch [17/500], Loss: 0.0004
Epoch [18/500], Loss: 0.0025
Epoch [19/500], Loss: 0.0065
Epoch [20/500], Loss: 0.0022
Epoch [21/500], Loss: 0.0046
Epoch [22/500], Loss: 0.0066
Epoch [23/500], Loss: 0.0073
Epoch [24/500], Loss: 0.0084
Epoch [25/500], Loss: 0.0089
Epoch [26/500], Loss: 0.0071
Epoch [27/500], Loss: 0.0052
Epoch [28/500], Loss: 0.0048
Epoch [29/500], Loss: 0.0043
Epoch [30/500], Loss: 0.0040
Epoch [31/500], Loss: 0.0034
Epoch [32/500], Loss: 0.0026
Epoch [33/500], Loss: 0.0016
Epoch [34/500], Loss: 0.0004
Epoch [35/500], Loss: 0

In [13]:
import jax
import jax.numpy as jnp
from flax import linen as nn
import optax
import numpy as np

# -------------------------------
# Data Preparation
# -------------------------------
sequence_length = 10
num_samples = 100

# Create a sine wave dataset (using 4 * pi as the end point)
X = jnp.linspace(0, 4 * jnp.pi, num=num_samples)[:, None]
y = jnp.sin(X)

def create_in_out_sequences(data, seq_length):
    in_seq = []
    out_seq = []
    for i in range(len(data) - seq_length):
        in_seq.append(data[i:i + seq_length])
        out_seq.append(data[i + seq_length])
    return jnp.stack(in_seq), jnp.stack(out_seq)

X_seq, y_seq = create_in_out_sequences(y, sequence_length)

# -------------------------------
# Define the RNN Model in Flax
# -------------------------------
class RNNCell(nn.Module):
    hidden_size: int = 50

    @nn.compact
    def __call__(self, carry, x):
        # carry: previous hidden state, shape (batch, hidden_size)
        # x: current input, shape (batch, input_size)
        # Compute new hidden state: tanh(W_ih*x + W_hh*carry + b)
        new_h = nn.tanh(
            nn.Dense(self.hidden_size, name="ih")(x) +
            nn.Dense(self.hidden_size, use_bias=False, name="hh")(carry)
        )
        return new_h, new_h  # returning new state as both carry and output

class RNNModel(nn.Module):
    hidden_size: int = 50

    @nn.compact
    def __call__(self, x):
        # x shape: (batch, seq_length, input_size)
        batch_size = x.shape[0]
        init_carry = jnp.zeros((batch_size, self.hidden_size))
        # Instead of instantiating RNNCell, pass the class to nn.scan
        rnn_scan = nn.scan(
            RNNCell,  # pass the class instead of an instance
            in_axes=1,
            out_axes=1,
            variable_broadcast="params",
            split_rngs={"params": False},
        )(hidden_size=self.hidden_size)  # now provide the argument for hidden_size
        carry, ys = rnn_scan(init_carry, x)
        # Use the output at the final time step
        last_output = ys[:, -1, :]
        output = nn.Dense(1)(last_output)
        return output

# -------------------------------
# Initialize the Model and Optimizer
# -------------------------------
model = RNNModel()
rng = jax.random.PRNGKey(42)
# Sample input with shape (batch=1, seq_length, input_size)
sample_input = jnp.ones((1, sequence_length, 1))
params = model.init(rng, sample_input)

# Set up the Adam optimizer from Optax
optimizer = optax.adam(learning_rate=0.001)
opt_state = optimizer.init(params)

# -------------------------------
# Define Loss and Training Step
# -------------------------------
def loss_fn(params, x, y):
    preds = model.apply(params, x)
    return jnp.mean((preds - y) ** 2)

@jax.jit
def train_step(params, opt_state, x, y):
    loss, grads = jax.value_and_grad(loss_fn)(params, x, y)
    updates, opt_state = optimizer.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    return params, opt_state, loss

# -------------------------------
# Training Loop
# -------------------------------
epochs = 500
for epoch in range(epochs):
    epoch_loss = 0.0
    # Loop over each sample (each sample has shape (sequence_length, 1))
    for seq, label in zip(X_seq, y_seq):
        # Add a batch dimension: new shape becomes (1, sequence_length, 1)
        seq = seq[None, :, :]
        label = label[None, :]
        params, opt_state, loss = train_step(params, opt_state, seq, label)
        epoch_loss += loss
    epoch_loss /= len(X_seq)
    if (epoch + 1) % 5 == 0 or epoch == 0:
        print(f"Epoch [{epoch + 1}/{epochs}], Loss: {epoch_loss:.4f}")

# -------------------------------
# Testing on New Data
# -------------------------------
# Create new test data (from 4*pi to 5*pi)
X_test = jnp.linspace(4 * jnp.pi, 5 * jnp.pi, num=10)[:, None]
X_test = X_test[None, :, :]  # Add batch dimension: shape becomes (1, 10, 1)
predictions = model.apply(params, X_test)
print("Predictions for new sequence:", predictions)


Epoch [1/500], Loss: 0.0716
Epoch [5/500], Loss: 0.0019
Epoch [10/500], Loss: 0.1030
Epoch [15/500], Loss: 0.0176
Epoch [20/500], Loss: 0.0087
Epoch [25/500], Loss: 0.0101
Epoch [30/500], Loss: 0.0030
Epoch [35/500], Loss: 0.0073
Epoch [40/500], Loss: 0.0100
Epoch [45/500], Loss: 0.0001
Epoch [50/500], Loss: 0.0001
Epoch [55/500], Loss: 0.0142
Epoch [60/500], Loss: 0.0001
Epoch [65/500], Loss: 0.0007
Epoch [70/500], Loss: 0.0120
Epoch [75/500], Loss: 0.0004
Epoch [80/500], Loss: 0.0010
Epoch [85/500], Loss: 0.0089
Epoch [90/500], Loss: 0.0103
Epoch [95/500], Loss: 0.0006
Epoch [100/500], Loss: 0.0003
Epoch [105/500], Loss: 0.0047
Epoch [110/500], Loss: 0.0066
Epoch [115/500], Loss: 0.0008
Epoch [120/500], Loss: 0.0042
Epoch [125/500], Loss: 0.0062
Epoch [130/500], Loss: 0.0005
Epoch [135/500], Loss: 0.0031
Epoch [140/500], Loss: 0.0041
Epoch [145/500], Loss: 0.0021
Epoch [150/500], Loss: 0.0024
Epoch [155/500], Loss: 0.0015
Epoch [160/500], Loss: 0.0026
Epoch [165/500], Loss: 0.0039
Ep

In [None]:
# Weak LLM
import jax
import jax.numpy as jnp
from flax import linen as nn
import optax
import numpy as np

# RNN Cell Definition
class RNNCell(nn.Module):
    hidden_size: int

    def setup(self):
        # MODIFIED: Initialize weights for the RNN cell
        self.W_ih = self.param('W_ih', nn.initializers.xavier_uniform(), (self.hidden_size, self.hidden_size))
        self.W_hh = self.param('W_hh', nn.initializers.xavier_uniform(), (self.hidden_size, self.hidden_size))

    def __call__(self, x, hidden_state):
        # MODIFIED: Ensure hidden state is properly utilized and returned
        new_hidden_state = jnp.tanh(jnp.dot(x, self.W_ih) + jnp.dot(hidden_state, self.W_hh))
        return new_hidden_state

# RNN Module Definition
class RNN(nn.Module):
    hidden_size: int
    output_size: int

    def setup(self):
        self.rnn_cell = RNNCell(self.hidden_size)
        self.fc = nn.Dense(self.output_size)

    def __call__(self, x):
        # MODIFIED: Initialized hidden state explicitly
        hidden_state = jnp.zeros((x.shape[0], self.hidden_size))

        def rnn_step(hidden_state, x_t):
            return self.rnn_cell(x_t, hidden_state)  # MODIFIED: Pass hidden state explicitly

        # Using jax.lax.scan for efficient state propagation
        hidden_states = jax.lax.scan(rnn_step, hidden_state, x)[0]  # MODIFIED: Capture hidden states
        output = self.fc(hidden_states)
        return output

# Loss Function
def compute_loss(logits, targets):
    return jnp.mean(jax.nn.softmax_cross_entropy(logits=logits, labels=targets))

# Main Function
def main():
    # Sample data for training (Dummy data)
    x_train = jnp.array(np.random.rand(100, 10, 1))  # 100 samples, 10 timesteps, 1 feature
    y_train = jnp.array(np.random.randint(0, 2, (100, 10, 2)))  # 2 classes

    model = RNN(hidden_size=16, output_size=2)  # Instantiate the RNN model
    params = model.init(jax.random.PRNGKey(0), x_train)  # Initialize parameters

    optimizer = optax.adam(learning_rate=0.001)
    opt_state = optimizer.init(params)

    # Training Loop
    epochs = 5
    for epoch in range(epochs):
        # Forward pass
        logits = model.apply(params, x_train)
        loss = compute_loss(logits, y_train)

        # Compute gradients and update parameters
        grads = jax.grad(compute_loss)(params, y_train)
        updates, opt_state = optimizer.update(grads, opt_state)
        params = optax.apply_updates(params, updates)

        print(f"Epoch [{epoch + 1}/{epochs}], Loss: {loss:.4f}")

    # Testing on new data
    X_test = np.linspace(4 * np.pi, 5 * np.pi, 10).reshape(-1, 1)
    X_test = jnp.expand_dims(X_test, axis=0)  # Add batch dimension

    predictions = model.apply(params, X_test)
    print(f"Predictions for new sequence: {predictions.tolist()}")

if __name__ == "__main__":
    main()

In [None]:
"""Error Code

  model = RNN(hidden_size=16, output_size=2)  # Instantiate the RNN model
  params = model.init(jax.random.PRNGKey(0), x_train)

def setup(self):
        self.W_ih = self.param('W_ih', nn.initializers.xavier_uniform(), (self.hidden_size, self.hidden_size))
        self.W_hh = self.param('W_hh', nn.initializers.xavier_uniform(), (self.hidden_size, self.hidden_size))
def __call__(self, x, hidden_state):
        # MODIFIED: Ensure hidden state is properly utilized and returned
        new_hidden_state = jnp.tanh(jnp.dot(x, self.W_ih) + jnp.dot(hidden_state, self.W_hh))

Error: TypeError: dot_general requires contracting dimensions to
have the same shape, got (1,) and (16,).

Fix Guide:
The error occurs due to a shape mismatch in the matrix multiplication within
RNNCell. The input x has shape (batch_size, input_features=1),
but W_ih is initialized with shape (hidden_size, hidden_size)=(16, 16),
causing a dimension mismatch in the dot product.

Correct Code
def setup(self):
        self.W_ih = self.param('W_ih', nn.initializers.xavier_uniform(),
                              (self.input_size, self.hidden_size))
        self.W_hh = self.param('W_hh', nn.initializers.xavier_uniform(),
                              (self.hidden_size, self.hidden_size))

    def __call__(self, x, hidden_state):
        new_hidden_state = jnp.tanh(jnp.dot(x, self.W_ih) + jnp.dot(hidden_state, self.W_hh))
        return new_hidden_state
"""

"""Error Code
final_hidden_state, _ = jax.lax.scan(rnn_step, hidden_state, x, length=x.shape[1])

Error:
ValueError: scan got `length` argument of 10 which disagrees with leading axis sizes [100].

Fix Guide
The error occurs in jax.lax.scan because the length parameter
(set to x.shape[1] = 10) doesn't match the leading axis size of the input x
that scan is iterating over. In jax.lax.scan, the x argument is expected to have
its first dimension as the sequence length to scan over, but:
x has shape (100, 10, 1) where 100 is the batch size and 10 is the sequence length.
scan interprets the first dimension (100) as the sequence length,
conflicting with length=10. We want to transpose x to (sequence_length, batch_size, features) = (10, 100, 1)
 before scanning.

Correct Code
def __call__(self, x):
        # Transpose x to (sequence_length, batch_size, input_size)
        x = jnp.transpose(x, (1, 0, 2))  # From (100, 10, 1) to (10, 100, 1)
        hidden_state = jnp.zeros((x.shape[1], self.hidden_size))  # batch_size = 100

        def rnn_step(hidden_state, x_t):
            return self.rnn_cell(x_t, hidden_state), None

        # Scan over sequence dimension (10)
        final_hidden_state, _ = jax.lax.scan(rnn_step, hidden_state, x)
        output = self.fc(final_hidden_state)
        return output
"""

"""Error Code
new_hidden_state = jnp.tanh(jnp.dot(x, self.W_ih) + jnp.dot(hidden_state, self.W_hh))
        return new_hidden_state

  Error:
TypeError: add got incompatible shapes for broadcasting: (100, 16), (10, 16).

Fix guide
x has shape (100, 1),10 is the sequence length and 100 is the batch size.
W_ih has shape (1, 16),jnp.dot(x, self.W_ih) has shape (100, 16).
hidden_state has shape (100, 16), and W_hh has shape (16, 16),
so jnp.dot(hidden_state, self.W_hh) also has shape (100, 16).
The final hidden state should reflect the batch size (100) and hidden size (16).

Correct Code:
def __call__(self, x):
    x = jnp.transpose(x, (1, 0, 2))  # From (batch, seq, feat) to (seq, batch, feat)
    hidden_state = jnp.zeros((x.shape[1], self.hidden_size))  # Shape: (batch_size, hidden_size)

    def rnn_step(hidden_state, x_t):
        new_hidden = self.rnn_cell(x_t, hidden_state)
        return new_hidden, None  # Return new hidden state as carry

    final_hidden_state, _ = jax.lax.scan(rnn_step, hidden_state, x)
    output = self.fc(final_hidden_state)
    return output
"""

"""Error Code
def __call__(self, x):
    x = jnp.transpose(x, (1, 0, 2))  # From (batch, seq, feat) to (seq, batch, feat)
    hidden_state = jnp.zeros((x.shape[1], self.hidden_size))  # Shape: (batch_size, hidden_size)

    def rnn_step(hidden_state, x_t):
        new_hidden = self.rnn_cell(x_t, hidden_state)
        return new_hidden, None  # Return new hidden state as carry

    final_hidden_state, _ = jax.lax.scan(rnn_step, hidden_state, x)
    output = self.fc(final_hidden_state)
    return output

Error:

UnexpectedTracerError: Encountered an unexpected tracer. A function transformed by JAX had a side effect, allowing for a reference to an intermediate value with type float32[1,16] wrapped in a DynamicJaxprTracer to escape the scope of the transformation.
JAX transformations require that functions explicitly return their outputs, and disallow saving intermediate values to global state.
The function being traced when the value leaked was rnn_step at <ipython-input-9-8a3c00ebe7fa>:21 traced for scan.
------------------------------
The leaked intermediate value was created on line /usr/local/lib/python3.11/dist-packages/flax/core/scope.py:968:14 (Scope.param).
------------------------------
When the value was created, the final 5 stack frames (most recent last) excluding JAX-internal frames were:
------------------------------
Fix guide:
The error occurs because JAX's transformations (like scan) expect pure functions,
but the rnn_cell call within rnn_step involves Flax's parameter initialization,
which has side effects (accessing self.W_ih and self.W_hh).
When traced by JAX, this creates an unexpected tracer leak.
To fix this, we need to explicitly pass the parameters to rnn_step and avoid relying on self.rnn_cell’s implicit state during the scan

correct code:
def __call__(self, x, rng=None):  # Added rng for initialization if needed
    x = jnp.transpose(x, (1, 0, 2))  # From (batch, seq, feat) to (seq, batch, feat)
    hidden_state = jnp.zeros((x.shape[1], self.hidden_size))  # Shape: (batch_size, hidden_size)

    def rnn_step(hidden_state, x_t, params):  # Added params argument
        W_ih = params['rnn_cell']['W_ih']
        W_hh = params['rnn_cell']['W_hh']
        new_hidden = jnp.tanh(jnp.dot(x_t, W_ih) + jnp.dot(hidden_state, W_hh))
        return new_hidden, None

    # Pass params explicitly (assuming params is available from apply)
    final_hidden_state, _ = jax.lax.scan(
        lambda hs, xt: rnn_step(hs, xt, self.params), hidden_state, x
    )
    output = self.fc(final_hidden_state)
    return output
"""

"""
Error
UnexpectedTracerError: Encountered an unexpected tracer. A function transformed by JAX had a side effect, allowing for a reference to an intermediate value with type float32[1,16] wrapped in a DynamicJaxprTracer to escape the scope of the transformation.
JAX transformations require that functions explicitly return their outputs, and disallow saving intermediate values to global state.
The function being traced when the value leaked was <lambda> at <ipython-input-28-2fb73f6ef55e>:50 traced for scan.
The leaked intermediate value was created on line /usr/local/lib/python3.11/dist-packages/flax/core/scope.py:968:14 (Scope.param).

"""

In [32]:
# Fixed Code
import jax
import jax.numpy as jnp
from flax import linen as nn
import optax
import numpy as np

# RNN Cell Definition
class RNNCell(nn.Module):
    input_size: int
    hidden_size: int

    def setup(self):
        self.W_ih = self.param('W_ih', nn.initializers.xavier_uniform(),
                              (self.input_size, self.hidden_size))
        self.W_hh = self.param('W_hh', nn.initializers.xavier_uniform(),
                              (self.hidden_size, self.hidden_size))

    def __call__(self, carry, x):
        new_carry = jnp.tanh(jnp.dot(x, self.W_ih) + jnp.dot(carry, self.W_hh))
        return new_carry, None

# RNN Module Definition
class RNN(nn.Module):
    input_size: int
    hidden_size: int
    output_size: int

    def setup(self):
        # Wrap RNNCell with nn.scan for proper parameter handling
        self.scanned_rnn_cell = nn.scan(
            RNNCell,
            variable_broadcast="params",
            split_rngs={"params": False},
            in_axes=0,
            out_axes=0
        )(input_size=self.input_size, hidden_size=self.hidden_size)
        self.fc = nn.Dense(self.output_size)
    def __call__(self, x):
        # Transpose x from (batch, seq, feat) to (seq, batch, feat)
        x = jnp.transpose(x, (1, 0, 2))
        batch_size = x.shape[1]
        init_carry = jnp.zeros((batch_size, self.hidden_size))
        final_carry, _ = self.scanned_rnn_cell(init_carry, x)
        output = self.fc(final_carry)
        return output
    # def __call__(self, x, params=None, rng=None):
    #     x = jnp.transpose(x, (1, 0, 2))  # From (batch, seq, feat) to (seq, batch, feat)
    #     hidden_state = jnp.zeros((x.shape[1], self.hidden_size))  # Shape: (batch_size, hidden_size)

    #     def rnn_step(hidden_state, x_t, cell_params):
    #         W_ih = cell_params['W_ih']
    #         W_hh = cell_params['W_hh']
    #         new_hidden = jnp.tanh(jnp.dot(x_t, W_ih) + jnp.dot(hidden_state, W_hh))
    #         return new_hidden, None

        # Use params if provided (apply), otherwise call rnn_cell directly (init)
        # if params is not None:
        #     cell_params = params['rnn_cell']
        #     final_hidden_state, _ = jax.lax.scan(
        #         lambda hs, xt: rnn_step(hs, xt, cell_params), hidden_state, x
        #     )
        # else:
        #     final_hidden_state, _ = jax.lax.scan(
        #         lambda hs, xt: (self.rnn_cell(xt, hs), None), hidden_state, x
        #     )
        # output = self.fc(final_hidden_state)
        # return output

# Loss Function
def compute_loss(params, model, x, targets):
    logits = model.apply(params, x)
    return jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=targets))

# Main Function
def main():
    # Sample data for training
    x_train = jnp.array(np.random.rand(100, 10, 1))  # 100 samples, 10 timesteps, 1 feature
    y_train = jnp.array(np.random.randint(0, 2, (100, 2)))  # 2 classes, output at last timestep
    # y_train = jnp.array(np.eye(2)[np.random.randint(0, 2, 100)])

    # Instantiate the RNN model
    model = RNN(input_size=1, hidden_size=16, output_size=2)
    params = model.init(jax.random.PRNGKey(0), x_train)

    optimizer = optax.adam(learning_rate=0.001)
    opt_state = optimizer.init(params)

    # Training Loop
    epochs = 5
    for epoch in range(epochs):
        loss, grads = jax.value_and_grad(compute_loss)(params, model, x_train, y_train)
        updates, opt_state = optimizer.update(grads, opt_state)
        params = optax.apply_updates(params, updates)

        print(f"Epoch [{epoch + 1}/{epochs}], Loss: {loss:.4f}")

    # Testing on new data
    X_test = np.linspace(4 * np.pi, 5 * np.pi, 10).reshape(1, 10, 1)  # 1 sample, 10 timesteps
    predictions = model.apply(params, X_test)
    print(f"Predictions for new sequence: {predictions.tolist()}")

if __name__ == "__main__":
    main()

Epoch [1/5], Loss: 0.6701
Epoch [2/5], Loss: 0.6684
Epoch [3/5], Loss: 0.6669
Epoch [4/5], Loss: 0.6655
Epoch [5/5], Loss: 0.6643
Predictions for new sequence: [[-0.16616860032081604, -0.9254869222640991]]
