In [1]:
import os
import time
import random
import argparse
import json

import numpy as np
import torch

from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.pipelines.pt_utils import KeyDataset
import torch_tensorrt

from pytriton.model_config import ModelConfig, Tensor
from pytriton.proxy.types import Request
from pytriton.triton import Triton, TritonConfig
from pytriton.decorators import batch
import torch_tensorrt


DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")



In [2]:
class Phi3Deployment:
    def __init__(self, model_id):
        model = AutoModelForCausalLM.from_pretrained(
            model_id,
            device_map="auto",
            torch_dtype="auto",
            trust_remote_code=True,
            torchscript=True,
            attn_implementation="eager")
        
        self.trt_model = torch.compile(
                            model,
                            backend="torch_tensorrt",
                            options={
                                "truncate_long_and_double": True,
                                "enabled_precisions": {torch.float16}
                                },
                            dynamic=False)
        
        self.tokenizer = AutoTokenizer.from_pretrained(model_id, padding_size='left')
    
    def infer_fn(self, request):
        generation_args = {
        "max_new_tokens": 200,
        "temperature": 0.0,
        "do_sample": False,
        }

        system_message = "You are a helpful AI Assistant. Help users by replying to their queries and make sure the responses are polite. Do not hallucinate."
        PROMPT = f"<|system|>\n{system_message}<|end|>"
        
        msg = request[0].values()
        msgs = request.tobytes().decode("utf-8")
        token_ids = self.tokenizer.apply_chat_template(msgs, 
                                                add_generation_prompt=True, 
                                                tokenize=False,)
        print(token_ids)

        if type(token_ids) is list:                                     
            token_ids.insert(0, PROMPT)
        else:
            token_ids = [PROMPT] + [token_ids]
        
        inputs = self.tokenizer(token_ids, return_tensors="pt", padding=True)
        inputs = {k: v.type(torch.int32).to(DEVICE) for k, v in inputs.items()}
        
        with torch.no_grad():
            outs = self.trt_model.generate(
                input_ids=inputs["input_ids"],
                attention_mask=inputs.get("attention_mask"),
                **generation_args)

        generated_texts = self.tokenizer.batch_decode(outs, skip_special_tokens=True, clean_up_tokenization_spaces=False)

        json_texts = json.dumps(generated_texts)
        np_results = np.frombuffer(json_str.encode('utf-8'), dtype=np.uint8)

        return np_results
        # async for text in generated_texts[1:]:
        #     result = np.char.encode(text.rpartition('\n')[2])
        #     yield result

    @property
    def inputs(self):
        return [
            Tensor(name="input", dtype=np.uint8, shape=(-1,)),
        ]

    @property
    def outputs(self):
        return [
            Tensor(name="output", dtype=np.uint8, shape=(-1,))
        ]


In [3]:
phi3_deploy = Phi3Deployment("./models/Phi-3-medium-4k-instruct")

triton = Triton()
triton.bind(
        model_name="Phi3",
        infer_func=phi3_deploy.infer_fn,
        inputs=phi3_deploy.inputs,
        outputs=phi3_deploy.outputs,
        # config=ModelConfig(max_batch_size=30),
        strict=True,)

triton.run()

INFO:accelerate.utils.modeling:We will use 90% of the memory on device 0 for storing the model, and 10% for the buffer to avoid OOM. You can set `max_memory` in to a higher value to use more memory (at your own risk).


Loading checkpoint shards:   0%|          | 0/6 [00:00<?, ?it/s]

I0926 07:54:10.979030 5072 pinned_memory_manager.cc:277] "Pinned memory pool is created at '0x7f1c5e000000' with size 268435456"
I0926 07:54:10.979413 5072 cuda_memory_manager.cc:107] "CUDA memory pool is created on device 0 with size 67108864"
I0926 07:54:10.980716 5072 server.cc:604] 
+------------------+------+
| Repository Agent | Path |
+------------------+------+
+------------------+------+

I0926 07:54:10.980751 5072 server.cc:631] 
+---------+------+--------+
| Backend | Path | Config |
+---------+------+--------+
+---------+------+--------+

I0926 07:54:10.980764 5072 server.cc:674] 
+-------+---------+--------+
| Model | Version | Status |
+-------+---------+--------+
+-------+---------+--------+

CacheManager Init Failed. Error: -29
W0926 07:54:11.078321 5072 metrics.cc:798] "DCGM unable to start: DCGM initialization error"
I0926 07:54:11.080256 5072 metrics.cc:770] "Collecting CPU metrics"
I0926 07:54:11.080518 5072 tritonserver.cc:2598] 
+----------------------------------

INFO:pytriton.client.client:Patch ModelClient http


I0926 07:54:11.303502 5072 model_lifecycle.cc:472] "loading: Phi3:1"
I0926 07:54:13.885862 5072 python_be.cc:1912] "TRITONBACKEND_ModelInstanceInitialize: Phi3_0_0 (CPU device 0)"


INFO:pytriton.triton:Infer function available as model: `/v2/models/Phi3`
INFO:pytriton.triton:  Status:         `GET  /v2/models/Phi3/ready/`
INFO:pytriton.triton:  Model config:   `GET  /v2/models/Phi3/config/`
INFO:pytriton.triton:  Inference:      `POST /v2/models/Phi3/infer/`
INFO:pytriton.triton:Read more about configuring and serving models in documentation: https://triton-inference-server.github.io/pytriton.
INFO:pytriton.triton:(Press CTRL+C or use the command `kill -SIGINT 4528` to send a SIGINT signal and quit)


I0926 07:54:14.354652 5072 model_lifecycle.cc:838] "successfully loaded 'Phi3'"


In [5]:
from pytriton.client import ModelClient

client = ModelClient("localhost", "Phi3")

torch.cuda.synchronize()
start = time.perf_counter()

data = load_dataset("json", data_files="./data/test_dataset.jsonl")['train']
messages = data['message']

json_str = json.dumps(messages)
np_messages = np.frombuffer(json_str.encode('utf-8'), dtype=np.uint8)

# outs = []
# async for result in client.infer_sample(messages=np_messages):
#     outs.append([{'generated_text': reult.tobytes().decode('utf-8')}])

np_results = client.infer_sample(np_messages)
outs = [
    [{'generated_text': result.tobytes().decode('utf-8')}]
    for result in np_results
]
    
torch.cuda.synchronize()
end = time.perf_counter()

  File "/home/elicer/miniconda3/envs/yaikids/lib/python3.10/site-packages/pytriton/proxy/inference.py", line 391, in _handle_requests
    async for responses in self._model_callable(requests):
  File "/home/elicer/miniconda3/envs/yaikids/lib/python3.10/site-packages/pytriton/proxy/inference.py", line 85, in _callable
    yield inference_callable(requests)
  File "/tmp/ipykernel_4528/1831263700.py", line 33, in infer_fn
    msgs = request.tobytes().decode("utf-8")
AttributeError: 'list' object has no attribute 'tobytes'



PyTritonClientInferenceServerError: Error occurred during inference request. Message: Failed to process the request(s) for model 'Phi3_0_0', message: TritonModelException: Model execute error: Traceback (most recent call last):
  File "/tmp/foldera8eNss/1/model.py", line 486, in execute
    raise triton_responses_or_error
c_python_backend_utils.TritonModelException: Traceback (most recent call last):
  File "/home/elicer/miniconda3/envs/yaikids/lib/python3.10/site-packages/pytriton/proxy/inference.py", line 391, in _handle_requests
    async for responses in self._model_callable(requests):
  File "/home/elicer/miniconda3/envs/yaikids/lib/python3.10/site-packages/pytriton/proxy/inference.py", line 85, in _callable
    yield inference_callable(requests)
  File "/tmp/ipykernel_4528/1831263700.py", line 33, in infer_fn
    msgs = request.tobytes().decode("utf-8")
AttributeError: 'list' object has no attribute 'tobytes'



At:
  /tmp/foldera8eNss/1/model.py(495): execute


In [None]:
print(f"Inference took {end - start} seconds")

In [None]:
triton.stop()

INFO:pytriton.proxy.inference:Closing Inference Handler


True

In [None]:
print("===== Answers =====")
correct = 0
for i, out in enumerate(outs):
    correct_answer = data[i]["answer"]
    answer = out[0]["generated_text"].lstrip().replace("\n","")
    if answer == correct_answer:
        correct += 1
    # print(answer)
 
print("===== Perf result =====")
print("Elapsed_time: ", end-start)
print(f"Correctness: {correct}/{len(data)}")