In [50]:
import torch, torch.nn as nn, torch.nn.functional as F

In [51]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cpu'

In [3]:
class sLSTMCell(nn.Module):
    def __init__(self, input_size, hidden_size, device):
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.device = device

        self.weight_ih = nn.Parameter(torch.randn(4 * hidden_size, input_size, device=device))
        self.weight_hh = nn.Parameter(torch.randn(4 * hidden_size, hidden_size, device=device))
        self.bias      = nn.Parameter(torch.randn(4 * hidden_size, device=device))

        self.reset_parameters()

    def reset_parameters(self):
        nn.init.xavier_uniform_(self.weight_ih)
        nn.init.xavier_uniform_(self.weight_hh)
        nn.init.zeros_(self.bias)

    def forward(self, input, hx):
        h, c, n, m = hx
        gates = input @ self.weight_ih.T + h @ self.weight_hh.T + self.bias

        z_tilde, i_tilde, f_tilde, o_tilde = gates.chunk(4, 1)

        z = torch.tanh(z_tilde)
        i = torch.exp(i_tilde)
        f = torch.exp(f_tilde)
        o = torch.sigmoid(o_tilde)

        m_t = torch.maximum(torch.log(f) + m, torch.log(i))
        i_prime = torch.exp(torch.log(i) - m_t)
        f_prime = torch.exp(torch.log(f) + m - m_t)

        c = f_prime * c + i_prime * z
        n = f_prime * n + i_prime
        h_tilde = c / n
        h = o * h_tilde

        return h, c, n, m_t

In [4]:
class sLSTM(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, dropout=0.0, device="cpu"):
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.dropout = dropout
        self.device = device
        self.layers = nn.ModuleList([
            sLSTMCell(input_size if i == 0 else hidden_size, hidden_size, device)
            for i in range(num_layers)
        ])
        self.dropout_layer = nn.Dropout(dropout)

    def forward(self, input, hidden_state=None):
        bs, seq_len, _ = input.size()
        if hidden_state is None:
            hidden_state = [(
                torch.zeros(bs, self.hidden_size, device=self.device),
                torch.zeros(bs, self.hidden_size, device=self.device),
                torch.ones (bs, self.hidden_size, device=self.device),
                torch.zeros(bs, self.hidden_size, device=self.device)
            ) for _ in range(self.num_layers)]

        outputs = []
        for t in range(seq_len):
            x = input[:, t, :]
            for layer_idx, layer in enumerate(self.layers):
                h, c, n, m = hidden_state[layer_idx]
                h, c, n, m = layer(x, (h, c, n, m))
                hidden_state[layer_idx] = (h, c, n, m)
                x = self.dropout_layer(h) if layer_idx < self.num_layers - 1 else h
            outputs.append(x)

        return torch.stack(outputs, dim=1), hidden_state

In [5]:
x = torch.randn(1, 10, 64).to(device)
model = sLSTM(64, 128, 2, device=device)
output, states = model(x)
print(output.size())

torch.Size([1, 10, 128])


In [6]:
class mLSTMCell(nn.Module):
    def __init__(self, input_size, hidden_size, device="cpu"):
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.device = device

        self.weight_ih = nn.Parameter(torch.randn(3 * hidden_size, input_size, device=device))
        self.weight_hh = nn.Parameter(torch.randn(3 * hidden_size, hidden_size, device=device))
        self.bias = nn.Parameter(torch.randn(3 * hidden_size, device=device))

        self.w_q = nn.Linear(input_size, hidden_size, device=device)
        self.w_k = nn.Linear(input_size, hidden_size, device=device)
        self.w_v = nn.Linear(input_size, hidden_size, device=device)

        self.reset_parameters()

    def reset_parameters(self):
        nn.init.xavier_uniform_(self.weight_ih)
        nn.init.xavier_uniform_(self.weight_hh)
        nn.init.zeros_(self.bias)
        nn.init.xavier_uniform_(self.w_q.weight)
        nn.init.xavier_uniform_(self.w_k.weight)
        nn.init.xavier_uniform_(self.w_v.weight)
        nn.init.zeros_(self.w_q.bias)
        nn.init.zeros_(self.w_k.bias)
        nn.init.zeros_(self.w_v.bias)

    def forward(self, input, hx):
        h, c = hx
        gates = input @ self.weight_ih.T + h @ self.weight_hh.T + self.bias

        i, f, o = gates.chunk(3, 1)

        i = torch.exp(i) # input gate
        f = torch.exp(f) # forget gate
        o = torch.sigmoid(o) # output gate

        q = self.w_q(input) # query
        k = self.w_k(input) # key
        v = self.w_v(input) # value

        c = f.unsqueeze(2) * c + i.unsqueeze(2) * torch.bmm(v.unsqueeze(2), k.unsqueeze(1)) # cell state
        h = o * torch.bmm(q.unsqueeze(1), c).squeeze(1) # hidden state

        return h, c

In [7]:
class mLSTM(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, dropout=0.0, device="cpu"):
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.dropout = dropout
        self.device = device
        self.layers = nn.ModuleList([
            mLSTMCell(input_size if i == 0 else hidden_size, hidden_size, device=device)
            for i in range(num_layers)
        ])
        self.dropout_layer = nn.Dropout(dropout)

    def forward(self, input, hidden_state=None):
        bs, seq_len, _ = input.size()
        if hidden_state is None:
            hidden_state = [(
                torch.zeros(bs, self.hidden_size, device=self.device),
                torch.zeros(bs, self.hidden_size, device=self.device)
            ) for _ in range(self.num_layers)]

        outputs = []
        for t in range(seq_len):
            x = input[:, t, :]
            for layer_idx, layer in enumerate(self.layers):
                h, c = hidden_state[layer_idx]
                h, c = layer(x, (h, c))
                hidden_state[layer_idx] = (h, c)
                x = self.dropout_layer(h) if layer_idx < self.num_layers - 1 else h
            outputs.append(x)

        return torch.stack(outputs, dim=1), hidden_state

In [8]:
x = torch.randn(1, 10, 64).to(device)
model = mLSTM(64, 128, 2, device=device)
output, states = model(x)
print(output.size())

torch.Size([1, 10, 128])


In [9]:
class xLSTMBlock(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, dropout=0.0, lstm_type="slstm", device="cpu"):
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.dropout = dropout
        self.lstm_type = lstm_type
        self.device = device

        if self.lstm_type == "slstm":
            self.lstm = sLSTM(input_size, hidden_size, num_layers, dropout, device=device)
        if self.lstm_type == "mlstm":
            self.lstm = mLSTM(input_size, hidden_size, num_layers, dropout, device=device)

        self.norm = nn.LayerNorm(hidden_size)
        self.act = nn.GELU()
        self.dropout_layer = nn.Dropout(dropout)
        self.proj = nn.Linear(hidden_size, input_size)

    def forward(self, input, hidden_state=None):
        lstm_output, hidden_state = self.lstm(input, hidden_state)
        output = self.act(lstm_output)
        output = self.norm(output)
        output = self.proj(output)
        output = self.dropout_layer(output + input)
        return output, hidden_state

In [10]:
class xLSTM(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_size, num_layers, num_blocks, dropout=0.0, lstm_type="slstm", device=device):
        super().__init__()
        self.vocab_size = vocab_size
        self.embed_dim = embed_dim
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.num_blocks = num_blocks
        self.dropout = dropout
        self.lstm_type = lstm_type

        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.blocks = nn.ModuleList([xLSTMBlock(embed_dim, hidden_size, num_layers, dropout, lstm_type, device=device) for _ in range(self.num_blocks)])
        self.output_layer = nn.Linear(embed_dim, vocab_size)

    def forward(self, input, hidden_state=None):
        embed_seq = self.embedding(input)
        if hidden_state is None:
            hidden_state = [None] * self.num_blocks
        output_seq = embed_seq
        for i, block in enumerate(self.blocks):
            output_seq, hidden_state[i] = block(output_seq, hidden_state[i])
        output_seq = self.output_layer(output_seq)
        return output_seq, hidden_state

In [11]:
vocab_size = 1000
embed_dim = 128
hidden_size = 64
num_layers = 2
num_blocks = 3
dropout = 0.1
lstm_type = "slstm"

model = xLSTM(vocab_size, embed_dim, hidden_size, num_layers, num_blocks, dropout, lstm_type, device=device)
model.to(device)

bs = 4
seq_len = 32
input_data = torch.randint(0, vocab_size, (bs, seq_len)).to(device)

output, hidden_state = model(input_data)

print(f"Input shape: {input_data.shape}")
print(f"Output shape: {output.shape}")

Input shape: torch.Size([4, 32])
Output shape: torch.Size([4, 32, 1000])


In [12]:
!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

--2024-07-12 18:24:41--  https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1115394 (1.1M) [text/plain]
Saving to: ‘input.txt’


2024-07-12 18:24:41 (46.0 MB/s) - ‘input.txt’ saved [1115394/1115394]



In [13]:
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()
len(text)

1115394

In [43]:
# here are all the unique characters that occur in this text
chars = sorted(list(set(text)))
vocab_size = len(chars)
# create a mapping from characters to integers
stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }
encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers
decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string

# Train and test splits
data = torch.tensor(encode(text), dtype=torch.long)
n = int(0.1*len(data)) # first 90% will be train, rest val
train_data = data[n:]
val_data = data[:n]

def get_batch(split):
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    return x.to(device), y.to(device)

In [52]:
block_size = 100  # sequence length
batch_size = 64
embed_dim = 128
hidden_size = 256
num_layers = 2
num_blocks = 3
dropout = 0.1
lstm_type = "slstm"
learning_rate = 0.001
num_epochs = 5
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = xLSTM(vocab_size, embed_dim, hidden_size, num_layers, num_blocks, dropout, lstm_type, device)
model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

In [53]:
from tqdm import tqdm

In [54]:
def train(model, criterion, optimizer, num_epochs):
    model.train()
    for epoch in range(num_epochs):
        total_loss = 0
        num_batches = len(train_data) // (block_size * batch_size)
        progress_bar = tqdm(range(num_batches), desc=f"Epoch {epoch+1}/{num_epochs}")

        for _ in progress_bar:
            batch_input, batch_target = get_batch('train')

            optimizer.zero_grad()
            output, _ = model(batch_input)
            loss = criterion(output.view(-1, vocab_size), batch_target.view(-1))

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()

            total_loss += loss.item()
            progress_bar.set_postfix({"Loss": f"{loss.item():.4f}"})

        avg_loss = total_loss / num_batches
        print(f"Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_loss:.4f}")

        # Validation
        model.eval()
        with torch.no_grad():
            val_loss = 0
            num_val_batches = len(val_data) // (block_size * batch_size)
            for _ in range(num_val_batches):
                batch_input, batch_target = get_batch('val')
                output, _ = model(batch_input)
                val_loss += criterion(output.view(-1, vocab_size), batch_target.view(-1)).item()
            avg_val_loss = val_loss / num_val_batches
            print(f"Validation Loss: {avg_val_loss:.4f}")
        model.train()

    print("Training completed.")

In [55]:
train(model, criterion, optimizer, num_epochs)

Epoch 1/5: 100%|██████████| 156/156 [14:41<00:00,  5.65s/it, Loss=1.8946]


Epoch 1/5, Average Loss: 2.3960
Validation Loss: 1.9230


Epoch 2/5: 100%|██████████| 156/156 [15:41<00:00,  6.04s/it, Loss=1.6215]


Epoch 2/5, Average Loss: 1.7498
Validation Loss: 1.6623


Epoch 3/5: 100%|██████████| 156/156 [16:33<00:00,  6.37s/it, Loss=1.5156]


Epoch 3/5, Average Loss: 1.5843
Validation Loss: 1.5606


Epoch 4/5: 100%|██████████| 156/156 [16:34<00:00,  6.38s/it, Loss=1.4924]


Epoch 4/5, Average Loss: 1.5008
Validation Loss: 1.5116


Epoch 5/5: 100%|██████████| 156/156 [16:47<00:00,  6.46s/it, Loss=1.4191]


Epoch 5/5, Average Loss: 1.4434
Validation Loss: 1.4659
Training completed.


In [56]:
def generate_text(model, start_text, length=200, temperature=1.0):
    model.eval()
    context = torch.tensor(encode(start_text), dtype=torch.long).unsqueeze(0).to(device)
    generated_text = start_text

    with torch.no_grad():
        for _ in range(length):
            output, _ = model(context)
            probs = (output[0, -1] / temperature).softmax(dim=-1)
            next_char_idx = torch.multinomial(probs, num_samples=1).item()
            generated_text += itos[next_char_idx]
            context = torch.cat((context, torch.tensor([[next_char_idx]], device=device)), dim=1)
            if context.size(1) > block_size:
                context = context[:, -block_size:]

    return generated_text

In [57]:
sample_text = generate_text(model, start_text="The ", length=1024, temperature=0.7)
print("Generated Text:")
print(sample_text)

Generated Text:
The departed true heavens did so fair last.

LUCIO:
How dost thou more straight not perform this body with his
happy brother of the ready of the ears make me.

GLOUCESTER:
Where is that I is a man the deep.

QUEEN MARGARET:
He had it who can shall hear me but a guest,
How you shall accomple to thee with my son.

SLY:
You shall be more of mine of time and so wit.

DUCHESS OF YORK:
Now, which end your dease is heart of care.

POLIXENES:
The prince and scares of a steal and calls
And thou wilt with me so proparion of a bornes we
Anople that your hungnot of my sit
What is the bounds conspainted of men.

CATESBY:
Good marriage with stir what may be death,
For come not see my heart be, thou wilt the father will
The preorer hand, my lord, if a general comes
That even my more king's hundly part,
And where those lady think it were that I,
And before the jest have peace interching.

KING RICHARD III:
Or, therefore false father, fair so nothing;
For I may know her should not have 