In [27]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [28]:
%ls

[0m[01;34mprofiling[0m/  [01;34mproto[0m/  [01;34mserver[0m/


In [75]:
import torch, time, tqdm
from server.text_generation_server.models import get_model
from text_generation_server.pb.generate_pb2 import Batch, Request, NextTokenChooserParameters, StoppingCriteriaParameters

num_tokens = 100
iterations = 3
model_id = "bigscience/bloom-560m"

def create_batch(max_tokens=20, batch_size=1):
    next_token_params = NextTokenChooserParameters(
        temperature=1,
        top_p=1,
        typical_p=1,
        seed=9248039014309552135,
        repetition_penalty=1
    )

    stopping_params = StoppingCriteriaParameters(
        max_new_tokens=max_tokens
    )

    requests = [Request(
        id=i, 
        inputs="What is Deep Learning?",
        truncate=1024,
        parameters=next_token_params,
        stopping_parameters=stopping_params
    ) for i in range(batch_size)]

    return Batch(
        id=0,
        requests=requests,
        size=batch_size,
        max_tokens=max_tokens
    )

def main(model_id, iterations=3, num_tokens=100, batch_size=1):
    model = get_model(
        model_id=model_id,
        revision=None,
        dtype="float16",
        quantize=None,
        sharded=True,
        trust_remote_code=True,
    )
    
    batch = model.batch_type.from_pb(
        create_batch(max_tokens=num_tokens, batch_size=batch_size), model.tokenizer, model.dtype, model.device
    )
    
    model.warmup(batch)

    tokens = []
    with torch.no_grad():
        start = time.perf_counter()
        for _ in tqdm.tqdm(range(iterations)):
            for _ in range(num_tokens):
                generations, next_batch = model.generate_token(batch)
                tokens.append(generations[15].token_text)
    
        torch.cuda.synchronize()
        end = time.perf_counter()

        print(tokens)
        print(f"Time = {end - start: 0.2f}")
        print(f"Tokens = {num_tokens * batch_size * iterations}")
        print(f"Tokens/sec = {num_tokens * batch_size * iterations / (end-start): 0.2f}")

In [77]:
print(next_batch.__dict__.keys())

dict_keys(['batch_id', 'requests', 'requests_idx_mapping', 'input_ids', 'attention_mask', 'position_ids', 'past_key_values', 'all_input_ids', 'input_lengths', 'prefix_offsets', 'read_offsets', 'next_token_choosers', 'stopping_criterias', 'max_input_length', 'padding_right_offset', 'max_tokens', 'keys_head_dim_last'])


In [76]:
main(model_id, iterations=1, num_tokens=num_tokens, batch_size=16)

The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is ignored.


FLASH_ATTENTION = True


100%|███████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:01<00:00,  1.73s/it]

[' Learning', ' is', ' a', ' new', ' type', ' of', ' machine', ' learning', ' that', ' is', ' used', ' to', ' solve', ' complex', ' problems', '.', ' It', ' is', ' a', ' type', ' of', ' machine', ' learning', ' that', ' uses', ' deep', ' neural', ' networks', ' to', ' learn', ' complex', ' mathematical', ' functions', '.', ' Deep', ' Learning', ' is', ' a', ' type', ' of', ' machine', ' learning', ' that', ' uses', ' deep', ' neural', ' networks', ' to', ' learn', ' complex', ' mathematical', ' functions', '.', ' Deep', ' Learning', ' is', ' a', ' type', ' of', ' machine', ' learning', ' that', ' uses', ' deep', ' neural', ' networks', ' to', ' learn', ' complex', ' mathematical', ' functions', '.', ' Deep', ' Learning', ' is', ' a', ' type', ' of', ' machine', ' learning', ' that', ' uses', ' deep', ' neural', ' networks', ' to', ' learn', ' complex', ' mathematical', ' functions', '.', ' Deep', ' Learning', ' is', ' a', ' type', ' of', ' machine', ' learning', ' that']
Time =  1.73
T




In [58]:
batch = create_batch()
batch = model.batch_type.from_pb(
        create_batch(max_tokens=num_tokens), model.tokenizer, model.dtype, model.device
    )

In [60]:
model

<text_generation_server.models.bloom.BLOOMSharded at 0x7fafe38f67c0>

In [61]:
generations, next_batch = model.generate_token(batch)

In [63]:
len(generations)

8

In [None]:
python3 