<a href="https://colab.research.google.com/github/shivendrra/AIVA-4x500m/blob/main/base/generate.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
!pip install tiktoken

Collecting tiktoken
  Downloading tiktoken-0.6.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.8/1.8 MB[0m [31m9.4 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: tiktoken
Successfully installed tiktoken-0.6.0


In [7]:
# hyperparameters
batch_size = 10
block_size = 256
max_iters = 2500
eval_interval = 100
learning_rate = 3e-5
eval_iters = 250
d_model = 512
n_head = 18
n_layers = 12
dropout = 0.2
norm_eps = 1e-05

import torch
import torch.nn as nn
from torch.nn import functional as F

device = 'cuda' if torch.cuda.is_available() else 'cpu'

class RMSNorm(nn.Module):
  def __init__(self, dim: int, eps: float = 1e-6):
    """
      Initialize the RMSNorm normalization layer.
      Args:
        dim (int): The dimension of the input tensor.
        eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
      Attributes:
        eps (float): A small value added to the denominator for numerical stability.
        weight (nn.Parameter): Learnable scaling parameter.
    """
    super().__init__()
    self.eps = eps
    self.weight = nn.Parameter(torch.ones(dim))

  def _norm(self, x):
    """
      Apply the RMSNorm normalization to the input tensor.
        Args:
        x (torch.Tensor): The input tensor.
      Returns:
        torch.Tensor: The normalized tensor.
    """
    return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

  def forward(self, x):
    """
      Forward pass through the RMSNorm layer.
      Args:
          x (torch.Tensor): The input tensor.
      Returns:
          torch.Tensor: The output tensor after applying RMSNorm.
    """
    output = self._norm(x.float()).type_as(x)
    return output * self.weight

class UnMaskedHead(nn.Module):
  def __init__(self, head_size, d_model, block_size, dropout):
    super().__init__()
    self.key = nn.Linear(d_model, head_size, bias=True)
    self.query = nn.Linear(d_model, head_size, bias=True)
    self.value = nn.Linear(d_model, head_size, bias=False)
    self.dropout = nn.Dropout(dropout)
    self.rel_pos_embd = nn.Parameter(torch.randn(block_size, block_size, head_size))

  def forward(self, x):
    B, T, C = x.shape
    key = self.key(x)
    query = self.query(x)

    scores = torch.matmul(query, key.transpose(-2, -1)) / (key.shape[-1] ** -0.5)
    rel_pos_scores = torch.einsum('btc,tvc->btv', query, self.rel_pos_embd[:T, :T])
    scores = scores + rel_pos_scores

    att_mat = F.softmax(scores, dim=-1)
    att_mat = self.dropout(att_mat)
    value = self.value(x)
    output = torch.matmul(att_mat, value)
    return output

class UnMaskedAttention(nn.Module):
  def __init__(self, d_model, block_size, dropout, n_head):
    head_size = d_model // n_head
    super().__init__()
    self.heads = nn.ModuleList([UnMaskedHead(d_model=d_model, dropout=dropout, block_size=block_size, head_size=head_size) for _ in range(n_head)])
    self.proj = nn.Linear(n_head * head_size, d_model)
    self.dropout = nn.Dropout(dropout)

  def forward(self, x):
    out = torch.cat([h(x) for h in self.heads], dim=-1)
    out = self.dropout(self.proj(out))
    return out

class MaskedHead(nn.Module):
  def __init__(self, d_model, head_size, dropout, block_size):
    super().__init__()
    self.key = nn.Linear(d_model, head_size, bias=False)
    self.query = nn.Linear(d_model, head_size, bias=False)
    self.value = nn.Linear(d_model, head_size, bias=False)
    self.dropout = nn.Dropout(dropout)
    self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))

  def forward(self, x):
    B, T, C = x.shape
    key = self.key(x)
    query = self.query(x)

    scores = torch.matmul(query, key.transpose(-2, -1)) / (key.shape[-1] ** -0.5)
    scores = scores.masked_fill(self.tril[:T, :T] == 0, float('-inf'))

    att_mat = F.softmax(scores, dim=-1)
    att_mat = self.dropout(att_mat)
    value = self.value(x)
    output = torch.matmul(att_mat, value)
    return output

class CasualMaskedAttention(nn.Module):
  def __init__(self, d_model, block_size, dropout, n_head):
    head_size = d_model // n_head
    super().__init__()
    self.heads = nn.ModuleList([MaskedHead(d_model=d_model, dropout=dropout, block_size=block_size, head_size=head_size) for _ in range(n_head)])
    self.proj = nn.Linear(n_head * head_size, d_model)
    self.dropout = nn.Dropout(dropout)

  def forward(self, x):
    out = torch.cat([h(x) for h in self.heads], dim=-1)
    out = self.dropout(self.proj(out))
    return out

class FinalHead(nn.Module):
  def __init__(self, d_model, head_size, dropout, block_size):
    super().__init__()
    self.key = nn.Linear(d_model, head_size, bias=False)
    self.query = nn.Linear(d_model, head_size, bias=False)
    self.value = nn.Linear(d_model, head_size, bias=True)
    self.dropout = nn.Dropout(dropout)

  def forward(self, x, att):
    B, T, C = x.shape
    key = self.key(att)
    query = self.query(att)

    scores = torch.matmul(query, key.transpose(-2, -1)) / (key.shape[-1] ** -0.5)

    att_mat = F.softmax(scores, dim=-1)
    att_mat = self.dropout(att_mat)
    value = self.value(x)
    output = torch.matmul(att_mat, value)
    return output

class FinalAttention(nn.Module):
  def __init__(self, d_model, block_size, dropout, n_head):
    head_size = d_model // n_head
    super().__init__()
    self.heads = nn.ModuleList([FinalHead(d_model=d_model, dropout=dropout, block_size=block_size, head_size=head_size) for _ in range(n_head)])
    self.proj = nn.Linear(n_head * head_size, d_model)
    self.dropout = nn.Dropout(dropout)

  def forward(self, x, att):
    out = torch.cat([h(x, att) for h in self.heads], dim=-1)
    out = self.dropout(self.proj(out))
    return out

class FeedForward(nn.Module):
  def __init__(self, d_model, dropout):
    super().__init__()
    self.net = nn.Sequential(
      nn.Linear(d_model, 4*d_model),
      nn.GELU(),
      nn.Linear(4*d_model, d_model),
      nn.Dropout(dropout)
    )

  def forward(self, x):
    return self.net(x)

class EncoderNetwork(nn.Module):
  def __init__(self, d_model, n_head, norm_eps, dropout, block_size):
    super().__init__()
    self.s_att = UnMaskedAttention(n_head=n_head, d_model=d_model, dropout=dropout, block_size=block_size)
    self.ffwd = FeedForward(d_model, dropout)
    self.dropout = nn.Dropout(dropout)
    self.norm = RMSNorm(d_model, eps=norm_eps)

  def forward(self, src):
    src = self.norm(src)
    src_out = src + self.dropout(self.s_att(src))

    src = self.norm(src_out)
    src_f = src + self.dropout(self.ffwd(src))

    del src_out, src
    return src_f

class DecoderNetwork(nn.Module):
  def __init__(self, d_model, n_head, norm_eps, dropout, block_size):
    super().__init__()
    self.m_att = CasualMaskedAttention(n_head=n_head, d_model=d_model, dropout=dropout, block_size=block_size)
    self.f_att = FinalAttention(d_model=d_model, n_head=n_head, dropout=dropout, block_size=block_size)
    self.ffwd = FeedForward(d_model, dropout)
    self.dropout = nn.Dropout(dropout)
    self.norm = RMSNorm(d_model, eps=norm_eps)

  def forward(self, src, att):
    m_att_out = self.norm(src)
    m_out = src + self.dropout(self.m_att(m_att_out))

    f_out = self.f_att(m_out, self.norm(att))
    f_out = m_out + self.dropout(f_out)

    src_f = self.norm(f_out)
    src_f = f_out + self.dropout(self.ffwd(src_f))

    del f_out, m_out, m_att_out, src, att
    return src_f

class Transformer(nn.Module):
  def __init__(self, vocab_size):
    super().__init__()
    self.block_size = block_size
    self.toked_model = nn.Embedding(vocab_size, d_model)
    self.pos_encod = nn.Embedding(block_size, d_model)
    self.enc_layer = nn.ModuleList([EncoderNetwork(n_head=n_head, norm_eps=norm_eps, block_size=block_size, dropout=dropout, d_model=d_model) for _ in range(n_layers)])
    self.dec_layer = nn.ModuleList([DecoderNetwork(n_head=n_head, norm_eps=norm_eps, block_size=block_size, dropout=dropout, d_model=d_model) for _ in range(n_layers)])
    self.norm_final = RMSNorm(d_model, eps=norm_eps)
    self.linear_final = nn.Linear(d_model, vocab_size)
    self.dropout = nn.Dropout(dropout)
    self.apply(self._init_weights)

  def _init_weights(self, module):
    """
      initialize weights of linear and embedding layers

      Args:
        - module (nn.Module): the module to initialize weights for
    """
    if isinstance(module, nn.Linear):
      torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
      if module.bias is not None:
        torch.nn.init.zeros_(module.bias.data)
    elif isinstance(module, nn.Embedding):
      torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

  def forward(self, idx, targets=None):
    """
      forward pass of the transformer model

    Args:
      - idx (Tensor): input tensor representing token indices
      - targets (Tensor): target tensor for computing loss during training

    Returns:
      - logits (Tensor): output logits from the final linear layer
      - loss (Tensor): optional. computed cross-entropy loss if targets are provided, else None
    """
    B, T = idx.shape

    toked_model = self.toked_model(idx)
    pos_encod = self.pos_encod(torch.arange(T, device=device))
    x = toked_model + pos_encod

    for layer in self.enc_layer:
      x_out = layer(x)

    for layer in self.dec_layer:
      x_final = layer(x, x_out)

    x_final = self.dropout(x_final)
    x_final = self.norm_final(x_final)
    logits = self.linear_final(x_final)

    if targets is None:
      loss = None

    else:
      B, T, C = logits.shape
      logits = logits.view(B*T, C)
      targets = targets.view(B*T)
      loss = F.cross_entropy(logits, targets)

    return logits, loss

  def generate(self, idx, max_new_tokens):
    # idx is (B, T) array of indices in the current context
    for _ in range(max_new_tokens):
        # crop idx to the last block_size tokens
        idx_cond = idx[:, -self.block_size:]
        # get the predictions
        logits, loss = self(idx_cond)
        # focus only on the last time step
        logits = logits[:, -1, :] # becomes (B, C)
        # apply softmax to get probabilities
        probs = F.softmax(logits, dim=-1) # (B, C)
        # sample from the distribution
        idx_next = torch.argmax(probs, dim=-1, keepdim=True) # (B, 1)
        # append sampled index to the running sequence
        idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
    return idx

In [8]:
import tiktoken
tokenizer = tiktoken.get_encoding("p50k_base")
tokenizer = tiktoken.encoding_for_model("text-davinci-003")

vocab_size = tokenizer.n_vocab
model = Transformer(vocab_size)
checkpoint_path = '/content/drive/MyDrive/base-500m.pth'
checkpoint = torch.load(checkpoint_path)
model.load_state_dict(checkpoint)
m = model.to(device)

In [15]:
seed = "why is he like that and"
seed_tokens = tokenizer.encode(seed)
seed_tokens = torch.tensor(seed_tokens, dtype=torch.long, device=device).unsqueeze(0)
generated = m.generate(seed_tokens, max_new_tokens=10)
generated_text = tokenizer.decode(generated[0].tolist())
print(generated_text)

why is he like that and is he he he he is he is that he


In [22]:
import torch
import torch.nn.functional as F
device = 'cuda' if torch.cuda.is_available() else 'cpu'

def generate(model, input_ids, max_length=100, temperature=1.0, top_k=0, top_p=0.0, beam_search=False, beam_size=3):
    """
    Generate text using the provided model.

    Args:
    - model (nn.Module): The Transformer model to use for generation.
    - tokenizer: Tokenizer object to convert tokens to text.
    - seed_text (str): The initial text to start generation.
    - max_length (int): Maximum length of the generated sequence.
    - temperature (float): Temperature scaling parameter for softmax.
    - top_k (int): Number of top tokens to sample from using top-k sampling.
    - top_p (float): Nucleus (top-p) sampling parameter.
    - beam_search (bool): Whether to use beam search for generation.
    - beam_size (int): Size of the beam for beam search.

    Returns:
    - generated_text (str): The generated text sequence.
    """
    model.eval()
    with torch.no_grad():
        current_length = input_ids.size(1)

        while current_length < max_length:
            logits = model(input_ids)[0][:, -1, :]  # Get logits for the last token
            logits /= temperature  # Apply temperature scaling
            filtered_logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
            probabilities = F.softmax(filtered_logits, dim=-1)

            if beam_search is True:
                # Beam search
                beam_scores, beam_indices = torch.topk(probabilities, k=beam_size, dim=-1)
                input_ids = torch.cat([input_ids[:, :, None].expand(-1, -1, beam_size), beam_indices.unsqueeze(1)], dim=2)
                current_length += 1
            else:
                # Sampling
                next_token = torch.argmax(probabilities, dim=-1, keepdim=True)
                input_ids = torch.cat([input_ids, next_token], dim=-1)
                current_length += 1

        return input_ids

def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
    """
    Apply top-k and top-p (nucleus) filtering to logits.

    Args:
    - logits (torch.Tensor): Logits tensor of shape (batch_size, vocab_size).
    - top_k (int): Number of top tokens to keep.
    - top_p (float): Cumulative probability threshold for nucleus sampling.
    - filter_value (float): Value to fill filtered logits outside the top-k and top-p.

    Returns:
    - filtered_logits (torch.Tensor): Filtered logits tensor.
    """
    if top_k > 0:
        logits = top_k_filtering(logits, top_k=top_k, filter_value=filter_value)

    if top_p > 0.0:
        logits = nucleus_sampling(logits, top_p=top_p)

    return logits

def top_k_filtering(logits, top_k=0, filter_value=-float('Inf')):
    """
    Apply top-k filtering to logits.

    Args:
    - logits (torch.Tensor): Logits tensor of shape (batch_size, vocab_size).
    - top_k (int): Number of top tokens to keep.
    - filter_value (float): Value to fill filtered logits outside the top-k.

    Returns:
    - filtered_logits (torch.Tensor): Filtered logits tensor.
    """
    if top_k > 0:
        values, indices = logits.topk(top_k, dim=-1)
        min_values = values[:, -1].unsqueeze(-1).repeat(1, logits.size(-1))
        logits = torch.where(logits < min_values, torch.ones_like(logits) * filter_value, logits)
    return logits

def nucleus_sampling(logits, top_p=0.0):
    """
    Apply nucleus (top-p) sampling to logits.

    Args:
    - logits (torch.Tensor): Logits tensor of shape (batch_size, vocab_size).
    - top_p (float): Cumulative probability threshold for nucleus sampling.

    Returns:
    - sampled_logits (torch.Tensor): Sampled logits tensor.
    """
    sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
    cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

    # Remove tokens with cumulative probability > top_p
    sorted_indices_to_remove = cumulative_probs > top_p
    sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
    sorted_indices_to_remove[..., 0] = 0

    # Set logits of removed tokens to a large negative value
    sampled_logits = sorted_logits.clone()
    sampled_logits[sorted_indices_to_remove] = -float('Inf')

    return sampled_logits

In [26]:
# Generate text using the model
seed_text = "Once upon a time"
input_tokens = tokenizer.encode(seed_text)
print(input_tokens)

input_ids = torch.tensor(input_tokens, dtype=torch.long, device=device).unsqueeze(0)
generated_text = generate(model, input_ids, max_length=50, temperature=1, top_k=20, top_p=0)
generated_text = tokenizer.decode(generated_text[0].tolist())

print("Generated Text:")
print(generated_text)

[7454, 2402, 257, 640]
Generated Text:
Once upon a time learning upon upon upon upon upon upon upon upon upon upon upon upon upon upon upon upon upon upon upon upon upon upon upon upon upon upon upon upon upon upon upon upon upon upon upon upon upon upon upon upon upon upon upon upon upon


In [None]:
def beam_search_decoder(logits, k=3, max_length=128, length_penalty=0.6):
    sequences = [[[100], 0.0]]

    for _ in range(max_length):
        all_candidates = []
        for seq, score in sequences:
            input_ids = seq
            with torch.no_grad():
                outputs = model(input_ids, none)
                next_token_logits = outputs[0][:, -1, :]
                log_probs = F.log_softmax(next_token_logits, dim=-1)

            topk_log_probs, topk_tokens = torch.topk(log_probs, k, dim=-1)
            topk_log_probs = topk_log_probs.cpu().numpy()[0]
            topk_tokens = topk_tokens.cpu().numpy()[0]

            for log_prob, token in zip(topk_log_probs, topk_tokens):
                new_seq = seq + [token]
                new_score = score - log_prob
                all_candidates.append([new_seq, new_score])

        ordered = sorted(all_candidates, key=lambda tup: tup[1])
        sequences = ordered[:k]

    return sequences[0][0]

def generate_output(input_ids, model, max_length=50, beam_width=3, temperature=1.0, top_p=0.9):
    with torch.no_grad():
        logits, _ = model(input_ids)

        # Apply temperature scaling
        logits /= temperature

        # Apply nucleus sampling (top-p sampling)
        sorted_logits, sorted_indices = torch.sort(logits, descending=True)
        cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

        sorted_indices_to_remove = cumulative_probs > top_p
        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
        sorted_indices_to_remove[..., 0] = 0

        indices_to_remove = sorted_indices[sorted_indices_to_remove]
        logits[:, indices_to_remove] = float('-inf')

        # Beam search decoding
        generated_tokens = beam_search_decoder(logits, k=beam_width, max_length=max_length)
    return generated_tokens

# Example usage
input_tokens = "Once upon a time"
context = torch.tensor([tokenizer.encode(input_tokens)], dtype=torch.long, device=device)
output_tokens = generate_output(context, model)
output_text = tokenizer.decode(output_tokens)
print(output_tokens, '\n' ,output_text)