In [1]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from gen2 import GraniteSteerer

steerer = GraniteSteerer()
model = steerer.model
tokenizer = steerer.tokenizer

  from .autonotebook import tqdm as notebook_tqdm


Loading tokenizer and model...


The fast path is not available because one of `(selective_state_update, causal_conv1d_fn, causal_conv1d_update)` is None. Falling back to the naive implementation. To install follow https://github.com/state-spaces/mamba/#installation and https://github.com/Dao-AILab/causal-conv1d
Loading weights: 100%|██████████| 370/370 [00:03<00:00, 99.78it/s, Materializing param=model.norm.weight]                               


In [7]:
ssm_layers = []

for name, module in model.named_modules():
    if "mamba" in name.lower():
        ssm_layers.append((name, module))
    print(name, type(module))

 <class 'transformers.models.granitemoehybrid.modeling_granitemoehybrid.GraniteMoeHybridForCausalLM'>
model <class 'transformers.models.granitemoehybrid.modeling_granitemoehybrid.GraniteMoeHybridModel'>
model.embed_tokens <class 'torch.nn.modules.sparse.Embedding'>
model.layers <class 'torch.nn.modules.container.ModuleList'>
model.layers.0 <class 'transformers.models.granitemoehybrid.modeling_granitemoehybrid.GraniteMoeHybridDecoderLayer'>
model.layers.0.input_layernorm <class 'transformers.models.granitemoehybrid.modeling_granitemoehybrid.GraniteMoeHybridRMSNorm'>
model.layers.0.post_attention_layernorm <class 'transformers.models.granitemoehybrid.modeling_granitemoehybrid.GraniteMoeHybridRMSNorm'>
model.layers.0.shared_mlp <class 'transformers.models.granitemoehybrid.modeling_granitemoehybrid.GraniteMoeHybridMLP'>
model.layers.0.shared_mlp.activation <class 'transformers.activations.SiLUActivation'>
model.layers.0.shared_mlp.input_linear <class 'torch.nn.modules.linear.Linear'>
model

In [10]:
# first_layer = model.model.layers[0].mamba
# last_layer = model.model.layers[-1].mamba

first_layer = model.model.layers[0]
last_layer = model.model.layers[-1]

print(f"Layer 0:\n{first_layer}")
print(f"Layer -1:\n{last_layer}")

Layer 0:
GraniteMoeHybridDecoderLayer(
  (input_layernorm): GraniteMoeHybridRMSNorm((768,), eps=1e-05)
  (post_attention_layernorm): GraniteMoeHybridRMSNorm((768,), eps=1e-05)
  (shared_mlp): GraniteMoeHybridMLP(
    (activation): SiLUActivation()
    (input_linear): Linear(in_features=768, out_features=4096, bias=False)
    (output_linear): Linear(in_features=2048, out_features=768, bias=False)
  )
  (mamba): GraniteMoeHybridMambaLayer(
    (act): SiLUActivation()
    (conv1d): Conv1d(1792, 1792, kernel_size=(4,), stride=(1,), padding=(3,), groups=1792)
    (in_proj): Linear(in_features=768, out_features=3376, bias=False)
    (norm): GraniteMoeHybridRMSNormGated()
    (out_proj): Linear(in_features=1536, out_features=768, bias=False)
  )
)
Layer -1:
GraniteMoeHybridDecoderLayer(
  (input_layernorm): GraniteMoeHybridRMSNorm((768,), eps=1e-05)
  (post_attention_layernorm): GraniteMoeHybridRMSNorm((768,), eps=1e-05)
  (shared_mlp): GraniteMoeHybridMLP(
    (activation): SiLUActivation()


In [5]:
captured = {}

def capture(name):
    def hook(module, inp, output):
        if isinstance(output, tuple):
            captured[name] = output[0].detach()
        else:
            captured[name] = output.detach()
    return hook

h_first = first_layer.register_forward_hook(capture("first"))
h_last  = last_layer.register_forward_hook(capture("last"))

In [6]:
prompt = "Why did the United States declare independence from Britain?"

inputs = tokenizer(prompt, return_tensors="pt").to(steerer.model.device)

with torch.no_grad():
    _ = model(**inputs)


print(captured["first"].shape)
print(captured["last"].shape)

torch.Size([1, 10, 768])
torch.Size([1, 10, 768])


In [None]:
print("first norm:", captured["first"].norm())
print("last norm:", captured["last"].norm())

In [None]:
lm_head = model.lm_head

logits_first = lm_head(captured["first"])
logits_last  = lm_head(captured["last"])

tok_first = logits_first[0, -1].argmax()
tok_last  = logits_last[0, -1].argmax()

print("first guess:", tokenizer.decode(tok_first))
print("last guess:", tokenizer.decode(tok_last))

## Generation Loop

In [None]:
model.eval()

prompt = "Explain why the sky is blue"
inputs = steerer.tokenizer(prompt, return_tensors="pt").to(steerer.model.device)

input_ids = inputs.input_ids
past_key_values = None

In [None]:
max_new_tokens = 5

for step in range(max_new_tokens):
    with torch.no_grad():
        outputs = steerer.model(
            input_ids=input_ids,
            past_key_values=past_key_values,
            use_cache=True,
            output_hidden_states=True,
            return_dict=True
        )

    logits = outputs.logits[:, -1, :]
    past_key_values = outputs.past_key_values
    hidden_states = outputs.hidden_states

    # greedy decode
    next_token = logits.argmax(dim=-1, keepdim=True)

    input_ids = next_token

    print(tokenizer.decode(next_token[0]), end="", flush=True)

In [None]:
final_hidden = hidden_states[-1][:, -1]
print(final_hidden.norm())