<a href="https://colab.research.google.com/github/percy-b/Al-kwarizmhi/blob/main/Wikitext_model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import drive
drive.mount('/content/drive')

import os

# Define the save path in your Google Drive
save_dir = "/content/drive/My Drive/model_checkpoints"

# Create the directory if it doesn't exist
os.makedirs(save_dir, exist_ok=True)


Mounted at /content/drive


In [None]:
!pip install transformers datasets tiktoken


In [None]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset
import tiktoken

import tiktoken
from datasets import load_dataset

# Initialize tiktoken's GPT-2 BPE encoder.
encoder = tiktoken.get_encoding("gpt2")

def text_to_token_ids(text):
    """Convert text to token IDs using tiktoken."""
    return encoder.encode(text)

def token_ids_to_text(token_ids):
    """Convert token IDs back to text using tiktoken."""
    return encoder.decode(token_ids)

# Load a subset of WikiText-103 (for memory efficiency)
dataset = load_dataset("wikitext", "wikitext-103-v1", split="train")
#dataset = dataset.select(range(20000))  # Use only the first 1000 samples

def tokenize_example(example):
    # Tokenize the text and store the token IDs in a new field.
    example["token_ids"] = text_to_token_ids(example["text"])
    return example

# Apply tokenization to each example
dataset = dataset.map(tokenize_example, batched=False)




In [None]:
max_seq_length = 512  # Set your desired sequence length

examples = []
for sample in dataset:
    tokens = sample["token_ids"]
    # Create non-overlapping chunks for this sample.
    for i in range(0, len(tokens) - max_seq_length, max_seq_length):
        chunk = tokens[i : i + max_seq_length]
        examples.append(chunk)

print("Number of training examples:", len(examples))


In [None]:
import torch
from torch.utils.data import Dataset

class WikiTextDataset(Dataset):
    def __init__(self, examples):
        self.examples = examples

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, idx):
        # Convert list of token IDs to a tensor.
        input_ids = torch.tensor(self.examples[idx], dtype=torch.long)
        # Create targets by shifting the input one token to the left.
        targets = input_ids.clone()
        targets[:-1] = input_ids[1:]
        targets[-1] = -100  # Set the final token to be ignored in loss computation.
        return input_ids, targets

train_dataset = WikiTextDataset(examples)


In [None]:
from torch.utils.data import DataLoader

dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True)


In [None]:
def generate(model, idx, max_new_tokens, context_size, temperature=0.0, top_k=None, eos_id=None):
    """
    Generates text using the model, with optional temperature scaling and top-k sampling.
    """
    for _ in range(max_new_tokens):
        idx_cond = idx[:, -context_size:]
        with torch.no_grad():
            logits = model(idx_cond)
        logits = logits[:, -1, :]  # Focus on last time step

        # Top-k filtering
        if top_k is not None:
            top_logits, _ = torch.topk(logits, top_k)
            min_val = top_logits[:, -1]
            logits = torch.where(logits < min_val, torch.tensor(float("-inf")).to(logits.device), logits)

        # Apply temperature scaling if set
        if temperature > 0.0:
            logits = logits / temperature
            probs = torch.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)
        else:
            idx_next = torch.argmax(logits, dim=-1, keepdim=True)

        if eos_id is not None and (idx_next == eos_id).all():
            break

        idx = torch.cat((idx, idx_next), dim=1)
    return idx


In [None]:
def prompt_to_token_ids(prompt):
    return torch.tensor([text_to_token_ids(prompt)], dtype=torch.long)


In [None]:
import math
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

# -------------------------
# My custom modules:
# -------------------------
class PositionEmbeddingFixedWeights(nn.Module):
    def __init__(self, seq_length, vocab_size, output_dim):
        super(PositionEmbeddingFixedWeights, self).__init__()

        self.word_embedding_layer = nn.Embedding(vocab_size, output_dim)
        self.position_embedding_layer = nn.Embedding(seq_length, output_dim)

        pos_embedding_matrix = self.get_position_encoding(seq_length, output_dim)
        self.position_embedding_layer.weight.data.copy_(torch.tensor(pos_embedding_matrix, dtype=torch.float))
        self.position_embedding_layer.weight.requires_grad = False

    def get_position_encoding(self, seq_len, d, n=10000):
        P = np.zeros((seq_len, d))
        for k in range(seq_len):
            for i in np.arange(int(d / 2)):
                denominator = np.power(n, 2 * i / d)
                P[k, 2 * i] = np.sin(k / denominator)
                P[k, 2 * i + 1] = np.cos(k / denominator)
        return P

    def forward(self, inputs):
        position_indices = torch.arange(inputs.size(1), device=inputs.device).unsqueeze(0)
        embedded_words = self.word_embedding_layer(inputs)
        embedded_positions = self.position_embedding_layer(position_indices)
        return embedded_words + embedded_positions

class DotProductAttention(nn.Module):
    def __init__(self):
        super().__init__()
    def forward(self, queries, keys, values, d_k, mask=None):
        scores = queries @ keys.transpose(-2, -1) / torch.sqrt(torch.tensor(d_k, dtype=queries.dtype))
        if mask is not None:
            scores += -1e9 * mask
        weights = torch.softmax(scores, dim=-1)
        return weights @ values

class MultiHeadAttention(nn.Module):
    def __init__(self, h, d_model):
        super(MultiHeadAttention, self).__init__()
        self.attention = DotProductAttention()
        self.heads = h
        self.d_head = d_model // h
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)

    def reshape_tensor(self, x, heads, flag):
        if flag:
            batch_size, seq_length, d_model = x.shape
            x = x.reshape(batch_size, seq_length, heads, -1)
            x = x.permute(0, 2, 1, 3)
        else:
            x = x.permute(0, 2, 1, 3)
            batch_size, seq_length, heads, depth = x.shape
            x = x.reshape(batch_size, seq_length, -1)
        return x

    def forward(self, queries, keys, values, mask=None):
        q_reshaped = self.reshape_tensor(self.W_q(queries), self.heads, True)
        k_reshaped = self.reshape_tensor(self.W_k(keys), self.heads, True)
        v_reshaped = self.reshape_tensor(self.W_v(values), self.heads, True)
        o_reshaped = self.attention(q_reshaped, k_reshaped, v_reshaped, self.d_head, mask)
        output = self.reshape_tensor(o_reshaped, self.heads, False)
        return self.W_o(output)

class LeafAttention(nn.Module):
    def __init__(self, h, d_model, dropout=0.1):
        super().__init__()
        self.self_attention = MultiHeadAttention(h, d_model)
        self.norm = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
    def forward(self, x, nchunks):
        batch_size, seq_len, d_model = x.shape
        segments = x.chunk(nchunks, dim=1)
        out = []
        for segment in segments:
            seg_length = segment.shape[1]
            attn_mask = torch.tril(torch.ones(seg_length, seg_length, device=x.device))
            m_attn = self.self_attention(segment, segment, segment, mask=attn_mask)
            m_attn = self.dropout(self.norm(m_attn))
            out.append(m_attn)
        concatenated = torch.cat(out, dim=1)
        return concatenated

class NodeAttention(nn.Module):
    def __init__(self, h, d_model, dropout=0.1):
        super().__init__()
        self.cross_attention = MultiHeadAttention(h, d_model)
        self.norm = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
    def forward(self, left_child, right_child):
        queries = right_child
        keys = left_child
        values = left_child
        seq_length = right_child.shape[1]
        attn_mask = torch.tril(torch.ones(seq_length, seq_length, device=right_child.device))
        m_attn = self.cross_attention(queries, keys, values, mask=attn_mask)
        m_attn = self.dropout(self.norm(m_attn))
        merged = torch.cat([left_child, m_attn], dim=1)
        return merged

class HierrachicalAttention(nn.Module):
    def __init__(self, h, d_model, levels):
        super().__init__()
        assert d_model % h == 0, "d_model must be divisible by h"
        self.levels = levels
        self.leaf_attention = LeafAttention(h, d_model)
        self.node_layers = nn.ModuleList([NodeAttention(h, d_model) for _ in range(levels)])
    def forward(self, x):
        nchunks = 2 ** self.levels
        current = self.leaf_attention(x, nchunks)
        for level in range(self.levels):
            chunks = current.chunk(2 ** (self.levels - level), dim=1)
            next_level = []
            for i in range(0, len(chunks), 2):
                merged = self.node_layers[level](chunks[i], chunks[i+1])
                next_level.append(merged)
            current = torch.cat(next_level, dim=1)
        return current

class HierrachicalDecoderBlock(nn.Module):
    def __init__(self, h, d_model, levels):
        super().__init__()
        assert d_model % h == 0, "d_model must be divisible by h"
        self.hier_attn1 = HierrachicalAttention(h, d_model, levels)
        self.hier_attn2 = HierrachicalAttention(h, d_model, levels)
        self.feed_forward = nn.Sequential(
            nn.Linear(d_model, 4 * d_model),
            nn.ReLU(),
            nn.Linear(4 * d_model, d_model)
        )
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.levels = levels
    def forward(self, x):
        original_len = x.shape[1]
        nchunks = 2 ** self.levels
        padding_size = (nchunks - original_len % nchunks) % nchunks
        if padding_size > 0:
            x_padded = F.pad(x, (0, 0, 0, padding_size))
        else:
            x_padded = x
        attn1_padded = self.hier_attn1(x_padded)
        attn1 = self.norm1(x + attn1_padded[:, :original_len, :])
        attn2_padded = self.hier_attn2(F.pad(attn1, (0, 0, 0, padding_size)))
        attn2 = self.norm2(attn1 + attn2_padded[:, :original_len, :])
        out = self.norm3(attn2 + self.feed_forward(attn2))
        return out

class Decoder(nn.Module):
    def __init__(self, vocab_size, sequence_length, heads, d_model, levels, drop_out=0.1):
        super().__init__()
        max_levels = int(math.log2(sequence_length))
        if levels > max_levels:
            raise ValueError(
                f"For sequence length {sequence_length}, maximum levels is {max_levels} (requested {levels})."
            )
        self.pos_encoding = PositionEmbeddingFixedWeights(sequence_length, vocab_size, d_model)
        self.decoder_block = HierrachicalDecoderBlock(heads, d_model, levels)
        self.dropout = nn.Dropout(drop_out)
        self.linear = nn.Linear(d_model, vocab_size)
    def forward(self, x, training=True):
        pos_encoding_output = self.pos_encoding(x)
        x = self.dropout(pos_encoding_output) if training else pos_encoding_output
        x = self.decoder_block(x)
        logits = self.linear(x)
        if training:
            return logits
        else:
            return F.softmax(logits, dim=-1)


In [None]:
vocab_size = encoder.n_vocab  # Using tiktoken's vocabulary size
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Hyperparameters for model training
vocab_size = encoder.n_vocab  # Using tiktoken's vocabulary size
heads = 12
d_model = 768        # Embedding dimension (you can adjust)
levels = 5
learning_rate = 3e-4
num_epochs = 100
eval_freq = 5  # Generate text every 5 epochs
max_new_tokens = 50  # Number of tokens to generate

In [None]:
"""
# First code for training while saving epochs

import torch

# Set device to GPU if available, otherwise CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Hyperparameters for model training
vocab_size = encoder.n_vocab  # Using tiktoken's vocabulary size
heads = 12
d_model = 768        # Embedding dimension (you can adjust)
levels = 5
learning_rate = 3e-4
num_epochs = 100
eval_freq = 2  # Generate text every 5 epochs
max_new_tokens = 50  # Number of tokens to generate

# Initialize your model (ensure Decoder is defined in your code)
model = Decoder(vocab_size, max_seq_length, heads, d_model, levels).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss(ignore_index=-100)

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Total Trainable Parameters: {count_parameters(model):,}")

# Training loop with periodic evaluation & generation
model.train()
for epoch in range(num_epochs):
    epoch_loss = 0.0
    for batch in dataloader:
        inputs, targets = batch
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        logits = model(inputs, training=True)  # shape: (batch, seq_length, vocab_size)
        loss = criterion(logits.view(-1, logits.size(-1)), targets.view(-1))
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()

    avg_loss = epoch_loss / len(dataloader)
    print(f"Epoch {epoch+1}/{num_epochs} Loss: {avg_loss:.4f}")

     # --- Save Model Every 10 Epochs ---
    if (epoch + 1) % 10 == 0:
        save_path = os.path.join(save_dir, f"model_epoch_{epoch+1}.pth")
        torch.save(model.state_dict(), save_path)
        print(f"✅ Model saved at {save_path}")


    # Every eval_freq epochs, generate text from the fixed prompt.
    if (epoch + 1) % eval_freq == 0:
        model.eval()
        with torch.no_grad():
            prompt = "Every effort moves you"
            start_tokens = prompt_to_token_ids(prompt).to(device)
            generated_ids = generate(
                model=model,
                idx=start_tokens,
                max_new_tokens=max_new_tokens,
                context_size=max_seq_length,
                top_k=None,
                temperature=0
            )
            output_text = token_ids_to_text(generated_ids.squeeze().tolist())
            print(f"\nGenerated Text at Epoch {epoch+1}:\n{output_text}\n{'='*50}\n")
        model.train()
"""

In [None]:
#continue training from checkpoint
import torch
import os
from glob import glob

# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

# Define model save directory
save_dir = "/content/drive/My Drive/model_checkpoints"
os.makedirs(save_dir, exist_ok=True)

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define model hyperparameters
vocab_size = encoder.n_vocab  # Using tiktoken's vocabulary size
heads = 12
d_model = 768        # Embedding dimension (you can adjust)
levels = 5
learning_rate = 3e-4
num_epochs = 100
eval_freq = 2  # Generate text every 2 epochs
max_new_tokens = 50  # Number of tokens to generate

# Initialize model and optimizer
model = Decoder(vocab_size, max_seq_length, heads, d_model, levels).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
criterion = torch.nn.CrossEntropyLoss(ignore_index=-100)

# Find the latest checkpoint
checkpoint_files = sorted(glob(os.path.join(save_dir, "model_epoch_*.pth")), key=os.path.getmtime)
latest_checkpoint = checkpoint_files[-1] if checkpoint_files else None
start_epoch = 0

# Load checkpoint if available
if latest_checkpoint:
    print(f"🔄 Loading checkpoint from {latest_checkpoint}")
    checkpoint = torch.load(latest_checkpoint, map_location=device)
    model.load_state_dict(checkpoint)
    start_epoch = int(latest_checkpoint.split("_")[-1].split(".")[0])  # Extract epoch number
    print(f"✅ Resumed training from epoch {start_epoch}")

# Training loop
model.train()
for epoch in range(start_epoch, 300):
    epoch_loss = 0.0
    for batch in dataloader:
        inputs, targets = batch
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        logits = model(inputs, training=True)
        loss = criterion(logits.view(-1, logits.size(-1)), targets.view(-1))
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()

    avg_loss = epoch_loss / len(dataloader)
    print(f"Epoch {epoch+1}/{num_epochs} Loss: {avg_loss:.4f}")

    # Save model every 10 epochs
    if (epoch + 1) % 10 == 0:
        save_path = os.path.join(save_dir, f"model_epoch_{epoch+1}.pth")
        torch.save(model.state_dict(), save_path)
        print(f"✅ Model saved at {save_path}")

    # Generate text every eval_freq epochs
    if (epoch + 1) % eval_freq == 0:
        model.eval()
        with torch.no_grad():
            prompt = "Every effort moves you"
            start_tokens = prompt_to_token_ids(prompt).to(device)
            generated_ids = generate(
                model=model,
                idx=start_tokens,
                max_new_tokens=max_new_tokens,
                context_size=max_seq_length,
                top_k=20,
                temperature=0.3
            )
            output_text = token_ids_to_text(generated_ids.squeeze().tolist())
            print(f"\nGenerated Text at Epoch {epoch+1}:\n{output_text}\n{'='*50}\n")
        model.train()
