In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
%pwd

'/home/ubuntu/rshaw/BLoRA-TGI/tgi'

### TextGenerationRouter - without batching task

In [3]:
from router import TextGenerationRouter, batching_task
from utils import GenerateRequest, GenerateParameters, GenerateRequestInputs

base_model_id = "decapoda-research/llama-7b-hf"
lora_ids = ["jondurbin/airoboros-7b-gpt4-1.2-peft", "trl-lib/llama-7b-se-rl-peft", 'winddude/wizardLM-LlaMA-LoRA-7B']

router = TextGenerationRouter(
    base_model_id=base_model_id,
    lora_ids=lora_ids
)

Loading checkpoint shards:   0%|          | 0/33 [00:00<?, ?it/s]

The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'LLaMATokenizer'. 
The class this function is called from is 'LlamaTokenizer'.
You are using the default legacy behaviour of the <class 'transformers.models.llama.tokenization_llama.LlamaTokenizer'>. If you see this, DO NOT PANIC! This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=True`. This should only be set if you understand what it means, and thouroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565
Using pad_token, but it is not set yet.


In [6]:
inputs = [
    ('Outline a five sentence short story about the Patriots',
    'jondurbin/airoboros-7b-gpt4-1.2-peft', 125),
    ('Question: What are the various algorithms to sort a list?\n\nAnswer:',
    'trl-lib/llama-7b-se-rl-peft', 50),
    ('### Instruction: Write a poem about the transformers Python library.\n### Response:',
    'winddude/wizardLM-LlaMA-LoRA-7B', 100),
    ('Question: How can I write a Java function to generate the nth Fibonacci number?\n\nAnswer:',
    'trl-lib/llama-7b-se-rl-peft', 50),
    ('### Instruction: Develop an eight sentence short story about a character who can bring their dreams into reality. \n### Response:',
    'winddude/wizardLM-LlaMA-LoRA-7B', 75)
]

generate_request_inputs = [
    GenerateRequestInputs(
        inputs=inp[0],
        lora_id=inp[1],
        generate_parameters=GenerateParameters(
            max_new_tokens=inp[2]
        )
    ) for inp in inputs
]

gr_lst = [
    GenerateRequest.from_gr_inputs(gr_inputs) 
    for gr_inputs in generate_request_inputs
]

In [7]:
idx = 0

# first prefill
print(gr_lst[idx])
router.submit_request(gr_lst[idx])
idx += 1

next_batch = router.queue.next_batch(block=False)
assert next_batch is not None
batch, generate_requests = next_batch

cached_batch = router.prefill(
    batch=batch,
    generate_requests=generate_requests
)

# run a few decodes
next_batch = router.queue.next_batch(block=False)
assert next_batch is None

for _ in range(10):
    if cached_batch is None:
        break
    
    batches = [cached_batch]
    cached_batch = router.decode(
        batches=batches,
        generate_requests=generate_requests
    )

GenerateRequest(inputs='Outline a five sentence short story about the Patriots', lora_id='jondurbin/airoboros-7b-gpt4-1.2-peft', generate_parameters=GenerateParameters(max_new_tokens=125), response_stream=<queue.Queue object at 0x7f1dc8f718e0>)


In [8]:
# add a prefill
print(gr_lst[idx])
router.submit_request(gr_lst[idx])
idx += 1

next_batch = router.queue.next_batch(block=False)
assert next_batch is not None
new_batch, new_generate_requests = next_batch

new_cached_batch = router.prefill(
    batch=new_batch,
    generate_requests=new_generate_requests
)

if new_cached_batch is not None:
    batches.append(new_cached_batch)
    assert len(generate_requests.keys() & new_generate_requests.keys()) == 0
    generate_requests.update(new_generate_requests)

# decode
cached_batch = router.decode(
    batches=batches,
    generate_requests=generate_requests
)

# run decodes
for i in range(50):
    print(i)
    if cached_batch is None:
        break

    batches = [cached_batch]
    cached_batch = router.decode(
        batches=batches,
        generate_requests=generate_requests
    )
    batches = [cached_batch]    

GenerateRequest(inputs='Question: What are the various algorithms to sort a list?\n\nAnswer:', lora_id='trl-lib/llama-7b-se-rl-peft', generate_parameters=GenerateParameters(max_new_tokens=50), response_stream=<queue.Queue object at 0x7f1dc8f73910>)
0


ValueError: Batch ID 1 not found in cache.

In [9]:
batches

[CachedBatch(batch_id=0, request_ids=[0]),
 CachedBatch(batch_id=1, request_ids=[1])]

In [16]:
router.service.cache.cache

{}

### TextGenerationService

In [15]:
import torch
from service.service import TextGenerationService
from utils import Batch, Request, GenerationParameters

torch.set_default_tensor_type(torch.cuda.HalfTensor)

base_model_id = "decapoda-research/llama-7b-hf"
lora_ids = ["jondurbin/airoboros-7b-gpt4-1.2-peft", "trl-lib/llama-7b-se-rl-peft", 'winddude/wizardLM-LlaMA-LoRA-7B']

service = TextGenerationService(
    base_model_id=base_model_id,
    lora_ids=lora_ids
)

Loading checkpoint shards:   0%|          | 0/33 [00:00<?, ?it/s]

The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'LLaMATokenizer'. 
The class this function is called from is 'LlamaTokenizer'.
Using pad_token, but it is not set yet.


In [35]:
inputs = [
    ('Outline a five sentence short story about the Patriots',
    'jondurbin/airoboros-7b-gpt4-1.2-peft', 125),
    ('Question: What are the various algorithms to sort a list?\n\nAnswer:',
    'trl-lib/llama-7b-se-rl-peft', 50),
    ('### Instruction: Write a poem about the transformers Python library.\n### Response:',
    'winddude/wizardLM-LlaMA-LoRA-7B', 100),
    ('Question: How can I write a Java function to generate the nth Fibonacci number?\n\nAnswer:',
    'trl-lib/llama-7b-se-rl-peft', 50),
    ('### Instruction: Develop an eight sentence short story about a character who can bring their dreams into reality. \n### Response:',
    'winddude/wizardLM-LlaMA-LoRA-7B', 75)
]

requests = []

for idx, inp in enumerate(inputs):
    requests.append(Request(
        id=idx,
        lora_id=inp[1],
        inputs=inp[0],
        generation_parameters=GenerationParameters(
            max_new_tokens=inp[2]
        )
    ))

batches = []
for idx, request in enumerate(requests):
    batches.append(Batch(id=idx, requests=[request]))

In [36]:
input_seqs = []
generation_dict = {}

for batch in batches:
    for req in batch.requests:
        generation_dict[req.id] = []
        input_seqs.append(service.model.tokenizer(req.inputs).input_ids)

active_ids = set()
tokens_b4_prefill = [25, 15, 10, 25, 100]

iterator = zip(batches, tokens_b4_prefill)

cached_batches = []
for idx, (batch, tokens_to_generate) in enumerate(iterator):
    generations, cached_b = service.Prefill(batch)
    for gen in generations:
        generation_dict[gen.request_id].append(gen.token_id)
    
    cached_batches.append(cached_b)

    should_break = False
    for _ in range(tokens_to_generate):
        generations, cached_b = service.Decode(cached_batches)

        for gen in generations:
            generation_dict[gen.request_id].append(gen.token_id)
            if gen.stopped:
                if cached_b is None:
                    should_break = True
                    break
                request_ids = cached_b.request_ids.copy()
                request_ids.remove(gen.request_id)
                cached_b = service.FilterBatch(cached_b.batch_id, request_ids=request_ids)

        if should_break:
            break
        
        cached_batches = [cached_b]

In [37]:
def print_gen(tokenizer, input_seqs, generation_dict):
    for idx in generation_dict:
        tokens = input_seqs[idx].copy()    
        tokens.extend(generation_dict[idx])
    
        print(tokenizer.decode(tokens))
        print("\n\n")

print_gen(service.model.tokenizer, input_seqs, generation_dict)

<unk>Outline a five sentence short story about the Patriots winning the Super Bowl.
1. The New England Patriots, led by their legendary quarterback Tom Brady, were determined to make history once again.
2. With their powerful offense and stifling defense, the Pats dominated the AFC Championship against the Kansas City Chiefs.
3. In the Super Bowl, they faced off against the Los Angeles Rams, who had an equally impressive roster.
4. However, the Patriots' experience and determination proved too much for the Rams, as they secured their sixth Super Bowl victory.
5. The Patri



<unk>Question: What are the various algorithms to sort a list?

Answer: The most common algorithms are:
\begin{itemize}
\item [Bubble sort](http://en.wikipedia.org/wiki/Bubble_sort)
\item [Selection sort](http://en.wikipedia.org/



<unk>### Instruction: Write a poem about the transformers Python library.
### Response:
Transformers Python library,
A powerful tool for data processing,
It can transform data into new 

### BLoraCausalLM

In [4]:
import torch
from service.causal_lm import BLoraCausalLMBatch, BLoraCausalLM
from utils import Batch, Request, GenerationParameters

torch.set_default_tensor_type(torch.cuda.HalfTensor)

base_model_id = "decapoda-research/llama-7b-hf"
lora_ids = ["jondurbin/airoboros-7b-gpt4-1.2-peft", "trl-lib/llama-7b-se-rl-peft", 'winddude/wizardLM-LlaMA-LoRA-7B']

model = BLoraCausalLM(
    base_model_id=base_model_id,
    lora_ids=lora_ids
)

def prefill(model, batch, gen_dict):
    generations, batch = model.generate_token(batch)
    for gen in generations:
        gen_dict[gen.request_id].append(gen.token_id)

    return batch

def decode(model, batch, gen_dict, active_ids):
    stopped_ids = []
    generations, batch = model.generate_token(batch)

    for gen in generations:
        if gen.stopped:
            stopped_ids.append(gen.request_id)    
        gen_dict[gen.request_id].append(gen.token_id)

    if len(stopped_ids) > 0:
        if batch is None:
            return batch
        for stopped_id in stopped_ids:
            active_ids.remove(stopped_id)
        batch.filter(list(active_ids))

    return batch

def print_gen(tokenizer, input_seqs, generation_dict):
    for idx in generation_dict:
        tokens = input_seqs[idx].copy()    
        tokens.extend(generation_dict[idx])
    
        print(tokenizer.decode(tokens))
        print("\n\n")

In [11]:
inputs = [
    ('Outline a five sentence short story about the Patriots',
    'jondurbin/airoboros-7b-gpt4-1.2-peft', 100),
    ('Question: What are the various algorithms to sort a list?\n\nAnswer:',
    'trl-lib/llama-7b-se-rl-peft', 25),
    ('### Instruction: Write a poem about the transformers Python library.\n### Response:',
    'winddude/wizardLM-LlaMA-LoRA-7B', 75),
    ('Question: How can I write a Java function to generate the nth Fibonacci number?\n\nAnswer:',
    'trl-lib/llama-7b-se-rl-peft', 50),
    ('### Instruction: Develop an eight sentence short story about a character who can bring their dreams into reality. \n### Response:',
    'winddude/wizardLM-LlaMA-LoRA-7B', 50)
]

requests = []

for idx, inp in enumerate(inputs):
    requests.append(Request(
        id=idx,
        lora_id=inp[1],
        inputs=inp[0],
        generation_parameters=GenerationParameters(
            max_new_tokens=inp[2]
        )
    ))

batches = []
for idx, request in enumerate(requests):
    batches.append(Batch(id=idx, requests=[request]))

In [12]:
clm_batches = [BLoraCausalLMBatch.from_batch(batch, tokenizer=model.tokenizer, device="cuda") for batch in batches]
input_seqs = []
generation_dict = {}
for clm_batch in clm_batches:
    input_seqs.extend(clm_batch.input_ids.tolist())
    for req in clm_batch.requests:
        generation_dict[req.id] = []

active_ids = set()
tokens_b4_prefill = [25, 15, 10, 25, 100]

# PREFILL
clm_batch = prefill(model, clm_batches[0], generation_dict)
for r in clm_batch.requests:
    active_ids.add(r.id)

# DECODE
for _ in range(tokens_b4_prefill[0]):
    clm_batch = decode(model, clm_batch, generation_dict, active_ids)

iterator = zip(clm_batches[1:], tokens_b4_prefill[1:])
for clm_batch_new, tokens_to_generate in iterator:

    # PREFILL
    clm_batch_new = prefill(model, clm_batch_new, generation_dict)
    for r in clm_batch_new.requests:
        active_ids.add(r.id)

    # CONCATENATE
    clm_batch = BLoraCausalLMBatch.concatenate(
        batches=[clm_batch, clm_batch_new]
    )

    # DECODE LOOP
    for i in range(tokens_to_generate):
        clm_batch = decode(model, clm_batch, generation_dict, active_ids)

        if clm_batch is None:
            break

print_gen(model.tokenizer, input_seqs, generation_dict)

<unk>Outline a five sentence short story about the Patriots winning the Super Bowl.
1. The New England Patriots, led by their legendary quarterback Tom Brady, were determined to make history once again.
2. With their powerful offense and stifling defense, the Pats dominated the AFC Championship against the Kansas City Chiefs.
3. In the Super Bowl, they faced off against the Los Angeles Rams, who had an equally impressive roster.
4. However, the Patriots' experience



<unk>Question: What are the various algorithms to sort a list?

Answer: The most common algorithms are:
\begin{itemize}
\item [Bubble sort](http://en.wikipedia



<unk>### Instruction: Write a poem about the transformers Python library.
### Response:
Transformers Python library,
A powerful tool for data processing,
It can transform data into new forms,
And help us to analyze and visualize.
Transformers Python library,
A versatile tool for data processing,
It can transform data into new forms,
And help us to analyze and vi

In [13]:
batch = Batch(id=0, requests=requests)
clm_batch = BLoraCausalLMBatch.from_batch(batch, tokenizer=model.tokenizer, device="cuda")

input_seqs = clm_batch.input_ids.tolist()
generation_dict = {req.id : [] for req in clm_batch.requests}

active_ids = set()

# PREFILL
clm_batch = prefill(model, clm_batch, generation_dict)
for r in clm_batch.requests:
    active_ids.add(r.id)

# DECODE
while clm_batch is not None:
    clm_batch = decode(model, clm_batch, generation_dict, active_ids)

print_gen(model.tokenizer, input_seqs, generation_dict)

<unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk>Outline a five sentence short story about the Patriots winning the Super Bowl.
1. The New England Patriots, led by their legendary quarterback Tom Brady, were determined to make history once again.
2. With their powerful offense and stifling defense, the Pats dominated the AFC Championship against the Kansas City Chiefs.
3. In the Super Bowl, they faced off against the Los Angeles Rams, who had an equally impressive roster.
4. However, the Patriots' experience



<unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk>Question: What are the various algorithms to sort a list?

Answer: The most common algorithms are:
\begin{itemize}
\item [Bubble sort](http://en.wikipedia



<unk><unk><unk><unk><unk><unk><unk><unk><unk><unk>### Instruction: Write a poem about the transformers Python library.
### Response:
Transformers Python library,
A powerful tool for data processing,
It can transform dat

In [None]:
%load_ext autoreload
%autoreload 2

from service.blora_utils import load_loras, prepare_batch
from transformers import LlamaForCausalLM, LlamaTokenizer
from IPython.display import clear_output

model2 = LlamaForCausalLM.from_pretrained(base_model_id, device_map="auto")
tokenizer = LlamaTokenizer.from_pretrained(base_model_id)
tokenizer.pad_token = 0
model2, lora_map = load_loras(model2, lora_ids)

In [49]:
b = prepare_batch(inputs, tokenizer, model2, lora_map)

outputs = []

for out in model2.generate(
    **b,
    max_length=100,
    stream_output=True
):
    outputs.append(out)
    batch_decoded = tokenizer.batch_decode(torch.cat([out.reshape(-1, 1) for out in outputs], dim=1))
    clear_output(wait=True)
    print("\n\n".join([lora + ":\n" + prompt + decoded for (prompt, lora, _), decoded in zip(inputs, batch_decoded)]))

jondurbin/airoboros-7b-gpt4-1.2-peft:
Outline a five sentence short story where a character stumbles upon a secret room in their house that contains relics from their future.
The character, who is a young boy named Timmy, stumbles upon a secret room in his house that contained relics from his future. The room was hidden behind a bookcase in the library, and it was filled with strange artifacts and documents.
Timmy's curiosity got the best of him, and he decided

trl-lib/llama-7b-se-rl-peft:
Question: What are the various algorithms to sort a list?

Answer:: The most common algorithms are:
\begin{itemize}
\item [Bubble sort](http://en.wikipedia.org/wiki/Bubble_sort)
\item [Selection sort](http://en.wikipedia.org/wiki/Selection_sort)
\item [Insertion sort](http://en.wikipedia

winddude/wizardLM-LlaMA-LoRA-7B:
### Instruction: Write a poem about the transformers Python library.
### Response:
Transformers Python library,
A powerful tool for data processing,
It can transform data into new f

In [138]:
model.unset_batch_lora_ids()

TypeError: __init__() got an unexpected keyword argument 'idx'

In [None]:
# b = prepare_batch(inputs, tokenizer, model, lora_map)

# tokens = b.input_ids[0,:].tolist()

# model_kwargs = {
#     "attention_mask": b.attention_mask,
#     "use_cache": True
# }

# input_ids = b.input_ids

# for _ in range(100):
#     model_inputs = model.prepare_inputs_for_generation(input_ids, **model_kwargs)

#     outputs = model(
#         **model_inputs,
#         return_dict=True,
#         output_attentions=False,
#         output_hidden_states=False,
#     )

#     next_token_logits = outputs.logits[:, -1, :]
#     next_tokens = torch.argmax(next_token_logits, dim=-1)
#     tokens.append(next_tokens.item())

#     input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
#     model_kwargs = model._update_model_kwargs_for_generation(
#         outputs, model_kwargs, is_encoder_decoder=False
#     )

#     print(next_tokens.item())
#     print("\n\n")
#     # print(len(tokens))
#     # print(tokenizer.decode(tokens[1:]))

Loading model...


Loading checkpoint shards: 100%|██████████| 33/33 [00:08<00:00,  3.68it/s]


Done!

Loading LORAs...


The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'LLaMATokenizer'. 
The class this function is called from is 'LlamaTokenizer'.
Using pad_token, but it is not set yet.


Done!



In [33]:
print(batch)

Batch(id=0, requests=[Request(id=0, lora_id='trl-lib/llama-7b-se-rl-peft', inputs='Write a 6 line dialogue between a character and a magical creature that only they can see.')])


In [36]:
requests = []
for idx, inp in enumerate(inputs):
    requests.append(Request(
        id=idx,
        lora_id=inp[1],
        inputs=inp[0]
    ))

batch = Batch(
    id=0,
    requests=requests
)

causal_lm_batch = BLoraCausalLMBatch.from_batch(batch, tokenizer=model2.tokenizer, device="cuda")
generation_dct = {idx: [] for idx in range(len(causal_lm_batch.requests))}

for _ in range(100):
    # print(causal_lm_batch.input_ids)
    # print(causal_lm_batch.position_ids)
    # # print(causal_lm_batch.attention_mask)
    # if causal_lm_batch.past_key_values is None:
    #     print("None")
    # else:
    #     print(causal_lm_batch.past_key_values[0][0][0,0,0,0])
    #     print(causal_lm_batch.past_key_values[0][0][0,0,0,1])

    generations, causal_lm_batch = model2.generate_token(causal_lm_batch)
    for idx, gen in enumerate(generations):
        generation_dct[idx].append(gen.token_id.item())

In [38]:
causal_lm_batch = BLoraCausalLMBatch.from_batch(batch, tokenizer=model2.tokenizer, device="cuda")
input_seqs = causal_lm_batch.input_ids.tolist()

key = 0
input_seqs[key].extend(generation_dct[key])
print(model2.tokenizer.decode(input_seqs[key]))

<unk>Write a 6 line dialogue between a character and a magical creature that only they can see. Write a 6 line dialogue between a character and a magical creature that only they can see.
Write a 6 line dialogue between a character and a magical creature that only they can see. Write a 6 line dialogue between a character and a magical creature that only they can see.
Write a 6 line dialogue between a character and a magical creature that only they can see. Write a 6 line dialogue between a character and a magical


In [35]:
model2.active_batch_id = None

In [29]:
model2.model.batch_lora_ids

['jondurbin/airoboros-7b-gpt4-1.2-peft']

In [None]:
for key in generation_dct:
    print(generation_dct[key])