In [2]:
# Hack to get repo path on remotes
import os
REPOPATH = os.path.abspath('.')
if "MLX_USER" in os.environ:
    REPOPATH = os.path.join("/mlx/users", os.environ["MLX_USER"], "repo/LLM-Tutorials")

In [3]:
%load_ext autoreload
%autoreload 2

In [4]:
import itertools
import math

import numpy as np

import torch
from torch import nn
import torch.nn.functional as F


device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Session 2 – NLP Models and Training Basics (RNN)
  
In this notebook, we will dive into fundamental sequence models such as **RNNs** and **LSTMs**. We’ll also cover basic neural embeddings, training objectives, and see how to implement and train a **simple text generator** using an LSTM.


## Table of Contents

1. [Introduction and Overview](#introduction)
2. [Recurrent Neural Networks (RNNs)](#rnns)
   - [The RNN Cell](#rnn-cell)
   - [Vanishing and Exploding Gradients](#vanishing)
3. [Long Short-Term Memory (LSTM)](#lstm)
   - [Key Intuition Behind LSTM Gates](#lstm-gates)
4. [Embeddings](#embeddings)
5. [Basic Training Objectives in Language Modeling](#training-objectives)
   - [Next Token Prediction](#next-token-pred)
   - [Perplexity](#perplexity)
6. [Implementing a Simple LSTM Text Generator in PyTorch](#implementation)
   - [Data Preparation](#data-prep)
   - [Model Definition](#model-def)
   - [Training Loop](#training-loop)
   - [Generating Text](#generate-text)

Each section will be followed by one or more **Exercises** to help you practice.

# <a id="overview"></a>1. Overview and Setup

This tutorial assumes you have:

- **Basic Python** knowledge.
- A local or cloud environment (e.g., Jupyter, Colab) with **PyTorch** installed.
  - If needed, install PyTorch via `pip install torch` or follow instructions at [pytorch.org](https://pytorch.org/get-started/locally/).

No prior reading of other sessions is required; we’ll present all the essentials here.

### Quick Setup Check
```python
import torch
print("PyTorch version:", torch.__version__)
```

Ensure you see a version number (e.g., `2.0.0` or similar) printed. If you get an error, please install or update PyTorch before continuing.

- We’ll focus on **RNNs** and **LSTMs**.  
- We’ll learn **why** they are powerful for sequential data.  
- We’ll cover **basic training objectives** (like next-token prediction) for language modeling.  
- Finally, we’ll implement a small **LSTM-based text generator**.

**By the end of this session**, you should be able to:
1. Understand how an RNN cell and LSTM cell process sequential data.  
2. Implement an **LSTM** in a deep learning framework (here, PyTorch).  
3. Train and evaluate a **text-generation** model.  

In [5]:
import torch
print("PyTorch version:", torch.__version__)

PyTorch version: 2.5.1


# 2. Recurrent Neural Networks (RNNs)<a id="rnns"></a>

Recurrent Neural Networks are designed to handle **sequential data** by maintaining a hidden state that captures information about previous time steps.

## Key Idea
At each time step $t$:
1. The RNN takes an input $x_t$ and the hidden state from the previous time step $h_{t-1}$.
2. It produces a new hidden state $h_t$.

Mathematically, a very **basic** RNN can be written as:
$$
\begin{aligned}
h_t &= \tanh(W_{hh} h_{t-1} + W_{xh} x_t + b_h) \\
y_t &= W_{hy} h_t + b_y
\end{aligned}
$$

- $h_t$ is the updated hidden state.
- $y_t$ is the output at time step $t$ (used for tasks like classification or next-token prediction).
- $W_{hh}, W_{xh}, W_{hy}$ are learned weight matrices.

**Rearrangement of Terms**

Notice that the term $W_{hh} h_{t-1} + W_{xh} x_t$ uses two matrix multiplications and an addition.
Unless compiled, these two multiplications will be performed sequentially.
We can gain a slight improvement if we concatenate $h$ and $x$, and use a single matrix multiplication by a larger weight matrix:

$$
\begin{aligned}
h_t &= \tanh(W_h H_t + b_h) \\
y_t &= W_{hy} h_t + b_y
\end{aligned}
$$

- $H_t = [h_{t-1}||x_t]$ is the concatenation of $h$ and $x$
- $W_{h} = [W_{hh}||W_{xh}]$ is the cconcatenaation of $W_{hh}, W_{xh}$


<img src="img_src/RNNs.svg"/>


## <a id="rnn-cell"></a>The RNN Cell

The **RNN cell** is the fundamental computational unit. At time step $t$:
1. **Input**: current token (often embedded) + previous hidden state.
2. **Output**: updated hidden state + optional output vector.

If you unroll this cell over time for $T$ steps, you get a **computation graph** that looks like a chain, where each link is an RNN cell.

<img src="img_src/RNN-folded.svg"/>
<img src="img_src/RNN-unfolded.svg"/>

### Exercise: Implement a Toy RNN Cell
**Goal**:  
1. Write a Python function that computes a single time-step of an RNN.  
1. Use NumPy or PyTorch (in NumPy style) to do the matrix multiplication and a `tanh` activation.  
1. Test it on a small input (e.g., input dimension of 5, hidden dimension of 3).

*(Keep it simple—focus on the concept, not a full RNN unrolled over time.)*  

In [6]:
import torch

def rnn_step(x_t, h_prev, Wxh, Whh, bh):
    """Simple RNN step

    Args:
        x_t: shape (batch_size, input_dim)
        h_prev: shape (batch_size, hidden_dim)
        Wxh: shape (input_dim, hidden_dim)
        Whh: shape (hidden_dim, hidden_dim)
        bh: shape (hidden_dim,)
    Returns:
        h_t: shape (batch_size, hidden_dim)
    """
    weighted_h = h_prev @ Whh  # shape: (batch_size, hidden_dim)
    weighted_x = x_t @ Wxh  # shape: (batch_size, hidden_dim)
    linear_hx = weighted_h + weighted_x + bh
    nonlinear_hx = torch.tanh(linear_hx)
    return nonlinear_hx

# Tests -- we only consider shapes here
N = (1, 2, 5)  # Batch sizes
hidden_dims = (1, 2, 5)  # Hidden sizes
input_dims = (1, 2, 5)  # Input sizes

failed_cases = []
for batch_size, hdim, xdim in itertools.product(N, hidden_dims, input_dims):
    x = torch.ones(batch_size, xdim)
    h = torch.zeros(batch_size, hdim)
    Wxh = torch.ones(xdim, hdim)
    Whh = torch.ones(hdim, hdim)
    bh = torch.zeros(hdim)
    expect_shape = (batch_size, hdim)
    with torch.no_grad():
        h_next = rnn_step(x, h, Wxh, Whh, bh)
    if h_next.shape != expect_shape:
        print('x', end='')
        failed_cases.append((h_next.shape, expect_shape))
    else:
        print('.', end='')
print()
for got, expected in failed_cases:
    print(f"{expected} vs. {got}")

...........................


In [7]:
# PyTorch implementation
class RNN(nn.Module):
    def __init__(self, data_dim, state_dim):
        super().__init__(data_dim, state_dim)

        # Define parameters to train
        self.input_linear = nn.Linear(  # Takes [x||h_prev] and produces h_next
            in_features=self.data_dim + self.state_dim,
            out_features=self.state_dim,
            bias=True)
        self.output_linear = nn.Linear(  # Takes h_next and produces y
            in_features=self.state_dim,
            out_features=self.data_dim,
            bias=True)
        self.tanh = nn.Tanh()

    def forward(self, x, h=None):
        # Concatenate x and hidden_state
        if h is None:
            h = torch.zeros(x.shape[0], self.state_dim, device=x.device, dtype=x.dtype)
        xh = torch.hstack([x, h])

        # Compute new hidden state
        xh = self.input_linear(xh)
        h_next = self.tanh(xh)

        # Compute output
        y = self.output_linear(h_next)
        return y, h_next

## <a id="vanishing"></a>Vanishing and Exploding Gradients

**Problem**: Simple RNNs often struggle with **long-term dependencies** due to **vanishing** or **exploding gradients**. That means:
- When sequences are long, the gradient that flows backward through time either becomes extremely small (**vanishes**) or extremely large (**explodes**).
- This makes training unstable or ineffective for capturing long-range context.

**Solution**: Specialized RNN variants like **LSTM** or **GRU** mitigate these issues by incorporating gating mechanisms.

### Research Note: let's invent a GRU (Gated Recurrent Unit)

**RNN** : $h_t = \phi(W_hh_{t-1} + W_xx_{t})$

* **Problem:** To compute the gradient of $h_1$ (or any early token), we need to multiply the gradients by small values in $W_h$, thus **vanishing** it.
* **Solution:** Intelligently choose the previous memory: $h_t = \phi(W_hh_{t-1} + W_xx_{t})$ or $h_t = h_{t-1}$

**RNN with no vanishing** : $h_t = \alpha\odot\hat{h}_t + (1-\alpha)\odot h_{t-1}$, where $\hat{h}_t=\phi(W_hh_{t-1} + W_xx_{t})$

* **Problem:** To compute the gradient of $h_1$ (or any early token), we need to multiply the gradients by large values in $W_h$, thus **exploding** it.
* **Solution:** Intelligently choose to set the previous memory to zero before multiplying it by the weights: $h_t = \phi(W_hh_{t-1} + W_xx_{t})$ or $h_t = \phi(W_xx_{t})$

**RNN with no explosion** : $h_t = \phi(W_h(\beta \odot h_{t-1}) + W_xx_{t})$

* **Problem:** How do we decide on the values of $\alpha$ and $\beta$?
* **Solution:** Don't! Let the data decide (learning)

$$
\begin{aligned}
h_t &= \overbrace{\alpha\odot\underbrace{\phi\left(W_h(\beta \odot h_{t-1}) + W_xx_{t}\right)}_{\text{no explosion}} + (1-\alpha)\odot h_{t-1}}^\text{no vanishing} \\
\text{where}\\
\alpha &= \sigma\left(Ah_{t-1} + Bx_t\right) &&\text{Memory Update Gate}\\
\beta &= \sigma\left(Ch_{t-1} + Dx_t\right) &&\text{Memory Reset Gate}
\end{aligned}
$$

**Congratulations**, you have just invented a **Gated Recurrent Unit** (GRU)!

*An earlier version of a gated recurrent network is [LSTM](https://en.wikipedia.org/wiki/Long_short-term_memory), which follows very similar logic for preserving the long-term context infromation.*

| Network | Complexity | Long-Term Relationship | Gradient Issues |
|---------|--------------|------|----|
| RNN (tanh) | (++) | None | (-) |
| GRU | (+) | (+)<br/>Single state | (++) |
| LSTM | (--) | (++)<br/>Separate state for long and short terms | (++) |

# <a id="lstm"></a>3. Long Short-Term Memory (LSTM)

A **Long Short-Term Memory (LSTM)** network is a type of RNN specifically designed to better capture **long-range dependencies**. It addresses the vanishing/exploding gradient problem through gates that control the flow of information.

### <a id="lstm-gates"></a>Key Intuition Behind LSTM Gates

Typical LSTM equations:

$$\begin{aligned}
f_t &= \sigma(W_f [h_{t-1}, x_t] + b_f)
&\quad(\text{Forget Gate}) \\
i_t &= \sigma(W_i [h_{t-1}, x_t] + b_i)
&\quad(\text{Input Gate}) \\
\tilde{C_t} &= \tanh(W_C [h_{t-1}, x_t] + b_C)
&\quad(\text{Candidate Values}) \\
C_t &= f_t \odot C_{t-1} + i_t \odot \tilde{C_t}
&\quad(\text{Cell State Update}) \\
o_t &= \sigma(W_o [h_{t-1}, x_t] + b_o)
&\quad(\text{Output Gate}) \\
h_t &= o_t \odot \tanh(C_t)
&\quad(\text{New Hidden State})
\end{aligned}$$

- **Forget Gate** ($f_t$): decides how much old state to keep.
- **Input Gate** ($i_t$): decides how much new information to add.
- **Candidate** ($\tilde{C_t}$): proposed update to the cell state.
- **Output Gate** ($o_t$): decides how much cell state to output as hidden state.

This gating mechanism helps **preserve gradients** across many time steps.


### Exercise: Compare RNN and GRU Outputs (or LSTM if you prefer)
1. Create a synthetic sequence of length 20.  
2. Feed it into a small **Vanilla RNN** and a small **GRU** (in PyTorch).  
3. Compare the final hidden states after feeding all time steps. Are they similar? If you vary the length from 20 to 50 to 100, how do the hidden states change?

*Hint*: This is a conceptual experiment. You can use random inputs, then measure how the hidden states drift over longer sequences.

*Hint*: **If you really want**, you can use the utilities in the `utils.py` file to generate simple synthetic sequences (`generate_synthetic_sequences`).

*Hint*: You don't have to train the network, but if you want to you can use `utils.py` (`train_recurrent`)

In [8]:
from torch import nn

batch_size = 3
sequence_length = 1000
xdim = 3
hdim = 2

x = torch.ones(sequence_length, batch_size, xdim)  # Input sequence
h = torch.zeros(1, batch_size, hdim)  # Initial hidden state / initial memory

rnn_model = nn.RNN(input_size=xdim, hidden_size=hdim, num_layers=1, batch_first=False, bidirectional=False)
gru_model = nn.GRU(input_size=xdim, hidden_size=hdim, num_layers=1, batch_first=False, bidirectional=False)
lstm_model = nn.LSTM(input_size=xdim, hidden_size=hdim, batch_first=False)

rnn_model.zero_grad()
gru_model.zero_grad()
lstm_model.zero_grad()

y_rnn, h_rnn = rnn_model(x)
y_gru, h_gru = gru_model(x)
y_lstm, (h_lstm, c_lstm) = lstm_model(x)

print(f"Hidden output shapes: {h_rnn.shape=}, {h_gru.shape=}, {h_lstm.shape=}, {c_lstm.shape=}")

# Very basic error -- just minimizing the norm of the memory
rnn_error = h_rnn.norm()
gru_error = h_gru.norm()
lstm_error = h_lstm.norm() + c_lstm.norm()

rnn_error.backward()
gru_error.backward()
lstm_error.backward()

print(f"Hidden state norms: {rnn_error:.2e}, {gru_error:.2e}, {lstm_error:.2e}")

Hidden output shapes: h_rnn.shape=torch.Size([1, 3, 2]), h_gru.shape=torch.Size([1, 3, 2]), h_lstm.shape=torch.Size([1, 3, 2]), c_lstm.shape=torch.Size([1, 3, 2])
Hidden state norms: 1.61e+00, 1.25e+00, 2.20e+00


In [9]:
rnn_model.weight_hh_l0.grad, gru_model.weight_hh_l0.grad, lstm_model.weight_hh_l0.grad

(tensor([[0.1370, 0.3210],
         [0.1878, 0.4400]]),
 tensor([[-3.1429e-03,  8.0500e-03],
         [-6.9294e-03,  1.7748e-02],
         [ 0.0000e+00,  0.0000e+00],
         [-1.4641e-08,  3.7500e-08],
         [ 1.1975e-02, -3.0671e-02],
         [-1.9021e-01,  4.8720e-01]]),
 tensor([[0.1712, 0.1392],
         [0.0473, 0.0385],
         [0.1563, 0.1271],
         [0.0478, 0.0389],
         [0.1135, 0.0923],
         [0.2680, 0.2180],
         [0.0829, 0.0674],
         [0.0396, 0.0322]]))

# <a id="embeddings"></a>4. Embeddings

**TODO: Embedding and Latent Space Explanation**

When dealing with text, each word or token is usually mapped to an **embedding** vector rather than a large one-hot vector.

- **Embedding Layer**: A learnable matrix that maps token indices to dense vectors of fixed dimension $d$.
- This helps the model learn **semantic relationships** between words.

For example:
- Word “hello” → index 5 → embedding vector $\mathbf{e} \in \mathbb{R}^d$.

Most frameworks (like PyTorch) provide a built-in layer, `nn.Embedding(vocab_size, embed_dim)`, that handles this.


### Exercise 3: Custom Embedding Lookup
1. Create a small vocabulary of 5 tokens.  
2. Initialize a random embedding matrix of shape $(5, d)$.  
3. Write a function that takes a token index and returns the corresponding embedding row.  
4. Compare with `nn.Embedding` in PyTorch for the same matrix initialization.


In [10]:
vocab_size = 5
embed_dim = 3
torch.manual_seed(42)

# Step 1: create a random embedding matrix
embedding_matrix = torch.randn(vocab_size, embed_dim)

def custom_embed_lookup(token_idx, embedding_matrix):
    """
    token_idx: int index (0 <= token_idx < vocab_size)
    embedding_matrix: shape (vocab_size, embed_dim)
    returns: torch.Tensor of shape (embed_dim,)
    """
    return embedding_matrix[token_idx]

# Pick a test token index
test_idx = 2
custom_vec = custom_embed_lookup(test_idx, embedding_matrix)
print("Custom lookup vector:", custom_vec)

# Step 2: Compare with nn.Embedding
embed_layer = nn.Embedding(vocab_size, embed_dim)
# Overwrite the embedding_layer's weights with our random matrix
with torch.no_grad():
    embed_layer.weight.copy_(embedding_matrix)

# Now let's see if it matches:
with torch.no_grad():
    torch_vec = embed_layer(torch.tensor([test_idx]))
print("nn.Embedding lookup vector:", torch_vec.squeeze(0))

# They should be (almost) identical
print("Difference:", (custom_vec - torch_vec.squeeze(0)).abs().sum().item())


Custom lookup vector: tensor([ 2.2082, -0.6380,  0.4617])
nn.Embedding lookup vector: tensor([ 2.2082, -0.6380,  0.4617])
Difference: 0.0


# <a id="training-objectives"></a>5. Training Objectives in Language Modeling

In language modeling, a typical goal is **next-token prediction**: given the previous tokens, predict the next one. We often use **cross-entropy loss** and measure model performance with **perplexity**.

### <a id="next-token-prediction"></a>Next Token Prediction

For a vocabulary of size $V$, the model outputs a probability distribution over the next token:
$$
P(x_t \mid x_{t-1}, x_{t-2}, \ldots, x_1)
$$
The training loss for a sequence might be:
$$
\mathcal{L} = -\sum_{t}\log P(\hat{x}_t = x_t)
$$
where $ x_t $ is the ground truth and $\hat{x}_t$ is the predicted distribution.


### <a id="perplexity"></a>Perplexity

**Perplexity (PPL)** is a common metric for language models:
$$
\text{PPL} = \exp\left(-\frac{1}{N}\sum_{t=1}^{N} \log P(x_t)\right),
$$
where $N$ is the total number of tokens in the test set. Lower PPL typically means a better language model.


### Exercise 4: Manual Cross-Entropy
- Let your model output a probability vector $[0.2, 0.3, 0.1, 0.4]$ for a 4-word vocabulary.  
- Suppose the correct label is index 3. Manually compute cross-entropy.  
- Compare with `torch.nn.functional.cross_entropy` to confirm your result.

In [11]:
import torch.nn.functional as F

probs = torch.tensor([0.2, 0.3, 0.1, 0.4])
true_label = 3  # index 3 is the correct label

# 1) Manual cross-entropy
manual_ce = -math.log(probs[true_label].item())

# 2) Using PyTorch (note that F.cross_entropy expects logits, not probabilities!)
# So we need to convert probabilities => logits with log-softmax inverse => logit = log(p_i / 1)
# But simpler is to do cross_entropy on log(prob) by building a single "batch" example:
logits = torch.log(probs).unsqueeze(0)  # shape (1, 4)
targets = torch.tensor([true_label])    # shape (1,)

ce_torch = F.nll_loss(logits, targets)  # nll_loss expects log-probabilities
# or equivalently: ce_torch = F.cross_entropy(logits, targets) if we interpret logits as log-probs

print("Manual cross-entropy:", manual_ce)
print("PyTorch cross-entropy:", ce_torch.item())


Manual cross-entropy: 0.916290716972994
PyTorch cross-entropy: 0.9162907004356384


# <a id="implementation"></a>6. Implementing a Simple LSTM Text Generator in PyTorch

Let’s build a small example that:
1. **Prepares a tiny text dataset**.
2. Splits it into input–target pairs for next-token prediction.
3. Defines and trains an LSTM-based model.
4. **Generates** text from the trained model.

### <a id="data-prep"></a>Data Preparation

In [12]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np

# For reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Example small text
text = "hello world hello pytorch hello world again"

# Tokenize (word-level for simplicity)
words = text.split()
vocab = list(set(words))
vocab_size = len(vocab)
word2idx = {w: i for i, w in enumerate(vocab)}
idx2word = {i: w for w, i in word2idx.items()}

print("Vocabulary:", vocab)
print("Mapping (word -> idx):", word2idx)
print("Vocab size:", vocab_size)

# Convert words to indices
indices = [word2idx[w] for w in words]

# We'll choose a sequence length
seq_length = 3

# Prepare training data
input_sequences = []
target_words = []

for i in range(len(indices) - seq_length):
    input_seq = indices[i:i+seq_length]   # 3 words
    target = indices[i+seq_length]        # the 4th word is the label
    input_sequences.append(input_seq)
    target_words.append(target)

input_sequences = torch.tensor(input_sequences, dtype=torch.long)
target_words = torch.tensor(target_words, dtype=torch.long)

print("Input sequences shape:", input_sequences.shape)
print("Target words shape:", target_words.shape)

Vocabulary: ['hello', 'world', 'pytorch', 'again']
Mapping (word -> idx): {'hello': 0, 'world': 1, 'pytorch': 2, 'again': 3}
Vocab size: 4
Input sequences shape: torch.Size([4, 3])
Target words shape: torch.Size([4])


In [13]:
for seq, targ in zip(input_sequences, target_words):
    seq = seq.numpy()
    targ = targ.item()
    seq_detokenized = list(map(idx2word.get, seq))
    targ_detokenized = idx2word.get(targ)
    print(f"{seq} => {targ}")
    print(f"  {seq_detokenized} => {targ_detokenized}")

[0 1 0] => 2
  ['hello', 'world', 'hello'] => pytorch
[1 0 2] => 0
  ['world', 'hello', 'pytorch'] => hello
[0 2 0] => 1
  ['hello', 'pytorch', 'hello'] => world
[2 0 1] => 3
  ['pytorch', 'hello', 'world'] => again


### <a id="model-def"></a>Model Definition


In [14]:
class SimpleLSTM(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim):
        super(SimpleLSTM, self).__init__()
        # 1) Embedding layer
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        # 2) LSTM layer
        self.lstm = nn.LSTM(embed_dim, hidden_dim, batch_first=True)
        # 3) Linear output layer
        self.fc = nn.Linear(hidden_dim, vocab_size)

    def forward(self, x):
        # x: (batch_size, seq_length)
        embedded = self.embedding(x)                # (batch_size, seq_length, embed_dim)
        lstm_out, (h_n, c_n) = self.lstm(embedded)  # (batch_size, seq_length, hidden_dim)
        final_hidden = lstm_out[:, -1, :]           # last time step
        logits = self.fc(final_hidden)              # (batch_size, vocab_size)
        return logits

### <a id="training-loop"></a>Training Loop

In [15]:
embed_dim = 8
hidden_dim = 16
learning_rate = 0.01
num_epochs = 200

model = SimpleLSTM(vocab_size, embed_dim, hidden_dim)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

for epoch in range(num_epochs):
    model.train()

    optimizer.zero_grad()
    logits = model(input_sequences)  # shape: (batch_size, vocab_size)
    loss = criterion(logits, target_words)

    loss.backward()
    optimizer.step()

    if (epoch+1) % 50 == 0:
        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}")

Epoch [50/200], Loss: 0.0098
Epoch [100/200], Loss: 0.0031
Epoch [150/200], Loss: 0.0020
Epoch [200/200], Loss: 0.0014


### <a id="generate-text"></a>Generating Text

We can now generate text by **sampling** the model’s predictions iteratively.

Feel free to experiment with:
- **Different seeds**.
- **Different sampling strategies** (e.g., greedy vs. top-k).  
- A **larger corpus** (like Tiny Shakespeare).

In [16]:
def generate_text(model, seed_words, num_words=5):
    model.eval()
    words_generated = seed_words[:]

    # Convert seed_words to indices
    current_seq = [word2idx[w] for w in seed_words]

    for _ in range(num_words):
        inp = torch.tensor([current_seq], dtype=torch.long)
        with torch.no_grad():
            logits = model(inp)  # shape: (1, vocab_size)
        probs = torch.softmax(logits, dim=-1).squeeze()  # shape: (vocab_size,)

        # Sample from probability distribution
        next_idx = torch.multinomial(probs, 1).item()
        next_word = idx2word[next_idx]
        words_generated.append(next_word)

        # Slide the window (drop the first index, append new index)
        current_seq = current_seq[1:] + [next_idx]

    return " ".join(words_generated)

# Let's try generating with a seed of length = seq_length (3)
seed = ["hello", "world", "hello"]  # must be in vocab
generated_text = generate_text(model, seed, num_words=5)
print("Generated Text:", generated_text)

Generated Text: hello world hello pytorch hello world again hello


### Exercise: Experiment with the Generator
1. Change the `num_words` to 10 or 20 and see if your text generation forms any repetitive patterns.  
2. Try a **larger** dataset if you have one. Compare the coherence of the generated text.  
3. Print out intermediate hidden states if you’re curious about how the model’s representation changes over time.


In [17]:
# Suppose 'model' is our trained LSTM model, 'word2idx' and 'idx2word' are our mappings.

def generate_text_with_hidden(model, seed_words, num_words=10):
    """
    Generate text from the model, returning the hidden states as well.
    """
    model.eval()
    words_generated = seed_words[:]
    hidden_states = []   # store hidden states at each step

    current_seq = [word2idx[w] for w in seed_words]

    # Hidden and cell state, if needed
    # We'll assume 1-layer LSTM, batch_size=1
    h, c = None, None

    for _ in range(num_words):
        inp = torch.tensor([current_seq], dtype=torch.long)
        with torch.no_grad():
            # Modify forward pass to capture intermediate hidden states
            # We can do this by running the embedding + LSTM manually:
            embedded = model.embedding(inp)  # shape (1, seq_length, embed_dim)
            # We pass in (h, c) if they exist, otherwise let the LSTM init them
            lstm_out, (h, c) = model.lstm(embedded, (h, c) if h is not None else None)

            # final time step
            final_hidden = lstm_out[:, -1, :]

            # For debugging: store the hidden state in a list
            hidden_states.append(final_hidden.detach().cpu().numpy())

            logits = model.fc(final_hidden)
            probs = torch.softmax(logits, dim=-1).squeeze()
            next_idx = torch.multinomial(probs, 1).item()

        next_word = idx2word[next_idx]
        words_generated.append(next_word)
        current_seq = current_seq[1:] + [next_idx]

    return " ".join(words_generated), hidden_states

# Example usage
seed = ["hello", "world", "hello"]
generated_text, h_states = generate_text_with_hidden(model, seed, num_words=8)
print("Generated Text:\n", generated_text)
print("Intermediate hidden states shapes:")
for i, hs in enumerate(h_states):
    print(f" Step {i+1}: {hs.shape}")


Generated Text:
 hello world hello pytorch hello world again pytorch hello world again
Intermediate hidden states shapes:
 Step 1: (1, 16)
 Step 2: (1, 16)
 Step 3: (1, 16)
 Step 4: (1, 16)
 Step 5: (1, 16)
 Step 6: (1, 16)
 Step 7: (1, 16)
 Step 8: (1, 16)


# Fun Things -- Shakespeare

### Step 1: Pre-process the textual data

In [18]:
import os

DATA_PATH = os.path.abspath(f"{REPOPATH}/shakespeare")
filelist = os.listdir(DATA_PATH)
filelist = list(map(lambda f: os.path.join(DATA_PATH, f), filelist))

In [19]:

def load_shakespeare_texts(filelist):
    """
    Loads all text from filelist and returns all texts concatenated
    """
    all_text = ""

    for file_path in filelist:
        with open(file_path, 'r', encoding='utf-8') as f:
            all_text += f.read() + "\n"  # add a newline at the end of each file

    return all_text

full_text = load_shakespeare_texts(filelist)
print("Total length of combined text:", len(full_text))


Total length of combined text: 5283837


**Character-Level Tokenization**

Since this is a character-level model, our “tokens” are just unique characters found in the text:

* Identify the unique set of characters.
* Map each character to a unique integer index.

In [20]:
import torch
import numpy as np

# Create vocabulary of unique characters
chars = sorted(list(set(full_text)))
vocab_size = len(chars)

print("Unique chars found:", vocab_size)
print("Example of characters:", chars[:50])

# Create mapping from character to index (and reverse)
_char2idx = {ch: i for i, ch in enumerate(chars)}
_idx2char = {i: ch for ch, i in _char2idx.items()}

# Add special characters
for special_token in ["<|UNK|>"]:
    k = len(_char2idx)
    _char2idx[special_token] = k
    _idx2char[k] = special_token

# Utility functions
def char2idx(ch):
    return [_char2idx.get(c, "<|UNK|>") for c in ch]
def idx2char(idx):
    if isinstance(idx, torch.Tensor):
        return idx2char(idx.detach().cpu().numpy())
    if isinstance(idx, np.ndarray):
        return idx2char(idx.tolist())
    if isinstance(idx, int):
        return _idx2char[idx]
    return [idx2char(i) for i in idx]


Unique chars found: 80
Example of characters: ['\t', '\n', ' ', '!', '$', '&', "'", '(', ')', ',', '-', '.', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', ':', ';', '?', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y']


**Convert Text to Indices**

Convert the entire text into a list (or array) of integer indices. This will make it easier to feed into PyTorch.

In [21]:
# Convert all text to indices
data_as_indices = char2idx(full_text)  # [_char2idx[ch] for ch in full_text]
data_tensor = torch.tensor(data_as_indices, dtype=torch.long)
print("data_tensor shape:", data_tensor.shape)

data_tensor shape: torch.Size([5283837])


**Create Training Sequences**

For character-level language modeling, a common approach is:

* Pick a sequence length, e.g. seq_length = 100.
* For each sequence of seq_length characters, the target is the next character.

We can use PyTorch's `Dataset`...

In [22]:
from torch.utils.data import Dataset, DataLoader

class CharDataset(Dataset):
    def __init__(self, data_tensor, seq_length):
        self.data = data_tensor
        self.seq_length = seq_length

    def __len__(self):
        # We can form this many sequences (minus 1 for the target)
        return len(self.data) // self.seq_length - 1

    def __getitem__(self, idx):
        start = idx * self.seq_length
        x_seq = self.data[start : start + self.seq_length]
        # Targets are the subsequent seq_length characters
        y_seq = self.data[start+1 : start + self.seq_length + 1]
        return x_seq, y_seq

seq_length = 30
dataset = CharDataset(data_tensor, seq_length=seq_length)
print("Dataset size:", len(dataset))

# For demonstration, let's get one example
example_x, example_y = dataset[0]
print("Example X (indices):", example_x[:10])
print("Example Y (index):", example_y[:10])
print("Example X (decoded):", idx2char(example_x[:10]))# "".join(idx2char(i.item()) for i in example_x[:30]))
print("Example Y (decoded):", idx2char(example_y[:10]))


Dataset size: 176126
Example X (indices): tensor([ 0, 25,  2, 37, 33, 28, 43, 45, 37, 37])
Example Y (index): tensor([25,  2, 37, 33, 28, 43, 45, 37, 37, 29])
Example X (decoded): ['\t', 'A', ' ', 'M', 'I', 'D', 'S', 'U', 'M', 'M']
Example Y (decoded): ['A', ' ', 'M', 'I', 'D', 'S', 'U', 'M', 'M', 'E']


In [23]:
dataset[0][0]

tensor([ 0, 25,  2, 37, 33, 28, 43, 45, 37, 37, 29, 42,  2, 38, 33, 31, 32, 44,
         6, 43,  2, 28, 42, 29, 25, 37,  1,  1,  1,  0])

In [24]:
batch_size = 100

def batch_second(batch):
    x, y = list(zip(*batch))
    x = torch.stack(x, 1)
    y = torch.stack(y, 1)

    return x, y

dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True, pin_memory=True, collate_fn=batch_second)

### Step 2: Model Definition (LSTM)

We’ll define a character-level LSTM model:

1. Embedding: maps integer character indices to dense vectors (optional, but often helps).
1. LSTM: one or more LSTM layers that process the embedded sequence.
1. Linear: output layer to predict the next character’s index.

In [25]:
import torch.nn as nn

class CharLSTM(nn.Module):
    def __init__(self, vocab_size, embed_dim=128, hidden_dim=256, num_layers=2):
        super().__init__()
        self.vocab_size = vocab_size
        self.embed_dim = embed_dim
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers

        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.lstm = nn.LSTM(embed_dim, hidden_dim, num_layers=num_layers, batch_first=False)
        self.fc = nn.Linear(hidden_dim, vocab_size)

    def forward(self, x, hidden_state=None):
        """
        x: (seq_length, batch_size)
        hidden_state: tuple (h, c) for LSTM hidden/cell states (if you want to pass it in)
        Returns: logits (seq_length, batch_size, vocab_size), updated_hidden_state
        """
        # 1) Embedding
        embedded = self.embedding(x)  # shape: (seq_length, batch_size, embed_dim)

        # 2) LSTM
        if hidden_state is None:
            out, (h, c) = self.lstm(embedded)  # out: (seq_length, batch_size, hidden_dim)
        else:
            out, (h, c) = self.lstm(embedded, hidden_state)

        # 3) Fully connected (we want to produce a prediction at each time step)
        logits = self.fc(out)  # shape: (seq_length, batch_size, vocab_size)

        return logits, (h, c)

    def init_hidden(self, batch_size):
        """
        Utility to initialize the hidden state (h, c) to zeros.
        Returns: h0, c0 (num_layers, batch_size, hidden_dim)
        """
        device = next(self.parameters()).device
        h0 = torch.zeros(self.num_layers, batch_size, self.hidden_dim, device=device)
        c0 = torch.zeros(self.num_layers, batch_size, self.hidden_dim, device=device)
        return (h0, c0)


class CharGRU(nn.Module):
    def __init__(self, vocab_size, embed_dim=128, hidden_dim=256, num_layers=2):
        super().__init__()
        self.vocab_size = vocab_size
        self.embed_dim = embed_dim
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers

        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.gru = nn.GRU(embed_dim, hidden_dim, num_layers=num_layers, batch_first=False)
        self.fc = nn.Linear(hidden_dim, vocab_size)

    def forward(self, x, hidden_state=None):
        """
        x: (seq_length, batch_size)
        hidden_state: tuple h for GRU hidden state (if you want to pass it in)
        Returns: logits (seq_length, batch_size, vocab_size), updated_hidden_state
        """
        # 1) Embedding
        embedded = self.embedding(x)  # shape: (seq_length, batch_size, embed_dim)

        # 2) GRU
        if hidden_state is None:
            out, h = self.gru(embedded)  # out: (seq_length, batch_size, hidden_dim)
        else:
            out, h = self.gru(embedded, hidden_state)

        # 3) Fully connected (we want to produce a prediction at each time step)
        logits = self.fc(out)  # shape: (seq_length, batch_size, vocab_size)

        return logits, h

    def init_hidden(self, batch_size):
        """
        Utility to initialize the hidden state h to zeros.
        Returns: h0 (num_layers, batch_size, hidden_dim)
        """
        device = next(self.parameters()).device
        h0 = torch.zeros(self.num_layers, batch_size, self.hidden_dim, device=device)
        return h0


### Step 3: Training Routine

**Training Setup**

We define:

* A loss function (CrossEntropyLoss), typical for next-character prediction.
* An optimizer (e.g., Adam or RMSprop).
* Possibly device (CPU or GPU).

In [26]:
import torch.optim as optim

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print("Using device:", device)

model = CharLSTM(vocab_size, embed_dim=512, hidden_dim=512, num_layers=3)
# model = CharGRU(vocab_size, embed_dim=512, hidden_dim=512, num_layers=3)
model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.8, patience=5)  # Reduce learning rate by half

history = []

Using device: cuda


**Training Loop**

At each iteration:

1. Get a batch (x, y) from the dataloader. Here, x is of shape (batch_size, seq_length) and y of shape (batch_size,).
1. Model outputs logits of shape (batch_size, seq_length, vocab_size).
1. We actually want to predict the character that comes after each character in x. So we can shift by 1 step or simply note that y at index i is the final character of the sequence. But if we want a prediction at each time step (not just the last one), we might create labels of shape (batch_size, seq_length)—one label per input character.

In the example below, we do the simplest approach: each sequence’s final character is the label. This means we use only the last time step’s logits to compute the loss. Alternatively, if you want to predict the next character at every time step, you’ll need to shift the labels accordingly. (We’ll show the typical approach of every time step.)

**Case: Predict next char at every time step**

We shift our target by 1 inside the dataset or handle it here. Let’s assume we do it at the dataset level for clarity.

In [27]:
%%time
# num_epochs = 1000
num_epochs = 1
model.train()

for epoch in range(num_epochs):
    total_loss = 0.0
    for x_seq, y_seq in dataloader:
        x_seq, y_seq = x_seq.to(device), y_seq.to(device)

        # Reset gradients
        optimizer.zero_grad()

        # Initialize hidden state
        hidden_state = model.init_hidden(batch_size=x_seq.size(1))

        # Forward pass
        logits, hidden_state = model(x_seq, hidden_state)
        # logits: (batch_size, seq_length, vocab_size)

        # Reshape logits and targets for cross-entropy
        # We want CE across all time steps
        logits_reshaped = logits.view(-1, vocab_size)   # (batch_size*seq_length, vocab_size)
        targets_reshaped = y_seq.view(-1)               # (batch_size*seq_length,)

        loss = criterion(logits_reshaped, targets_reshaped)
        loss.backward()

        optimizer.step()

        total_loss += loss.item()

    avg_loss = total_loss / len(dataloader)
    if epoch == 0 or (epoch + 1) % 10 == 0 or epoch + 1 == num_epochs:
        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {avg_loss:.4f}, Correct prediction: {math.exp(-avg_loss):.1%}")
    scheduler.step(avg_loss)
    history.append(avg_loss)


Epoch 1/1, Loss: 1.7219, Correct prediction: 17.9%
CPU times: user 31.6 s, sys: 375 ms, total: 32 s
Wall time: 31.7 s


In [28]:
# (Optional) Save the model for future use
MODEL_SAVE_PATH = os.path.join(REPOPATH, "models")
os.makedirs(MODEL_SAVE_PATH, exist_ok=True)

model_name = "shakespeare_lstm_checkpoint"

checkpoint_path = os.path.join(MODEL_SAVE_PATH, f"{model_name}.pt")
torch.save({
    "last_epoch": epoch,
    "last_loss": avg_loss,
    "history": history,
    "model_state_dict": model.state_dict(),
    "optimizer_state_dict": optimizer.state_dict(),
    "scheduler_state_dict": scheduler.state_dict(),
}, checkpoint_path)


### Step 4: Inference

In [29]:
def generate_text(model, start_string="ROMEO:", length=500, temperature=1.0):
    """
    Generates text one character at a time.
    - start_string: initial prompt
    - length: number of characters to generate
    - temperature: sampling diversity (1.0 => neutral, >1 => more random)
    """
    model.eval()

    # Convert start string to indices
    input_indices = char2idx(start_string)  # [char2idx(ch) for ch in start_string]
    input_tensor = torch.tensor([input_indices], dtype=torch.long, device=device).transpose(0, 1)

    # Initialize hidden state
    hidden_state = model.init_hidden(batch_size=1)

    # "Warm up" the model with the start string
    _, hidden_state = model(input_tensor, hidden_state)
    # for i in range(len(start_string) - 1):
    #     # feed each char except the last one
    #     print(input_tensor[:, i:i+1], hidden_state)
    #     _, hidden_state = model(input_tensor[:, i:i+1], hidden_state)

    # The last character in start_string
    last_char_idx = input_tensor[:, -1]
    output_text = start_string

    # Now generate 'length' more characters
    for _ in range(length):
        logits, hidden_state = model(last_char_idx.unsqueeze(1), hidden_state)
        # logits shape: (1, 1, vocab_size)
        logits = logits[-1, :, :]  # take the last time step => shape (1, vocab_size)

        # Apply temperature
        logits = logits / temperature

        probs = torch.softmax(logits, dim=-1).squeeze()  # shape (vocab_size,)
        next_idx = torch.multinomial(probs, 1).item()

        # Append to output
        next_char = idx2char(next_idx)
        output_text += next_char

        # Update last_char_idx
        last_char_idx = torch.tensor([next_idx], device=device)

    return output_text

# Example usage after training:
generated = generate_text(model, start_string="ROMEO:", length=300, temperature=0.8)
print("Generated text:\n", generated)


Generated text:
 ROMEO:)	|


SCENE	Betwine not one. Rest well, or I will as all the Murder, his feeding both as fetch in thine own; here than a fair sight before thee.

GUIDERIUS	And the place how not there's in this adjon the injurious back: his head in you.

PIANDA	What, sir.

BANTIO	Nay, althood, well! of them.

LADY B


In [30]:
shAIkspear = []
seeds = ["JULIET:", "ROMEO:", "[", "Booboo dog ", "to be or not to be "]
length = 500

for s in seeds:
    generated = generate_text(model, start_string=s, length=length, temperature=0.8)
    # print("Generated text:\n", generated)
    shAIkspear.append((s, generated))

In [31]:
for s, g in shAIkspear:
    print(f"--- Seed: '{s}' ---")
    print(g)
    print(f"--- STOP ---\n\n")

--- Seed: 'JULIET:' ---
JULIET:
	|  thine is a particious blood of the earth.

PANDARUS	Prithee, my most injurious shall pass of his death, and hath blament to the suitable of lady to enough'd of wits sate to the world for mine prompent.
	A new in his holigated whom here word by villain's lordshor, when he do he doth from to tell who present dishops of mine.

	[Knight]




	KING HENRY IV


ACT V



SCENE V	Think the same to perplaint as a rich, day did such on the eye direction, with the image and rest, his eyes of the breath
--- STOP ---


--- Seed: 'ROMEO:' ---
ROMEO:)

LUCIUS	Do you down me no.

DUKE VINCENTIO	What I may attended you.

	[Enter PANTOLIO]

COUNTESS	Ay not into me.

CASSIUS	What brief you therefore more feather, you will be the cardinal: name.

QUEEN MARGARET	O, my lord, not faith.

WELUS	He gone, so true grow upon my corrub again; he that you not out, with her honour's majestor.

OPHELIA	My name to said, my lord;
	A world at York, old that I am mine may have sprink
	Y

---


# Let's train on some non-English characters

Found some here: https://github.com/aboutjm/Automation/blob/master/book

In [32]:
import os
import urllib.request
import urllib.parse

DATA_PATH = os.path.join(REPOPATH, "cnbook")
os.makedirs(DATA_PATH, exist_ok=True)

BASE_URL = "https://raw.githubusercontent.com/aboutjm/Automation/master/book/%5B三体1-3%2B三体X修订增补%5DTXT精校版.刘慈欣/"
filelist = [
    "三体1疯狂年代.txt",
    "三体2黑暗森林.txt",
    "三体3死神永生.txt",
    "三体X修订增补版.txt",
]

filepaths = []

for idx, f in enumerate(filelist):
    url = BASE_URL + f
    encoded_url = urllib.parse.quote(url, safe=':/%')
    file_path = os.path.join(DATA_PATH, f"file{idx}.txt")
    if not os.path.exists(file_path):
        print(f"Downloading {encoded_url}...")
        !wget $encoded_url -O $file_path -q
    filepaths.append(file_path)


In [33]:
import re

texts = []

for fpath in filepaths:
    with open(fpath, "r", encoding="gbk") as f:
        text = f.read()
        texts.append(text)
merged_text = "\n\n".join(texts)

# Cleanup: Replace \u3000 with space
merged_text = merged_text.replace("\u3000", " ")
# Cleanup: Replace “ and ” with "
merged_text = merged_text.replace("“", '"')
merged_text = merged_text.replace("”", '"')
# Cleanup: Replace double blank characters with a single ones
# merged_text = merged_text.replace("  ", " ")
merged_text = re.sub(r"[ ]+", " ", merged_text)
merged_text = re.sub(r"[\t]+", "\t", merged_text)
merged_text = re.sub(r"[\n]+", "\n", merged_text)

In [34]:
print(merged_text[:100])
print(merged_text[-100:])

 三体（中国科幻基石丛书） 
 刘慈欣著
 "基石"是个平实的词，不够"炫"，却能够准确传达我们对构建中的中国科幻繁华巨厦的情感与信心，因此，我们用它来作为这套原创丛书的名字。
 最近十年，是科幻创作
增辉不少。正是因为fengziying同学和其他网友的热情支持和鼓励，才让笔者终于下定决心去整理和修订这部极不成熟的作品。希望不会让大家太失望。
 Isaiah（phenixus）
10.12.28



**Character-Level Tokenization**

Since this is a character-level model, our “tokens” are just unique characters found in the text:

* Identify the unique set of characters.
* Map each character to a unique integer index.

In [35]:
import torch
import numpy as np

from collections import Counter

# Create vocabulary of unique characters
char_counts = Counter(merged_text)
chars = sorted(char_counts.keys(), key=char_counts.get, reverse=True)
vocab_size = len(chars)

print("Unique chars found:", vocab_size)
print("Example of characters (top):", chars[:50])
print("Example of characters (bot):", chars[-50:])

# Create mapping from character to index (and reverse)
_char2idx = {ch: i for i, ch in enumerate(chars)}
_idx2char = {i: ch for ch, i in _char2idx.items()}

# Add special characters
for special_token in ["<|UNK|>"]:
    k = len(_char2idx)
    _char2idx[special_token] = k
    _idx2char[k] = special_token

# Utility functions
def char2idx(ch):
    return [_char2idx.get(c, "<|UNK|>") for c in ch]
def idx2char(idx):
    if isinstance(idx, torch.Tensor):
        return idx2char(idx.detach().cpu().numpy())
    if isinstance(idx, np.ndarray):
        return idx2char(idx.tolist())
    if isinstance(idx, int):
        return _idx2char.get(idx, "<|UNK|>")
    return [idx2char(i) for i in idx]


Unique chars found: 3772
Example of characters (top): ['，', '的', '。', '"', '一', ' ', '是', '\n', '了', '在', '这', '不', '个', '有', '他', '人', '到', '中', '们', '我', '上', '地', '时', '来', '那', '大', '说', '着', '能', '出', '看', '和', '后', '你', '就', '面', '也', '可', '现', '没', '都', '她', '对', '但', '过', '星', '？', '下', '太', '子']
Example of characters (bot): ['贷', '茹', '豹', '彷', '徨', '缥', '缈', '茗', '荑', '宛', '凰', '旬', '捺', '狐', '猴', '迭', '赦', '朱', '茧', '睥', '睨', '伫', '氤', '氲', '▽', '◇', '霄', '谚', '嗫', '嚅', '宸', '谕', '鸢', '踌', '躇', '匾', '炒', '隘', '齑', '酬', '敝', '帚', '诟', '佬', '吭', '琢', '裆', '怂', '恿', '@']


**Convert Text to Indices**

Convert the entire text into a list (or array) of integer indices. This will make it easier to feed into PyTorch.

In [36]:
# Convert all text to indices
data_as_indices = char2idx(merged_text)
data_tensor = torch.tensor(data_as_indices, dtype=torch.long)
print("data_tensor shape:", data_tensor.shape)

data_tensor shape: torch.Size([1017810])


**Create Training Sequences**

For character-level language modeling, a common approach is:

* Pick a sequence length, e.g. seq_length = 100.
* For each sequence of seq_length characters, the target is the next character.

We can use PyTorch's `Dataset`...

In [37]:
from torch.utils.data import Dataset, DataLoader

class CharDataset(Dataset):
    def __init__(self, data_tensor, seq_length):
        self.data = data_tensor
        self.seq_length = seq_length

    def __len__(self):
        # We can form this many sequences (minus 1 for the target)
        return len(self.data) // self.seq_length - 1

    def __getitem__(self, idx):
        start = idx * self.seq_length
        x_seq = self.data[start : start + self.seq_length]
        # Targets are the subsequent seq_length characters
        y_seq = self.data[start+1 : start + self.seq_length + 1]
        return x_seq, y_seq

seq_length = 30
dataset = CharDataset(data_tensor, seq_length=seq_length)
print("Dataset size:", len(dataset))

# For demonstration, let's get one example
example_x, example_y = dataset[0]
print("Example X (indices):", example_x[:10])
print("Example Y (index):", example_y[:10])
print(f"Example X (decoded): \"{''.join(idx2char(example_x[:10]))}\"")
print(f"Example Y (decoded): \"{''.join(idx2char(example_y[:10]))}\"")


Dataset size: 33926
Example X (indices): tensor([   5,   58,   54, 1777,   17,  178,  347,  688,  300,  458])
Example Y (index): tensor([  58,   54, 1777,   17,  178,  347,  688,  300,  458, 1371])
Example X (decoded): " 三体（中国科幻基石"
Example Y (decoded): "三体（中国科幻基石丛"


In [38]:
batch_size = 100

def batch_second(batch):
    x, y = list(zip(*batch))
    x = torch.stack(x, 1)
    y = torch.stack(y, 1)

    return x, y

dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True, pin_memory=True, collate_fn=batch_second)

### Step 2: Model Definition (LSTM)

We’ll define a character-level LSTM model:

1. Embedding: maps integer character indices to dense vectors (optional, but often helps).
1. LSTM: one or more LSTM layers that process the embedded sequence.
1. Linear: output layer to predict the next character’s index.

In [39]:
import torch.nn as nn

class CharLSTM(nn.Module):
    def __init__(self, vocab_size, embed_dim=128, hidden_dim=256, num_layers=2):
        super(CharLSTM, self).__init__()
        self.vocab_size = vocab_size
        self.embed_dim = embed_dim
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers

        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.lstm = nn.LSTM(embed_dim, hidden_dim, num_layers=num_layers, batch_first=False)
        self.fc = nn.Linear(hidden_dim, vocab_size)

    def forward(self, x, hidden_state=None):
        """
        x: (batch_size, seq_length)
        hidden_state: tuple (h, c) for LSTM hidden/cell states (if you want to pass it in)
        Returns: logits (batch_size, seq_length, vocab_size), updated_hidden_state
        """
        # 1) Embedding
        embedded = self.embedding(x)  # shape: (batch_size, seq_length, embed_dim)

        # 2) LSTM
        if hidden_state is None:
            out, (h, c) = self.lstm(embedded)  # out: (batch_size, seq_length, hidden_dim)
        else:
            out, (h, c) = self.lstm(embedded, hidden_state)

        # 3) Fully connected (we want to produce a prediction at each time step)
        logits = self.fc(out)  # shape: (batch_size, seq_length, vocab_size)

        return logits, (h, c)

    def init_hidden(self, batch_size):
        """
        Utility to initialize the hidden state (h, c) to zeros.
        Returns: h0, c0 (num_layers, batch_size, hidden_dim)
        """
        device = next(self.parameters()).device
        h0 = torch.zeros(self.num_layers, batch_size, self.hidden_dim, device=device)
        c0 = torch.zeros(self.num_layers, batch_size, self.hidden_dim, device=device)
        return (h0, c0)


### Step 3: Training Routine

**Training Setup**

We define:

* A loss function (CrossEntropyLoss), typical for next-character prediction.
* An optimizer (e.g., Adam or RMSprop).
* Possibly device (CPU or GPU).

In [40]:
import torch.optim as optim

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print("Using device:", device)

model = CharLSTM(vocab_size, embed_dim=1024, hidden_dim=512, num_layers=3)
model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.8, patience=5)  # Reduce learning rate by half

history = []

Using device: cuda


**Training Loop**

At each iteration:

1. Get a batch (x, y) from the dataloader. Here, x is of shape (batch_size, seq_length) and y of shape (batch_size,).
1. Model outputs logits of shape (batch_size, seq_length, vocab_size).
1. We actually want to predict the character that comes after each character in x. So we can shift by 1 step or simply note that y at index i is the final character of the sequence. But if we want a prediction at each time step (not just the last one), we might create labels of shape (batch_size, seq_length)—one label per input character.

In the example below, we do the simplest approach: each sequence’s final character is the label. This means we use only the last time step’s logits to compute the loss. Alternatively, if you want to predict the next character at every time step, you’ll need to shift the labels accordingly. (We’ll show the typical approach of every time step.)

**Case: Predict next char at every time step**

We shift our target by 1 inside the dataset or handle it here. Let’s assume we do it at the dataset level for clarity.

In [41]:
%%time

MODEL_SAVE_PATH = os.path.join(REPOPATH, "models")
os.makedirs(MODEL_SAVE_PATH, exist_ok=True)

num_epochs = 1
model.train()

history = []
best_loss = float("inf")

for epoch in range(num_epochs):
    total_loss = 0.0
    for x_seq, y_seq in dataloader:
        x_seq, y_seq = x_seq.to(device), y_seq.to(device)

        # Reset gradients
        optimizer.zero_grad()

        # Initialize hidden state
        hidden_state = model.init_hidden(batch_size=x_seq.size(1))

        # Forward pass
        logits, hidden_state = model(x_seq, hidden_state)
        # logits: (batch_size, seq_length, vocab_size)

        # Reshape logits and targets for cross-entropy
        # We want CE across all time steps
        logits_reshaped = logits.view(-1, vocab_size)   # (batch_size*seq_length, vocab_size)
        targets_reshaped = y_seq.view(-1)               # (batch_size*seq_length,)

        loss = criterion(logits_reshaped, targets_reshaped)
        loss.backward()

        optimizer.step()

        total_loss += loss.item()

    avg_loss = total_loss / len(dataloader)
    if epoch == 0 or (epoch + 1) % 10 == 0 or epoch + 1 == num_epochs:
        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {avg_loss:.4f}, Correct prediction: {math.exp(-avg_loss):.1%}")
    scheduler.step(avg_loss)
    history.append(avg_loss)

model_name = f"checkpoint_cn_lstm.pt"
model_path = os.path.join(MODEL_SAVE_PATH, model_name)
torch.save({
    "last_epoch": epoch,
    "last_loss": avg_loss,
    "history": history,
    "model_state_dict": model.state_dict(),
    "optimizer_state_dict": optimizer.state_dict(),
    "scheduler_state_dict": scheduler.state_dict(),
}, model_path)


Epoch 1/1, Loss: 6.1692, Correct prediction: 0.2%
CPU times: user 9.72 s, sys: 153 ms, total: 9.88 s
Wall time: 13.3 s


In [42]:
model_name = f"checkpoint_cn_lstm.pt"
model_path = os.path.join(MODEL_SAVE_PATH, model_name)
torch.save({
    "last_epoch": epoch,
    "last_loss": avg_loss,
    "history": history,
    "model_state_dict": model.state_dict(),
    "optimizer_state_dict": optimizer.state_dict(),
    "scheduler_state_dict": scheduler.state_dict(),
}, model_path)

### Step 4: Inference

In [43]:
def generate_text(model, start_string="ROMEO:", length=500, temperature=1.0):
    """
    Generates text one character at a time.
    - start_string: initial prompt
    - length: number of characters to generate
    - temperature: sampling diversity (1.0 => neutral, >1 => more random)
    """
    model.eval()

    # Convert start string to indices
    input_indices = char2idx(start_string)  # [char2idx(ch) for ch in start_string]
    input_tensor = torch.tensor([input_indices], dtype=torch.long, device=device).transpose(0, 1)

    # Initialize hidden state
    hidden_state = model.init_hidden(batch_size=1)

    # "Warm up" the model with the start string
    _, hidden_state = model(input_tensor, hidden_state)
    # for i in range(len(start_string) - 1):
    #     # feed each char except the last one
    #     print(input_tensor[:, i:i+1], hidden_state)
    #     _, hidden_state = model(input_tensor[:, i:i+1], hidden_state)

    # The last character in start_string
    last_char_idx = input_tensor[:, -1]
    output_text = start_string

    # Now generate 'length' more characters
    for _ in range(length):
        logits, hidden_state = model(last_char_idx.unsqueeze(1), hidden_state)
        # logits shape: (1, 1, vocab_size)
        logits = logits[-1, :, :]  # take the last time step => shape (1, vocab_size)

        # Apply temperature
        logits = logits / temperature

        probs = torch.softmax(logits, dim=-1).squeeze()  # shape (vocab_size,)
        next_idx = torch.multinomial(probs, 1).item()

        # Append to output
        next_char = idx2char(next_idx)
        output_text += next_char

        # Update last_char_idx
        last_char_idx = torch.tensor([next_idx], device=device)

    return output_text

# Example usage after training:
generated = generate_text(model, start_string=" ", length=300, temperature=0.8)
print("Generated text:\n", generated)


Generated text:
  

 

"但故忘先，是这牛测。"那千。了"土个出，当是而最中十本，"可但是 要人人"不光重的，""这的太果在号有造快进！说落是没我有的围便支学确确程大有，没你孤到概，没
 是一个吹经，有为为不的，但 三么少过A的，才
"这体，朝假的您它雪回。 她由星，以公觉，失乐，就文眠遗义的用中，？ "是的。
 "，"三个所们[女权浮，们国个主会，但进续是把在亚速出是物部么不所大赫高中。"这太到。以了您到一部斯，要很果一辑都在是进现。去到界地中人这膛的色切距，"号的："
 电人我这野。
  A的次长出的，
当少在，""张击的小，"自窗的，那的想着：他他

"什白。"
 
 ""其过伸是但转子的所在一厅的现


In [44]:
cn_texts = []
seeds = ["国王", " ", "[", "Booboo dog ", "to be or not to be "]
length = 500

for s in seeds:
    generated = generate_text(model, start_string=s, length=length, temperature=0.8)
    # print("Generated text:\n", generated)
    cn_texts.append((s, generated))

In [45]:
for s, g in cn_texts:
    print(f"--- Seed: '{s}' ---")
    print(g)
    print(f"--- STOP ---\n\n")

--- Seed: '国王' ---
国王。什么，我该自行，乐认了什天旧人后，转世到自人有的，们是可舰行物以去恒多的力人明的是泡下，在人发着的人并欢差，了生为用想能情。"
 
 "有脚给后在是被已过，他你部子也你不那个在奇击一信切在将很战的时口。 有显东跑长，我是它人又你能你出这谁人都背子算顿，但它是"是和们那。

  这好，庄是这时的的过，没像地的吃个处大的要人，他在于是在像都不有是人敬二看的三位，他具己您城一威划，"至牲他人是他这句阿，我这果灭巨透确。
 "当，。  是来么其，
"打北？ 
  青无个一地下已地，地到了"。有一心了，但这始程学不就有都的根你。" 能者，智空一显地光星觉的很再的学样没将地于都那觉是的用应了油心。
 "有在我们。 一太后。
"
 要空小的长感，， 同更是还那里的的宙临的真件条星，们一间军扯来，"一现世这况的光活。这官住更， "但很凭百分呵了是一分的动前了一面。""
界目快外枪的两义的宙国色，只一是常览某失多都这心了经平，还在只"号扎程，不们，说于
 其但"启外人所国赞，感世在地活务组三个空冷的源还还他以可一自体。
 "
  "之人，有得到传人在他：可很无天步要，从这时这阔烟星们是，这个样下确面成
--- STOP ---


--- Seed: ' ' ---
 可"在在是出，目外的人的欢》、研觉的博体，好知么，我"是你没或星了多泛地爆菲，通忆，但了出一子需到返空，间曲，大迪战球，"在然，他"早能。 "但得你"是，"以，一么。
 那位象人很个想的，看在A时的的新地动和，一面，"他你是国东几喊的一中的着，他"这向，没他我罗办身到们；是之经的了不的说那一工前，架得而。十么的凭字能是非与就要很族的头的，一空，！ 知刻出。们，以一够可"只都他"的上文天死没是知大吧，相多都但我的意己界8知条，你在了之通如出都一行的头弹来出"的船，一色忽海几播是们有严维的的止信运子，那么这强太个的重础，到看在的量围联划有。" "们可… 
 想"以，"人你没章测卡，这道冬子"自间想为类。们
 这生到物头，"把以来地了吞出的止中穿斯，然说， 们。" 就一子以现很的各多计文明，仰变上们程主因缘界的电，"有上是又停是他触到着的，中为的她由目个，所以不"代万上确战，能要他五万经元人不所真咒简下，也我来质，这完子的那么讲的着义的上四维学发了熵到地十小的，界搜空全社为，亮面有

In [46]:
model


CharLSTM(
  (embedding): Embedding(3772, 1024)
  (lstm): LSTM(1024, 512, num_layers=3)
  (fc): Linear(in_features=512, out_features=3772, bias=True)
)

In [48]:
total_params = 0
for param in model.parameters():
    total_params += param.numel()

print(f"Number of parameters: {total_params:,}")

Number of parameters: 13,149,884
