In [1]:
# download a model
from huggingface_hub import snapshot_download
MODEL_PATH = snapshot_download(repo_id="mgoin/TinyStories-33M-quant-deepsparse")

Fetching 10 files:   0%|          | 0/10 [00:00<?, ?it/s]

In [2]:
!ls {MODEL_PATH}

config.json		model.onnx		 tokenizer_config.json
generation_config.json	model-orig.onnx		 tokenizer.json
merges.txt		special_tokens_map.json  vocab.json


In [3]:
from service.causal_lm import DeepSparseCausalLM
from service.service import DeepSparseService

# setup service
service = DeepSparseService(
    model = DeepSparseCausalLM(
        model_path=f"{MODEL_PATH}/model.onnx",
        tokenizer_path=MODEL_PATH
    )
)

None of PyTorch, TensorFlow >= 2.0, or Flax have been found. Models won't be available and only tokenizers, configuration and file/data utilities can be used.
Using pad_token, but it is not set yet.
2023-10-17 22:44:55 deepsparse.utils.onnx INFO     Overwriting in-place the input shapes of the transformer model at /home/rshaw/.cache/huggingface/hub/models--mgoin--TinyStories-33M-quant-deepsparse/snapshots/6d30653d6fd728a5b8121a2e6801408c79c3c179/model.onnx
DeepSparse, Copyright 2021-present / Neuralmagic, Inc. version: 1.6.0.20231012 COMMUNITY | (ecee26fb) (release) (optimized) (system=avx2, binary=avx2)
2023-10-17 22:44:58 deepsparse.utils.onnx INFO     Overwriting in-place the input shapes of the transformer model at /home/rshaw/.cache/huggingface/hub/models--mgoin--TinyStories-33M-quant-deepsparse/snapshots/6d30653d6fd728a5b8121a2e6801408c79c3c179/model.onnx


deepsparse.engine.Engine:
	onnx_file_path: /home/rshaw/.cache/huggingface/hub/models--mgoin--TinyStories-33M-quant-deepsparse/snapshots/6d30653d6fd728a5b8121a2e6801408c79c3c179/model.onnx
	batch_size: 1
	num_cores: 24
	num_streams: 1
	scheduler: Scheduler.default
	fraction_of_supported_ops: 1.0
	cpu_avx_type: avx2
	cpu_vnni: False
deepsparse.engine.Engine:
	onnx_file_path: /home/rshaw/.cache/huggingface/hub/models--mgoin--TinyStories-33M-quant-deepsparse/snapshots/6d30653d6fd728a5b8121a2e6801408c79c3c179/model.onnx
	batch_size: 1
	num_cores: 24
	num_streams: 1
	scheduler: Scheduler.default
	fraction_of_supported_ops: 1.0
	cpu_avx_type: avx2
	cpu_vnni: False


In [4]:
from utils import CachedBatch, Batch, Generation, GenerateRequest, Request, GenerationParameters, GenerateRequestInputs

# setup inputs
prompts = [
    "Pricess Peach jumped from the balcony and",
    "Mario and Luigi ran out of the store and",
    "Bowser took out the flamethrower and",
    "Wario shaved his mustache and",
    "Toad made a funny sound and",
]


### ----- IMPORTANT ------
# NOTE: this controls how many decodes will be run for each request
max_new_tokens_list = [30, 6, 17, 9, 10]

requests = [
    Request(
        id=idx,
        inputs=prompt,
        generation_parameters=GenerationParameters(max_new_tokens=max_new_tokens),
    ) for idx, (prompt, max_new_tokens) in enumerate(zip(prompts, max_new_tokens_list))
]

batches = [
    Batch(
        id=idx,
        requests=[request]
    ) for idx, request in enumerate(requests)
]

In [8]:
NUM_DECODES = 5

def do_decode(service, cached_batch_list, text_list):
    generations, next_batch = service.Decode(cached_batch_list)

    if next_batch is None:
        cached_batch_list = []
        service.ClearCache()
        
        print("-------- COMPLETED GENERATION --------")
        print(f"id = {generations[0].request_id}: {text_list[generations[0].request_id]}")
        print("\n\n")

        return cached_batch_list, text_list

    cached_batch_list = [next_batch]
    for generation in generations:
        if generation.stopped:
            batch_id = cached_batch_list[0].batch_id
            active_request_ids = [
                request_id for request_id in cached_batch_list[0].request_ids if request_id != generation.request_id
            ]

            service.FilterBatch(
                batch_id=batch_id, request_ids=active_request_ids
            )

            print("-------- COMPLETED GENERATION --------")
            print(f"id = {generation.request_id}: {text_list[generation.request_id]}")
            print("\n\n")

        text_list[generation.request_id] += generation.token

    return cached_batch_list, text_list       

In [11]:
NUM_DECODES = 5
cached_batch_lst = []
text_lst = [prompt for prompt in prompts]

print("-------- ORIGINAL PROMPTS --------")
for idx, prompt in enumerate(prompts):
    print(f"id: {idx}: {prompt}")
print("\n")

for batch in batches:
    # prefill
    generation, new_cached_batch = service.Prefill(batch)
    text_lst[generation.request_id] += generation.token
    cached_batch_lst.append(new_cached_batch)

    # decodes
    for _ in range(NUM_DECODES):
        cached_batch_lst, text_lst = do_decode(service, cached_batch_lst, text_lst)

# once all the batches have been added
for _ in range(100):
    cached_batch_lst, text_lst = do_decode(service, cached_batch_lst, text_lst)
    if len(cached_batch_lst) == 0:
        break

-------- ORIGINAL PROMPTS --------
id: 0: Pricess Peach jumped from the balcony and
id: 1: Mario and Luigi ran out of the store and
id: 2: Bowser took out the flamethrower and
id: 3: Wario shaved his mustache and
id: 4: Toad made a funny sound and


-------- COMPLETED GENERATION --------
id = 1: Mario and Luigi ran out of the store and back to the park.



-------- COMPLETED GENERATION --------
id = 0: Pricess Peach jumped from the balcony and landed on the grass. She was so happy to be free.

The end.




-------- COMPLETED GENERATION --------
id = 3: Wario shaved his mustache and it made him look very handsome. He



-------- COMPLETED GENERATION --------
id = 2: Bowser took out the flamethrower and said, "I'm going to give you a special surprise!"

The



-------- COMPLETED GENERATION --------
id = 4: Toad made a funny sound and hopped away. He was so happy that he



