In [7]:

from torch.utils.data import Dataset, DataLoader
import torch
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

In [8]:
class RWKV_TimeMix(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.ctx_len = config.ctx_len
        self.n_head = config.n_head
        self.head_size = config.n_embd // config.n_head

        # Time weighting
        self.time_w = nn.Parameter(torch.ones(self.n_head, config.ctx_len))
        self.time_alpha = nn.Parameter(torch.ones(self.n_head, 1, config.ctx_len))
        self.time_beta = nn.Parameter(torch.ones(self.n_head, config.ctx_len, 1))

        self.key = nn.Linear(config.n_embd, config.n_embd)
        self.value = nn.Linear(config.n_embd, config.n_embd)
        self.receptance = nn.Linear(config.n_embd, config.n_embd)

        self.output = nn.Linear(config.n_embd, config.n_embd)

    def forward(self, x):
        B, T, C = x.size()

        # Time weighting
        w = F.pad(self.time_w, (0, self.ctx_len - T))
        w = w.unsqueeze(0).unsqueeze(2)  # Add batch and head size dimensions
        w = torch.tile(w, (B, 1, T, T))  # Make sure w has the shape [B, n_head, T, T]
        w = w[:, :, :T, :T] * self.time_alpha * self.time_beta  # Now it has the right shape

        k = self.key(x)
        v = self.value(x)
        r = self.receptance(x)

        sum_k = torch.cumsum(k, dim=1)
        kv = (k * v).view(B, T, self.n_head, self.head_size)

        # Use .reshape() instead of .view()
        wkv = torch.einsum('bhtu,buhc->bthc', w, kv).reshape(B, T, C)
        rwkv = torch.sigmoid(r) * wkv / sum_k

        return self.output(rwkv)



class RWKV_ChannelMix(nn.Module):
    def __init__(self, config):
        super().__init__()
        hidden_size = 4 * config.n_embd  # Intermediate size
        self.key = nn.Linear(config.n_embd, hidden_size)  # Project to a higher dimension
        self.value = nn.Linear(config.n_embd, hidden_size)  # Same for value
        self.receptance = nn.Linear(config.n_embd, hidden_size)  # Adjust receptance to hidden_size
        self.output = nn.Linear(hidden_size, config.n_embd)  # Output back to original embedding size

    def forward(self, x):
        k = F.mish(self.key(x))  # Apply non-linearity to key projection
        v = self.value(x)  # Value projection
        r = torch.sigmoid(self.receptance(x))  # Receptance is now the same size as k and v

        # Element-wise multiplication of k, v, and r
        return self.output(k * v * r)


class RWKV_Block(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.ln1 = nn.LayerNorm(config.n_embd)
        self.ln2 = nn.LayerNorm(config.n_embd)
        self.time_mix = RWKV_TimeMix(config)
        self.channel_mix = RWKV_ChannelMix(config)

    def forward(self, x):
        x = x + self.time_mix(self.ln1(x))
        x = x + self.channel_mix(self.ln2(x))
        return x

class RWKV_Model(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config  # Store config in the model
        self.token_embedding = nn.Embedding(config.vocab_size, config.n_embd)
        self.blocks = nn.ModuleList([RWKV_Block(config) for _ in range(config.num_layers)])
        self.ln_out = nn.LayerNorm(config.n_embd)
        self.head = nn.Linear(config.n_embd, config.vocab_size)

    def forward(self, x):
        x = self.token_embedding(x)
        for block in self.blocks:
            x = block(x)
        x = self.ln_out(x)
        return self.head(x)






class RWKV_Config:
    def __init__(self, vocab_size, n_embd, ctx_len, num_layers, n_head):
        self.vocab_size = vocab_size
        self.n_embd = n_embd
        self.ctx_len = ctx_len
        self.num_layers = num_layers
        self.n_head = n_head






In [9]:
class TextDataset(Dataset):
    def __init__(self, tokens, context_length):
        self.tokens = tokens  # Tokens are passed directly, no need to call tokenizer again
        self.context_length = context_length

    def __len__(self):
        return len(self.tokens) - self.context_length

    def __getitem__(self, idx):
        input_seq = self.tokens[idx:idx + self.context_length]
        target_seq = self.tokens[idx + 1:idx + self.context_length + 1]
        return torch.tensor(input_seq), torch.tensor(target_seq)

# Function to decode tokens back to text
def decode_tokens(tokens, reverse_vocab):
    return ' '.join(reverse_vocab[token] for token in tokens)

def simple_tokenizer(text):
    # Split by whitespace to create tokens
    tokens = text.split()  
    # Create vocab dictionary mapping words to unique IDs
    vocab = {word: i for i, word in enumerate(set(tokens))}
    # Create reverse vocab dictionary to decode back
    reverse_vocab = {i: word for word, i in vocab.items()}
    # Tokenize the text
    tokenized_text = [vocab[word] for word in tokens]
    return tokenized_text, vocab, reverse_vocab

In [10]:
# Sample text data for training
text_data = "This is a simple language model example for testing the RWKV model training implementation."

In [11]:
# Tokenize the text and prepare dataset
tokenized_text, vocab, reverse_vocab = simple_tokenizer(text_data) # tokenizer now returns both tokenized text and vocab
print("Tokenized Text:", tokenized_text)
print("Vocabulary:", vocab)
print("Reverse Vocabulary:", reverse_vocab)

# Decoding back the tokenized text
decoded_text = decode_tokens(tokenized_text, reverse_vocab)
print("Decoded Text:", decoded_text)

context_length = 10  # How many tokens to consider as context for each training example

Tokenized Text: [2, 5, 9, 1, 7, 3, 8, 11, 12, 6, 10, 3, 4, 0]
Vocabulary: {'implementation.': 0, 'simple': 1, 'This': 2, 'model': 3, 'training': 4, 'is': 5, 'the': 6, 'language': 7, 'example': 8, 'a': 9, 'RWKV': 10, 'for': 11, 'testing': 12}
Reverse Vocabulary: {0: 'implementation.', 1: 'simple', 2: 'This', 3: 'model', 4: 'training', 5: 'is', 6: 'the', 7: 'language', 8: 'example', 9: 'a', 10: 'RWKV', 11: 'for', 12: 'testing'}
Decoded Text: This is a simple language model example for testing the RWKV model training implementation.


In [12]:
# Update TextDataset to receive only the tokenized text
dataset = TextDataset(tokenized_text, context_length)
# Create the DataLoader
train_dataloader = DataLoader(dataset, batch_size=2, shuffle=True)


In [13]:
def train_rwkv(model, dataloader, optimizer, criterion, epochs=10):
    model.train()
    for epoch in range(epochs):
        total_loss = 0.0
        for batch in dataloader:
            optimizer.zero_grad()
            inputs, targets = batch
            print("inputs: ", inputs)
            print("targets: ", targets)
            outputs = model(inputs)
            print("outputs: ", outputs)
            # Access vocab size from the model's output layer
            loss = criterion(outputs.view(-1, model.head.out_features), targets.view(-1))
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        avg_loss = total_loss / len(dataloader)
        print(f"Epoch {epoch+1}, Loss: {avg_loss:.4f}")

In [14]:
# RWKV Model Training
config = RWKV_Config(vocab_size=len(vocab), n_embd=256, ctx_len=context_length, num_layers=6, n_head=8)
model = RWKV_Model(config)

optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()


from torchvision.utils import save_image
print(model)

# Train the RWKV model
train_rwkv(model, train_dataloader, optimizer, criterion, epochs=10)

RWKV_Model(
  (token_embedding): Embedding(13, 256)
  (blocks): ModuleList(
    (0-5): 6 x RWKV_Block(
      (ln1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
      (ln2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
      (time_mix): RWKV_TimeMix(
        (key): Linear(in_features=256, out_features=256, bias=True)
        (value): Linear(in_features=256, out_features=256, bias=True)
        (receptance): Linear(in_features=256, out_features=256, bias=True)
        (output): Linear(in_features=256, out_features=256, bias=True)
      )
      (channel_mix): RWKV_ChannelMix(
        (key): Linear(in_features=256, out_features=1024, bias=True)
        (value): Linear(in_features=256, out_features=1024, bias=True)
        (receptance): Linear(in_features=256, out_features=1024, bias=True)
        (output): Linear(in_features=1024, out_features=256, bias=True)
      )
    )
  )
  (ln_out): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
  (head): Linear(in_features=

In [15]:
def generate_text(model, tokenizer, vocab, reverse_vocab, prompt, max_new_tokens=50, context_length=10):
    model.eval()  # Set model to evaluation mode
    tokens = tokenizer(prompt)[0]  # Tokenize the input prompt
    print("Initial tokens:", tokens)

    # Initialize generated sequences with padding if necessary
    if len(tokens) < context_length:
        generated = [tokens + [0] * (context_length - len(tokens))] * 2  # Pad with zeros
    else:
        generated = [tokens[-context_length:]] * 2  # Take the last `context_length` tokens
    print("Initialized generated sequences:", generated)

    for _ in range(max_new_tokens):
        context_tokens = []

        for gen in generated:
            # Only take the last 'context_length' tokens
            context = gen[-context_length:]
            context_tokens.append(context)

        # Convert the context tokens into a tensor
        input_tensor = torch.tensor(context_tokens)  # Shape should be (2, context_length)
        print("Input tensor shape:", input_tensor.shape)
        print("Input tensor:", input_tensor)

        # Pass through the model
        logits = model(input_tensor)  # Get logits for next token
        print("Logits shape:", logits.shape)

        # Process logits to get the next tokens
        next_token_logits = logits[:, -1, :]  # Get logits for the last token in each sequence
        next_tokens = torch.argmax(next_token_logits, dim=-1)  # Choose the most likely next token
        print("Next tokens:", next_tokens)

        # Append the new tokens to the generated sequences
        for i in range(len(generated)):
            generated[i].append(next_tokens[i].item())  # Add the predicted token to the generated sequence

            # If the length exceeds context_length, remove the oldest token
            if len(generated[i]) > context_length:
                generated[i].pop(0)

    # Convert generated tokens back to text using the reverse vocab
    generated_texts = [decode_tokens(gen, reverse_vocab) for gen in generated]
    print("Final generated texts:", generated_texts)

    return generated_texts


In [17]:
prompt = "This is a simple language model example for testing the"
generated_text = generate_text(model, simple_tokenizer, vocab, reverse_vocab, prompt, max_new_tokens=20)
print(generated_text)

Initial tokens: [1, 3, 7, 0, 5, 2, 6, 8, 9, 4]
Initialized generated sequences: [[1, 3, 7, 0, 5, 2, 6, 8, 9, 4], [1, 3, 7, 0, 5, 2, 6, 8, 9, 4]]
Input tensor shape: torch.Size([2, 10])
Input tensor: tensor([[1, 3, 7, 0, 5, 2, 6, 8, 9, 4],
        [1, 3, 7, 0, 5, 2, 6, 8, 9, 4]])
Logits shape: torch.Size([2, 10, 13])
Next tokens: tensor([1, 1])
Input tensor shape: torch.Size([2, 10])
Input tensor: tensor([[7, 0, 5, 2, 6, 8, 9, 4, 1, 1],
        [7, 0, 5, 2, 6, 8, 9, 4, 1, 1]])
Logits shape: torch.Size([2, 10, 13])
Next tokens: tensor([11, 11])
Input tensor shape: torch.Size([2, 10])
Input tensor: tensor([[ 5,  2,  6,  8,  9,  4,  1,  1, 11, 11],
        [ 5,  2,  6,  8,  9,  4,  1,  1, 11, 11]])
Logits shape: torch.Size([2, 10, 13])
Next tokens: tensor([1, 5])
Input tensor shape: torch.Size([2, 10])
Input tensor: tensor([[ 6,  8,  9,  4,  1,  1, 11, 11,  1,  5],
        [ 6,  8,  9,  4,  1,  1, 11, 11,  1,  5]])
Logits shape: torch.Size([2, 10, 13])
Next tokens: tensor([11,  7])
Input t