Inspired by: https://github.com/triton-inference-server/server/tree/main/docs/examples/stable_diffusion

- When it comes to multi-model pipeline where we don't have a single model by a bunch of models within a single pipeline, we could pull out those models (nn.Module) and deploy each of them onto different backends (onnx, torchscript, TensorRT) etc and keep the pipeline in the python runtime.

Let's try to compile all of these models in all backends. Let's do it step-by-step

- To build individual models, you need to understand what kind of models are there in the pipeline.
One by one load them individually and compile them as per your required backend.
- For diffusion pipeline, it has `UNet2DConditionModel`(Unet), `CLIPTextModel`(Text Encoder), `AutoencoderKL`(VAE)

- Additional components are `CLIPTokenizer`(Tokenizer) and `DPMSolverMultistepScheduler`(Scheduler). These components can't be compiled so they will stay in Python runtime.

- Check pre-trained weights here: https://huggingface.co/stabilityai/stable-diffusion-2-base/tree/main

In [None]:
from transformers import CLIPTokenizer
from diffusers import DPMSolverMultistepScheduler

In [None]:
tokenizer = CLIPTokenizer.from_pretrained('stabilityai/stable-diffusion-2-base', subfolder="tokenizer")

In [None]:
scheduler = DPMSolverMultistepScheduler.from_pretrained('stabilityai/stable-diffusion-2-base', subfolder="scheduler")

In [None]:
tokenizer.save_pretrained('../weights_sd/tokenizer')
scheduler.save_pretrained('../weights_sd/scheduler')

In [None]:
from diffusers import UNet2DConditionModel, AutoencoderKL
from transformers import CLIPTextModel
import torch
from torch import nn

In [None]:
DEVICE_TYPE = "GPU" if torch.cuda.is_available() else "CPU"

## Onnx

In [None]:
!mkdir -p ../models_sd/onnx/text_encoder/1
!mkdir -p ../models_sd/onnx/vae/1
!mkdir -p ../models_sd/onnx/unet/1
!mkdir -p ../models_sd/onnx/pipeline/1


### Text Encoder

In [None]:
class TextEncoderModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = CLIPTextModel.from_pretrained('stabilityai/stable-diffusion-2-base', subfolder="text_encoder", return_dict=False)
    
    def forward(self, input_ids):
        return self.model(input_ids)[0]

In [None]:
text_encoder = TextEncoderModel()
text_encoder.eval();

In [None]:
prompt = 'real life goku going super saiyan, beautiful landscape, lightning storm, dramatic lightning, cinematic, establishing shot'
text_input = tokenizer(
    prompt,
    padding="max_length",
    max_length=tokenizer.model_max_length,
    truncation=True,
    return_tensors="pt",
)

In [None]:
INPUT_NAMES = ['input_ids']
OUTPUT_NAMES = ['last_hidden_state']

In [None]:
configuration = f"""
name: "text_encoder"
platform: "onnxruntime_onnx"
max_batch_size: 8

input [
  {{
    name: "{INPUT_NAMES[0]}"
    data_type: TYPE_INT32
    dims: [ -1 ]
  }}
]
output [
  {{
    name: "{OUTPUT_NAMES[0]}"
    data_type: TYPE_FP32
    dims: [ -1, 1024 ]
  }}
]

instance_group [
  {{
    kind: KIND_{DEVICE_TYPE}
  }}
]
"""


with open('../models_sd/onnx/text_encoder/config.pbtxt', 'w') as f:
    f.write(configuration)

In [None]:
with torch.inference_mode():
    torch.onnx.export(
        text_encoder,
        text_input.input_ids.to(torch.int32),
        "../models_sd/onnx/text_encoder/1/model.onnx",
        input_names=INPUT_NAMES,
        output_names=OUTPUT_NAMES,
        dynamic_axes={
            "input_ids": {
                0: "batch_size",
                1: "sequence_len"
            },
        },
        opset_version=14,
        do_constant_folding=True,
    )

### Unet

In [None]:
class UnetModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = UNet2DConditionModel.from_pretrained('stabilityai/stable-diffusion-2-base', subfolder="unet", return_dict=False)
        
    def forward(self, latent_model_input, t, prompt_embeds):
        return self.model(latent_model_input, t, encoder_hidden_states=prompt_embeds)[0]

In [None]:
unet = UnetModel()
unet.eval();

In [None]:
INPUT_NAMES = ['latents', 'timestep', 'prompt_embeds']
OUTPUT_NAMES = ['latents_out']

In [None]:
configuration = f"""
name: "unet"
platform: "onnxruntime_onnx"
default_model_filename: "model.onnx"
max_batch_size: 8

input [
  {{
    name: "{INPUT_NAMES[0]}"
    data_type: TYPE_FP32
    dims: [ -1, -1, -1 ]
  }}
]
input [
  {{
    name: "{INPUT_NAMES[1]}"
    data_type: TYPE_INT32
    dims: [ 1 ]
    reshape: {{ shape: [ ] }}
  }}
]
input [
  {{
    name: "{INPUT_NAMES[2]}"
    data_type: TYPE_FP32
    dims: [ -1, 1024 ]
  }}
]
output [
  {{
    name: "{OUTPUT_NAMES[0]}"
    data_type: TYPE_FP32
    dims: [ 4, -1, -1 ]
  }}
]

instance_group [
  {{
    kind: KIND_{DEVICE_TYPE}
  }}
]
"""


with open('../models_sd/onnx/unet/config.pbtxt', 'w') as f:
    f.write(configuration)

In [None]:
from diffusers import OnnxRuntimeModel
unet=OnnxRuntimeModel.from_pretrained('../models_sd/onnx/unet/1/')

In [None]:
!rm -rf ../models_sd/onnx/unet/1/*

In [None]:
with torch.inference_mode():
    torch.onnx.export(
        unet,
        (torch.randn(2, 4, 64, 64),  torch.tensor([7, 7]).int(), torch.randn(2, 77, 1024)),
        "../models_sd/onnx/unet/1/model.onnx",
        input_names=INPUT_NAMES,
        output_names=OUTPUT_NAMES,
        dynamic_axes={
            INPUT_NAMES[0]: {
                0: "batch_size",
                1: "channels",
                2: "height",
                3: "width",
            },
            INPUT_NAMES[1]: {
                0: "batch_size",
            },
            INPUT_NAMES[2]: {
                0: "batch_size",
                1: "sequence_len",
            },
        },
        opset_version=14,
        do_constant_folding=True,
    )

### VAE

In [None]:
class VAEModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = AutoencoderKL.from_pretrained('stabilityai/stable-diffusion-2-base', subfolder="vae", return_dict=False)    

    def forward(self, latents):
        return self.model.decode(latents)[0]

In [None]:
vae = VAEModel()
vae.eval();

In [None]:
INPUT_NAMES = ['latents']
OUTPUT_NAMES = ['image']

In [None]:
configuration = f"""
name: "vae"
platform: "onnxruntime_onnx"
max_batch_size: 8

input [
  {{
    name: "{INPUT_NAMES[0]}"
    data_type: TYPE_FP32
    dims: [ -1, -1, -1 ]
  }}
]
output [
  {{
    name: "{OUTPUT_NAMES[0]}"
    data_type: TYPE_FP32
    dims: [ 3, -1, -1 ]
  }}
]

instance_group [
  {{
    kind: KIND_{DEVICE_TYPE}
  }}
]

"""


with open('../models_sd/onnx/vae/config.pbtxt', 'w') as f:
    f.write(configuration)

In [None]:
with torch.inference_mode():
    torch.onnx.export(
        vae,
        torch.randn(1, 4, 64, 64),
        "../models_sd/onnx/vae/1/model.onnx",
        input_names=INPUT_NAMES,
        output_names=OUTPUT_NAMES,
        dynamic_axes={
            INPUT_NAMES[0]: {
                0: "batch_size",
                1: "channels",
                2: "height",
                3: "width",
            },
        },
        opset_version=14,
        do_constant_folding=True,
    )

## Torchscript

In [None]:
!mkdir -p ../models_sd/torchscript/text_encoder/1
!mkdir -p ../models_sd/torchscript/vae/1
!mkdir -p ../models_sd/torchscript/unet/1
!mkdir -p ../models_sd/torchscript/pipeline/1


### Text Encoder

In [None]:
class TextEncoderModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = CLIPTextModel.from_pretrained('stabilityai/stable-diffusion-2-base', subfolder="text_encoder", return_dict=False)
    
    def forward(self, input_ids):
        return self.model(input_ids)[0]

In [None]:
text_encoder = TextEncoderModel()
text_encoder.eval();

In [None]:
prompt = 'real life goku going super saiyan, beautiful landscape, lightning storm, dramatic lightning, cinematic, establishing shot'
text_input = tokenizer(
    prompt,
    padding="max_length",
    max_length=tokenizer.model_max_length,
    truncation=True,
    return_tensors="pt",
)

In [None]:
INPUT_NAMES = ['input_ids']
OUTPUT_NAMES = ['last_hidden_state']

In [None]:
configuration = f"""
name: "text_encoder"
platform: "pytorch_libtorch"
max_batch_size: 8

input [
  {{
    name: "{INPUT_NAMES[0]}"
    data_type: TYPE_INT32
    dims: [ -1 ]
  }}
]
output [
  {{
    name: "{OUTPUT_NAMES[0]}"
    data_type: TYPE_FP32
    dims: [ -1, 1024 ]
  }}
]

instance_group [
  {{
    kind: KIND_{DEVICE_TYPE}
  }}
]

parameters: {{
    key: "INFERENCE_MODE"
    value: {{ string_value: "true" }}
}}
"""


with open('../models_sd/torchscript/text_encoder/config.pbtxt', 'w') as f:
    f.write(configuration)

In [None]:
with torch.inference_mode():
    traced_script_module = torch.jit.trace(text_encoder, (text_input.input_ids.to(torch.int32),))
    traced_script_module.save('../models_sd/torchscript/text_encoder/1/model.pt')

### Unet

In [None]:
class UnetModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = UNet2DConditionModel.from_pretrained('stabilityai/stable-diffusion-2-base', subfolder="unet", return_dict=False)
        
    def forward(self, latent_model_input, t, prompt_embeds):
        return self.model(latent_model_input, t, encoder_hidden_states=prompt_embeds)[0]

In [None]:
unet = UnetModel()
unet.eval();

In [None]:
INPUT_NAMES = ['latents', 'timestep', 'prompt_embeds']
OUTPUT_NAMES = ['latents_out']

In [None]:
configuration = f"""
name: "unet"
platform: "pytorch_libtorch"
max_batch_size: 8

input [
  {{
    name: "{INPUT_NAMES[0]}"
    data_type: TYPE_FP32
    dims: [ -1, -1, -1 ]
  }}
]
input [
  {{
    name: "{INPUT_NAMES[1]}"
    data_type: TYPE_INT32
    dims: [ 1 ]
    reshape: {{ shape: [ ] }}
  }}
]
input [
  {{
    name: "{INPUT_NAMES[2]}"
    data_type: TYPE_FP32
    dims: [ -1, 1024 ]
  }}
]
output [
  {{
    name: "{OUTPUT_NAMES[0]}"
    data_type: TYPE_FP32
    dims: [ 4, -1, -1 ]
  }}
]

instance_group [
  {{
    kind: KIND_{DEVICE_TYPE}
  }}
]

parameters: {{
    key: "INFERENCE_MODE"
    value: {{ string_value: "true" }}
}}
"""


with open('../models_sd/torchscript/unet/config.pbtxt', 'w') as f:
    f.write(configuration)

In [None]:
with torch.inference_mode():
    traced_script_module = torch.jit.trace(unet, (torch.randn(2, 4, 64, 64), torch.tensor([7, 7]).int(), torch.randn(2, 77, 1024)))
    traced_script_module.save('../models_sd/torchscript/unet/1/model.pt')

### VAE

In [None]:
class VAEModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = AutoencoderKL.from_pretrained('stabilityai/stable-diffusion-2-base', subfolder="vae", return_dict=False)    

    def forward(self, latents):
        return self.model.decode(latents)[0]

In [None]:
vae = VAEModel()
vae.eval();

In [None]:
INPUT_NAMES = ['latents']
OUTPUT_NAMES = ['image']

In [None]:
configuration = f"""
name: "vae"
platform: "pytorch_libtorch"
max_batch_size: 8

input [
  {{
    name: "{INPUT_NAMES[0]}"
    data_type: TYPE_FP32
    dims: [ -1, -1, -1 ]
  }}
]
output [
  {{
    name: "{OUTPUT_NAMES[0]}"
    data_type: TYPE_FP32
    dims: [ 3, -1, -1 ]
  }}
]

instance_group [
  {{
    kind: KIND_{DEVICE_TYPE}
  }}
]

parameters: {{
    key: "INFERENCE_MODE"
    value: {{ string_value: "true" }}
}}
"""


with open('../models_sd/torchscript/vae/config.pbtxt', 'w') as f:
    f.write(configuration)

In [None]:
with torch.inference_mode():
    traced_script_module = torch.jit.trace(vae, (torch.randn(1, 4, 64, 64),))
    traced_script_module.save('../models_sd/torchscript/vae/1/model.pt')

## Pipeline

In [None]:
INPUT_NAMES = ['prompt', 'height', 'width', 'inference_steps']
OUTPUT_NAMES = ['image']

In [None]:
configuration = f"""
name: "pipeline"
backend: "python"
max_batch_size: 0

input [
  {{
    name: "{INPUT_NAMES[0]}"
    data_type: TYPE_STRING
    dims: [ -1 ]
  }}
]
input [
  {{
    name: "{INPUT_NAMES[1]}"
    data_type: TYPE_INT32
    dims: [ 1 ]
  }}
]
input [
  {{
    name: "{INPUT_NAMES[2]}"
    data_type: TYPE_INT32
    dims: [ 1 ]
  }}
]
input [
  {{
    name: "{INPUT_NAMES[3]}"
    data_type: TYPE_INT32
    dims: [ 1 ]
  }}
]
output [
  {{
    name: "{OUTPUT_NAMES[0]}"
    data_type: TYPE_FP32
    dims: [ -1, -1, -1]
  }}
]

instance_group [
  {{
    kind: KIND_{DEVICE_TYPE}
  }}
]

parameters {{
    key: "unet_inchannels"
    value: {{ string_value: "{unet.model.in_channels}" }}
}}

parameters {{
    key: "vae_scale_factor"
    value: {{ string_value: "{2**(len(vae.model.config.block_out_channels) - 1)}" }}
}}
"""


with open('../models_sd/torchscript/pipeline/config.pbtxt', 'w') as f:
    f.write(configuration)

### Inference

In [None]:
!curl -v 0.0.0.0:8000/v2/health/ready

In [None]:
!curl -v 0.0.0.0:8000/v2/models/pipeline

In [None]:
import tritonclient.http as tritonhttpclient
import numpy as np
from PIL import Image

In [None]:
VERBOSE = False
INPUT_NAMES = ['prompt', 'height', 'width', 'inference_steps']
OUTPUT_NAMES = ['image']
INPUT_DTYPES = ['BYTES', 'INT32', 'INT32', 'INT32']
OUTPUT_DTYPES = ['FLOAT32']
model_name = 'pipeline'
url = '0.0.0.0:8000'
model_version = '1'

In [None]:
prompts = ['real life goku going super saiyan, beautiful landscape, lightning storm, dramatic lightning, cinematic, establishing shot']*2
# text_input = tokenizer(
#     prompt,
#     padding="max_length",
#     max_length=tokenizer.model_max_length,
#     truncation=True,
#     return_tensors="pt",
# )

In [None]:
with tritonhttpclient.InferenceServerClient(url=url, verbose=False) as client:
    # Define input config
    inputs = [
        tritonhttpclient.InferInput(INPUT_NAMES[0], (len(prompts),), INPUT_DTYPES[0]),
        tritonhttpclient.InferInput(INPUT_NAMES[1], (1,), INPUT_DTYPES[1]),
        tritonhttpclient.InferInput(INPUT_NAMES[2], (1,), INPUT_DTYPES[2]),
        tritonhttpclient.InferInput(INPUT_NAMES[3], (1,), INPUT_DTYPES[3]),
    ]
    
    # Attach input
    inputs[0].set_data_from_numpy(np.asarray(prompts, dtype=object))
    inputs[1].set_data_from_numpy(np.asarray([512], dtype=np.int32))
    inputs[2].set_data_from_numpy(np.asarray([512], dtype=np.int32))
    inputs[3].set_data_from_numpy(np.asarray([2], dtype=np.int32))
    
    # Define output config
    outputs = [
        tritonhttpclient.InferRequestedOutput(OUTPUT_NAMES[0]),
    ]
    
    # Hit triton server
    response = client.infer(model_name, model_version=model_version, inputs=inputs, outputs=outputs)
    generated_images = response.as_numpy(OUTPUT_NAMES[0])

In [None]:
generated_images = (generated_images*255).round().astype("uint8")
generated_images.shape

In [None]:
Image.fromarray(generated_images[0])

In [None]:
Image.fromarray(generated_images[1])