In [1]:
import os
import sys
import numpy as np
import inspect

from transformers import AutoTokenizer, LlamaForCausalLM, GenerationConfig
import torch
import torch.nn.functional as F

In [2]:
sys.path.append("..")

In [3]:
prompts = ["The theory of relativity states that the speed of light is constant in all reference frames"]

In [4]:
model_id = "meta-llama/Llama-3.2-3B"
device = "mps"

In [5]:
os.environ["TOKENIZERS_PARALLELISM"] = "true"
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.add_special_tokens({'pad_token': tokenizer.eos_token})

0

In [6]:
model = LlamaForCausalLM.from_pretrained(model_id).to(device)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [7]:
model.generation_config.pad_token_id = model.config.eos_token_id

In [8]:
inputs = tokenizer(prompts, return_tensors="pt")
print(inputs["input_ids"].numpy().tolist())
inputs = {k:v.to(device) for k,v in inputs.items()}

[[128000, 791, 10334, 315, 1375, 44515, 5415, 430, 279, 4732, 315, 3177, 374, 6926, 304, 682, 5905, 14418]]


In [9]:
generation_kwargs = dict(max_length=19, do_sample=False, temperature=None, top_p=None)

In [10]:
outputs = model.generate(**inputs, **generation_kwargs)
print(outputs.cpu().numpy().tolist())

[[128000, 791, 10334, 315, 1375, 44515, 5415, 430, 279, 4732, 315, 3177, 374, 6926, 304, 682, 5905, 14418, 13]]


### Reproduction

In [None]:
inspect.signature(model.generate).parameters

In [None]:
print(model.generate.__doc__)

In [19]:
model.config

LlamaConfig {
  "_attn_implementation_autoset": true,
  "_name_or_path": "meta-llama/Llama-3.2-3B",
  "architectures": [
    "LlamaForCausalLM"
  ],
  "attention_bias": false,
  "attention_dropout": 0.0,
  "bos_token_id": 128000,
  "eos_token_id": 128001,
  "head_dim": 128,
  "hidden_act": "silu",
  "hidden_size": 3072,
  "initializer_range": 0.02,
  "intermediate_size": 8192,
  "max_position_embeddings": 131072,
  "mlp_bias": false,
  "model_type": "llama",
  "num_attention_heads": 24,
  "num_hidden_layers": 28,
  "num_key_value_heads": 8,
  "pretraining_tp": 1,
  "rms_norm_eps": 1e-05,
  "rope_scaling": {
    "factor": 32.0,
    "high_freq_factor": 4.0,
    "low_freq_factor": 1.0,
    "original_max_position_embeddings": 8192,
    "rope_type": "llama3"
  },
  "rope_theta": 500000.0,
  "tie_word_embeddings": true,
  "torch_dtype": "bfloat16",
  "transformers_version": "4.47.1",
  "use_cache": true,
  "vocab_size": 128256
}

In [20]:
model.generation_config._from_model_config

True

In [26]:
model.generation_config

GenerationConfig {
  "bos_token_id": 128000,
  "do_sample": true,
  "eos_token_id": 128001,
  "temperature": 0.6,
  "top_p": 0.9
}

In [32]:
new_generation_config = GenerationConfig.from_model_config(model.config)
new_generation_config

GenerationConfig {
  "bos_token_id": 128000,
  "eos_token_id": 128001
}

In [33]:
new_generation_config == model.generation_config

False

In [30]:
model.config._get_non_default_generation_parameters()

{}

In [40]:
generation_config, model_kwargs = model._prepare_generation_config(None, **generation_kwargs)
generation_config

GenerationConfig {
  "bos_token_id": 128000,
  "eos_token_id": 128001,
  "max_length": 19,
  "pad_token_id": 128001,
  "temperature": null,
  "top_p": null
}

In [45]:
inspect.signature(model.forward)

<Signature (input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Union[transformers.cache_utils.Cache, List[torch.FloatTensor], NoneType] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, num_logits_to_keep: int = 0, **kwargs: *<class 'transformers.models.llama.modeling_llama.KwargsForCausalLM'>) -> Union[Tuple, transformers.modeling_outputs.CausalLMOutputWithPast]>

In [None]:
print(model.forward.__doc__)

In [12]:
out = model.forward(**inputs)

In [65]:
dir(out)

['__annotations__',
 '__class__',
 '__class_getitem__',
 '__contains__',
 '__dataclass_fields__',
 '__dataclass_params__',
 '__delattr__',
 '__delitem__',
 '__dict__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattribute__',
 '__getitem__',
 '__getstate__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__ior__',
 '__iter__',
 '__le__',
 '__len__',
 '__lt__',
 '__match_args__',
 '__module__',
 '__ne__',
 '__new__',
 '__or__',
 '__post_init__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__reversed__',
 '__ror__',
 '__setattr__',
 '__setitem__',
 '__sizeof__',
 '__str__',
 '__subclasshook__',
 'attentions',
 'clear',
 'copy',
 'fromkeys',
 'get',
 'hidden_states',
 'items',
 'keys',
 'logits',
 'loss',
 'move_to_end',
 'past_key_values',
 'pop',
 'popitem',
 'setdefault',
 'to_tuple',
 'update',
 'values']

In [13]:
logits = out.logits
logits.shape

torch.Size([1, 18, 128256])

In [14]:
model.config.vocab_size

128256

In [15]:
preds = logits[0,-1, :]
preds.shape

torch.Size([128256])

In [18]:
max_pred = preds.argmax(keepdims=True)

In [19]:
max_pred

tensor([13], device='mps:0')

In [21]:
model.model.embed_tokens

Embedding(128256, 3072)

In [24]:
toks = inputs["input_ids"]
toks.shape

torch.Size([1, 18])

In [25]:
h = model.model.embed_tokens(toks)
print(h.shape)
h

torch.Size([1, 18, 3072])


tensor([[[-0.0011, -0.0007, -0.0046,  ..., -0.0015, -0.0021,  0.0018],
         [ 0.0065, -0.0332, -0.0101,  ..., -0.0303,  0.0197, -0.0017],
         [-0.0264, -0.0152, -0.0183,  ...,  0.0131,  0.0369, -0.0364],
         ...,
         [ 0.0013,  0.0025, -0.0155,  ...,  0.0320,  0.0058, -0.0131],
         [ 0.0454, -0.0099,  0.0054,  ..., -0.0061, -0.0052, -0.0125],
         [-0.0330, -0.0001,  0.0117,  ..., -0.0094,  0.0117, -0.0187]]],
       device='mps:0', grad_fn=<EmbeddingBackward0>)

In [27]:
model.config._attn_implementation

'sdpa'