In [1]:
import onnx
import onnxruntime
import transformers, numpy as np
from torch.nn import functional as F
import torch
from huggingface_hub import snapshot_download
from diffusers import AutoencoderKL, DDIMScheduler, UNet2DConditionModel
from torchinfo import summary
from accelerate import load_checkpoint_in_model
import os

  from .autonotebook import tqdm as notebook_tqdm


In [5]:
class SkipAttnProcessor(torch.nn.Module):
    def __init__(self, *args, **kwargs) -> None:
        super().__init__()

    def __call__(
        self,
        attn,
        hidden_states,
        encoder_hidden_states=None,
        attention_mask=None,
        temb=None,
    ):
        return hidden_states

class AttnProcessor2_0(torch.nn.Module):
    r"""
    Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
    """

    def __init__(
        self,
        hidden_size=None,
        cross_attention_dim=None,
        **kwargs
    ):
        super().__init__()
        if not hasattr(F, "scaled_dot_product_attention"):
            raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")

    def __call__(
        self,
        attn,
        hidden_states,
        encoder_hidden_states=None,
        attention_mask=None,
        temb=None,
        *args,
        **kwargs,
    ):
        residual = hidden_states

        if attn.spatial_norm is not None:
            hidden_states = attn.spatial_norm(hidden_states, temb)

        input_ndim = hidden_states.ndim

        if input_ndim == 4:
            batch_size, channel, height, width = hidden_states.shape
            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)

        batch_size, sequence_length, _ = (
            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
        )

        if attention_mask is not None:
            attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
            # scaled_dot_product_attention expects attention_mask shape to be
            # (batch, heads, source_length, target_length)
            attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])

        if attn.group_norm is not None:
            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)

        query = attn.to_q(hidden_states)

        if encoder_hidden_states is None:
            encoder_hidden_states = hidden_states
        elif attn.norm_cross:
            encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)

        key = attn.to_k(encoder_hidden_states)
        value = attn.to_v(encoder_hidden_states)

        inner_dim = key.shape[-1]
        head_dim = inner_dim // attn.heads

        query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

        key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
        value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

        # the output of sdp = (batch, num_heads, seq_len, head_dim)
        # TODO: add support for attn.scale when we move to Torch 2.1
        
        hidden_states = F.scaled_dot_product_attention(
            query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
        )
        # hidden_states = flash_attn_func(
        #     query, key, value, dropout_p=0.0, causal=False
        # )

        hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
        hidden_states = hidden_states.to(query.dtype)

        # linear proj
        hidden_states = attn.to_out[0](hidden_states)
        # dropout
        hidden_states = attn.to_out[1](hidden_states)

        if input_ndim == 4:
            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)

        if attn.residual_connection:
            hidden_states = hidden_states + residual

        hidden_states = hidden_states / attn.rescale_output_factor

        return hidden_states

In [6]:
def init_adapter(unet, 
                 cross_attn_cls=SkipAttnProcessor,
                 self_attn_cls=None,
                 cross_attn_dim=None, 
                 **kwargs):
    if cross_attn_dim is None:
        cross_attn_dim = unet.config.cross_attention_dim
    attn_procs = {}
    for name in unet.attn_processors.keys():
        cross_attention_dim = None if name.endswith("attn1.processor") else cross_attn_dim
        if name.startswith("mid_block"):
            hidden_size = unet.config.block_out_channels[-1]
        elif name.startswith("up_blocks"):
            block_id = int(name[len("up_blocks.")])
            hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
        elif name.startswith("down_blocks"):
            block_id = int(name[len("down_blocks.")])
            hidden_size = unet.config.block_out_channels[block_id]
        if cross_attention_dim is None:
            if self_attn_cls is not None:
                attn_procs[name] = self_attn_cls(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, **kwargs)
            else:
                # retain the original attn processor
                attn_procs[name] = AttnProcessor2_0(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, **kwargs)
        else:
            attn_procs[name] = cross_attn_cls(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, **kwargs)
                                                    
    unet.set_attn_processor(attn_procs)
    adapter_modules = torch.nn.ModuleList(unet.attn_processors.values())
    return adapter_modules

In [7]:
def get_trainable_module(unet, trainable_module_name):
    if trainable_module_name == "unet":
        return unet
    elif trainable_module_name == "transformer":
        trainable_modules = torch.nn.ModuleList()
        for blocks in [unet.down_blocks, unet.mid_block, unet.up_blocks]:
            if hasattr(blocks, "attentions"):
                trainable_modules.append(blocks.attentions)
            else:
                for block in blocks:
                    if hasattr(block, "attentions"):
                        trainable_modules.append(block.attentions)
        return trainable_modules
    elif trainable_module_name == "attention":
        attn_blocks = torch.nn.ModuleList()
        for name, param in unet.named_modules():
            if "attn1" in name:
                attn_blocks.append(param)
        return attn_blocks
    else:
        raise ValueError(f"Unknown trainable_module_name: {trainable_module_name}")

In [8]:
def auto_attn_ckpt_load(attn_ckpt, version):
    sub_folder = {
        "mix": "mix-48k-1024",
        "vitonhd": "vitonhd-16k-512",
        "dresscode": "dresscode-16k-512",
    }[version]
    if os.path.exists(attn_ckpt):
        load_checkpoint_in_model(attn_modules, os.path.join(attn_ckpt, sub_folder, 'attention'))
    else:
        repo_path = snapshot_download(repo_id=attn_ckpt)
        print(f"Downloaded {attn_ckpt} to {repo_path}")
        load_checkpoint_in_model(attn_modules, os.path.join(repo_path, sub_folder, 'attention'))

In [9]:
base_model_path = f"booksforcharlie/stable-diffusion-inpainting"

In [10]:
repo_path = snapshot_download("zhengchong/CatVTON")

Fetching 12 files: 100%|█████████████████████████████████████████████████████████████████████████| 12/12 [00:00<00:00, 134937.39it/s]


In [11]:
vae_model_path = f"stabilityai/sd-vae-ft-mse"

In [12]:
unet = UNet2DConditionModel.from_pretrained(base_model_path, subfolder="unet").to('cuda')

An error occurred while trying to fetch booksforcharlie/stable-diffusion-inpainting: booksforcharlie/stable-diffusion-inpainting does not appear to have a file named diffusion_pytorch_model.safetensors.
Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead.
  return self.fget.__get__(instance, owner)()


In [13]:
init_adapter(unet, cross_attn_cls=SkipAttnProcessor)

ModuleList(
  (0): AttnProcessor2_0()
  (1): SkipAttnProcessor()
  (2): AttnProcessor2_0()
  (3): SkipAttnProcessor()
  (4): AttnProcessor2_0()
  (5): SkipAttnProcessor()
  (6): AttnProcessor2_0()
  (7): SkipAttnProcessor()
  (8): AttnProcessor2_0()
  (9): SkipAttnProcessor()
  (10): AttnProcessor2_0()
  (11): SkipAttnProcessor()
  (12): AttnProcessor2_0()
  (13): SkipAttnProcessor()
  (14): AttnProcessor2_0()
  (15): SkipAttnProcessor()
  (16): AttnProcessor2_0()
  (17): SkipAttnProcessor()
  (18): AttnProcessor2_0()
  (19): SkipAttnProcessor()
  (20): AttnProcessor2_0()
  (21): SkipAttnProcessor()
  (22): AttnProcessor2_0()
  (23): SkipAttnProcessor()
  (24): AttnProcessor2_0()
  (25): SkipAttnProcessor()
  (26): AttnProcessor2_0()
  (27): SkipAttnProcessor()
  (28): AttnProcessor2_0()
  (29): SkipAttnProcessor()
  (30): AttnProcessor2_0()
  (31): SkipAttnProcessor()
)

In [14]:
attn_modules = get_trainable_module(unet, "attention")

In [15]:
auto_attn_ckpt_load(repo_path, "mix")

In [13]:
summary(noise_scheduler)

NameError: name 'noise_scheduler' is not defined

In [4]:
device = "cuda"

In [58]:
target_h, target_w = 768, 1024
latent_h, latent_w = target_h // 8, target_w // 8
timestep = torch.tensor([0, 0]).to(device)          # or an int, or a 1-D tensor
B = 1
text_emb = torch.randn(B, 77, 768).to(device) 

In [5]:
latents  = torch.randn(1, 9, 256, 96).to(device)
timestep = torch.tensor([0]).to(device)   
text_emb = None

In [15]:
timestep.shape

torch.Size([1, 1])

In [14]:
device = "cuda"
unet.to(device)
B, H, W = 1, 256, 96                # 64×64 latents ⇒ 512×512 output after VAE
cross_dim = unet.config.cross_attention_dim  # 768 for SD-1.5 / 1280 for SD-XL

# ----- 2.  Dummy inputs ----------------------------------------------------
dummy_inputs = {
    "sample": torch.randn(B, 9, H, W).to("cuda"),            # latent tensor
    "timestep": torch.tensor([0]).to("cuda"),                # or an int 0
    "encoder_hidden_states": None,
}

out = unet(
 **dummy_inputs
)

print(out)

UNet2DConditionOutput(sample=tensor([[[[-0.0194,  0.6650, -0.6582,  ...,  0.0899,  0.0390, -0.3538],
          [-0.4746, -0.0834,  0.0643,  ...,  0.4412,  0.3240, -0.0194],
          [ 0.8023, -0.2351, -0.1483,  ..., -0.5015, -0.7761,  0.2852],
          ...,
          [ 0.2856,  0.3402, -0.8140,  ...,  0.0404, -1.4999, -0.3887],
          [-0.3153,  0.0259,  0.2955,  ..., -0.6490, -1.6492,  0.6452],
          [-0.2383, -0.5000,  0.4734,  ..., -0.1951,  0.0021,  0.1873]],

         [[-0.3322,  1.0846, -0.0256,  ...,  0.2241, -0.5634,  0.4134],
          [ 0.5437,  0.0455,  0.0151,  ...,  0.4621,  0.3210, -0.3693],
          [ 0.6319, -0.5658,  0.6428,  ..., -0.2190,  0.6855,  0.0678],
          ...,
          [ 0.3675, -0.3273,  0.1165,  ..., -0.0463,  0.2239, -0.8052],
          [-0.6482, -0.1556,  0.4305,  ..., -0.4057,  0.3460,  0.6349],
          [ 0.2523,  0.0934,  0.5658,  ...,  0.2503,  0.1444, -0.8433]],

         [[ 0.2852,  0.7701, -0.6358,  ..., -1.3040, -0.4034, -0.2569],
 

In [17]:
out.sample.shape

torch.Size([1, 4, 256, 96])

In [None]:
dynamic_axes = 

In [18]:
torch.onnx.export(
    unet,
    (latents, timestep, text_emb),
    "unet.onnx",
    opset_version=17,
    input_names=["sample", "timestep"],
    output_names = ["out_sample"],
    dynamic_axes={
        "sample": {0: "batch", 2: "height", 3: "width"},
        "timestep": {0: "batch"},
        "out_sample": {0: "batch", 2: "height", 3: "width"},
    },
    do_constant_folding=True,
)

  assert hidden_states.shape[1] == self.channels
  assert hidden_states.shape[1] == self.channels
  assert hidden_states.shape[1] == self.channels
  if hidden_states.shape[0] >= 64:
  if not return_dict:


In [23]:
!ls unet*

unet.onnx


In [2]:
onnx_model = onnx.load("unet.onnx")

In [4]:
print(onnx.checker.check_model("unet.onnx"))

None


In [6]:
ort_session = onnxruntime.InferenceSession("unet.onnx", providers=["CUDAExecutionProvider"])

[0;93m2025-07-13 18:18:34.624263338 [W:onnxruntime:, transformer_memcpy.cc:83 ApplyImpl] 16 Memcpy nodes are added to the graph main_graph for CUDAExecutionProvider. It might have negative impact on performance (including unable to run CUDA graph). Set session_options.log_severity_level=1 to see the detail logs before this message.[m
[0;93m2025-07-13 18:18:34.629083935 [W:onnxruntime:, session_state.cc:1280 VerifyEachNodeIsAssignedToAnEp] Some nodes were not assigned to the preferred execution providers which may or may not have an negative impact on performance. e.g. ORT explicitly assigns shape related ops to CPU to improve perf.[m
[0;93m2025-07-13 18:18:34.629090229 [W:onnxruntime:, session_state.cc:1282 VerifyEachNodeIsAssignedToAnEp] Rerunning with verbose output on a non-minimal build will show node assignments.[m


In [7]:
def to_numpy(tensor):
    return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()

In [8]:
ort_inputs = {
    "sample": to_numpy(latents),
    "timestep": to_numpy(timestep)
}

In [9]:
ort_outs = ort_session.run(None, ort_inputs)

[1;31m2025-07-13 18:19:30.296926007 [E:onnxruntime:, sequential_executor.cc:572 ExecuteKernel] Non-zero status code returned while running Softmax node. Name:'/down_blocks.0/attentions.0/transformer_blocks.0/attn1/Softmax' Status Message: /onnxruntime_src/onnxruntime/core/framework/bfc_arena.cc:376 void* onnxruntime::BFCArena::AllocateRawInternal(size_t, bool, onnxruntime::Stream*, bool, onnxruntime::WaitNotificationFn) Failed to allocate memory for requested buffer of size 19327352832
[m


RuntimeException: [ONNXRuntimeError] : 6 : RUNTIME_EXCEPTION : Non-zero status code returned while running Softmax node. Name:'/down_blocks.0/attentions.0/transformer_blocks.0/attn1/Softmax' Status Message: /onnxruntime_src/onnxruntime/core/framework/bfc_arena.cc:376 void* onnxruntime::BFCArena::AllocateRawInternal(size_t, bool, onnxruntime::Stream*, bool, onnxruntime::WaitNotificationFn) Failed to allocate memory for requested buffer of size 19327352832


In [1]:
!pip uninstall onnxruntime -y

Found existing installation: onnxruntime 1.22.1
Uninstalling onnxruntime-1.22.1:
  Successfully uninstalled onnxruntime-1.22.1


In [2]:
pip install --upgrade onnxruntime-gpu

Collecting onnxruntime-gpu
  Downloading onnxruntime_gpu-1.22.0-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (4.9 kB)
Downloading onnxruntime_gpu-1.22.0-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (283.2 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m283.2/283.2 MB[0m [31m46.1 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hInstalling collected packages: onnxruntime-gpu
Successfully installed onnxruntime-gpu-1.22.0
Note: you may need to restart the kernel to use updated packages.


## Quantization

In [2]:
from onnxruntime.quantization import quantize_dynamic, QuantType

In [12]:
!python3 -m onnxruntime.quantization.preprocess --input /home/razaare/unet.onnx --output /home/razaare/unet_infer/unet-infer.onnx --skip_symbolic_shape 1

usage: preprocess.py [-h] --input INPUT --output OUTPUT
                     [--skip_optimization SKIP_OPTIMIZATION]
                     [--skip_onnx_shape SKIP_ONNX_SHAPE]
                     [--skip_symbolic_shape SKIP_SYMBOLIC_SHAPE]
                     [--auto_merge] [--int_max INT_MAX] [--guess_output_rank]
                     [--verbose VERBOSE] [--save_as_external_data]
                     [--all_tensors_to_one_file]
                     [--external_data_location EXTERNAL_DATA_LOCATION]
                     [--external_data_size_threshold EXTERNAL_DATA_SIZE_THRESHOLD]

Model optimizer and shape inferencer, in preparation for quantization,
Consists of three optional steps: 1. Symbolic shape inference (best for
transformer models). 2. Model optimization. 3. ONNX shape inference. Model
quantization with QDQ format, i.e. inserting QuantizeLinear/DeQuantizeLinear
on the tensor, requires tensor shape information to perform its best.
Currently, shape inferencing works best with op

In [15]:
quantize_dynamic(
    model_input="/home/razaare/unet.onnx",
    model_output="/home/razaare/unet_int8/unet_int8.onnx",
    per_channel=True,                          
    weight_type=QuantType.QInt8,
    use_external_data_format=True
)



In [13]:
help(quantize_dynamic)

Help on function quantize_dynamic in module onnxruntime.quantization.quantize:

quantize_dynamic(model_input: 'str | Path | onnx.ModelProto', model_output: 'str | Path', op_types_to_quantize=None, per_channel=False, reduce_range=False, weight_type=<QuantType.QInt8: 0>, nodes_to_quantize=None, nodes_to_exclude=None, use_external_data_format=False, extra_options=None)
    Given an onnx model, create a quantized onnx model and save it into a file
    
    Args:
        model_input: file path of model or ModelProto to quantize
        model_output: file path of quantized model
        op_types_to_quantize:
            specify the types of operators to quantize, like ['Conv'] to quantize Conv only.
            It quantizes all supported operators by default.
        per_channel: quantize weights per channel
        reduce_range:
            quantize weights with 7-bits. It may improve the accuracy for some models running on non-VNNI machine,
            especially for per-channel mode
     

In [1]:
import os
import torch
os.environ["ORT_DISABLE_MEM_PATTERN"]  = "1"
os.environ["ORT_DISABLE_CPU_MEM_ARENA"] = "1"
os.environ["ORT_DISABLE_PREPACKING"]    = "1"

from diffusers import AutoencoderKL, DDIMScheduler
import numpy as np
from pathlib import Path
from onnxruntime.quantization import (
    QuantFormat, QuantType, CalibrationMethod,
    quantize_static, CalibrationDataReader,
)

import onnxruntime as ort

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
sess_options = ort.SessionOptions()
sess_options.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL       
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_BASIC
sess_options.graph_optimization_level = (
    ort.GraphOptimizationLevel.ORT_ENABLE_EXTENDED)
sess_options.optimized_model_filepath = "unet_ort_opt.onnx"
ort.InferenceSession("/home/razaare/unet.onnx", sess_options, providers=["CPUExecutionProvider"])

<onnxruntime.capi.onnxruntime_inference_collection.InferenceSession at 0x7fd8d82276d0>

In [5]:
so = ort.SessionOptions()
so.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL       
so.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_BASIC

In [5]:
base_ckpt = "booksforcharlie/stable-diffusion-inpainting"

In [6]:
noise_scheduler = DDIMScheduler.from_pretrained(base_ckpt, subfolder="scheduler")

In [7]:
path = Path("/home/razaare/calibration_dataset")

In [8]:
def build_calib_set(out_dir: Path, n_samples: int, sched: DDIMScheduler, seed: int | None = None):
    rng = np.random.default_rng(seed)
    out_dir.mkdir(parents=True, exist_ok=True)

    h, w = 256, 96
    for i in range(n_samples):
        t = rng.integers(0, sched.num_train_timesteps, dtype=np.int64)
        t_tensor = torch.tensor([t], dtype=torch.long)
    
        eps = torch.randn(1, 4, h, w, dtype=torch.float32)
        x_t = sched.add_noise(torch.zeros_like(eps), eps, t_tensor)
    
        mask_ratio = rng.uniform(0.1, 0.9)
        mask = (torch.rand(1, 1, h, w) < mask_ratio).float()
        x0 = torch.randn_like(x_t)
        masked = x0 * (1.0 - mask)
    
        latent = torch.cat([x_t, masked, mask], dim=1)  # 1×9×256×96
    
        np.savez_compressed(
            out_dir / f"s{i:05}.npz",
            latent=latent.numpy(),
            t=np.array([t], dtype=np.int64),
        )

In [11]:
build_calib_set(path, 512, noise_scheduler, 42)

In [4]:
class InpaintReader(CalibrationDataReader):
    def __init__(self, root: Path):
        self.files = list(root.glob("*.npz"))
        self._it = iter(self.files)

    def get_next(self):
        try:
            f = next(self._it)
        except StopIteration:
            return None
        d = np.load(f)
        return {"sample": d["latent"], "timestep": d["t"]}

In [20]:
help(quantize_static)

Help on function quantize_static in module onnxruntime.quantization.quantize:

quantize_static(model_input: 'str | Path | onnx.ModelProto', model_output: 'str | Path', calibration_data_reader: 'CalibrationDataReader', quant_format=<QuantFormat.QDQ: 1>, op_types_to_quantize=None, per_channel=False, reduce_range=False, activation_type=<QuantType.QInt8: 0>, weight_type=<QuantType.QInt8: 0>, nodes_to_quantize=None, nodes_to_exclude=None, use_external_data_format=False, calibrate_method=<CalibrationMethod.MinMax: 0>, calibration_providers=None, extra_options=None)
    Given an onnx model and calibration data reader, create a quantized onnx model and save it into a file
    It is recommended to use QuantFormat.QDQ format from 1.11 with activation_type = QuantType.QInt8 and weight_type
    = QuantType.QInt8. If model is targeted to GPU/TRT, symmetric activation and weight are required. If model is
    targeted to CPU, asymmetric activation and symmetric weight are recommended for balance of p

In [None]:
reader = InpaintReader(Path("/home/razaare/calibration_dataset"))
quantize_static(
    model_input="/home/razaare/unet_ort_opt.onnx",
    model_output="home/razaare/unet_int8_static/unet_int8_static.onnx",
    calibration_data_reader=reader,
    quant_format=QuantFormat.QDQ,
    weight_type=QuantType.QInt8,
    activation_type=QuantType.QInt8,
    calibrate_method=CalibrationMethod.MinMax,
    use_external_data_format=True,
    per_channel=False,
    extra_options={
        "CalibMaxIntermediateOutputs": 2,
        "CalibMovingAverage": True,         
        "CalibMovingAverageConstant": 0.01,
        "ActivationSymmetric": True,
        "WeightSymmetric": True,
        "DisableShapeInference": True,
    },
    calibration_providers=["CPUExecutionProvider"],
    verbose=True
)

