In [None]:
%pip install transformers accelerate bitsandbytes

In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from lattent import TTTForCausalLM, TTTConfig, TTT_STANDARD_CONFIGS
import torch

In [None]:
# Set CUDA Device
device_num = 0

if torch.cuda.is_available() and device_num != -1:
    torch.cuda.set_device(device_num)
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
    device_num = -1  # cpu
print(f"INFO: Using device - {device}:{device_num}")

In [None]:
model_id = "meta-llama/Llama-2-7b-hf"

In [None]:
# Common Tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_id)

In [None]:
# Initializing a TTT ttt-1b style configuration
# configuration = TTTConfig(**TTT_STANDARD_CONFIGS['1b']) is equivalent to the following
configuration = TTTConfig()
configuration

In [None]:
# Initializing a model from the ttt-1b style configuration
model = TTTForCausalLM(configuration)
model.to(device)
model.eval()

In [None]:
input_text = "Greeting from TTT!"

inf_params = dict(
    input_ids=tokenizer(input_text, return_tensors="pt").to(device).input_ids,
    max_length=50,
    do_sample=True,
    top_k=50,
    top_p=0.95,
    temperature=0.7,
    num_return_sequences=1,
    pad_token_id=tokenizer.eos_token_id
)

In [None]:
# Inference using TTT
with torch.no_grad():
    out_ids = model.generate(**inf_params)
    print(*tokenizer.batch_decode(out_ids, skip_special_tokens=True))