# LSTM from Scratch

This notebook implements an LSTM (Long Short-Term Memory) cell from scratch. LSTMs solve the vanishing gradient problem that plagues vanilla RNNs by using gating mechanisms to control information flow.

**Goal:** Understand how LSTM gates enable learning of long-range dependencies.

**Prerequisites:** [lstms.md](../architectures/lstms.md), [05-rnn-from-scratch.ipynb](05-rnn-from-scratch.ipynb)

In [None]:
import numpy as np
import matplotlib.pyplot as plt

np.random.seed(42)

## 1. The LSTM Cell

The LSTM has three gates controlling information flow:

**Forget gate** - what to discard from cell state:
$$f_t = \sigma(W_f \cdot [h_{t-1}, x_t] + b_f)$$

**Input gate** - what new information to store:
$$i_t = \sigma(W_i \cdot [h_{t-1}, x_t] + b_i)$$
$$\tilde{C}_t = \tanh(W_C \cdot [h_{t-1}, x_t] + b_C)$$

**Cell state update** - the "memory":
$$C_t = f_t \odot C_{t-1} + i_t \odot \tilde{C}_t$$

**Output gate** - what to output:
$$o_t = \sigma(W_o \cdot [h_{t-1}, x_t] + b_o)$$
$$h_t = o_t \odot \tanh(C_t)$$

In [None]:
def sigmoid(x):
    """Numerically stable sigmoid."""
    return np.where(
        x >= 0,
        1 / (1 + np.exp(-x)),
        np.exp(x) / (1 + np.exp(x))
    )

def sigmoid_derivative(s):
    """Derivative of sigmoid (given sigmoid output)."""
    return s * (1 - s)

def tanh_derivative(t):
    """Derivative of tanh (given tanh output)."""
    return 1 - t**2

In [None]:
class LSTMCell:
    """
    Single LSTM cell.
    
    Processes one timestep, maintaining both hidden state (h) and cell state (c).
    """
    
    def __init__(self, input_dim, hidden_dim):
        """
        Args:
            input_dim: Size of input vectors
            hidden_dim: Size of hidden/cell state
        """
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        
        # Combined input dimension: [h_{t-1}, x_t]
        combined_dim = input_dim + hidden_dim
        
        # Initialize weights (Xavier)
        scale = np.sqrt(1.0 / combined_dim)
        
        # Forget gate
        self.W_f = np.random.randn(combined_dim, hidden_dim) * scale
        self.b_f = np.ones(hidden_dim)  # Initialize to 1 (keep by default)
        
        # Input gate
        self.W_i = np.random.randn(combined_dim, hidden_dim) * scale
        self.b_i = np.zeros(hidden_dim)
        
        # Cell candidate
        self.W_c = np.random.randn(combined_dim, hidden_dim) * scale
        self.b_c = np.zeros(hidden_dim)
        
        # Output gate
        self.W_o = np.random.randn(combined_dim, hidden_dim) * scale
        self.b_o = np.zeros(hidden_dim)
    
    def forward(self, x, h_prev, c_prev):
        """
        Forward pass for single timestep.
        
        Args:
            x: Input [batch, input_dim]
            h_prev: Previous hidden state [batch, hidden_dim]
            c_prev: Previous cell state [batch, hidden_dim]
            
        Returns:
            h: New hidden state
            c: New cell state
        """
        # Store inputs for backprop
        self.x = x
        self.h_prev = h_prev
        self.c_prev = c_prev
        
        # Concatenate input and previous hidden state
        self.concat = np.concatenate([h_prev, x], axis=-1)
        
        # Forget gate
        self.f = sigmoid(self.concat @ self.W_f + self.b_f)
        
        # Input gate
        self.i = sigmoid(self.concat @ self.W_i + self.b_i)
        
        # Cell candidate
        self.c_tilde = np.tanh(self.concat @ self.W_c + self.b_c)
        
        # Output gate
        self.o = sigmoid(self.concat @ self.W_o + self.b_o)
        
        # New cell state: forget old + add new
        self.c = self.f * c_prev + self.i * self.c_tilde
        
        # New hidden state
        self.tanh_c = np.tanh(self.c)
        self.h = self.o * self.tanh_c
        
        return self.h, self.c
    
    def backward(self, dh, dc):
        """
        Backward pass for single timestep.
        
        Args:
            dh: Gradient w.r.t. hidden state
            dc: Gradient w.r.t. cell state (from future)
            
        Returns:
            dx: Gradient w.r.t. input
            dh_prev: Gradient w.r.t. previous hidden state
            dc_prev: Gradient w.r.t. previous cell state
            grads: Dictionary of parameter gradients
        """
        # Gradient from hidden state output
        do = dh * self.tanh_c
        dc_from_h = dh * self.o * tanh_derivative(self.tanh_c)
        
        # Total cell gradient
        dc_total = dc + dc_from_h
        
        # Gradient through cell state update
        df = dc_total * self.c_prev
        di = dc_total * self.c_tilde
        dc_tilde = dc_total * self.i
        dc_prev = dc_total * self.f
        
        # Gradient through gates (sigmoid derivative)
        df_raw = df * sigmoid_derivative(self.f)
        di_raw = di * sigmoid_derivative(self.i)
        dc_tilde_raw = dc_tilde * tanh_derivative(self.c_tilde)
        do_raw = do * sigmoid_derivative(self.o)
        
        # Parameter gradients
        dW_f = self.concat.T @ df_raw if len(df_raw.shape) > 1 else np.outer(self.concat, df_raw)
        dW_i = self.concat.T @ di_raw if len(di_raw.shape) > 1 else np.outer(self.concat, di_raw)
        dW_c = self.concat.T @ dc_tilde_raw if len(dc_tilde_raw.shape) > 1 else np.outer(self.concat, dc_tilde_raw)
        dW_o = self.concat.T @ do_raw if len(do_raw.shape) > 1 else np.outer(self.concat, do_raw)
        
        db_f = df_raw.sum(axis=0) if len(df_raw.shape) > 1 else df_raw
        db_i = di_raw.sum(axis=0) if len(di_raw.shape) > 1 else di_raw
        db_c = dc_tilde_raw.sum(axis=0) if len(dc_tilde_raw.shape) > 1 else dc_tilde_raw
        db_o = do_raw.sum(axis=0) if len(do_raw.shape) > 1 else do_raw
        
        # Gradient through concatenation
        dconcat = (df_raw @ self.W_f.T + di_raw @ self.W_i.T + 
                   dc_tilde_raw @ self.W_c.T + do_raw @ self.W_o.T)
        
        # Split gradient for h_prev and x
        dh_prev = dconcat[..., :self.hidden_dim]
        dx = dconcat[..., self.hidden_dim:]
        
        grads = {
            'W_f': dW_f, 'b_f': db_f,
            'W_i': dW_i, 'b_i': db_i,
            'W_c': dW_c, 'b_c': db_c,
            'W_o': dW_o, 'b_o': db_o,
        }
        
        return dx, dh_prev, dc_prev, grads

In [None]:
# Test the LSTM cell
cell = LSTMCell(input_dim=10, hidden_dim=20)

x = np.random.randn(10)  # Single input
h_prev = np.zeros(20)
c_prev = np.zeros(20)

h, c = cell.forward(x, h_prev, c_prev)
print(f"Input shape: {x.shape}")
print(f"Hidden state shape: {h.shape}")
print(f"Cell state shape: {c.shape}")
print(f"\nGate values (should be in [0,1]):")
print(f"  Forget gate: [{cell.f.min():.3f}, {cell.f.max():.3f}]")
print(f"  Input gate:  [{cell.i.min():.3f}, {cell.i.max():.3f}]")
print(f"  Output gate: [{cell.o.min():.3f}, {cell.o.max():.3f}]")

## 2. Character-Level LSTM

Now let's build a full LSTM for character-level text generation.

In [None]:
class CharLSTM:
    """
    Character-level LSTM for text generation.
    """
    
    def __init__(self, vocab_size, hidden_dim):
        self.vocab_size = vocab_size
        self.hidden_dim = hidden_dim
        
        # LSTM cell
        self.cell = LSTMCell(vocab_size, hidden_dim)
        
        # Output layer
        scale = np.sqrt(1.0 / hidden_dim)
        self.W_hy = np.random.randn(hidden_dim, vocab_size) * scale
        self.b_y = np.zeros(vocab_size)
    
    def forward(self, inputs, h_init=None, c_init=None):
        """
        Forward pass through sequence.
        """
        seq_len = len(inputs)
        
        if h_init is None:
            h_init = np.zeros(self.hidden_dim)
        if c_init is None:
            c_init = np.zeros(self.hidden_dim)
        
        # Storage
        self.inputs = inputs
        self.xs = {}
        self.hs = {-1: h_init}
        self.cs = {-1: c_init}
        self.os = {}
        self.ps = {}
        
        # Store cell states for backprop
        self.cells = {}
        
        for t in range(seq_len):
            # One-hot encode
            x = np.zeros(self.vocab_size)
            x[inputs[t]] = 1
            self.xs[t] = x
            
            # Create fresh cell for this timestep (to store intermediate values)
            self.cells[t] = LSTMCell(self.vocab_size, self.hidden_dim)
            # Copy weights
            for attr in ['W_f', 'b_f', 'W_i', 'b_i', 'W_c', 'b_c', 'W_o', 'b_o']:
                setattr(self.cells[t], attr, getattr(self.cell, attr))
            
            # LSTM step
            self.hs[t], self.cs[t] = self.cells[t].forward(x, self.hs[t-1], self.cs[t-1])
            
            # Output
            self.os[t] = self.hs[t] @ self.W_hy + self.b_y
            
            # Softmax
            exp_o = np.exp(self.os[t] - self.os[t].max())
            self.ps[t] = exp_o / exp_o.sum()
        
        return self.ps, self.hs, self.cs
    
    def loss(self, targets):
        """Cross-entropy loss."""
        loss = 0
        for t in range(len(targets)):
            loss -= np.log(self.ps[t][targets[t]] + 1e-10)
        return loss / len(targets)
    
    def backward(self, targets):
        """Backpropagation through time."""
        seq_len = len(targets)
        
        # Initialize gradients
        dW_f = np.zeros_like(self.cell.W_f)
        dW_i = np.zeros_like(self.cell.W_i)
        dW_c = np.zeros_like(self.cell.W_c)
        dW_o = np.zeros_like(self.cell.W_o)
        db_f = np.zeros_like(self.cell.b_f)
        db_i = np.zeros_like(self.cell.b_i)
        db_c = np.zeros_like(self.cell.b_c)
        db_o = np.zeros_like(self.cell.b_o)
        dW_hy = np.zeros_like(self.W_hy)
        db_y = np.zeros_like(self.b_y)
        
        # Gradients flowing back
        dh_next = np.zeros(self.hidden_dim)
        dc_next = np.zeros(self.hidden_dim)
        
        for t in reversed(range(seq_len)):
            # Output gradient
            do = self.ps[t].copy()
            do[targets[t]] -= 1
            
            # Output layer gradients
            dW_hy += np.outer(self.hs[t], do)
            db_y += do
            
            # Gradient to hidden state
            dh = do @ self.W_hy.T + dh_next
            
            # LSTM backward
            dx, dh_next, dc_next, cell_grads = self.cells[t].backward(dh, dc_next)
            
            # Accumulate LSTM gradients
            dW_f += cell_grads['W_f']
            dW_i += cell_grads['W_i']
            dW_c += cell_grads['W_c']
            dW_o += cell_grads['W_o']
            db_f += cell_grads['b_f']
            db_i += cell_grads['b_i']
            db_c += cell_grads['b_c']
            db_o += cell_grads['b_o']
        
        # Clip gradients
        for grad in [dW_f, dW_i, dW_c, dW_o, db_f, db_i, db_c, db_o, dW_hy, db_y]:
            np.clip(grad, -5, 5, out=grad)
        
        return {
            'W_f': dW_f, 'b_f': db_f,
            'W_i': dW_i, 'b_i': db_i,
            'W_c': dW_c, 'b_c': db_c,
            'W_o': dW_o, 'b_o': db_o,
            'W_hy': dW_hy, 'b_y': db_y,
        }
    
    def update(self, grads, lr):
        """Update parameters."""
        self.cell.W_f -= lr * grads['W_f']
        self.cell.W_i -= lr * grads['W_i']
        self.cell.W_c -= lr * grads['W_c']
        self.cell.W_o -= lr * grads['W_o']
        self.cell.b_f -= lr * grads['b_f']
        self.cell.b_i -= lr * grads['b_i']
        self.cell.b_c -= lr * grads['b_c']
        self.cell.b_o -= lr * grads['b_o']
        self.W_hy -= lr * grads['W_hy']
        self.b_y -= lr * grads['b_y']
    
    def sample(self, seed_char, length, temperature=1.0):
        """Generate text."""
        h = np.zeros(self.hidden_dim)
        c = np.zeros(self.hidden_dim)
        x_idx = seed_char
        generated = [x_idx]
        
        for _ in range(length):
            x = np.zeros(self.vocab_size)
            x[x_idx] = 1
            
            h, c = self.cell.forward(x, h, c)
            o = h @ self.W_hy + self.b_y
            
            o = o / temperature
            exp_o = np.exp(o - o.max())
            probs = exp_o / exp_o.sum()
            
            x_idx = np.random.choice(self.vocab_size, p=probs)
            generated.append(x_idx)
        
        return generated

## 3. Training

In [None]:
# Sample text
text = """To be, or not to be, that is the question:
Whether 'tis nobler in the mind to suffer
The slings and arrows of outrageous fortune,
Or to take arms against a sea of troubles
And by opposing end them. To die: to sleep;
No more; and by a sleep to say we end
The heart-ache and the thousand natural shocks
That flesh is heir to. 'Tis a consummation
Devoutly to be wish'd. To die, to sleep;
To sleep: perchance to dream: ay, there's the rub."""

# Create vocabulary
chars = sorted(list(set(text)))
vocab_size = len(chars)
char_to_idx = {ch: i for i, ch in enumerate(chars)}
idx_to_char = {i: ch for i, ch in enumerate(chars)}

data = [char_to_idx[ch] for ch in text]

print(f"Text length: {len(text)}")
print(f"Vocabulary size: {vocab_size}")

In [None]:
def train_lstm(model, data, seq_length=25, epochs=100, lr=0.1, print_every=10):
    """Train the LSTM."""
    losses = []
    
    for epoch in range(epochs):
        h_prev = np.zeros(model.hidden_dim)
        c_prev = np.zeros(model.hidden_dim)
        epoch_loss = 0
        n_batches = 0
        
        for i in range(0, len(data) - seq_length - 1, seq_length):
            inputs = data[i:i+seq_length]
            targets = data[i+1:i+seq_length+1]
            
            # Forward
            probs, hidden_states, cell_states = model.forward(inputs, h_prev, c_prev)
            loss = model.loss(targets)
            
            # Backward
            grads = model.backward(targets)
            model.update(grads, lr)
            
            # Carry state forward
            h_prev = hidden_states[seq_length - 1].copy()
            c_prev = cell_states[seq_length - 1].copy()
            
            epoch_loss += loss
            n_batches += 1
        
        avg_loss = epoch_loss / n_batches
        losses.append(avg_loss)
        
        if epoch % print_every == 0:
            print(f"Epoch {epoch:3d}: Loss = {avg_loss:.4f}")
            sample_idx = model.sample(data[0], 100, temperature=0.8)
            sample_text = ''.join([idx_to_char[i] for i in sample_idx])
            print(f"  Sample: {sample_text[:60]}...\n")
    
    return losses

In [None]:
# Train LSTM
lstm_model = CharLSTM(vocab_size=vocab_size, hidden_dim=100)

print("Training LSTM...\n")
lstm_losses = train_lstm(lstm_model, data, seq_length=25, epochs=200, lr=0.1, print_every=25)

In [None]:
plt.figure(figsize=(10, 5))
plt.plot(lstm_losses)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('LSTM Training Loss')
plt.grid(True)
plt.show()

## 4. Visualizing the Gates

Let's see how the gates behave when processing text.

In [None]:
def visualize_gates(model, text_sample, char_to_idx, idx_to_char):
    """Visualize gate activations for a text sample."""
    data_sample = [char_to_idx[ch] for ch in text_sample]
    
    # Forward pass
    model.forward(data_sample)
    
    # Extract gate values
    forget_gates = []
    input_gates = []
    output_gates = []
    
    for t in range(len(data_sample)):
        forget_gates.append(model.cells[t].f)
        input_gates.append(model.cells[t].i)
        output_gates.append(model.cells[t].o)
    
    F = np.array(forget_gates)
    I = np.array(input_gates)
    O = np.array(output_gates)
    
    # Plot
    fig, axes = plt.subplots(3, 1, figsize=(14, 10))
    
    for ax, gate, name in zip(axes, [F, I, O], ['Forget Gate', 'Input Gate', 'Output Gate']):
        im = ax.imshow(gate.T[:20], aspect='auto', cmap='RdYlGn', vmin=0, vmax=1)
        ax.set_xlabel('Time step')
        ax.set_ylabel('Hidden unit')
        ax.set_title(name)
        ax.set_xticks(range(len(text_sample)))
        ax.set_xticklabels(list(text_sample), fontsize=8)
        plt.colorbar(im, ax=ax)
    
    plt.tight_layout()
    plt.show()
    
    return F, I, O

# Visualize on a short sample
test_sample = "To be, or not to be"
F, I, O = visualize_gates(lstm_model, test_sample, char_to_idx, idx_to_char)

**Gate interpretations:**
- **Forget gate near 1:** Keep this information
- **Forget gate near 0:** Clear this information
- **Input gate high:** Add new information here
- **Output gate high:** This hidden unit should influence output

## 5. Comparing Gradient Flow: RNN vs LSTM

The key advantage of LSTMs is better gradient flow through the cell state.

In [None]:
# Simple RNN for comparison
class SimpleRNN:
    """Vanilla RNN for comparison."""
    
    def __init__(self, vocab_size, hidden_dim):
        self.vocab_size = vocab_size
        self.hidden_dim = hidden_dim
        
        scale_xh = np.sqrt(1.0 / vocab_size)
        scale_hh = np.sqrt(1.0 / hidden_dim)
        
        self.W_xh = np.random.randn(vocab_size, hidden_dim) * scale_xh
        self.W_hh = np.random.randn(hidden_dim, hidden_dim) * scale_hh
        self.b_h = np.zeros(hidden_dim)
        self.W_hy = np.random.randn(hidden_dim, vocab_size) * np.sqrt(1.0 / hidden_dim)
        self.b_y = np.zeros(vocab_size)
    
    def forward(self, inputs, h_init=None):
        if h_init is None:
            h_init = np.zeros(self.hidden_dim)
        
        self.xs = {}
        self.hs = {-1: h_init}
        self.ps = {}
        
        for t in range(len(inputs)):
            x = np.zeros(self.vocab_size)
            x[inputs[t]] = 1
            self.xs[t] = x
            
            self.hs[t] = np.tanh(x @ self.W_xh + self.hs[t-1] @ self.W_hh + self.b_h)
            
            o = self.hs[t] @ self.W_hy + self.b_y
            exp_o = np.exp(o - o.max())
            self.ps[t] = exp_o / exp_o.sum()
        
        return self.ps, self.hs
    
    def measure_gradient_flow(self, targets):
        """Measure gradient magnitude at each timestep."""
        seq_len = len(targets)
        
        dh_next = np.zeros(self.hidden_dim)
        gradient_norms = []
        
        for t in reversed(range(seq_len)):
            do = self.ps[t].copy()
            do[targets[t]] -= 1
            
            dh = do @ self.W_hy.T + dh_next
            dh_raw = dh * (1 - self.hs[t]**2)
            
            gradient_norms.append(np.linalg.norm(dh_raw))
            
            dh_next = dh_raw @ self.W_hh.T
        
        return gradient_norms[::-1]

In [None]:
# Compare gradient flow
seq_len = 50
test_inputs = data[:seq_len]
test_targets = data[1:seq_len+1]

# RNN
rnn = SimpleRNN(vocab_size, 100)
rnn.forward(test_inputs)
rnn_grads = rnn.measure_gradient_flow(test_targets)

# LSTM - measure gradient through cell state pathway
def measure_lstm_gradient_flow(model, inputs, targets):
    """Measure LSTM gradient flow."""
    model.forward(inputs)
    
    dh_next = np.zeros(model.hidden_dim)
    dc_next = np.zeros(model.hidden_dim)
    gradient_norms = []
    
    for t in reversed(range(len(targets))):
        do = model.ps[t].copy()
        do[targets[t]] -= 1
        
        dh = do @ model.W_hy.T + dh_next
        
        # Total gradient magnitude
        gradient_norms.append(np.linalg.norm(dh) + np.linalg.norm(dc_next))
        
        dx, dh_next, dc_next, _ = model.cells[t].backward(dh, dc_next)
    
    return gradient_norms[::-1]

lstm_grads = measure_lstm_gradient_flow(lstm_model, test_inputs, test_targets)

# Plot
plt.figure(figsize=(12, 5))

# Normalize for comparison
rnn_grads = np.array(rnn_grads)
lstm_grads = np.array(lstm_grads)

if rnn_grads[-1] > 0:
    rnn_grads_norm = rnn_grads / rnn_grads[-1]
else:
    rnn_grads_norm = rnn_grads
    
if lstm_grads[-1] > 0:
    lstm_grads_norm = lstm_grads / lstm_grads[-1]
else:
    lstm_grads_norm = lstm_grads

plt.plot(rnn_grads_norm, label='RNN', linewidth=2)
plt.plot(lstm_grads_norm, label='LSTM', linewidth=2)
plt.xlabel('Position in sequence')
plt.ylabel('Relative gradient magnitude')
plt.title('Gradient Flow: RNN vs LSTM')
plt.legend()
plt.yscale('log')
plt.grid(True)
plt.show()

print(f"Gradient at position 0 relative to final:")
print(f"  RNN:  {rnn_grads_norm[0]:.2e}")
print(f"  LSTM: {lstm_grads_norm[0]:.2e}")

**Key insight:** The LSTM maintains much stronger gradients for early timesteps. The cell state provides a "gradient highway" that bypasses the multiplicative vanishing of vanilla RNNs.

## 6. Generate Text

In [None]:
# Generate with different temperatures
print("Generated text:\n")

for temp in [0.5, 0.8, 1.0, 1.5]:
    seed = char_to_idx['T']
    sample_idx = lstm_model.sample(seed, 150, temperature=temp)
    sample_text = ''.join([idx_to_char[i] for i in sample_idx])
    print(f"Temperature {temp}:")
    print(f"  {sample_text}")
    print()

## 7. Summary

| Component | Formula | Purpose |
|-----------|---------|--------|
| **Forget gate** | $f = \sigma(W_f[h,x] + b_f)$ | What to forget |
| **Input gate** | $i = \sigma(W_i[h,x] + b_i)$ | What to add |
| **Cell candidate** | $\tilde{C} = \tanh(W_c[h,x] + b_c)$ | What to add |
| **Cell state** | $C = f \odot C_{prev} + i \odot \tilde{C}$ | Long-term memory |
| **Output gate** | $o = \sigma(W_o[h,x] + b_o)$ | What to output |
| **Hidden state** | $h = o \odot \tanh(C)$ | Short-term output |

**Key takeaways:**

1. **Gates control information flow** - multiplicative gating is powerful
2. **Cell state is the key** - provides a "highway" for gradient flow
3. **Forget gate initialized to 1** - default is to remember
4. **More parameters than RNN** - but much better at long sequences

**Historical note:** LSTMs (1997) dominated sequence modeling until transformers (2017). Many modern systems still use LSTMs for smaller-scale tasks.

**Next:** [07-attention-from-scratch.ipynb](07-attention-from-scratch.ipynb) introduces attention, which eventually replaced LSTMs for most NLP tasks.

## 8. Exercises

1. **GRU:** Implement a GRU cell (simpler gating than LSTM).

2. **Peephole connections:** Add connections from cell state to gates.

3. **Bidirectional LSTM:** Process sequence in both directions.

4. **Longer sequences:** Train on longer sequences. How does LSTM compare to RNN?

In [None]:
# Exercise 1 starter: Implement GRU
class GRUCell:
    """
    Gated Recurrent Unit.
    
    Simpler than LSTM: 2 gates instead of 3.
    
    Reset gate: r = sigmoid(W_r[h,x])
    Update gate: z = sigmoid(W_z[h,x])  
    Candidate: h_tilde = tanh(W_h[r*h, x])
    Output: h = (1-z)*h_prev + z*h_tilde
    """
    
    def __init__(self, input_dim, hidden_dim):
        # Your implementation here
        pass
    
    def forward(self, x, h_prev):
        # Your implementation here
        pass