In [1]:
!pip install ray==2.3
!pip install ray[serve]
!pip install fastapi==0.96

[31mERROR: Invalid requirement: 'fastapi=0.96'
Hint: = is not a valid operator. Did you mean == ?[0m[31m
[0m

In [3]:
from io import BytesIO
from fastapi import FastAPI
from fastapi.responses import Response
import ray
from ray import serve
from typing import Any, List, Mapping
from PIL import Image

In [4]:
import ray

ray.init(
    address="ray://example-cluster-kuberay-head-svc:10001",
    runtime_env={
        "pip": [
            "diffusers==0.7.2",
            "transformers==4.24.0",
            "flax",
          ##"pip install 'jax[tpu]==0.4.11' -f https://storage.googleapis.com/jax-releases/libtpu_releases.html",
            "tensorboard-plugin-profile",
            "tensorboard",
            "ray[serve]",
        ]
    }
)

0,1
Python version:,3.10.9
Ray version:,2.3.0
Dashboard:,http://10.120.0.3:8265


In [5]:
app = FastAPI()


@serve.deployment(num_replicas=1, route_prefix="/")
@serve.ingress(app)
class APIIngress:
  """`APIIngress`, e.g. the request router.

  Arguments:
    diffusion_model_handle: The handle that we use to access the Diffusion
      model server that actually runs on TPU hardware.

  """
  def __init__(self, diffusion_model_handle) -> None:
    self.handle = diffusion_model_handle

  @serve.batch(batch_wait_timeout_s=10, max_batch_size=8)
  async def batched_generate_handler(self, prompts: List[str]):
    """Sends a batch of prompts to the TPU model server.

    This takes advantage of @serve.batch which is Ray Serve's built-in batching
    mechanism.

    We set `batch_wait_timeout_s`=10 and `max_batch_size`=8 which means that we
    wait the minimum of 10s or the time it takes to retrieve 8 requests in a
    batch to begin processing.

    Args:
      prompts: A list of input prompts

    Returns:
      A list of responses which contents are raw PNG.

    """
    print("Number of input prompts: ", len(prompts))
    print(prompts)
    assert len(prompts) <= 8, "We should not have more than 8 prompts."

    # Pad to 8 for now (unclear if this is necessary)
    num_to_pad = 8 - len(prompts)
    prompts += [""] * num_to_pad

    image_ref = await self.handle.generate.remote(prompts)
    images = await image_ref

    # Remove the padded responses.
    images = images[:8 - num_to_pad]
    results = []
    for image in images:
      file_stream = BytesIO()
      image.save(file_stream, "PNG")
      results.append(
          Response(content=file_stream.getvalue(), media_type="image/png"))
    return results

  @app.get(
      "/imagine",
      responses={200: {"content": {"image/png": {}}}},
      response_class=Response,
  )
  async def generate(self, prompt: str):
    """Requests the generation of an individual prompt.

    This implementation simply re-routes the requests to the batch handler.
    @serve.batch will return to this function an individual response.

    Note that we specify the endpoint (e.g. /imagine) through FastAPI.

    Args:
      prompt: An individual prompt.

    Returns:
      A Response.

    """
    return await self.batched_generate_handler(prompt)

In [6]:
@serve.deployment(
    ray_actor_options={
        "resources": {"google.com/tpu": 4}
    },
  ##  autoscaling_config={"min_replicas": 1, "max_replicas": 4},
    )
class StableDiffusion:
  """FLAX Stable Diffusion Ray Serve deployment.

  This is the actual model server that runs on the TPU host.

  Notes:
    - We use custom resources to label a TPU host (note the name will change
      once Ray Cluster on TPUs are standardized..)
    - We can define the number of minimum and maximum replicas to the
      autoscaler.
    - Autoscaler will not be functional in this version (as we're using
      tpu_controller) but should be functional on single TPU hosts using
      the Ray Cluster launcher path OR through Kuberay
    - Regardless of the route, Autoscaling works based on the load.
      Documentation (https://docs.ray.io/en/latest/serve/architecture.html#ray-serve-autoscaling)
      specifies that it is based on the ServeHandle queue and in-flight queries
      for scaling decisions (e.g. I need to dig deeper to better understand).
    - This example "only" uses a single model, but we could start composing
      multiple handles together if we wanted to ensemble, or direct from
      one model server to another.
    - I suspect this could work on multi host TPUs, but not with autoscaling.

  Attributes:
    run_with_profiler: Whether or not to run with the profiler. Note that
      this saves the profile to the separate TPU VM.

  """
  def __init__(self, run_with_profiler: bool = False):
    from diffusers import FlaxStableDiffusionPipeline
    from flax.jax_utils import replicate
    import jax
    import jax.numpy as jnp
    from jax import pmap

    model_id = "CompVis/stable-diffusion-v1-4"
    
    self.pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
        model_id,
        revision="bf16",
        dtype=jnp.bfloat16)

    self.p_params = replicate(params)
    self.p_generate = pmap(self.pipeline._generate)
    self._run_with_profiler = run_with_profiler
    self._profiler_dir = "/tmp/tensorboard"

  def generate(self, prompts: List[str]):
    """Generates a batch of images from Diffusion from a list of prompts.

    Notes:
      - One "sharp edge" is that we need to run imports within the function
        as this function is what is called on the raylet. Outside imports
        cannot be sent over Ray to the raylets.

    Args:
      prompts: a list of strings. Should be a factor of 4.

    Returns:
      A list of PIL Images.

    """
    from flax.training.common_utils import shard
    import jax
    import time
    import numpy as np
    from PIL import Image

    print("sanity check: ", jax.device_count())

    rng = jax.random.PRNGKey(0)
    rng = jax.random.split(rng, jax.device_count())

    assert len(prompts), "prompt parameter cannot be empty"

    print("Prompts: ", prompts)
    prompt_ids = self.pipeline.prepare_inputs(prompts)
    #print("Prompt IDs: ", prompt_ids)
    prompt_ids = shard(prompt_ids)
    print("Sharded prompt ids has shape:", prompt_ids.shape)
    if self._run_with_profiler:
      jax.profiler.start_trace(self._profiler_dir)

    time_start = time.time()
    images = self.p_generate(prompt_ids, self.p_params, rng)
    images = images.block_until_ready()
    elapsed = time.time() - time_start
    if self._run_with_profiler:
      jax.profiler.stop_trace()

    print("Inference time (in seconds): ", elapsed)
    print("Shape of the predictions: ", images.shape)
    images = images.reshape(
        (images.shape[0] * images.shape[1],) + images.shape[-3:])
    print("Shape of images afterwards: ", images.shape)
    return self.pipeline.numpy_to_pil(np.array(images))

In [8]:
diffusion_bound = StableDiffusion.bind()
deployment = APIIngress.bind(diffusion_bound)
serve.run(deployment, host="0.0.0.0")


RayServeSyncHandle(deployment='APIIngress')

In [9]:
import requests
import multiprocessing
import random
from io import BytesIO

In [12]:
def send_request_and_receive_image(prompt: str):
  """Sends a single prompt request and returns the Image."""
  inputs = "%20".join(prompt.split(" "))
  resp = requests.get(f"http://example-cluster-kuberay-head-svc:8000/imagine?prompt={inputs}")
  return BytesIO(resp.content)


def send_requests():
  """Sends a list of requests and processes the responses."""
  prompts = [
      "Labrador in the style of Hokusai",
      "Painting of a squirrel skating in New York",
      "HAL-9000 in the style of Van Gogh",
      "Times Square under water, with fish and a dolphin swimming around",
      "Ancient Roman fresco showing a man working on his laptop",
      "Close-up photograph of young black woman against urban background, high quality, bokeh",
      "Armchair in the shape of an avocado",
      "Clown astronaut in space, with Earth in the background",
  ]
  with multiprocessing.Pool(processes=len(prompts)) as p:
    raw_images = p.map(send_request_and_receive_image, prompts)

  images = [Image.open(raw_image) for raw_image in raw_images]

  def image_grid(imgs, rows, cols):
    w, h = imgs[0].size
    grid = Image.new("RGB", size=(cols * w, rows * h))
    for i, img in enumerate(imgs):
      grid.paste(img, box=(i % cols * w, i // cols * h))
    return grid

  grid = image_grid(images, 2, 4)
  grid.save(f"./diffusion_results.png")

In [13]:
send_requests()