In [1]:
#Input
import torch
import torch.nn as nn
import torch.optim as optim

# Define the Encoder
class Encoder(nn.Module):
    def __init__(self, input_dim, embed_dim, hidden_dim, num_layers):
        super(Encoder, self).__init__()
        self.embedding = nn.Embedding(input_dim, embed_dim)
        self.lstm = nn.LSTM(embed_dim, hidden_dim, num_layers, batch_first=True)

    def forward(self, x):
        embedded = self.embedding(x)
        outputs, (hidden, cell) = self.lstm(embedded)
        return outputs, (hidden, cell)

# Define the Decoder with Attention
class Decoder(nn.Module):
    def __init__(self, output_dim, embed_dim, hidden_dim, num_layers, src_seq_length):
        super(Decoder, self).__init__()
        self.embedding = nn.Embedding(output_dim, embed_dim)
        self.attention = nn.Linear(hidden_dim + embed_dim, src_seq_length)
        self.attention_combine = nn.Linear(hidden_dim + embed_dim, embed_dim)
        self.lstm = nn.LSTM(embed_dim, hidden_dim, num_layers, batch_first=True)
        self.fc_out = nn.Linear(hidden_dim, output_dim)

    def forward(self, x, encoder_outputs, hidden, cell):
        x = x.unsqueeze(1)  # Add sequence dimension
        embedded = self.embedding(x)

        # Attention mechanism
        attention_weights = torch.softmax(self.attention(torch.cat((embedded.squeeze(1), hidden[-1]), dim=1)), dim=1)
        context_vector = torch.bmm(attention_weights.unsqueeze(1), encoder_outputs)

        # Combine context and embedded input
        combined = torch.cat((embedded.squeeze(1), context_vector.squeeze(1)), dim=1)
        combined = torch.tanh(self.attention_combine(combined)).unsqueeze(1)

        # LSTM and output
        lstm_out, (hidden, cell) = self.lstm(combined, (hidden, cell))
        output = self.fc_out(lstm_out.squeeze(1))
        return output, hidden, cell

# Define synthetic training data
torch.manual_seed(42)
src_vocab_size = 20
tgt_vocab_size = 20
src_seq_length = 10
tgt_seq_length = 12
batch_size = 16

src_data = torch.randint(0, src_vocab_size, (batch_size, src_seq_length))
tgt_data = torch.randint(0, tgt_vocab_size, (batch_size, tgt_seq_length))

# Initialize models, loss function, and optimizer
input_dim = src_vocab_size
output_dim = tgt_vocab_size
embed_dim = 32
hidden_dim = 64
num_layers = 2

encoder = Encoder(input_dim, embed_dim, hidden_dim, num_layers)
decoder = Decoder(output_dim, embed_dim, hidden_dim, num_layers, src_seq_length)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(list(encoder.parameters()) + list(decoder.parameters()), lr=0.001)

# Training loop
epochs = 100
for epoch in range(epochs):
    encoder_outputs, (hidden, cell) = encoder(src_data)
    loss = 0
    decoder_input = torch.zeros(batch_size, dtype=torch.long)  # Start token

    for t in range(tgt_seq_length):
        output, hidden, cell = decoder(decoder_input, encoder_outputs, hidden, cell)
        loss += criterion(output, tgt_data[:, t])
        decoder_input = tgt_data[:, t]  # Teacher forcing

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    # Log progress every 10 epochs
    if (epoch + 1) % 10 == 0:
        print(f"Epoch [{epoch + 1}/{epochs}] - Loss: {loss.item():.4f}")

# Test the sequence-to-sequence model with new input
test_input = torch.randint(0, src_vocab_size, (1, src_seq_length))
with torch.no_grad():
    encoder_outputs, (hidden, cell) = encoder(test_input)
    decoder_input = torch.zeros(1, dtype=torch.long)  # Start token
    output_sequence = []

    for _ in range(tgt_seq_length):
        output, hidden, cell = decoder(decoder_input, encoder_outputs, hidden, cell)
        predicted = output.argmax(1)
        output_sequence.append(predicted.item())
        decoder_input = predicted

    print(f"Input: {test_input.tolist()}, Output: {output_sequence}")

Epoch [10/100] - Loss: 35.5304
Epoch [20/100] - Loss: 34.7664
Epoch [30/100] - Loss: 33.6247
Epoch [40/100] - Loss: 30.9979
Epoch [50/100] - Loss: 27.3896
Epoch [60/100] - Loss: 24.1525
Epoch [70/100] - Loss: 21.2032
Epoch [80/100] - Loss: 18.6953
Epoch [90/100] - Loss: 16.5154
Epoch [100/100] - Loss: 14.5447
Input: [[3, 18, 4, 11, 8, 17, 12, 7, 18, 1]], Output: [13, 13, 2, 2, 2, 12, 12, 7, 7, 12, 12, 12]


In [2]:
#Strong LLM
import jax
import jax.numpy as jnp
from flax import linen as nn
import optax
import numpy as np


class LSTMLayer(nn.Module):
    hidden_dim: int

    @nn.compact
    def __call__(self, x, initial_state):
        # x shape: (batch, seq_length, features)
        lstm_cell = nn.LSTMCell()
        def step_fn(carry, x_t):
            new_carry, y = lstm_cell(carry, x_t)
            return new_carry, y
        final_state, outputs = nn.scan(
            step_fn,
            variable_broadcast="params",
            split_rngs={"params": False},
            in_axes=1,
            out_axes=1,
        )(initial_state, x)
        return outputs, final_state

class Encoder(nn.Module):
    input_dim: int        # vocabulary size of source
    embed_dim: int
    hidden_dim: int
    num_layers: int

    @nn.compact
    def __call__(self, x):
        # x shape: (batch, src_seq_length) of token IDs
        emb = nn.Embed(num_embeddings=self.input_dim, features=self.embed_dim)(x)
        # emb shape: (batch, src_seq_length, embed_dim)
        batch_size = emb.shape[0]
        outputs = emb
        states = []
        # Process through a stack of LSTM layers
        for i in range(self.num_layers):
            # Initialize carry (cell, hidden) for this layer
            initial_state = nn.LSTMCell.initialize_carry(jax.random.PRNGKey(0), (batch_size,), self.hidden_dim)
            outputs, final_state = LSTMLayer(self.hidden_dim, name=f"lstm_layer_{i}")(outputs, initial_state)
            states.append(final_state)
        # Collect final states from each layer.
        hidden = jnp.stack([s[1] for s in states], axis=0)  # shape: (num_layers, batch, hidden_dim)
        cell   = jnp.stack([s[0] for s in states], axis=0)
        return outputs, (hidden, cell)

class Decoder(nn.Module):
    output_dim: int       # vocabulary size of target
    embed_dim: int
    hidden_dim: int
    num_layers: int
    src_seq_length: int

    @nn.compact
    def __call__(self, x, encoder_outputs, hidden, cell):
        # x shape: (batch,) token IDs; add a time dimension.
        x = x[:, None]  # shape becomes (batch, 1)
        embedded = nn.Embed(num_embeddings=self.output_dim, features=self.embed_dim)(x)
        # embedded shape: (batch, 1, embed_dim)
        embedded_squeezed = jnp.squeeze(embedded, axis=1)  # (batch, embed_dim)
        # Attention: combine embedded input with last-layer hidden state.
        last_hidden = hidden[-1]  # (batch, hidden_dim)
        attn_input = jnp.concatenate([embedded_squeezed, last_hidden], axis=1)  # (batch, embed_dim+hidden_dim)
        # Map to raw attention scores (one score per encoder time step).
        attn_scores = nn.Dense(self.src_seq_length)(attn_input)  # (batch, src_seq_length)
        attention_weights = jax.nn.softmax(attn_scores, axis=1)   # (batch, src_seq_length)
        # Compute context vector as weighted sum over encoder outputs.
        # encoder_outputs shape: (batch, src_seq_length, hidden_dim)
        attention_weights_exp = attention_weights[:, None, :]  # (batch, 1, src_seq_length)
        context_vector = jnp.matmul(attention_weights_exp, encoder_outputs)  # (batch, 1, hidden_dim)
        context_vector = jnp.squeeze(context_vector, axis=1)  # (batch, hidden_dim)
        # Combine context vector and embedded input.
        combined = jnp.concatenate([embedded_squeezed, context_vector], axis=1)  # (batch, embed_dim+hidden_dim)
        combined = nn.Dense(self.embed_dim)(combined)
        combined = jnp.tanh(combined)
        combined = combined[:, None, :]  # (batch, 1, embed_dim)
        # Pass through a one-step multi-layer LSTM.
        new_hidden = []
        new_cell = []
        x_t = jnp.squeeze(combined, axis=1)  # (batch, embed_dim)
        for i in range(self.num_layers):
            lstm_cell = nn.LSTMCell(name=f"decoder_lstm_cell_{i}")
            state = (cell[i], hidden[i])
            new_state, y = lstm_cell(state, x_t)
            new_cell.append(new_state[0])
            new_hidden.append(new_state[1])
            x_t = y  # output becomes input for the next layer
        new_hidden = jnp.stack(new_hidden, axis=0)
        new_cell = jnp.stack(new_cell, axis=0)
        # Map final LSTM output to target vocabulary logits.
        output = nn.Dense(self.output_dim)(y)  # (batch, output_dim)
        return output, new_hidden, new_cell

def get_data(key):
    src_vocab_size = 20
    tgt_vocab_size = 20
    src_seq_length = 10
    tgt_seq_length = 12
    batch_size = 16
    key, subkey = jax.random.split(key)
    src_data = jax.random.randint(subkey, shape=(batch_size, src_seq_length), minval=0, maxval=src_vocab_size)
    key, subkey = jax.random.split(key)
    tgt_data = jax.random.randint(subkey, shape=(batch_size, tgt_seq_length), minval=0, maxval=tgt_vocab_size)
    return src_data, tgt_data, src_vocab_size, tgt_vocab_size, src_seq_length, tgt_seq_length, batch_size

def cross_entropy_loss(logits, targets):
    # logits shape: (batch, num_classes); targets shape: (batch,)
    log_probs = jax.nn.log_softmax(logits)
    one_hot = jax.nn.one_hot(targets, logits.shape[-1])
    loss = -jnp.sum(one_hot * log_probs, axis=-1)
    return jnp.mean(loss)

def loss_fn(encoder_params, decoder_params, src_data, tgt_data, encoder, decoder, tgt_seq_length):
    # Run encoder.
    enc_vars = {'params': encoder_params}
    encoder_outputs, (hidden, cell) = encoder.apply(enc_vars, src_data)
    loss = 0.0
    # Start token (assumed 0) for the decoder.
    decoder_input = jnp.zeros((src_data.shape[0],), dtype=jnp.int32)
    dec_vars = {'params': decoder_params}
    for t in range(tgt_seq_length):
        logits, hidden, cell = decoder.apply(dec_vars, decoder_input, encoder_outputs, hidden, cell)
        loss += cross_entropy_loss(logits, tgt_data[:, t])
        # Teacher forcing: next input is current target.
        decoder_input = tgt_data[:, t]
    return loss / tgt_seq_length

def main():
    key = jax.random.PRNGKey(42)
    src_data, tgt_data, src_vocab_size, tgt_vocab_size, src_seq_length, tgt_seq_length, batch_size = get_data(key)

    input_dim = src_vocab_size      # Source vocabulary size
    output_dim = tgt_vocab_size     # Target vocabulary size
    embed_dim = 32
    hidden_dim = 64
    num_layers = 2

    encoder = Encoder(input_dim=input_dim, embed_dim=embed_dim, hidden_dim=hidden_dim, num_layers=num_layers)
    decoder = Decoder(output_dim=output_dim, embed_dim=embed_dim, hidden_dim=hidden_dim,
                      num_layers=num_layers, src_seq_length=src_seq_length)

    # Initialize model parameters.
    encoder_vars = encoder.init(key, src_data)
    encoder_params = encoder_vars['params']
    encoder_outputs, (hidden, cell) = encoder.apply({'params': encoder_params}, src_data)
    dummy_decoder_input = jnp.zeros((batch_size,), dtype=jnp.int32)
    decoder_vars = decoder.init(key, dummy_decoder_input, encoder_outputs, hidden, cell)
    decoder_params = decoder_vars['params']

    # Combine parameters and set up the optimizer.
    params = {'encoder': encoder_params, 'decoder': decoder_params}
    optimizer = optax.adam(learning_rate=0.001)
    opt_state = optimizer.init(params)

    @jax.jit
    def train_step(params, opt_state, src_data, tgt_data):
        def loss_wrapper(params):
            return loss_fn(params['encoder'], params['decoder'], src_data, tgt_data, encoder, decoder, tgt_seq_length)
        loss_val, grads = jax.value_and_grad(loss_wrapper)(params)
        updates, opt_state = optimizer.update(grads, opt_state)
        new_params = optax.apply_updates(params, updates)
        return new_params, opt_state, loss_val

    epochs = 100
    for epoch in range(epochs):
        params, opt_state, loss_val = train_step(params, opt_state, src_data, tgt_data)
        if (epoch + 1) % 10 == 0:
            print(f"Epoch [{epoch + 1}/{epochs}] - Loss: {loss_val:.4f}")

    key, subkey = jax.random.split(key)
    test_input = jax.random.randint(subkey, shape=(1, src_seq_length), minval=0, maxval=src_vocab_size)
    enc_vars = {'params': params['encoder']}
    encoder_outputs, (hidden, cell) = encoder.apply(enc_vars, test_input)
    decoder_input = jnp.zeros((1,), dtype=jnp.int32)  # Start token
    output_sequence = []
    dec_vars = {'params': params['decoder']}
    for _ in range(tgt_seq_length):
        logits, hidden, cell = decoder.apply(dec_vars, decoder_input, encoder_outputs, hidden, cell)
        predicted = jnp.argmax(logits, axis=-1)
        output_sequence.append(int(predicted[0]))
        decoder_input = predicted
    print(f"Input: {np.array(test_input).tolist()}, Output: {output_sequence}")

if __name__ == "__main__":
    main()


TypeError: 'int' object is not subscriptable

In [None]:
#Weak LLM
import jax
import jax.numpy as jnp
from flax import linen as nn
from flax.training import train_state
import optax
import numpy as np

class Decoder(nn.Module):
    # Define the decoder module with attention mechanism
    def __init__(self, vocab_size, hidden_size):
        self.vocab_size = vocab_size
        self.hidden_size = hidden_size

    def __call__(self, decoder_input, encoder_outputs, hidden_state, cell_state):
        # Compute the attention scores
        attention_scores = jnp.dot(encoder_outputs, hidden_state)  # MODIFIED: Ensure hidden_state is used appropriately
        attention_weights = nn.softmax(attention_scores)
        context_vector = jnp.dot(attention_weights, encoder_outputs)  # Compute the context vector

        # Update hidden state (dummy example, the actual implementation may vary)
        hidden_state = self.update_hidden_state(hidden_state, context_vector)

        # Generate output (dummy generation logic)
        output = nn.Dense(self.vocab_size)(context_vector)  # Define your output layer here

        return output, hidden_state, cell_state

    def update_hidden_state(self, hidden_state, context_vector):
        # Dummy update function for hidden state
        return hidden_state + context_vector  # Replace with actual update logic

def main():
    # Example parameters
    vocab_size = 10000
    hidden_size = 256
    tgt_seq_length = 10

    # Initialize decoder and states
    decoder = Decoder(vocab_size=vocab_size, hidden_size=hidden_size)
    hidden_state = jnp.zeros((1, hidden_size))
    cell_state = jnp.zeros((1, hidden_size))
    decoder_input = jnp.zeros((1, vocab_size))  # Adjust input dimensions accordingly
    encoder_outputs = jnp.zeros((1, tgt_seq_length, hidden_size))  # Example encoder output

    output_sequence = []

    # Decoding process
    for _ in range(tgt_seq_length):
        output, hidden_state, cell_state = decoder(decoder_input, encoder_outputs, hidden_state, cell_state)  # MODIFIED: Updated to pass hidden_state
        predicted = jnp.argmax(output, axis=1)
        output_sequence.append(predicted.item())

        # Ensure decoder_input shape matches the required input shape for the attention function
        decoder_input = jax.nn.one_hot(predicted, vocab_size)  # MODIFIED: Convert predicted index to one-hot encoding

    print(f"Input: {jnp.zeros((1, vocab_size)).tolist()}, Output: {output_sequence}")  # Placeholder for input

if __name__ == "__main__":
    main()

In [None]:
"""
Error Code:
File: <ipython-input-1-ffef6510e4ae>, line 17 attention_scores = jnp.dot(encoder_outputs, hidden_state) ... context_vector = jnp.dot(attention_weights, encoder_outputs)


Error:
dot_general requires contracting dimensions to have the same shape, got (256,) and (1,).


Fix Guide:
The error occurs because the dimensions in the dot product don't align. The encoder_outputs has shape (batch, seq_len, hidden_size), and hidden_state is (batch, hidden_size). Using jnp.dot here is incorrect. Instead, use einsum to correctly compute attention scores between each encoder output and the hidden state. Similarly, adjust the context vector computation to sum over the sequence dimension.


Correct Code:
attention_scores = jnp.einsum('bsh,bh->bs', encoder_outputs, hidden_state)
context_vector = jnp.einsum('bs,bsh->bh', attention_weights, encoder_outputs)
"""


"""
Error Code:
output = nn.Dense(self.vocab_size)(context_vector)


Error:
raised in the init method of Dense


Fix Guide:
The error occurs because Flax modules require parameter initialization through a proper module structure. The Decoder's call method needs the @nn.compact decorator to create submodules (like Dense) inline. Also, ensure the Decoder's init calls its parent's init.


Correct Code:
@nn.compact  # MODIFIED: Add this decorator
def __call__(self, decoder_input, encoder_outputs, hidden_state, cell_state):
    # ... (existing code)
    output = nn.Dense(self.vocab_size)(context_vector)
"""


"""
Error Code:
class Decoder(nn.Module):
    def __init__(self, vocab_size, hidden_size):
        self.vocab_size = vocab_size
        self.hidden_size = hidden_size


Error:
In Flax Linen, the __init__ method should not be directly overridden to initialize parameters


Fix Guide:
Declare module parameters using class attributes
Define each sublayer in the setup() method


Correct Code:
class Decoder(nn.Module):
    output_dim: int
    embed_dim: int
    hidden_dim: int
    num_layers: int
    src_seq_length: int

    def setup(self):
        self.embedding = nn.Embed(num_embeddings=self.output_dim, features=self.embed_dim)
        self.attention = nn.Dense(self.src_seq_length)
        self.attention_combine = nn.Dense(self.embed_dim)
        self.lstm = nn.OptimizedLSTMCell() 
        self.fc_out = nn.Dense(self.output_dim)
"""


"""
Error Code:
def __call__(self, decoder_input, encoder_outputs, hidden_state, cell_state):


Error:
The JAX code directly treats decoder_input as a vector and uses one-hot encoding, which is inconsistent with the original code logic


Fix Guide:
Define the embedding layer in setup().
In __call__, first embed the decoder_input (token index, shape [batch] or [batch, 1]) to get the embedding vector, which is then used for subsequent calculations


Correct Code:
def __call__(self, decoder_input, encoder_outputs, hidden_state, cell_state):
    # decoder_input is a token index, the shape is (batch,) or (batch, 1)
    embedded = self.embedding(decoder_input)  # Output shape: (batch, embed_dim) or (batch, 1, embed_dim)
    if embedded.ndim == 3:
        embedded = embedded.squeeze(1)
"""


"""
Error Code:
attention_scores = jnp.einsum('bsh,bh->bs', encoder_outputs, hidden_state)
attention_weights = nn.softmax(attention_scores)
context_vector = jnp.einsum('bs,bsh->bh', attention_weights, encoder_outputs)


Error:
The JAX code only uses hidden_state to participate in the calculation, without using the embedded information of the decoder.


Fix Guide:
Concatenate the current decoder's embedded embedded with hidden_state, and pass it to the self.attention linear layer to calculate the attention score
Use jax.nn.softmax to calculate the attention weight, and calculate the context vector based on the weight and encoder_outputs


Correct Code:
# Concatenate the current embedding and the previous hidden state (assuming the shape of hidden_state is (batch, hidden_dim))
concat_input = jnp.concatenate([embedded, hidden_state], axis=-1) # Shape (batch, embed_dim + hidden_dim)
attention_scores = self.attention(concat_input) # Output shape (batch, src_seq_length)
attention_weights = jax.nn.softmax(attention_scores, axis=-1)
context_vector = jnp.einsum('bs,bsh->bh', attention_weights, encoder_outputs) # Get (batch, hidden_dim)
"""


"""
Error Code:
# Update hidden state (dummy example, the actual implementation may vary)
hidden_state = self.update_hidden_state(hidden_state, context_vector)

# Generate output (dummy generation logic)
output = nn.Dense(self.vocab_size)(context_vector)  # Define your output layer here

return output, hidden_state, cell_state


Error:
After obtaining the context vector, the AX code concatenates it with the embedded input, expands the dimension through a fusion layer and tanh activation, then feeds it into an LSTM for state update, and then generates the output using a fully connected layer.


Fix Guide:
Concatenate embedded and context_vector, then pass self.attention_combine and tanh activation
Expand the fused vector into the sequence dimension and input it into LSTMCell for state update
Use the updated hidden state to get the output through self.fc_out


Correct Code:
# Fusion current embedding and context vector
combined = jnp.concatenate([embedded, context_vector], axis=-1) # (batch, embed_dim + hidden_dim)
combined = jax.nn.tanh(self.attention_combine(combined))
combined = combined[:, None, :]

combined = combined.squeeze(1) # (batch, embed_dim)
(new_hidden_state, new_cell_state), _ = self.lstm((hidden_state, cell_state), combined)
output = self.fc_out(new_hidden_state) # (batch, output_dim)

return output, new_hidden_state, new_cell_state
"""


"""
Error Code:
decoder = Decoder(vocab_size=vocab_size, hidden_size=hidden_size)


Error:
__init__() got an unexpected keyword argument 'vocab_size'


Fix Guide:
Modify the input parameters of Decoder


Correct Code:
vocab_size = 10000
embed_dim = 32
hidden_dim = 256
num_layers = 1 
src_seq_length = 10

decoder = Decoder(output_dim=output_dim, embed_dim=embed_dim, hidden_dim=hidden_dim,
                      num_layers=num_layers, src_seq_length=src_seq_length)
"""


"""
Error Code:
self.lstm = nn.OptimizedLSTMCell()


Error:
The hidden layer dimension parameter was not passed in when the LSTM cell was initialized


Fix Guide:
Pass in features=self.hidden_dim during initialization


Correct Code:
self.lstm = nn.OptimizedLSTMCell(features=self.hidden_dim)
"""


"""
Error Code:
(new_hidden_state, new_cell_state), _ = self.lstm((hidden_state, cell_state), combined)


Error:
The input here is (hidden_state, cell_state), which does not match the state order of the LSTM cell


Fix Guide:
Adjust the state order of the incoming LSTM cell and receive the return value


Correct Code:
(new_cell_state, new_hidden_state), _ = self.lstm((cell_state, hidden_state), combined)

"""


"""
Error Code:
decoder_input = jnp.zeros((1, vocab_size))


Error:
In the PyTorch code, the token index is passed to the decoder (the shape is (batch,) or (batch, 1)), and the one-hot vector is passed to jax, the shape is (1, vocab_size)


Fix Guide:
Define decoder_input as an integer token index


Correct Code:
decoder_input = jnp.array([0])
"""


"""
Error Code:
hidden_state = jnp.zeros((1, hidden_size))
cell_state = jnp.zeros((1, hidden_size))
decoder_input = jnp.array([0])
encoder_outputs = jnp.zeros((1, tgt_seq_length, hidden_size))

Error:
name 'hidden_size' is not defined


Fix Guide:
Should use src_seq_length and hidden_dim


Correct Code:
hidden_state = jnp.zeros((1, hidden_dim))
cell_state = jnp.zeros((1, hidden_dim))
decoder_input = jnp.array([0])
encoder_outputs = jnp.zeros((1, src_seq_length, hidden_dim))
"""


"""
Error Code:
combined = jax.nn.tanh(self.attention_combine(combined))
combined = combined[:, None, :]
combined = combined.squeeze(1) 


Error:
Adding a dimension and then squeezing it out immediately is unnecessary for the input LSTM cell and may cause shape confusion


Fix Guide:
Directly keep the shape of combined as (batch, embed_dim)


Correct Code:
combined = jax.nn.tanh(self.attention_combine(combined))
"""


"""
Error Code:
decoder = Decoder(output_dim=output_dim, embed_dim=embed_dim, hidden_dim=hidden_dim,
                  num_layers=num_layers, src_seq_length=src_seq_length)


Error:
output_dim is undefined


Fix Guide:
Replace output_dim with vocab_size


Correct Code:
decoder = Decoder(output_dim=vocab_size, embed_dim=embed_dim, hidden_dim=hidden_dim,
                  num_layers=num_layers, src_seq_length=src_seq_length)

"""


"""
Error Code:
for _ in range(tgt_seq_length):


Error:
tgt_seq_length is not defined in main()


Fix Guide:
When using the target sequence length, refer to tgt_seq_length defined in the PyTorch code


Correct Code:
tgt_seq_length = 12
for _ in range(tgt_seq_length):
"""


"""
Error Code:
decoder_input = jax.nn.one_hot(predicted, vocab_size)


Error:
Input dimensions do not match


Fix Guide:
When decoding, the predicted token index is used directly without converting to one-hot encoding


Correct Code:
decoder_input = predicted
"""


"""
Error Code:
for _ in range(tgt_seq_length):
    output, hidden_state, cell_state = decoder(decoder_input, encoder_outputs, hidden_state, cell_state)


Error:
"Decoder" object has no attribute "embedding". If "embedding" is defined in '.setup()', remember these fields are only accessible from inside 'init' or 'apply'.


Fix Guide:
Define a random number generator (PRNG key)
Call decoder.init to initialize the model parameters and save the returned variable dictionary
In the decoding loop, use decoder.apply(variables, ...) to call the model instead of calling the module object directly


Correct Code:
variables = decoder.init(rng, decoder_input, encoder_outputs, hidden_state, cell_state)
for _ in range(tgt_seq_length):
    output, hidden_state, cell_state = decoder.apply(variables, decoder_input, encoder_outputs, hidden_state, cell_state)
        
"""


"""
Error Code:
output_sequence = []

# Decoding process
tgt_seq_length = 12


Error:
name 'rng' is not defined


Fix Guide:
Initialize parameter variables using jax


Correct Code:
rng = jax.random.PRNGKey(0)
output_sequence = []

# Decoding process
tgt_seq_length = 12
"""


"""
Error Code:
for _ in range(tgt_seq_length):
    output, hidden_state, cell_state = decoder.apply(variables, decoder_input, encoder_outputs, hidden_state, cell_state)
    predicted = jnp.argmax(output, axis=1)
    output_sequence.append(int(predicted.item()))
    decoder_input = predicted


Error:
The kernel appears to have died. It will restart automatically.
Calling decoder.apply(...) directly in each loop may cause repeated tracing and compilation, which may consume a lot of memory or cause runtime problems, eventually leading to kernel crashes.


Fix Guide:
Encapsulate the decoding step into a separate function and JIT-compile it using jax.jit


Correct Code:
@jax.jit
def decode_step(decoder_input, hidden_state, cell_state, variables, encoder_outputs):
    output, new_hidden_state, new_cell_state = decoder.apply(variables, decoder_input, encoder_outputs, hidden_state, cell_state)
    predicted = jnp.argmax(output, axis=1)
    return predicted, new_hidden_state, new_cell_state
    
for _ in range(tgt_seq_length):
    predicted, hidden_state, cell_state = decode_step(decoder_input, hidden_state, cell_state, variables, encoder_outputs)
    output_sequence.append(int(predicted.item()))
    decoder_input = predicted
"""


"""
Error Code:
@jax.jit
def decode_step(decoder_input, hidden_state, cell_state, variables, encoder_outputs):
    output, new_hidden_state, new_cell_state = decoder.apply(variables, decoder_input, encoder_outputs, hidden_state, cell_state)
    predicted = jnp.argmax(output, axis=1)
    return predicted, new_hidden_state, new_cell_state


Error:
A JIT-compiled decode_step whose parameters are not marked as static each time it is called in a loop may cause JAX to repeatedly trace and recompile, consuming large amounts of memory or causing unexpected errors


Fix Guide:
Mark unchanged parameters as static so that the JIT only compiles the dynamic part


Correct Code:
@jax.jit(static_argnums=(3,4))
def decode_step(decoder_input, hidden_state, cell_state, variables, encoder_outputs):
    output, new_hidden_state, new_cell_state = decoder.apply(variables, decoder_input, encoder_outputs, hidden_state, cell_state)
    predicted = jnp.argmax(output, axis=1)
    return predicted, new_hidden_state, new_cell_state
"""


"""
Error Code:
@jax.jit(static_argnums=(3,4))


Error:
Mark model parameters and encoder_outputs as static parameters via static_argnums, causing JAX to try to hash these objects during tracing


Fix Guide:
Remove static_argnums parameter so all inputs are passed as dynamic arguments


Correct Code:
@jax.jit
"""


"""
Error Code:
self.lstm = nn.OptimizedLSTMCell(features=self.hidden_dim)


Error:
nn.OptimizedLSTMCell is not available or deprecated in Flax's Linen API


Fix Guide:
Replace nn.OptimizedLSTMCell with nn.LSTMCell and pass the same features parameter


Correct Code:
self.lstm = nn.LSTMCell(features=self.hidden_dim)
"""


"""
Error Code:
(new_cell_state, new_hidden_state), _ = self.lstm((cell_state, hidden_state), combined)
output = self.fc_out(new_hidden_state)


Error:
When calling LSTMCell, a tuple is returned:
The first return value is the new state (carry), which is usually structured as (new_cell_state, new_hidden_state)
The second return value is the output of the current time step
The current code incorrectly uses the second return value as "ignore" and directly uses the new hidden state (the part removed from carry) as the output, which is inconsistent with the logic of taking lstm_out and then fully connecting in PyTorch


Fix Guide:
When unpacking, get the carry and output at the same time
Use the output value to pass into the fully connected layer to get the final output


Correct Code:
carry, lstm_output = self.lstm((cell_state, hidden_state), combined)
new_cell_state, new_hidden_state = carry
output = self.fc_out(lstm_output)
"""


"""
Error Code:
# The JAX code only has the Decoder part, but no corresponding Encoder


Error:
The sequence-to-sequence model requires two parts: Encoder and Decoder. The lack of Encoder makes the overall model incomplete and cannot complete the end-to-end task


Fix Guide:
Add a Flax-based Encoder module


Correct Code:
class Encoder(nn.Module):
    input_dim: int
    embed_dim: int
    hidden_dim: int

    def setup(self):
        self.embedding = nn.Embed(num_embeddings=self.input_dim, features=self.embed_dim)
        self.lstm = nn.LSTMCell(features=self.hidden_dim)

    def __call__(self, x):
        # x: (batch, seq_length)
        embedded = self.embedding(x)  # (batch, seq_length, embed_dim)
        batch, seq_length, _ = embedded.shape
        cell_state = jnp.zeros((batch, self.hidden_dim))
        hidden_state = jnp.zeros((batch, self.hidden_dim))
        outputs = []
        for t in range(seq_length):
            (cell_state, hidden_state), lstm_output = self.lstm((cell_state, hidden_state), embedded[:, t, :])
            outputs.append(lstm_output)
        outputs = jnp.stack(outputs, axis=1)  # (batch, seq_length, hidden_dim)
        return outputs, (hidden_state, cell_state)
"""


"""
Error Code:
class Encoder(nn.Module):
    input_dim: int
    embed_dim: int
    hidden_dim: int

    def setup(self):
        self.embedding = nn.Embed(num_embeddings=self.input_dim, features=self.embed_dim)
        self.lstm = nn.LSTMCell(features=self.hidden_dim)

    def __call__(self, x):
        # x: (batch, seq_length)
        embedded = self.embedding(x)  # (batch, seq_length, embed_dim)
        batch, seq_length, _ = embedded.shape
        cell_state = jnp.zeros((batch, self.hidden_dim))
        hidden_state = jnp.zeros((batch, self.hidden_dim))
        outputs = []
        for t in range(seq_length):
            (cell_state, hidden_state), lstm_output = self.lstm((cell_state, hidden_state), embedded[:, t, :])
            outputs.append(lstm_output)
        outputs = jnp.stack(outputs, axis=1)  # (batch, seq_length, hidden_dim)
        return outputs, (hidden_state, cell_state)


Error:
The Encoder in the PyTorch code uses the num_layers parameter to build a multi-layer LSTM, while the JAX code only creates a single-layer LSTMCell


Fix Guide:
Add num_layers parameter to Encoder and construct a LSTMCell list in setup()
Update each layer in turn for each time step in __call__


Correct Code:
class Encoder(nn.Module):
    input_dim: int
    embed_dim: int
    hidden_dim: int
    num_layers: int

    def setup(self):
        self.embedding = nn.Embed(num_embeddings=self.input_dim, features=self.embed_dim)
        self.lstm_cells = [nn.LSTMCell(features=self.hidden_dim) for _ in range(self.num_layers)]

    def __call__(self, x):
        # x: (batch, seq_length)
        embedded = self.embedding(x)  # (batch, seq_length, embed_dim)
        batch, seq_length, _ = embedded.shape

        hidden_states = [jnp.zeros((batch, self.hidden_dim)) for _ in range(self.num_layers)]
        cell_states = [jnp.zeros((batch, self.hidden_dim)) for _ in range(self.num_layers)]
        outputs = []
        for t in range(seq_length):
            x_t = embedded[:, t, :]
            for i, cell in enumerate(self.lstm_cells):
                (cell_states[i], hidden_states[i]), x_t = cell((cell_states[i], hidden_states[i]), x_t)
            outputs.append(x_t)
        outputs = jnp.stack(outputs, axis=1)  # (batch, seq_length, hidden_dim)

        hidden_states = jnp.stack(hidden_states, axis=0)
        cell_states = jnp.stack(cell_states, axis=0)
        return outputs, (hidden_states, cell_states)
"""


"""
Error Code:
class Decoder(nn.Module):
    output_dim: int
    embed_dim: int
    hidden_dim: int
    num_layers: int
    src_seq_length: int

    def setup(self):
        self.embedding = nn.Embed(num_embeddings=self.output_dim, features=self.embed_dim)
        self.attention = nn.Dense(self.src_seq_length)
        self.attention_combine = nn.Dense(self.embed_dim)
        self.lstm = nn.LSTMCell(features=self.hidden_dim)
        self.fc_out = nn.Dense(self.output_dim)

    def __call__(self, decoder_input, encoder_outputs, hidden_state, cell_state):
        # decoder_input is a token index, the shape is (batch,) or (batch, 1)
        embedded = self.embedding(decoder_input)  # Output shape: (batch, embed_dim) or (batch, 1, embed_dim)
        if embedded.ndim == 3:
            embedded = embedded.squeeze(1)
        # Compute the attention scores
        # Concatenate the current embedding and the previous hidden state (assuming the shape of hidden_state is (batch, hidden_dim))
        concat_input = jnp.concatenate([embedded, hidden_state], axis=-1) # Shape (batch, embed_dim + hidden_dim)
        attention_scores = self.attention(concat_input) # Output shape (batch, src_seq_length)
        attention_weights = jax.nn.softmax(attention_scores, axis=-1)
        context_vector = jnp.einsum('bs,bsh->bh', attention_weights, encoder_outputs) # Get (batch, hidden_dim)

        # Fusion current embedding and context vector
        combined = jnp.concatenate([embedded, context_vector], axis=-1) # (batch, embed_dim + hidden_dim)
        combined = jax.nn.tanh(self.attention_combine(combined))
        
        carry, lstm_output = self.lstm((cell_state, hidden_state), combined)
        new_cell_state, new_hidden_state = carry
        output = self.fc_out(lstm_output)

        return output, new_hidden_state, new_cell_state

    def update_hidden_state(self, hidden_state, context_vector):
        # Dummy update function for hidden state
        return hidden_state + context_vector  # Replace with actual update logic


Error:
Only a single-layer LSTMCell is created in the Decoder, while the Decoder in the PyTorch code uses multiple layers of LSTM


Fix Guide:
Modify setup() to use a list to generate multiple LSTMCells, and update each layer in turn in __call__


Correct Code:
class Decoder(nn.Module):
    output_dim: int
    embed_dim: int
    hidden_dim: int
    num_layers: int
    src_seq_length: int

    def setup(self):
        self.embedding = nn.Embed(num_embeddings=self.output_dim, features=self.embed_dim)
        self.attention = nn.Dense(self.src_seq_length)
        self.attention_combine = nn.Dense(self.embed_dim)
        self.lstm_cells = [nn.LSTMCell(features=self.hidden_dim) for _ in range(self.num_layers)]
        self.fc_out = nn.Dense(self.output_dim)

    def __call__(self, decoder_input, encoder_outputs, hidden_state, cell_state):
        # decoder_input: (batch,) 或 (batch, 1)
        embedded = self.embedding(decoder_input)  # (batch, embed_dim) 或 (batch, 1, embed_dim)
        if embedded.ndim == 3:
            embedded = embedded.squeeze(1)  # (batch, embed_dim)

        concat_input = jnp.concatenate([embedded, hidden_state[-1]], axis=-1)  # (batch, embed_dim + hidden_dim)
        attention_scores = self.attention(concat_input)  # (batch, src_seq_length)
        attention_weights = jax.nn.softmax(attention_scores, axis=-1)
        context_vector = jnp.einsum('bs,bsh->bh', attention_weights, encoder_outputs)  # (batch, hidden_dim)

        combined = jnp.concatenate([embedded, context_vector], axis=-1)  # (batch, embed_dim + hidden_dim)
        combined = jax.nn.tanh(self.attention_combine(combined))  # (batch, embed_dim)
        
        new_hidden_states = []
        new_cell_states = []
        x = combined

        for i, cell in enumerate(self.lstm_cells):
            (new_cell, new_hidden), x = cell((cell_state[i], hidden_state[i]), x)
            new_hidden_states.append(new_hidden)
            new_cell_states.append(new_cell)
        new_hidden_states = jnp.stack(new_hidden_states, axis=0)  # (num_layers, batch, hidden_dim)
        new_cell_states = jnp.stack(new_cell_states, axis=0)      # (num_layers, batch, hidden_dim)
        output = self.fc_out(x)  # (batch, output_dim)
        return output, new_hidden_states, new_cell_states
"""


"""
Error Code:
def main():
    # Example parameters
    vocab_size = 10000
    embed_dim = 32
    hidden_dim = 256
    num_layers = 1 
    src_seq_length = 10

    # Initialize decoder and states
    decoder = Decoder(output_dim=vocab_size, embed_dim=embed_dim, hidden_dim=hidden_dim,
                      num_layers=num_layers, src_seq_length=src_seq_length)

    hidden_state = jnp.zeros((1, hidden_dim))
    cell_state = jnp.zeros((1, hidden_dim))
    decoder_input = jnp.array([0])
    encoder_outputs = jnp.zeros((1, src_seq_length, hidden_dim))
    
    rng = jax.random.PRNGKey(0)
    output_sequence = []

    # Decoding process
    tgt_seq_length = 12
    variables = decoder.init(rng, decoder_input, encoder_outputs, hidden_state, cell_state)
    
    @jax.jit
    def decode_step(decoder_input, hidden_state, cell_state, variables, encoder_outputs):
        output, new_hidden_state, new_cell_state = decoder.apply(variables, decoder_input, encoder_outputs, hidden_state, cell_state)
        predicted = jnp.argmax(output, axis=1)
        return predicted, new_hidden_state, new_cell_state

    # Decoding process
    for _ in range(tgt_seq_length):
        predicted, hidden_state, cell_state = decode_step(decoder_input, hidden_state, cell_state, variables, encoder_outputs)
        output_sequence.append(int(predicted.item()))
        decoder_input = predicted

    print(f"Input: {jnp.zeros((1, vocab_size)).tolist()}, Output: {output_sequence}")  # Placeholder for input


Error:
The required parameters are missing, the Encoder is not called, and a randomly generated test_input is not used to get the encoder_outputs and status through the Encoder and then pass them to the Decoder


Fix Guide:
Add the corresponding parameters in pytorch and call Encoder. In the test phase, a test input should be generated first, encoder_outputs and initial state should be obtained through Encoder, and then Decoder should be called for decoding, and finally the actual input and output should be printed


Correct Code:
def main():
    # Example parameters
    src_vocab_size = 20
    tgt_vocab_size = 20
    src_seq_length = 10
    tgt_seq_length = 12
    batch_size = 1 
    embed_dim = 32
    hidden_dim = 64
    num_layers = 2

    rng = jax.random.PRNGKey(42)
    
    encoder = Encoder(input_dim=src_vocab_size, embed_dim=embed_dim, hidden_dim=hidden_dim, num_layers=num_layers)
    decoder = Decoder(output_dim=tgt_vocab_size, embed_dim=embed_dim, hidden_dim=hidden_dim,
                      num_layers=num_layers, src_seq_length=src_seq_length)
    
    test_input = jax.random.randint(rng, (1, src_seq_length), 0, src_vocab_size)
    encoder_variables = encoder.init(rng, test_input)
    encoder_outputs, (enc_hidden, enc_cell) = encoder.apply(encoder_variables, test_input)
    
    hidden_state = jnp.zeros((num_layers, 1, hidden_dim))
    cell_state = jnp.zeros((num_layers, 1, hidden_dim))
    
    decoder_input = jnp.array([0])  
    decoder_variables = decoder.init(rng, decoder_input, encoder_outputs, hidden_state, cell_state)
    
    output_sequence = []
    
    @jax.jit
    def decode_step(decoder_input, hidden_state, cell_state, variables, encoder_outputs):
        output, new_hidden_state, new_cell_state = decoder.apply(variables, decoder_input, encoder_outputs, hidden_state, cell_state)
        predicted = jnp.argmax(output, axis=-1)
        return predicted, new_hidden_state, new_cell_state
    
    for _ in range(tgt_seq_length):
        predicted, hidden_state, cell_state = decode_step(decoder_input, hidden_state, cell_state, decoder_variables, encoder_outputs)
        output_sequence.append(int(predicted.item()))
        decoder_input = predicted
    
    print(f"Input: {test_input.tolist()}, Output: {output_sequence}")
"""


"""
Error Code:
hidden_state = jnp.zeros((num_layers, 1, hidden_dim))
cell_state = jnp.zeros((num_layers, 1, hidden_dim))


Error:
The PyTorch code directly uses the encoder output as the initial state of the decoder to pass context information.


Fix Guide:
Modified to use enc_hidden and enc_cell returned by the encoder as the decoder initial state


Correct Code:
hidden_state, cell_state = enc_hidden, enc_cell
"""


"""
Error Code:
# In main(), only the inference decoding process is implemented, without the training loop


Error:
Compared to the PyTorch code, the JAX code lacks the implementation of the training loop, loss function calculation, optimizer updates, and teacher forcing


Fix Guide:
Add training data generation, cross entropy loss function, optax-based Adam optimizer, and a training loop with teacher forcing at each time step


Correct Code:
batch_size = 16

src_data = jax.random.randint(rng, (batch_size, src_seq_length), 0, src_vocab_size)
tgt_data = jax.random.randint(rng, (batch_size, tgt_seq_length), 0, tgt_vocab_size)

tx = optax.adam(0.001)
state = train_state.TrainState.create(apply_fn=None, params=params, tx=tx)


def loss_fn(params, encoder, decoder, src, tgt):
    encoder_outputs, (enc_hidden, enc_cell) = encoder.apply({'params': params['encoder']}, src)
    loss = 0.0
    hidden_state, cell_state = enc_hidden, enc_cell
    decoder_input = jnp.zeros((src.shape[0],), dtype=jnp.int32)
    for t in range(tgt.shape[1]):
        logits, hidden_state, cell_state = decoder.apply({'params': params['decoder']}, decoder_input, encoder_outputs, hidden_state, cell_state)
        loss += jnp.mean(optax.softmax_cross_entropy_with_integer_labels(logits, tgt[:, t]))
        decoder_input = tgt[:, t]
    return loss


def create_train_state(rng, encoder, decoder, src_vocab_size, tgt_vocab_size, src_seq_length):
    encoder_variables = encoder.init(rng, jnp.ones((1, src_seq_length), jnp.int32))
    decoder_variables = decoder.init(
        rng,
        jnp.ones((1,), jnp.int32),
        jnp.ones((1, src_seq_length, encoder.hidden_dim)),
        jnp.ones((encoder.num_layers, 1, encoder.hidden_dim)),
        jnp.ones((encoder.num_layers, 1, encoder.hidden_dim))
    )
    params = {
        'encoder': encoder_variables['params'],
        'decoder': decoder_variables['params']
    }
    tx = optax.adam(0.001)
    return train_state.TrainState.create(apply_fn=None, params=params, tx=tx)
"""


"""
Error Code:
decoder_input = predicted


Error:
The JAX code does not use teacher forcing


Fix Guide:
In the training loop, the target token of the current time step should be used as the input of the next decoder step


Correct Code:
for t in range(tgt.shape[1]):
    logits, hidden_state, cell_state = decoder.apply({'params': params['decoder']}, decoder_input, encoder_outputs, hidden_state, cell_state)
    loss += jnp.mean(optax.softmax_cross_entropy_with_integer_labels(logits, tgt[:, t]))
    decoder_input = tgt[:, t]
"""


"""
Error Code:
# No corresponding training information


Error:
Lack of training information, updated gradients, loss and log printing


Fix Guide:
Add corresponding training information, update gradient, loss and log printing


Correct Code:

@jax.jit
def train_step(state, encoder, decoder, src, tgt):
    grad_fn = jax.value_and_grad(loss_fn)
    loss, grads = grad_fn(state.params, encoder, decoder, src, tgt)
    state = state.apply_gradients(grads=grads)
    return state, loss

epochs = 100
for epoch in range(epochs):
    state, loss = train_step(state, encoder, decoder, src_data, tgt_data)
    if (epoch + 1) % 10 == 0:
        print(f"Epoch [{epoch + 1}/{epochs}] - Loss: {loss:.4f}")
"""


"""
Error Code:
@jax.jit
def train_step(state, encoder, decoder, src, tgt):
    grad_fn = jax.value_and_grad(loss_fn)
    loss, grads = grad_fn(state.params, encoder, decoder, src, tgt)
    state = state.apply_gradients(grads=grads)
    return state, loss
    

Error:
Cannot interpret value of type <class 'main.Encoder'> as an abstract array; it does not have a dtype attribute


Fix Guide:
Setting the static_argnums parameter in the jax.jit decorator


Correct Code:
@jax.jit(static_argnums=(1, 2))
def train_step(state, encoder, decoder, src, tgt):
    grad_fn = jax.value_and_grad(loss_fn)
    loss, grads = grad_fn(state.params, encoder, decoder, src, tgt)
    state = state.apply_gradients(grads=grads)
    return state, loss
"""


"""
Error Code:
@jax.jit(static_argnums=(1, 2))
def train_step(state, encoder, decoder, src, tgt):


Error:
jit() missing 1 required positional argument: 'fun'


Fix Guide:
Use Python's built-in partial to fix static parameters and then use it as a decorator


Correct Code:
from functools import partial

@partial(jax.jit, static_argnums=(1, 2))
def train_step(state, encoder, decoder, src, tgt):
"""

In [10]:
#Fixed Code
import jax
import jax.numpy as jnp
from flax import linen as nn
from flax.training import train_state
import optax
import numpy as np
from functools import partial


class Encoder(nn.Module):
    input_dim: int
    embed_dim: int
    hidden_dim: int
    num_layers: int

    def setup(self):
        self.embedding = nn.Embed(num_embeddings=self.input_dim, features=self.embed_dim)
        self.lstm_cells = [nn.LSTMCell(features=self.hidden_dim) for _ in range(self.num_layers)]

    def __call__(self, x):
        # x: (batch, seq_length)
        embedded = self.embedding(x)  # (batch, seq_length, embed_dim)
        batch, seq_length, _ = embedded.shape

        hidden_states = [jnp.zeros((batch, self.hidden_dim)) for _ in range(self.num_layers)]
        cell_states = [jnp.zeros((batch, self.hidden_dim)) for _ in range(self.num_layers)]
        outputs = []
        for t in range(seq_length):
            x_t = embedded[:, t, :]
            for i, cell in enumerate(self.lstm_cells):
                (cell_states[i], hidden_states[i]), x_t = cell((cell_states[i], hidden_states[i]), x_t)
            outputs.append(x_t)
        outputs = jnp.stack(outputs, axis=1)  # (batch, seq_length, hidden_dim)

        hidden_states = jnp.stack(hidden_states, axis=0)
        cell_states = jnp.stack(cell_states, axis=0)
        return outputs, (hidden_states, cell_states)

    
class Decoder(nn.Module):
    output_dim: int
    embed_dim: int
    hidden_dim: int
    num_layers: int
    src_seq_length: int

    def setup(self):
        self.embedding = nn.Embed(num_embeddings=self.output_dim, features=self.embed_dim)
        self.attention = nn.Dense(self.src_seq_length)
        self.attention_combine = nn.Dense(self.embed_dim)
        self.lstm_cells = [nn.LSTMCell(features=self.hidden_dim) for _ in range(self.num_layers)]
        self.fc_out = nn.Dense(self.output_dim)

    def __call__(self, decoder_input, encoder_outputs, hidden_state, cell_state):
        # decoder_input: (batch,) 或 (batch, 1)
        embedded = self.embedding(decoder_input)  # (batch, embed_dim) 或 (batch, 1, embed_dim)
        if embedded.ndim == 3:
            embedded = embedded.squeeze(1)  # (batch, embed_dim)

        concat_input = jnp.concatenate([embedded, hidden_state[-1]], axis=-1)  # (batch, embed_dim + hidden_dim)
        attention_scores = self.attention(concat_input)  # (batch, src_seq_length)
        attention_weights = jax.nn.softmax(attention_scores, axis=-1)
        context_vector = jnp.einsum('bs,bsh->bh', attention_weights, encoder_outputs)  # (batch, hidden_dim)

        combined = jnp.concatenate([embedded, context_vector], axis=-1)  # (batch, embed_dim + hidden_dim)
        combined = jax.nn.tanh(self.attention_combine(combined))  # (batch, embed_dim)
        
        new_hidden_states = []
        new_cell_states = []
        x = combined

        for i, cell in enumerate(self.lstm_cells):
            (new_cell, new_hidden), x = cell((cell_state[i], hidden_state[i]), x)
            new_hidden_states.append(new_hidden)
            new_cell_states.append(new_cell)
        new_hidden_states = jnp.stack(new_hidden_states, axis=0)  # (num_layers, batch, hidden_dim)
        new_cell_states = jnp.stack(new_cell_states, axis=0)      # (num_layers, batch, hidden_dim)
        output = self.fc_out(x)  # (batch, output_dim)
        return output, new_hidden_states, new_cell_states


def loss_fn(params, encoder, decoder, src, tgt):
    encoder_outputs, (enc_hidden, enc_cell) = encoder.apply({'params': params['encoder']}, src)
    loss = 0.0
    batch_size = src.shape[0]
    hidden_state, cell_state = enc_hidden, enc_cell

    decoder_input = jnp.zeros((batch_size,), dtype=jnp.int32)
    tgt_seq_length = tgt.shape[1]
    for t in range(tgt_seq_length):
        logits, hidden_state, cell_state = decoder.apply({'params': params['decoder']},
                                                           decoder_input,
                                                           encoder_outputs,
                                                           hidden_state,
                                                           cell_state)
        loss += jnp.mean(optax.softmax_cross_entropy_with_integer_labels(logits, tgt[:, t]))

        decoder_input = tgt[:, t]
    return loss
    
    
def create_train_state(rng, encoder, decoder, src_vocab_size, tgt_vocab_size, src_seq_length):
    encoder_variables = encoder.init(rng, jnp.ones((1, src_seq_length), jnp.int32))
    decoder_variables = decoder.init(
        rng,
        jnp.ones((1,), jnp.int32),
        jnp.ones((1, src_seq_length, encoder.hidden_dim)),
        jnp.ones((encoder.num_layers, 1, encoder.hidden_dim)),
        jnp.ones((encoder.num_layers, 1, encoder.hidden_dim))
    )
    params = {
        'encoder': encoder_variables['params'],
        'decoder': decoder_variables['params']
    }
    tx = optax.adam(0.001)
    return train_state.TrainState.create(apply_fn=None, params=params, tx=tx)


@partial(jax.jit, static_argnums=(1, 2))
def train_step(state, encoder, decoder, src, tgt):
    grad_fn = jax.value_and_grad(loss_fn)
    loss, grads = grad_fn(state.params, encoder, decoder, src, tgt)
    state = state.apply_gradients(grads=grads)
    return state, loss


def main():
    # Example parameters
    src_vocab_size = 20
    tgt_vocab_size = 20
    src_seq_length = 10
    tgt_seq_length = 12
    batch_size = 1 
    embed_dim = 32
    hidden_dim = 64
    num_layers = 2

    rng = jax.random.PRNGKey(42)
    
    encoder = Encoder(input_dim=src_vocab_size, embed_dim=embed_dim, hidden_dim=hidden_dim, num_layers=num_layers)
    decoder = Decoder(output_dim=tgt_vocab_size, embed_dim=embed_dim, hidden_dim=hidden_dim,
                      num_layers=num_layers, src_seq_length=src_seq_length)
    
    src_data = jax.random.randint(rng, (batch_size, src_seq_length), 0, src_vocab_size)
    tgt_data = jax.random.randint(rng, (batch_size, tgt_seq_length), 0, tgt_vocab_size)
    
    state = create_train_state(rng, encoder, decoder, src_vocab_size, tgt_vocab_size, src_seq_length)
    
    epochs = 100
    for epoch in range(epochs):
        state, loss = train_step(state, encoder, decoder, src_data, tgt_data)
        if (epoch + 1) % 10 == 0:
            print(f"Epoch [{epoch + 1}/{epochs}] - Loss: {loss:.4f}")
    
    test_input = jax.random.randint(rng, (1, src_seq_length), 0, src_vocab_size)
    encoder_variables = encoder.init(rng, test_input)
    encoder_outputs, (enc_hidden, enc_cell) = encoder.apply(encoder_variables, test_input)
    
    hidden_state = jnp.zeros((num_layers, 1, hidden_dim))
    cell_state = jnp.zeros((num_layers, 1, hidden_dim))
    
    decoder_input = jnp.array([0])  
    decoder_variables = decoder.init(rng, decoder_input, encoder_outputs, hidden_state, cell_state)
    
    output_sequence = []
    
    @jax.jit
    def decode_step(decoder_input, hidden_state, cell_state, variables, encoder_outputs):
        output, new_hidden_state, new_cell_state = decoder.apply(variables, decoder_input, encoder_outputs, hidden_state, cell_state)
        predicted = jnp.argmax(output, axis=-1)
        return predicted, new_hidden_state, new_cell_state
    
    for _ in range(tgt_seq_length):
        predicted, hidden_state, cell_state = decode_step(decoder_input, hidden_state, cell_state, decoder_variables, encoder_outputs)
        output_sequence.append(int(predicted.item()))
        decoder_input = predicted
    
    print(f"Input: {test_input.tolist()}, Output: {output_sequence}")
    
    
if __name__ == "__main__":
    main()

Epoch [10/100] - Loss: 34.3292
Epoch [20/100] - Loss: 27.7096
Epoch [30/100] - Loss: 19.8640
Epoch [40/100] - Loss: 14.5860
Epoch [50/100] - Loss: 11.3062
Epoch [60/100] - Loss: 9.1747
Epoch [70/100] - Loss: 7.5736
Epoch [80/100] - Loss: 6.2618
Epoch [90/100] - Loss: 5.1877
Epoch [100/100] - Loss: 4.3040
Input: [[15, 12, 3, 7, 15, 15, 9, 16, 12, 9]], Output: [4, 16, 16, 12, 12, 12, 4, 5, 4, 16, 16, 16]
