In [1]:
import torch, torch.nn as nn, torch.nn.functional as F
import math

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

In [3]:
class sLSTMCell(nn.Module):
    def __init__(self, input_size, hidden_size, device):
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.device = device

        self.weight_ih = nn.Parameter(torch.randn(4 * hidden_size, input_size, device=device))
        self.weight_hh = nn.Parameter(torch.randn(4 * hidden_size, hidden_size, device=device))
        self.bias      = nn.Parameter(torch.randn(4 * hidden_size, device=device))

        self.reset_parameters()

    def reset_parameters(self):
        nn.init.xavier_uniform_(self.weight_ih)
        nn.init.xavier_uniform_(self.weight_hh)
        nn.init.zeros_(self.bias)

    def forward(self, input, hx):
        h, c, n, m = hx
        gates = input @ self.weight_ih.T + h @ self.weight_hh.T + self.bias

        z_tilde, i_tilde, f_tilde, o_tilde = gates.chunk(4, 1)

        z = torch.tanh(z_tilde)
        i = torch.exp(i_tilde)
        f = torch.exp(f_tilde)
        o = torch.sigmoid(o_tilde)

        m_t = torch.maximum(torch.log(f) + m, torch.log(i))
        i_prime = torch.exp(torch.log(i) - m_t)
        f_prime = torch.exp(torch.log(f) + m - m_t)

        c = f_prime * c + i_prime * z
        n = f_prime * n + i_prime
        h_tilde = c / n
        h = o * h_tilde

        return h, c, n, m_t

In [4]:
class sLSTM(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, dropout=0.0, device="cpu"):
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.dropout = dropout
        self.device = device
        self.layers = nn.ModuleList([
            sLSTMCell(input_size if i == 0 else hidden_size, hidden_size, device)
            for i in range(num_layers)
        ])
        self.dropout_layer = nn.Dropout(dropout)

    def forward(self, input, hidden_state=None):
        bs, seq_len, _ = input.size()
        if hidden_state is None:
            hidden_state = [(
                torch.zeros(bs, self.hidden_size, device=self.device),
                torch.zeros(bs, self.hidden_size, device=self.device),
                torch.ones (bs, self.hidden_size, device=self.device),
                torch.zeros(bs, self.hidden_size, device=self.device)
            ) for _ in range(self.num_layers)]

        outputs = []
        for t in range(seq_len):
            x = input[:, t, :]
            for layer_idx, layer in enumerate(self.layers):
                h, c, n, m = hidden_state[layer_idx]
                h, c, n, m = layer(x, (h, c, n, m))
                hidden_state[layer_idx] = (h, c, n, m)
                x = self.dropout_layer(h) if layer_idx < self.num_layers - 1 else h
            outputs.append(x)

        return torch.stack(outputs, dim=1), hidden_state

In [5]:
x = torch.randn(1, 10, 64).to(device)
model = sLSTM(64, 128, 2, device=device)
output, states = model(x)
print(output.size())

torch.Size([1, 10, 128])


In [6]:
class mLSTMCell(nn.Module):
    def __init__(self, input_size, hidden_size, device="cpu"):
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.device = device

        # Input, forget, and output gates
        self.w_i = nn.Parameter(torch.randn(hidden_size, input_size, device=device))
        self.w_f = nn.Parameter(torch.randn(hidden_size, input_size, device=device))
        self.w_o = nn.Parameter(torch.randn(hidden_size, input_size, device=device))
        self.b_i = nn.Parameter(torch.zeros(hidden_size, device=device))
        self.b_f = nn.Parameter(torch.zeros(hidden_size, device=device))
        self.b_o = nn.Parameter(torch.zeros(hidden_size, device=device))
        
        self.w_q = nn.Linear(input_size, hidden_size, device=device)
        self.w_k = nn.Linear(input_size, hidden_size, device=device)
        self.w_v = nn.Linear(input_size, hidden_size, device=device)

        self.reset_parameters()

    def reset_parameters(self):
        nn.init.xavier_uniform_(self.w_i)
        nn.init.xavier_uniform_(self.w_f)
        nn.init.xavier_uniform_(self.w_o)
        nn.init.zeros_(self.b_i)
        nn.init.zeros_(self.b_f)
        nn.init.zeros_(self.b_o)
        nn.init.xavier_uniform_(self.w_q.weight)
        nn.init.xavier_uniform_(self.w_k.weight)
        nn.init.xavier_uniform_(self.w_v.weight)
        nn.init.zeros_(self.w_q.bias)
        nn.init.zeros_(self.w_k.bias)
        nn.init.zeros_(self.w_v.bias)

    def forward(self, input, hx):
        h, c, n = hx
        
        # compute gates
        i_t = torch.exp(input @ self.w_i.T + self.b_i) # input_gate
        f_t = torch.sigmoid(input @ self.w_f.T + self.b_f) # forget_gate
        o_t = torch.sigmoid(input @ self.w_o.T + self.b_o) # output_gate
        
        q_t = self.w_q(input) # query
        k_t = self.w_k(input) / math.sqrt(self.hidden_size) # key
        v_t = self.w_v(input) # value
        
        # update cell state and normalizer state
        c = f_t * c + i_t * (v_t * k_t) # cell_state
        n = f_t * n + i_t * k_t # normalizer_state
        
        # compute hidden state
        h_tilde = c * q_t
        denom = torch.clamp(torch.abs(n * q_t), min=1.0)
        h = o_t * (h_tilde / denom)

        return h, c, n

In [7]:
class mLSTM(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, dropout=0.0, device="cpu"):
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.dropout = dropout
        self.device = device
        self.layers = nn.ModuleList([
            mLSTMCell(input_size if i == 0 else hidden_size, hidden_size, device=device)
            for i in range(num_layers)
        ])
        self.dropout_layer = nn.Dropout(dropout)

    def forward(self, input, hidden_state=None):
        bs, seq_len, _ = input.size()
        if hidden_state is None:
            hidden_state = [(
                torch.zeros(bs, self.hidden_size, device=self.device),
                torch.zeros(bs, self.hidden_size, device=self.device),
                torch.zeros(bs, self.hidden_size, device=self.device)
            ) for _ in range(self.num_layers)]

        outputs = []
        for t in range(seq_len):
            x = input[:, t, :]
            for layer_idx, layer in enumerate(self.layers):
                h, c, n = hidden_state[layer_idx]
                h, c, n = layer(x, (h, c, n))
                hidden_state[layer_idx] = (h, c, n)
                x = self.dropout_layer(h) if layer_idx < self.num_layers - 1 else h
            outputs.append(x)

        return torch.stack(outputs, dim=1), hidden_state

In [8]:
x = torch.randn(1, 10, 64).to(device)
model = mLSTM(64, 128, 2, device=device)
output, states = model(x)
print(output.size())

torch.Size([1, 10, 128])


In [9]:
class xLSTMBlock(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, dropout=0.0, lstm_type="slstm", device="cpu"):
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.dropout = dropout
        self.lstm_type = lstm_type
        self.device = device

        if self.lstm_type == "slstm":
            self.lstm = sLSTM(input_size, hidden_size, num_layers, dropout, device=device)
        if self.lstm_type == "mlstm":
            self.lstm = mLSTM(input_size, hidden_size, num_layers, dropout, device=device)

        self.norm = nn.LayerNorm(hidden_size)
        self.act = nn.GELU()
        self.dropout_layer = nn.Dropout(dropout)
        self.proj = nn.Linear(hidden_size, input_size)

    def forward(self, input, hidden_state=None):
        lstm_output, hidden_state = self.lstm(input, hidden_state)
        output = self.act(lstm_output)
        output = self.norm(output)
        output = self.proj(output)
        output = self.dropout_layer(output + input)
        return output, hidden_state

In [10]:
class xLSTM(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_size, num_layers, num_blocks, dropout=0.0, lstm_type="slstm", device=device):
        super().__init__()
        self.vocab_size = vocab_size
        self.embed_dim = embed_dim
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.num_blocks = num_blocks
        self.dropout = dropout
        self.lstm_type = lstm_type

        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.blocks = nn.ModuleList([xLSTMBlock(embed_dim, hidden_size, num_layers, dropout, lstm_type, device=device) for _ in range(self.num_blocks)])
        self.output_layer = nn.Linear(embed_dim, vocab_size)

    def forward(self, input, hidden_state=None):
        embed_seq = self.embedding(input)
        if hidden_state is None:
            hidden_state = [None] * self.num_blocks
        output_seq = embed_seq
        for i, block in enumerate(self.blocks):
            output_seq, hidden_state[i] = block(output_seq, hidden_state[i])
        output_seq = self.output_layer(output_seq)
        return output_seq, hidden_state

In [11]:
vocab_size = 1000
embed_dim = 128
hidden_size = 64
num_layers = 2
num_blocks = 3
dropout = 0.1
lstm_type = "mlstm"

model = xLSTM(vocab_size, embed_dim, hidden_size, num_layers, num_blocks, dropout, lstm_type, device=device)
model.to(device)

bs = 4
seq_len = 32
input_data = torch.randint(0, vocab_size, (bs, seq_len)).to(device)

output, hidden_state = model(input_data)

print(f"Input shape: {input_data.shape}")
print(f"Output shape: {output.shape}")

Input shape: torch.Size([4, 32])
Output shape: torch.Size([4, 32, 1000])


In [12]:
!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

--2024-07-13 16:18:55--  https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.110.133, 185.199.111.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1115394 (1.1M) [text/plain]
Saving to: 'input.txt'


2024-07-13 16:18:55 (34.3 MB/s) - 'input.txt' saved [1115394/1115394]



In [13]:
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()
len(text)

1115394

In [14]:
# here are all the unique characters that occur in this text
chars = sorted(list(set(text)))
vocab_size = len(chars)
# create a mapping from characters to integers
stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }
encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers
decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string

# Train and test splits
data = torch.tensor(encode(text), dtype=torch.long)
n = int(0.1*len(data)) # first 90% will be train, rest val
train_data = data[n:]
val_data = data[:n]

def get_batch(split):
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    return x.to(device), y.to(device)

In [18]:
block_size = 100  # sequence length
batch_size = 64
embed_dim = 128
hidden_size = 256
num_layers = 2
num_blocks = 3
dropout = 0.1
lstm_type = "mlstm"
learning_rate = 0.001
num_epochs = 10
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = xLSTM(vocab_size, embed_dim, hidden_size, num_layers, num_blocks, dropout, lstm_type, device)
model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

In [19]:
from tqdm import tqdm

In [20]:
def train(model, criterion, optimizer, num_epochs):
    model.train()
    for epoch in range(num_epochs):
        total_loss = 0
        num_batches = len(train_data) // (block_size * batch_size)
        progress_bar = tqdm(range(num_batches), desc=f"Epoch {epoch+1}/{num_epochs}")

        for _ in progress_bar:
            batch_input, batch_target = get_batch('train')

            optimizer.zero_grad()
            output, _ = model(batch_input)
            loss = criterion(output.view(-1, vocab_size), batch_target.view(-1))

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()

            total_loss += loss.item()
            progress_bar.set_postfix({"Loss": f"{loss.item():.4f}"})

        avg_loss = total_loss / num_batches
        print(f"Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_loss:.4f}")

        # Validation
        model.eval()
        with torch.no_grad():
            val_loss = 0
            num_val_batches = len(val_data) // (block_size * batch_size)
            for _ in range(num_val_batches):
                batch_input, batch_target = get_batch('val')
                output, _ = model(batch_input)
                val_loss += criterion(output.view(-1, vocab_size), batch_target.view(-1)).item()
            avg_val_loss = val_loss / num_val_batches
            print(f"Validation Loss: {avg_val_loss:.4f}")
        model.train()

    print("Training completed.")

In [21]:
train(model, criterion, optimizer, num_epochs)

Epoch 1/10: 100%|██████████| 156/156 [02:42<00:00,  1.04s/it, Loss=1.7723]


Epoch 1/10, Average Loss: 2.2751
Validation Loss: 1.7929


Epoch 2/10: 100%|██████████| 156/156 [02:39<00:00,  1.02s/it, Loss=1.6055]


Epoch 2/10, Average Loss: 1.7135
Validation Loss: 1.6471


Epoch 3/10: 100%|██████████| 156/156 [02:39<00:00,  1.02s/it, Loss=1.5220]


Epoch 3/10, Average Loss: 1.5914
Validation Loss: 1.5653


Epoch 4/10: 100%|██████████| 156/156 [02:38<00:00,  1.01s/it, Loss=1.5310]


Epoch 4/10, Average Loss: 1.5296
Validation Loss: 1.5283


Epoch 5/10: 100%|██████████| 156/156 [02:38<00:00,  1.02s/it, Loss=1.5003]


Epoch 5/10, Average Loss: 1.4842
Validation Loss: 1.4907


Epoch 6/10: 100%|██████████| 156/156 [02:38<00:00,  1.02s/it, Loss=1.4458]


Epoch 6/10, Average Loss: 1.4552
Validation Loss: 1.4766


Epoch 7/10: 100%|██████████| 156/156 [02:38<00:00,  1.01s/it, Loss=1.3985]


Epoch 7/10, Average Loss: 1.4279
Validation Loss: 1.4599


Epoch 8/10: 100%|██████████| 156/156 [02:38<00:00,  1.02s/it, Loss=1.4377]


Epoch 8/10, Average Loss: 1.4090
Validation Loss: 1.4504


Epoch 9/10: 100%|██████████| 156/156 [02:40<00:00,  1.03s/it, Loss=1.3824]


Epoch 9/10, Average Loss: 1.3921
Validation Loss: 1.4349


Epoch 10/10: 100%|██████████| 156/156 [02:40<00:00,  1.03s/it, Loss=1.3683]


Epoch 10/10, Average Loss: 1.3780
Validation Loss: 1.4366
Training completed.


In [22]:
def generate_text(model, start_text, length=200, temperature=1.0):
    model.eval()
    context = torch.tensor(encode(start_text), dtype=torch.long).unsqueeze(0).to(device)
    generated_text = start_text

    with torch.no_grad():
        for _ in range(length):
            output, _ = model(context)
            probs = (output[0, -1] / temperature).softmax(dim=-1)
            next_char_idx = torch.multinomial(probs, num_samples=1).item()
            generated_text += itos[next_char_idx]
            context = torch.cat((context, torch.tensor([[next_char_idx]], device=device)), dim=1)
            if context.size(1) > block_size:
                context = context[:, -block_size:]

    return generated_text

In [23]:
sample_text = generate_text(model, start_text="The ", length=1024, temperature=0.7)
print("Generated Text:")
print(sample_text)

Generated Text:
The guilty shall not see the fond in your commites
Not to vile his friar Romeo standing four pleasure.

KING RICHARD II:
I take her be being like a fear,
And she was be were many profess are to child.

First Servant:
What, lords; the noble begin from the true!

KING RICHARD III:
Farewell.
For mercy means that no man's head
It to discourse the post little law on thy master,
Thou excepts to Lalong to bear the very serve to be thou wilt thou shalt e'er we weep.
It come, when thou hast to a death,
That what can there we say the king as thou art to see.

LUCENTIO:
Why, what I must be it, not seemed with him.

DUKE OF YORK:
A widow, I cannot make thy beauty.

MENENIUS:
How took you, my lord; and so I but the duke now,
And bless the world meet to keep the maid,
And you would to be a, which we are with Paris:
The cares must come to any more jewel?

KING RICHARD III:
Why, I have not so, you may the many heaven.

JULIET:
What, when the order sound sin crown of the earth,
None hat