In [None]:
# developer Mujtaba Ghulami for learn multihead attention model and sample PositionalEncoding

In [None]:
%%capture
!pip install datasets transformers
!pip install torchinfo
!pip install torchviz

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from datasets import load_dataset
from transformers import AutoTokenizer
from torch.utils.data import DataLoader
#from torchviz import make_dot
#from torchinfo import summary
from transformers import GPT2Tokenizer
import torch.nn.functional as F
import time
import pandas as pd
from datasets import Dataset

In [None]:

def apply_rope(x, rope_freqs):
    # x: (batch, seq_len, num_heads, head_dim)
    bsz, seq_len, num_heads, head_dim = x.shape
    x_ = x.view(bsz, seq_len, num_heads, head_dim // 2, 2)
    cos, sin = rope_freqs
    cos = cos[:seq_len, None, None, :, :]
    sin = sin[:seq_len, None, None, :, :]
    x_out = torch.cat([
        x_[..., 0] * cos - x_[..., 1] * sin,
        x_[..., 0] * sin + x_[..., 1] * cos
    ], dim=-1)
    return x_out.view(bsz, seq_len, num_heads, head_dim)

def build_rope_cache(max_seq_len, head_dim, base=10000):
    freqs = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim))
    positions = torch.arange(max_seq_len).float()
    angles = torch.einsum('i,j->ij', positions, freqs)  # (seq_len, head_dim/2)
    cos = torch.cos(angles).unsqueeze(-1)
    sin = torch.sin(angles).unsqueeze(-1)
    return cos, sin

class GQAAttention(nn.Module):
    def __init__(self, embed_dim, num_q_heads, num_kv_heads, max_len=2048, rope=True):
        super().__init__()
        self.num_q_heads = num_q_heads
        self.num_kv_heads = num_kv_heads
        self.head_dim = embed_dim // num_q_heads
        self.rope = rope

        self.q_proj = nn.Linear(embed_dim, num_q_heads * self.head_dim)
        self.k_proj = nn.Linear(embed_dim, num_kv_heads * self.head_dim)
        self.v_proj = nn.Linear(embed_dim, num_kv_heads * self.head_dim)
        self.out_proj = nn.Linear(embed_dim, embed_dim)

        if rope:
            self.register_buffer("rope_freqs", build_rope_cache(max_len, self.head_dim), persistent=False)

    def forward(self, x):
        bsz, seq_len, _ = x.size()

        q = self.q_proj(x).view(bsz, seq_len, self.num_q_heads, self.head_dim)
        k = self.k_proj(x).view(bsz, seq_len, self.num_kv_heads, self.head_dim)
        v = self.v_proj(x).view(bsz, seq_len, self.num_kv_heads, self.head_dim)

        # Expand K/V to match Q heads (GQA)
        if self.num_q_heads != self.num_kv_heads:
            k = k.repeat_interleave(self.num_q_heads // self.num_kv_heads, dim=2)
            v = v.repeat_interleave(self.num_q_heads // self.num_kv_heads, dim=2)

        # Apply RoPE
        if self.rope:
            q = apply_rope(q, self.rope_freqs)
            k = apply_rope(k, self.rope_freqs)

        # Attention
        attn_weights = torch.matmul(q, k.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.head_dim, device=x.device, dtype=q.dtype))
        #attn_weights = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
        causal_mask = torch.triu(torch.ones(seq_len, seq_len, device=x.device), 1).bool()
        attn_weights = attn_weights.masked_fill(causal_mask, float('-inf'))
        attn_probs = torch.softmax(attn_weights, dim=-1)

        attn_output = torch.matmul(attn_probs, v)
        attn_output = attn_output.view(bsz, seq_len, -1)

        return self.out_proj(attn_output)

class MOE(nn.Module):
    def __init__(self, embed_dim, ff_dim, num_experts=4, top_k=1):
        super(MOE, self).__init__()
        self.num_experts = num_experts
        self.top_k = top_k

        # Create experts (each is a feed-forward network)
        self.experts = nn.ModuleList([
            nn.Sequential(
                nn.Linear(embed_dim, ff_dim),
                nn.SiLU(),
                nn.Linear(ff_dim, embed_dim)
            )
            for _ in range(num_experts)
        ])

        # Gating network
        self.gate = nn.Linear(embed_dim, num_experts)

    def forward(self, x):
        # x: (batch, seq_len, embed_dim)
        gate_logits = self.gate(x)  # (batch, seq_len, num_experts)
        gate_scores = torch.softmax(gate_logits, dim=-1)  # probabilities

        # Get top-k experts per token
        topk_scores, topk_indices = torch.topk(gate_scores, self.top_k, dim=-1)  # (batch, seq_len, top_k)

        # Initialize combined output
        output = torch.zeros_like(x)

        for i in range(self.top_k):
            expert_idx = topk_indices[..., i]  # (batch, seq_len)
            score = topk_scores[..., i].unsqueeze(-1)  # (batch, seq_len, 1)

            # Mask tokens for this expert
            for exp_id, expert in enumerate(self.experts):
                mask = (expert_idx == exp_id).unsqueeze(-1)  # (batch, seq_len, 1)
                if mask.any():
                    exp_out = expert(x * mask)  # Apply expert only where mask is True
                    output += exp_out * score * mask

        return output

# Causal Self-Attention
class MatrixModel(nn.Module):
    def __init__(self, embed_dim, num_heads, ff_dim,num_experts):
        super(MatrixModel, self).__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads

        # GQA Attention
        self.gqa = GQAAttention(embed_dim, num_q_heads=num_heads, num_kv_heads=num_heads//2, max_len=2048, rope=True)
        # Feed-forward network following attention
        self.ffn = MOE(embed_dim, ff_dim, num_experts=num_experts, top_k=1)
        self.layernorm1 = nn.LayerNorm(embed_dim)
        self.layernorm2 = nn.LayerNorm(embed_dim)


    def forward(self, inputs):
        # inputs shape: (batch, seq_len, embed_dim)
        attn_output = self.gqa(inputs)

        # Residual connection and layer normalization
        out1 = self.layernorm1(inputs + attn_output)
        # Pass through feed-forward network
        ffn_output = self.ffn(out1)

        output = self.layernorm2(out1 + ffn_output)
        return output


# Full Brain-Inspired Model Module (now a Causal Language Model)
class MatrixGPT_MOE_GQA_ROPE(nn.Module):
    def __init__(self, vocab_size, max_length, embed_dim, num_layers,
                 num_heads, key_dim, ff_dim,num_experts):
        super(MatrixGPT_MOE_GQA_ROPE, self).__init__()
        self.max_length = max_length
        self.embed_dim = embed_dim

        # Embedding layer
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        # Positional encoding layer


        # List of Causal Self-Attention layers
        self.MatrixModel_layers = nn.ModuleList([
            MatrixModel(embed_dim, num_heads, ff_dim,num_experts)
            for _ in range(num_layers)
        ])

        self.output_layer = nn.Linear(self.embed_dim, vocab_size)

    def forward(self, x):
        # x shape: (batch, seq_len)
        x = self.embedding(x)  # (batch, seq_len, embed_dim)

        for attn in self.MatrixModel_layers:
            x= attn(x)

        # Output layer; for language modeling, output logits over vocab for each token.
        logits = self.output_layer(x)
        # For classification we often apply softmax externally (e.g., in loss function)
        return logits



In [None]:
vocab_size = 50259
max_length = 1024
embed_dim = 1536
num_layers = 16   # Increase depth for better representation
num_heads = 8
key_dim = 192  # Should be embed_dim // num_heads
ff_dim = 6144
num_experts=4
model_path= "/kaggle/working/MatrixGPT.pth" #"/kaggle/input/matrix/MatrixGPT.pth" #"/content/drive/MyDrive/brain_p/MatrixGPT.pth" # "/kaggle/working/MatrixGPT.pth"
save_path = "/kaggle/working/MatrixGPT.pth"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
%%capture
model = MatrixGPT_MOE_GQA_ROPE(
    vocab_size=vocab_size,
    max_length=max_length,
    embed_dim=embed_dim,
    num_layers=num_layers,
    num_heads=num_heads,
    key_dim=key_dim,
    ff_dim=ff_dim,
    num_experts=num_experts
)

In [None]:
%%capture
if torch.cuda.device_count() > 1:
    print("Let's use", torch.cuda.device_count(), "GPUs!")
    model = nn.DataParallel(model)

model.to(device)

In [None]:
%%capture
optimizer = optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss()

In [None]:
%%capture
# Define header tokens
START_HEADER = "<|startheader|>"
END_HEADER = "<|endheader|>"

# Load the dataset


# Initialize the tokenizer
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
# Set up padding and end-of-text token
tokenizer.pad_token = tokenizer.eos_token  # Use EOS token as padding

# Add special tokens: ensure EOS and header tokens are added
special_tokens_dict = {
    "eos_token": "<|endoftext|>",
    "additional_special_tokens": [START_HEADER, END_HEADER]
}
tokenizer.add_special_tokens(special_tokens_dict)

def _return_header(message) -> str:
    role = message.get("from", "")
    if role == "system":
        return "system:"
    elif role == "gpt":
        return "assistant:"
    elif role == "human":
        return "user:"
    return "unknown:"

def encode_header(message):
    header = _return_header(message)
    # Wrap the header text with start and end header tokens
    return f"{START_HEADER}{header}{END_HEADER}"

def encode_message(message) -> str:
    text = encode_header(message)
    text += message["value"].strip()
    text += "<|endoftext|>"  # Append the correct end-of-text token
    return text

def encode_dialog_prompt(dialog):
    # Concatenate all messages in the dialog into one string.
    return "".join(encode_message(message) for message in dialog)

def hermes_ins(batch):
    # Encode the conversation in each batch item
    texts = [encode_dialog_prompt(item['conversations']) for item in batch]
    tokenized = tokenizer(
        texts,
        return_tensors="pt",
        padding=True,  # You may change padding behavior if desired
        truncation=True,
        max_length=max_length + 1  # Increased max_length by 1 to account for labels
    )
    input_ids = tokenized["input_ids"].long()
    inputs = input_ids[:, :-1]
    labels = input_ids[:, 1:]
    return {"input_ids": inputs, "labels": labels, "text": texts}

# Create DataLoader


In [None]:
%%capture
#!wget https://huggingface.co/datasets/teknium/OpenHermes-2.5/resolve/main/openhermes2_5.json

In [None]:
# Load the JSON data using pandas
df = pd.read_json("openhermes2_5.json")

# Create a Dataset from the pandas DataFrame
OpenHermes = Dataset.from_pandas(df)

In [None]:
#OpenHermes = load_dataset("teknium/OpenHermes-2.5", split='train', trust_remote_code=True)
hermes_instruct = DataLoader(OpenHermes, batch_size=1, shuffle=True, collate_fn=hermes_ins)

In [None]:
def save_all(model,optimizer,loss):
    checkpoint = {
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': loss,  }
    torch.save(checkpoint, save_path)
    time.sleep(7)

In [None]:
%%capture
checkpoint = torch.load(model_path, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
loss = checkpoint['loss']
del checkpoint

In [None]:
str(loss.item())

'0.4338739812374115'

In [None]:
save = 1  # Initialize save counter
saved=0
epoch=0
model.train()

for epoch in range(1):  # Change to desired number of epochs
    epoch += 1
    for batch in hermes_instruct:
        inputs = batch["input_ids"].to(device)
        targets = batch["labels"].to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs.view(-1, vocab_size), targets.view(-1))
        loss.backward()
        optimizer.step()
        print(f"\rLoss: {str(loss.item())[:5]} epoch: {epoch} saved: {saved} cycle: {save}", end="\t\t\t")
        save = save+1
        if save > 3500:
            save = 0
            save_all(model, optimizer, loss)  # Save model
            saved =saved+1
            if saved % 3 == 0:
                #upload()
                pass
            break
            print(f"\rModel saved", end="")


In [None]:
# Initialize parameters
save = 1  # Initialize save counter
saved = 0
epoch = 0
model.train()

# Gradient accumulation settings
accum_steps = 16  # Number of small batches to accumulate gradients over
effective_batch_size = 1 * accum_steps  # Simulated larger batch size
#print(f"Simulating effective batch size: {effective_batch_size}")

for epoch in range(1):  # Change to desired number of epochs
    epoch += 1
    optimizer.zero_grad()  # Clear gradients at the start of each epoch
    for batch_idx, batch in enumerate(hermes_instruct):
        inputs = batch["input_ids"].to(device)
        targets = batch["labels"].to(device)

        # Forward pass
        outputs = model(inputs)
        loss = criterion(outputs.view(-1, vocab_size), targets.view(-1))

        # Scale loss to account for gradient accumulation
        loss = loss / accum_steps
        loss.backward()  # Accumulate gradients

        # Perform optimization step after accum_steps batches
        if (batch_idx + 1) % accum_steps == 0:
            optimizer.step()  # Update model parameters
            optimizer.zero_grad()  # Clear gradients after update
            print(f"\rLoss: {str(loss.item() * accum_steps)[:5]} epoch: {epoch} saved: {saved} cycle: {save}", end="\t\t\t")

        save += 1
        if save > 3500:
            save = 0
            save_all(model, optimizer,str(loss.item() * accum_steps)[:5])  # Save model
            saved += 1
            if saved % 3 == 0:
                # upload()
                pass
            #break
            print(f"\rModel saved", end="")

Loss: 6.279 epoch: 1 saved: 3 cycle: 937			

In [None]:
def top_k_sampling(logits, k):
    """
    Select the next token using top-k sampling.
    Args:
        logits (Tensor): Logits for the current token with shape [vocab_size].
        k (int): The number of top tokens to sample from.
    Returns:
        int: The token id sampled from the top-k distribution.
    """
    # Apply softmax to get probabilities.
    probabilities = F.softmax(logits, dim=-1)
    # Get the top-k token ids and their probabilities.
    topk_probs, topk_indices = torch.topk(probabilities, k)
    # Normalize the top-k probabilities.
    topk_probs = topk_probs / torch.sum(topk_probs)
    # Sample one token id from the top-k distribution.
    next_token_id = torch.multinomial(topk_probs, 1).item()
    # Get the corresponding token id from topk_indices.
    return topk_indices[next_token_id].item()

def generate_text_k(model, tokenizer, input_text,device, max_length=50, k=10):
    model.eval()
    # Tokenize the input text.
    input_ids = tokenizer.encode(input_text, return_tensors="pt").to(device)
    generated = input_ids.tolist()[0]

    with torch.no_grad():
        for _ in range(max_length):
            # Use only the last token as input along with the previous hidden state.
            input_token = torch.tensor([[generated[-1]]]).to(device)
            logits = model(input_token)
            # Get logits for the last token (shape: [1, 1, vocab_size]) and remove unneeded dimensions.
            next_token_logits = logits[:, -1, :].squeeze(0) # Corrected indexing

            # Sample the next token using top-k sampling.
            next_token_id = top_k_sampling(next_token_logits, k)
            generated.append(next_token_id)

            # Optionally, stop generation if the end-of-sequence token is generated.
            if tokenizer.eos_token_id is not None and next_token_id == tokenizer.eos_token_id:
                break

    # Decode the complete generated token list.
    generated_text = tokenizer.decode(generated, skip_special_tokens=True)
    return generated_text

def generate_text(model, tokenizer, input_text,device, max_length=50):
    model.eval()
    # Tokenize the input text.
    input_ids = tokenizer.encode(input_text, return_tensors="pt").to(device)
    generated = input_ids.tolist()[0]

    with torch.no_grad():
        for _ in range(max_length):
            # Provide the complete sequence each time to help the model consider context.
            input_ids_tensor = torch.tensor([generated]).to(device)

            # Get probability distribution for the next token.
            logits = model(input_ids_tensor)
            last_token_logits = logits[:, -1, :]  # shape (batch, vocab_size) # Corrected indexing
            probabilities = torch.softmax(last_token_logits, dim=-1)
            # Greedy sampling: choose the token with the highest probability.
            next_token_id = torch.argmax(probabilities, dim=-1).item()
            generated.append(next_token_id)

            # Stop generation if the EOS token is produced.
            if next_token_id == tokenizer.eos_token_id:
                break

    generated_text = tokenizer.decode(generated, skip_special_tokens=True)
    return generated_text

class TopPTextGenerator:
    """
    A class to perform text generation using nucleus (top-p) sampling.
    """
    def __init__(self, model, tokenizer, top_p=0.9, temperature=1.0, device=None):
        """
        model: PyTorch module that returns logits of shape [batch_size, seq_length, vocab_size]
        tokenizer: A tokenizer with encode/decode methods and an eos_token_id attribute.
        top_p: The cumulative probability threshold for nucleus sampling.
        temperature: A factor to control randomness; higher values increase randomness.
        device: torch.device to use.
        """
        self.model = model
        self.tokenizer = tokenizer
        self.top_p = top_p
        self.temperature = temperature
        self.device = device

    def nucleus_sampling(self, logits):
        """
        Applies nucleus (top-p) filtering to the logits.
        logits: Tensor of shape [vocab_size] representing logits for the next token.
        Returns the logits with values filtered out that do not belong to the top-p cumulative distribution.
        """
        # Apply temperature scaling
        logits = logits / self.temperature

        # Compute probabilities from logits
        probs = F.softmax(logits, dim=-1)

        # Sort the probabilities in descending order
        sorted_probs, sorted_indices = torch.sort(probs, descending=True)

        # Compute cumulative probabilities of the sorted tensor
        cumulative_probs = torch.cumsum(sorted_probs, dim=-1)

        # Create a mask to filter out tokens with cumulative probability above the threshold
        sorted_indices_to_remove = cumulative_probs > self.top_p

        # Shift the mask one token to the right to keep the first token above the threshold
        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
        sorted_indices_to_remove[..., 0] = 0

        # Prepare an output copy of logits to modify
        filtered_logits = logits.clone()
        # Get the indices to remove from the sorted token indices
        indices_to_remove = sorted_indices[sorted_indices_to_remove]
        filtered_logits[indices_to_remove] = -float('Inf')
        return filtered_logits

    def generate(self, prompt, seq_len=50):
        """
        Generates text conditioned on a prompt.

        prompt: Starting text string.
        seq_len: Maximum number of tokens to generate.
        Returns the generated text string.
        """
        self.model.eval()
        # Encode the prompt using the GPT-2 tokenizer.
        token_ids = self.tokenizer.encode(prompt)
        input_ids = torch.tensor(token_ids, dtype=torch.long, device=self.device).unsqueeze(0)  # shape: [1, seq_length]

        with torch.no_grad():
            for _ in range(seq_len):
                # Get logits from the model (assuming model returns logits for every token)
                logits = self.model(input_ids)  # shape: [1, current_seq_len, vocab_size]
                next_token_logits = logits[0, -1, :]  # shape: [vocab_size] # Corrected indexing

                # Apply nucleus sampling filtering to logits
                filtered_logits = self.nucleus_sampling(next_token_logits)

                # Convert filtered logits to probabilities and sample the next token
                probs = F.softmax(filtered_logits, dim=-1)
                next_token = torch.multinomial(probs, num_samples=1)

                # Append the sample to the sequence
                input_ids = torch.cat((input_ids, next_token.unsqueeze(0)), dim=1)

                # If we hit the end-of-sequence token, stop early.
                if self.tokenizer.eos_token_id and next_token.item() == self.tokenizer.eos_token_id:
                    break

        output_text = self.tokenizer.decode(input_ids.squeeze().tolist(), skip_special_tokens=True)
        return output_text


In [None]:
generator = TopPTextGenerator(model, tokenizer, top_p=0.9, temperature=1.0, device=device)

In [None]:
prompt="where is United States"
system="<|startheader|>system:<|endheader|>You are an AI assistant. You will be given a task. You must generate a detailed and long answer.<|endoftext|>"
input_text = f"{system}<|startheader|>user:<|endheader|>{prompt}<|endoftext|><|startheader|>assistant:<|endheader|>"

In [None]:
generated_text = generator.generate(input_text, seq_len=20)
print("Generated Text (Top-P):")
print(generated_text)

Generated Text (Top-P):
system:You are an AI assistant. You will be given a task. You must generate a detailed and long answer.user:where is United Statesassistant:useruserusersystemuserusersystemsystemsystemusersystemsystemuserusersystemuseruseruserusersystem


In [None]:
generated = generate_text_k(model, tokenizer, input_text,device, max_length=10, k=10)
print("Generated text (Top-K):\n", generated)

Generated text (Top-K):
 system:You are an AI assistant. You will be given a task. You must generate a detailed and long answer.user:where is United Statesassistant:usersystemuseruseruseruserusersystemsystemuser


In [None]:
generated = generate_text(model, tokenizer, input_text,device, max_length=10)
print("Generated text (Greedy):\n", generated)

Generated text (Greedy):
 system:You are an AI assistant. You will be given a task. You must generate a detailed and long answer.user:where is United Statesassistant:useruseruseruseruseruseruseruseruseruser


In [None]:
c=0
for i in hermes_instruct:
    #print(i['input_ids'][1])
    #print(i['labels'][1])
    #text2= tokenizer.decode(i['input_ids'][0], skip_special_tokens=False)
    #print(text2)
    print("--------------------------------------------------")
    #text=tokenizer.decode(i['labels'][0], skip_special_tokens=False)
    #print(text)
    #print(f"\r{c++}",end="")
    print(i["text"])
    break

--------------------------------------------------


In [None]:
tokenizer.decode(tokenizer.encode(input_text),skip_special_tokens=False)

In [None]:
tokenizer.decode(50258)