# Serving a Stable Diffusion Model

This guide is a quickstart to use [Ray Serve](todo) for model serving. The provided example loads a pretrained stable diffusion model from HuggingFace and serves it to a local endpoint as a Ray Serve deployment. See `server.py` to see what code can be replaced to serve your own models!

> Slot in your code below wherever you see the ✂️ icon to build a model serving Ray application off of this template!

## Installing Dependencies

First, we'll need to install necessary dependencies in the Anyscale Workspace. To do so, first open up a terminal, and follow one of the following install steps, depending on which size template you picked:

### Install Dependencies (Small-scale Template)

The small-scale template only runs on a single node (the head node), so we just need to install the requirements *locally*.

```
pip install -r requirements.txt --upgrade
```

### Install Cluster-wide Dependencies (Large-scale Template)

When running in a distributed Ray Cluster, all nodes need to have access to the installed packages.
For this, we'll use `pip install --user` to install the necessary requirements.
On an [Anyscale Workspace](https://docs.anyscale.com/user-guide/develop-and-debug/workspaces),
this will install packages to a *shared filesystem* that will be available to all nodes in the cluster.

```
pip install --user -r requirements.txt --upgrade
```

## Deploy the Ray Serve application locally

The Ray Serve application with the model serving logic can be found in `app.py`, where we define:
- The `/imagine` API endpoint that we query to generate the image.
- The stable diffusion model loaded inside a Ray Serve Deployment.
  We'll specify the *number of model replicas* to keep active in our Ray cluster. These model replicas can process incoming requests concurrently.

Let's deploy the Ray Serve application locally (at `http://localhost:8000`)!
Open a terminal in your Workspace, and run the following command in your workspace directory (where `server.py` is located):


| Template Size | Launch Command |
| ------------- | --------------------- |
|Small-scale (single-node) | `python server.py --num-replicas=1`  |
|Large-scale (multi-node)  | `python server.py --num-replicas=4` |

This command will continue running to host your local Ray Serve application.
This will be the place to view all the autoscaling logs, as well as any logs emitted by
the model inference once requests start coming through.

## Make a Request

Next, we'll build a simple client to submit prompts as HTTP requests to the local endpoint at `http://localhost:8000/imagine`.
The script will 

In [None]:
import aiohttp
import asyncio
import matplotlib.pyplot as plt
import os
import requests
import time
import uuid


> ✂️ Replace this value to change the number of images to generate per prompt.
>
> Each image will be generated starting from a different set of random noise,
> so you'll be able to see multiple options per prompt!
> With more model replicas, more images can be generated in parallel.
>
> Try starting with the recommended value in the table below:

| Template Size | `NUM_IMAGES_PER_PROMPT` |
| ------------- | --------------------- |
|Small-scale (single-node) | `1` |
|Large-scale (multi-node)  | `4` |

In [None]:
NUM_IMAGES_PER_PROMPT: int = 4


Start the client script in the next few cells, and enter a prompt to generate your first image! For example:

```
Enter a prompt (or 'q' to quit):   twin peaks sf in basquiat painting style

Generating image(s)...
(Take a look at the terminal serving the endpoint for more logs!)


Generated 1 image(s) in 69.89 seconds to the directory: 58b298d9
```

![Example output](https://user-images.githubusercontent.com/3887863/221063452-3c5e5f6b-fc8c-410f-ad5c-202441cceb51.png)

In [None]:
endpoint = "http://localhost:8000/imagine"


async def generate_image(session, prompt):
    req = {"prompt": prompt, "img_size": 776}
    async with session.get(endpoint, params=req) as resp:
        image_data = await resp.read()
    return image_data


def show_images(filenames):
    fig, axs = plt.subplots(1, len(filenames), figsize=(4 * len(filenames), 4))
    for i, filename in enumerate(filenames):
        ax = axs if len(filenames) == 1 else axs[i]
        ax.imshow(plt.imread(filename))
        ax.axis("off")
    plt.show()


async def main():
    try:
        requests.get(endpoint, timeout=0.1)
    except Exception as e:
        raise RuntimeWarning(
            "Did you setup the Ray Serve model replicas with "
            "`python server.py --num-replicas=...` in another terminal yet?"
        ) from e

    while True:
        prompt = input(f"\nEnter a prompt (or 'q' to quit):  ")
        if prompt.lower() == "q":
            break

        print("\nGenerating image(s)...")
        print("(Take a look at the terminal serving the endpoint for more logs!)\n")
        start = time.time()

        async with aiohttp.ClientSession() as session:
            tasks = []
            for i in range(NUM_IMAGES_PER_PROMPT):
                tasks.append(generate_image(session, prompt))
            images = await asyncio.gather(*tasks)

        dirname = f"{uuid.uuid4().hex[:8]}"
        os.makedirs(dirname)
        filenames = []
        for i, image in enumerate(images):
            filename = os.path.join(dirname, f"{i}.png")
            with open(filename, "wb") as f:
                f.write(image)
            filenames.append(filename)

        elapsed = time.time() - start

        print(
            f"\nGenerated {len(images)} image(s) in {elapsed:.2f} seconds to "
            f"the directory: {dirname}\n"
        )
        show_images(filenames)


Once the stable diffusion model finishes generating your image, it will be included in the HTTP response body.
The client writes this to an image in your Workspace directory for you to view. It'll also show up in the notebook cell!

In [None]:
await main()


You've successfully served a stable diffusion model!
You can modify this template and quickly iterate your model deployment directly on your cluster within your Anyscale Workspace,
testing with the local endpoint.