In [5]:
import tensorflow as tf
import numpy as np
import random

# Corrected xLSTM Implementation
class xLSTMCell(tf.keras.layers.Layer):
    def __init__(self, units, **kwargs):
        super().__init__(**kwargs)
        self.units = units
        self.state_size = [units, units, units, units]  # h, c, n, m
        self.output_size = units

    def build(self, input_shape):
        input_dim = input_shape[-1]
        # Proper weight initialization from paper
        self.kernel = self.add_weight(
            shape=(input_dim, 4 * self.units),
            initializer='glorot_uniform',
            name='kernel'
        )
        self.recurrent_kernel = self.add_weight(
            shape=(self.units, 4 * self.units),  # Fixed size to match all gates
            initializer='orthogonal',
            name='recurrent_kernel'
        )
        self.bias = self.add_weight(
            shape=(4 * self.units,),
            initializer='zeros',
            name='bias'
        )

    def call(self, inputs, states):
        h_prev, c_prev, n_prev, m_prev = states

        # Compute all gates together
        gates = tf.matmul(inputs, self.kernel) + tf.matmul(h_prev, self.recurrent_kernel) + self.bias

        # Split into components following paper equations
        z, i, f, o = tf.split(gates, num_or_size_splits=4, axis=1)

        # Exponential gating with stabilization
        i_t = tf.exp(i)  # Input gate (exp activation)
        f_t = tf.exp(f)  # Forget gate (exp activation)

        # Stabilization mechanism (Equation 15-17)
        m_t = tf.maximum(tf.math.log(f_t) + m_prev, tf.math.log(i_t))
        i_prime = tf.exp(tf.math.log(i_t) - m_t)
        f_prime = tf.exp(tf.math.log(f_t) + m_prev - m_t)

        # State updates (Equations 8-10)
        n_t = f_prime * n_prev + i_prime
        c_t = f_prime * c_prev + i_prime * tf.tanh(z)
        h_t = tf.sigmoid(o) * (c_t / (n_t + 1e-8))  # Add epsilon for numerical stability

        return h_t, [h_t, c_t, n_t, m_t]



In [6]:
# Enhanced model builder
def build_model(cell_type, units=32, input_dim=3):
    inputs = tf.keras.Input(shape=(None, input_dim))

    if cell_type == 'xLSTM':
        cell = xLSTMCell(units)
        x = tf.keras.layers.RNN(cell, return_sequences=False)(inputs)
    elif cell_type == 'LSTM':
        x = tf.keras.layers.LSTM(units, return_sequences=False)(inputs)
    else:
        x = tf.keras.layers.SimpleRNN(units, return_sequences=False)(inputs)

    outputs = tf.keras.layers.Dense(1, activation='sigmoid')(x)
    return tf.keras.Model(inputs, outputs)

# Enhanced training with proper regularization
def train_and_evaluate(model, x_train, y_train, x_test, y_test, name):
    model.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate=0.001, clipvalue=1.0),
        loss='binary_crossentropy',
        metrics=['accuracy'],
    )

    # Add early stopping
    es = tf.keras.callbacks.EarlyStopping(
        monitor='val_loss',
        patience=3,
        restore_best_weights=True
    )

    history = model.fit(
        x_train, y_train,
        validation_data=(x_test, y_test),
        epochs=20,
        batch_size=32,
        callbacks=[es],
        verbose=0
    )

    # Get best validation accuracy
    best_epoch = np.argmin(history.history['val_loss'])
    best_acc = history.history['val_accuracy'][best_epoch]
    print(f"{name} Best Validation Accuracy: {best_acc:.4f}")
    return best_acc


In [7]:

if __name__ == "__main__":
    # Generate enhanced dataset
    def generate_dyck(n_samples, max_len=20):
        # Improved dataset generation with more challenging patterns
        data = []
        labels = []
        for _ in range(n_samples):
            depth = 0
            s = []
            valid = True
            for _ in range(max_len):
                if random.random() < 0.5:
                    s.append('(')
                    depth += 1
                else:
                    s.append(')')
                    depth -= 1
                    if depth < 0:
                        valid = False
            label = 1 if valid and depth == 0 else 0
            data.append(''.join(s))
            labels.append(label)
        return data, labels

    # Vectorizer with masking
    def vectorize_data(sequences, vocab={'(': 0, ')': 1}, max_len=20):
        vec = np.zeros((len(sequences), max_len, len(vocab)), dtype=np.float32)
        for i, seq in enumerate(sequences):
            for t, char in enumerate(seq[:max_len]):
                if char in vocab:
                    vec[i, t, vocab[char]] = 1.0
        return vec

    # Create enhanced dataset
    train_seqs, train_labels = generate_dyck(5000, 20)
    test_seqs, test_labels = generate_dyck(1000, 20)
    x_train = vectorize_data(train_seqs)
    x_test = vectorize_data(test_seqs)
    y_train = np.array(train_labels)
    y_test = np.array(test_labels)

    # Compare models
    results = {}
    for model_type in ['xLSTM', 'LSTM', 'RNN']:
        model = build_model(model_type, units=32, input_dim=2)  # Reduced input dim since we removed padding token
        acc = train_and_evaluate(model, x_train, y_train, x_test, y_test, model_type)
        results[model_type] = acc

    print("\nFinal Comparison:")
    for model, acc in results.items():
        print(f"{model}: {acc:.4f}")

xLSTM Best Validation Accuracy: 0.9940
LSTM Best Validation Accuracy: 0.9910
RNN Best Validation Accuracy: 0.9840

Final Comparison:
xLSTM: 0.9940
LSTM: 0.9910
RNN: 0.9840
