In [1]:
import os
import re
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from collections import Counter
from datasets import load_dataset, Dataset, load_from_disk
from tokenizers import ByteLevelBPETokenizer
from transformers import AdamW, AutoTokenizer
import torch.nn.utils.prune as prune
import pennylane as qml

print(f"CUDA Available: {torch.cuda.is_available()}")
print(f"PyTorch CUDA Version: {torch.version.cuda}")
print(f"PennyLane Version: {qml.__version__}")

# =============================================================================
# 1. Data Preparation: Load and Process the Dataset
# =============================================================================
dataset = load_dataset('HuggingFaceTB/cosmopedia', 'khanacademy', 
                         streaming=True,
                         split='train')
output_file = "cosmopedia_full.txt"
example_limit = None  # Set to an integer (e.g. 10000) to limit processing, or None for all examples.

with open(output_file, "w", encoding="utf-8") as f:
    count = 0
    for example in dataset:
        text = example["text"]
        # Remove extra newlines and spaces.
        text = re.sub(r'\n+', ' ', text)
        text = re.sub(r'\s+', ' ', text)
        text = text.strip()
        if text:
            f.write(text + "\n")
            count += 1
        if example_limit is not None and count >= example_limit:
            break
        if count % 1000 == 0 and count > 0:
            print(f"Processed {count} examples...")
            
print(f"Completed writing {count} examples to {output_file}")

# =============================================================================
# 2. Train a BPE Tokeniser on the Entire Cosmopedia Dataset Text
# =============================================================================
print("Training the BPE tokeniser on the full Cosmopedia dataset text...")
bpe_tokenizer = ByteLevelBPETokenizer()
bpe_tokenizer.train(
    files=output_file,
    vocab_size=30000,
    min_frequency=2,
    special_tokens=["<s>", "<pad>", "</s>", "<unk>", "<mask>"]
)
output_dir = "bpe_tokenizer_model"
os.makedirs(output_dir, exist_ok=True)
bpe_tokenizer.save_model(output_dir)
print(f"BPE tokeniser model saved to {output_dir}")

# -----------------------------------------------------------------------------
# 2.1 Re-map the token IDs to form a contiguous range.
#      (The tokeniser’s raw vocabulary IDs might not be in 0..N-1.)
# -----------------------------------------------------------------------------
vocab_dict = bpe_tokenizer.get_vocab()  # Returns a dict {token: id}
# Create a mapping from the original (old) ID to a new contiguous ID.
sorted_vocab = sorted(vocab_dict.items(), key=lambda x: x[1])
mapping = {old_id: new_id for new_id, (token, old_id) in enumerate(sorted_vocab)}
new_vocab_size = len(vocab_dict)  # This is now the actual number of tokens.

# =============================================================================
# 3. Tokenise the Dataset Using the Trained BPE Tokeniser & Save It
# =============================================================================
seq_len = 128  # Maximum sequence length

def tokenize_and_pad(text):
    encoded = bpe_tokenizer.encode(text)
    old_ids = encoded.ids
    # Remap each token id; if a token is missing in mapping, use the <unk> token.
    unk_old = bpe_tokenizer.token_to_id("<unk>")
    unk_new = mapping.get(unk_old, 0)
    new_ids = [mapping.get(i, unk_new) for i in old_ids]
    if len(new_ids) < seq_len:
        pad_old = bpe_tokenizer.token_to_id("<pad>")
        pad_new = mapping.get(pad_old, 0)
        new_ids = new_ids + [pad_new] * (seq_len - len(new_ids))
    else:
        new_ids = new_ids[:seq_len]
    return new_ids

print("Tokenising the entire dataset using the trained BPE tokeniser...")
tokenized_data = {"input_ids": []}
with open(output_file, "r", encoding="utf-8") as f:
    for line in f:
        line = line.strip()
        if line:
            token_ids = tokenize_and_pad(line)
            tokenized_data["input_ids"].append(token_ids)

# Create an attention mask: 1 for non-pad tokens and 0 for pad tokens.
pad_old = bpe_tokenizer.token_to_id("<pad>")
pad_new = mapping.get(pad_old, 0)
attention_masks = []
for ids in tokenized_data["input_ids"]:
    mask = [1 if token_id != pad_new else 0 for token_id in ids]
    attention_masks.append(mask)
tokenized_data["attention_mask"] = attention_masks

# Convert the dict to a Hugging Face Dataset and save to disk.
tokenized_dataset = Dataset.from_dict(tokenized_data)
tokenized_dataset.save_to_disk("tokenized_cosmopediaa11")
print("Tokenised dataset saved to disk as 'tokenized_cosmopediaa11'.")

# =============================================================================
# 4. Load the Tokenised Dataset & Prepare DataLoader for Training
# =============================================================================
tokenized_dataset = load_from_disk("tokenized_cosmopediaa11")
print("Tokenised dataset loaded from disk.")
print("Dataset features:", tokenized_dataset.features)

tokenized_dataset.set_format("torch", columns=["input_ids", "attention_mask"])
batch_size = 16
train_dataloader = DataLoader(tokenized_dataset, batch_size=batch_size, shuffle=True)

# =============================================================================




# 5.3 Low-Rank Factorisation Feedforward Layer
class LowRankLinear(nn.Module):
    def __init__(self, in_features, out_features, rank=128):
        super(LowRankLinear, self).__init__()
        self.U = nn.Linear(in_features, rank, bias=False)
        self.V = nn.Linear(rank, out_features, bias=False)
    
    def forward(self, x):
        return self.V(self.U(x))



# 5.5 Quantisation Wrapper (Placeholder)
class QuantizedLayer(nn.Module):
    def __init__(self, module):
        super(QuantizedLayer, self).__init__()
        self.module = module
    
    def forward(self, x):
        return self.module(x)

# 5.6 Pruning Wrapper (L1 Unstructured Pruning)
class PrunedTransformer(nn.Module):
    def __init__(self, module, amount=0.3):
        super(PrunedTransformer, self).__init__()
        self.module = module
        for name, child in self.module.named_modules():
            if isinstance(child, nn.Linear):
                prune.l1_unstructured(child, name="weight", amount=amount)
    
    def forward(self, x):
        return self.module(x)

# 5.1 Mixture of Experts (MoE)
class SparseMoE(nn.Module):
    def __init__(self, d_model, num_experts=8, top_k=5):
        super(SparseMoE, self).__init__()
        self.top_k = top_k
        self.num_experts = num_experts
        self.experts = nn.ModuleList([nn.Linear(d_model, d_model) for _ in range(num_experts)])
        self.gate = nn.Linear(d_model, num_experts)
        self.last_gate = None  # Placeholder for analysis

    def forward(self, x):
        # x: [batch, seq_len, d_model]
        gate_logits = self.gate(x)  # [batch, seq_len, num_experts]
        gate_probs = F.softmax(gate_logits, dim=-1)
        self.last_gate = gate_probs  # Save gating probabilities for analysis

        topk_probs, topk_indices = torch.topk(gate_probs, self.top_k, dim=-1)
        
        batch_size, seq_len, d_model = x.size()
        output = torch.zeros_like(x)
        for b in range(batch_size):
            for t in range(seq_len):
                token = x[b, t]  # [d_model]
                expert_sum = 0
                for i in range(self.top_k):
                    expert_index = topk_indices[b, t, i]
                    prob = topk_probs[b, t, i]
                    expert_out = self.experts[expert_index](token)
                    expert_sum += prob * expert_out
                output[b, t] = expert_sum
        return output

# 5.4 Sparse Self-Attention (Linformer-like) with Causal Masking
class LinformerSelfAttention(nn.Module):
    def __init__(self, d_model, seq_len, compression_ratio=0.7):
        super(LinformerSelfAttention, self).__init__()
        reduced_dim = int(d_model * compression_ratio)
        self.query_proj = nn.Linear(d_model, d_model)
        self.key_proj = nn.Linear(d_model, reduced_dim)
        self.value_proj = nn.Linear(d_model, reduced_dim)
        self.out_proj = nn.Linear(reduced_dim, d_model)
        self.last_attn = None  # Placeholder for analysis

    def forward(self, x):
        # x: [batch, seq_len, d_model]
        Q = self.query_proj(x)
        K = self.key_proj(x)
        V = self.value_proj(x)
        # Adjust Q to match the dimension of K if necessary.
        Q_reduced = Q[:, :, :K.size(-1)]
        
        # Create a causal mask so that token t only attends to tokens <= t.
        seq_length = x.size(1)
        mask = torch.tril(torch.ones((seq_length, seq_length), device=x.device))
        mask = mask.unsqueeze(0).unsqueeze(0)  # Shape: [1, 1, seq_len, seq_len]
        
        attn_scores = torch.matmul(Q_reduced, K.transpose(-2, -1)) / (K.size(-1) ** 0.5)
        attn_scores = attn_scores.masked_fill(mask.squeeze(0) == 0, float('-inf'))
        attn_probs = F.softmax(attn_scores, dim=-1)
        self.last_attn = attn_probs  # Save attention weights for analysis
        
        attn_output = torch.matmul(attn_probs, V)
        output = self.out_proj(attn_output)
        return output

# 5.8 Extend SMARTModel to Build a Language Model (SMARTLMModel) without RAG
class SMARTLMModel(nn.Module):
    def __init__(self, vocab_size, d_model, max_seq_len, num_experts=8):
        super(SMARTLMModel, self).__init__()
        self.embed = nn.Embedding(vocab_size, d_model)
        # Create positional embeddings for the maximum sequence length.
        self.pos_embed = nn.Parameter(torch.randn(1, max_seq_len, d_model))
        # Removed RAG module from the SMARTModel
        self.smart = nn.Sequential(
            SparseMoE(d_model, num_experts=num_experts, top_k=5),
            LinformerSelfAttention(d_model, seq_len=max_seq_len, compression_ratio=0.7),
            # LowRank feedforward network remains as part of SMARTModel.
            LowRankLinear(d_model, d_model, rank=min(128, d_model))
        )
        self.norm = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(0.1)
        self.lm_head = nn.Linear(d_model, vocab_size)
    
    def forward(self, input_ids):
        # input_ids: [batch, L] where L can be less than max_seq_len.
        x = self.embed(input_ids)          # [batch, L, d_model]
        pos_embed = self.pos_embed[:, :x.size(1), :]  # [1, L, d_model]
        x = x + pos_embed
        # Pass through the simplified SMARTModel without RAG.
        x = self.smart(x)
        # Residual connection and normalisation can be applied here as well.
        x = self.norm(x)
        x = self.norm(x + self.dropout(x))
        logits = self.lm_head(x)
        return logits


# =============================================================================
# 6. Training Setup and Loop
# =============================================================================
# IMPORTANT: Use the new_vocab_size (the contiguous vocabulary size) for the model.
d_model = 512       # Embedding/hidden dimension.
seq_len = 128       # Sequence length (should match tokenisation).
num_epochs = 20
learning_rate = 5e-5

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Re-create the DataLoader (if needed)
train_dataloader = DataLoader(tokenized_dataset, batch_size=batch_size, shuffle=True)

model = SMARTLMModel(vocab_size=new_vocab_size, d_model=d_model, max_seq_len=seq_len, num_experts=8)
model = model.to(device)

optimizer = AdamW(model.parameters(), lr=learning_rate)

# Use a Hugging Face tokeniser (e.g. GPT-2) to obtain the pad token for loss computation.
hf_tokenizer = AutoTokenizer.from_pretrained("gpt2")
loss_fn = nn.CrossEntropyLoss(ignore_index=hf_tokenizer.pad_token_id if hf_tokenizer.pad_token_id is not None else -100)



CUDA Available: True
PyTorch CUDA Version: 12.1
PennyLane Version: 0.38.0


Using the latest cached version of the dataset since HuggingFaceTB/cosmopedia couldn't be found on the Hugging Face Hub
Found the latest cached dataset configuration 'khanacademy' at C:\Users\saikr\.cache\huggingface\datasets\HuggingFaceTB___cosmopedia\khanacademy\0.0.0\0ae6ec63f91742bd2d1eaef4f02232c55d719385 (last modified on Tue Jan 28 22:25:22 2025).


Processed 1000 examples...
Processed 2000 examples...
Processed 3000 examples...
Processed 4000 examples...
Processed 5000 examples...
Processed 6000 examples...
Processed 7000 examples...
Processed 8000 examples...
Processed 9000 examples...
Processed 10000 examples...
Processed 11000 examples...
Processed 12000 examples...
Processed 13000 examples...
Processed 14000 examples...
Processed 15000 examples...
Processed 16000 examples...
Processed 17000 examples...
Processed 18000 examples...
Processed 19000 examples...
Processed 20000 examples...
Processed 21000 examples...
Processed 22000 examples...
Processed 23000 examples...
Processed 24000 examples...
Completed writing 24123 examples to cosmopedia_full.txt
Training the BPE tokeniser on the full Cosmopedia dataset text...
BPE tokeniser model saved to bpe_tokenizer_model
Tokenising the entire dataset using the trained BPE tokeniser...


Saving the dataset (0/1 shards):   0%|          | 0/24123 [00:00<?, ? examples/s]

Tokenised dataset saved to disk as 'tokenized_cosmopediaa11'.
Tokenised dataset loaded from disk.
Dataset features: {'input_ids': Sequence(feature=Value(dtype='int32', id=None), length=-1, id=None), 'attention_mask': Sequence(feature=Value(dtype='int8', id=None), length=-1, id=None)}




In [2]:
model.train()
for epoch in range(num_epochs):
    total_loss = 0.0
    for batch in train_dataloader:
        input_ids = batch["input_ids"].to(device)  # [batch, seq_len]
        # Shift tokens for autoregressive prediction:
        inputs = input_ids[:, :-1]
        targets = input_ids[:, 1:]
        
        optimizer.zero_grad()
        logits = model(inputs)  # logits: [batch, seq_len-1, vocab_size]
        loss = loss_fn(logits.reshape(-1, new_vocab_size), targets.reshape(-1))
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
    avg_loss = total_loss / len(train_dataloader)
    print(f"Epoch {epoch+1}/{num_epochs} - Loss: {avg_loss:.4f}")


print("Training complete.")



Epoch 1/20 - Loss: 5.7727
Epoch 2/20 - Loss: 4.6516
Epoch 3/20 - Loss: 4.3008
Epoch 4/20 - Loss: 4.0676
Epoch 5/20 - Loss: 3.8918
Epoch 6/20 - Loss: 3.7535
Epoch 7/20 - Loss: 3.6347
Epoch 8/20 - Loss: 3.5284
Epoch 9/20 - Loss: 3.4325
Epoch 10/20 - Loss: 3.3447
Epoch 11/20 - Loss: 3.2637
Epoch 12/20 - Loss: 3.1892
Epoch 13/20 - Loss: 3.1194
Epoch 14/20 - Loss: 3.0548
Epoch 15/20 - Loss: 2.9934
Epoch 16/20 - Loss: 2.9359
Epoch 17/20 - Loss: 2.8820
Epoch 18/20 - Loss: 2.8314
Epoch 19/20 - Loss: 2.7837
Epoch 20/20 - Loss: 2.7383
Training complete.


In [3]:
# =============================================================================
# 7. Inference: Generate Output from the Trained Model
# =============================================================================
# Reload the BPE tokeniser from disk.
bpe_tokenizer = ByteLevelBPETokenizer(
    os.path.join("bpe_tokenizer_model", "vocab.json"),
    os.path.join("bpe_tokenizer_model", "merges.txt")
)
hf_tokenizer = AutoTokenizer.from_pretrained("gpt2")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.eval()




SMARTLMModel(
  (embed): Embedding(30000, 512)
  (smart): SMARTModel(
    (moe): SparseMoE(
      (experts): ModuleList(
        (0-7): 8 x Linear(in_features=512, out_features=512, bias=True)
      )
      (gate): Linear(in_features=512, out_features=8, bias=True)
    )
    (rag): RAGRetriever()
    (attn): LinformerSelfAttention(
      (query_proj): Linear(in_features=512, out_features=512, bias=True)
      (key_proj): Linear(in_features=512, out_features=358, bias=True)
      (value_proj): Linear(in_features=512, out_features=358, bias=True)
      (out_proj): Linear(in_features=358, out_features=512, bias=True)
    )
    (ffn): LowRankLinear(
      (U): Linear(in_features=512, out_features=128, bias=False)
      (V): Linear(in_features=128, out_features=512, bias=False)
    )
    (quantized_ffn): QuantizedLayer(
      (module): LowRankLinear(
        (U): Linear(in_features=512, out_features=128, bias=False)
        (V): Linear(in_features=128, out_features=512, bias=False)
      )


In [17]:
def generate_text(prompt, model, bpe_tokenizer, mapping, max_seq_len=128, max_new_tokens=50):
    """
    Greedy decoding with a sliding window.
    """
    model.eval()
    
    def tokenize_prompt(text):
        encoded = bpe_tokenizer.encode(text)
        old_ids = encoded.ids
        unk_old = bpe_tokenizer.token_to_id("<unk>")
        unk_new = mapping.get(unk_old, 0)
        new_ids = [mapping.get(i, unk_new) for i in old_ids]
        return torch.tensor(new_ids).unsqueeze(0)  # Shape: [1, L]
    
    generated = tokenize_prompt(prompt).to(device)
    
    with torch.no_grad():
        for _ in range(max_new_tokens):
            if generated.size(1) > max_seq_len:
                current_input = generated[:, -max_seq_len:]
            else:
                current_input = generated
            logits = model(current_input)  # [1, current_length, vocab_size]
            next_token_logits = logits[0, -1, :]
            next_token = torch.argmax(next_token_logits).unsqueeze(0).unsqueeze(0)
            generated = torch.cat((generated, next_token), dim=1)
            # Stop if the end-of-sequence token is generated.
            if next_token.item() == mapping.get(bpe_tokenizer.token_to_id("</s>"), -1):
                break
    
    generated_ids = generated[0].cpu().tolist()
    decoded_text = bpe_tokenizer.decode(generated_ids)
    return decoded_text

# Example inference:
prompt_text = "A number grid is a simple yet powerful tool for helping students "
generated_output = generate_text(prompt_text, model, bpe_tokenizer, mapping, max_seq_len=128, max_new_tokens=50)
print("\n--- Inference Example ---")
print("Prompt:", prompt_text)
print("Generated Output:", generated_output)


--- Inference Example ---
Prompt: A number grid is a simple yet powerful tool for helping students 
Generated Output: A number grid is a simple yet powerful tool for helping students  and understanding of data. It is important to understand the relationship between two variables and variables. In this section, we will explore how to interpret the relationship between two variables and variables. Let's start with a simple example: Suppose we want to find the


In [20]:
def generate_text(prompt, model, bpe_tokenizer, mapping, max_seq_len=128, max_new_tokens=100):
    """
    Greedy decoding with a sliding window.
    """
    model.eval()
    
    def tokenize_prompt(text):
        encoded = bpe_tokenizer.encode(text)
        old_ids = encoded.ids
        unk_old = bpe_tokenizer.token_to_id("<unk>")
        unk_new = mapping.get(unk_old, 0)
        new_ids = [mapping.get(i, unk_new) for i in old_ids]
        return torch.tensor(new_ids).unsqueeze(0)  # Shape: [1, L]
    
    generated = tokenize_prompt(prompt).to(device)
    
    with torch.no_grad():
        for _ in range(max_new_tokens):
            if generated.size(1) > max_seq_len:
                current_input = generated[:, -max_seq_len:]
            else:
                current_input = generated
            logits = model(current_input)  # [1, current_length, vocab_size]
            next_token_logits = logits[0, -1, :]
            next_token = torch.argmax(next_token_logits).unsqueeze(0).unsqueeze(0)
            generated = torch.cat((generated, next_token), dim=1)
            # Stop if the end-of-sequence token is generated.
            if next_token.item() == mapping.get(bpe_tokenizer.token_to_id("</s>"), -1):
                break
    
    generated_ids = generated[0].cpu().tolist()
    decoded_text = bpe_tokenizer.decode(generated_ids)
    return decoded_text

# Example inference:
prompt_text = "Counting by Tens"
generated_output = generate_text(prompt_text, model, bpe_tokenizer, mapping, max_seq_len=128, max_new_tokens=100)
print("\n--- Inference Example ---")
print("Prompt:", prompt_text)
print("Generated Output:", generated_output)


--- Inference Example ---
Prompt: Counting by Tens
Generated Output: Counting by Tens in the United States, and their role in the United States. The Constitution was the United States of the United States Constitution in the United States, which had been established by the Soviet Union of the U.S. Constitution, the federal government, government, and government, and the government, and government, and government. One significant consequences for the government was the government in the federal government, the Constitution, which states that the Constitution was the Constitution was the Constitution in the President of the Constitution, which
