# Serving a Chatbot with Ray Serve

| Template Specification | Description |
| ---------------------- | ----------- |
| Summary | This template demonstrates how to serve a chatbot with [Ray Serve](https://docs.ray.io/en/latest/serve/index.html).|
| Time to Run | |
| Minimum Compute Requirements | At least 1 GPU node. The default is 1 node with 4 CPUs, and 1 node with 1 NVIDIA T4 GPU. |
| Cluster Environment | This template uses a docker image built on top of the latest Anyscale-provided Ray image using Python 3.9: [`anyscale/ray:latest-py39-cu118`](https://docs.anyscale.com/reference/base-images/overview). See the appendix below for more details. |

## Get Started

**When the workspace is up and running, start coding by clicking on the Jupyter or VSCode icon above. Open the `start.ipynb` file and follow the instructions there.**

By the end, we'll have a chatbot that can have a conversation using GPT-J.

The application will look something like this:

```text
>>> Hi there!

...

>>> What's your name?

...
```

![Example output]()

> Slot in your code below wherever you see the ✂️ icon to build off of this template!
>
> The framework and data format used in this template can be easily replaced to suit your own application!

First, add the imports and the Serve logger.

In [None]:
import asyncio
import logging
from queue import Empty

from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer

from ray import serve

logger = logging.getLogger("ray.serve")

## Build and run the Ray Serve chatbot application

Create a FastAPI deployment that initializes the model in the constructor and exposes a WebSocket endpoint, so users can query it:


In [None]:
fastapi_app = FastAPI()


@serve.deployment
@serve.ingress(fastapi_app)
class Chatbot:
    def __init__(self, model_id: str):
        self.loop = asyncio.get_running_loop()

        self.model_id = model_id
        self.model = AutoModelForCausalLM.from_pretrained(self.model_id)
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_id)

    @fastapi_app.websocket("/")
    async def handle_request(self, ws: WebSocket) -> None:
        await ws.accept()

        conversation = ""
        try:
            while True:
                prompt = await ws.receive_text()
                logger.info(f'Got prompt: "{prompt}"')
                conversation += prompt
                streamer = TextIteratorStreamer(
                    self.tokenizer,
                    timeout=0,
                    skip_prompt=True,
                    skip_special_tokens=True,
                )
                self.loop.run_in_executor(
                    None, self.generate_text, conversation, streamer
                )
                response = ""
                async for text in self.consume_streamer(streamer):
                    await ws.send_text(text)
                    response += text
                await ws.send_text("<<Response Finished>>")
                conversation += response
        except WebSocketDisconnect:
            print("Client disconnected.")
        
    def generate_text(self, prompt: str, streamer: TextIteratorStreamer):
        input_ids = self.tokenizer([prompt], return_tensors="pt").input_ids
        self.model.generate(input_ids, streamer=streamer, max_length=10000)

    async def consume_streamer(self, streamer: TextIteratorStreamer):
        while True:
            try:
                for token in streamer:
                    logger.info(f'Yielding token: "{token}"')
                    yield token
                break
            except Empty:
                await asyncio.sleep(0.001)

`Chatbot` uses three methods to handle requests:

* `handle_request`: the entrypoint for WebSocket requests. The `handle_request` method is decorated with a `@fastapi_app.websocket` decorator, which lets it accept WebSocket requests. First it `awaits` to accept the client's WebSocket request. Then, until the client disconnects, it does the following:
    * gets the prompt from the client with `ws.receive_text`
    * starts a new `TextIteratorStreamer` to access generated tokens
    * runs the model in a background thread on the conversation so far
    * streams the model's output back using `ws.send_text`
    * stores the prompt and the response in the `conversation` string
* `generate_text`: the method that runs the model. This method runs in a background thread kicked off by `handle_request`. It pushes generated tokens into the streamer constructed by `handle_request`.
* `consume_streamer`: a generator method that consumes the streamer constructed by `handle_request`. This method keeps yielding tokens from the streamer until the model in `generate_text` closes the streamer. This method avoids blocking the event loop by calling `asyncio.sleep` with a brief timeout whenever the streamer is empty and waiting for a new token.

**Handling Disconnects:** Each time `handle_request` gets a new prompt from a client, it runs the whole conversation– with the new prompt appended– through the model. When the model is finished generating tokens, `handle_request` sends the `"<<Response Finished>>"` string to inform the client that all tokens have been generated. `handle_request` continues to run until the client explicitly disconnects. This disconnect raises a `WebSocketDisconnect` exception, which ends the call.

Bind the `Chatbot` to a language model. For this tutorial, use the `"EleutherAI/gpt-j-6B"` model.

✂️ You can replace the model ID to use a different language model!

In [None]:
app = Chatbot.bind("EleutherAI/gpt-j-6B")

Now, deploy the Ray Serve application locally at `ws://localhost:8000`!

In [None]:
# Shutdown any existing Serve replicas, if they're still around.
serve.shutdown()
serve.run(app, port=8000, name="Chat")
print("Done setting up replicas! Now accepting requests...")


## Make requests to the endpoint

Next, we'll build a simple WebSocket client to chat with the chatbot.

Run the client script in the next cell to start chatting!


Once the stable diffusion model finishes generating your image(s), it will be included in the HTTP response body.
The client saves all the images in a local directory for you to view, and they'll also show up in the notebook cell!

In [None]:
from websockets.sync.client import connect

with connect("ws://localhost:8000") as websocket:
    user_prompt = input(">>> ")
    websocket.send(user_prompt)
    print("Chatbot: ", end="", flush=True)
    while True:
        received = websocket.recv()
        if received == "<<Response Finished>>":
            break
        print(received, end="", flush=True)
    print("\n")


In [None]:
# Shut down Serve once you're done!
serve.shutdown()

## Summary

This template used [Ray Serve](https://docs.ray.io/en/latest/serve/index.html) to serve a chatbot. Ray Serve is one of many libraries under the [Ray AI Runtime](https://docs.ray.io/en/latest/ray-air/getting-started.html).

At a high level, this template showed how to:
1. Load a HuggingFace model in a Ray Serve deployment and perform inference.
2. Provide bidirectional communication in Ray Serve deployments with WebSockets
3. Stream inputs and ouputs between a Ray Serve deployment and an end user

See this [getting started guide](https://docs.ray.io/en/latest/serve/getting_started.html) for a more detailed walkthrough of Ray Serve.