In [1]:
import ray
import os
from starlette.requests import Request
from typing import List, Optional, Any
import torch
import shutil
import logging
import sys
import json
import time
from huggingface_hub.hf_api import HfFolder
import json
from typing import AsyncGenerator
from fastapi import BackgroundTasks
from starlette.requests import Request
from starlette.responses import StreamingResponse, Response
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.sampling_params import SamplingParams
from vllm.logger import init_logger
from vllm.utils import random_uuid
from ray import serve
import vllm
vllm_logger = init_logger(__name__)
vllm_logger.setLevel(logging.DEBUG)

In [2]:
!echo $HOST_IP

10.244.142.11


In [3]:
!echo $NCCL_DEBUG

INFO


In [4]:
!echo $NCCL_SOCKET_IFNAME

eth0


In [5]:
MODEL = "lmsys/vicuna-13b-v1.5-16k"

In [6]:
logger = logging.getLogger("ray.serve")
logger.setLevel(logging.DEBUG)

@serve.deployment(num_replicas=1, ray_actor_options={"resources": {"custom_worker": 1}}, route_prefix="/llmapi")
class SnowflakeVLLMDeployment:
    def __init__(self, **kwargs):
        args = AsyncEngineArgs(**kwargs)
        self.engine = AsyncLLMEngine.from_engine_args(args)

    async def stream_results(self, results_generator) -> AsyncGenerator[bytes, None]:
        num_returned = 0
        async for request_output in results_generator:
            text_outputs = [output.text for output in request_output.outputs]
            assert len(text_outputs) == 1
            text_output = text_outputs[0][num_returned:]
            ret = {"text": text_output}
            yield (json.dumps(ret) + "\n").encode("utf-8")
            num_returned += len(text_output)

    async def may_abort_request(self, request_id) -> None:
        await self.engine.abort(request_id)

    async def __call__(self, request: Request) -> Response:
        request_dict = await request.json()
        prompt = request_dict.pop("prompt")
        stream = request_dict.pop("stream", False)
        sampling_params = SamplingParams(**request_dict)
        request_id = random_uuid()
        results_generator = self.engine.generate(prompt, sampling_params, request_id)
        if stream:
            background_tasks = BackgroundTasks()
            background_tasks.add_task(self.may_abort_request, request_id)
            return StreamingResponse(
                self.stream_results(results_generator), background=background_tasks
            )

        # Non-streaming case
        final_output = None
        async for request_output in results_generator:
            if await request.is_disconnected():
                # Abort the request if the client disconnects.
                await self.engine.abort(request_id)
                return Response(status_code=499)
            final_output = request_output

        assert final_output is not None
        #prompt = final_output.prompt
        #text_outputs = [prompt + output.text for output in final_output.outputs]
        text_outputs = [output.text for output in final_output.outputs]
        ret = {"text": text_outputs}
        return Response(content=json.dumps(ret))



In [7]:
deployment = SnowflakeVLLMDeployment.bind(model=MODEL, tensor_parallel_size=8, seed=123)
ray.init(address="auto", log_to_driver=False)
serve.run(target=deployment, name="llm")

2024-07-05 07:56:10,194	INFO worker.py:1567 -- Connecting to existing Ray cluster at address: 10.244.142.11:6379...
2024-07-05 07:56:10,239	INFO worker.py:1743 -- Connected to Ray cluster. View the dashboard at [1m[32m10.244.142.11:8265 [39m[22m
[2024-07-05 07:56:10,242 I 1278 1278] logging.cc:230: Set ray log level from environment variable RAY_BACKEND_LOG_LEVEL to -1
2024-07-05 07:57:22,682	INFO api.py:575 -- Deployed app 'llm' successfully.
2024-07-05 07:57:22,685	INFO router.py:286 -- Created DeploymentHandle 'coxpelah' for Deployment(name='SnowflakeVLLMDeployment', app='llm').


DeploymentHandle(deployment='SnowflakeVLLMDeployment')

2024-07-05 07:57:22,694	DEBUG long_poll.py:155 -- LongPollClient <ray.serve._private.long_poll.LongPollClient object at 0x7ff6da0f5db0> received updates for keys: [(LongPollNamespace.RUNNING_REPLICAS, Deployment(name='SnowflakeVLLMDeployment', app='llm')), (LongPollNamespace.DEPLOYMENT_CONFIG, Deployment(name='SnowflakeVLLMDeployment', app='llm'))].
2024-07-05 07:57:22,695	INFO pow_2_scheduler.py:260 -- Got updated replicas for Deployment(name='SnowflakeVLLMDeployment', app='llm'): {'leszxbzd'}.
