In [1]:
import torch 
import torch.nn as nn 

from llama2 import Llama2Model

torch.manual_seed(123)

<torch._C.Generator at 0x7b95261b3a10>

In [2]:
class KVCache:
  def __init__(self, n_layers):
    self.cache = [None] * n_layers

  def get(self, layer_idx):
    return self.cache[layer_idx]

  def update(self, layer_idx, value):
    self.cache[layer_idx] = value

  def get_all(self):
    return self.cache 

  def reset(self):
    for i in range(len(self.cache)):
      self.cache[i] = None

In [3]:
LLAMA2_CONFIG_7B = {
    "vocab_size": 32000,     # Vocabulary size
    "ctx_len": 4096,         # Context length
    "d_model": 4096,         # Embedding dimension
    "n_heads": 32,           # Number of attention heads
    "n_layers": 32,          # Number of layers
    "d_ff": 11008,           # NEW: Size of the intermediate dimension in FeedForward
    # "dtype": torch.bfloat16  # NEW: Lower-precision dtype to reduce memory usage
}

model = Llama2Model(LLAMA2_CONFIG_7B)

In [4]:
encoded = [0, 1, 2, 3, 4]
encoded_tensor = torch.tensor(encoded).unsqueeze(0)
encoded_tensor.shape

torch.Size([1, 5])

In [5]:
model(encoded_tensor)

tensor([[[ 0.1468,  0.9266,  0.4884,  ..., -0.1634, -1.3636, -0.1411],
         [-0.0795,  1.5225,  0.2883,  ..., -0.3145, -1.2785,  0.2247],
         [-0.3107,  1.0335,  0.7911,  ..., -0.9444, -1.0664, -0.6065],
         [-0.5140,  0.8891,  0.5288,  ..., -0.8170, -0.0885,  0.7101],
         [-0.1581,  0.6680,  0.3907,  ..., -0.6547, -0.8114,  0.3739]]],
       grad_fn=<ViewBackward0>)

In [6]:
def generate_text_simple(model, idx, max_new_tokens, ctx_len, use_cache=True):
  """
  idx: (batch_size, seq_len)
  """

  model.eval()
  with torch.no_grad():
    if use_cache:
      cache = KVCache(model.cfg["n_layers"])
      model.reset_kv_cache()
      logits = model(idx[:,-ctx_len:], cache=cache)
      for _ in range(max_new_tokens):
        next_idx = logits[:,-1,:].argmax(dim=-1, keepdim=True)
        idx = torch.cat((idx, next_idx), dim=1)
        logits = model(next_idx, cache=cache)
    else:
      for _ in range(max_new_tokens):
        logits = model(idx[:,-ctx_len:], cache=None)
        next_idx = logits[:,-1,:].argmax(dim=-1, keepdim=True)
        idx = torch.cat((idx, next_idx), dim=1)

  return idx

In [7]:
generate_text_simple(model, encoded_tensor, 10, 4096, use_cache=True)

tensor([[    0,     1,     2,     3,     4, 16062, 17784,  6099,  3964, 24148,
          6537, 12469, 23455, 14386, 13850]])

In [8]:
generate_text_simple(model, encoded_tensor, 10, 4096, use_cache=False)

tensor([[    0,     1,     2,     3,     4, 16062, 17784,  6099,  3964, 24148,
          6537, 12469, 23455, 14386, 13850]])