In [18]:
from src.models.modelling_llama_skip import LlamaSkipConnectionForCausalLM, LlamaSkipConnectionConfig
from transformers.models.llama import LlamaForCausalLM

In [19]:
from transformers import pipeline, AutoModelForCausalLM, AutoConfig, AutoModel, BitsAndBytesConfig

checkpoint = "meta-llama/Llama-3.2-1B-Instruct"
AutoConfig.register("llama-skip", LlamaSkipConnectionConfig)
AutoModel.register(LlamaSkipConnectionConfig, LlamaSkipConnectionForCausalLM)
AutoModelForCausalLM.register(LlamaSkipConnectionConfig, LlamaSkipConnectionForCausalLM)


llamaSkipConfig = LlamaSkipConnectionConfig.from_json_file("./configs/llama_skip_causal.json")
llamaSkipModel = LlamaSkipConnectionForCausalLM.from_pretrained(checkpoint, config=llamaSkipConfig)

Some weights of LlamaSkipConnectionForCausalLM were not initialized from the model checkpoint at meta-llama/Llama-3.2-1B-Instruct and are newly initialized: ['model.layers.0.mlp.lora_gate_proj.0.weight', 'model.layers.0.mlp.lora_gate_proj.1.weight', 'model.layers.1.mlp.lora_gate_proj.0.weight', 'model.layers.1.mlp.lora_gate_proj.1.weight', 'model.layers.10.mlp.lora_gate_proj.0.weight', 'model.layers.10.mlp.lora_gate_proj.1.weight', 'model.layers.11.mlp.lora_gate_proj.0.weight', 'model.layers.11.mlp.lora_gate_proj.1.weight', 'model.layers.12.mlp.lora_gate_proj.0.weight', 'model.layers.12.mlp.lora_gate_proj.1.weight', 'model.layers.13.mlp.lora_gate_proj.0.weight', 'model.layers.13.mlp.lora_gate_proj.1.weight', 'model.layers.14.mlp.lora_gate_proj.0.weight', 'model.layers.14.mlp.lora_gate_proj.1.weight', 'model.layers.15.mlp.lora_gate_proj.0.weight', 'model.layers.15.mlp.lora_gate_proj.1.weight', 'model.layers.2.mlp.lora_gate_proj.0.weight', 'model.layers.2.mlp.lora_gate_proj.1.weight', 'm

In [20]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(checkpoint)
sequence = "In a hole in the ground there lived a hobbit."
input= tokenizer(sequence, return_tensors='pt').input_ids
llamaSkipModel.eval()


LlamaSkipConnectionForCausalLM(
  (model): LlamaSkipConnectionModel(
    (embed_tokens): Embedding(128256, 2048)
    (layers): ModuleList(
      (0-15): 16 x LlamaSkipDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2048, out_features=512, bias=False)
          (v_proj): Linear(in_features=2048, out_features=512, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaSkipMLP(
          (gate_proj): Linear(in_features=2048, out_features=8192, bias=False)
          (up_proj): Linear(in_features=2048, out_features=8192, bias=False)
          (down_proj): Linear(in_features=8192, out_features=2048, bias=False)
          (act_fn): SiLU()
          (lora_gate_proj): Sequential(
            (0): Linear(in_features=2048, out_features=1638, bias=False)
            (1)

In [21]:
model_id = "meta-llama/Llama-3.2-1B-Instruct"

pipe = pipeline(
    "text-generation",
    model=llamaSkipModel,
    tokenizer=tokenizer,
    max_new_tokens = 1000,
    eos_token_id=tokenizer.eos_token_id
)
messages = [
    {"role": "system", "content": "You are a pirate chatbot who always responds in pirate speak!"},
    {"role": "user", "content": "Who are you?"},
]
# out = pipe.model.generate(input)
# tokenizer.decode(out[0])

Hardware accelerator e.g. GPU is available in the environment, but no `device` argument is passed to the `Pipeline` object. Model will be on CPU.


In [22]:
standardPipe = pipeline(
    "text-generation",
    model=checkpoint,
    tokenizer=tokenizer,
    max_new_tokens = 1000,
    eos_token_id=tokenizer.eos_token_id
)

Hardware accelerator e.g. GPU is available in the environment, but no `device` argument is passed to the `Pipeline` object. Model will be on CPU.


In [30]:
import time

start1 = time.time()
for i in range(15):
    out = standardPipe.model.forward(input, use_cache=False)

start2 = time.time()

for i in range(15):
    out = pipe.model.forward(input,use_cache=False)


start3 = time.time()

print(start3-start2, start2-start1)


5.395598411560059 11.420119524002075
