Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enable more ONNX optimizations #241

Closed
ssube opened this issue Mar 12, 2023 · 4 comments · Fixed by #265
Closed

enable more ONNX optimizations #241

ssube opened this issue Mar 12, 2023 · 4 comments · Fixed by #265
Labels
model/diffusion scope/api scope/convert status/progress issues that are in progress and have a branch type/feature new features
Milestone

Comments

@ssube
Copy link
Owner

ssube commented Mar 12, 2023

https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/python/tools/transformers/README.md#model-optimizer

@ssube
Copy link
Owner Author

ssube commented Mar 19, 2023

Running this ahead of time breaks the additional network blending, #264, but the memory difference is incredible:

image

  • 4222MiB for a batch of 1 images at 512x512
  • 7184MiB for a batch of 5 images at 512x512
  • 8208MiB for a batch of 2 images at 1024x1024, peaking ~11GB during decoding

Using the optimization script without the --float16 option and with the --external_data option initially produced an error:

onnxruntime.capi.onnxruntime_pybind11_state.Fail: [ONNXRuntimeError] : 1 : FAIL : Load model from ../models/stable-diffusion-fp32-opt/text_encoder/model.onnx failed:/onnxruntime_src/onnxruntime/core/graph/graph.cc:1348 void onnxruntime::Graph::InitializeStateFromModelFileGraphProto() This is an invalid model. Graph output (input.103) does not exist in the graph.

After copying over a working text encoder, it produces a model that is the same size on disk:

4.0G    ./stable-diffusion-fp32
4.0G    ./stable-diffusion-fp32-opt

but uses substantially less memory during inference:

image

  • 6682MiB for a batch of 1 image at 512x512

@ssube ssube added status/progress issues that are in progress and have a branch and removed status/new issues that have not been confirmed yet labels Mar 19, 2023
@ssube
Copy link
Owner Author

ssube commented Mar 19, 2023

Using the torch-fp16 optimizations without ORT fp16, I was able to run SD v1.5 on a simulated 6GB card:

export ONNX_WEB_MEMORY_LIMIT=6442450944
export ONNX_WEB_OPTIMIZATIONS=torch-fp16,onnx-low-memory,diffusers-attention-slicing

Trying to run a normal fp32 model fails with:

Available memory of 696087552 is smaller than requested bytes of 2147483648 

@ssube ssube mentioned this issue Mar 19, 2023
@Amblyopius
Copy link

My take on memory:

  • I have been doing FP16 for months, initially via onnxconverter-common and now via ORT transformers (the function's code is fairly identical as both projects are maintained by Microsoft).
  • The next step after FP16 is attention slicing (when you start increasing resolution)
  • Next after attention slicing is loading Text Encoder on CPU as CLIP runs fast enough on CPU. From this point you can run on 4GB VRAM
  • Next after that is VAE on CPU. While VAE on CPU is a bigger hit than Text Encoder on CPU, it stops UNET performance from degrading over multiple runs

I do use some parts of ORT transformers for optimisations, but limited to things that are not targeted at specific hardware (the full optimisation in ORT Transformers is CUDA specific). I would have to recheck ORT transformers to see if more generic ONNX optimisations rather than CUDA optimisations have been added, but if not I can't adopt them. There's just too limited interest in CUDA specific ONNX for Stable Diffusion.

@ssube
Copy link
Owner Author

ssube commented Mar 19, 2023

Neat, thanks. I'm curious to see how low I can get it, and CPU offloading is the next thing I need to work on.

Some of the optimizations in ORT are CUDA-specific and/or incompatible with the CPU provider, notably fp16, but they've also been adding a lot of node folding and graph optimization stuff recently. I think those are generic. If I'm reading https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/python/tools/transformers/optimizer.py right, optimize_model is the more general stuff and optimize_by_onnxruntime starts to get hardware-specific. I have not tried the latter yet.

https://huggingface.co/docs/optimum/v1.7.1/en/onnxruntime/usage_guides/optimization has a lot of the node folding as well, and only becomes CUDA-specific at O4. I don't know if you can mix them, but I do know that the ORT optimizer renames a lot of nodes, which breaks my LoRA/Inversion code.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
model/diffusion scope/api scope/convert status/progress issues that are in progress and have a branch type/feature new features
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants