In [1]:
import torch 
import torch.nn as nn 

from llama2 import Llama2Model

In [2]:
LLAMA2_CONFIG_7B = {
    "vocab_size": 32000,     # Vocabulary size
    "ctx_len": 4096,         # Context length
    "d_model": 4096,         # Embedding dimension
    "n_heads": 32,           # Number of attention heads
    "n_layers": 32,          # Number of layers
    "d_ff": 11008,           # NEW: Size of the intermediate dimension in FeedForward
    # "dtype": torch.bfloat16  # NEW: Lower-precision dtype to reduce memory usage
}

model = Llama2Model(LLAMA2_CONFIG_7B)

In [3]:
encoded = [0, 1, 2, 3, 4]
encoded_tensor = torch.tensor(encoded).unsqueeze(0)
encoded_tensor.shape

torch.Size([1, 5])

In [None]:
model(encoded_tensor)

tensor([[[-0.6122,  1.0840,  1.6156,  ...,  0.3889,  0.6667,  0.1819],
         [-0.0839,  0.9597,  0.4589,  ...,  0.3129,  0.5975, -0.2102],
         [-0.1689,  0.6633,  0.6354,  ..., -0.1950,  0.6698,  0.4925],
         [-0.5230,  1.0820,  0.4338,  ...,  0.4965,  0.4614,  0.5027],
         [-0.7566,  1.6268,  0.1110,  ...,  0.7826,  0.7723,  0.7949]]],
       grad_fn=<ViewBackward0>)

In [None]:
def generate_text_simple(model, idx, max_new_tokens, ctx_len):
  for _ in range(max_new_tokens):
    idx_cond = idx[:, -ctx_len:]
    logits = model(idx_cond)
    next_idx = logits[:,-1,:].argmax(dim=-1, keepdim=True)
    idx = torch.cat((idx, next_idx), dim=1)
  return idx

In [6]:
generate_text_simple(model, encoded_tensor, 10, 4096)

tensor([[    0,     1,     2,     3,     4,  1115,  3714, 17598, 11235,  8427,
         31261,  9118, 28991, 28706,   384]])