In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
%pwd

'/home/ubuntu/rshaw/BLoRA-TGI/tgi'

In [37]:
import torch
from service.causal_lm import BLoraCausalLMBatch, BLoraCausalLM
from utils import Batch, Request, GenerationParameters

torch.set_default_tensor_type(torch.cuda.HalfTensor)

base_model_id = "decapoda-research/llama-7b-hf"
lora_ids = ["jondurbin/airoboros-7b-gpt4-1.2-peft", "trl-lib/llama-7b-se-rl-peft", 'winddude/wizardLM-LlaMA-LoRA-7B']

model = BLoraCausalLM(
    base_model_id=base_model_id,
    lora_ids=lora_ids
)

Loading checkpoint shards:   0%|          | 0/33 [00:00<?, ?it/s]

The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'LLaMATokenizer'. 
The class this function is called from is 'LlamaTokenizer'.
Using pad_token, but it is not set yet.


In [45]:
inputs = [
    ('Outline a five sentence short story where a character stumbles upon a secret room in their house that contains relics from their future.',
    'jondurbin/airoboros-7b-gpt4-1.2-peft', 50),
    ('Question: What are the various algorithms to sort a list?\n\nAnswer:',
    'trl-lib/llama-7b-se-rl-peft', 100),
    ('### Instruction: Write a poem about the transformers Python library.\n### Response:',
    'winddude/wizardLM-LlaMA-LoRA-7B', 25),
    ('Question: Sculpt a three verse poem about the feeling of walking through a lush, vibrant garden in full bloom.\n\nAnswer:',
    'trl-lib/llama-7b-se-rl-peft', 125),
    ('### Instruction: Develop an eight sentence short story about a character who can bring their dreams into reality. \n### Response:',
    'winddude/wizardLM-LlaMA-LoRA-7B', 50)
]

requests = []
new_tokens = []

for idx, inp in enumerate(inputs):
    requests.append(Request(
        id=idx,
        lora_id=inp[1],
        inputs=inp[0],
        generation_parameters=GenerationParameters(
            max_new_tokens=inp[2]
        )
    ))

batch = Batch(
    id=0,
    requests=requests
)


In [46]:
causal_lm_batch = BLoraCausalLMBatch.from_batch(batch, tokenizer=model.tokenizer, device="cuda")
input_seqs = causal_lm_batch.input_ids.tolist()

generation_dct = {r.id: [] for r in causal_lm_batch.requests}
active_ids = set([r.id for r in requests])

while True:
    stopped_ids = []

    generations, causal_lm_batch = model.generate_token(causal_lm_batch)
    for idx, gen in enumerate(generations):
        if gen.stopped:
            stopped_ids.append(gen.request_id)    
        generation_dct[gen.request_id].append(gen.token_id)

    if len(stopped_ids) > 0:
        if causal_lm_batch is None:
            break

        for stopped_id in stopped_ids:
            active_ids.remove(stopped_id)
        causal_lm_batch.filter(list(active_ids))

for key in generation_dct:
    input_seqs[key].extend(generation_dct[key])
    print(model.tokenizer.decode(input_seqs[key]))
    print("\n")

<unk>Outline a five sentence short story where a character stumbles upon a secret room in their house that contains relics from their future.
The character, who is a young boy named Timmy, stumbles upon a secret room in his house that contained relics from his future. The room was hidden behind a bookcase in the library, and it was filled with strange


<unk><unk><unk><unk><unk><unk><unk><unk><unk>Write a 6 line dialogue between a character and a magical creature that only they can see. Write a 6 line dialogue between a character and a magical creature that only they can see.
Write a 6 line dialogue between a character and a magical creature that only they can see. Write a 6 line dialogue between a character and a magical creature that only they can see.
Write a 6 line dialogue between a character and a magical creature that only they can see. Write a 6 line dialogue between a character and a magical


<unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk>### Instruction: Write a poem

In [None]:
%load_ext autoreload
%autoreload 2

from service.blora_utils import load_loras, prepare_batch
from transformers import LlamaForCausalLM, LlamaTokenizer
from IPython.display import clear_output

model2 = LlamaForCausalLM.from_pretrained(base_model_id, device_map="auto")
tokenizer = LlamaTokenizer.from_pretrained(base_model_id)
tokenizer.pad_token = 0
model2, lora_map = load_loras(model2, lora_ids)

In [41]:
b = prepare_batch(inputs, tokenizer, model2, lora_map)

outputs = []

for out in model2.generate(
    **b,
    max_length=100,
    stream_output=True
):
    outputs.append(out)
    batch_decoded = tokenizer.batch_decode(torch.cat([out.reshape(-1, 1) for out in outputs], dim=1))
    clear_output(wait=True)
    print("\n\n".join([lora + ":\n" + prompt + decoded for (prompt, lora, _), decoded in zip(inputs, batch_decoded)]))

jondurbin/airoboros-7b-gpt4-1.2-peft:
Outline a five sentence short story where a character stumbles upon a secret room in their house that contains relics from their future.
The character, who is a young boy named Timmy, stumbles upon a secret room in his house that contained relics from his future. The room was hidden behind a bookcase in the library, and it was filled with strange artifacts and documents.
Timmy's curiosity got the best of him, and he decided

trl-lib/llama-7b-se-rl-peft:
Write a 6 line dialogue between a character and a magical creature that only they can see.
Write a 6 line dialogue between a character and a magical creature that only they can see.
Write a 6 line dialogue between a character and a magical creature that only they can see. The creature must be a magical creature that is not a human.
Write a 6 line dialogue between a character and

winddude/wizardLM-LlaMA-LoRA-7B:
Describe a four sentence scene where a character discovers a hidden talent that changes 

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


  from .autonotebook import tqdm as notebook_tqdm
Loading checkpoint shards: 100%|██████████| 33/33 [00:08<00:00,  3.79it/s]
The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'LLaMATokenizer'. 
The class this function is called from is 'LlamaTokenizer'.
You are using the default legacy behaviour of the <class 'transformers.models.llama.tokenization_llama.LlamaTokenizer'>. If you see this, DO NOT PANIC! This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=True`. This should only be set if you understand what it means, and thouroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


In [None]:
# b = prepare_batch(inputs, tokenizer, model, lora_map)

# tokens = b.input_ids[0,:].tolist()

# model_kwargs = {
#     "attention_mask": b.attention_mask,
#     "use_cache": True
# }

# input_ids = b.input_ids

# for _ in range(100):
#     model_inputs = model.prepare_inputs_for_generation(input_ids, **model_kwargs)

#     outputs = model(
#         **model_inputs,
#         return_dict=True,
#         output_attentions=False,
#         output_hidden_states=False,
#     )

#     next_token_logits = outputs.logits[:, -1, :]
#     next_tokens = torch.argmax(next_token_logits, dim=-1)
#     tokens.append(next_tokens.item())

#     input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
#     model_kwargs = model._update_model_kwargs_for_generation(
#         outputs, model_kwargs, is_encoder_decoder=False
#     )

#     print(next_tokens.item())
#     print("\n\n")
#     # print(len(tokens))
#     # print(tokenizer.decode(tokens[1:]))

Loading model...


Loading checkpoint shards: 100%|██████████| 33/33 [00:08<00:00,  3.68it/s]


Done!

Loading LORAs...


The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'LLaMATokenizer'. 
The class this function is called from is 'LlamaTokenizer'.
Using pad_token, but it is not set yet.


Done!



In [33]:
print(batch)

Batch(id=0, requests=[Request(id=0, lora_id='trl-lib/llama-7b-se-rl-peft', inputs='Write a 6 line dialogue between a character and a magical creature that only they can see.')])


In [36]:
requests = []
for idx, inp in enumerate(inputs):
    requests.append(Request(
        id=idx,
        lora_id=inp[1],
        inputs=inp[0]
    ))

batch = Batch(
    id=0,
    requests=requests
)

causal_lm_batch = BLoraCausalLMBatch.from_batch(batch, tokenizer=model2.tokenizer, device="cuda")
generation_dct = {idx: [] for idx in range(len(causal_lm_batch.requests))}

for _ in range(100):
    # print(causal_lm_batch.input_ids)
    # print(causal_lm_batch.position_ids)
    # # print(causal_lm_batch.attention_mask)
    # if causal_lm_batch.past_key_values is None:
    #     print("None")
    # else:
    #     print(causal_lm_batch.past_key_values[0][0][0,0,0,0])
    #     print(causal_lm_batch.past_key_values[0][0][0,0,0,1])

    generations, causal_lm_batch = model2.generate_token(causal_lm_batch)
    for idx, gen in enumerate(generations):
        generation_dct[idx].append(gen.token_id.item())

In [38]:
causal_lm_batch = BLoraCausalLMBatch.from_batch(batch, tokenizer=model2.tokenizer, device="cuda")
input_seqs = causal_lm_batch.input_ids.tolist()

key = 0
input_seqs[key].extend(generation_dct[key])
print(model2.tokenizer.decode(input_seqs[key]))

<unk>Write a 6 line dialogue between a character and a magical creature that only they can see. Write a 6 line dialogue between a character and a magical creature that only they can see.
Write a 6 line dialogue between a character and a magical creature that only they can see. Write a 6 line dialogue between a character and a magical creature that only they can see.
Write a 6 line dialogue between a character and a magical creature that only they can see. Write a 6 line dialogue between a character and a magical


In [35]:
model2.active_batch_id = None

In [29]:
model2.model.batch_lora_ids

['jondurbin/airoboros-7b-gpt4-1.2-peft']

In [None]:
for key in generation_dct:
    print(generation_dct[key])