# Model registration
Based on https://huggingface.co/docs/transformers/main/en/custom_models#registering-a-model-with-custom-code-to-the-auto-classes

In [1]:
# %load_ext autoreload
# %autoreload 2
import transformers
from multihead_models import *
import torch



In [2]:
mhllamaconfig = MHLlamaConfig()
mhllamaconfig.save_pretrained("mhllama")

In [3]:
model = MultiheadLlamaForCausalLM(mhllamaconfig)
torch.save(model.state_dict(), './mhllama/pytorch_model.bin')

KeyboardInterrupt: 

In [2]:
transformers.AutoConfig.register('mhllama', MHLlamaConfig)
# transformers.AutoModel.register(MHLlamaConfig, MultiheadLlamaForCausalLM)
transformers.AutoModelForCausalLM.register(MHLlamaConfig, MultiheadLlamaForCausalLM)

## Try

In [4]:
model = transformers.AutoModelForCausalLM.from_pretrained('/home/sonia/llama-qlora/mhllama')

MultiheadLlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 4096, padding_idx=0)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (up_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (down_proj): Linear(in_features=11008, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_layernorm): LlamaRMSNorm()
      )
    )
    (norm): LlamaRMSNor

: 

# Old

`model.base_model.model.model` stays unchanged, while `model.base_model.model.lm_head` gets duplicated up to however many heads we want. For each head, apply some mask that will only predict tokens from that head's vocabulary

ohmodel: PeftModelForCausalLM > LoraModel > LlamaForCausalLM > LlamaModel

Llama does GREEDY_SEARCH when not multihead MLM

## MH model construction

In [1]:
%load_ext autoreload
%autoreload 2
from peft import AutoPeftModelForCausalLM, PeftModelForCausalLM, get_peft_config
from transformers import AutoTokenizer, LlamaTokenizer, AutoModelForCausalLM
from multihead_models import MultiHeadPeftModelForCausalLM
import torch
import json
import os

In [2]:
tokenizer = LlamaTokenizer.from_pretrained('/mnt/data/zoo/llama2/llama2-7b-hf/')
inp = tokenizer(
    ['There is a Ubuntu server visible at IP 43.205.13.243, port 22, offering the service cpe:/a:openbsd:openssh:8.2p1 Ubuntu-4ubuntu0.5.\n',], 
    return_tensors="pt",
)  # Batch size 1
inp = {x:inp[x].cuda(0) for x in inp}

step=60
output_dir = f'/mnt/data/sonia/ckpts/sent3/checkpoint-{step}/'
print(output_dir)
ohmodel = AutoPeftModelForCausalLM.from_pretrained(output_dir, device_map="auto", torch_dtype=torch.bfloat16)
# outputs = ohmodel.generate(**inp, max_new_tokens=58)
# out=tokenizer.batch_decode(outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False)
# print(out[0])

/mnt/data/sonia/ckpts/sent3/checkpoint-60/


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [3]:
ohmodel.peft_config['default']

LoraConfig(peft_type=<PeftType.LORA: 'LORA'>, auto_mapping=None, base_model_name_or_path='/mnt/data/zoo/llama2/llama2-7b-hf/', revision=None, task_type='CAUSAL_LM', inference_mode=True, r=64, target_modules={'k_proj', 'down_proj', 'gate_proj', 'v_proj', 'o_proj', 'up_proj', 'q_proj'}, lora_alpha=16.0, lora_dropout=0.1, fan_in_fan_out=False, bias='none', modules_to_save=None, init_lora_weights=True, layers_to_transform=None, layers_pattern=None, rank_pattern={}, alpha_pattern={}, megatron_config=None, megatron_core='megatron.core', loftq_config={})

In [4]:
mhmodel = MultiHeadPeftModelForCausalLM.from_one_head(ohmodel, 4*[torch.ones(32000)], [1])

In [None]:
inter = mhmodel.generate(**inp, max_length=100, do_mlm_sample=False)
inter

In [15]:
tokenizer.batch_decode(inter, skip_special_tokens=True, clean_up_tokenization_spaces=False)

['There is a Ubuntu server visible at IP 43.205.13.243, port 22, offering the service cpe:/a:openbsd:openssh:8.2p1 Ubuntu-4ubuntu0.5.\nThe server is running OpenSSH 8.2p1 Ubuntu-4ubuntu0.5.\nThe server is running OpenSSH 8.2p1 Ubuntu-4ubuntu0.5']

In [16]:
len(inp['input_ids'][0])

58

In [17]:
len(inter[0])

100

## "Cloze"

~ special token 3695

In [37]:
inp = tokenizer(
    ['There is a ~ server visible at IP ~, port ~, offering the service ~\n',],
    #  'There is a ~ server visible at IP ~, port ~, offering the service ~\n', 
    # 'There is a ~ server visible at IP ~, port ~, offering the service ~\n',], 
    return_tensors="pt",
)  # Batch size 1
inp = {x:inp[x].cuda(0) for x in inp}
inp

{'input_ids': tensor([[    1,  1670,   338,   263,  3695,  1923,  7962,   472,  5641,  3695,
          29892,  2011,  3695, 29892, 27032,   278,  2669,  3695,    13]],
        device='cuda:0'),
 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]],
        device='cuda:0')}

In [38]:
inter = mhmodel.generate(**inp, max_length=100, do_mlm_sample=True)
inter

model inputs torch.Size([1, 19])
hidden states torch.Size([1, 19, 4096])
tensor([[    1,  1670,   338,   263,  3695,  1923,  7962,   472,  5641,  3695,
         29892,  2011,  3695, 29892, 27032,   278,  2669,  3695,    13]],
       device='cuda:0')
0 4 tensor(3695, device='cuda:0')
tensor([[    1,  1670,   338,   263,  3287,  1923,  7962,   472,  5641,  3695,
         29892,  2011,  3695, 29892, 27032,   278,  2669,  3695,    13]],
       device='cuda:0')
0 9 tensor(3695, device='cuda:0')
tensor([[    1,  1670,   338,   263,  3287,  1923,  7962,   472,  5641,  3211,
         29892,  2011,  3695, 29892, 27032,   278,  2669,  3695,    13]],
       device='cuda:0')
0 12 tensor(3695, device='cuda:0')
tensor([[    1,  1670,   338,   263,  3287,  1923,  7962,   472,  5641,  3211,
         29892,  2011,  3695, 29892, 27032,   278,  2669,  3695,    13]],
       device='cuda:0')
0 17 tensor(3695, device='cuda:0')
tensor([[    1,  1670,   338,   263,  3287,  1923,  7962,   472,  5641,  3211,
  

[tensor([[    1,  1670,   338,   263,  3695,  1923,  7962,   472,  5641,  3695,
          29892,  2011,  3695, 29892, 27032,   278,  2669,  3695,    13]],
        device='cuda:0'),
 tensor([[    1,  1670,   338,   263,  3287,  1923,  7962,   472,  5641,  3695,
          29892,  2011,  3695, 29892, 27032,   278,  2669,  3695,    13]],
        device='cuda:0'),
 tensor([[    1,  1670,   338,   263,  3287,  1923,  7962,   472,  5641,  3211,
          29892,  2011,  3695, 29892, 27032,   278,  2669,  3695,    13]],
        device='cuda:0'),
 tensor([[    1,  1670,   338,   263,  3287,  1923,  7962,   472,  5641,  3211,
          29892,  2011,  3695, 29892, 27032,   278,  2669,  3695,    13]],
        device='cuda:0'),
 tensor([[    1,  1670,   338,   263,  3287,  1923,  7962,   472,  5641,  3211,
          29892,  2011,  3695, 29892, 27032,   278,  2669,  3695,    13]],
        device='cuda:0')]

In [39]:
for r in inter:
    print(tokenizer.batch_decode(r, skip_special_tokens=False, clean_up_tokenization_spaces=False))

['<s> There is a ~ server visible at IP ~, port ~, offering the service ~\n']
['<s> There is a lot server visible at IP ~, port ~, offering the service ~\n']
['<s> There is a lot server visible at IP address, port ~, offering the service ~\n']
['<s> There is a lot server visible at IP address, port ~, offering the service ~\n']
['<s> There is a lot server visible at IP address, port ~, offering the service ~\n']


In [25]:
tokenizer.batch_decode([[5641]], skip_special_tokens=False, clean_up_tokenization_spaces=False)

['IP']