In [None]:
import tensorflow as tf

# ─── 1. Layer‐builders ───────────────────────────────────────────────

def build_sLSTM_layers(units):
    return {
        'input_gate':    tf.keras.layers.Dense(units),
        'forget_gate':   tf.keras.layers.Dense(units),
        'output_gate':   tf.keras.layers.Dense(units),
        'input_transf':  tf.keras.layers.Dense(units),
    }

def build_mLSTM_layers(units):
    return {
        'key_transf':    tf.keras.layers.Dense(units),
        'value_transf':  tf.keras.layers.Dense(units),
        'query_transf':  tf.keras.layers.Dense(units),
        'input_gate':    tf.keras.layers.Dense(units),
        'forget_gate':   tf.keras.layers.Dense(units),
        'output_gate':   tf.keras.layers.Dense(units),
    }

# ─── 2. Step‐functions ──────────────────────────────────────────────

def sLSTM_step(x, h_prev, c_prev, layers):
    i_t = tf.nn.softplus(layers['input_gate'](x))  # more stable than exp
    f_t = tf.sigmoid(layers['forget_gate'](x))
    o_t = tf.sigmoid(layers['output_gate'](x))
    z_t = tf.tanh(layers['input_transf'](x))

    c_t = f_t * c_prev + i_t * z_t
    h_t = o_t * tf.tanh(c_t)
    return h_t, c_t

def mLSTM_step(x, h_prev, C_prev, layers):
    k = layers['key_transf'](x)
    v = layers['value_transf'](x)
    q = layers['query_transf'](x)

    i_t = tf.nn.softplus(layers['input_gate'](x))  # more stable than exp
    f_t = tf.sigmoid(layers['forget_gate'](x))
    o_t = tf.sigmoid(layers['output_gate'](x))

    v_e = tf.expand_dims(v, 2)     # (batch, d, 1)
    k_e = tf.expand_dims(k, 1)     # (batch, 1, d)
    delta_C = v_e @ k_e            # (batch, d, d)

    C_t = f_t[:, None, None] * C_prev + i_t[:, None, None] * delta_C

    q_e = tf.expand_dims(q, 2)     # (batch, d, 1)
    h_t = tf.squeeze(C_t @ q_e, 2) # (batch, d)
    h_t = o_t * h_t

    return h_t, C_t

# ─── 3. Model as a tf.keras.Model subclass ──────────────────────────

class xLSTM(tf.keras.Model):
    def __init__(self, hidden_dim, block_types):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.block_types = block_types
        self.blocks = []

        for bt in block_types:
            if bt == 'sLSTM':
                self.blocks.append(('s', build_sLSTM_layers(hidden_dim)))
            else:
                self.blocks.append(('m', build_mLSTM_layers(hidden_dim)))

        self.output_layer = tf.keras.layers.Dense(1)

    def call(self, inputs, training=False):
        batch = tf.shape(inputs)[0]
        seq_len = tf.shape(inputs)[1]

        # Initialize hidden and cell states
        h = [tf.zeros((batch, self.hidden_dim)) for _ in self.blocks]
        c = []
        for kind, _ in self.blocks:
            if kind == 's':
                c.append(tf.zeros((batch, self.hidden_dim)))
            else:
                c.append(tf.zeros((batch, self.hidden_dim, self.hidden_dim)))

        # TensorArray to collect outputs over time
        ta = tf.TensorArray(tf.float32, size=seq_len)

        for t in tf.range(seq_len):
            x = inputs[:, t, :]
            for i, (kind, layers) in enumerate(self.blocks):
                if kind == 's':
                    h[i], c[i] = sLSTM_step(x, h[i], c[i], layers)
                else:
                    h[i], c[i] = mLSTM_step(x, h[i], c[i], layers)
                x = h[i]  # feed to next block

            ta = ta.write(t, h[-1])  # only top layer's output

        # Gather outputs and return final time step's top hidden state
        outs = ta.stack()                    # (seq_len, batch, dim)
        outs = tf.transpose(outs, [1, 0, 2]) # (batch, seq_len, dim)
        final = outs[:, -1, :]               # (batch, dim)
        return self.output_layer(final)      # (batch, 1)
