# Serving a Stable Diffusion Model with Ray Serve

This guide is a quickstart to use [Ray Serve](https://docs.ray.io/en/latest/serve/index.html) for model serving. Ray Serve is one of many libraries under the [Ray AI Runtime](https://docs.ray.io/en/latest/ray-air/getting-started.html).

This template loads a pretrained stable diffusion model from HuggingFace and serves it to a local endpoint as a Ray Serve deployment. 

> Slot in your code below wherever you see the ✂️ icon to build a model serving Ray application off of this template!

## Handling Dependencies

This template requires certain Python packages to be available to every node in the cluster.

> ✂️ Add your own package dependencies in the `requirements.txt` file!


In [None]:
requirements_path = "./requirements.txt"


In [None]:
with open(requirements_path, "r") as f:
    requirements = f.read().strip().splitlines()

print("Requirements:")
print("\n".join(requirements))


First, we may want to use these modules right here in our script, which is running on the head node.
Install the Python packages on the head node using `pip install`.

```{note}
You may need to restart this notebook kernel to access the installed packages.
```


In [None]:
%pip install -r {requirements_path} --upgrade

Next, we need to make sure all worker nodes also have access to the dependencies.
For this, use a [Ray Runtime Environment](https://docs.ray.io/en/latest/ray-core/handling-dependencies.html#runtime-environments)
to dynamically set up dependencies throughout the cluster.


In [None]:
import ray

ray.init(runtime_env={"pip": requirements})


## Deploy the Ray Serve application locally

First, we define the Ray Serve application with the model loading and inference logic. This includes setting up:
- The `/imagine` API endpoint that we query to generate the image.
- The stable diffusion model loaded inside a Ray Serve Deployment.
  We'll specify the *number of model replicas* to keep active in our Ray cluster. These model replicas can process incoming requests concurrently.


In [None]:
from fastapi import FastAPI
from fastapi.responses import Response
from io import BytesIO
import matplotlib.pyplot as plt
import numpy as np
import os
import requests
import time
import uuid

import ray
from ray import serve


> ✂️ Replace these values to change the number of model replicas to serve, as well as the GPU resources required by each replica.
>
> With more model replicas, more images can be generated in parallel!

In [None]:
NUM_REPLICAS: int = 4
NUM_GPUS_PER_REPLICA: float = 1

# Control the output size: (IMAGE_SIZE, IMAGE_SIZE)
# NOTE: Generated image quality degrades rapidly if you reduce the size too much.
IMAGE_SIZE: int = 776


First, we define the Ray Serve Deployment, which will load a stable diffusion model and perform inference with it.


In [None]:
# Configure each model replica to use the specified resources.
ray_actor_options = {
    "num_gpus": NUM_GPUS_PER_REPLICA,
}


> ✂️ Modify this block to load your own model, and change the `generate` method to perform your own online inference logic!

In [None]:
@serve.deployment(
    ray_actor_options=ray_actor_options,
    num_replicas=NUM_REPLICAS,
)
class StableDiffusionV2:
    def __init__(self):
        # <Replace with your own model loading logic>
        try:
            import torch
            from diffusers import EulerDiscreteScheduler, StableDiffusionPipeline
        except ImportError as e:
            raise RuntimeError(
                "Did you set a runtime env to install dependencies?"
            ) from e

        model_id = "stabilityai/stable-diffusion-2"
        scheduler = EulerDiscreteScheduler.from_pretrained(
            model_id, subfolder="scheduler"
        )
        self.pipe = StableDiffusionPipeline.from_pretrained(
            model_id, scheduler=scheduler, revision="fp16", torch_dtype=torch.float16
        )
        self.pipe = self.pipe.to("cuda")

    def generate(self, prompt: str, img_size: int = 776):
        # <Replace with your own model inference logic>
        assert len(prompt), "prompt parameter cannot be empty"
        image = self.pipe(prompt, height=img_size, width=img_size).images[0]
        return image


Next, we'll define the actual API endpoint to live at `/imagine`.

> ✂️ Modify this block to change the endpoint URL, response schema, and add any post-processing logic needed from your model output!

In [None]:
app = FastAPI()


@serve.deployment(num_replicas=1, route_prefix="/")
@serve.ingress(app)
class APIIngress:
    def __init__(self, diffusion_model_handle) -> None:
        self.handle = diffusion_model_handle

    @app.get(
        "/imagine",
        responses={200: {"content": {"image/png": {}}}},
        response_class=Response,
    )
    async def generate(self, prompt: str, img_size: int = 776):
        assert len(prompt), "prompt parameter cannot be empty"

        image = await (await self.handle.generate.remote(prompt, img_size=img_size))

        file_stream = BytesIO()
        image.save(file_stream, "PNG")
        return Response(content=file_stream.getvalue(), media_type="image/png")


Now, we deploy the Ray Serve application locally at `http://localhost:8000`!

In [None]:
entrypoint = APIIngress.bind(StableDiffusionV2.bind())
port = 8000

# Shutdown any existing Serve replicas, if they're still around.
serve.shutdown()
serve.run(entrypoint, port=port, name="serving_stable_diffusion_template")
print("Done setting up replicas! Now accepting requests...")


## Make requests to the endpoint

Next, we'll build a simple client to submit prompts as HTTP requests to the local endpoint at `http://localhost:8000/imagine`.

> ✂️ Replace this value to change the number of images to generate per prompt.
>
> Each image will be generated starting from a different set of random noise,
> so you'll be able to see multiple options per prompt!
>
> Try starting with `NUM_IMAGES_PER_PROMPT` equal to `NUM_REPLICAS` from earlier.

In [None]:
NUM_IMAGES_PER_PROMPT: int = NUM_REPLICAS


> ✂️ You can choose to run this interactively, or submit a single `PROMPT`.

In [None]:
INTERACTIVE: bool = False
PROMPT = "twin peaks sf in basquiat painting style"


Start the client script in the next few cells, and generate your first image! For example:

If running interactively, this will look like:

```
Enter a prompt (or 'q' to quit):   twin peaks sf in basquiat painting style

Generating image(s)...
(Take a look at the terminal serving the endpoint for more logs!)


Generated 1 image(s) in 69.89 seconds to the directory: 58b298d9
```

![Example output](https://user-images.githubusercontent.com/3887863/221063452-3c5e5f6b-fc8c-410f-ad5c-202441cceb51.png)

In [None]:
endpoint = f"http://localhost:{port}/imagine"


@ray.remote(num_cpus=1)
def generate_image(prompt):
    req = {"prompt": prompt, "img_size": IMAGE_SIZE}
    resp = requests.get(endpoint, params=req)
    return resp.content


def show_images(filenames):
    fig, axs = plt.subplots(1, len(filenames), figsize=(4 * len(filenames), 4))
    for i, filename in enumerate(filenames):
        ax = axs if len(filenames) == 1 else axs[i]
        ax.imshow(plt.imread(filename))
        ax.axis("off")
    plt.show()


def main():
    try:
        requests.get(endpoint, timeout=0.1)
    except Exception as e:
        raise RuntimeWarning(
            "Did you setup the Ray Serve model replicas with "
            "`python server.py --num-replicas=...` in another terminal yet?"
        ) from e

    generation_times = []
    while True:
        prompt = (
            PROMPT
            if not INTERACTIVE
            else input(f"\nEnter a prompt (or 'q' to quit):  ")
        )
        if prompt.lower() == "q":
            break

        print("\nGenerating image(s)...\n")
        start = time.time()

        # Make `NUM_IMAGES_PER_PROMPT` requests to the endpoint at once!
        images = ray.get(
            [generate_image.remote(prompt) for _ in range(NUM_IMAGES_PER_PROMPT)]
        )

        dirname = f"{uuid.uuid4().hex[:8]}"
        os.makedirs(dirname)
        filenames = []
        for i, image in enumerate(images):
            filename = os.path.join(dirname, f"{i}.png")
            with open(filename, "wb") as f:
                f.write(image)
            filenames.append(filename)

        elapsed = time.time() - start
        generation_times.append(elapsed)
        print(
            f"\nGenerated {len(images)} image(s) in {elapsed:.2f} seconds to "
            f"the directory: {dirname}\n"
        )
        show_images(filenames)
        if not INTERACTIVE:
            break
    return np.mean(generation_times) if generation_times else -1


Once the stable diffusion model finishes generating your image, it will be included in the HTTP response body.
The client writes this to an image in your Workspace directory for you to view. It'll also show up in the notebook cell!

In [None]:
mean_generation_time = main()


You've successfully served a stable diffusion model!
You can modify this template and iterate your model deployment directly on your cluster within your Anyscale Workspace,
testing with the local endpoint.

In [None]:
# Shut down the model replicas once you're done!
serve.shutdown()
