diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index b173ebed7..749366e75 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -176,6 +176,7 @@ jobs: build-and-push-docker-images: name: Build and push container images + if: ${{ github.event_name != 'pull_request' }} runs-on: ubuntu-latest permissions: diff --git a/CMakeLists.txt b/CMakeLists.txt index 1bce6ca4c..6a9fb1041 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -11,6 +11,10 @@ endif() if (MSVC) add_compile_definitions(_CRT_SECURE_NO_WARNINGS) add_compile_definitions(_SILENCE_CXX17_CODECVT_HEADER_DEPRECATION_WARNING) + add_compile_options( + $<$:/MP> + $<$:/MP> + ) endif() set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin) @@ -152,10 +156,12 @@ endif() set(SD_LIB stable-diffusion) -file(GLOB SD_LIB_SOURCES +file(GLOB SD_LIB_SOURCES CONFIGURE_DEPENDS "src/*.h" "src/*.cpp" "src/*.hpp" + "src/model_io/*.h" + "src/model_io/*.cpp" "src/tokenizers/*.h" "src/tokenizers/*.cpp" "src/tokenizers/vocab/*.h" diff --git a/README.md b/README.md index fbed50d24..8afdeb20a 100644 --- a/README.md +++ b/README.md @@ -57,6 +57,7 @@ API and command-line option may change frequently.*** - [Z-Image](./docs/z_image.md) - [Ovis-Image](./docs/ovis_image.md) - [Anima](./docs/anima.md) + - [ERNIE-Image](./docs/ernie_image.md) - Image Edit Models - [FLUX.1-Kontext-dev](./docs/kontext.md) - [Qwen Image Edit series](./docs/qwen_image_edit.md) @@ -76,9 +77,10 @@ API and command-line option may change frequently.*** - OpenCL - SYCL - Supported weight formats - - Pytorch checkpoint (`.ckpt` or `.pth`) + - Pytorch checkpoint (`.ckpt` or `.pth` or `.pt`) - Safetensors (`.safetensors`) - GGUF (`.gguf`) +- Convert mode supports converting model weights to `.gguf` or `.safetensors` - Supported platforms - Linux - Mac OS @@ -96,6 +98,7 @@ API and command-line option may change frequently.*** - `DPM++ 2M` - [`DPM++ 2M v2`](https://github.com/AUTOMATIC1111/stable-diffusion-webui/discussions/8457) - `DPM++ 2S a` + - `ER-SDE` - [`LCM`](https://github.com/AUTOMATIC1111/stable-diffusion-webui/issues/13952) - Cross-platform reproducibility - `--rng cuda`, default, consistent with the `stable-diffusion-webui GPU RNG` @@ -144,6 +147,7 @@ If you want to improve performance or reduce VRAM/RAM usage, please refer to [pe - [🔥Z-Image](./docs/z_image.md) - [Ovis-Image](./docs/ovis_image.md) - [Anima](./docs/anima.md) +- [ERNIE-Image](./docs/ernie_image.md) - [LoRA](./docs/lora.md) - [LCM/LCM-LoRA](./docs/lcm.md) - [Using PhotoMaker to personalize image generation](./docs/photo_maker.md) diff --git a/assets/ernie_image/example.png b/assets/ernie_image/example.png new file mode 100644 index 000000000..3f5ed652f Binary files /dev/null and b/assets/ernie_image/example.png differ diff --git a/assets/ernie_image/turbo_example.png b/assets/ernie_image/turbo_example.png new file mode 100644 index 000000000..15318b3e9 Binary files /dev/null and b/assets/ernie_image/turbo_example.png differ diff --git a/docs/distilled_sd.md b/docs/distilled_sd.md index 3174b18f8..7aa8fbede 100644 --- a/docs/distilled_sd.md +++ b/docs/distilled_sd.md @@ -87,51 +87,32 @@ pipe.save_pretrained("segmindtiny-sd", safe_serialization=True) ```bash python convert_diffusers_to_original_stable_diffusion.py \ --model_path ./segmindtiny-sd \ - --checkpoint_path ./segmind_tiny-sd.ckpt --half + --checkpoint_path ./segmind_tiny-sd.safetensors --half --use_safetensors ``` -The file segmind_tiny-sd.ckpt will be generated and is now ready for use with sd.cpp. You can follow a similar process for the other models mentioned above. +The file segmind_tiny-sd.safetensors will be generated and is now ready for use with sd.cpp. You can follow a similar process for the other models mentioned above. -##### Another available .ckpt file: - - * https://huggingface.co/ClashSAN/small-sd/resolve/main/tinySDdistilled.ckpt - -To use this file, you must first adjust its non-contiguous tensors: - -```python -import torch -ckpt = torch.load("tinySDdistilled.ckpt", map_location=torch.device('cpu')) -for key, value in ckpt['state_dict'].items(): - if isinstance(value, torch.Tensor): - ckpt['state_dict'][key] = value.contiguous() -torch.save(ckpt, "tinySDdistilled_fixed.ckpt") -``` - - -### SDXS-512 +### SDXS-512-DreamShaper Another very tiny and **incredibly fast** model is SDXS by IDKiro et al. The authors refer to it as *"Real-Time One-Step Latent Diffusion Models with Image Conditions"*. For details read the paper: https://arxiv.org/pdf/2403.16627 . Once again the authors removed some more blocks of U-Net part and unlike other SD1 models they use an adjusted _AutoEncoderTiny_ instead of default _AutoEncoderKL_ for the VAE part. +##### Some ready-to-run SDXS-512 model files are available online, such as: -##### 1. Download the diffusers model from Hugging Face using Python: - -```python -from diffusers import StableDiffusionPipeline -pipe = StableDiffusionPipeline.from_pretrained("IDKiro/sdxs-512-dreamshaper") -pipe.save_pretrained(save_directory="sdxs") -``` -##### 2. Create a safetensors file - -```bash -python convert_diffusers_to_original_stable_diffusion.py \ - --model_path sdxs --checkpoint_path sdxs.safetensors --half --use_safetensors -``` - -##### 3. Run the model as follows: +* https://huggingface.co/akleine/sdxs-512 +* https://huggingface.co/concedo/sdxs-512-tinySDdistilled-GGUF +##### Run the model as follows: ```bash ~/stable-diffusion.cpp/build/bin/sd-cli -m sdxs.safetensors -p "portrait of a lovely cat" \ --cfg-scale 1 --steps 1 ``` +Both options: ``` --cfg-scale 1 ``` and ``` --steps 1 ``` are mandatory here. + +### SDXS-512-0.9 + +Even though the name "SDXS-512-0.9" is similar to "SDXS-512-DreamShaper", it is *completely different* but also **incredibly fast**. Sometimes it is preferred, so try it yourself. +##### Download a ready-to-run file from here: + +* https://huggingface.co/akleine/sdxs-09 -Both options: ``` --cfg-scale 1 ``` and ``` --steps 1 ``` are mandatory here. +For the use of this model, both options ``` --cfg-scale 1 ``` and ``` --steps 1 ``` are again absolutely necessary. diff --git a/docs/ernie_image.md b/docs/ernie_image.md new file mode 100644 index 000000000..d68da3966 --- /dev/null +++ b/docs/ernie_image.md @@ -0,0 +1,35 @@ +# How to Use + +You can run ERNIE-Image with stable-diffusion.cpp on GPUs with 4GB of VRAM — or even less. + +## Download weights + +- Download ERNIE-Image-Turbo + - safetensors: https://huggingface.co/Comfy-Org/ERNIE-Image/tree/main/diffusion_models + - gguf: https://huggingface.co/unsloth/ERNIE-Image-Turbo-GGUF/tree/main +- Download ERNIE-Image + - safetensors: https://huggingface.co/Comfy-Org/ERNIE-Image/tree/main/diffusion_models + - gguf: https://huggingface.co/unsloth/ERNIE-Image-GGUF/tree/main +- Download vae + - safetensors: https://huggingface.co/Comfy-Org/ERNIE-Image/tree/main/vae +- Download ministral 3b + - safetensors: https://huggingface.co/Comfy-Org/ERNIE-Image/tree/main/text_encoders + - gguf: https://huggingface.co/unsloth/Ministral-3-3B-Instruct-2512-GGUF/tree/main + +## Examples + +### ERNIE-Image-Turbo + +``` +.\bin\Release\sd-cli.exe --diffusion-model ..\..\ComfyUI\models\diffusion_models\ernie-image-turbo.safetensors --vae ..\..\ComfyUI\models\vae\flux2_ae.safetensors --llm ..\..\ComfyUI\models\text_encoders\ministral-3-3b.safetensors -p "a lovely cat" --cfg-scale 1.0 --steps 8 -v --offload-to-cpu --diffusion-fa +``` + +ERNIE-Image Turbo example + +### ERNIE-Image + +``` +.\bin\Release\sd-cli.exe --diffusion-model ..\..\ComfyUI\models\diffusion_models\ernie-image-UD-Q4_K_M.gguf --vae ..\..\ComfyUI\models\vae\flux2_ae.safetensors --llm ..\..\ComfyUI\models\text_encoders\ministral-3-3b.safetensors -p "a lovely cat" --cfg-scale 5.0 -v --offload-to-cpu --diffusion-fa +``` + +ERNIE-Image example diff --git a/examples/cli/README.md b/examples/cli/README.md index 289cb866a..2e9c75ecd 100644 --- a/examples/cli/README.md +++ b/examples/cli/README.md @@ -14,6 +14,9 @@ CLI Options: --metadata-format metadata output format, one of [text, json] (default: text) --canny apply canny preprocessor (edge detection) --convert-name convert tensor name (for convert mode) + convert mode writes `.gguf` or `.safetensors` based on the output extension. + `.safetensors` export currently supports f16, bf16, f32, and i32 tensor types only. + i32 is passthrough only; no f32 <-> i32 conversion is performed -v, --verbose print extra info --color colors the logging tags according to level --taesd-preview-only prevents usage of taesd for decoding the final image. (for use with --preview tae) @@ -114,7 +117,7 @@ Generation Options: medium --skip-layer-start SLG enabling point (default: 0.01) --skip-layer-end SLG disabling point (default: 0.2) - --eta noise multiplier (default: 0 for ddim_trailing, tcd, res_multistep and res_2s; 1 for euler_a and dpm++2s_a) + --eta noise multiplier (default: 0 for ddim_trailing, tcd, res_multistep and res_2s; 1 for euler_a, er_sde and dpm++2s_a) --flow-shift shift value for Flow models like SD3.x or WAN (default: auto) --high-noise-cfg-scale (high noise) unconditional guidance scale: (default: 7.0) --high-noise-img-cfg-scale (high noise) image guidance scale for inpaint or instruct-pix2pix models (default: same as --cfg-scale) @@ -122,7 +125,7 @@ Generation Options: --high-noise-slg-scale (high noise) skip layer guidance (SLG) scale, only for DiT models: (default: 0) --high-noise-skip-layer-start (high noise) SLG enabling point (default: 0.01) --high-noise-skip-layer-end (high noise) SLG disabling point (default: 0.2) - --high-noise-eta (high noise) noise multiplier (default: 0 for ddim_trailing, tcd, res_multistep and res_2s; 1 for euler_a and dpm++2s_a) + --high-noise-eta (high noise) noise multiplier (default: 0 for ddim_trailing, tcd, res_multistep and res_2s; 1 for euler_a, er_sde and dpm++2s_a) --strength strength for noising/unnoising (default: 0.75) --pm-style-strength --control-strength strength to apply Control Net (default: 0.9). 1.0 corresponds to full destruction of information in init image @@ -133,10 +136,10 @@ Generation Options: --disable-image-metadata do not embed generation metadata on image files -s, --seed RNG seed (default: 42, use random seed for < 0) --sampling-method sampling method, one of [euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm, ddim_trailing, - tcd, res_multistep, res_2s] (default: euler for Flux/SD3/Wan, euler_a + tcd, res_multistep, res_2s, er_sde] (default: euler for Flux/SD3/Wan, euler_a otherwise) --high-noise-sampling-method (high noise) sampling method, one of [euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm, - ddim_trailing, tcd, res_multistep, res_2s] default: euler for Flux/SD3/Wan, + ddim_trailing, tcd, res_multistep, res_2s, er_sde] default: euler for Flux/SD3/Wan, euler_a otherwise --scheduler denoiser sigma scheduler, one of [discrete, karras, exponential, ays, gits, smoothstep, sgm_uniform, simple, kl_optimal, lcm, bong_tangent], default: discrete diff --git a/examples/common/common.cpp b/examples/common/common.cpp index 5826b8357..8e09c39a1 100644 --- a/examples/common/common.cpp +++ b/examples/common/common.cpp @@ -841,6 +841,22 @@ ArgOptions SDGenerationParams::get_options() { "--guidance", "distilled guidance scale for models with guidance input (default: 3.5)", &sample_params.guidance.distilled_guidance}, + {"", + "--apg-eta", + "parallel projected guidance scale for APG (default: 1.0, recommended: between 0 and 1)", + &sample_params.guidance.apg.eta}, + {"", + "--apg-momentum", + "momentum for guidance adjustments with APG (default: 0, recommended: around -0.5 (negative))", + &sample_params.guidance.apg.momentum}, + {"", + "--apg-nt", + "APG norm threshold: Upper bound allowed for the amplitude (L2 norm) of guidance updates (default: 0 = disabled, recommended: 4-15)", + &sample_params.guidance.apg.norm_threshold}, + {"", + "--apg-nt-smoothing", + "Norm threshold smoothing for APG, smoothly decrease the amplitude of the guidance update if it gets too close to the norm threshold (experimental; default: 0 = disabled)", + &sample_params.guidance.apg.norm_threshold_smoothing}, {"", "--slg-scale", "skip layer guidance (SLG) scale, only for DiT models: (default: 0). 0 means disabled, a value of 2.5 is nice for sd3.5 medium", @@ -855,7 +871,7 @@ ArgOptions SDGenerationParams::get_options() { &sample_params.guidance.slg.layer_end}, {"", "--eta", - "noise multiplier (default: 0 for ddim_trailing, tcd, res_multistep and res_2s; 1 for euler_a and dpm++2s_a)", + "noise multiplier (default: 0 for ddim_trailing, tcd, res_multistep and res_2s; 1 for euler_a, er_sde and dpm++2s_a)", &sample_params.eta}, {"", "--flow-shift", @@ -887,7 +903,7 @@ ArgOptions SDGenerationParams::get_options() { &high_noise_sample_params.guidance.slg.layer_end}, {"", "--high-noise-eta", - "(high noise) noise multiplier (default: 0 for ddim_trailing, tcd, res_multistep and res_2s; 1 for euler_a and dpm++2s_a)", + "(high noise) noise multiplier (default: 0 for ddim_trailing, tcd, res_multistep and res_2s; 1 for euler_a, er_sde and dpm++2s_a)", &high_noise_sample_params.eta}, {"", "--strength", @@ -931,6 +947,18 @@ ArgOptions SDGenerationParams::get_options() { "do not embed generation metadata on image files", false, &embed_image_metadata}, + {"", + "--slg-uncond", + "use CFG's forward pass for skip layer guidance (SLG) instead of a separate pass, only for DiT models (recommended to keep slg-scale at 0 if enabled)", + true, + &sample_params.guidance.slg.uncond + }, + {"", + "--high-noise-slg-uncond", + "(high noise) use CFG's forward pass for skip layer guidance (SLG) instead of a separate pass, only for DiT models (recommended to keep slg-scale at 0 if enabled)", + true, + &high_noise_sample_params.guidance.slg.uncond + }, {"", "--vae-tiling", "process vae in tiles to reduce memory usage", @@ -1185,12 +1213,12 @@ ArgOptions SDGenerationParams::get_options() { on_seed_arg}, {"", "--sampling-method", - "sampling method, one of [euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm, ddim_trailing, tcd, res_multistep, res_2s] " + "sampling method, one of [euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm, ddim_trailing, tcd, res_multistep, res_2s, er_sde] " "(default: euler for Flux/SD3/Wan, euler_a otherwise)", on_sample_method_arg}, {"", "--high-noise-sampling-method", - "(high noise) sampling method, one of [euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm, ddim_trailing, tcd, res_multistep, res_2s]" + "(high noise) sampling method, one of [euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm, ddim_trailing, tcd, res_multistep, res_2s, er_sde]" " default: euler for Flux/SD3/Wan, euler_a otherwise", on_high_noise_sample_method_arg}, {"", @@ -1531,6 +1559,21 @@ bool SDGenerationParams::from_json_str( if (guidance_json.contains("distilled_guidance") && guidance_json["distilled_guidance"].is_number()) { target_params.guidance.distilled_guidance = guidance_json["distilled_guidance"]; } + if (guidance_json.contains("apg") && guidance_json["apg"].is_object()) { + const json& apg_json = guidance_json["apg"]; + if (apg_json.contains("eta") && apg_json["eta"].is_number()) { + target_params.guidance.apg.eta = apg_json["eta"]; + } + if (apg_json.contains("momentum") && apg_json["momentum"].is_number()) { + target_params.guidance.apg.momentum = apg_json["momentum"]; + } + if (apg_json.contains("norm_threshold") && apg_json["norm_threshold"].is_number()) { + target_params.guidance.apg.norm_threshold = apg_json["norm_threshold"]; + } + if (apg_json.contains("norm_threshold_smoothing") && apg_json["norm_threshold_smoothing"].is_number()) { + target_params.guidance.apg.norm_threshold_smoothing = apg_json["norm_threshold_smoothing"]; + } + } if (guidance_json.contains("slg") && guidance_json["slg"].is_object()) { const json& slg_json = guidance_json["slg"]; if (slg_json.contains("layers") && slg_json["layers"].is_array()) { @@ -1545,6 +1588,9 @@ bool SDGenerationParams::from_json_str( if (slg_json.contains("scale") && slg_json["scale"].is_number()) { target_params.guidance.slg.scale = slg_json["scale"]; } + if (slg_json.contains("uncond") && slg_json["scale"].is_number()) { + target_params.guidance.slg.scale = slg_json["scale"]; + } } } }; @@ -1589,10 +1635,18 @@ bool SDGenerationParams::from_json_str( LOG_ERROR("invalid init_image"); return false; } + if (!parse_image_json_field(j, "end_image", 3, width, height, end_image)) { + LOG_ERROR("invalid end_image"); + return false; + } if (!parse_image_array_json_field(j, "ref_images", 3, width, height, ref_images)) { LOG_ERROR("invalid ref_images"); return false; } + if (!parse_image_array_json_field(j, "control_frames", 3, width, height, control_frames)) { + LOG_ERROR("invalid control_frames"); + return false; + } if (!parse_image_json_field(j, "mask_image", 1, width, height, mask_image)) { LOG_ERROR("invalid mask_image"); return false; @@ -2097,6 +2151,8 @@ std::string version_string() { } std::string get_image_params(const SDContextParams& ctx_params, const SDGenerationParams& gen_params, int64_t seed) { + sd_img_gen_params_t defaults; + sd_img_gen_params_init(&defaults); std::string parameter_string; if (gen_params.prompt_with_lora.size() != 0) { parameter_string += gen_params.prompt_with_lora + "\n"; @@ -2108,6 +2164,22 @@ std::string get_image_params(const SDContextParams& ctx_params, const SDGenerati } parameter_string += "Steps: " + std::to_string(gen_params.sample_params.sample_steps) + ", "; parameter_string += "CFG scale: " + std::to_string(gen_params.sample_params.guidance.txt_cfg) + ", "; + { + auto & apg = gen_params.sample_params.guidance.apg; + auto & def_apg = defaults.sample_params.guidance.apg; + if (apg.eta != def_apg.eta) { + parameter_string += "APG eta: " + std::to_string(apg.eta) + ", "; + } + if (apg.momentum != def_apg.momentum) { + parameter_string += "APG momentum: " + std::to_string(apg.momentum) + ", "; + } + if (apg.norm_threshold != def_apg.norm_threshold) { + parameter_string += "APG norm threshold: " + std::to_string(apg.norm_threshold) + ", "; + if (apg.norm_threshold > 0 && apg.norm_threshold_smoothing != def_apg.norm_threshold_smoothing && apg.norm_threshold_smoothing > 0) { + parameter_string += "APG norm threshold smoothing: " + std::to_string(apg.norm_threshold_smoothing) + ", "; + } + } + } if (gen_params.sample_params.guidance.slg.scale != 0 && gen_params.skip_layers.size() != 0) { parameter_string += "SLG scale: " + std::to_string(gen_params.sample_params.guidance.txt_cfg) + ", "; parameter_string += "Skip layers: ["; diff --git a/examples/common/media_io.cpp b/examples/common/media_io.cpp index 0b8b3a27b..e2e1ca5a3 100644 --- a/examples/common/media_io.cpp +++ b/examples/common/media_io.cpp @@ -95,6 +95,57 @@ using WebPMuxPtr = std::unique_ptr; using WebPAnimEncoderPtr = std::unique_ptr; #endif +#ifdef SD_USE_WEBM +class MemoryMkvWriter : public mkvmuxer::IMkvWriter { +public: + mkvmuxer::int32 Write(const void* buf, mkvmuxer::uint32 len) override { + if (buf == nullptr && len > 0) { + return -1; + } + const size_t end_pos = position_ + static_cast(len); + if (end_pos > data_.size()) { + data_.resize(end_pos); + } + if (len > 0) { + memcpy(data_.data() + position_, buf, len); + } + position_ = end_pos; + return 0; + } + + mkvmuxer::int64 Position() const override { + return static_cast(position_); + } + + mkvmuxer::int32 Position(mkvmuxer::int64 position) override { + if (position < 0) { + return -1; + } + const size_t target = static_cast(position); + if (target > data_.size()) { + data_.resize(target); + } + position_ = target; + return 0; + } + + bool Seekable() const override { + return true; + } + + void ElementStartNotify(mkvmuxer::uint64, mkvmuxer::int64) override { + } + + const std::vector& data() const { + return data_; + } + +private: + std::vector data_; + size_t position_ = 0; +}; +#endif + bool read_binary_file_bytes(const char* path, std::vector& data) { std::ifstream fin(fs::path(path), std::ios::binary); if (!fin) { @@ -570,6 +621,32 @@ void write_u16_le(FILE* f, uint16_t val) { fwrite(&val, 2, 1, f); } +void write_u32_le(std::vector& data, uint32_t val) { + data.push_back(static_cast(val & 0xFF)); + data.push_back(static_cast((val >> 8) & 0xFF)); + data.push_back(static_cast((val >> 16) & 0xFF)); + data.push_back(static_cast((val >> 24) & 0xFF)); +} + +void write_u16_le(std::vector& data, uint16_t val) { + data.push_back(static_cast(val & 0xFF)); + data.push_back(static_cast((val >> 8) & 0xFF)); +} + +void patch_u32_le(std::vector& data, size_t offset, uint32_t val) { + if (offset + 4 > data.size()) { + return; + } + data[offset + 0] = static_cast(val & 0xFF); + data[offset + 1] = static_cast((val >> 8) & 0xFF); + data[offset + 2] = static_cast((val >> 16) & 0xFF); + data[offset + 3] = static_cast((val >> 24) & 0xFF); +} + +void write_fourcc(std::vector& data, const char* fourcc) { + data.insert(data.end(), fourcc, fourcc + 4); +} + EncodedImageFormat encoded_image_format_from_path(const std::string& path) { std::string ext = fs::path(path).extension().string(); std::transform(ext.begin(), ext.end(), ext.begin(), ::tolower); @@ -699,95 +776,96 @@ uint8_t* load_image_from_memory(const char* image_bytes, return load_image_common(true, image_bytes, len, width, height, expected_width, expected_height, expected_channel); } -int create_mjpg_avi_from_sd_images(const char* filename, sd_image_t* images, int num_images, int fps, int quality) { +std::vector create_mjpg_avi_from_sd_images_to_vector(sd_image_t* images, int num_images, int fps, int quality) { if (num_images == 0) { fprintf(stderr, "Error: Image array is empty.\n"); - return -1; + return {}; } - FilePtr file(fopen(filename, "wb")); - if (!file) { - perror("Error opening file for writing"); - return -1; - } - FILE* f = file.get(); - uint32_t width = images[0].width; uint32_t height = images[0].height; uint32_t channels = images[0].channel; if (channels != 3 && channels != 4) { fprintf(stderr, "Error: Unsupported channel count: %u\n", channels); - return -1; - } - - fwrite("RIFF", 4, 1, f); - long riff_size_pos = ftell(f); - write_u32_le(f, 0); - fwrite("AVI ", 4, 1, f); - - fwrite("LIST", 4, 1, f); - write_u32_le(f, 4 + 8 + 56 + 8 + 4 + 8 + 56 + 8 + 40); - fwrite("hdrl", 4, 1, f); - - fwrite("avih", 4, 1, f); - write_u32_le(f, 56); - write_u32_le(f, 1000000 / fps); - write_u32_le(f, 0); - write_u32_le(f, 0); - write_u32_le(f, 0x110); - write_u32_le(f, num_images); - write_u32_le(f, 0); - write_u32_le(f, 1); - write_u32_le(f, width * height * 3); - write_u32_le(f, width); - write_u32_le(f, height); - write_u32_le(f, 0); - write_u32_le(f, 0); - write_u32_le(f, 0); - write_u32_le(f, 0); - - fwrite("LIST", 4, 1, f); - write_u32_le(f, 4 + 8 + 56 + 8 + 40); - fwrite("strl", 4, 1, f); - - fwrite("strh", 4, 1, f); - write_u32_le(f, 56); - fwrite("vids", 4, 1, f); - fwrite("MJPG", 4, 1, f); - write_u32_le(f, 0); - write_u16_le(f, 0); - write_u16_le(f, 0); - write_u32_le(f, 0); - write_u32_le(f, 1); - write_u32_le(f, fps); - write_u32_le(f, 0); - write_u32_le(f, num_images); - write_u32_le(f, width * height * 3); - write_u32_le(f, (uint32_t)-1); - write_u32_le(f, 0); - write_u16_le(f, 0); - write_u16_le(f, 0); - write_u16_le(f, 0); - write_u16_le(f, 0); - - fwrite("strf", 4, 1, f); - write_u32_le(f, 40); - write_u32_le(f, 40); - write_u32_le(f, width); - write_u32_le(f, height); - write_u16_le(f, 1); - write_u16_le(f, 24); - fwrite("MJPG", 4, 1, f); - write_u32_le(f, width * height * 3); - write_u32_le(f, 0); - write_u32_le(f, 0); - write_u32_le(f, 0); - write_u32_le(f, 0); - - fwrite("LIST", 4, 1, f); - long movi_size_pos = ftell(f); - write_u32_le(f, 0); - fwrite("movi", 4, 1, f); + return {}; + } + + // stb_image_write changes JPEG sampling behavior above quality 90. + // MJPG AVI playback is more compatible when we keep the encoder on the + // <= 90 path. + const int mjpg_quality = std::clamp(quality, 1, 90); + + std::vector avi_data; + avi_data.reserve(static_cast(num_images) * 1024); + + write_fourcc(avi_data, "RIFF"); + const size_t riff_size_pos = avi_data.size(); + write_u32_le(avi_data, 0); + write_fourcc(avi_data, "AVI "); + + write_fourcc(avi_data, "LIST"); + write_u32_le(avi_data, 4 + 8 + 56 + 8 + 4 + 8 + 56 + 8 + 40); + write_fourcc(avi_data, "hdrl"); + + write_fourcc(avi_data, "avih"); + write_u32_le(avi_data, 56); + write_u32_le(avi_data, 1000000 / fps); + write_u32_le(avi_data, 0); + write_u32_le(avi_data, 0); + write_u32_le(avi_data, 0x110); + write_u32_le(avi_data, num_images); + write_u32_le(avi_data, 0); + write_u32_le(avi_data, 1); + write_u32_le(avi_data, width * height * 3); + write_u32_le(avi_data, width); + write_u32_le(avi_data, height); + write_u32_le(avi_data, 0); + write_u32_le(avi_data, 0); + write_u32_le(avi_data, 0); + write_u32_le(avi_data, 0); + + write_fourcc(avi_data, "LIST"); + write_u32_le(avi_data, 4 + 8 + 56 + 8 + 40); + write_fourcc(avi_data, "strl"); + + write_fourcc(avi_data, "strh"); + write_u32_le(avi_data, 56); + write_fourcc(avi_data, "vids"); + write_fourcc(avi_data, "MJPG"); + write_u32_le(avi_data, 0); + write_u16_le(avi_data, 0); + write_u16_le(avi_data, 0); + write_u32_le(avi_data, 0); + write_u32_le(avi_data, 1); + write_u32_le(avi_data, fps); + write_u32_le(avi_data, 0); + write_u32_le(avi_data, num_images); + write_u32_le(avi_data, width * height * 3); + write_u32_le(avi_data, static_cast(-1)); + write_u32_le(avi_data, 0); + write_u16_le(avi_data, 0); + write_u16_le(avi_data, 0); + write_u16_le(avi_data, 0); + write_u16_le(avi_data, 0); + + write_fourcc(avi_data, "strf"); + write_u32_le(avi_data, 40); + write_u32_le(avi_data, 40); + write_u32_le(avi_data, width); + write_u32_le(avi_data, height); + write_u16_le(avi_data, 1); + write_u16_le(avi_data, 24); + write_fourcc(avi_data, "MJPG"); + write_u32_le(avi_data, width * height * 3); + write_u32_le(avi_data, 0); + write_u32_le(avi_data, 0); + write_u32_le(avi_data, 0); + write_u32_le(avi_data, 0); + + write_fourcc(avi_data, "LIST"); + const size_t movi_size_pos = avi_data.size(); + write_u32_le(avi_data, 0); + write_fourcc(avi_data, "movi"); std::vector index(static_cast(num_images)); std::vector jpeg_data; @@ -801,55 +879,61 @@ int create_mjpg_avi_from_sd_images(const char* filename, sd_image_t* images, int buffer->insert(buffer->end(), src, src + size); }; - if (!stbi_write_jpg_to_func(write_to_buf, &jpeg_data, images[i].width, images[i].height, channels, images[i].data, quality)) { + if (!stbi_write_jpg_to_func(write_to_buf, &jpeg_data, images[i].width, images[i].height, channels, images[i].data, mjpg_quality)) { fprintf(stderr, "Error: Failed to encode JPEG frame.\n"); - return -1; + return {}; } - fwrite("00dc", 4, 1, f); - write_u32_le(f, (uint32_t)jpeg_data.size()); - index[i].offset = ftell(f) - 8; - index[i].size = (uint32_t)jpeg_data.size(); - fwrite(jpeg_data.data(), 1, jpeg_data.size(), f); + index[i].offset = static_cast(avi_data.size()); + write_fourcc(avi_data, "00dc"); + write_u32_le(avi_data, static_cast(jpeg_data.size())); + index[i].size = (uint32_t)jpeg_data.size(); + avi_data.insert(avi_data.end(), jpeg_data.begin(), jpeg_data.end()); if (jpeg_data.size() % 2) { - fputc(0, f); + avi_data.push_back(0); } } - long cur_pos = ftell(f); - long movi_size = cur_pos - movi_size_pos - 4; - fseek(f, movi_size_pos, SEEK_SET); - write_u32_le(f, movi_size); - fseek(f, cur_pos, SEEK_SET); + const size_t movi_size = avi_data.size() - movi_size_pos - 4; + patch_u32_le(avi_data, movi_size_pos, static_cast(movi_size)); - fwrite("idx1", 4, 1, f); - write_u32_le(f, num_images * 16); + write_fourcc(avi_data, "idx1"); + write_u32_le(avi_data, num_images * 16); for (int i = 0; i < num_images; i++) { - fwrite("00dc", 4, 1, f); - write_u32_le(f, 0x10); - write_u32_le(f, index[i].offset); - write_u32_le(f, index[i].size); + write_fourcc(avi_data, "00dc"); + write_u32_le(avi_data, 0x10); + write_u32_le(avi_data, index[i].offset); + write_u32_le(avi_data, index[i].size); } - cur_pos = ftell(f); - long file_size = cur_pos - riff_size_pos - 4; - fseek(f, riff_size_pos, SEEK_SET); - write_u32_le(f, file_size); - fseek(f, cur_pos, SEEK_SET); + const size_t file_size = avi_data.size() - riff_size_pos - 4; + patch_u32_le(avi_data, riff_size_pos, static_cast(file_size)); + + return avi_data; +} +int create_mjpg_avi_from_sd_images(const char* filename, sd_image_t* images, int num_images, int fps, int quality) { + std::vector avi_data = create_mjpg_avi_from_sd_images_to_vector(images, num_images, fps, quality); + if (avi_data.empty()) { + return -1; + } + if (!write_binary_file_bytes(filename, avi_data)) { + perror("Error opening file for writing"); + return -1; + } return 0; } #ifdef SD_USE_WEBP -int create_animated_webp_from_sd_images(const char* filename, sd_image_t* images, int num_images, int fps, int quality) { +std::vector create_animated_webp_from_sd_images_to_vector(sd_image_t* images, int num_images, int fps, int quality) { if (num_images == 0) { fprintf(stderr, "Error: Image array is empty.\n"); - return -1; + return {}; } if (fps <= 0) { fprintf(stderr, "Error: FPS must be positive.\n"); - return -1; + return {}; } const int width = static_cast(images[0].width); @@ -857,14 +941,14 @@ int create_animated_webp_from_sd_images(const char* filename, sd_image_t* images const int channels = static_cast(images[0].channel); if (channels != 1 && channels != 3 && channels != 4) { fprintf(stderr, "Error: Unsupported channel count: %d\n", channels); - return -1; + return {}; } WebPAnimEncoderOptions anim_options; WebPConfig config; if (!WebPAnimEncoderOptionsInit(&anim_options) || !WebPConfigInit(&config)) { fprintf(stderr, "Error: Failed to initialize WebP animation encoder.\n"); - return -1; + return {}; } config.quality = static_cast(quality); @@ -875,13 +959,13 @@ int create_animated_webp_from_sd_images(const char* filename, sd_image_t* images } if (!WebPValidateConfig(&config)) { fprintf(stderr, "Error: Invalid WebP encoder configuration.\n"); - return -1; + return {}; } WebPAnimEncoderPtr enc(WebPAnimEncoderNew(width, height, &anim_options)); if (enc == nullptr) { fprintf(stderr, "Error: Could not create WebPAnimEncoder object.\n"); - return -1; + return {}; } const int frame_duration_ms = std::max(1, static_cast(std::lround(1000.0 / static_cast(fps)))); @@ -891,13 +975,13 @@ int create_animated_webp_from_sd_images(const char* filename, sd_image_t* images const sd_image_t& image = images[i]; if (static_cast(image.width) != width || static_cast(image.height) != height) { fprintf(stderr, "Error: Frame dimensions do not match.\n"); - return -1; + return {}; } WebPPictureGuard picture; if (!picture.initialized) { fprintf(stderr, "Error: Failed to initialize WebPPicture.\n"); - return -1; + return {}; } picture.picture.use_argb = 1; picture.picture.width = width; @@ -921,12 +1005,12 @@ int create_animated_webp_from_sd_images(const char* filename, sd_image_t* images if (!picture_ok) { fprintf(stderr, "Error: Failed to import frame into WebPPicture.\n"); - return -1; + return {}; } if (!WebPAnimEncoderAdd(enc.get(), &picture.picture, timestamp_ms, &config)) { fprintf(stderr, "Error: Failed to add frame to animated WebP: %s\n", WebPAnimEncoderGetError(enc.get())); - return -1; + return {}; } timestamp_ms += frame_duration_ms; @@ -934,52 +1018,50 @@ int create_animated_webp_from_sd_images(const char* filename, sd_image_t* images if (!WebPAnimEncoderAdd(enc.get(), nullptr, timestamp_ms, nullptr)) { fprintf(stderr, "Error: Failed to finalize animated WebP frames: %s\n", WebPAnimEncoderGetError(enc.get())); - return -1; + return {}; } WebPDataGuard webp_data; if (!WebPAnimEncoderAssemble(enc.get(), &webp_data.data)) { fprintf(stderr, "Error: Failed to assemble animated WebP: %s\n", WebPAnimEncoderGetError(enc.get())); - return -1; + return {}; } - FilePtr f(fopen(filename, "wb")); - if (!f) { - perror("Error opening file for writing"); + return std::vector(webp_data.data.bytes, webp_data.data.bytes + webp_data.data.size); +} + +int create_animated_webp_from_sd_images(const char* filename, sd_image_t* images, int num_images, int fps, int quality) { + std::vector webp_data = create_animated_webp_from_sd_images_to_vector(images, num_images, fps, quality); + if (webp_data.empty()) { return -1; } - if (webp_data.data.size > 0 && fwrite(webp_data.data.bytes, 1, webp_data.data.size, f.get()) != webp_data.data.size) { - fprintf(stderr, "Error: Failed to write animated WebP file.\n"); + if (!write_binary_file_bytes(filename, webp_data)) { + perror("Error opening file for writing"); return -1; } - return 0; } #endif #ifdef SD_USE_WEBM -int create_webm_from_sd_images(const char* filename, sd_image_t* images, int num_images, int fps, int quality) { +std::vector create_webm_from_sd_images_to_vector(sd_image_t* images, int num_images, int fps, int quality) { if (num_images == 0) { fprintf(stderr, "Error: Image array is empty.\n"); - return -1; + return {}; } if (fps <= 0) { fprintf(stderr, "Error: FPS must be positive.\n"); - return -1; + return {}; } const int width = static_cast(images[0].width); const int height = static_cast(images[0].height); if (width <= 0 || height <= 0) { fprintf(stderr, "Error: Invalid frame dimensions.\n"); - return -1; + return {}; } - mkvmuxer::MkvWriter writer; - if (!writer.Open(filename)) { - fprintf(stderr, "Error: Could not open WebM file for writing.\n"); - return -1; - } + MemoryMkvWriter writer; const int ret = [&]() -> int { mkvmuxer::Segment segment; @@ -1045,30 +1127,63 @@ int create_webm_from_sd_images(const char* filename, sd_image_t* images, int num } return 0; }(); - writer.Close(); - return ret; + if (ret != 0) { + return {}; + } + return writer.data(); +} + +int create_webm_from_sd_images(const char* filename, sd_image_t* images, int num_images, int fps, int quality) { + std::vector webm_data = create_webm_from_sd_images_to_vector(images, num_images, fps, quality); + if (webm_data.empty()) { + return -1; + } + if (!write_binary_file_bytes(filename, webm_data)) { + perror("Error opening file for writing"); + return -1; + } + return 0; } #endif -int create_video_from_sd_images(const char* filename, sd_image_t* images, int num_images, int fps, int quality) { - std::string path = filename ? filename : ""; - auto pos = path.find_last_of('.'); - std::string ext = pos == std::string::npos ? "" : path.substr(pos); - for (char& ch : ext) { - ch = static_cast(tolower(static_cast(ch))); +std::vector create_video_from_sd_images_to_vector(const std::string& output_format, + sd_image_t* images, + int num_images, + int fps, + int quality) { + std::string format = output_format; + std::transform(format.begin(), format.end(), format.begin(), + [](unsigned char c) { return static_cast(tolower(c)); }); + if (!format.empty() && format[0] == '.') { + format.erase(format.begin()); } #ifdef SD_USE_WEBM - if (ext == ".webm") { - return create_webm_from_sd_images(filename, images, num_images, fps, quality); + if (format == "webm") { + return create_webm_from_sd_images_to_vector(images, num_images, fps, quality); } #endif #ifdef SD_USE_WEBP - if (ext == ".webp") { - return create_animated_webp_from_sd_images(filename, images, num_images, fps, quality); + if (format == "webp") { + return create_animated_webp_from_sd_images_to_vector(images, num_images, fps, quality); } #endif - return create_mjpg_avi_from_sd_images(filename, images, num_images, fps, quality); + return create_mjpg_avi_from_sd_images_to_vector(images, num_images, fps, quality); +} + +int create_video_from_sd_images(const char* filename, sd_image_t* images, int num_images, int fps, int quality) { + std::string path = filename ? filename : ""; + auto pos = path.find_last_of('.'); + std::string ext = pos == std::string::npos ? "" : path.substr(pos); + std::vector video_data = create_video_from_sd_images_to_vector(ext, images, num_images, fps, quality); + if (video_data.empty()) { + return -1; + } + if (!write_binary_file_bytes(filename, video_data)) { + perror("Error opening file for writing"); + return -1; + } + return 0; } diff --git a/examples/common/media_io.h b/examples/common/media_io.h index e6ca098d9..6b3f6f883 100644 --- a/examples/common/media_io.h +++ b/examples/common/media_io.h @@ -58,6 +58,10 @@ int create_mjpg_avi_from_sd_images(const char* filename, int num_images, int fps, int quality = 90); +std::vector create_mjpg_avi_from_sd_images_to_vector(sd_image_t* images, + int num_images, + int fps, + int quality = 90); #ifdef SD_USE_WEBP int create_animated_webp_from_sd_images(const char* filename, @@ -65,6 +69,10 @@ int create_animated_webp_from_sd_images(const char* filename, int num_images, int fps, int quality = 90); +std::vector create_animated_webp_from_sd_images_to_vector(sd_image_t* images, + int num_images, + int fps, + int quality = 90); #endif #ifdef SD_USE_WEBM @@ -73,6 +81,10 @@ int create_webm_from_sd_images(const char* filename, int num_images, int fps, int quality = 90); +std::vector create_webm_from_sd_images_to_vector(sd_image_t* images, + int num_images, + int fps, + int quality = 90); #endif int create_video_from_sd_images(const char* filename, @@ -80,5 +92,10 @@ int create_video_from_sd_images(const char* filename, int num_images, int fps, int quality = 90); +std::vector create_video_from_sd_images_to_vector(const std::string& output_format, + sd_image_t* images, + int num_images, + int fps, + int quality = 90); #endif // __MEDIA_IO_H__ diff --git a/examples/server/README.md b/examples/server/README.md index e27d973fa..908b459ad 100644 --- a/examples/server/README.md +++ b/examples/server/README.md @@ -219,7 +219,7 @@ Default Generation Options: medium --skip-layer-start SLG enabling point (default: 0.01) --skip-layer-end SLG disabling point (default: 0.2) - --eta noise multiplier (default: 0 for ddim_trailing, tcd, res_multistep and res_2s; 1 for euler_a and dpm++2s_a) + --eta noise multiplier (default: 0 for ddim_trailing, tcd, res_multistep and res_2s; 1 for euler_a, er_sde and dpm++2s_a) --flow-shift shift value for Flow models like SD3.x or WAN (default: auto) --high-noise-cfg-scale (high noise) unconditional guidance scale: (default: 7.0) --high-noise-img-cfg-scale (high noise) image guidance scale for inpaint or instruct-pix2pix models (default: same as --cfg-scale) @@ -227,7 +227,7 @@ Default Generation Options: --high-noise-slg-scale (high noise) skip layer guidance (SLG) scale, only for DiT models: (default: 0) --high-noise-skip-layer-start (high noise) SLG enabling point (default: 0.01) --high-noise-skip-layer-end (high noise) SLG disabling point (default: 0.2) - --high-noise-eta (high noise) noise multiplier (default: 0 for ddim_trailing, tcd, res_multistep and res_2s; 1 for euler_a and dpm++2s_a) + --high-noise-eta (high noise) noise multiplier (default: 0 for ddim_trailing, tcd, res_multistep and res_2s; 1 for euler_a, er_sde and dpm++2s_a) --strength strength for noising/unnoising (default: 0.75) --pm-style-strength --control-strength strength to apply Control Net (default: 0.9). 1.0 corresponds to full destruction of information in init image @@ -238,10 +238,10 @@ Default Generation Options: --disable-image-metadata do not embed generation metadata on image files -s, --seed RNG seed (default: 42, use random seed for < 0) --sampling-method sampling method, one of [euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm, ddim_trailing, - tcd, res_multistep, res_2s] (default: euler for Flux/SD3/Wan, euler_a + tcd, res_multistep, res_2s, er_sde] (default: euler for Flux/SD3/Wan, euler_a otherwise) --high-noise-sampling-method (high noise) sampling method, one of [euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm, - ddim_trailing, tcd, res_multistep, res_2s] default: euler for Flux/SD3/Wan, + ddim_trailing, tcd, res_multistep, res_2s, er_sde] default: euler for Flux/SD3/Wan, euler_a otherwise --scheduler denoiser sigma scheduler, one of [discrete, karras, exponential, ays, gits, smoothstep, sgm_uniform, simple, kl_optimal, lcm, bong_tangent], default: discrete diff --git a/examples/server/api.md b/examples/server/api.md index 8f8bf9edd..39744dbed 100644 --- a/examples/server/api.md +++ b/examples/server/api.md @@ -9,7 +9,7 @@ The server currently exposes three API families: - `sdcpp API` under `/sdcpp/v1/...` The `sdcpp API` is the native API surface. -Its request schema is also the canonical schema for `sd_cpp_extra_args`. +Its request schema is the same schema used by `sd_cpp_extra_args`. Global LoRA rule: @@ -55,8 +55,6 @@ Current endpoints include: - `POST /sdcpp/v1/jobs/{id}/cancel` - `POST /sdcpp/v1/vid_gen` -`POST /sdcpp/v1/vid_gen` is currently exposed but returns `501 Not Implemented`. - ## `sd_cpp_extra_args` `sd_cpp_extra_args` is an extension mechanism for the compatibility APIs. @@ -79,12 +77,12 @@ Behavior: - The JSON block is parsed using the same field rules as the `sdcpp API`. - The block is removed from the final prompt before generation. -Intended use: +Supported use: - extend `OpenAI API` requests with native `stable-diffusion.cpp` controls - extend `sdapi` requests with native `stable-diffusion.cpp` controls -Not intended use: +Unsupported use: - do not use `sd_cpp_extra_args` with `/sdcpp/v1/*` @@ -372,20 +370,25 @@ Field types: Returns frontend-friendly capability metadata. -Typical contents: +The mode-aware fields are the primary interface. The top-level compatibility fields are deprecated mirrors kept for older clients. -| Field | Type | -| --- | --- | -| `model` | `object` | -| `defaults` | `object` | -| `loras` | `array` | -| `samplers` | `array` | -| `schedulers` | `array` | -| `output_formats` | `array` | -| `limits` | `object` | -| `features` | `object` | +Top-level fields: -Nested fields currently returned: +| Field | Type | Notes | +| --- | --- | --- | +| `model` | `object` | Loaded model metadata | +| `current_mode` | `string` | The native generation mode mirrored by top-level compatibility fields | +| `supported_modes` | `array` | Supported native modes such as `img_gen` or `vid_gen` | +| `defaults` | `object` | Deprecated compatibility mirror of `defaults_by_mode[current_mode]` | +| `output_formats` | `array` | Deprecated compatibility mirror of `output_formats_by_mode[current_mode]` | +| `features` | `object` | Deprecated compatibility mirror of `features_by_mode[current_mode]` | +| `defaults_by_mode` | `object` | Explicit defaults for each supported mode | +| `output_formats_by_mode` | `object` | Explicit output formats for each supported mode | +| `features_by_mode` | `object` | Explicit feature flags for each supported mode | +| `samplers` | `array` | Available sampling methods | +| `schedulers` | `array` | Available schedulers | +| `loras` | `array` | Available LoRA entries | +| `limits` | `object` | Shared queue and size limits | `model` @@ -395,50 +398,24 @@ Nested fields currently returned: | `model.stem` | `string` | | `model.path` | `string` | -`defaults` +Compatibility rules: + +- `defaults`, `output_formats`, and `features` are deprecated compatibility mirrors +- those three top-level fields always mirror `current_mode` +- `supported_modes`, `defaults_by_mode`, `output_formats_by_mode`, and `features_by_mode` are the mode-aware fields + +Mode-aware objects: | Field | Type | | --- | --- | -| `defaults.prompt` | `string` | -| `defaults.negative_prompt` | `string` | -| `defaults.clip_skip` | `integer` | -| `defaults.width` | `integer` | -| `defaults.height` | `integer` | -| `defaults.strength` | `number` | -| `defaults.seed` | `integer` | -| `defaults.batch_count` | `integer` | -| `defaults.auto_resize_ref_image` | `boolean` | -| `defaults.increase_ref_index` | `boolean` | -| `defaults.control_strength` | `number` | -| `defaults.sample_params` | `object` | -| `defaults.sample_params.scheduler` | `string` | -| `defaults.sample_params.sample_method` | `string` | -| `defaults.sample_params.sample_steps` | `integer` | -| `defaults.sample_params.eta` | `number \| null` | -| `defaults.sample_params.shifted_timestep` | `integer` | -| `defaults.sample_params.flow_shift` | `number \| null` | -| `defaults.sample_params.guidance` | `object` | -| `defaults.sample_params.guidance.txt_cfg` | `number` | -| `defaults.sample_params.guidance.img_cfg` | `number \| null` | -| `defaults.sample_params.guidance.distilled_guidance` | `number` | -| `defaults.sample_params.guidance.slg` | `object` | -| `defaults.sample_params.guidance.slg.layers` | `array` | -| `defaults.sample_params.guidance.slg.layer_start` | `number` | -| `defaults.sample_params.guidance.slg.layer_end` | `number` | -| `defaults.sample_params.guidance.slg.scale` | `number` | -| `defaults.vae_tiling_params` | `object` | -| `defaults.vae_tiling_params.enabled` | `boolean` | -| `defaults.vae_tiling_params.tile_size_x` | `integer` | -| `defaults.vae_tiling_params.tile_size_y` | `integer` | -| `defaults.vae_tiling_params.target_overlap` | `number` | -| `defaults.vae_tiling_params.rel_size_x` | `number` | -| `defaults.vae_tiling_params.rel_size_y` | `number` | -| `defaults.cache_mode` | `string` | -| `defaults.cache_option` | `string` | -| `defaults.scm_mask` | `string` | -| `defaults.scm_policy_dynamic` | `boolean` | -| `defaults.output_format` | `string` | -| `defaults.output_compression` | `integer` | +| `defaults_by_mode.img_gen` | `object` | +| `defaults_by_mode.vid_gen` | `object` | +| `output_formats_by_mode.img_gen` | `array` | +| `output_formats_by_mode.vid_gen` | `array` | +| `features_by_mode.img_gen` | `object` | +| `features_by_mode.vid_gen` | `object` | + +Shared nested fields: `loras` @@ -458,19 +435,100 @@ Nested fields currently returned: | `limits.max_batch_count` | `integer` | | `limits.max_queue_size` | `integer` | -`features` +Shared default fields used by both `img_gen` and `vid_gen`: + +| Field | Type | +| --- | --- | +| `prompt` | `string` | +| `negative_prompt` | `string` | +| `clip_skip` | `integer` | +| `width` | `integer` | +| `height` | `integer` | +| `strength` | `number` | +| `seed` | `integer` | +| `sample_params` | `object` | +| `sample_params.scheduler` | `string` | +| `sample_params.sample_method` | `string` | +| `sample_params.sample_steps` | `integer` | +| `sample_params.eta` | `number \| null` | +| `sample_params.shifted_timestep` | `integer` | +| `sample_params.flow_shift` | `number \| null` | +| `sample_params.guidance.txt_cfg` | `number` | +| `sample_params.guidance.img_cfg` | `number \| null` | +| `sample_params.guidance.distilled_guidance` | `number` | +| `sample_params.guidance.slg.layers` | `array` | +| `sample_params.guidance.slg.layer_start` | `number` | +| `sample_params.guidance.slg.layer_end` | `number` | +| `sample_params.guidance.slg.scale` | `number` | +| `vae_tiling_params` | `object` | +| `vae_tiling_params.enabled` | `boolean` | +| `vae_tiling_params.tile_size_x` | `integer` | +| `vae_tiling_params.tile_size_y` | `integer` | +| `vae_tiling_params.target_overlap` | `number` | +| `vae_tiling_params.rel_size_x` | `number` | +| `vae_tiling_params.rel_size_y` | `number` | +| `cache_mode` | `string` | +| `cache_option` | `string` | +| `scm_mask` | `string` | +| `scm_policy_dynamic` | `boolean` | +| `output_format` | `string` | +| `output_compression` | `integer` | + +`img_gen`-specific default fields: + +| Field | Type | +| --- | --- | +| `batch_count` | `integer` | +| `auto_resize_ref_image` | `boolean` | +| `increase_ref_index` | `boolean` | +| `control_strength` | `number` | + +`vid_gen`-specific default fields: | Field | Type | | --- | --- | -| `features.init_image` | `boolean` | -| `features.mask_image` | `boolean` | -| `features.control_image` | `boolean` | -| `features.ref_images` | `boolean` | -| `features.lora` | `boolean` | -| `features.vae_tiling` | `boolean` | -| `features.cache` | `boolean` | -| `features.cancel_queued` | `boolean` | -| `features.cancel_generating` | `boolean` | +| `video_frames` | `integer` | +| `fps` | `integer` | +| `moe_boundary` | `number` | +| `vace_strength` | `number` | +| `high_noise_sample_params` | `object` | +| `high_noise_sample_params.scheduler` | `string` | +| `high_noise_sample_params.sample_method` | `string` | +| `high_noise_sample_params.sample_steps` | `integer` | +| `high_noise_sample_params.eta` | `number \| null` | +| `high_noise_sample_params.shifted_timestep` | `integer` | +| `high_noise_sample_params.flow_shift` | `number \| null` | +| `high_noise_sample_params.guidance.txt_cfg` | `number` | +| `high_noise_sample_params.guidance.img_cfg` | `number \| null` | +| `high_noise_sample_params.guidance.distilled_guidance` | `number` | +| `high_noise_sample_params.guidance.slg.layers` | `array` | +| `high_noise_sample_params.guidance.slg.layer_start` | `number` | +| `high_noise_sample_params.guidance.slg.layer_end` | `number` | +| `high_noise_sample_params.guidance.slg.scale` | `number` | + +Fields returned in `features_by_mode.img_gen`: + +- `init_image` +- `mask_image` +- `control_image` +- `ref_images` +- `lora` +- `vae_tiling` +- `cache` +- `cancel_queued` +- `cancel_generating` + +Fields returned in `features_by_mode.vid_gen`: + +- `init_image` +- `end_image` +- `control_frames` +- `high_noise_sample_params` +- `lora` +- `vae_tiling` +- `cache` +- `cancel_queued` +- `cancel_generating` #### `POST /sdcpp/v1/img_gen` @@ -521,9 +579,7 @@ Typical status codes: - `409 Conflict` - `410 Gone` -### Canonical Request Schema - -The `sdcpp API` request body is the canonical native schema. +### Request Body Example: @@ -612,7 +668,7 @@ Channel expectations: If omitted or null: - single-image fields map to an empty `sd_image_t` -- array fields map to `nullptr + count = 0` +- array fields map to an empty C-style array, represented as `pointer = nullptr` and `count = 0` ### Field Mapping Summary @@ -686,11 +742,11 @@ HTTP-only output fields: | `output_format` | `string` | | `output_compression` | `integer` | -### Optional Field Semantics +### Optional Field Handling -Clients should preserve unset semantics for optional sampling fields. +Optional sampling fields may be omitted. -If a user has not explicitly provided one of these fields, the client should omit it instead of injecting a guessed fallback: +When omitted, backend defaults apply to these fields: - `sample_params.scheduler` - `sample_params.sample_method` @@ -766,29 +822,394 @@ Example cancelled job: } ``` -### Validation and Retention +### Submission Errors -Recommended behavior: +`POST /sdcpp/v1/img_gen` may return: -- malformed JSON returns `400` -- invalid image payloads return `400` -- invalid parameter structure returns `400` -- queue full returns `429` or `503` -- accepted runtime failures transition the job to `failed` -- unsupported in-progress cancellation may return `409` +- `202 Accepted` when the job is created +- `400 Bad Request` for an empty body, unsupported model mode, invalid JSON, or invalid generation parameters +- `429 Too Many Requests` when the job queue is full +- `500 Internal Server Error` for unexpected server exceptions during submission -Recommended retention controls: +### `vid_gen` -- pending job limit -- completed job TTL -- failed job TTL +The following section documents the native async contract for video generation. -### Future `vid_gen` +#### `POST /sdcpp/v1/vid_gen` -Future `vid_gen` should reuse the same async job model: +Submits an async video generation job. -- `POST /sdcpp/v1/vid_gen` -- `GET /sdcpp/v1/jobs/{id}` -- `POST /sdcpp/v1/jobs/{id}/cancel` +Successful submission returns `202 Accepted`. + +Example response: + +```json +{ + "id": "job_01HTXYZVID", + "kind": "vid_gen", + "status": "queued", + "created": 1775401200, + "poll_url": "/sdcpp/v1/jobs/job_01HTXYZVID" +} +``` + +Response fields: + +| Field | Type | +| --- | --- | +| `id` | `string` | +| `kind` | `string` | +| `status` | `string` | +| `created` | `integer` | +| `poll_url` | `string` | + +### Request Body + +Compared with `img_gen`, the `vid_gen` request body: + +- `vid_gen` is a single video sequence job, so `batch_count` is not part of the request schema +- `ref_images`, `mask_image`, `control_image`, `control_strength`, and `embed_image_metadata` are not part of the request schema +- `vid_gen` adds `end_image`, `control_frames`, `high_noise_sample_params`, `video_frames`, `fps`, `moe_boundary`, and `vace_strength` + +Example: + +```json +{ + "prompt": "a cat walking through a rainy alley", + "negative_prompt": "", + "clip_skip": -1, + "width": 832, + "height": 480, + "strength": 0.75, + "seed": -1, + "video_frames": 33, + "fps": 16, + "moe_boundary": 0.875, + "vace_strength": 1.0, + + "init_image": null, + "end_image": null, + "control_frames": [], + + "sample_params": { + "scheduler": "discrete", + "sample_method": "euler", + "sample_steps": 28, + "eta": 1.0, + "shifted_timestep": 0, + "custom_sigmas": [], + "flow_shift": 0.0, + "guidance": { + "txt_cfg": 7.0, + "img_cfg": 7.0, + "distilled_guidance": 3.5, + "slg": { + "layers": [7, 8, 9], + "layer_start": 0.01, + "layer_end": 0.2, + "scale": 0.0 + } + } + }, + + "high_noise_sample_params": { + "scheduler": "discrete", + "sample_method": "euler", + "sample_steps": -1, + "eta": 1.0, + "shifted_timestep": 0, + "flow_shift": 0.0, + "guidance": { + "txt_cfg": 7.0, + "img_cfg": 7.0, + "distilled_guidance": 3.5, + "slg": { + "layers": [7, 8, 9], + "layer_start": 0.01, + "layer_end": 0.2, + "scale": 0.0 + } + } + }, + + "lora": [], + + "vae_tiling_params": { + "enabled": false, + "tile_size_x": 0, + "tile_size_y": 0, + "target_overlap": 0.5, + "rel_size_x": 0.0, + "rel_size_y": 0.0 + }, + + "cache_mode": "disabled", + "cache_option": "", + "scm_mask": "", + "scm_policy_dynamic": true, + + "output_format": "webm", + "output_compression": 100 +} +``` + +### LoRA Rules + +- The server only accepts explicit LoRA entries from the `lora` field. +- Prompt-embedded `` tags are intentionally unsupported. +- `lora[].is_high_noise` controls whether a LoRA applies only to the high-noise stage. + +### Image and Frame Encoding Rules + +Any image field accepts: + +- a raw base64 string, or +- a data URL such as `data:image/png;base64,...` + +Channel expectations: + +- `init_image`: 3 channels +- `end_image`: 3 channels +- `control_frames[]`: 3 channels + +Frame ordering rules: + +- `control_frames[]` order is the conditioning frame order +- `control_frames[]` is preserved in request order + +If omitted or null: + +- single-image fields map to an empty `sd_image_t` +- array fields map to an empty C-style array, represented as `pointer = nullptr` and `count = 0` + +### Field Mapping Summary + +Top-level scalar fields: + +| Field | Type | +| --- | --- | +| `prompt` | `string` | +| `negative_prompt` | `string` | +| `clip_skip` | `integer` | +| `width` | `integer` | +| `height` | `integer` | +| `strength` | `number` | +| `seed` | `integer` | +| `video_frames` | `integer` | +| `fps` | `integer` | +| `moe_boundary` | `number` | +| `vace_strength` | `number` | + +Image and frame fields: + +| Field | Type | +| --- | --- | +| `init_image` | `string \| null` | +| `end_image` | `string \| null` | +| `control_frames` | `array` | + +LoRA fields: + +| Field | Type | +| --- | --- | +| `lora[].path` | `string` | +| `lora[].multiplier` | `number` | +| `lora[].is_high_noise` | `boolean` | + +Sampling fields: + +| Field | Type | +| --- | --- | +| `sample_params.scheduler` | `string` | +| `sample_params.sample_method` | `string` | +| `sample_params.sample_steps` | `integer` | +| `sample_params.eta` | `number` | +| `sample_params.shifted_timestep` | `integer` | +| `sample_params.custom_sigmas` | `array` | +| `sample_params.flow_shift` | `number` | +| `sample_params.guidance.txt_cfg` | `number` | +| `sample_params.guidance.img_cfg` | `number` | +| `sample_params.guidance.distilled_guidance` | `number` | +| `sample_params.guidance.slg.layers` | `array` | +| `sample_params.guidance.slg.layer_start` | `number` | +| `sample_params.guidance.slg.layer_end` | `number` | +| `sample_params.guidance.slg.scale` | `number` | + +High-noise sampling fields: + +| Field | Type | +| --- | --- | +| `high_noise_sample_params.scheduler` | `string` | +| `high_noise_sample_params.sample_method` | `string` | +| `high_noise_sample_params.sample_steps` | `integer` | +| `high_noise_sample_params.eta` | `number` | +| `high_noise_sample_params.shifted_timestep` | `integer` | +| `high_noise_sample_params.flow_shift` | `number` | +| `high_noise_sample_params.guidance.txt_cfg` | `number` | +| `high_noise_sample_params.guidance.img_cfg` | `number` | +| `high_noise_sample_params.guidance.distilled_guidance` | `number` | +| `high_noise_sample_params.guidance.slg.layers` | `array` | +| `high_noise_sample_params.guidance.slg.layer_start` | `number` | +| `high_noise_sample_params.guidance.slg.layer_end` | `number` | +| `high_noise_sample_params.guidance.slg.scale` | `number` | + +Other native fields: + +| Field | Type | +| --- | --- | +| `vae_tiling_params` | `object` | +| `cache_mode` | `string` | +| `cache_option` | `string` | +| `scm_mask` | `string` | +| `scm_policy_dynamic` | `boolean` | + +HTTP-only output fields: + +| Field | Type | +| --- | --- | +| `output_format` | `string` | +| `output_compression` | `integer` | + +For `vid_gen`, `output_format` and `output_compression` control container encoding. +`fps` is request metadata for the generated sequence and is echoed in the completed job result. + +Allowed `output_format` values: + +- `webm` +- `webp` +- `avi` + +Output format behavior: + +- `output_format` defaults to `webm` +- `webp` means animated WebP +- `avi` means MJPG AVI +- `webm` requires the server to be built with WebM support; otherwise the request returns `400` + +### Result Payload + +Completed jobs return one encoded container payload, not a list of per-frame images. + +Result fields: + +- `result.b64_json` contains the whole encoded container file as base64 +- `result.mime_type` identifies the media type +- `result.output_format` echoes the selected container format +- `result.fps` echoes the effective playback FPS +- `result.frame_count` reports the actual decoded frame count used to build the container + +Expected MIME types: + +| `output_format` | `mime_type` | +| --- | --- | +| `webm` | `video/webm` | +| `webp` | `image/webp` | +| `avi` | `video/x-msvideo` | + +### Optional Field Handling + +Optional sampling fields may be omitted. + +When omitted, backend defaults apply to these fields: + +- `sample_params.scheduler` +- `sample_params.sample_method` +- `sample_params.eta` +- `sample_params.flow_shift` +- `sample_params.guidance.img_cfg` +- `high_noise_sample_params.scheduler` +- `high_noise_sample_params.sample_method` +- `high_noise_sample_params.eta` +- `high_noise_sample_params.flow_shift` +- `high_noise_sample_params.guidance.img_cfg` + +`high_noise_sample_params` may also be omitted entirely. + +### Frame Count Semantics + +`video_frames` is the requested target length, but the current core video path internally normalizes the effective frame count to the largest `4n + 1` value that does not exceed the requested count. + +Examples: + +- `video_frames = 33` stays `33` +- `video_frames = 34` becomes `33` +- `video_frames = 32` becomes `29` + +The completed job payload includes the actual decoded `frame_count`. + +### Completion Result + +Example completed job: + +```json +{ + "id": "job_01HTXYZVID", + "kind": "vid_gen", + "status": "completed", + "created": 1775401200, + "started": 1775401203, + "completed": 1775401215, + "queue_position": 0, + "result": { + "output_format": "webm", + "mime_type": "video/webm", + "fps": 16, + "frame_count": 33, + "b64_json": "GkXfo59ChoEBQveBAULygQRC84EIQo..." + }, + "error": null +} +``` + +The response returns the encoded `.webm`, animated `.webp`, or `.avi` container payload directly. + +### Failure Result + +Example failed job: + +```json +{ + "id": "job_01HTXYZVID", + "kind": "vid_gen", + "status": "failed", + "created": 1775401200, + "started": 1775401203, + "completed": 1775401204, + "queue_position": 0, + "result": null, + "error": { + "code": "generation_failed", + "message": "generate_video returned no results" + } +} +``` + +### Cancelled Result + +Example cancelled job: + +```json +{ + "id": "job_01HTXYZVID", + "kind": "vid_gen", + "status": "cancelled", + "created": 1775401200, + "started": null, + "completed": 1775401202, + "queue_position": 0, + "result": null, + "error": { + "code": "cancelled", + "message": "job cancelled by client" + } +} +``` + +### Submission Errors + +`POST /sdcpp/v1/vid_gen` may return: -Its request body should mirror `sd_vid_gen_params_t` in the same way that `img_gen` mirrors `sd_img_gen_params_t`. +- `202 Accepted` when the job is created +- `400 Bad Request` for an empty body, unsupported model mode, invalid JSON, invalid generation parameters, or an unsupported output format +- `429 Too Many Requests` when the job queue is full +- `500 Internal Server Error` for unexpected server exceptions during submission diff --git a/examples/server/async_jobs.cpp b/examples/server/async_jobs.cpp index 39c47cfaa..e8e9d8ada 100644 --- a/examples/server/async_jobs.cpp +++ b/examples/server/async_jobs.cpp @@ -95,8 +95,12 @@ bool cancel_queued_job(AsyncJobManager& manager, AsyncGenerationJob& job) { job.status = AsyncJobStatus::Cancelled; job.completed_at = unix_timestamp_now(); job.result_images_b64.clear(); - job.error_code = "cancelled"; - job.error_message = "job cancelled by client"; + job.result_media_b64.clear(); + job.result_media_mime_type.clear(); + job.result_frame_count = 0; + job.result_fps = 0; + job.error_code = "cancelled"; + job.error_message = "job cancelled by client"; return true; } @@ -122,14 +126,24 @@ json make_async_job_json(const AsyncJobManager& manager, const AsyncGenerationJo } if (job.status == AsyncJobStatus::Completed) { - json images = json::array(); - for (size_t i = 0; i < job.result_images_b64.size(); ++i) { - images.push_back({{"index", i}, {"b64_json", job.result_images_b64[i]}}); + if (job.kind == AsyncJobKind::VidGen) { + result["result"] = { + {"output_format", job.vid_gen.output_format}, + {"mime_type", job.result_media_mime_type}, + {"fps", job.result_fps}, + {"frame_count", job.result_frame_count}, + {"b64_json", job.result_media_b64}, + }; + } else { + json images = json::array(); + for (size_t i = 0; i < job.result_images_b64.size(); ++i) { + images.push_back({{"index", i}, {"b64_json", job.result_images_b64[i]}}); + } + result["result"] = { + {"output_format", job.img_gen.output_format}, + {"images", images}, + }; } - result["result"] = { - {"output_format", job.img_gen.output_format}, - {"images", images}, - }; result["error"] = nullptr; } else if (job.status == AsyncJobStatus::Failed || job.status == AsyncJobStatus::Cancelled) { @@ -156,16 +170,15 @@ bool execute_img_gen_job(ServerRuntime& runtime, sd_img_gen_params_t params = job.img_gen.to_sd_img_gen_params_t(); SDImageVec results; - int num_results = 0; { std::lock_guard lock(*runtime.sd_ctx_mutex); sd_image_t* raw_results = generate_image(runtime.sd_ctx, ¶ms); - num_results = params.batch_count; - results.adopt(raw_results, num_results); + results.adopt(raw_results, params.batch_count); } - if (results.empty() || num_results <= 0) { + const int num_results = results.count(); + if (num_results <= 0) { error_message = "generate_image returned no results"; return false; } @@ -208,6 +221,47 @@ bool execute_img_gen_job(ServerRuntime& runtime, return true; } +bool execute_vid_gen_job(ServerRuntime& runtime, + AsyncGenerationJob& job, + std::string& output_media_b64, + std::string& output_media_mime_type, + int& output_frame_count, + int& output_fps, + std::string& error_message) { + sd_vid_gen_params_t params = job.vid_gen.to_sd_vid_gen_params_t(); + + SDImageVec results; + int num_results = 0; + + { + std::lock_guard lock(*runtime.sd_ctx_mutex); + sd_image_t* raw_results = generate_video(runtime.sd_ctx, ¶ms, &num_results); + results.adopt(raw_results, num_results); + } + + num_results = results.count(); + if (num_results <= 0) { + error_message = "generate_video returned no results"; + return false; + } + + std::vector video_bytes = create_video_from_sd_images_to_vector(job.vid_gen.output_format, + results.data(), + num_results, + job.vid_gen.gen_params.fps, + job.vid_gen.output_compression); + if (video_bytes.empty()) { + error_message = "failed to encode generated video container"; + return false; + } + + output_media_b64 = base64_encode(video_bytes); + output_media_mime_type = video_mime_type(job.vid_gen.output_format); + output_frame_count = num_results; + output_fps = job.vid_gen.gen_params.fps; + return true; +} + void async_job_worker(ServerRuntime& runtime) { AsyncJobManager& manager = *runtime.async_job_manager; @@ -240,11 +294,23 @@ void async_job_worker(ServerRuntime& runtime) { } std::vector output_images; + std::string output_media_b64; + std::string output_media_mime_type; + int output_frame_count = 0; + int output_fps = 0; std::string error_message; bool ok = false; if (job->kind == AsyncJobKind::ImgGen) { ok = execute_img_gen_job(runtime, *job, output_images, error_message); + } else if (job->kind == AsyncJobKind::VidGen) { + ok = execute_vid_gen_job(runtime, + *job, + output_media_b64, + output_media_mime_type, + output_frame_count, + output_fps, + error_message); } else { error_message = "unsupported job kind"; } @@ -258,8 +324,12 @@ void async_job_worker(ServerRuntime& runtime) { job->completed_at = unix_timestamp_now(); if (ok) { - job->status = AsyncJobStatus::Completed; - job->result_images_b64 = std::move(output_images); + job->status = AsyncJobStatus::Completed; + job->result_images_b64 = std::move(output_images); + job->result_media_b64 = std::move(output_media_b64); + job->result_media_mime_type = std::move(output_media_mime_type); + job->result_frame_count = output_frame_count; + job->result_fps = output_fps; job->error_code.clear(); job->error_message.clear(); } else { @@ -267,6 +337,10 @@ void async_job_worker(ServerRuntime& runtime) { job->error_code = "generation_failed"; job->error_message = error_message.empty() ? "unknown generation error" : error_message; job->result_images_b64.clear(); + job->result_media_b64.clear(); + job->result_media_mime_type.clear(); + job->result_frame_count = 0; + job->result_fps = 0; } purge_expired_jobs(manager); diff --git a/examples/server/async_jobs.h b/examples/server/async_jobs.h index cb90bdd8e..89997a3b4 100644 --- a/examples/server/async_jobs.h +++ b/examples/server/async_jobs.h @@ -36,7 +36,12 @@ struct AsyncGenerationJob { int64_t started_at = 0; int64_t completed_at = 0; ImgGenJobRequest img_gen; + VidGenJobRequest vid_gen; std::vector result_images_b64; + std::string result_media_b64; + std::string result_media_mime_type; + int result_frame_count = 0; + int result_fps = 0; std::string error_code; std::string error_message; }; @@ -63,4 +68,11 @@ bool execute_img_gen_job(ServerRuntime& runtime, AsyncGenerationJob& job, std::vector& output_images, std::string& error_message); +bool execute_vid_gen_job(ServerRuntime& runtime, + AsyncGenerationJob& job, + std::string& output_media_b64, + std::string& output_media_mime_type, + int& output_frame_count, + int& output_fps, + std::string& error_message); void async_job_worker(ServerRuntime& runtime); diff --git a/examples/server/frontend b/examples/server/frontend index 740475a7a..797ccf808 160000 --- a/examples/server/frontend +++ b/examples/server/frontend @@ -1 +1 @@ -Subproject commit 740475a7a6794dc07fb23e8ec5dc56e7e80aa8c1 +Subproject commit 797ccf80825cc035508ba9b599b2a21953e7f835 diff --git a/examples/server/routes_openai.cpp b/examples/server/routes_openai.cpp index af1210459..ce6215d1e 100644 --- a/examples/server/routes_openai.cpp +++ b/examples/server/routes_openai.cpp @@ -253,6 +253,12 @@ void register_openai_api_endpoints(httplib::Server& svr, ServerRuntime& rt) { svr.Post("/v1/images/generations", [runtime](const httplib::Request& req, httplib::Response& res) { try { + if (!runtime_supports_generation_mode(*runtime, IMG_GEN)) { + res.status = 400; + res.set_content(json({{"error", unsupported_generation_mode_error(IMG_GEN)}}).dump(), "application/json"); + return; + } + ImgGenJobRequest request; std::string error_message; if (!build_openai_generation_request(req, *runtime, request, error_message)) { @@ -319,6 +325,12 @@ void register_openai_api_endpoints(httplib::Server& svr, ServerRuntime& rt) { svr.Post("/v1/images/edits", [runtime](const httplib::Request& req, httplib::Response& res) { try { + if (!runtime_supports_generation_mode(*runtime, IMG_GEN)) { + res.status = 400; + res.set_content(json({{"error", unsupported_generation_mode_error(IMG_GEN)}}).dump(), "application/json"); + return; + } + ImgGenJobRequest request; std::string error_message; if (!build_openai_edit_request(req, *runtime, request, error_message)) { diff --git a/examples/server/routes_sdapi.cpp b/examples/server/routes_sdapi.cpp index ca6661c0b..63c89ec8b 100644 --- a/examples/server/routes_sdapi.cpp +++ b/examples/server/routes_sdapi.cpp @@ -246,6 +246,11 @@ void register_sdapi_endpoints(httplib::Server& svr, ServerRuntime& rt) { res.set_content(R"({"error":"empty body"})", "application/json"); return; } + if (!runtime_supports_generation_mode(*runtime, IMG_GEN)) { + res.status = 400; + res.set_content(json({{"error", unsupported_generation_mode_error(IMG_GEN)}}).dump(), "application/json"); + return; + } json j = json::parse(req.body); ImgGenJobRequest request; diff --git a/examples/server/routes_sdcpp.cpp b/examples/server/routes_sdcpp.cpp index 930033bbd..8119136a4 100644 --- a/examples/server/routes_sdcpp.cpp +++ b/examples/server/routes_sdcpp.cpp @@ -75,18 +75,122 @@ static fs::path resolve_display_model_path(const ServerRuntime& runtime) { return {}; } +static json make_sample_params_json(const sd_sample_params_t& sample_params, const std::vector& skip_layers) { + const auto& guidance = sample_params.guidance; + return { + {"scheduler", capability_scheduler_name(sample_params.scheduler)}, + {"sample_method", capability_sample_method_name(sample_params.sample_method)}, + {"sample_steps", sample_params.sample_steps}, + {"eta", finite_number_or_null(sample_params.eta)}, + {"shifted_timestep", sample_params.shifted_timestep}, + {"flow_shift", finite_number_or_null(sample_params.flow_shift)}, + {"guidance", + { + {"txt_cfg", guidance.txt_cfg}, + {"img_cfg", finite_number_or_null(guidance.img_cfg)}, + {"distilled_guidance", guidance.distilled_guidance}, + {"slg", + { + {"layers", skip_layers}, + {"layer_start", guidance.slg.layer_start}, + {"layer_end", guidance.slg.layer_end}, + {"scale", guidance.slg.scale}, + }}, + }}, + }; +} + +static json make_img_gen_defaults_json(const SDGenerationParams& defaults, const std::string& output_format) { + return { + {"prompt", defaults.prompt}, + {"negative_prompt", defaults.negative_prompt}, + {"clip_skip", defaults.clip_skip}, + {"width", defaults.width > 0 ? defaults.width : 512}, + {"height", defaults.height > 0 ? defaults.height : 512}, + {"strength", defaults.strength}, + {"seed", defaults.seed}, + {"batch_count", defaults.batch_count}, + {"auto_resize_ref_image", defaults.auto_resize_ref_image}, + {"increase_ref_index", defaults.increase_ref_index}, + {"control_strength", defaults.control_strength}, + {"sample_params", make_sample_params_json(defaults.sample_params, defaults.skip_layers)}, + {"vae_tiling_params", make_vae_tiling_json(defaults.vae_tiling_params)}, + {"cache_mode", defaults.cache_mode}, + {"cache_option", defaults.cache_option}, + {"scm_mask", defaults.scm_mask}, + {"scm_policy_dynamic", defaults.scm_policy_dynamic}, + {"output_format", output_format}, + {"output_compression", 100}, + }; +} + +static json make_vid_gen_defaults_json(const SDGenerationParams& defaults, const std::string& output_format) { + return { + {"prompt", defaults.prompt}, + {"negative_prompt", defaults.negative_prompt}, + {"clip_skip", defaults.clip_skip}, + {"width", defaults.width > 0 ? defaults.width : 512}, + {"height", defaults.height > 0 ? defaults.height : 512}, + {"strength", defaults.strength}, + {"seed", defaults.seed}, + {"video_frames", defaults.video_frames}, + {"fps", defaults.fps}, + {"moe_boundary", defaults.moe_boundary}, + {"vace_strength", defaults.vace_strength}, + {"sample_params", make_sample_params_json(defaults.sample_params, defaults.skip_layers)}, + {"high_noise_sample_params", make_sample_params_json(defaults.high_noise_sample_params, defaults.high_noise_skip_layers)}, + {"vae_tiling_params", make_vae_tiling_json(defaults.vae_tiling_params)}, + {"cache_mode", defaults.cache_mode}, + {"cache_option", defaults.cache_option}, + {"scm_mask", defaults.scm_mask}, + {"scm_policy_dynamic", defaults.scm_policy_dynamic}, + {"output_format", output_format}, + {"output_compression", 100}, + }; +} + +static json make_img_gen_features_json() { + return { + {"init_image", true}, + {"mask_image", true}, + {"control_image", true}, + {"ref_images", true}, + {"lora", true}, + {"vae_tiling", true}, + {"cache", true}, + {"cancel_queued", true}, + {"cancel_generating", false}, + }; +} + +static json make_vid_gen_features_json() { + return { + {"init_image", true}, + {"end_image", true}, + {"control_frames", true}, + {"high_noise_sample_params", true}, + {"lora", true}, + {"vae_tiling", true}, + {"cache", true}, + {"cancel_queued", true}, + {"cancel_generating", false}, + }; +} + static json make_capabilities_json(ServerRuntime& runtime) { refresh_lora_cache(runtime); AsyncJobManager& manager = *runtime.async_job_manager; const auto& defaults = *runtime.default_gen_params; - const auto& sample_params = defaults.sample_params; - const auto& guidance = sample_params.guidance; const fs::path model_path = resolve_display_model_path(runtime); + const bool supports_img = runtime_supports_generation_mode(runtime, IMG_GEN); + const bool supports_vid = runtime_supports_generation_mode(runtime, VID_GEN); json samplers = json::array(); json schedulers = json::array(); - json output_formats = json::array({"png", "jpeg"}); + json image_output_formats = supported_img_output_formats(); + json video_output_formats = supported_vid_output_formats(); json available_loras = json::array(); + json supported_modes = json::array(); for (int i = 0; i < SAMPLE_METHOD_COUNT; ++i) { samplers.push_back(sd_sample_method_name((sample_method_t)i)); @@ -96,10 +200,6 @@ static json make_capabilities_json(ServerRuntime& runtime) { schedulers.push_back(sd_scheduler_name((scheduler_t)i)); } -#ifdef SD_USE_WEBP - output_formats.push_back("webp"); -#endif - { std::lock_guard lock(*runtime.lora_mutex); for (const auto& entry : *runtime.lora_cache) { @@ -110,77 +210,80 @@ static json make_capabilities_json(ServerRuntime& runtime) { } } + if (supports_img) { + supported_modes.push_back("img_gen"); + } + if (supports_vid) { + supported_modes.push_back("vid_gen"); + } + + std::string default_img_output_format = "png"; + std::string default_vid_output_format = "avi"; + if (!image_output_formats.empty()) { + default_img_output_format = image_output_formats[0].get(); + } + if (!video_output_formats.empty()) { + default_vid_output_format = video_output_formats[0].get(); + } + + json defaults_by_mode = json::object(); + json output_formats_by_mode = json::object(); + json features_by_mode = json::object(); + if (supports_img) { + defaults_by_mode["img_gen"] = make_img_gen_defaults_json(defaults, default_img_output_format); + output_formats_by_mode["img_gen"] = image_output_formats; + features_by_mode["img_gen"] = make_img_gen_features_json(); + } + if (supports_vid) { + defaults_by_mode["vid_gen"] = make_vid_gen_defaults_json(defaults, default_vid_output_format); + output_formats_by_mode["vid_gen"] = video_output_formats; + features_by_mode["vid_gen"] = make_vid_gen_features_json(); + } + + json top_level_defaults = json::object(); + json top_level_output_formats = json::array(); + json top_level_features = { + {"cancel_queued", true}, + {"cancel_generating", false}, + }; + std::string current_mode = ""; + if (supports_img) { + current_mode = "img_gen"; + top_level_defaults = defaults_by_mode["img_gen"]; + top_level_output_formats = output_formats_by_mode["img_gen"]; + top_level_features = features_by_mode["img_gen"]; + } else if (supports_vid) { + current_mode = "vid_gen"; + top_level_defaults = defaults_by_mode["vid_gen"]; + top_level_output_formats = output_formats_by_mode["vid_gen"]; + top_level_features = features_by_mode["vid_gen"]; + } + json result; result["model"] = { {"name", model_path.filename().u8string()}, {"stem", model_path.stem().u8string()}, {"path", model_path.u8string()}, }; - result["defaults"] = { - {"prompt", defaults.prompt}, - {"negative_prompt", defaults.negative_prompt}, - {"clip_skip", defaults.clip_skip}, - {"width", defaults.width > 0 ? defaults.width : 512}, - {"height", defaults.height > 0 ? defaults.height : 512}, - {"strength", defaults.strength}, - {"seed", defaults.seed}, - {"batch_count", defaults.batch_count}, - {"auto_resize_ref_image", defaults.auto_resize_ref_image}, - {"increase_ref_index", defaults.increase_ref_index}, - {"control_strength", defaults.control_strength}, - {"sample_params", - { - {"scheduler", capability_scheduler_name(sample_params.scheduler)}, - {"sample_method", capability_sample_method_name(sample_params.sample_method)}, - {"sample_steps", sample_params.sample_steps}, - {"eta", finite_number_or_null(sample_params.eta)}, - {"shifted_timestep", sample_params.shifted_timestep}, - {"flow_shift", finite_number_or_null(sample_params.flow_shift)}, - {"guidance", - { - {"txt_cfg", guidance.txt_cfg}, - {"img_cfg", finite_number_or_null(guidance.img_cfg)}, - {"distilled_guidance", guidance.distilled_guidance}, - {"slg", - { - {"layers", defaults.skip_layers}, - {"layer_start", guidance.slg.layer_start}, - {"layer_end", guidance.slg.layer_end}, - {"scale", guidance.slg.scale}, - }}, - }}, - }}, - {"vae_tiling_params", make_vae_tiling_json(defaults.vae_tiling_params)}, - {"cache_mode", defaults.cache_mode}, - {"cache_option", defaults.cache_option}, - {"scm_mask", defaults.scm_mask}, - {"scm_policy_dynamic", defaults.scm_policy_dynamic}, - {"output_format", "png"}, - {"output_compression", 100}, - }; - result["limits"] = { - {"min_width", 64}, - {"max_width", 4096}, - {"min_height", 64}, - {"max_height", 4096}, - {"max_batch_count", 8}, - {"max_queue_size", manager.max_pending_jobs}, + result["current_mode"] = current_mode; + result["supported_modes"] = supported_modes; + result["defaults"] = top_level_defaults; + result["defaults_by_mode"] = defaults_by_mode; + result["limits"] = { + {"min_width", 64}, + {"max_width", 4096}, + {"min_height", 64}, + {"max_height", 4096}, + {"max_batch_count", 8}, + {"max_queue_size", manager.max_pending_jobs}, }; - result["samplers"] = samplers; - result["schedulers"] = schedulers; - result["output_formats"] = output_formats; - result["features"] = { - {"init_image", true}, - {"mask_image", true}, - {"control_image", true}, - {"ref_images", true}, - {"lora", true}, - {"vae_tiling", true}, - {"cache", true}, - {"cancel_queued", true}, - {"cancel_generating", false}, - }; - result["loras"] = available_loras; + result["samplers"] = samplers; + result["schedulers"] = schedulers; + result["output_formats"] = top_level_output_formats; + result["output_formats_by_mode"] = output_formats_by_mode; + result["features"] = top_level_features; + result["features_by_mode"] = features_by_mode; + result["loras"] = available_loras; return result; } @@ -211,6 +314,33 @@ static bool parse_img_gen_request(const json& body, return true; } +static bool parse_vid_gen_request(const json& body, + ServerRuntime& runtime, + VidGenJobRequest& request, + std::string& error_message) { + request.gen_params = *runtime.default_gen_params; + + refresh_lora_cache(runtime); + if (!request.gen_params.from_json_str(body.dump(), [&](const std::string& path) { + return get_lora_full_path(runtime, path); + })) { + error_message = "invalid generation parameters"; + return false; + } + + std::string output_format = body.value("output_format", "webm"); + int output_compression = body.value("output_compression", 100); + if (!assign_output_options(request, output_format, output_compression, error_message)) { + return false; + } + // Intentionally disable prompt-embedded LoRA tag parsing for server APIs. + if (!request.gen_params.resolve_and_validate(VID_GEN, "", true)) { + error_message = "invalid generation parameters"; + return false; + } + return true; +} + void register_sdcpp_api_endpoints(httplib::Server& svr, ServerRuntime& rt) { ServerRuntime* runtime = &rt; @@ -226,6 +356,11 @@ void register_sdcpp_api_endpoints(httplib::Server& svr, ServerRuntime& rt) { res.set_content(R"({"error":"empty body"})", "application/json"); return; } + if (!runtime_supports_generation_mode(*runtime, IMG_GEN)) { + res.status = 400; + res.set_content(json({{"error", unsupported_generation_mode_error(IMG_GEN)}}).dump(), "application/json"); + return; + } json body = json::parse(req.body); ImgGenJobRequest request; @@ -276,9 +411,66 @@ void register_sdcpp_api_endpoints(httplib::Server& svr, ServerRuntime& rt) { } }); - svr.Post("/sdcpp/v1/vid_gen", [](const httplib::Request&, httplib::Response& res) { - res.status = 501; - res.set_content(R"({"error":"vid_gen is reserved and not implemented yet"})", "application/json"); + svr.Post("/sdcpp/v1/vid_gen", [runtime](const httplib::Request& req, httplib::Response& res) { + try { + if (req.body.empty()) { + res.status = 400; + res.set_content(R"({"error":"empty body"})", "application/json"); + return; + } + if (!runtime_supports_generation_mode(*runtime, VID_GEN)) { + res.status = 400; + res.set_content(json({{"error", unsupported_generation_mode_error(VID_GEN)}}).dump(), "application/json"); + return; + } + + json body = json::parse(req.body); + VidGenJobRequest request; + std::string error_message; + if (!parse_vid_gen_request(body, *runtime, request, error_message)) { + res.status = 400; + res.set_content(json({{"error", error_message}}).dump(), "application/json"); + return; + } + + AsyncJobManager& manager = *runtime->async_job_manager; + std::shared_ptr job = std::make_shared(); + job->kind = AsyncJobKind::VidGen; + job->status = AsyncJobStatus::Queued; + job->created_at = unix_timestamp_now(); + job->vid_gen = std::move(request); + + { + std::lock_guard lock(manager.mutex); + purge_expired_jobs(manager); + if (count_pending_jobs(manager) >= manager.max_pending_jobs) { + res.status = 429; + res.set_content(R"({"error":"job queue is full"})", "application/json"); + return; + } + job->id = make_async_job_id(manager); + manager.jobs[job->id] = job; + manager.queue.push_back(job->id); + } + + manager.cv.notify_one(); + + json out; + out["id"] = job->id; + out["kind"] = async_job_kind_name(job->kind); + out["status"] = async_job_status_name(job->status); + out["created"] = job->created_at; + out["poll_url"] = "/sdcpp/v1/jobs/" + job->id; + + res.status = 202; + res.set_content(out.dump(), "application/json"); + } catch (const json::parse_error& e) { + res.status = 400; + res.set_content(json({{"error", "invalid json"}, {"message", e.what()}}).dump(), "application/json"); + } catch (const std::exception& e) { + res.status = 500; + res.set_content(json({{"error", "server_error"}, {"message", e.what()}}).dump(), "application/json"); + } }); svr.Get(R"(/sdcpp/v1/jobs/([A-Za-z0-9_\-]+))", [runtime](const httplib::Request& req, httplib::Response& res) { diff --git a/examples/server/runtime.cpp b/examples/server/runtime.cpp index c29799e3a..39880a182 100644 --- a/examples/server/runtime.cpp +++ b/examples/server/runtime.cpp @@ -45,6 +45,44 @@ std::string normalize_output_format(std::string output_format) { return output_format; } +std::vector supported_img_output_formats(bool allow_webp) { + std::vector formats = {"png", "jpeg"}; +#ifdef SD_USE_WEBP + if (allow_webp) { + formats.push_back("webp"); + } +#else + (void)allow_webp; +#endif + return formats; +} + +std::vector supported_vid_output_formats() { + std::vector formats; +#ifdef SD_USE_WEBM + formats.push_back("webm"); +#endif +#ifdef SD_USE_WEBP + formats.push_back("webp"); +#endif + formats.push_back("avi"); + return formats; +} + +static std::string valid_vid_output_formats_message() { + const std::vector formats = supported_vid_output_formats(); + + std::string message = "invalid output_format, must be one of ["; + for (size_t i = 0; i < formats.size(); ++i) { + if (i > 0) { + message += ", "; + } + message += formats[i]; + } + message += "]"; + return message; +} + bool assign_output_options(ImgGenJobRequest& request, std::string output_format, int output_compression, @@ -53,19 +91,88 @@ bool assign_output_options(ImgGenJobRequest& request, request.output_format = normalize_output_format(std::move(output_format)); request.output_compression = std::clamp(output_compression, 0, 100); - const bool valid_format = request.output_format == "png" || - request.output_format == "jpeg" || - (allow_webp && request.output_format == "webp"); + const std::vector valid_formats = supported_img_output_formats(allow_webp); + const bool valid_format = std::find(valid_formats.begin(), + valid_formats.end(), + request.output_format) != valid_formats.end(); if (!valid_format) { - error_message = allow_webp - ? "invalid output_format, must be one of [png, jpeg, webp]" - : "invalid output_format, must be one of [png, jpeg]"; + error_message = "invalid output_format, must be one of ["; + for (size_t i = 0; i < valid_formats.size(); ++i) { + if (i > 0) { + error_message += ", "; + } + error_message += valid_formats[i]; + } + error_message += "]"; return false; } return true; } +bool assign_output_options(VidGenJobRequest& request, + std::string output_format, + int output_compression, + std::string& error_message) { + request.output_format = normalize_output_format(std::move(output_format)); + request.output_compression = std::clamp(output_compression, 0, 100); + + if (request.output_format == "avi") { + return true; + } + + if (request.output_format == "webm") { +#ifdef SD_USE_WEBM + return true; +#else + error_message = valid_vid_output_formats_message(); + return false; +#endif + } + + if (request.output_format == "webp") { +#ifdef SD_USE_WEBP + return true; +#else + error_message = valid_vid_output_formats_message(); + return false; +#endif + } + + error_message = valid_vid_output_formats_message(); + return false; +} + +std::string video_mime_type(const std::string& output_format) { + if (output_format == "webm") { + return "video/webm"; + } + if (output_format == "webp") { + return "image/webp"; + } + return "video/x-msvideo"; +} + +bool runtime_supports_generation_mode(const ServerRuntime& runtime, SDMode mode) { + if (mode == VID_GEN) { + return sd_ctx_supports_video_generation(runtime.sd_ctx); + } + if (mode == IMG_GEN) { + return sd_ctx_supports_image_generation(runtime.sd_ctx); + } + return true; +} + +std::string unsupported_generation_mode_error(SDMode mode) { + if (mode == VID_GEN) { + return "loaded model does not support vid_gen"; + } + if (mode == IMG_GEN) { + return "loaded model does not support img_gen"; + } + return "loaded model does not support requested mode"; +} + ArgOptions SDSvrParams::get_options() { ArgOptions options; diff --git a/examples/server/runtime.h b/examples/server/runtime.h index 65e932439..1970e7dbc 100644 --- a/examples/server/runtime.h +++ b/examples/server/runtime.h @@ -58,13 +58,32 @@ struct ImgGenJobRequest { } }; +struct VidGenJobRequest { + SDGenerationParams gen_params; + std::string output_format = "webm"; + int output_compression = 100; + + sd_vid_gen_params_t to_sd_vid_gen_params_t() { + return gen_params.to_sd_vid_gen_params_t(); + } +}; + std::string base64_encode(const std::vector& bytes); std::string normalize_output_format(std::string output_format); +std::vector supported_img_output_formats(bool allow_webp = true); +std::vector supported_vid_output_formats(); bool assign_output_options(ImgGenJobRequest& request, std::string output_format, int output_compression, bool allow_webp, std::string& error_message); +bool assign_output_options(VidGenJobRequest& request, + std::string output_format, + int output_compression, + std::string& error_message); +std::string video_mime_type(const std::string& output_format); +bool runtime_supports_generation_mode(const ServerRuntime& runtime, SDMode mode); +std::string unsupported_generation_mode_error(SDMode mode); void refresh_lora_cache(ServerRuntime& rt); std::string get_lora_full_path(ServerRuntime& rt, const std::string& path); int64_t unix_timestamp_now(); diff --git a/format-code.sh b/format-code.sh index 5c30fb4ff..8aa422bca 100644 --- a/format-code.sh +++ b/format-code.sh @@ -1,5 +1,5 @@ for f in src/*.cpp src/*.h src/*.hpp src/tokenizers/*.h src/tokenizers/*.cpp src/tokenizers/vocab/*.h src/tokenizers/vocab/*.cpp \ - examples/cli/*.cpp examples/cli/*.h examples/server/*.cpp \ + src/model_io/*.h src/model_io/*.cpp examples/cli/*.cpp examples/cli/*.h examples/server/*.cpp \ examples/common/*.hpp examples/common/*.h examples/common/*.cpp; do [[ "$f" == vocab* ]] && continue echo "formatting '$f'" diff --git a/include/stable-diffusion.h b/include/stable-diffusion.h index f093bb56c..57b700a48 100644 --- a/include/stable-diffusion.h +++ b/include/stable-diffusion.h @@ -50,6 +50,7 @@ enum sample_method_t { TCD_SAMPLE_METHOD, RES_MULTISTEP_SAMPLE_METHOD, RES_2S_SAMPLE_METHOD, + ER_SDE_SAMPLE_METHOD, SAMPLE_METHOD_COUNT }; @@ -211,12 +212,20 @@ typedef struct { uint8_t* data; } sd_image_t; +typedef struct { + float eta; + float momentum; + float norm_threshold; + float norm_threshold_smoothing; +} sd_apg_params_t; + typedef struct { int* layers; size_t layer_count; float layer_start; float layer_end; float scale; + bool uncond; } sd_slg_params_t; typedef struct { @@ -224,6 +233,7 @@ typedef struct { float img_cfg; float distilled_guidance; sd_slg_params_t slg; + sd_apg_params_t apg; } sd_guidance_params_t; typedef struct { @@ -347,6 +357,8 @@ SD_API void sd_set_progress_callback(sd_progress_cb_t cb, void* data); SD_API void sd_set_preview_callback(sd_preview_cb_t cb, enum preview_t mode, int interval, bool denoised, bool noisy, void* data); SD_API int32_t sd_get_num_physical_cores(); SD_API const char* sd_get_system_info(); +SD_API bool sd_ctx_supports_image_generation(const sd_ctx_t* sd_ctx); +SD_API bool sd_ctx_supports_video_generation(const sd_ctx_t* sd_ctx); SD_API const char* sd_type_name(enum sd_type_t type); SD_API enum sd_type_t str_to_sd_type(const char* str); diff --git a/src/auto_encoder_kl.hpp b/src/auto_encoder_kl.hpp index d4283959d..5cf09b883 100644 --- a/src/auto_encoder_kl.hpp +++ b/src/auto_encoder_kl.hpp @@ -533,7 +533,7 @@ class AutoEncoderKLModel : public GGMLBlock { const std::string& prefix = "") : version(version), decode_only(decode_only), use_video_decoder(use_video_decoder) { if (sd_version_is_dit(version)) { - if (sd_version_is_flux2(version)) { + if (sd_version_uses_flux2_vae(version)) { dd_config.z_channels = 32; embed_dim = 32; } else { @@ -578,7 +578,7 @@ class AutoEncoderKLModel : public GGMLBlock { ggml_tensor* decode(GGMLRunnerContext* ctx, ggml_tensor* z) { // z: [N, z_channels, h, w] - if (sd_version_is_flux2(version)) { + if (sd_version_uses_flux2_vae(version)) { // [N, C*p*p, h, w] -> [N, C, h*p, w*p] int64_t p = 2; @@ -617,7 +617,7 @@ class AutoEncoderKLModel : public GGMLBlock { auto quant_conv = std::dynamic_pointer_cast(blocks["quant_conv"]); z = quant_conv->forward(ctx, z); // [N, 2*embed_dim, h/8, w/8] } - if (sd_version_is_flux2(version)) { + if (sd_version_uses_flux2_vae(version)) { z = ggml_ext_chunk(ctx->ggml_ctx, z, 2, 2)[0]; // [N, C, H, W] -> [N, C*p*p, H/p, W/p] @@ -640,7 +640,7 @@ class AutoEncoderKLModel : public GGMLBlock { int get_encoder_output_channels() { int factor = dd_config.double_z ? 2 : 1; - if (sd_version_is_flux2(version)) { + if (sd_version_uses_flux2_vae(version)) { return dd_config.z_channels * 4; } return dd_config.z_channels * factor; @@ -673,7 +673,7 @@ struct AutoEncoderKL : public VAE { } else if (sd_version_is_flux(version) || sd_version_is_z_image(version)) { scale_factor = 0.3611f; shift_factor = 0.1159f; - } else if (sd_version_is_flux2(version)) { + } else if (sd_version_uses_flux2_vae(version)) { scale_factor = 1.0f; shift_factor = 0.f; } @@ -747,7 +747,7 @@ struct AutoEncoderKL : public VAE { } sd::Tensor vae_output_to_latents(const sd::Tensor& vae_output, std::shared_ptr rng) override { - if (sd_version_is_flux2(version)) { + if (sd_version_uses_flux2_vae(version)) { return vae_output; } else if (version == VERSION_SD1_PIX2PIX) { return sd::ops::chunk(vae_output, 2, 2)[0]; @@ -758,7 +758,7 @@ struct AutoEncoderKL : public VAE { std::pair, sd::Tensor> get_latents_mean_std(const sd::Tensor& latents, int channel_dim) { GGML_ASSERT(channel_dim >= 0 && static_cast(channel_dim) < static_cast(latents.dim())); - if (sd_version_is_flux2(version)) { + if (sd_version_uses_flux2_vae(version)) { GGML_ASSERT(latents.shape()[channel_dim] == 128); std::vector stats_shape(static_cast(latents.dim()), 1); stats_shape[static_cast(channel_dim)] = latents.shape()[channel_dim]; @@ -804,7 +804,7 @@ struct AutoEncoderKL : public VAE { } sd::Tensor diffusion_to_vae_latents(const sd::Tensor& latents) override { - if (sd_version_is_flux2(version)) { + if (sd_version_uses_flux2_vae(version)) { int channel_dim = 2; auto [mean_tensor, std_tensor] = get_latents_mean_std(latents, channel_dim); return (latents * std_tensor) / scale_factor + mean_tensor; @@ -813,7 +813,7 @@ struct AutoEncoderKL : public VAE { } sd::Tensor vae_to_diffusion_latents(const sd::Tensor& latents) override { - if (sd_version_is_flux2(version)) { + if (sd_version_uses_flux2_vae(version)) { int channel_dim = 2; auto [mean_tensor, std_tensor] = get_latents_mean_std(latents, channel_dim); return ((latents - mean_tensor) * scale_factor) / std_tensor; diff --git a/src/common_block.hpp b/src/common_block.hpp index 2cef389af..112a4d7a1 100644 --- a/src/common_block.hpp +++ b/src/common_block.hpp @@ -277,6 +277,7 @@ class CrossAttention : public GGMLBlock { int64_t context_dim; int64_t n_head; int64_t d_head; + bool xtra_dim = false; public: CrossAttention(int64_t query_dim, @@ -288,7 +289,11 @@ class CrossAttention : public GGMLBlock { query_dim(query_dim), context_dim(context_dim) { int64_t inner_dim = d_head * n_head; - + if (context_dim == 320 && d_head == 320) { + // LOG_DEBUG("CrossAttention: temp set dim to 1024 for sdxs_09"); + xtra_dim = true; + context_dim = 1024; + } blocks["to_q"] = std::shared_ptr(new Linear(query_dim, inner_dim, false)); blocks["to_k"] = std::shared_ptr(new Linear(context_dim, inner_dim, false)); blocks["to_v"] = std::shared_ptr(new Linear(context_dim, inner_dim, false)); @@ -313,10 +318,16 @@ class CrossAttention : public GGMLBlock { int64_t n_context = context->ne[1]; int64_t inner_dim = d_head * n_head; - auto q = to_q->forward(ctx, x); // [N, n_token, inner_dim] + auto q = to_q->forward(ctx, x); // [N, n_token, inner_dim] + if (xtra_dim) { + // LOG_DEBUG("CrossAttention: temp set dim to 1024 for sdxs_09"); + context->ne[0] = 1024; // patch dim + } auto k = to_k->forward(ctx, context); // [N, n_context, inner_dim] auto v = to_v->forward(ctx, context); // [N, n_context, inner_dim] - + if (xtra_dim) { + context->ne[0] = 320; // reset dim to orig + } x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, n_head, nullptr, false, ctx->flash_attn_enabled); // [N, n_token, inner_dim] x = to_out_0->forward(ctx, x); // [N, n_token, query_dim] diff --git a/src/conditioner.hpp b/src/conditioner.hpp index a39346cbf..9f4d45524 100644 --- a/src/conditioner.hpp +++ b/src/conditioner.hpp @@ -1621,10 +1621,12 @@ struct LLMEmbedder : public Conditioner { LLM::LLMArch arch = LLM::LLMArch::QWEN2_5_VL; if (version == VERSION_FLUX2) { arch = LLM::LLMArch::MISTRAL_SMALL_3_2; + } else if (sd_version_is_ernie_image(version)) { + arch = LLM::LLMArch::MINISTRAL_3_3B; } else if (sd_version_is_z_image(version) || version == VERSION_OVIS_IMAGE || version == VERSION_FLUX2_KLEIN) { arch = LLM::LLMArch::QWEN3; } - if (arch == LLM::LLMArch::MISTRAL_SMALL_3_2) { + if (arch == LLM::LLMArch::MISTRAL_SMALL_3_2 || arch == LLM::LLMArch::MINISTRAL_3_3B) { tokenizer = std::make_shared(); } else { tokenizer = std::make_shared(); @@ -1671,14 +1673,18 @@ struct LLMEmbedder : public Conditioner { size_t max_length = 100000000) { std::vector> parsed_attention; if (attn_range.first >= 0 && attn_range.second > 0) { - parsed_attention.emplace_back(text.substr(0, attn_range.first), 1.f); + if (attn_range.first > 0) { + parsed_attention.emplace_back(text.substr(0, attn_range.first), 1.f); + } if (attn_range.second - attn_range.first > 0) { auto new_parsed_attention = parse_prompt_attention(text.substr(attn_range.first, attn_range.second - attn_range.first)); parsed_attention.insert(parsed_attention.end(), new_parsed_attention.begin(), new_parsed_attention.end()); } - parsed_attention.emplace_back(text.substr(attn_range.second), 1.f); + if (attn_range.second < text.size()) { + parsed_attention.emplace_back(text.substr(attn_range.second), 1.f); + } } else { parsed_attention.emplace_back(text, 1.f); } @@ -1867,6 +1873,13 @@ struct LLMEmbedder : public Conditioner { prompt_attn_range.second = static_cast(prompt.size()); prompt += "[/INST]"; + } else if (sd_version_is_ernie_image(version)) { + prompt_template_encode_start_idx = 0; + out_layers = {25}; // -2 + + prompt_attn_range.first = 0; + prompt += conditioner_params.text; + prompt_attn_range.second = static_cast(prompt.size()); } else if (sd_version_is_z_image(version)) { prompt_template_encode_start_idx = 0; out_layers = {35}; // -2 diff --git a/src/convert.cpp b/src/convert.cpp new file mode 100644 index 000000000..7cae8df0f --- /dev/null +++ b/src/convert.cpp @@ -0,0 +1,138 @@ +#include +#include +#include +#include + +#include "model.h" +#include "model_io/gguf_io.h" +#include "model_io/safetensors_io.h" +#include "util.h" + +#include "ggml-cpu.h" + +static ggml_type get_export_tensor_type(ModelLoader& model_loader, + const TensorStorage& tensor_storage, + ggml_type type, + const TensorTypeRules& tensor_type_rules) { + const std::string& name = tensor_storage.name; + ggml_type tensor_type = tensor_storage.type; + ggml_type dst_type = type; + + for (const auto& tensor_type_rule : tensor_type_rules) { + std::regex pattern(tensor_type_rule.first); + if (std::regex_search(name, pattern)) { + dst_type = tensor_type_rule.second; + break; + } + } + + if (model_loader.tensor_should_be_converted(tensor_storage, dst_type)) { + tensor_type = dst_type; + } + + return tensor_type; +} + +static bool load_tensors_for_export(ModelLoader& model_loader, + ggml_context* ggml_ctx, + ggml_type type, + const TensorTypeRules& tensor_type_rules, + std::vector& tensors) { + std::mutex tensor_mutex; + auto on_new_tensor_cb = [&](const TensorStorage& tensor_storage, ggml_tensor** dst_tensor) -> bool { + const std::string& name = tensor_storage.name; + ggml_type tensor_type = get_export_tensor_type(model_loader, tensor_storage, type, tensor_type_rules); + + std::lock_guard lock(tensor_mutex); + ggml_tensor* tensor = ggml_new_tensor(ggml_ctx, tensor_type, tensor_storage.n_dims, tensor_storage.ne); + if (tensor == nullptr) { + LOG_ERROR("ggml_new_tensor failed"); + return false; + } + ggml_set_name(tensor, name.c_str()); + + if (!tensor->data) { + GGML_ASSERT(ggml_nelements(tensor) == 0); + // Avoid crashing writers by setting a dummy pointer for zero-sized tensors. + LOG_DEBUG("setting dummy pointer for zero-sized tensor %s", name.c_str()); + tensor->data = ggml_get_mem_buffer(ggml_ctx); + } + + TensorWriteInfo write_info; + write_info.tensor = tensor; + write_info.n_dims = tensor_storage.n_dims; + for (int i = 0; i < tensor_storage.n_dims; ++i) { + write_info.ne[i] = tensor_storage.ne[i]; + } + + *dst_tensor = tensor; + tensors.push_back(std::move(write_info)); + + return true; + }; + + bool success = model_loader.load_tensors(on_new_tensor_cb); + LOG_INFO("load tensors done"); + return success; +} + +bool convert(const char* input_path, + const char* vae_path, + const char* output_path, + sd_type_t output_type, + const char* tensor_type_rules, + bool convert_name) { + ModelLoader model_loader; + + if (!model_loader.init_from_file(input_path)) { + LOG_ERROR("init model loader from file failed: '%s'", input_path); + return false; + } + + if (vae_path != nullptr && strlen(vae_path) > 0) { + if (!model_loader.init_from_file(vae_path, "vae.")) { + LOG_ERROR("init model loader from file failed: '%s'", vae_path); + return false; + } + } + if (convert_name) { + model_loader.convert_tensors_name(); + } + + ggml_type type = (ggml_type)output_type; + bool output_is_safetensors = ends_with(output_path, ".safetensors"); + TensorTypeRules type_rules = parse_tensor_type_rules(tensor_type_rules); + + auto backend = ggml_backend_cpu_init(); + size_t mem_size = 1 * 1024 * 1024; // for padding + mem_size += model_loader.get_tensor_storage_map().size() * ggml_tensor_overhead(); + mem_size += model_loader.get_params_mem_size(backend, type); + LOG_INFO("model tensors mem size: %.2fMB", mem_size / 1024.f / 1024.f); + ggml_context* ggml_ctx = ggml_init({mem_size, nullptr, false}); + + if (ggml_ctx == nullptr) { + LOG_ERROR("ggml_init failed for converter"); + ggml_backend_free(backend); + return false; + } + + std::vector tensors; + bool success = load_tensors_for_export(model_loader, ggml_ctx, type, type_rules, tensors); + ggml_backend_free(backend); + + std::string error; + if (success) { + if (output_is_safetensors) { + success = write_safetensors_file(output_path, tensors, &error); + } else { + success = write_gguf_file(output_path, tensors, &error); + } + } + + if (!success && !error.empty()) { + LOG_ERROR("%s", error.c_str()); + } + + ggml_free(ggml_ctx); + return success; +} diff --git a/src/denoiser.hpp b/src/denoiser.hpp index c9c9d881d..a6e81d597 100644 --- a/src/denoiser.hpp +++ b/src/denoiser.hpp @@ -953,8 +953,9 @@ static sd::Tensor sample_dpmpp_2s_ancestral(denoise_cb_t model, float t_next = t_fn(sigma_down); float h = t_next - t; float s = t + 0.5f * h; - sd::Tensor x2 = (sigma_fn(s) / sigma_fn(t)) * x - (exp(-h * 0.5f) - 1) * denoised; - auto denoised2_opt = model(x2, sigmas[i + 1], i + 1); + float sigma_s = sigma_fn(s); + sd::Tensor x2 = (sigma_s / sigma_fn(t)) * x - (exp(-h * 0.5f) - 1) * denoised; + auto denoised2_opt = model(x2, sigma_s, i + 1); if (denoised2_opt.empty()) { return {}; } @@ -969,6 +970,100 @@ static sd::Tensor sample_dpmpp_2s_ancestral(denoise_cb_t model, return x; } +static sd::Tensor sample_dpmpp_2s_ancestral_flow(denoise_cb_t model, + sd::Tensor x, + const std::vector& sigmas, + std::shared_ptr rng, + float eta = 1.0f) { + int steps = static_cast(sigmas.size()) - 1; + for (int i = 0; i < steps; i++) { + float sigma = sigmas[i]; + float sigma_to = sigmas[i + 1]; + + bool opt_first_step = (1.0 - sigma < 1e-6); + + auto denoised_opt = model(x, sigma, (opt_first_step ? 1 : -1) * (i + 1)); + if (denoised_opt.empty()) { + return {}; + } + sd::Tensor denoised = std::move(denoised_opt); + + if (sigma_to == 0.0f) { + // Euler method (final step, no noise) + // sigma_to == 0 --> sigma_down = 0, so: + // x + d * (sigma_down - sigma) + // = x + ((x - denoised) / sigma) * (sigma_down - sigma) + // = x + ((x - denoised) / sigma) * ( 0 - sigma) + // = x + ((x - denoised) ) * -1 + // = x -x + denoised + x = denoised; + + } else { + auto [sigma_down, sigma_up, alpha_scale] = get_ancestral_step_flow(sigma, sigma_to, eta); + sd::Tensor D_i; + + if (opt_first_step) { + // the reformulated exp_s calc already accounts for this, but we can avoid + // a redundant model call for the typical sigma 1 at the first step: + // exp_s = sqrt((1-sigma)/sigma * (1-sigma_down)/sigma_down) + // = sqrt((1- 1)/ 1 * (1-sigma_down)/sigma_down) + // = 0 + // so sigma_s = 1 = sigma, and sigma_s_i_ratio = sigma_s / sigma = 1 + // u = (x*sigma_s_i_ratio)+(denoised*(1.0f-sigma_s_i_ratio)) + // = (x*1)+(denoised*0) = x + // so D_i = model(u, sigma_s, i + 1) + // = model(x, sigma, i + 1) + // = denoised + D_i = denoised; + + } else { + float sigma_s; + + // ref implementation would be: + // auto lambda_fn = [](float sigma) -> float { + // return std::log((1.0f - sigma) / sigma); }; + // auto sigma_fn = [](float lbda) -> float { + // return 1.0f / (std::exp(lbda) + 1.0f); }; + // t_i = lambda_fn(sigma); + // t_down = lambda_fn(sigma_down); + // float r = 0.5f; + // h = t_down - t_i; + // s = t_i + r * h; + // sigma_s = sigma_fn(s); + + // assuming r is constant, we sidestep the singularity at sigma -> 1 by: + // s = 0.5 * (lambda_fn(sigma) + lambda_fn(sigma_down)) + // = 0.5 * (log((1-sigma)/sigma) + log((1-sigma_down)/sigma_down)) + // = 0.5 * log(((1-sigma)/sigma) * ((1-sigma_down)/sigma_down)) + // = log(sqrt (((1-sigma)/sigma) * ((1-sigma_down)/sigma_down))) + // so exp(s) = sqrt((1-sigma)/sigma * (1-sigma_down)/sigma_down) + // and sigma_s = sigma_fn(s) = 1.0f / (exp(s) + 1.0f) + + float exp_s = std::sqrt(((1 - sigma) / sigma) * ((1 - sigma_down) / sigma_down)); + sigma_s = 1.0f / (exp_s + 1.0f); + + float sigma_s_i_ratio = sigma_s / sigma; + sd::Tensor u = (x * sigma_s_i_ratio) + (denoised * (1.0f - sigma_s_i_ratio)); + + auto denoised2_opt = model(u, sigma_s, i + 1); + if (denoised2_opt.empty()) { + return {}; + } + D_i = std::move(denoised2_opt); + } + + float sigma_down_i_ratio = sigma_down / sigma; + x = (x * sigma_down_i_ratio) + (D_i * (1.0f - sigma_down_i_ratio)); + + if (sigma_to > 0.0f && eta > 0.0f) { + x = alpha_scale * x + sd::Tensor::randn_like(x, rng) * sigma_up; + } + } + } + + return x; +} + static sd::Tensor sample_dpmpp_2m(denoise_cb_t model, sd::Tensor x, const std::vector& sigmas) { @@ -1040,7 +1135,8 @@ static sd::Tensor sample_dpmpp_2m_v2(denoise_cb_t model, static sd::Tensor sample_lcm(denoise_cb_t model, sd::Tensor x, const std::vector& sigmas, - std::shared_ptr rng) { + std::shared_ptr rng, + bool is_flow_denoiser) { int steps = static_cast(sigmas.size()) - 1; for (int i = 0; i < steps; i++) { auto denoised_opt = model(x, sigmas[i], i + 1); @@ -1049,6 +1145,9 @@ static sd::Tensor sample_lcm(denoise_cb_t model, } x = std::move(denoised_opt); if (sigmas[i + 1] > 0) { + if (is_flow_denoiser) { + x *= (1 - sigmas[i + 1]); + } x += sd::Tensor::randn_like(x, rng) * sigmas[i + 1]; } } @@ -1285,37 +1384,149 @@ static sd::Tensor sample_res_2s(denoise_cb_t model, return x; } +static sd::Tensor sample_er_sde(denoise_cb_t model, + sd::Tensor x, + std::vector sigmas, + std::shared_ptr rng, + bool is_flow_denoiser, + float eta) { + constexpr int max_stage = 3; + constexpr int num_integration_points = 200; + constexpr float num_integration_points_f = 200.0f; + float s_noise = eta; + + auto er_sde_flow_sigma = [](float sigma) -> float { + sigma = std::max(sigma, 1e-6f); + sigma = std::min(sigma, 1.0f - 1e-4f); + return sigma; + }; + + auto sigma_to_er_sde_lambda = [&](float sigma, bool is_flow_denoiser) -> float { + if (is_flow_denoiser) { + sigma = er_sde_flow_sigma(sigma); + return sigma / std::max(1.0f - sigma, 1e-6f); + } + return std::max(sigma, 1e-6f); + }; + + auto sigma_to_er_sde_alpha = [&](float sigma, bool is_flow_denoiser) -> float { + if (is_flow_denoiser) { + sigma = er_sde_flow_sigma(sigma); + return 1.0f - sigma; + } + return 1.0f; + }; + + auto er_sde_noise_scaler = [](float x) -> float { + x = std::max(x, 0.0f); + return x * (std::exp(std::pow(x, 0.3f)) + 10.0f); + }; + + if (is_flow_denoiser) { + for (size_t i = 0; i + 1 < sigmas.size(); ++i) { + if (sigmas[i] > 1.0f) { + sigmas[i] = er_sde_flow_sigma(sigmas[i]); + } + } + } + + std::vector er_lambdas(sigmas.size(), 0.0f); + for (size_t i = 0; i < sigmas.size(); ++i) { + er_lambdas[i] = sigma_to_er_sde_lambda(sigmas[i], is_flow_denoiser); + } + + sd::Tensor old_denoised = x; + sd::Tensor old_denoised_d = x; + bool have_old_denoised = false; + bool have_old_denoised_d = false; + + int steps = static_cast(sigmas.size()) - 1; + for (int i = 0; i < steps; i++) { + sd::Tensor denoised = model(x, sigmas[i], i + 1); + if (denoised.empty()) { + return {}; + } + + int stage_used = std::min(max_stage, i + 1); + + if (sigmas[i + 1] == 0.0f) { + x = denoised; + } else { + float er_lambda_s = er_lambdas[i]; + float er_lambda_t = er_lambdas[i + 1]; + float alpha_s = sigma_to_er_sde_alpha(sigmas[i], is_flow_denoiser); + float alpha_t = sigma_to_er_sde_alpha(sigmas[i + 1], is_flow_denoiser); + float scaled_s = er_sde_noise_scaler(er_lambda_s); + float scaled_t = er_sde_noise_scaler(er_lambda_t); + float r_alpha = alpha_s > 0.0f ? alpha_t / alpha_s : 0.0f; + float r = scaled_s > 0.0f ? scaled_t / scaled_s : 0.0f; + + x = r_alpha * r * x + alpha_t * (1.0f - r) * denoised; + + if (stage_used >= 2 && have_old_denoised) { + float dt = er_lambda_t - er_lambda_s; + float lambda_step_size = -dt / num_integration_points_f; + float s = 0.0f; + float s_u = 0.0f; + + for (int p = 0; p < num_integration_points; ++p) { + float lambda_pos = er_lambda_t + p * lambda_step_size; + float scaled_pos = er_sde_noise_scaler(lambda_pos); + if (scaled_pos <= 0.0f) { + continue; + } + + s += 1.0f / scaled_pos; + if (stage_used >= 3 && have_old_denoised_d) { + s_u += (lambda_pos - er_lambda_s) / scaled_pos; + } + } + + s *= lambda_step_size; + + float denom_d = er_lambda_s - er_lambdas[i - 1]; + if (std::fabs(denom_d) > 1e-12f) { + float coeff_d = alpha_t * (dt + s * scaled_t); + sd::Tensor denoised_d = (denoised - old_denoised) / denom_d; + x += coeff_d * denoised_d; + + if (stage_used >= 3 && have_old_denoised_d) { + float denom_u = (er_lambda_s - er_lambdas[i - 2]) * 0.5f; + if (std::fabs(denom_u) > 1e-12f) { + s_u *= lambda_step_size; + float coeff_u = alpha_t * (0.5f * dt * dt + s_u * scaled_t); + sd::Tensor denoised_u = (denoised_d - old_denoised_d) / denom_u; + x += coeff_u * denoised_u; + } + } + + old_denoised_d = denoised_d; + have_old_denoised_d = true; + } + } + + float noise_scale_sq = er_lambda_t * er_lambda_t - er_lambda_s * er_lambda_s * r * r; + if (s_noise > 0.0f && noise_scale_sq > 0.0f) { + float noise_scale = alpha_t * std::sqrt(std::max(noise_scale_sq, 0.0f)); + x += sd::Tensor::randn_like(x, rng) * noise_scale; + } + } + + old_denoised = denoised; + have_old_denoised = true; + } + return x; +} + static sd::Tensor sample_ddim_trailing(denoise_cb_t model, sd::Tensor x, const std::vector& sigmas, std::shared_ptr rng, float eta) { - float beta_start = 0.00085f; - float beta_end = 0.0120f; - std::vector alphas_cumprod(TIMESTEPS); - std::vector compvis_sigmas(TIMESTEPS); - for (int i = 0; i < TIMESTEPS; i++) { - alphas_cumprod[i] = - (i == 0 ? 1.0f : alphas_cumprod[i - 1]) * - (1.0f - - std::pow(sqrtf(beta_start) + - (sqrtf(beta_end) - sqrtf(beta_start)) * - ((float)i / (TIMESTEPS - 1)), - 2)); - compvis_sigmas[i] = - std::sqrt((1 - alphas_cumprod[i]) / alphas_cumprod[i]); - } - int steps = static_cast(sigmas.size()) - 1; for (int i = 0; i < steps; i++) { - int timestep = static_cast(roundf(TIMESTEPS - i * ((float)TIMESTEPS / steps))) - 1; - int prev_timestep = timestep - TIMESTEPS / steps; - float sigma = static_cast(compvis_sigmas[timestep]); - if (i == 0) { - x *= std::sqrt(sigma * sigma + 1) / sigma; - } else { - x *= std::sqrt(sigma * sigma + 1); - } + float sigma = sigmas[i]; + float sigma_to = sigmas[i + 1]; auto model_output_opt = model(x, sigma, i + 1); if (model_output_opt.empty()) { @@ -1324,8 +1535,8 @@ static sd::Tensor sample_ddim_trailing(denoise_cb_t model, sd::Tensor model_output = std::move(model_output_opt); model_output = (x - model_output) * (1.0f / sigma); - float alpha_prod_t = static_cast(alphas_cumprod[timestep]); - float alpha_prod_t_prev = static_cast(prev_timestep >= 0 ? alphas_cumprod[prev_timestep] : alphas_cumprod[0]); + float alpha_prod_t = 1.0f / (sigma * sigma + 1.0f); + float alpha_prod_t_prev = 1.0f / (sigma_to * sigma_to + 1.0f); float beta_prod_t = 1.0f - alpha_prod_t; sd::Tensor pred_original_sample = ((x / std::sqrt(sigma * sigma + 1)) - @@ -1337,11 +1548,11 @@ static sd::Tensor sample_ddim_trailing(denoise_cb_t model, (1.0f - alpha_prod_t / alpha_prod_t_prev); float std_dev_t = eta * std::sqrt(variance); - x = std::sqrt(alpha_prod_t_prev) * pred_original_sample + - std::sqrt(1.0f - alpha_prod_t_prev - std::pow(std_dev_t, 2)) * model_output; + x = pred_original_sample + + std::sqrt((1.0f - alpha_prod_t_prev - std::pow(std_dev_t, 2)) / alpha_prod_t_prev) * model_output; if (eta > 0) { - x += std_dev_t * sd::Tensor::randn_like(x, rng); + x += std_dev_t / std::sqrt(alpha_prod_t_prev) * sd::Tensor::randn_like(x, rng); } } return x; @@ -1368,19 +1579,26 @@ static sd::Tensor sample_tcd(denoise_cb_t model, std::sqrt((1 - alphas_cumprod[i]) / alphas_cumprod[i]); } - int original_steps = 50; - int steps = static_cast(sigmas.size()) - 1; + auto get_timestep_from_sigma = [&](float s) -> int { + auto it = std::lower_bound(compvis_sigmas.begin(), compvis_sigmas.end(), s); + if (it == compvis_sigmas.begin()) + return 0; + if (it == compvis_sigmas.end()) + return TIMESTEPS - 1; + int idx_high = static_cast(std::distance(compvis_sigmas.begin(), it)); + int idx_low = idx_high - 1; + if (std::abs(compvis_sigmas[idx_high] - s) < std::abs(compvis_sigmas[idx_low] - s)) { + return idx_high; + } + return idx_low; + }; + + int steps = static_cast(sigmas.size()) - 1; for (int i = 0; i < steps; i++) { - int timestep = TIMESTEPS - 1 - (TIMESTEPS / original_steps) * (int)floor(i * ((float)original_steps / steps)); - int prev_timestep = i >= steps - 1 ? 0 : TIMESTEPS - 1 - (TIMESTEPS / original_steps) * (int)floor((i + 1) * ((float)original_steps / steps)); + float sigma_to = sigmas[i + 1]; + int prev_timestep = get_timestep_from_sigma(sigma_to); int timestep_s = (int)floor((1 - eta) * prev_timestep); - float sigma = static_cast(compvis_sigmas[timestep]); - - if (i == 0) { - x *= std::sqrt(sigma * sigma + 1) / sigma; - } else { - x *= std::sqrt(sigma * sigma + 1); - } + float sigma = sigmas[i]; auto model_output_opt = model(x, sigma, i + 1); if (model_output_opt.empty()) { @@ -1389,9 +1607,9 @@ static sd::Tensor sample_tcd(denoise_cb_t model, sd::Tensor model_output = std::move(model_output_opt); model_output = (x - model_output) * (1.0f / sigma); - float alpha_prod_t = static_cast(alphas_cumprod[timestep]); + float alpha_prod_t = 1.0f / (sigma * sigma + 1.0f); float beta_prod_t = 1.0f - alpha_prod_t; - float alpha_prod_t_prev = static_cast(prev_timestep >= 0 ? alphas_cumprod[prev_timestep] : alphas_cumprod[0]); + float alpha_prod_t_prev = 1.0f / (sigma_to * sigma_to + 1.0f); float alpha_prod_s = static_cast(alphas_cumprod[timestep_s]); float beta_prod_s = 1.0f - alpha_prod_s; @@ -1399,12 +1617,12 @@ static sd::Tensor sample_tcd(denoise_cb_t model, std::sqrt(beta_prod_t) * model_output) * (1.0f / std::sqrt(alpha_prod_t)); - x = std::sqrt(alpha_prod_s) * pred_original_sample + - std::sqrt(beta_prod_s) * model_output; + x = std::sqrt(alpha_prod_s / alpha_prod_t_prev) * pred_original_sample + + std::sqrt(beta_prod_s / alpha_prod_t_prev) * model_output; - if (eta > 0 && i != steps - 1) { + if (eta > 0 && sigma_to > 0.0f) { x = std::sqrt(alpha_prod_t_prev / alpha_prod_s) * x + - std::sqrt(1.0f - alpha_prod_t_prev / alpha_prod_s) * sd::Tensor::randn_like(x, rng); + std::sqrt(1.0f / alpha_prod_t_prev - 1.0f / alpha_prod_s) * sd::Tensor::randn_like(x, rng); } } return x; @@ -1431,13 +1649,16 @@ static sd::Tensor sample_k_diffusion(sample_method_t method, case DPM2_SAMPLE_METHOD: return sample_dpm2(model, std::move(x), sigmas); case DPMPP2S_A_SAMPLE_METHOD: - return sample_dpmpp_2s_ancestral(model, std::move(x), sigmas, rng, eta); + if (is_flow_denoiser) + return sample_dpmpp_2s_ancestral_flow(model, std::move(x), sigmas, rng, eta); + else + return sample_dpmpp_2s_ancestral(model, std::move(x), sigmas, rng, eta); case DPMPP2M_SAMPLE_METHOD: return sample_dpmpp_2m(model, std::move(x), sigmas); case DPMPP2Mv2_SAMPLE_METHOD: return sample_dpmpp_2m_v2(model, std::move(x), sigmas); case LCM_SAMPLE_METHOD: - return sample_lcm(model, std::move(x), sigmas, rng); + return sample_lcm(model, std::move(x), sigmas, rng, is_flow_denoiser); case IPNDM_SAMPLE_METHOD: return sample_ipndm(model, std::move(x), sigmas); case IPNDM_V_SAMPLE_METHOD: @@ -1446,6 +1667,8 @@ static sd::Tensor sample_k_diffusion(sample_method_t method, return sample_res_multistep(model, std::move(x), sigmas, rng, eta); case RES_2S_SAMPLE_METHOD: return sample_res_2s(model, std::move(x), sigmas, rng, eta); + case ER_SDE_SAMPLE_METHOD: + return sample_er_sde(model, std::move(x), sigmas, rng, is_flow_denoiser, eta); case DDIM_TRAILING_SAMPLE_METHOD: return sample_ddim_trailing(model, std::move(x), sigmas, rng, eta); case TCD_SAMPLE_METHOD: diff --git a/src/diffusion_model.hpp b/src/diffusion_model.hpp index eb0debffc..c0a2a11c0 100644 --- a/src/diffusion_model.hpp +++ b/src/diffusion_model.hpp @@ -3,6 +3,7 @@ #include #include "anima.hpp" +#include "ernie_image.hpp" #include "flux.hpp" #include "mmdit.hpp" #include "qwen_image.hpp" @@ -516,4 +517,66 @@ struct ZImageModel : public DiffusionModel { } }; +struct ErnieImageModel : public DiffusionModel { + std::string prefix; + ErnieImage::ErnieImageRunner ernie_image; + + ErnieImageModel(ggml_backend_t backend, + bool offload_params_to_cpu, + const String2TensorStorage& tensor_storage_map = {}, + const std::string prefix = "model.diffusion_model") + : prefix(prefix), ernie_image(backend, offload_params_to_cpu, tensor_storage_map, prefix) { + } + + std::string get_desc() override { + return ernie_image.get_desc(); + } + + void alloc_params_buffer() override { + ernie_image.alloc_params_buffer(); + } + + void free_params_buffer() override { + ernie_image.free_params_buffer(); + } + + void free_compute_buffer() override { + ernie_image.free_compute_buffer(); + } + + void get_param_tensors(std::map& tensors) override { + ernie_image.get_param_tensors(tensors, prefix); + } + + size_t get_params_buffer_size() override { + return ernie_image.get_params_buffer_size(); + } + + void set_weight_adapter(const std::shared_ptr& adapter) override { + ernie_image.set_weight_adapter(adapter); + } + + int64_t get_adm_in_channels() override { + return 768; + } + + void set_flash_attention_enabled(bool enabled) { + ernie_image.set_flash_attention_enabled(enabled); + } + + void set_circular_axes(bool circular_x, bool circular_y) override { + ernie_image.set_circular_axes(circular_x, circular_y); + } + + sd::Tensor compute(int n_threads, + const DiffusionParams& diffusion_params) override { + GGML_ASSERT(diffusion_params.x != nullptr); + GGML_ASSERT(diffusion_params.timesteps != nullptr); + return ernie_image.compute(n_threads, + *diffusion_params.x, + *diffusion_params.timesteps, + tensor_or_empty(diffusion_params.context)); + } +}; + #endif diff --git a/src/ernie_image.hpp b/src/ernie_image.hpp new file mode 100644 index 000000000..d17648d2b --- /dev/null +++ b/src/ernie_image.hpp @@ -0,0 +1,438 @@ +#ifndef __SD_ERNIE_IMAGE_HPP__ +#define __SD_ERNIE_IMAGE_HPP__ + +#include +#include + +#include "common_dit.hpp" +#include "flux.hpp" +#include "qwen_image.hpp" +#include "rope.hpp" + +namespace ErnieImage { + constexpr int ERNIE_IMAGE_GRAPH_SIZE = 40960; + + __STATIC_INLINE__ ggml_tensor* timestep_embedding_sin_cos(ggml_context* ctx, + ggml_tensor* timesteps, + int dim, + int max_period = 10000) { + auto emb = ggml_ext_timestep_embedding(ctx, timesteps, dim, max_period, 1.0f); + int64_t half = dim / 2; + auto cos_part = ggml_view_2d(ctx, emb, half, emb->ne[1], emb->nb[1], 0); + auto sin_part = ggml_view_2d(ctx, emb, half, emb->ne[1], emb->nb[1], half * emb->nb[0]); + auto sin_first = ggml_concat(ctx, sin_part, cos_part, 0); + return sin_first; + } + + __STATIC_INLINE__ ggml_tensor* apply_rotary_emb(ggml_context* ctx, ggml_tensor* x, ggml_tensor* pe) { + // x: [N, S, heads, head_dim] + // pe: [2, S, 1, head_dim], stored as ggml [head_dim, 1, S, 2]. + int64_t head_dim = x->ne[0]; + int64_t heads = x->ne[1]; + int64_t S = x->ne[2]; + int64_t N = x->ne[3]; + int64_t rot_dim = pe->ne[0]; + GGML_ASSERT(rot_dim <= head_dim); + GGML_ASSERT(rot_dim % 2 == 0); + GGML_ASSERT(pe->ne[1] == 1 && pe->ne[2] == S && pe->ne[3] == 2); + + x = ggml_cont(ctx, x); + auto x_rot = ggml_ext_slice(ctx, x, 0, 0, rot_dim, false); + auto x_pass = rot_dim < head_dim ? ggml_ext_slice(ctx, x, 0, rot_dim, head_dim, false) : nullptr; + + int64_t half = rot_dim / 2; + auto x1 = ggml_view_4d(ctx, x_rot, half, heads, S, N, x_rot->nb[1], x_rot->nb[2], x_rot->nb[3], 0); + auto x2 = ggml_view_4d(ctx, x_rot, half, heads, S, N, x_rot->nb[1], x_rot->nb[2], x_rot->nb[3], half * x_rot->nb[0]); + x1 = ggml_cont(ctx, x1); + x2 = ggml_cont(ctx, x2); + auto rotated = ggml_concat(ctx, ggml_neg(ctx, x2), x1, 0); + + auto cos_emb = ggml_ext_slice(ctx, pe, 3, 0, 1, false); + auto sin_emb = ggml_ext_slice(ctx, pe, 3, 1, 2, false); + + auto out = ggml_add(ctx, ggml_mul(ctx, x_rot, cos_emb), ggml_mul(ctx, rotated, sin_emb)); + if (x_pass != nullptr) { + out = ggml_concat(ctx, out, x_pass, 0); + } + return out; + } + + struct ErnieImageAttention : public GGMLBlock { + int64_t num_heads; + int64_t head_dim; + + ErnieImageAttention(int64_t query_dim, + int64_t heads, + int64_t dim_head, + float eps = 1e-6f) + : num_heads(heads), head_dim(dim_head) { + int64_t inner_dim = heads * dim_head; + blocks["to_q"] = std::make_shared(query_dim, inner_dim, false); + blocks["to_k"] = std::make_shared(query_dim, inner_dim, false); + blocks["to_v"] = std::make_shared(query_dim, inner_dim, false); + blocks["norm_q"] = std::make_shared(dim_head, eps); + blocks["norm_k"] = std::make_shared(dim_head, eps); + blocks["to_out.0"] = std::make_shared(inner_dim, query_dim, false); + } + + ggml_tensor* forward(GGMLRunnerContext* ctx, + ggml_tensor* x, + ggml_tensor* pe, + ggml_tensor* attention_mask = nullptr) { + // x: [N, S, hidden_size] + // pe: [S, head_dim/2, 2, 2], generated in image-token-first order. + auto to_q = std::dynamic_pointer_cast(blocks["to_q"]); + auto to_k = std::dynamic_pointer_cast(blocks["to_k"]); + auto to_v = std::dynamic_pointer_cast(blocks["to_v"]); + auto norm_q = std::dynamic_pointer_cast(blocks["norm_q"]); + auto norm_k = std::dynamic_pointer_cast(blocks["norm_k"]); + auto to_out_0 = std::dynamic_pointer_cast(blocks["to_out.0"]); + + int64_t S = x->ne[1]; + int64_t N = x->ne[2]; + + auto q = to_q->forward(ctx, x); + auto k = to_k->forward(ctx, x); + auto v = to_v->forward(ctx, x); + + q = ggml_reshape_4d(ctx->ggml_ctx, q, head_dim, num_heads, S, N); // [N, S, heads, head_dim] + k = ggml_reshape_4d(ctx->ggml_ctx, k, head_dim, num_heads, S, N); // [N, S, heads, head_dim] + v = ggml_reshape_4d(ctx->ggml_ctx, v, head_dim, num_heads, S, N); // [N, S, heads, head_dim] + + q = norm_q->forward(ctx, q); + k = norm_k->forward(ctx, k); + + q = apply_rotary_emb(ctx->ggml_ctx, q, pe); + k = apply_rotary_emb(ctx->ggml_ctx, k, pe); + + q = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, q, 0, 2, 1, 3)); // [N, heads, S, head_dim] + q = ggml_reshape_3d(ctx->ggml_ctx, q, q->ne[0], q->ne[1], q->ne[2] * q->ne[3]); + + k = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, k, 0, 2, 1, 3)); // [N, heads, S, head_dim] + k = ggml_reshape_3d(ctx->ggml_ctx, k, k->ne[0], k->ne[1], k->ne[2] * k->ne[3]); + + x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, num_heads, attention_mask, true, ctx->flash_attn_enabled); // [N, S, hidden_size] + x = to_out_0->forward(ctx, x); + return x; + } + }; + + struct ErnieImageFeedForward : public GGMLBlock { + public: + ErnieImageFeedForward(int64_t hidden_size, int64_t ffn_hidden_size) { + blocks["gate_proj"] = std::make_shared(hidden_size, ffn_hidden_size, false); + blocks["up_proj"] = std::make_shared(hidden_size, ffn_hidden_size, false); + blocks["linear_fc2"] = std::make_shared(ffn_hidden_size, hidden_size, false); + } + + ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x) { + auto gate_proj = std::dynamic_pointer_cast(blocks["gate_proj"]); + auto up_proj = std::dynamic_pointer_cast(blocks["up_proj"]); + auto linear_fc2 = std::dynamic_pointer_cast(blocks["linear_fc2"]); + + auto gate = gate_proj->forward(ctx, x); + gate = ggml_ext_gelu(ctx->ggml_ctx, gate); + x = up_proj->forward(ctx, x); + x = ggml_mul(ctx->ggml_ctx, x, gate); + x = linear_fc2->forward(ctx, x); + return x; + } + }; + + struct ErnieImageSharedAdaLNBlock : public GGMLBlock { + public: + ErnieImageSharedAdaLNBlock(int64_t hidden_size, + int64_t num_heads, + int64_t ffn_hidden_size, + float eps = 1e-6f) { + blocks["adaLN_sa_ln"] = std::make_shared(hidden_size, eps); + blocks["self_attention"] = std::make_shared(hidden_size, + num_heads, + hidden_size / num_heads, + eps); + blocks["adaLN_mlp_ln"] = std::make_shared(hidden_size, eps); + blocks["mlp"] = std::make_shared(hidden_size, ffn_hidden_size); + } + + ggml_tensor* forward(GGMLRunnerContext* ctx, + ggml_tensor* x, + ggml_tensor* pe, + const std::vector& temb, + ggml_tensor* attention_mask = nullptr) { + // x: [N, image_tokens + text_tokens, hidden_size] + auto adaLN_sa_ln = std::dynamic_pointer_cast(blocks["adaLN_sa_ln"]); + auto self_attention = std::dynamic_pointer_cast(blocks["self_attention"]); + auto adaLN_mlp_ln = std::dynamic_pointer_cast(blocks["adaLN_mlp_ln"]); + auto mlp = std::dynamic_pointer_cast(blocks["mlp"]); + + auto shift_msa = temb[0]; + auto scale_msa = temb[1]; + auto gate_msa = temb[2]; + auto shift_mlp = temb[3]; + auto scale_mlp = temb[4]; + auto gate_mlp = temb[5]; + + auto residual = x; + x = adaLN_sa_ln->forward(ctx, x); + x = Flux::modulate(ctx->ggml_ctx, x, shift_msa, scale_msa, true); + auto attn_out = self_attention->forward(ctx, x, pe, attention_mask); + x = ggml_add(ctx->ggml_ctx, residual, ggml_mul(ctx->ggml_ctx, attn_out, gate_msa)); + + residual = x; + x = adaLN_mlp_ln->forward(ctx, x); + x = Flux::modulate(ctx->ggml_ctx, x, shift_mlp, scale_mlp, true); + x = ggml_add(ctx->ggml_ctx, residual, ggml_mul(ctx->ggml_ctx, mlp->forward(ctx, x), gate_mlp)); + return x; + } + }; + + struct ErnieImageAdaLNContinuous : public GGMLBlock { + public: + ErnieImageAdaLNContinuous(int64_t hidden_size, float eps = 1e-6f) { + blocks["norm"] = std::make_shared(hidden_size, eps, false); + blocks["linear"] = std::make_shared(hidden_size, hidden_size * 2, true); + } + + ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x, ggml_tensor* conditioning) { + auto norm = std::dynamic_pointer_cast(blocks["norm"]); + auto linear = std::dynamic_pointer_cast(blocks["linear"]); + + auto mods = ggml_ext_chunk(ctx->ggml_ctx, linear->forward(ctx, conditioning), 2, 0); + auto scale = mods[0]; + auto shift = mods[1]; + + x = norm->forward(ctx, x); + x = Flux::modulate(ctx->ggml_ctx, x, shift, scale); + return x; + } + }; + + struct ErnieImageParams { + int64_t hidden_size = 4096; + int64_t num_heads = 32; + int64_t num_layers = 36; + int64_t ffn_hidden_size = 12288; + int64_t in_channels = 128; + int64_t out_channels = 128; + int patch_size = 1; + int64_t text_in_dim = 3072; + int theta = 256; + std::vector axes_dim = {32, 48, 48}; + int axes_dim_sum = 128; + float eps = 1e-6f; + }; + + class ErnieImageModel : public GGMLBlock { + public: + ErnieImageParams params; + + ErnieImageModel() = default; + ErnieImageModel(ErnieImageParams params) + : params(params) { + blocks["x_embedder.proj"] = std::make_shared(params.in_channels, + params.hidden_size, + std::pair{params.patch_size, params.patch_size}, + std::pair{params.patch_size, params.patch_size}, + std::pair{0, 0}, + std::pair{1, 1}, + true); + if (params.text_in_dim != params.hidden_size) { + blocks["text_proj"] = std::make_shared(params.text_in_dim, params.hidden_size, false); + } + blocks["time_embedding"] = std::make_shared(params.hidden_size, params.hidden_size); + blocks["adaLN_modulation.1"] = std::make_shared(params.hidden_size, 6 * params.hidden_size, true); + + for (int i = 0; i < params.num_layers; i++) { + blocks["layers." + std::to_string(i)] = std::make_shared(params.hidden_size, + params.num_heads, + params.ffn_hidden_size, + params.eps); + } + + blocks["final_norm"] = std::make_shared(params.hidden_size, params.eps); + blocks["final_linear"] = std::make_shared(params.hidden_size, + params.patch_size * params.patch_size * params.out_channels, + true); + } + + ggml_tensor* forward(GGMLRunnerContext* ctx, + ggml_tensor* x, + ggml_tensor* timestep, + ggml_tensor* context, + ggml_tensor* pe) { + // x: [N, C, H, W] + // context: [N, text_tokens, 3072] + // pe: [image_tokens + text_tokens, head_dim/2, 2, 2] + GGML_ASSERT(context != nullptr); + GGML_ASSERT(x->ne[1] % params.patch_size == 0 && x->ne[0] % params.patch_size == 0); + + int64_t W = x->ne[0]; + int64_t H = x->ne[1]; + int64_t Hp = H / params.patch_size; + int64_t Wp = W / params.patch_size; + int64_t n_img = Hp * Wp; + int64_t N = x->ne[3]; + + auto x_embedder_proj = std::dynamic_pointer_cast(blocks["x_embedder.proj"]); + auto time_embedding = std::dynamic_pointer_cast(blocks["time_embedding"]); + auto adaLN_mod = std::dynamic_pointer_cast(blocks["adaLN_modulation.1"]); + auto final_norm = std::dynamic_pointer_cast(blocks["final_norm"]); + auto final_linear = std::dynamic_pointer_cast(blocks["final_linear"]); + + auto img = x_embedder_proj->forward(ctx, x); // [N, hidden_size, Hp, Wp] + img = ggml_reshape_3d(ctx->ggml_ctx, img, img->ne[0] * img->ne[1], img->ne[2], N); // [N, hidden_size, image_tokens] + img = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, img, 1, 0, 2, 3)); // [N, image_tokens, hidden_size] + + auto txt = context; + auto text_proj = std::dynamic_pointer_cast(blocks["text_proj"]); + if (text_proj) { + txt = text_proj->forward(ctx, txt); + } + + auto hidden_states = ggml_concat(ctx->ggml_ctx, img, txt, 1); // [N, image_tokens + text_tokens, hidden_size] + + auto sample = timestep_embedding_sin_cos(ctx->ggml_ctx, timestep, static_cast(params.hidden_size)); + auto c = time_embedding->forward(ctx, sample); // [N, hidden_size] + + auto mod_params = adaLN_mod->forward(ctx, ggml_silu(ctx->ggml_ctx, c)); // [N, 6 * hidden_size] + auto chunks = ggml_ext_chunk(ctx->ggml_ctx, mod_params, 6, 0); + std::vector temb; + temb.reserve(6); + for (auto chunk : chunks) { + temb.push_back(ggml_reshape_3d(ctx->ggml_ctx, chunk, chunk->ne[0], 1, chunk->ne[1])); // [N, 1, hidden_size] + } + + for (int i = 0; i < params.num_layers; i++) { + auto layer = std::dynamic_pointer_cast(blocks["layers." + std::to_string(i)]); + hidden_states = layer->forward(ctx, hidden_states, pe, temb); + } + + hidden_states = final_norm->forward(ctx, hidden_states, c); + hidden_states = final_linear->forward(ctx, hidden_states); // [N, image_tokens, p*p*out_channels] + auto patches = ggml_ext_slice(ctx->ggml_ctx, hidden_states, 1, 0, n_img); // [N, image_tokens, hidden_size] + + auto out = DiT::unpatchify(ctx->ggml_ctx, + patches, + Hp, + Wp, + params.patch_size, + params.patch_size, + false); // [N, out_channels, H, W] + return out; + } + }; + + struct ErnieImageRunner : public GGMLRunner { + ErnieImageParams ernie_params; + ErnieImageModel ernie_image; + std::vector pe_vec; + + ErnieImageRunner(ggml_backend_t backend, + bool offload_params_to_cpu, + const String2TensorStorage& tensor_storage_map = {}, + const std::string prefix = "") + : GGMLRunner(backend, offload_params_to_cpu) { + ernie_params.num_layers = 0; + for (const auto& [name, tensor_storage] : tensor_storage_map) { + if (!starts_with(name, prefix)) { + continue; + } + if (ends_with(name, "x_embedder.proj.weight") && tensor_storage.n_dims == 4) { + ernie_params.patch_size = static_cast(tensor_storage.ne[0]); + ernie_params.in_channels = tensor_storage.ne[2]; + ernie_params.hidden_size = tensor_storage.ne[3]; + } else if (ends_with(name, "text_proj.weight") && tensor_storage.n_dims == 2) { + ernie_params.text_in_dim = tensor_storage.ne[0]; + } else if (ends_with(name, "layers.0.self_attention.norm_q.weight")) { + int64_t head_dim = tensor_storage.ne[0]; + ernie_params.num_heads = ernie_params.hidden_size / head_dim; + } else if (ends_with(name, "layers.0.mlp.gate_proj.weight") && tensor_storage.n_dims == 2) { + ernie_params.ffn_hidden_size = tensor_storage.ne[1]; + } else if (ends_with(name, "final_linear.weight") && tensor_storage.n_dims == 2) { + int64_t out_dim = tensor_storage.ne[1]; + ernie_params.out_channels = out_dim / ernie_params.patch_size / ernie_params.patch_size; + } + + size_t pos = name.find("layers."); + if (pos != std::string::npos) { + std::string layer_name = name.substr(pos); + auto items = split_string(layer_name, '.'); + if (items.size() > 1) { + int block_index = atoi(items[1].c_str()); + if (block_index + 1 > ernie_params.num_layers) { + ernie_params.num_layers = block_index + 1; + } + } + } + } + if (ernie_params.num_layers == 0) { + ernie_params.num_layers = 36; + } + ernie_params.axes_dim_sum = 0; + for (int axis_dim : ernie_params.axes_dim) { + ernie_params.axes_dim_sum += axis_dim; + } + + LOG_INFO("ernie_image: layers = %" PRId64 ", hidden_size = %" PRId64 ", heads = %" PRId64 + ", ffn_hidden_size = %" PRId64 ", in_channels = %" PRId64 ", out_channels = %" PRId64, + ernie_params.num_layers, + ernie_params.hidden_size, + ernie_params.num_heads, + ernie_params.ffn_hidden_size, + ernie_params.in_channels, + ernie_params.out_channels); + + ernie_image = ErnieImageModel(ernie_params); + ernie_image.init(params_ctx, tensor_storage_map, prefix); + } + + std::string get_desc() override { + return "ernie_image"; + } + + void get_param_tensors(std::map& tensors, const std::string prefix) { + ernie_image.get_param_tensors(tensors, prefix); + } + + ggml_cgraph* build_graph(const sd::Tensor& x_tensor, + const sd::Tensor& timesteps_tensor, + const sd::Tensor& context_tensor) { + ggml_cgraph* gf = new_graph_custom(ERNIE_IMAGE_GRAPH_SIZE); + ggml_tensor* x = make_input(x_tensor); + ggml_tensor* timesteps = make_input(timesteps_tensor); + GGML_ASSERT(x->ne[3] == 1); + GGML_ASSERT(!context_tensor.empty()); + ggml_tensor* context = make_input(context_tensor); + + pe_vec = Rope::gen_ernie_image_pe(static_cast(x->ne[1]), + static_cast(x->ne[0]), + ernie_params.patch_size, + static_cast(x->ne[3]), + static_cast(context->ne[1]), + ernie_params.theta, + circular_y_enabled, + circular_x_enabled, + ernie_params.axes_dim); + int pos_len = static_cast(pe_vec.size() / ernie_params.axes_dim_sum / 2); + auto pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, ernie_params.axes_dim_sum, 1, pos_len, 2); + set_backend_tensor_data(pe, pe_vec.data()); + + auto runner_ctx = get_context(); + ggml_tensor* out = ernie_image.forward(&runner_ctx, x, timesteps, context, pe); + ggml_build_forward_expand(gf, out); + return gf; + } + + sd::Tensor compute(int n_threads, + const sd::Tensor& x, + const sd::Tensor& timesteps, + const sd::Tensor& context) { + auto get_graph = [&]() -> ggml_cgraph* { + return build_graph(x, timesteps, context); + }; + return restore_trailing_singleton_dims(GGMLRunner::compute(get_graph, n_threads, false), x.dim()); + } + }; +} // namespace ErnieImage + +#endif // __SD_ERNIE_IMAGE_HPP__ diff --git a/src/llm.hpp b/src/llm.hpp index 9eacdb905..4afaa3ba6 100644 --- a/src/llm.hpp +++ b/src/llm.hpp @@ -28,6 +28,7 @@ namespace LLM { QWEN2_5_VL, QWEN3, MISTRAL_SMALL_3_2, + MINISTRAL_3_3B, ARCH_COUNT, }; @@ -35,6 +36,7 @@ namespace LLM { "qwen2.5vl", "qwen3", "mistral_small3.2", + "ministral3.3b", }; struct LLMVisionParams { @@ -419,6 +421,9 @@ namespace LLM { if (arch == LLMArch::MISTRAL_SMALL_3_2) { q = ggml_rope_ext(ctx->ggml_ctx, q, input_pos, nullptr, 128, GGML_ROPE_TYPE_NORMAL, 8192, 1000000000.f, 1.f, 0.f, 1.f, 32.f, 1.f); k = ggml_rope_ext(ctx->ggml_ctx, k, input_pos, nullptr, 128, GGML_ROPE_TYPE_NORMAL, 8192, 1000000000.f, 1.f, 0.f, 1.f, 32.f, 1.f); + } else if (arch == LLMArch::MINISTRAL_3_3B) { + q = ggml_rope_ext(ctx->ggml_ctx, q, input_pos, nullptr, 128, GGML_ROPE_TYPE_NEOX, 262144, 1000000.f, 1.f, 0.f, 1.f, 32.f, 1.f); + k = ggml_rope_ext(ctx->ggml_ctx, k, input_pos, nullptr, 128, GGML_ROPE_TYPE_NEOX, 262144, 1000000.f, 1.f, 0.f, 1.f, 32.f, 1.f); } else if (arch == LLMArch::QWEN3) { q = ggml_rope_ext(ctx->ggml_ctx, q, input_pos, nullptr, 128, GGML_ROPE_TYPE_NEOX, 40960, 1000000.f, 1.f, 0.f, 1.f, 32.f, 1.f); k = ggml_rope_ext(ctx->ggml_ctx, k, input_pos, nullptr, 128, GGML_ROPE_TYPE_NEOX, 40960, 1000000.f, 1.f, 0.f, 1.f, 32.f, 1.f); @@ -634,7 +639,7 @@ namespace LLM { bool enable_vision_ = false) : GGMLRunner(backend, offload_params_to_cpu), enable_vision(enable_vision_) { params.arch = arch; - if (arch == LLMArch::MISTRAL_SMALL_3_2) { + if (arch == LLMArch::MISTRAL_SMALL_3_2 || arch == LLMArch::MINISTRAL_3_3B) { params.head_dim = 128; params.num_heads = 32; params.num_kv_heads = 8; @@ -746,7 +751,7 @@ namespace LLM { } int64_t n_tokens = input_ids->ne[0]; - if (params.arch == LLMArch::MISTRAL_SMALL_3_2 || params.arch == LLMArch::QWEN3) { + if (params.arch == LLMArch::MISTRAL_SMALL_3_2 || params.arch == LLMArch::MINISTRAL_3_3B || params.arch == LLMArch::QWEN3) { input_pos_vec.resize(n_tokens); for (int i = 0; i < n_tokens; ++i) { input_pos_vec[i] = i; @@ -982,7 +987,7 @@ namespace LLM { const std::string prefix = "", bool enable_vision = false) : model(arch, backend, offload_params_to_cpu, tensor_storage_map, prefix, enable_vision) { - if (arch == LLMArch::MISTRAL_SMALL_3_2) { + if (arch == LLMArch::MISTRAL_SMALL_3_2 || arch == LLMArch::MINISTRAL_3_3B) { tokenizer = std::make_shared(); } else { tokenizer = std::make_shared(); diff --git a/src/model.cpp b/src/model.cpp index 1639c161f..3479a0bea 100644 --- a/src/model.cpp +++ b/src/model.cpp @@ -2,6 +2,7 @@ #include #include #include +#include #include #include #include @@ -12,8 +13,11 @@ #include #include -#include "gguf_reader.hpp" #include "model.h" +#include "model_io/gguf_io.h" +#include "model_io/safetensors_io.h" +#include "model_io/torch_legacy_io.h" +#include "model_io/torch_zip_io.h" #include "stable-diffusion.h" #include "util.h" @@ -21,6 +25,7 @@ #include "ggml-backend.h" #include "ggml-cpu.h" #include "ggml.h" +#include "zip.h" #include "name_conversion.h" #include "stable-diffusion.h" @@ -37,40 +42,6 @@ #include "ggml-opencl.h" #endif -#define ST_HEADER_SIZE_LEN 8 - -uint64_t read_u64(uint8_t* buffer) { - // little endian - uint64_t value = 0; - value |= static_cast(buffer[7]) << 56; - value |= static_cast(buffer[6]) << 48; - value |= static_cast(buffer[5]) << 40; - value |= static_cast(buffer[4]) << 32; - value |= static_cast(buffer[3]) << 24; - value |= static_cast(buffer[2]) << 16; - value |= static_cast(buffer[1]) << 8; - value |= static_cast(buffer[0]); - return value; -} - -int32_t read_int(uint8_t* buffer) { - // little endian - int value = 0; - value |= buffer[3] << 24; - value |= buffer[2] << 16; - value |= buffer[1] << 8; - value |= buffer[0]; - return value; -} - -uint16_t read_short(uint8_t* buffer) { - // little endian - uint16_t value = 0; - value |= buffer[1] << 8; - value |= buffer[0]; - return value; -} - /*================================================= Preprocess ==================================================*/ const char* unused_tensors[] = { @@ -110,7 +81,7 @@ const char* unused_tensors[] = { "first_stage_model.bn.", }; -bool is_unused_tensor(std::string name) { +bool is_unused_tensor(const std::string& name) { for (size_t i = 0; i < sizeof(unused_tensors) / sizeof(const char*); i++) { if (starts_with(name, unused_tensors[i])) { return true; @@ -250,79 +221,6 @@ void ModelLoader::add_tensor_storage(const TensorStorage& tensor_storage) { tensor_storage_map[tensor_storage.name] = tensor_storage; } -bool is_zip_file(const std::string& file_path) { - zip_t* zip = zip_open(file_path.c_str(), 0, 'r'); - if (zip == nullptr) { - return false; - } - zip_close(zip); - return true; -} - -bool is_gguf_file(const std::string& file_path) { - std::ifstream file(file_path, std::ios::binary); - if (!file.is_open()) { - return false; - } - - char magic[4]; - - file.read(magic, sizeof(magic)); - if (!file) { - return false; - } - for (uint32_t i = 0; i < sizeof(magic); i++) { - if (magic[i] != GGUF_MAGIC[i]) { - return false; - } - } - - return true; -} - -bool is_safetensors_file(const std::string& file_path) { - std::ifstream file(file_path, std::ios::binary); - if (!file.is_open()) { - return false; - } - - // get file size - file.seekg(0, file.end); - size_t file_size_ = file.tellg(); - file.seekg(0, file.beg); - - // read header size - if (file_size_ <= ST_HEADER_SIZE_LEN) { - return false; - } - - uint8_t header_size_buf[ST_HEADER_SIZE_LEN]; - file.read((char*)header_size_buf, ST_HEADER_SIZE_LEN); - if (!file) { - return false; - } - - size_t header_size_ = read_u64(header_size_buf); - if (header_size_ >= file_size_ || header_size_ <= 2) { - return false; - } - - // read header - std::vector header_buf; - header_buf.resize(header_size_ + 1); - header_buf[header_size_] = '\0'; - file.read(header_buf.data(), header_size_); - if (!file) { - return false; - } - try { - nlohmann::json header_ = nlohmann::json::parse(header_buf.data()); - } catch (const std::exception&) { - return false; - } - return true; -} - bool ModelLoader::init_from_file(const std::string& file_path, const std::string& prefix) { if (is_directory(file_path)) { LOG_INFO("load %s using diffusers format", file_path.c_str()); @@ -333,9 +231,12 @@ bool ModelLoader::init_from_file(const std::string& file_path, const std::string } else if (is_safetensors_file(file_path)) { LOG_INFO("load %s using safetensors format", file_path.c_str()); return init_from_safetensors_file(file_path, prefix); - } else if (is_zip_file(file_path)) { - LOG_INFO("load %s using checkpoint format", file_path.c_str()); - return init_from_ckpt_file(file_path, prefix); + } else if (is_torch_zip_file(file_path)) { + LOG_INFO("load %s using torch zip format", file_path.c_str()); + return init_from_torch_zip_file(file_path, prefix); + } else if (init_from_torch_legacy_file(file_path, prefix)) { + LOG_INFO("load %s using torch legacy format", file_path.c_str()); + return true; } else { if (file_exists(file_path)) { LOG_WARN("unknown format %s", file_path.c_str()); @@ -375,242 +276,121 @@ bool ModelLoader::init_from_file_and_convert_name(const std::string& file_path, bool ModelLoader::init_from_gguf_file(const std::string& file_path, const std::string& prefix) { LOG_DEBUG("init from '%s'", file_path.c_str()); - file_paths_.push_back(file_path); - size_t file_index = file_paths_.size() - 1; - - gguf_context* ctx_gguf_ = nullptr; - ggml_context* ctx_meta_ = nullptr; - ctx_gguf_ = gguf_init_from_file(file_path.c_str(), {true, &ctx_meta_}); - if (!ctx_gguf_) { - LOG_ERROR("failed to open '%s' with gguf_init_from_file. Try to open it with GGUFReader.", file_path.c_str()); - GGUFReader gguf_reader; - if (!gguf_reader.load(file_path)) { - LOG_ERROR("failed to open '%s' with GGUFReader.", file_path.c_str()); - return false; - } - - size_t data_offset = gguf_reader.data_offset(); - for (const auto& gguf_tensor_info : gguf_reader.tensors()) { - std::string name = gguf_tensor_info.name; - if (!starts_with(name, prefix)) { - name = prefix + name; - } - - TensorStorage tensor_storage( - name, - gguf_tensor_info.type, - gguf_tensor_info.shape.data(), - static_cast(gguf_tensor_info.shape.size()), - file_index, - data_offset + gguf_tensor_info.offset); - - // LOG_DEBUG("%s %s", name.c_str(), tensor_storage.to_string().c_str()); - - add_tensor_storage(tensor_storage); - } - - return true; + std::vector tensor_storages; + std::string error; + if (!read_gguf_file(file_path, tensor_storages, &error)) { + LOG_ERROR("%s", error.c_str()); + return false; } - int n_tensors = static_cast(gguf_get_n_tensors(ctx_gguf_)); - - size_t total_size = 0; - size_t data_offset = gguf_get_data_offset(ctx_gguf_); - for (int i = 0; i < n_tensors; i++) { - std::string name = gguf_get_tensor_name(ctx_gguf_, i); - ggml_tensor* dummy = ggml_get_tensor(ctx_meta_, name.c_str()); - size_t offset = data_offset + gguf_get_tensor_offset(ctx_gguf_, i); + file_paths_.push_back(file_path); + size_t file_index = file_paths_.size() - 1; - // LOG_DEBUG("%s", name.c_str()); + for (auto& tensor_storage : tensor_storages) { + // LOG_DEBUG("%s", tensor_storage.name.c_str()); - if (!starts_with(name, prefix)) { - name = prefix + name; + if (!starts_with(tensor_storage.name, prefix)) { + tensor_storage.name = prefix + tensor_storage.name; } - - TensorStorage tensor_storage(name, dummy->type, dummy->ne, ggml_n_dims(dummy), file_index, offset); - - GGML_ASSERT(ggml_nbytes(dummy) == tensor_storage.nbytes()); + tensor_storage.file_index = file_index; add_tensor_storage(tensor_storage); } - gguf_free(ctx_gguf_); - ggml_free(ctx_meta_); - return true; } /*================================================= SafeTensorsModelLoader ==================================================*/ -ggml_type str_to_ggml_type(const std::string& dtype) { - ggml_type ttype = GGML_TYPE_COUNT; - if (dtype == "F16") { - ttype = GGML_TYPE_F16; - } else if (dtype == "BF16") { - ttype = GGML_TYPE_BF16; - } else if (dtype == "F32") { - ttype = GGML_TYPE_F32; - } else if (dtype == "F64") { - ttype = GGML_TYPE_F32; - } else if (dtype == "F8_E4M3") { - ttype = GGML_TYPE_F16; - } else if (dtype == "F8_E5M2") { - ttype = GGML_TYPE_F16; - } else if (dtype == "I64") { - ttype = GGML_TYPE_I32; - } - return ttype; -} - -// https://huggingface.co/docs/safetensors/index bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const std::string& prefix) { LOG_DEBUG("init from '%s', prefix = '%s'", file_path.c_str(), prefix.c_str()); - file_paths_.push_back(file_path); - size_t file_index = file_paths_.size() - 1; - std::ifstream file(file_path, std::ios::binary); - if (!file.is_open()) { - LOG_ERROR("failed to open '%s'", file_path.c_str()); - file_paths_.pop_back(); + + std::vector tensor_storages; + std::string error; + if (!read_safetensors_file(file_path, tensor_storages, &error)) { + LOG_ERROR("%s", error.c_str()); return false; } - // get file size - file.seekg(0, file.end); - size_t file_size_ = file.tellg(); - file.seekg(0, file.beg); + file_paths_.push_back(file_path); + size_t file_index = file_paths_.size() - 1; - // read header size - if (file_size_ <= ST_HEADER_SIZE_LEN) { - LOG_ERROR("invalid safetensor file '%s'", file_path.c_str()); - file_paths_.pop_back(); - return false; - } + for (auto& tensor_storage : tensor_storages) { + if (is_unused_tensor(tensor_storage.name)) { + continue; + } - uint8_t header_size_buf[ST_HEADER_SIZE_LEN]; - file.read((char*)header_size_buf, ST_HEADER_SIZE_LEN); - if (!file) { - LOG_ERROR("read safetensors header size failed: '%s'", file_path.c_str()); - return false; - } + if (!starts_with(tensor_storage.name, prefix)) { + tensor_storage.name = prefix + tensor_storage.name; + } + tensor_storage.file_index = file_index; - size_t header_size_ = read_u64(header_size_buf); - if (header_size_ >= file_size_) { - LOG_ERROR("invalid safetensor file '%s'", file_path.c_str()); - file_paths_.pop_back(); - return false; - } + add_tensor_storage(tensor_storage); - // read header - std::vector header_buf; - header_buf.resize(header_size_ + 1); - header_buf[header_size_] = '\0'; - file.read(header_buf.data(), header_size_); - if (!file) { - LOG_ERROR("read safetensors header failed: '%s'", file_path.c_str()); - file_paths_.pop_back(); - return false; + // LOG_DEBUG("%s", tensor_storage.to_string().c_str()); } - nlohmann::json header_; - try { - header_ = nlohmann::json::parse(header_buf.data()); - } catch (const std::exception&) { - LOG_ERROR("parsing safetensors header failed", file_path.c_str()); - file_paths_.pop_back(); - return false; - } + return true; +} - for (auto& item : header_.items()) { - std::string name = item.key(); - nlohmann::json tensor_info = item.value(); - // LOG_DEBUG("%s %s\n", name.c_str(), tensor_info.dump().c_str()); +/*================================================= TorchLegacyModelLoader ==================================================*/ - if (name == "__metadata__") { - continue; - } +bool ModelLoader::init_from_torch_legacy_file(const std::string& file_path, const std::string& prefix) { + LOG_DEBUG("init from torch legacy '%s'", file_path.c_str()); - if (is_unused_tensor(name)) { - continue; + std::vector tensor_storages; + std::string error; + if (!read_torch_legacy_file(file_path, tensor_storages, &error)) { + if ((!error.empty()) && (ends_with(file_path, ".pt") || ends_with(file_path, ".pth"))) { + LOG_WARN("%s", error.c_str()); } + return false; + } - std::string dtype = tensor_info["dtype"]; - nlohmann::json shape = tensor_info["shape"]; + file_paths_.push_back(file_path); + size_t file_index = file_paths_.size() - 1; - if (dtype == "U8") { + for (auto& tensor_storage : tensor_storages) { + if (is_unused_tensor(tensor_storage.name)) { continue; } - size_t begin = tensor_info["data_offsets"][0].get(); - size_t end = tensor_info["data_offsets"][1].get(); - - ggml_type type = str_to_ggml_type(dtype); - if (type == GGML_TYPE_COUNT) { - LOG_ERROR("unsupported dtype '%s' (tensor '%s')", dtype.c_str(), name.c_str()); - return false; - } - - if (shape.size() > SD_MAX_DIMS) { - LOG_ERROR("invalid tensor '%s'", name.c_str()); - return false; + if (!starts_with(tensor_storage.name, prefix)) { + tensor_storage.name = prefix + tensor_storage.name; } + tensor_storage.file_index = file_index; - int n_dims = (int)shape.size(); - int64_t ne[SD_MAX_DIMS] = {1, 1, 1, 1, 1}; - for (int i = 0; i < n_dims; i++) { - ne[i] = shape[i].get(); - } + add_tensor_storage(tensor_storage); + } - if (n_dims == 5) { - n_dims = 4; - ne[0] = ne[0] * ne[1]; - ne[1] = ne[2]; - ne[2] = ne[3]; - ne[3] = ne[4]; - } + return true; +} - // ggml_n_dims returns 1 for scalars - if (n_dims == 0) { - n_dims = 1; - } +/*================================================= TorchZipModelLoader ==================================================*/ - if (!starts_with(name, prefix)) { - name = prefix + name; - } +bool ModelLoader::init_from_torch_zip_file(const std::string& file_path, const std::string& prefix) { + LOG_DEBUG("init from '%s'", file_path.c_str()); - TensorStorage tensor_storage(name, type, ne, n_dims, file_index, ST_HEADER_SIZE_LEN + header_size_ + begin); - tensor_storage.reverse_ne(); + std::vector tensor_storages; + std::string error; + if (!read_torch_zip_file(file_path, tensor_storages, &error)) { + LOG_ERROR("%s", error.c_str()); + return false; + } - size_t tensor_data_size = end - begin; + file_paths_.push_back(file_path); + size_t file_index = file_paths_.size() - 1; - bool tensor_size_ok; - if (dtype == "F8_E4M3") { - tensor_storage.is_f8_e4m3 = true; - // f8 -> f16 - tensor_size_ok = (tensor_storage.nbytes() == tensor_data_size * 2); - } else if (dtype == "F8_E5M2") { - tensor_storage.is_f8_e5m2 = true; - // f8 -> f16 - tensor_size_ok = (tensor_storage.nbytes() == tensor_data_size * 2); - } else if (dtype == "F64") { - tensor_storage.is_f64 = true; - // f64 -> f32 - tensor_size_ok = (tensor_storage.nbytes() * 2 == tensor_data_size); - } else if (dtype == "I64") { - tensor_storage.is_i64 = true; - // i64 -> i32 - tensor_size_ok = (tensor_storage.nbytes() * 2 == tensor_data_size); - } else { - tensor_size_ok = (tensor_storage.nbytes() == tensor_data_size); - } - if (!tensor_size_ok) { - LOG_ERROR("size mismatch for tensor '%s' (%s)\n", name.c_str(), dtype.c_str()); - return false; + for (auto& tensor_storage : tensor_storages) { + if (!starts_with(tensor_storage.name, prefix)) { + tensor_storage.name = prefix + tensor_storage.name; } + tensor_storage.file_index = file_index; add_tensor_storage(tensor_storage); - // LOG_DEBUG("%s %s", tensor_storage.to_string().c_str(), dtype.c_str()); + // LOG_DEBUG("%s", tensor_storage.to_string().c_str()); } return true; @@ -642,367 +422,6 @@ bool ModelLoader::init_from_diffusers_file(const std::string& file_path, const s return true; } -/*================================================= CkptModelLoader ==================================================*/ - -// $ python -m pickletools sd-v1-4/archive/data.pkl | head -n 100 -// 0: \x80 PROTO 2 -// 2: } EMPTY_DICT -// 3: q BINPUT 0 -// 5: ( MARK -// 6: X BINUNICODE 'epoch' -// 16: q BINPUT 1 -// 18: K BININT1 6 -// 20: X BINUNICODE 'global_step' -// 36: q BINPUT 2 -// 38: J BININT 470000 -// 43: X BINUNICODE 'pytorch-lightning_version' -// 73: q BINPUT 3 -// 75: X BINUNICODE '1.4.2' -// 85: q BINPUT 4 -// 87: X BINUNICODE 'state_dict' -// 102: q BINPUT 5 -// 104: } EMPTY_DICT -// 105: q BINPUT 6 -// 107: ( MARK -// 108: X BINUNICODE 'betas' -// 118: q BINPUT 7 -// 120: c GLOBAL 'torch._utils _rebuild_tensor_v2' -// 153: q BINPUT 8 -// 155: ( MARK -// 156: ( MARK -// 157: X BINUNICODE 'storage' -// 169: q BINPUT 9 -// 171: c GLOBAL 'torch FloatStorage' -// 191: q BINPUT 10 -// 193: X BINUNICODE '0' -// 199: q BINPUT 11 -// 201: X BINUNICODE 'cpu' -// 209: q BINPUT 12 -// 211: M BININT2 1000 -// 214: t TUPLE (MARK at 156) -// 215: q BINPUT 13 -// 217: Q BINPERSID -// 218: K BININT1 0 -// 220: M BININT2 1000 -// ............................... -// 3201: q BINPUT 250 -// 3203: R REDUCE -// 3204: q BINPUT 251 -// 3206: X BINUNICODE 'model.diffusion_model.input_blocks.1.1.proj_in.weight' -// 3264: q BINPUT 252 -// 3266: h BINGET 8 -// 3268: ( MARK -// 3269: ( MARK -// 3270: h BINGET 9 -// 3272: h BINGET 10 -// 3274: X BINUNICODE '30' -// 3281: q BINPUT 253 -// 3283: h BINGET 12 -// 3285: J BININT 102400 -// 3290: t TUPLE (MARK at 3269) -// 3291: q BINPUT 254 -// 3293: Q BINPERSID -// 3294: K BININT1 0 -// 3296: ( MARK -// 3297: M BININT2 320 -// 3300: M BININT2 320 -// 3303: K BININT1 1 -// 3305: K BININT1 1 -// 3307: t TUPLE (MARK at 3296) -// 3308: q BINPUT 255 -// 3310: ( MARK -// 3311: M BININT2 320 -// 3314: K BININT1 1 -// 3316: K BININT1 1 -// 3318: K BININT1 1 -// 3320: t TUPLE (MARK at 3310) -// 3321: r LONG_BINPUT 256 -// 3326: \x89 NEWFALSE -// 3327: h BINGET 16 -// 3329: ) EMPTY_TUPLE -// 3330: R REDUCE -// 3331: r LONG_BINPUT 257 -// 3336: t TUPLE (MARK at 3268) -// 3337: r LONG_BINPUT 258 -// 3342: R REDUCE -// 3343: r LONG_BINPUT 259 -// 3348: X BINUNICODE 'model.diffusion_model.input_blocks.1.1.proj_in.bias' -// 3404: r LONG_BINPUT 260 -// 3409: h BINGET 8 -// 3411: ( MARK -// 3412: ( MARK -// 3413: h BINGET 9 -// 3415: h BINGET 10 -// 3417: X BINUNICODE '31' - -struct PickleTensorReader { - enum ReadPhase { - READ_NAME, - READ_DATA, - CHECK_SIZE, - READ_DIMENS - }; - ReadPhase phase = READ_NAME; - size_t entry_size = 0; - int32_t nelements = 0; - - TensorStorage tensor_storage; - - static ggml_type global_type; // all pickle_tensors data type - static bool read_global_type; - - bool read_int_value(uint32_t value) { - if (phase == CHECK_SIZE) { - if (entry_size == value * ggml_type_size(tensor_storage.type)) { - nelements = value; - phase = READ_DIMENS; - return true; - } else { - phase = READ_NAME; - } - } else if (phase == READ_DIMENS) { - if (tensor_storage.n_dims + 1 > SD_MAX_DIMS) { // too many dimens - phase = READ_NAME; - tensor_storage.n_dims = 0; - } - if (nelements % value == 0) { - tensor_storage.ne[tensor_storage.n_dims] = value; - tensor_storage.n_dims++; - } - } - return false; - } - - void read_global(const std::string& str) { - if (str == "FloatStorage") { - if (read_global_type) { - global_type = GGML_TYPE_F32; - read_global_type = false; - } - tensor_storage.type = GGML_TYPE_F32; - } else if (str == "HalfStorage") { - if (read_global_type) { - global_type = GGML_TYPE_F16; - read_global_type = false; - } - tensor_storage.type = GGML_TYPE_F16; - } - } - - void read_string(const std::string& str, zip_t* zip, std::string dir) { - if (str == "storage") { - read_global_type = true; - } else if (str != "state_dict") { - if (phase == READ_DATA) { - std::string entry_name = dir + "data/" + std::string(str); - - size_t i, n = zip_entries_total(zip); - for (i = 0; i < n; ++i) { - zip_entry_openbyindex(zip, i); - { - std::string name = zip_entry_name(zip); - if (name == entry_name) { - tensor_storage.index_in_zip = (int)i; - entry_size = zip_entry_size(zip); - zip_entry_close(zip); - break; - } - } - zip_entry_close(zip); - } - - phase = entry_size > 0 ? CHECK_SIZE : READ_NAME; - } - if (!read_global_type && phase == READ_NAME) { - tensor_storage.name = str; - phase = READ_DATA; - tensor_storage.type = global_type; - } - } - } -}; - -ggml_type PickleTensorReader::global_type = GGML_TYPE_F32; // all pickle_tensors data type -bool PickleTensorReader::read_global_type = false; - -int find_char(uint8_t* buffer, int len, char c) { - for (int pos = 0; pos < len; pos++) { - if (buffer[pos] == c) { - return pos; - } - } - return -1; -} - -#define MAX_STRING_BUFFER 512 - -bool ModelLoader::parse_data_pkl(uint8_t* buffer, - size_t buffer_size, - zip_t* zip, - std::string dir, - size_t file_index, - const std::string prefix) { - uint8_t* buffer_end = buffer + buffer_size; - if (buffer[0] == 0x80) { // proto - if (buffer[1] != 2) { - LOG_ERROR("Unsupported protocol\n"); - return false; - } - buffer += 2; // 0x80 and version - char string_buffer[MAX_STRING_BUFFER]; - bool finish = false; - PickleTensorReader reader; - // read pickle binary file - while (!finish && buffer < buffer_end) { - uint8_t opcode = *buffer; - buffer++; - // https://github.com/python/cpython/blob/3.7/Lib/pickletools.py#L1048 - // https://github.com/python/cpython/blob/main/Lib/pickle.py#L105 - switch (opcode) { - case '}': // EMPTY_DICT = b'}' # push empty dict - break; - case ']': // EMPTY_LIST = b']' # push empty list - break; - // skip unused sections - case 'h': // BINGET = b'h' # " " " " " " ; " " 1-byte arg - case 'q': // BINPUT = b'q' # " " " " " ; " " 1-byte arg - case 'Q': // BINPERSID = b'Q' # " " " ; " " " " stack - buffer++; - break; - case 'r': // LONG_BINPUT = b'r' # " " " " " ; " " 4-byte arg - buffer += 4; - break; - case 0x95: // FRAME = b'\x95' # indicate the beginning of a new frame - buffer += 8; - break; - case 0x94: // MEMOIZE = b'\x94' # store top of the stack in memo - break; - case '(': // MARK = b'(' # push special markobject on stack - break; - case 'K': // BININT1 = b'K' # push 1-byte unsigned int - { - uint8_t value = *buffer; - if (reader.read_int_value(value)) { - buffer++; - } - buffer++; - } break; - case 'M': // BININT2 = b'M' # push 2-byte unsigned int - { - uint16_t value = read_short(buffer); - if (reader.read_int_value(value)) { - buffer++; - } - buffer += 2; - } break; - case 'J': // BININT = b'J' # push four-byte signed int - { - const int32_t value = read_int(buffer); - if (reader.read_int_value(value)) { - buffer++; // skip tuple after read num_elements - } - buffer += 4; - } break; - case 'X': // BINUNICODE = b'X' # " " " ; counted UTF-8 string argument - { - const int32_t len = read_int(buffer); - buffer += 4; - memset(string_buffer, 0, MAX_STRING_BUFFER); - if (len > MAX_STRING_BUFFER) { - LOG_WARN("tensor name very large"); - } - memcpy(string_buffer, buffer, len < MAX_STRING_BUFFER ? len : (MAX_STRING_BUFFER - 1)); - buffer += len; - reader.read_string(string_buffer, zip, dir); - } break; - case 0x8C: // SHORT_BINUNICODE = b'\x8c' # push short string; UTF-8 length < 256 bytes - { - const int8_t len = *buffer; - buffer++; - memset(string_buffer, 0, MAX_STRING_BUFFER); - memcpy(string_buffer, buffer, len); - buffer += len; - // printf("String: '%s'\n", string_buffer); - } break; - case 'c': // GLOBAL = b'c' # push self.find_class(modname, name); 2 string args - { - int len = find_char(buffer, MAX_STRING_BUFFER, '\n'); - - buffer += len + 1; - len = find_char(buffer, MAX_STRING_BUFFER, '\n'); - - memset(string_buffer, 0, MAX_STRING_BUFFER); - memcpy(string_buffer, buffer, len); - buffer += len + 1; - reader.read_global(string_buffer); - } break; - case 0x86: // TUPLE2 = b'\x86' # build 2-tuple from two topmost stack items - case 0x85: // TUPLE1 = b'\x85' # build 1-tuple from stack top - case 't': // TUPLE = b't' # build tuple from topmost stack items - if (reader.phase == PickleTensorReader::READ_DIMENS) { - reader.tensor_storage.reverse_ne(); - reader.tensor_storage.file_index = file_index; - // if(strcmp(prefix.c_str(), "scarlett") == 0) - // printf(" ZIP got tensor %s \n ", reader.tensor_storage.name.c_str()); - std::string name = reader.tensor_storage.name; - if (!starts_with(name, prefix)) { - name = prefix + name; - } - reader.tensor_storage.name = name; - add_tensor_storage(reader.tensor_storage); - - // LOG_DEBUG("%s", reader.tensor_storage.name.c_str()); - // reset - reader = PickleTensorReader(); - } - break; - case '.': // STOP = b'.' # every pickle ends with STOP - finish = true; - break; - default: - break; - } - } - } - return true; -} - -bool ModelLoader::init_from_ckpt_file(const std::string& file_path, const std::string& prefix) { - LOG_DEBUG("init from '%s'", file_path.c_str()); - file_paths_.push_back(file_path); - size_t file_index = file_paths_.size() - 1; - - zip_t* zip = zip_open(file_path.c_str(), 0, 'r'); - if (zip == nullptr) { - LOG_ERROR("failed to open '%s'", file_path.c_str()); - return false; - } - int n = (int)zip_entries_total(zip); - for (int i = 0; i < n; ++i) { - zip_entry_openbyindex(zip, i); - { - std::string name = zip_entry_name(zip); - size_t pos = name.find("data.pkl"); - if (pos != std::string::npos) { - std::string dir = name.substr(0, pos); - printf("ZIP %d, name = %s, dir = %s \n", i, name.c_str(), dir.c_str()); - void* pkl_data = nullptr; - size_t pkl_size; - zip_entry_read(zip, &pkl_data, &pkl_size); - - // LOG_DEBUG("%lld", pkl_size); - - parse_data_pkl((uint8_t*)pkl_data, pkl_size, zip, dir, file_index, prefix); - - free(pkl_data); - } - } - zip_entry_close(zip); - } - zip_close(zip); - return true; -} - SDVersion ModelLoader::get_sd_version() { TensorStorage token_embedding_weight, input_block_weight; @@ -1019,64 +438,66 @@ SDVersion ModelLoader::get_sd_version() { bool has_middle_block_1 = false; bool has_output_block_311 = false; bool has_output_block_71 = false; + bool has_attn_1024 = false; for (auto& [name, tensor_storage] : tensor_storage_map) { - if (!(is_xl)) { - if (tensor_storage.name.find("model.diffusion_model.double_blocks.") != std::string::npos) { - is_flux = true; - } - if (tensor_storage.name.find("model.diffusion_model.nerf_final_layer_conv.") != std::string::npos) { - return VERSION_CHROMA_RADIANCE; - } - if (tensor_storage.name.find("model.diffusion_model.joint_blocks.") != std::string::npos) { - return VERSION_SD3; - } - if (tensor_storage.name.find("model.diffusion_model.transformer_blocks.0.img_mod.1.weight") != std::string::npos) { - return VERSION_QWEN_IMAGE; - } - if (tensor_storage.name.find("llm_adapter.blocks.0.cross_attn.q_proj.weight") != std::string::npos) { - return VERSION_ANIMA; - } - if (tensor_storage.name.find("model.diffusion_model.double_stream_modulation_img.lin.weight") != std::string::npos) { - is_flux2 = true; - } - if (tensor_storage.name.find("single_blocks.47.linear1.weight") != std::string::npos) { - has_single_block_47 = true; - } - if (tensor_storage.name.find("model.diffusion_model.double_blocks.0.img_mlp.gate_proj.weight") != std::string::npos) { - return VERSION_OVIS_IMAGE; - } - if (tensor_storage.name.find("model.diffusion_model.cap_embedder.0.weight") != std::string::npos) { - return VERSION_Z_IMAGE; - } - if (tensor_storage.name.find("model.diffusion_model.blocks.0.cross_attn.norm_k.weight") != std::string::npos) { - is_wan = true; - } - if (tensor_storage.name.find("model.diffusion_model.patch_embedding.weight") != std::string::npos) { - patch_embedding_channels = tensor_storage.ne[3]; - } - if (tensor_storage.name.find("model.diffusion_model.img_emb") != std::string::npos) { - has_img_emb = true; - } - if (tensor_storage.name.find("model.diffusion_model.input_blocks.") != std::string::npos || - tensor_storage.name.find("unet.down_blocks.") != std::string::npos) { - is_unet = true; - if (has_multiple_encoders) { - is_xl = true; - } - } - if (tensor_storage.name.find("conditioner.embedders.1") != std::string::npos || - tensor_storage.name.find("cond_stage_model.1") != std::string::npos || - tensor_storage.name.find("te.1") != std::string::npos) { - has_multiple_encoders = true; - if (is_unet) { - is_xl = true; - } + if (tensor_storage.name.find("model.diffusion_model.double_blocks.") != std::string::npos) { + is_flux = true; + } + if (tensor_storage.name.find("model.diffusion_model.nerf_final_layer_conv.") != std::string::npos) { + return VERSION_CHROMA_RADIANCE; + } + if (tensor_storage.name.find("model.diffusion_model.joint_blocks.") != std::string::npos) { + return VERSION_SD3; + } + if (tensor_storage.name.find("model.diffusion_model.transformer_blocks.0.img_mod.1.weight") != std::string::npos) { + return VERSION_QWEN_IMAGE; + } + if (tensor_storage.name.find("llm_adapter.blocks.0.cross_attn.q_proj.weight") != std::string::npos) { + return VERSION_ANIMA; + } + if (tensor_storage.name.find("model.diffusion_model.double_stream_modulation_img.lin.weight") != std::string::npos) { + is_flux2 = true; + } + if (tensor_storage.name.find("single_blocks.47.linear1.weight") != std::string::npos) { + has_single_block_47 = true; + } + if (tensor_storage.name.find("model.diffusion_model.double_blocks.0.img_mlp.gate_proj.weight") != std::string::npos) { + return VERSION_OVIS_IMAGE; + } + if (tensor_storage.name.find("model.diffusion_model.cap_embedder.0.weight") != std::string::npos) { + return VERSION_Z_IMAGE; + } + if (tensor_storage.name.find("model.diffusion_model.layers.0.adaLN_sa_ln.weight") != std::string::npos) { + return VERSION_ERNIE_IMAGE; + } + if (tensor_storage.name.find("model.diffusion_model.blocks.0.cross_attn.norm_k.weight") != std::string::npos) { + is_wan = true; + } + if (tensor_storage.name.find("model.diffusion_model.patch_embedding.weight") != std::string::npos) { + patch_embedding_channels = tensor_storage.ne[3]; + } + if (tensor_storage.name.find("model.diffusion_model.img_emb") != std::string::npos) { + has_img_emb = true; + } + if (tensor_storage.name.find("model.diffusion_model.input_blocks.") != std::string::npos || + tensor_storage.name.find("unet.down_blocks.") != std::string::npos) { + is_unet = true; + if (has_multiple_encoders) { + is_xl = true; } - if (tensor_storage.name.find("model.diffusion_model.input_blocks.8.0.time_mixer.mix_factor") != std::string::npos) { - return VERSION_SVD; + } + if (tensor_storage.name.find("conditioner.embedders.1") != std::string::npos || + tensor_storage.name.find("cond_stage_model.1") != std::string::npos || + tensor_storage.name.find("te.1") != std::string::npos) { + has_multiple_encoders = true; + if (is_unet) { + is_xl = true; } } + if (tensor_storage.name.find("model.diffusion_model.input_blocks.8.0.time_mixer.mix_factor") != std::string::npos) { + return VERSION_SVD; + } if (tensor_storage.name.find("model.diffusion_model.middle_block.1.") != std::string::npos || tensor_storage.name.find("unet.mid_block.resnets.1.") != std::string::npos) { has_middle_block_1 = true; @@ -1088,6 +509,10 @@ SDVersion ModelLoader::get_sd_version() { if (tensor_storage.name.find("model.diffusion_model.output_blocks.7.1") != std::string::npos || tensor_storage.name.find("unet.up_blocks.2.attentions.1") != std::string::npos) { has_output_block_71 = true; + if (tensor_storage.name.find("model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn1.to_k.weight") != std::string::npos) { + if (tensor_storage.ne[0] == 1024) + has_attn_1024 = true; + } } if (tensor_storage.name == "cond_stage_model.transformer.text_model.embeddings.token_embedding.weight" || tensor_storage.name == "cond_stage_model.model.token_embedding.weight" || @@ -1161,7 +586,7 @@ SDVersion ModelLoader::get_sd_version() { } if (!has_middle_block_1) { if (!has_output_block_71) { - return VERSION_SDXS; + return VERSION_SDXS_512_DS; } return VERSION_SD1_TINY_UNET; } @@ -1171,7 +596,7 @@ SDVersion ModelLoader::get_sd_version() { return VERSION_SD2_INPAINT; } if (!has_middle_block_1) { - return VERSION_SD2_TINY_UNET; + return has_attn_1024 ? VERSION_SDXS_09 : VERSION_SD2_TINY_UNET; } return VERSION_SD2; } @@ -1262,8 +687,8 @@ std::map ModelLoader::get_vae_wtype_stat() { return wtype_stat; } -static std::vector> parse_tensor_type_rules(const std::string& tensor_type_rules) { - std::vector> result; +TensorTypeRules parse_tensor_type_rules(const std::string& tensor_type_rules) { + TensorTypeRules result; for (const auto& item : split_string(tensor_type_rules, ',')) { if (item.size() == 0) continue; @@ -1696,76 +1121,6 @@ bool ModelLoader::tensor_should_be_converted(const TensorStorage& tensor_storage return false; } -bool ModelLoader::save_to_gguf_file(const std::string& file_path, ggml_type type, const std::string& tensor_type_rules_str) { - auto backend = ggml_backend_cpu_init(); - size_t mem_size = 1 * 1024 * 1024; // for padding - mem_size += tensor_storage_map.size() * ggml_tensor_overhead(); - mem_size += get_params_mem_size(backend, type); - LOG_INFO("model tensors mem size: %.2fMB", mem_size / 1024.f / 1024.f); - ggml_context* ggml_ctx = ggml_init({mem_size, nullptr, false}); - - gguf_context* gguf_ctx = gguf_init_empty(); - - auto tensor_type_rules = parse_tensor_type_rules(tensor_type_rules_str); - - std::mutex tensor_mutex; - auto on_new_tensor_cb = [&](const TensorStorage& tensor_storage, ggml_tensor** dst_tensor) -> bool { - const std::string& name = tensor_storage.name; - ggml_type tensor_type = tensor_storage.type; - ggml_type dst_type = type; - - for (const auto& tensor_type_rule : tensor_type_rules) { - std::regex pattern(tensor_type_rule.first); - if (std::regex_search(name, pattern)) { - dst_type = tensor_type_rule.second; - break; - } - } - - if (tensor_should_be_converted(tensor_storage, dst_type)) { - tensor_type = dst_type; - } - - std::lock_guard lock(tensor_mutex); - ggml_tensor* tensor = ggml_new_tensor(ggml_ctx, tensor_type, tensor_storage.n_dims, tensor_storage.ne); - if (tensor == nullptr) { - LOG_ERROR("ggml_new_tensor failed"); - return false; - } - ggml_set_name(tensor, name.c_str()); - - // LOG_DEBUG("%s %d %s %d[%d %d %d %d] %d[%d %d %d %d]", name.c_str(), - // ggml_nbytes(tensor), ggml_type_name(tensor_type), - // tensor_storage.n_dims, - // tensor_storage.ne[0], tensor_storage.ne[1], tensor_storage.ne[2], tensor_storage.ne[3], - // tensor->n_dims, tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]); - - if (!tensor->data) { - GGML_ASSERT(ggml_nelements(tensor) == 0); - // avoid crashing the gguf writer by setting a dummy pointer for zero-sized tensors - LOG_DEBUG("setting dummy pointer for zero-sized tensor %s", name.c_str()); - tensor->data = ggml_get_mem_buffer(ggml_ctx); - } - - *dst_tensor = tensor; - - gguf_add_tensor(gguf_ctx, tensor); - - return true; - }; - - bool success = load_tensors(on_new_tensor_cb); - ggml_backend_free(backend); - LOG_INFO("load tensors done"); - LOG_INFO("trying to save tensors to %s", file_path.c_str()); - if (success) { - gguf_write_to_file(gguf_ctx, file_path.c_str(), false); - } - ggml_free(ggml_ctx); - gguf_free(gguf_ctx); - return success; -} - int64_t ModelLoader::get_params_mem_size(ggml_backend_t backend, ggml_type type) { size_t alignment = 128; if (backend != nullptr) { @@ -1785,29 +1140,3 @@ int64_t ModelLoader::get_params_mem_size(ggml_backend_t backend, ggml_type type) return mem_size; } - -bool convert(const char* input_path, - const char* vae_path, - const char* output_path, - sd_type_t output_type, - const char* tensor_type_rules, - bool convert_name) { - ModelLoader model_loader; - - if (!model_loader.init_from_file(input_path)) { - LOG_ERROR("init model loader from file failed: '%s'", input_path); - return false; - } - - if (vae_path != nullptr && strlen(vae_path) > 0) { - if (!model_loader.init_from_file(vae_path, "vae.")) { - LOG_ERROR("init model loader from file failed: '%s'", vae_path); - return false; - } - } - if (convert_name) { - model_loader.convert_tensors_name(); - } - bool success = model_loader.save_to_gguf_file(output_path, (ggml_type)output_type, tensor_type_rules); - return success; -} diff --git a/src/model.h b/src/model.h index 3af35eb7e..65bc6c367 100644 --- a/src/model.h +++ b/src/model.h @@ -5,20 +5,13 @@ #include #include #include -#include #include -#include -#include #include #include "ggml-backend.h" #include "ggml.h" -#include "gguf.h" -#include "json.hpp" +#include "model_io/tensor_storage.h" #include "ordered_map.hpp" -#include "zip.h" - -#define SD_MAX_DIMS 5 enum SDVersion { VERSION_SD1, @@ -28,7 +21,8 @@ enum SDVersion { VERSION_SD2, VERSION_SD2_INPAINT, VERSION_SD2_TINY_UNET, - VERSION_SDXS, + VERSION_SDXS_512_DS, + VERSION_SDXS_09, VERSION_SDXL, VERSION_SDXL_INPAINT, VERSION_SDXL_PIX2PIX, @@ -50,18 +44,19 @@ enum SDVersion { VERSION_FLUX2_KLEIN, VERSION_Z_IMAGE, VERSION_OVIS_IMAGE, + VERSION_ERNIE_IMAGE, VERSION_COUNT, }; static inline bool sd_version_is_sd1(SDVersion version) { - if (version == VERSION_SD1 || version == VERSION_SD1_INPAINT || version == VERSION_SD1_PIX2PIX || version == VERSION_SD1_TINY_UNET || version == VERSION_SDXS) { + if (version == VERSION_SD1 || version == VERSION_SD1_INPAINT || version == VERSION_SD1_PIX2PIX || version == VERSION_SD1_TINY_UNET || version == VERSION_SDXS_512_DS) { return true; } return false; } static inline bool sd_version_is_sd2(SDVersion version) { - if (version == VERSION_SD2 || version == VERSION_SD2_INPAINT || version == VERSION_SD2_TINY_UNET) { + if (version == VERSION_SD2 || version == VERSION_SD2_INPAINT || version == VERSION_SD2_TINY_UNET || version == VERSION_SDXS_09) { return true; } return false; @@ -137,6 +132,20 @@ static inline bool sd_version_is_z_image(SDVersion version) { return false; } +static inline bool sd_version_is_ernie_image(SDVersion version) { + if (version == VERSION_ERNIE_IMAGE) { + return true; + } + return false; +} + +static inline bool sd_version_uses_flux2_vae(SDVersion version) { + if (sd_version_is_flux2(version) || sd_version_is_ernie_image(version)) { + return true; + } + return false; +} + static inline bool sd_version_is_inpaint(SDVersion version) { if (version == VERSION_SD1_INPAINT || version == VERSION_SD2_INPAINT || @@ -155,7 +164,8 @@ static inline bool sd_version_is_dit(SDVersion version) { sd_version_is_wan(version) || sd_version_is_qwen_image(version) || sd_version_is_anima(version) || - sd_version_is_z_image(version)) { + sd_version_is_z_image(version) || + sd_version_is_ernie_image(version)) { return true; } return false; @@ -178,116 +188,10 @@ enum PMVersion { PM_VERSION_2, }; -struct TensorStorage { - std::string name; - ggml_type type = GGML_TYPE_F32; - ggml_type expected_type = GGML_TYPE_COUNT; - bool is_f8_e4m3 = false; - bool is_f8_e5m2 = false; - bool is_f64 = false; - bool is_i64 = false; - int64_t ne[SD_MAX_DIMS] = {1, 1, 1, 1, 1}; - int n_dims = 0; - - size_t file_index = 0; - int index_in_zip = -1; // >= means stored in a zip file - uint64_t offset = 0; // offset in file - - TensorStorage() = default; - - TensorStorage(std::string name, ggml_type type, const int64_t* ne, int n_dims, size_t file_index, size_t offset = 0) - : name(std::move(name)), type(type), n_dims(n_dims), file_index(file_index), offset(offset) { - for (int i = 0; i < n_dims; i++) { - this->ne[i] = ne[i]; - } - } - - int64_t nelements() const { - int64_t n = 1; - for (int i = 0; i < SD_MAX_DIMS; i++) { - n *= ne[i]; - } - return n; - } - - int64_t nbytes() const { - return nelements() * ggml_type_size(type) / ggml_blck_size(type); - } - - int64_t nbytes_to_read() const { - if (is_f8_e4m3 || is_f8_e5m2) { - return nbytes() / 2; - } else if (is_f64 || is_i64) { - return nbytes() * 2; - } else { - return nbytes(); - } - } - - void unsqueeze() { - if (n_dims == 2) { - n_dims = 4; - ne[3] = ne[1]; - ne[2] = ne[0]; - ne[1] = 1; - ne[0] = 1; - } - } - - std::vector chunk(size_t n) { - std::vector chunks; - uint64_t chunk_size = nbytes_to_read() / n; - // printf("%d/%d\n", chunk_size, nbytes_to_read()); - reverse_ne(); - for (size_t i = 0; i < n; i++) { - TensorStorage chunk_i = *this; - chunk_i.ne[0] = ne[0] / n; - chunk_i.offset = offset + i * chunk_size; - chunk_i.reverse_ne(); - chunks.push_back(chunk_i); - } - reverse_ne(); - return chunks; - } - - void reverse_ne() { - int64_t new_ne[SD_MAX_DIMS] = {1, 1, 1, 1, 1}; - for (int i = 0; i < n_dims; i++) { - new_ne[i] = ne[n_dims - 1 - i]; - } - for (int i = 0; i < n_dims; i++) { - ne[i] = new_ne[i]; - } - } - - std::string to_string() const { - std::stringstream ss; - const char* type_name = ggml_type_name(type); - if (is_f8_e4m3) { - type_name = "f8_e4m3"; - } else if (is_f8_e5m2) { - type_name = "f8_e5m2"; - } else if (is_f64) { - type_name = "f64"; - } else if (is_i64) { - type_name = "i64"; - } - ss << name << " | " << type_name << " | "; - ss << n_dims << " ["; - for (int i = 0; i < SD_MAX_DIMS; i++) { - ss << ne[i]; - if (i != SD_MAX_DIMS - 1) { - ss << ", "; - } - } - ss << "]"; - return ss.str(); - } -}; - -typedef std::function on_new_tensor_cb_t; - typedef OrderedMap String2TensorStorage; +using TensorTypeRules = std::vector>; + +TensorTypeRules parse_tensor_type_rules(const std::string& tensor_type_rules); class ModelLoader { protected: @@ -297,16 +201,10 @@ class ModelLoader { void add_tensor_storage(const TensorStorage& tensor_storage); - bool parse_data_pkl(uint8_t* buffer, - size_t buffer_size, - zip_t* zip, - std::string dir, - size_t file_index, - const std::string prefix); - bool init_from_gguf_file(const std::string& file_path, const std::string& prefix = ""); bool init_from_safetensors_file(const std::string& file_path, const std::string& prefix = ""); - bool init_from_ckpt_file(const std::string& file_path, const std::string& prefix = ""); + bool init_from_torch_zip_file(const std::string& file_path, const std::string& prefix = ""); + bool init_from_torch_legacy_file(const std::string& file_path, const std::string& prefix = ""); bool init_from_diffusers_file(const std::string& file_path, const std::string& prefix = ""); public: @@ -336,7 +234,6 @@ class ModelLoader { return names; } - bool save_to_gguf_file(const std::string& file_path, ggml_type type, const std::string& tensor_type_rules); bool tensor_should_be_converted(const TensorStorage& tensor_storage, ggml_type type); int64_t get_params_mem_size(ggml_backend_t backend, ggml_type type = GGML_TYPE_COUNT); ~ModelLoader() = default; diff --git a/src/model_io/binary_io.h b/src/model_io/binary_io.h new file mode 100644 index 000000000..9093eeaf9 --- /dev/null +++ b/src/model_io/binary_io.h @@ -0,0 +1,57 @@ +#ifndef __SD_MODEL_IO_BINARY_IO_H__ +#define __SD_MODEL_IO_BINARY_IO_H__ + +#include +#include + +namespace model_io { + + inline int32_t read_int(const uint8_t* buffer) { + uint32_t value = 0; + value |= static_cast(buffer[3]) << 24; + value |= static_cast(buffer[2]) << 16; + value |= static_cast(buffer[1]) << 8; + value |= static_cast(buffer[0]); + return static_cast(value); + } + + inline uint16_t read_short(const uint8_t* buffer) { + uint16_t value = 0; + value |= static_cast(buffer[1]) << 8; + value |= static_cast(buffer[0]); + return value; + } + + inline uint64_t read_u64(const uint8_t* buffer) { + uint64_t value = 0; + value |= static_cast(buffer[7]) << 56; + value |= static_cast(buffer[6]) << 48; + value |= static_cast(buffer[5]) << 40; + value |= static_cast(buffer[4]) << 32; + value |= static_cast(buffer[3]) << 24; + value |= static_cast(buffer[2]) << 16; + value |= static_cast(buffer[1]) << 8; + value |= static_cast(buffer[0]); + return value; + } + + inline void write_u64(std::ostream& stream, uint64_t value) { + uint8_t buffer[8]; + for (int i = 0; i < 8; ++i) { + buffer[i] = static_cast((value >> (8 * i)) & 0xFF); + } + stream.write((const char*)buffer, sizeof(buffer)); + } + + inline int find_char(const uint8_t* buffer, int len, char c) { + for (int pos = 0; pos < len; pos++) { + if (buffer[pos] == (uint8_t)c) { + return pos; + } + } + return -1; + } + +} // namespace model_io + +#endif // __SD_MODEL_IO_BINARY_IO_H__ diff --git a/src/model_io/gguf_io.cpp b/src/model_io/gguf_io.cpp new file mode 100644 index 000000000..378694d8e --- /dev/null +++ b/src/model_io/gguf_io.cpp @@ -0,0 +1,123 @@ +#include "gguf_io.h" + +#include +#include +#include +#include + +#include "gguf.h" +#include "gguf_reader_ext.h" +#include "util.h" + +static void set_error(std::string* error, const std::string& message) { + if (error != nullptr) { + *error = message; + } +} + +bool is_gguf_file(const std::string& file_path) { + std::ifstream file(file_path, std::ios::binary); + if (!file.is_open()) { + return false; + } + + char magic[4]; + + file.read(magic, sizeof(magic)); + if (!file) { + return false; + } + for (uint32_t i = 0; i < sizeof(magic); i++) { + if (magic[i] != GGUF_MAGIC[i]) { + return false; + } + } + + return true; +} + +bool read_gguf_file(const std::string& file_path, + std::vector& tensor_storages, + std::string* error) { + tensor_storages.clear(); + + gguf_context* ctx_gguf_ = nullptr; + ggml_context* ctx_meta_ = nullptr; + + ctx_gguf_ = gguf_init_from_file(file_path.c_str(), {true, &ctx_meta_}); + if (!ctx_gguf_) { + GGUFReader gguf_reader; + if (!gguf_reader.load(file_path)) { + set_error(error, "failed to open '" + file_path + "' with GGUFReader"); + return false; + } + + size_t data_offset = gguf_reader.data_offset(); + for (const auto& gguf_tensor_info : gguf_reader.tensors()) { + TensorStorage tensor_storage( + gguf_tensor_info.name, + gguf_tensor_info.type, + gguf_tensor_info.shape.data(), + static_cast(gguf_tensor_info.shape.size()), + 0, + data_offset + gguf_tensor_info.offset); + + tensor_storages.push_back(tensor_storage); + } + + return true; + } + + int n_tensors = static_cast(gguf_get_n_tensors(ctx_gguf_)); + + size_t data_offset = gguf_get_data_offset(ctx_gguf_); + for (int i = 0; i < n_tensors; i++) { + std::string name = gguf_get_tensor_name(ctx_gguf_, i); + ggml_tensor* dummy = ggml_get_tensor(ctx_meta_, name.c_str()); + size_t offset = data_offset + gguf_get_tensor_offset(ctx_gguf_, i); + + TensorStorage tensor_storage(name, dummy->type, dummy->ne, ggml_n_dims(dummy), 0, offset); + + if (ggml_nbytes(dummy) != tensor_storage.nbytes()) { + gguf_free(ctx_gguf_); + ggml_free(ctx_meta_); + set_error(error, "size mismatch for tensor '" + name + "'"); + return false; + } + + tensor_storages.push_back(tensor_storage); + } + + gguf_free(ctx_gguf_); + ggml_free(ctx_meta_); + + return true; +} + +bool write_gguf_file(const std::string& file_path, + const std::vector& tensors, + std::string* error) { + gguf_context* gguf_ctx = gguf_init_empty(); + if (gguf_ctx == nullptr) { + set_error(error, "gguf_init_empty failed"); + return false; + } + + for (const TensorWriteInfo& write_tensor : tensors) { + ggml_tensor* tensor = write_tensor.tensor; + if (tensor == nullptr) { + set_error(error, "null tensor cannot be written to GGUF"); + gguf_free(gguf_ctx); + return false; + } + gguf_add_tensor(gguf_ctx, tensor); + } + + LOG_INFO("trying to save tensors to %s", file_path.c_str()); + bool success = gguf_write_to_file(gguf_ctx, file_path.c_str(), false); + if (!success) { + set_error(error, "failed to write GGUF file '" + file_path + "'"); + } + gguf_free(gguf_ctx); + return success; +} diff --git a/src/model_io/gguf_io.h b/src/model_io/gguf_io.h new file mode 100644 index 000000000..81c981145 --- /dev/null +++ b/src/model_io/gguf_io.h @@ -0,0 +1,17 @@ +#ifndef __SD_MODEL_IO_GGUF_IO_H__ +#define __SD_MODEL_IO_GGUF_IO_H__ + +#include +#include + +#include "tensor_storage.h" + +bool is_gguf_file(const std::string& file_path); +bool read_gguf_file(const std::string& file_path, + std::vector& tensor_storages, + std::string* error = nullptr); +bool write_gguf_file(const std::string& file_path, + const std::vector& tensors, + std::string* error = nullptr); + +#endif // __SD_MODEL_IO_GGUF_IO_H__ diff --git a/src/gguf_reader.hpp b/src/model_io/gguf_reader_ext.h similarity index 98% rename from src/gguf_reader.hpp rename to src/model_io/gguf_reader_ext.h index 9a2ceebcf..95f0027fc 100644 --- a/src/gguf_reader.hpp +++ b/src/model_io/gguf_reader_ext.h @@ -1,5 +1,5 @@ -#ifndef __GGUF_READER_HPP__ -#define __GGUF_READER_HPP__ +#ifndef __SD_MODEL_IO_GGUF_READER_EXT_H__ +#define __SD_MODEL_IO_GGUF_READER_EXT_H__ #include #include @@ -231,4 +231,4 @@ class GGUFReader { size_t data_offset() const { return data_offset_; } }; -#endif // __GGUF_READER_HPP__ +#endif // __SD_MODEL_IO_GGUF_READER_EXT_H__ diff --git a/src/model_io/pickle_io.cpp b/src/model_io/pickle_io.cpp new file mode 100644 index 000000000..3a978178a --- /dev/null +++ b/src/model_io/pickle_io.cpp @@ -0,0 +1,1064 @@ +#include "pickle_io.h" + +#include +#include +#include +#include +#include +#include + +#include "binary_io.h" +#include "util.h" + +// $ python -m pickletools sd-v1-4/archive/data.pkl | head -n 100 +// 0: \x80 PROTO 2 +// 2: } EMPTY_DICT +// 3: q BINPUT 0 +// 5: ( MARK +// 6: X BINUNICODE 'epoch' +// 16: q BINPUT 1 +// 18: K BININT1 6 +// 20: X BINUNICODE 'global_step' +// 36: q BINPUT 2 +// 38: J BININT 470000 +// 43: X BINUNICODE 'pytorch-lightning_version' +// 73: q BINPUT 3 +// 75: X BINUNICODE '1.4.2' +// 85: q BINPUT 4 +// 87: X BINUNICODE 'state_dict' +// 102: q BINPUT 5 +// 104: } EMPTY_DICT +// 105: q BINPUT 6 +// 107: ( MARK +// 108: X BINUNICODE 'betas' +// 118: q BINPUT 7 +// 120: c GLOBAL 'torch._utils _rebuild_tensor_v2' +// 153: q BINPUT 8 +// 155: ( MARK +// 156: ( MARK +// 157: X BINUNICODE 'storage' +// 169: q BINPUT 9 +// 171: c GLOBAL 'torch FloatStorage' +// 191: q BINPUT 10 +// 193: X BINUNICODE '0' +// 199: q BINPUT 11 +// 201: X BINUNICODE 'cpu' +// 209: q BINPUT 12 +// 211: M BININT2 1000 +// 214: t TUPLE (MARK at 156) +// 215: q BINPUT 13 +// 217: Q BINPERSID +// 218: K BININT1 0 +// 220: M BININT2 1000 +// ............................... +// 3201: q BINPUT 250 +// 3203: R REDUCE +// 3204: q BINPUT 251 +// 3206: X BINUNICODE 'model.diffusion_model.input_blocks.1.1.proj_in.weight' +// 3264: q BINPUT 252 +// 3266: h BINGET 8 +// 3268: ( MARK +// 3269: ( MARK +// 3270: h BINGET 9 +// 3272: h BINGET 10 +// 3274: X BINUNICODE '30' +// 3281: q BINPUT 253 +// 3283: h BINGET 12 +// 3285: J BININT 102400 +// 3290: t TUPLE (MARK at 3269) +// 3291: q BINPUT 254 +// 3293: Q BINPERSID +// 3294: K BININT1 0 +// 3296: ( MARK +// 3297: M BININT2 320 +// 3300: M BININT2 320 +// 3303: K BININT1 1 +// 3305: K BININT1 1 +// 3307: t TUPLE (MARK at 3296) +// 3308: q BINPUT 255 +// 3310: ( MARK +// 3311: M BININT2 320 +// 3314: K BININT1 1 +// 3316: K BININT1 1 +// 3318: K BININT1 1 +// 3320: t TUPLE (MARK at 3310) +// 3321: r LONG_BINPUT 256 +// 3326: \x89 NEWFALSE +// 3327: h BINGET 16 +// 3329: ) EMPTY_TUPLE +// 3330: R REDUCE +// 3331: r LONG_BINPUT 257 +// 3336: t TUPLE (MARK at 3268) +// 3337: r LONG_BINPUT 258 +// 3342: R REDUCE +// 3343: r LONG_BINPUT 259 +// 3348: X BINUNICODE 'model.diffusion_model.input_blocks.1.1.proj_in.bias' +// 3404: r LONG_BINPUT 260 +// 3409: h BINGET 8 +// 3411: ( MARK +// 3412: ( MARK +// 3413: h BINGET 9 +// 3415: h BINGET 10 +// 3417: X BINUNICODE '31' +// https://github.com/python/cpython/blob/3.7/Lib/pickletools.py#L1048 +// https://github.com/python/cpython/blob/main/Lib/pickle.py#L105 + +using model_io::find_char; +using model_io::read_int; +using model_io::read_short; +using model_io::read_u64; + +static void set_error(std::string* error, const std::string& message) { + if (error != nullptr) { + *error = message; + } +} + +bool skip_pickle_object(const uint8_t* buffer, size_t buffer_size, size_t* object_size) { + const uint8_t* p = buffer; + const uint8_t* end = buffer + buffer_size; + + while (p < end) { + uint8_t opcode = *p++; + switch (opcode) { + case '.': // STOP = b'.' # every pickle ends with STOP + *object_size = (size_t)(p - buffer); + return true; + case 0x80: // PROTO = b'\x80' # protocol version indicator + case 'K': // BININT1 = b'K' # push 1-byte unsigned int + case 'h': // BINGET = b'h' # read memo index, 1-byte arg + case 'q': // BINPUT = b'q' # write memo index, 1-byte arg + case 'C': // SHORT_BINBYTES = b'C' # push bytes; length < 256 + case 0x82: // EXT1 = b'\x82' # extension code, 1-byte arg + p += 1; + break; + case 'M': // BININT2 = b'M' # push 2-byte unsigned int + case 0x83: // EXT2 = b'\x83' # extension code, 2-byte arg + p += 2; + break; + case 'J': // BININT = b'J' # push 4-byte signed int + case 'j': // LONG_BINGET = b'j' # read memo index, 4-byte arg + case 'r': // LONG_BINPUT = b'r' # write memo index, 4-byte arg + case 0x84: // EXT4 = b'\x84' # extension code, 4-byte arg + p += 4; + break; + case 'I': // INT = b'I' # push decimal integer line + case 'L': // LONG = b'L' # push decimal long integer line + case 'F': // FLOAT = b'F' # push decimal float line + case 'S': // STRING = b'S' # push quoted string line + case 'V': { // UNICODE = b'V' # push raw-unicode string line + int len = find_char(p, (int)(end - p), '\n'); + if (len < 0) { + return false; + } + p += len + 1; + } break; + case 'G': // BINFLOAT = b'G' # push 8-byte binary float + p += 8; + break; + case 0x8A: // LONG1 = b'\x8a' # push long integer; 1-byte length + if (p >= end) { + return false; + } + p += 1 + p[0]; + break; + case 0x8B: { // LONG4 = b'\x8b' # push long integer; 4-byte length + if (p + 4 > end) { + return false; + } + uint32_t n = (uint32_t)read_int(p); + p += 4 + n; + } break; + case 'B': { // BINBYTES = b'B' # push bytes; 4-byte length + if (p + 4 > end) { + return false; + } + uint32_t n = (uint32_t)read_int(p); + p += 4 + n; + } break; + case 'T': // BINSTRING = b'T' # push string; 4-byte length + case 'X': { // BINUNICODE = b'X' # push UTF-8 string; 4-byte length + if (p + 4 > end) { + return false; + } + uint32_t n = (uint32_t)read_int(p); + p += 4 + n; + } break; + case 0x8D: // BINUNICODE8 = b'\x8d' # push UTF-8 string; 8-byte length + case 0x8E: // BINBYTES8 = b'\x8e' # push bytes; 8-byte length + case 0x96: { // BYTEARRAY8 = b'\x96' # push bytearray; 8-byte length + if (p + 8 > end) { + return false; + } + uint64_t n = read_u64(p); + p += 8; + if (n > (uint64_t)(end - p)) { + return false; + } + p += n; + } break; + case 'U': // SHORT_BINSTRING = b'U' # push string; length < 256 + case 0x8C: // SHORT_BINUNICODE = b'\x8c' # push UTF-8 string; length < 256 + if (p >= end) { + return false; + } + p += 1 + p[0]; + break; + case 'P': { // PERSID = b'P' # persistent id, newline-terminated + int len = find_char(p, (int)(end - p), '\n'); + if (len < 0) { + return false; + } + p += len + 1; + } break; + case 0x95: // FRAME = b'\x95' # indicate the beginning of a new frame + p += 8; + break; + case 'c': { // GLOBAL = b'c' # push module/name global reference + int len = find_char(p, (int)(end - p), '\n'); + if (len < 0) { + return false; + } + p += len + 1; + len = find_char(p, (int)(end - p), '\n'); + if (len < 0) { + return false; + } + p += len + 1; + } break; + case '}': // EMPTY_DICT = b'}' # push empty dict + case ']': // EMPTY_LIST = b']' # push empty list + case '(': // MARK = b'(' # push markobject + case 't': // TUPLE = b't' # build tuple from mark + case 0x85: // TUPLE1 = b'\x85' # build 1-tuple from stack + case 0x86: // TUPLE2 = b'\x86' # build 2-tuple from stack + case 0x87: // TUPLE3 = b'\x87' # build 3-tuple from stack + case ')': // EMPTY_TUPLE = b')' # push empty tuple + case 'l': // LIST = b'l' # build list from mark + case 'Q': // BINPERSID = b'Q' # persistent id from stack + case 0x94: // MEMOIZE = b'\x94' # store top of stack in memo + case 0x88: // NEWTRUE = b'\x88' # push True + case 0x89: // NEWFALSE = b'\x89' # push False + case 'R': // REDUCE = b'R' # apply callable to args + case 'u': // SETITEMS = b'u' # add mark-delimited items to dict + case 's': // SETITEM = b's' # add key/value to dict + case 'e': // APPENDS = b'e' # extend list with mark-delimited items + case 'a': // APPEND = b'a' # append item to list + case 'b': // BUILD = b'b' # build object state + case 0x81: // NEWOBJ = b'\x81' # build object via __new__ + case 0x8F: // EMPTY_SET = b'\x8f' # push empty set + case 0x90: // ADDITEMS = b'\x90' # add mark-delimited items to set + case 0x91: // FROZENSET = b'\x91' # build frozenset from mark + case 0x92: // NEWOBJ_EX = b'\x92' # build object with kwargs + case 0x93: // STACK_GLOBAL = b'\x93' # build global from module/name strings + case 0x97: // NEXT_BUFFER = b'\x97' # out-of-band buffer marker + case 0x98: // READONLY_BUFFER = b'\x98' # mark buffer readonly + case 'N': // NONE = b'N' # push None + case '0': // POP = b'0' # discard top stack item + case '1': // POP_MARK = b'1' # discard stack through topmost mark + case '2': // DUP = b'2' # duplicate top stack item + case 'o': // OBJ = b'o' # build class instance from mark + break; + case 'i': { // INST = b'i' # build class instance from module/name + int len = find_char(p, (int)(end - p), '\n'); + if (len < 0) { + return false; + } + p += len + 1; + len = find_char(p, (int)(end - p), '\n'); + if (len < 0) { + return false; + } + p += len + 1; + } break; + default: + return false; + } + if (p > end) { + return false; + } + } + + return false; +} + +bool pickle_object_is_torch_magic_number(const uint8_t* buffer, size_t buffer_size) { + static const uint8_t torch_magic_bytes[] = {0x6C, 0xFC, 0x9C, 0x46, 0xF9, 0x20, 0x6A, 0xA8, 0x50, 0x19}; + + if (buffer_size < 5 || buffer[0] != 0x80) { + return false; + } + + size_t pos = 2; + if (pos >= buffer_size) { + return false; + } + + uint8_t opcode = buffer[pos++]; + if (opcode != 0x8A || pos >= buffer_size) { + return false; + } + + uint8_t len = buffer[pos++]; + if (len != sizeof(torch_magic_bytes) || pos + len >= buffer_size) { + return false; + } + + if (memcmp(buffer + pos, torch_magic_bytes, sizeof(torch_magic_bytes)) != 0) { + return false; + } + pos += len; + + return pos < buffer_size && buffer[pos] == '.'; +} + +bool parse_pickle_uint32_object(const uint8_t* buffer, size_t buffer_size, uint32_t* value) { + if (buffer_size < 4 || buffer[0] != 0x80) { + return false; + } + + size_t pos = 2; + if (pos >= buffer_size) { + return false; + } + + uint8_t opcode = buffer[pos++]; + switch (opcode) { + case 'K': // BININT1 = b'K' # push 1-byte unsigned int + if (pos + 1 >= buffer_size) { + return false; + } + *value = buffer[pos]; + pos += 1; + break; + case 'M': // BININT2 = b'M' # push 2-byte unsigned int + if (pos + 2 >= buffer_size) { + return false; + } + *value = read_short(buffer + pos); + pos += 2; + break; + case 'J': // BININT = b'J' # push 4-byte signed int + if (pos + 4 >= buffer_size) { + return false; + } + *value = (uint32_t)read_int(buffer + pos); + pos += 4; + break; + default: + return false; + } + + return pos < buffer_size && buffer[pos] == '.'; +} + +struct PickleStorageInfo { + std::string key; + ggml_type type = GGML_TYPE_COUNT; + bool is_f64 = false; + bool is_i64 = false; + uint64_t raw_element_nbytes = 0; + uint64_t nbytes = 0; +}; + +struct PickleTensorInfo { + TensorStorage tensor_storage; + int stride_n_dims = 0; + int64_t stride[SD_MAX_DIMS]{1, 1, 1, 1, 1}; +}; + +struct PickleValue { + enum Kind { + MARK, + NONE, + BOOL, + INT, + STRING, + GLOBAL, + TUPLE, + LIST, + DICT, + ORDERED_DICT, + STORAGE, + TENSOR, + }; + + Kind kind = NONE; + int64_t int_value = 0; + bool bool_value = false; + std::string str_value; + std::vector items; + std::vector> dict_items; + PickleStorageInfo storage; + PickleTensorInfo tensor; +}; + +static PickleValue make_mark_value() { + PickleValue value; + value.kind = PickleValue::MARK; + return value; +} + +static PickleValue make_none_value() { + PickleValue value; + value.kind = PickleValue::NONE; + return value; +} + +static PickleValue make_bool_value(bool b) { + PickleValue value; + value.kind = PickleValue::BOOL; + value.bool_value = b; + return value; +} + +static PickleValue make_int_value(int64_t x) { + PickleValue value; + value.kind = PickleValue::INT; + value.int_value = x; + return value; +} + +static PickleValue make_string_value(const std::string& s) { + PickleValue value; + value.kind = PickleValue::STRING; + value.str_value = s; + return value; +} + +static PickleValue make_global_value(const std::string& s) { + PickleValue value; + value.kind = PickleValue::GLOBAL; + value.str_value = s; + return value; +} + +static PickleValue make_tuple_value(std::vector items) { + PickleValue value; + value.kind = PickleValue::TUPLE; + value.items = std::move(items); + return value; +} + +static PickleValue make_list_value() { + PickleValue value; + value.kind = PickleValue::LIST; + return value; +} + +static PickleValue make_dict_value(bool ordered) { + PickleValue value; + value.kind = ordered ? PickleValue::ORDERED_DICT : PickleValue::DICT; + return value; +} + +static PickleValue make_storage_value(const PickleStorageInfo& storage) { + PickleValue value; + value.kind = PickleValue::STORAGE; + value.storage = storage; + return value; +} + +static PickleValue make_tensor_value(const PickleTensorInfo& tensor) { + PickleValue value; + value.kind = PickleValue::TENSOR; + value.tensor = tensor; + return value; +} + +static std::string pickle_value_to_string(const PickleValue& value) { + if (value.kind == PickleValue::STRING) { + return value.str_value; + } + if (value.kind == PickleValue::INT) { + return std::to_string(value.int_value); + } + return ""; +} + +static bool parse_storage_type(const std::string& global_name, PickleStorageInfo* storage) { + if (global_name == "torch.FloatStorage") { + storage->type = GGML_TYPE_F32; + storage->raw_element_nbytes = 4; + return true; + } + if (global_name == "torch.DoubleStorage") { + storage->type = GGML_TYPE_F32; + storage->is_f64 = true; + storage->raw_element_nbytes = 8; + return true; + } + if (global_name == "torch.HalfStorage") { + storage->type = GGML_TYPE_F16; + storage->raw_element_nbytes = 2; + return true; + } + if (global_name == "torch.BFloat16Storage") { + storage->type = GGML_TYPE_BF16; + storage->raw_element_nbytes = 2; + return true; + } + if (global_name == "torch.IntStorage") { + storage->type = GGML_TYPE_I32; + storage->raw_element_nbytes = 4; + return true; + } + if (global_name == "torch.LongStorage") { + storage->type = GGML_TYPE_I32; + storage->is_i64 = true; + storage->raw_element_nbytes = 8; + return true; + } + return false; +} + +static bool tensor_is_contiguous(const PickleTensorInfo& tensor) { + if (tensor.tensor_storage.nelements() == 0) { + return true; + } + if (tensor.stride_n_dims != tensor.tensor_storage.n_dims) { + return false; + } + + int64_t expected_stride = 1; + for (int i = tensor.tensor_storage.n_dims - 1; i >= 0; --i) { + if (tensor.stride[i] != expected_stride) { + return false; + } + expected_stride *= tensor.tensor_storage.ne[i]; + } + return true; +} + +static void collect_tensors_from_pickle_value(const PickleValue& value, + std::vector& tensor_storages) { + if (value.kind != PickleValue::DICT && value.kind != PickleValue::ORDERED_DICT) { + return; + } + + for (const auto& item : value.dict_items) { + if (item.first.kind == PickleValue::STRING && item.second.kind == PickleValue::TENSOR) { + TensorStorage tensor_storage = item.second.tensor.tensor_storage; + tensor_storage.name = item.first.str_value; + tensor_storage.reverse_ne(); + tensor_storages.push_back(tensor_storage); + } else if (item.second.kind == PickleValue::DICT || item.second.kind == PickleValue::ORDERED_DICT) { + collect_tensors_from_pickle_value(item.second, tensor_storages); + } + } +} + +bool parse_torch_state_dict_pickle(const uint8_t* buffer, + size_t buffer_size, + std::vector& tensor_storages, + std::unordered_map& storage_nbytes, + std::string* error) { + if (buffer_size < 2 || buffer[0] != 0x80 || buffer[1] < 2 || buffer[1] > 5) { + set_error(error, "unsupported torch pickle protocol"); + return false; + } + + const uint8_t* p = buffer + 2; + const uint8_t* end = buffer + buffer_size; + std::vector stack; + std::unordered_map memo; + + while (p < end) { + uint8_t opcode = *p++; + switch (opcode) { + case '.': { // STOP = b'.' # every pickle ends with STOP + if (stack.empty()) { + set_error(error, "empty torch pickle stack"); + return false; + } + size_t old_tensor_count = tensor_storages.size(); + collect_tensors_from_pickle_value(stack.back(), tensor_storages); + if (tensor_storages.size() == old_tensor_count) { + set_error(error, "torch pickle does not contain a supported state_dict"); + return false; + } + return true; + } + case '}': // EMPTY_DICT = b'}' # push empty dict + stack.push_back(make_dict_value(false)); + break; + case ']': // EMPTY_LIST = b']' # push empty list + stack.push_back(make_list_value()); + break; + case 'l': { // LIST = b'l' # build list from mark + int mark_idx = (int)stack.size() - 1; + while (mark_idx >= 0 && stack[mark_idx].kind != PickleValue::MARK) { + --mark_idx; + } + if (mark_idx < 0) { + set_error(error, "torch pickle list without mark"); + return false; + } + std::vector items(stack.begin() + mark_idx + 1, stack.end()); + stack.erase(stack.begin() + mark_idx, stack.end()); + PickleValue list_value = make_list_value(); + list_value.items = std::move(items); + stack.push_back(std::move(list_value)); + } break; + case '(': // MARK = b'(' # push markobject + stack.push_back(make_mark_value()); + break; + case ')': // EMPTY_TUPLE = b')' # push empty tuple + stack.push_back(make_tuple_value({})); + break; + case 'N': // NONE = b'N' # push None + stack.push_back(make_none_value()); + break; + case 0x88: // NEWTRUE = b'\x88' # push True + stack.push_back(make_bool_value(true)); + break; + case 0x89: // NEWFALSE = b'\x89' # push False + stack.push_back(make_bool_value(false)); + break; + case 'K': // BININT1 = b'K' # push 1-byte unsigned int + if (p >= end) { + return false; + } + stack.push_back(make_int_value(*p++)); + break; + case 'M': // BININT2 = b'M' # push 2-byte unsigned int + if (p + 2 > end) { + return false; + } + stack.push_back(make_int_value(read_short(p))); + p += 2; + break; + case 'J': // BININT = b'J' # push 4-byte signed int + if (p + 4 > end) { + return false; + } + stack.push_back(make_int_value(read_int(p))); + p += 4; + break; + case 'I': { // INT = b'I' # push decimal integer line + int len = find_char(p, (int)(end - p), '\n'); + if (len < 0) { + return false; + } + std::string s((const char*)p, len); + p += len + 1; + if (s == "01") { + stack.push_back(make_bool_value(true)); + } else if (s == "00") { + stack.push_back(make_bool_value(false)); + } else { + stack.push_back(make_int_value(std::strtoll(s.c_str(), nullptr, 10))); + } + } break; + case 'L': { // LONG = b'L' # push decimal long integer line + int len = find_char(p, (int)(end - p), '\n'); + if (len < 0) { + return false; + } + std::string s((const char*)p, len); + p += len + 1; + if (!s.empty() && s.back() == 'L') { + s.pop_back(); + } + stack.push_back(make_int_value(std::strtoll(s.c_str(), nullptr, 10))); + } break; + case 'F': { // FLOAT = b'F' # push decimal float line + int len = find_char(p, (int)(end - p), '\n'); + if (len < 0) { + return false; + } + p += len + 1; + stack.push_back(make_none_value()); + } break; + case 'G': // BINFLOAT = b'G' # push 8-byte binary float + if (p + 8 > end) { + return false; + } + p += 8; + stack.push_back(make_none_value()); + break; + case 0x8A: { // LONG1 = b'\x8a' # push long integer; 1-byte length + if (p >= end) { + return false; + } + uint8_t n = *p++; + if (p + n > end || n > 8) { + return false; + } + int64_t value = 0; + for (uint8_t i = 0; i < n; ++i) { + value |= (int64_t)p[i] << (i * 8); + } + p += n; + stack.push_back(make_int_value(value)); + } break; + case 'C': { // SHORT_BINBYTES = b'C' # push bytes; length < 256 + if (p >= end) { + return false; + } + uint8_t len = *p++; + if (p + len > end) { + return false; + } + stack.push_back(make_string_value(std::string((const char*)p, len))); + p += len; + } break; + case 'B': { // BINBYTES = b'B' # push bytes; 4-byte length + if (p + 4 > end) { + return false; + } + int32_t len = read_int(p); + p += 4; + if (len < 0 || p + len > end) { + return false; + } + stack.push_back(make_string_value(std::string((const char*)p, len))); + p += len; + } break; + case 'T': // BINSTRING = b'T' # push string; 4-byte length + case 'X': { // BINUNICODE = b'X' # push UTF-8 string; 4-byte length + if (p + 4 > end) { + return false; + } + int32_t len = read_int(p); + p += 4; + if (len < 0 || p + len > end) { + return false; + } + stack.push_back(make_string_value(std::string((const char*)p, len))); + p += len; + } break; + case 0x8D: // BINUNICODE8 = b'\x8d' # push UTF-8 string; 8-byte length + case 0x8E: // BINBYTES8 = b'\x8e' # push bytes; 8-byte length + case 0x96: { // BYTEARRAY8 = b'\x96' # push bytearray; 8-byte length + if (p + 8 > end) { + return false; + } + uint64_t len = read_u64(p); + p += 8; + if (len > (uint64_t)(end - p)) { + return false; + } + stack.push_back(make_string_value(std::string((const char*)p, (size_t)len))); + p += len; + } break; + case 'U': // SHORT_BINSTRING = b'U' # push string; length < 256 + case 0x8C: { // SHORT_BINUNICODE = b'\x8c' # push UTF-8 string; length < 256 + if (p >= end) { + return false; + } + uint8_t len = *p++; + if (p + len > end) { + return false; + } + stack.push_back(make_string_value(std::string((const char*)p, len))); + p += len; + } break; + case 'S': { // STRING = b'S' # push quoted string line + int len = find_char(p, (int)(end - p), '\n'); + if (len < 0) { + return false; + } + std::string s((const char*)p, len); + p += len + 1; + if (s.size() >= 2 && (s[0] == '\'' || s[0] == '"') && s.back() == s[0]) { + s = s.substr(1, s.size() - 2); + } + stack.push_back(make_string_value(s)); + } break; + case 'V': { // UNICODE = b'V' # push raw-unicode string line + int len = find_char(p, (int)(end - p), '\n'); + if (len < 0) { + return false; + } + stack.push_back(make_string_value(std::string((const char*)p, len))); + p += len + 1; + } break; + case 'c': { // GLOBAL = b'c' # push module/name global reference + int len = find_char(p, (int)(end - p), '\n'); + if (len < 0) { + return false; + } + std::string module((const char*)p, len); + p += len + 1; + len = find_char(p, (int)(end - p), '\n'); + if (len < 0) { + return false; + } + std::string name((const char*)p, len); + p += len + 1; + stack.push_back(make_global_value(module + "." + name)); + } break; + case 0x93: { // STACK_GLOBAL = b'\x93' # build global from module/name strings + if (stack.size() < 2 || stack[stack.size() - 2].kind != PickleValue::STRING || + stack.back().kind != PickleValue::STRING) { + return false; + } + std::string name = stack.back().str_value; + stack.pop_back(); + std::string module = stack.back().str_value; + stack.pop_back(); + stack.push_back(make_global_value(module + "." + name)); + } break; + case 'h': // BINGET = b'h' # read memo index, 1-byte arg + if (p >= end || !memo.count(*p)) { + return false; + } + stack.push_back(memo[*p++]); + break; + case 'j': { // LONG_BINGET = b'j' # read memo index, 4-byte arg + if (p + 4 > end) { + return false; + } + int32_t memo_idx = read_int(p); + if (!memo.count(memo_idx)) { + return false; + } + stack.push_back(memo[memo_idx]); + p += 4; + } break; + case 'q': // BINPUT = b'q' # write memo index, 1-byte arg + if (p >= end || stack.empty()) { + return false; + } + memo[*p++] = stack.back(); + break; + case 'r': // LONG_BINPUT = b'r' # write memo index, 4-byte arg + if (p + 4 > end || stack.empty()) { + return false; + } + memo[read_int(p)] = stack.back(); + p += 4; + break; + case 0x94: // MEMOIZE = b'\x94' # store top of stack in memo + if (stack.empty()) { + return false; + } + memo[(int32_t)memo.size()] = stack.back(); + break; + case 0x95: // FRAME = b'\x95' # indicate the beginning of a new frame + if (p + 8 > end) { + return false; + } + p += 8; + break; + case '0': // POP = b'0' # discard top stack item + if (stack.empty()) { + return false; + } + stack.pop_back(); + break; + case '1': { // POP_MARK = b'1' # discard stack through topmost mark + int mark_idx = (int)stack.size() - 1; + while (mark_idx >= 0 && stack[mark_idx].kind != PickleValue::MARK) { + --mark_idx; + } + if (mark_idx < 0) { + return false; + } + stack.erase(stack.begin() + mark_idx, stack.end()); + } break; + case '2': // DUP = b'2' # duplicate top stack item + if (stack.empty()) { + return false; + } + stack.push_back(stack.back()); + break; + case 0x8F: // EMPTY_SET = b'\x8f' # push empty set + stack.push_back(make_list_value()); + break; + case 0x90: { // ADDITEMS = b'\x90' # add mark-delimited items to set + int mark_idx = (int)stack.size() - 1; + while (mark_idx >= 0 && stack[mark_idx].kind != PickleValue::MARK) { + --mark_idx; + } + if (mark_idx <= 0 || stack[mark_idx - 1].kind != PickleValue::LIST) { + return false; + } + PickleValue& set_value = stack[mark_idx - 1]; + set_value.items.insert(set_value.items.end(), stack.begin() + mark_idx + 1, stack.end()); + stack.erase(stack.begin() + mark_idx, stack.end()); + } break; + case 0x91: { // FROZENSET = b'\x91' # build frozenset from mark + int mark_idx = (int)stack.size() - 1; + while (mark_idx >= 0 && stack[mark_idx].kind != PickleValue::MARK) { + --mark_idx; + } + if (mark_idx < 0) { + return false; + } + PickleValue set_value = make_list_value(); + set_value.items.insert(set_value.items.end(), stack.begin() + mark_idx + 1, stack.end()); + stack.erase(stack.begin() + mark_idx, stack.end()); + stack.push_back(std::move(set_value)); + } break; + case 0x85: // TUPLE1 = b'\x85' # build 1-tuple from stack + case 0x86: // TUPLE2 = b'\x86' # build 2-tuple from stack + case 0x87: { // TUPLE3 = b'\x87' # build 3-tuple from stack + int tuple_size = opcode == 0x85 ? 1 : (opcode == 0x86 ? 2 : 3); + if ((int)stack.size() < tuple_size) { + return false; + } + std::vector items(stack.end() - tuple_size, stack.end()); + stack.erase(stack.end() - tuple_size, stack.end()); + stack.push_back(make_tuple_value(std::move(items))); + } break; + case 't': { // TUPLE = b't' # build tuple from mark + int mark_idx = (int)stack.size() - 1; + while (mark_idx >= 0 && stack[mark_idx].kind != PickleValue::MARK) { + --mark_idx; + } + if (mark_idx < 0) { + return false; + } + std::vector items(stack.begin() + mark_idx + 1, stack.end()); + stack.erase(stack.begin() + mark_idx, stack.end()); + stack.push_back(make_tuple_value(std::move(items))); + } break; + case 'Q': { // BINPERSID = b'Q' # persistent id from stack + if (stack.empty()) { + return false; + } + PickleValue pid = stack.back(); + stack.pop_back(); + if (pid.kind != PickleValue::TUPLE || pid.items.size() < 5 || pid.items[0].kind != PickleValue::STRING || + pid.items[1].kind != PickleValue::GLOBAL || pid.items[4].kind != PickleValue::INT || + pid.items[0].str_value != "storage") { + return false; + } + + PickleStorageInfo storage; + storage.key = pickle_value_to_string(pid.items[2]); + if (storage.key.empty() || !parse_storage_type(pid.items[1].str_value, &storage)) { + return false; + } + storage.nbytes = (uint64_t)pid.items[4].int_value * storage.raw_element_nbytes; + storage_nbytes[storage.key] = storage.nbytes; + stack.push_back(make_storage_value(storage)); + } break; + case 'R': { // REDUCE = b'R' # apply callable to args + if (stack.size() < 2) { + return false; + } + PickleValue args = stack.back(); + stack.pop_back(); + PickleValue callable = stack.back(); + stack.pop_back(); + if (callable.kind != PickleValue::GLOBAL || args.kind != PickleValue::TUPLE) { + stack.push_back(make_none_value()); + break; + } + + if (callable.str_value == "collections.OrderedDict" && args.items.empty()) { + stack.push_back(make_dict_value(true)); + break; + } + + if ((callable.str_value == "torch._utils._rebuild_tensor_v2" || callable.str_value == "torch._utils._rebuild_tensor") && + args.items.size() >= 4 && args.items[0].kind == PickleValue::STORAGE && + args.items[1].kind == PickleValue::INT && args.items[2].kind == PickleValue::TUPLE && + args.items[3].kind == PickleValue::TUPLE) { + PickleTensorInfo tensor; + tensor.tensor_storage.type = args.items[0].storage.type; + tensor.tensor_storage.is_f64 = args.items[0].storage.is_f64; + tensor.tensor_storage.is_i64 = args.items[0].storage.is_i64; + tensor.tensor_storage.storage_key = args.items[0].storage.key; + tensor.tensor_storage.offset = (uint64_t)args.items[1].int_value * args.items[0].storage.raw_element_nbytes; + + for (const auto& item : args.items[2].items) { + if (item.kind != PickleValue::INT || tensor.tensor_storage.n_dims >= SD_MAX_DIMS) { + return false; + } + tensor.tensor_storage.ne[tensor.tensor_storage.n_dims++] = item.int_value; + } + + for (const auto& item : args.items[3].items) { + if (item.kind != PickleValue::INT || tensor.stride_n_dims >= SD_MAX_DIMS) { + return false; + } + tensor.stride[tensor.stride_n_dims++] = item.int_value; + } + + if (!tensor_is_contiguous(tensor)) { + return false; + } + stack.push_back(make_tensor_value(tensor)); + break; + } + + // Non-tensor checkpoint metadata can use REDUCE for arbitrary + // Python objects. Do not execute it; keep stack shape only. + stack.push_back(make_none_value()); + break; + } + case 'b': // BUILD = b'b' # build object state + if (stack.size() < 2) { + return false; + } + stack.pop_back(); + break; + case 'u': { // SETITEMS = b'u' # add mark-delimited items to dict + int mark_idx = (int)stack.size() - 1; + while (mark_idx >= 0 && stack[mark_idx].kind != PickleValue::MARK) { + --mark_idx; + } + if (mark_idx <= 0) { + return false; + } + PickleValue& dict = stack[mark_idx - 1]; + if (dict.kind != PickleValue::DICT && dict.kind != PickleValue::ORDERED_DICT) { + return false; + } + for (int i = mark_idx + 1; i + 1 < (int)stack.size(); i += 2) { + dict.dict_items.emplace_back(stack[i], stack[i + 1]); + } + stack.erase(stack.begin() + mark_idx, stack.end()); + } break; + case 's': { // SETITEM = b's' # add key/value to dict + if (stack.size() < 3) { + return false; + } + PickleValue value = stack.back(); + stack.pop_back(); + PickleValue key = stack.back(); + stack.pop_back(); + PickleValue& dict = stack.back(); + if (dict.kind != PickleValue::DICT && dict.kind != PickleValue::ORDERED_DICT) { + return false; + } + dict.dict_items.emplace_back(key, value); + } break; + case 'e': { // APPENDS = b'e' # extend list with mark-delimited items + int mark_idx = (int)stack.size() - 1; + while (mark_idx >= 0 && stack[mark_idx].kind != PickleValue::MARK) { + --mark_idx; + } + if (mark_idx <= 0 || stack[mark_idx - 1].kind != PickleValue::LIST) { + return false; + } + PickleValue& list_value = stack[mark_idx - 1]; + list_value.items.insert(list_value.items.end(), stack.begin() + mark_idx + 1, stack.end()); + stack.erase(stack.begin() + mark_idx, stack.end()); + } break; + case 'a': { // APPEND = b'a' # append item to list + if (stack.size() < 2) { + return false; + } + PickleValue item = stack.back(); + stack.pop_back(); + if (stack.back().kind != PickleValue::LIST) { + return false; + } + stack.back().items.push_back(item); + } break; + default: + set_error(error, + "unsupported torch pickle opcode 0x" + sd_format("%02X", opcode) + + " at offset " + std::to_string((p - buffer) - 1)); + return false; + } + } + + set_error(error, "unterminated torch state_dict pickle"); + return false; +} diff --git a/src/model_io/pickle_io.h b/src/model_io/pickle_io.h new file mode 100644 index 000000000..6a3db37b9 --- /dev/null +++ b/src/model_io/pickle_io.h @@ -0,0 +1,21 @@ +#ifndef __SD_MODEL_IO_PICKLE_IO_H__ +#define __SD_MODEL_IO_PICKLE_IO_H__ + +#include +#include +#include +#include +#include + +#include "tensor_storage.h" + +bool skip_pickle_object(const uint8_t* buffer, size_t buffer_size, size_t* object_size); +bool pickle_object_is_torch_magic_number(const uint8_t* buffer, size_t buffer_size); +bool parse_pickle_uint32_object(const uint8_t* buffer, size_t buffer_size, uint32_t* value); +bool parse_torch_state_dict_pickle(const uint8_t* buffer, + size_t buffer_size, + std::vector& tensor_storages, + std::unordered_map& storage_nbytes, + std::string* error = nullptr); + +#endif // __SD_MODEL_IO_PICKLE_IO_H__ diff --git a/src/model_io/safetensors_io.cpp b/src/model_io/safetensors_io.cpp new file mode 100644 index 000000000..889352218 --- /dev/null +++ b/src/model_io/safetensors_io.cpp @@ -0,0 +1,316 @@ +#include "safetensors_io.h" + +#include +#include +#include +#include +#include + +#include "binary_io.h" +#include "json.hpp" +#include "util.h" + +static constexpr size_t ST_HEADER_SIZE_LEN = 8; + +static void set_error(std::string* error, const std::string& message) { + if (error != nullptr) { + *error = message; + } +} + +bool is_safetensors_file(const std::string& file_path) { + std::ifstream file(file_path, std::ios::binary); + if (!file.is_open()) { + return false; + } + + // get file size + file.seekg(0, file.end); + size_t file_size_ = file.tellg(); + file.seekg(0, file.beg); + + // read header size + if (file_size_ <= ST_HEADER_SIZE_LEN) { + return false; + } + + uint8_t header_size_buf[ST_HEADER_SIZE_LEN]; + file.read((char*)header_size_buf, ST_HEADER_SIZE_LEN); + if (!file) { + return false; + } + + size_t header_size_ = model_io::read_u64(header_size_buf); + if (header_size_ >= file_size_ || header_size_ <= 2) { + return false; + } + + // read header + std::vector header_buf; + header_buf.resize(header_size_ + 1); + header_buf[header_size_] = '\0'; + file.read(header_buf.data(), header_size_); + if (!file) { + return false; + } + try { + nlohmann::json header_ = nlohmann::json::parse(header_buf.data()); + } catch (const std::exception&) { + return false; + } + return true; +} + +static ggml_type safetensors_dtype_to_ggml_type(const std::string& dtype) { + ggml_type ttype = GGML_TYPE_COUNT; + if (dtype == "F16") { + ttype = GGML_TYPE_F16; + } else if (dtype == "BF16") { + ttype = GGML_TYPE_BF16; + } else if (dtype == "F32") { + ttype = GGML_TYPE_F32; + } else if (dtype == "F64") { + ttype = GGML_TYPE_F32; + } else if (dtype == "F8_E4M3") { + ttype = GGML_TYPE_F16; + } else if (dtype == "F8_E5M2") { + ttype = GGML_TYPE_F16; + } else if (dtype == "I32") { + ttype = GGML_TYPE_I32; + } else if (dtype == "I64") { + ttype = GGML_TYPE_I32; + } + return ttype; +} + +// https://huggingface.co/docs/safetensors/index +bool read_safetensors_file(const std::string& file_path, + std::vector& tensor_storages, + std::string* error) { + std::ifstream file(file_path, std::ios::binary); + if (!file.is_open()) { + set_error(error, "failed to open '" + file_path + "'"); + return false; + } + + // get file size + file.seekg(0, file.end); + size_t file_size_ = file.tellg(); + file.seekg(0, file.beg); + + // read header size + if (file_size_ <= ST_HEADER_SIZE_LEN) { + set_error(error, "invalid safetensor file '" + file_path + "'"); + return false; + } + + uint8_t header_size_buf[ST_HEADER_SIZE_LEN]; + file.read((char*)header_size_buf, ST_HEADER_SIZE_LEN); + if (!file) { + set_error(error, "read safetensors header size failed: '" + file_path + "'"); + return false; + } + + size_t header_size_ = model_io::read_u64(header_size_buf); + if (header_size_ >= file_size_) { + set_error(error, "invalid safetensor file '" + file_path + "'"); + return false; + } + + // read header + std::vector header_buf; + header_buf.resize(header_size_ + 1); + header_buf[header_size_] = '\0'; + file.read(header_buf.data(), header_size_); + if (!file) { + set_error(error, "read safetensors header failed: '" + file_path + "'"); + return false; + } + + nlohmann::json header_; + try { + header_ = nlohmann::json::parse(header_buf.data()); + } catch (const std::exception&) { + set_error(error, "parsing safetensors header failed: '" + file_path + "'"); + return false; + } + + tensor_storages.clear(); + for (auto& item : header_.items()) { + std::string name = item.key(); + nlohmann::json tensor_info = item.value(); + // LOG_DEBUG("%s %s\n", name.c_str(), tensor_info.dump().c_str()); + + if (name == "__metadata__") { + continue; + } + + std::string dtype = tensor_info["dtype"]; + nlohmann::json shape = tensor_info["shape"]; + + if (dtype == "U8") { + continue; + } + + size_t begin = tensor_info["data_offsets"][0].get(); + size_t end = tensor_info["data_offsets"][1].get(); + + ggml_type type = safetensors_dtype_to_ggml_type(dtype); + if (type == GGML_TYPE_COUNT) { + set_error(error, "unsupported dtype '" + dtype + "' (tensor '" + name + "')"); + return false; + } + + if (shape.size() > SD_MAX_DIMS) { + set_error(error, "invalid tensor '" + name + "'"); + return false; + } + + int n_dims = (int)shape.size(); + int64_t ne[SD_MAX_DIMS] = {1, 1, 1, 1, 1}; + for (int i = 0; i < n_dims; i++) { + ne[i] = shape[i].get(); + } + + if (n_dims == 5) { + n_dims = 4; + ne[0] = ne[0] * ne[1]; + ne[1] = ne[2]; + ne[2] = ne[3]; + ne[3] = ne[4]; + } + + // ggml_n_dims returns 1 for scalars + if (n_dims == 0) { + n_dims = 1; + } + + TensorStorage tensor_storage(name, type, ne, n_dims, 0, ST_HEADER_SIZE_LEN + header_size_ + begin); + tensor_storage.reverse_ne(); + + size_t tensor_data_size = end - begin; + + bool tensor_size_ok; + if (dtype == "F8_E4M3") { + tensor_storage.is_f8_e4m3 = true; + // f8 -> f16 + tensor_size_ok = (tensor_storage.nbytes() == tensor_data_size * 2); + } else if (dtype == "F8_E5M2") { + tensor_storage.is_f8_e5m2 = true; + // f8 -> f16 + tensor_size_ok = (tensor_storage.nbytes() == tensor_data_size * 2); + } else if (dtype == "F64") { + tensor_storage.is_f64 = true; + // f64 -> f32 + tensor_size_ok = (tensor_storage.nbytes() * 2 == tensor_data_size); + } else if (dtype == "I64") { + tensor_storage.is_i64 = true; + // i64 -> i32 + tensor_size_ok = (tensor_storage.nbytes() * 2 == tensor_data_size); + } else { + tensor_size_ok = (tensor_storage.nbytes() == tensor_data_size); + } + if (!tensor_size_ok) { + set_error(error, "size mismatch for tensor '" + name + "' (" + dtype + ")"); + return false; + } + + tensor_storages.push_back(tensor_storage); + + // LOG_DEBUG("%s %s", tensor_storage.to_string().c_str(), dtype.c_str()); + } + + return true; +} + +static bool ggml_type_to_safetensors_dtype(ggml_type type, std::string* dtype) { + switch (type) { + case GGML_TYPE_F16: + *dtype = "F16"; + return true; + case GGML_TYPE_BF16: + *dtype = "BF16"; + return true; + case GGML_TYPE_F32: + *dtype = "F32"; + return true; + case GGML_TYPE_I32: + *dtype = "I32"; + return true; + default: + return false; + } +} + +bool write_safetensors_file(const std::string& file_path, + const std::vector& tensors, + std::string* error) { + nlohmann::ordered_json header = nlohmann::ordered_json::object(); + + uint64_t data_offset = 0; + for (const TensorWriteInfo& write_tensor : tensors) { + ggml_tensor* tensor = write_tensor.tensor; + if (tensor == nullptr) { + set_error(error, "null tensor cannot be written to safetensors"); + return false; + } + + const std::string name = ggml_get_name(tensor); + std::string dtype; + if (!ggml_type_to_safetensors_dtype(tensor->type, &dtype)) { + set_error(error, + "unsupported safetensors dtype '" + std::string(ggml_type_name(tensor->type)) + + "' for tensor '" + name + "'"); + return false; + } + + const uint64_t tensor_nbytes = ggml_nbytes(tensor); + + nlohmann::ordered_json json_tensor_info = nlohmann::ordered_json::object(); + json_tensor_info["dtype"] = dtype; + + nlohmann::ordered_json shape = nlohmann::ordered_json::array(); + for (int i = 0; i < write_tensor.n_dims; ++i) { + shape.push_back(write_tensor.ne[write_tensor.n_dims - 1 - i]); + } + json_tensor_info["shape"] = shape; + + nlohmann::ordered_json data_offsets = nlohmann::ordered_json::array(); + data_offsets.push_back(data_offset); + data_offsets.push_back(data_offset + tensor_nbytes); + json_tensor_info["data_offsets"] = data_offsets; + + header[name] = json_tensor_info; + data_offset += tensor_nbytes; + } + + const std::string header_str = header.dump(); + + std::ofstream file(file_path, std::ios::binary); + if (!file.is_open()) { + set_error(error, "failed to open '" + file_path + "' for writing"); + return false; + } + + LOG_INFO("trying to save tensors to %s", file_path.c_str()); + model_io::write_u64(file, header_str.size()); + file.write(header_str.data(), header_str.size()); + if (!file) { + set_error(error, "failed to write safetensors header to '" + file_path + "'"); + return false; + } + + for (const TensorWriteInfo& write_tensor : tensors) { + ggml_tensor* tensor = write_tensor.tensor; + const std::string name = ggml_get_name(tensor); + const size_t tensor_nbytes = ggml_nbytes(tensor); + file.write((const char*)tensor->data, tensor_nbytes); + if (!file) { + set_error(error, + "failed to write tensor '" + name + "' to '" + file_path + "'"); + return false; + } + } + + return true; +} diff --git a/src/model_io/safetensors_io.h b/src/model_io/safetensors_io.h new file mode 100644 index 000000000..08a1bc1f3 --- /dev/null +++ b/src/model_io/safetensors_io.h @@ -0,0 +1,17 @@ +#ifndef __SD_MODEL_IO_SAFETENSORS_IO_H__ +#define __SD_MODEL_IO_SAFETENSORS_IO_H__ + +#include +#include + +#include "tensor_storage.h" + +bool is_safetensors_file(const std::string& file_path); +bool read_safetensors_file(const std::string& file_path, + std::vector& tensor_storages, + std::string* error = nullptr); +bool write_safetensors_file(const std::string& file_path, + const std::vector& tensors, + std::string* error = nullptr); + +#endif // __SD_MODEL_IO_SAFETENSORS_IO_H__ diff --git a/src/model_io/tensor_storage.h b/src/model_io/tensor_storage.h new file mode 100644 index 000000000..c0cf079c5 --- /dev/null +++ b/src/model_io/tensor_storage.h @@ -0,0 +1,132 @@ +#ifndef __SD_TENSOR_STORAGE_H__ +#define __SD_TENSOR_STORAGE_H__ + +#include +#include +#include +#include +#include +#include +#include + +#include "ggml.h" + +#define SD_MAX_DIMS 5 + +struct TensorStorage { + std::string name; + ggml_type type = GGML_TYPE_F32; + ggml_type expected_type = GGML_TYPE_COUNT; + bool is_f8_e4m3 = false; + bool is_f8_e5m2 = false; + bool is_f64 = false; + bool is_i64 = false; + int64_t ne[SD_MAX_DIMS] = {1, 1, 1, 1, 1}; + int n_dims = 0; + + std::string storage_key; + size_t file_index = 0; + int index_in_zip = -1; // >= means stored in a zip file + uint64_t offset = 0; // offset in file + + TensorStorage() = default; + + TensorStorage(std::string name, ggml_type type, const int64_t* ne, int n_dims, size_t file_index, size_t offset = 0) + : name(std::move(name)), type(type), n_dims(n_dims), file_index(file_index), offset(offset) { + for (int i = 0; i < n_dims; i++) { + this->ne[i] = ne[i]; + } + } + + int64_t nelements() const { + int64_t n = 1; + for (int i = 0; i < SD_MAX_DIMS; i++) { + n *= ne[i]; + } + return n; + } + + int64_t nbytes() const { + return nelements() * ggml_type_size(type) / ggml_blck_size(type); + } + + int64_t nbytes_to_read() const { + if (is_f8_e4m3 || is_f8_e5m2) { + return nbytes() / 2; + } else if (is_f64 || is_i64) { + return nbytes() * 2; + } else { + return nbytes(); + } + } + + void unsqueeze() { + if (n_dims == 2) { + n_dims = 4; + ne[3] = ne[1]; + ne[2] = ne[0]; + ne[1] = 1; + ne[0] = 1; + } + } + + std::vector chunk(size_t n) { + std::vector chunks; + uint64_t chunk_size = nbytes_to_read() / n; + // printf("%d/%d\n", chunk_size, nbytes_to_read()); + reverse_ne(); + for (size_t i = 0; i < n; i++) { + TensorStorage chunk_i = *this; + chunk_i.ne[0] = ne[0] / n; + chunk_i.offset = offset + i * chunk_size; + chunk_i.reverse_ne(); + chunks.push_back(chunk_i); + } + reverse_ne(); + return chunks; + } + + void reverse_ne() { + int64_t new_ne[SD_MAX_DIMS] = {1, 1, 1, 1, 1}; + for (int i = 0; i < n_dims; i++) { + new_ne[i] = ne[n_dims - 1 - i]; + } + for (int i = 0; i < n_dims; i++) { + ne[i] = new_ne[i]; + } + } + + std::string to_string() const { + std::stringstream ss; + const char* type_name = ggml_type_name(type); + if (is_f8_e4m3) { + type_name = "f8_e4m3"; + } else if (is_f8_e5m2) { + type_name = "f8_e5m2"; + } else if (is_f64) { + type_name = "f64"; + } else if (is_i64) { + type_name = "i64"; + } + ss << name << " | " << type_name << " | "; + ss << n_dims << " ["; + for (int i = 0; i < SD_MAX_DIMS; i++) { + ss << ne[i]; + if (i != SD_MAX_DIMS - 1) { + ss << ", "; + } + } + ss << "]"; + return ss.str(); + } +}; + +struct TensorWriteInfo { + int64_t ne[SD_MAX_DIMS] = {1, 1, 1, 1, 1}; + int n_dims = 0; + ggml_tensor* tensor = nullptr; +}; + +typedef std::function on_new_tensor_cb_t; + +#endif // __SD_TENSOR_STORAGE_H__ diff --git a/src/model_io/torch_legacy_io.cpp b/src/model_io/torch_legacy_io.cpp new file mode 100644 index 000000000..816547252 --- /dev/null +++ b/src/model_io/torch_legacy_io.cpp @@ -0,0 +1,252 @@ +#include "torch_legacy_io.h" + +#include +#include +#include +#include +#include +#include + +#include "pickle_io.h" +#include "util.h" + +// torch.save format background: +// +// - Before PyTorch 1.6.0, torch.save used this legacy non-zip format by +// default. +// - Since PyTorch 1.6.0, torch.save defaults to an uncompressed ZIP64 archive +// containing data.pkl, data/, version, and, since PyTorch 2.1.0, byteorder. +// - The old format can still be produced explicitly with: +// torch.save(obj, path, _use_new_zipfile_serialization=False) +// +// Whether obj is a state_dict or a whole nn.Module does not change the outer +// container format selected by torch.save. It changes the pickled object inside: +// +// - state_dict: usually an OrderedDict[str, Tensor]. pickle_io.cpp supports a +// restricted subset of this layout because tensor metadata and raw storages +// can be recovered without executing pickle callables. +// - whole module/checkpoint object: arbitrary Python object graph. This may +// require importing user classes and executing pickle GLOBAL/REDUCE rebuild +// logic, so it is intentionally not supported here. +// +// Legacy non-zip PyTorch files are not a single pickle object: +// +// 1. pickle object: PyTorch legacy magic number +// 2. pickle object: legacy protocol version, expected to be 1001 +// 3. pickle object: sys_info metadata, ignored by this reader +// 4. pickle object: state_dict metadata, parsed by pickle_io.cpp +// 5. pickle object: serialized storage key list, skipped here +// 6. raw storage data payloads +// - PyTorch writes storages after the pickles, ordered by storage key +// - each storage has an 8-byte legacy storage header followed by raw bytes +static constexpr size_t LEGACY_STORAGE_HEADER_SIZE = 8; + +static void set_error(std::string* error, const std::string& message) { + if (error != nullptr) { + *error = message; + } +} + +static std::string bytes_to_hex(const std::vector& bytes) { + static const char* hex = "0123456789ABCDEF"; + std::string result; + result.reserve(bytes.size() * 3); + for (size_t i = 0; i < bytes.size(); ++i) { + if (i > 0) { + result.push_back('-'); + } + result.push_back(hex[(bytes[i] >> 4) & 0x0F]); + result.push_back(hex[bytes[i] & 0x0F]); + } + return result; +} + +static bool is_probably_tar_file(const std::vector& header) { + return header.size() >= 262 && + header[257] == 'u' && + header[258] == 's' && + header[259] == 't' && + header[260] == 'a' && + header[261] == 'r'; +} + +static std::string torch_legacy_diagnostics(const std::string& file_path, const std::vector& buffer) { + if (!ends_with(file_path, ".pt") && !ends_with(file_path, ".pth")) { + return ""; + } + if (buffer.empty()) { + return "unsupported PyTorch file '" + file_path + "': empty file"; + } + + size_t short_len = std::min(buffer.size(), 32); + std::vector short_header(buffer.begin(), buffer.begin() + short_len); + const bool raw_pickle = buffer[0] == 0x80; + const bool tar_file = is_probably_tar_file(buffer); + + std::string message = "unsupported PyTorch file '" + file_path + "': first bytes " + + bytes_to_hex(short_header) + + ", raw_pickle=" + (raw_pickle ? "true" : "false") + + ", tar=" + (tar_file ? "true" : "false"); + if (raw_pickle) { + message += "; raw pickle did not match the restricted state_dict layouts currently supported"; + } else if (tar_file) { + message += "; legacy tar PyTorch checkpoints are not supported yet"; + } + return message; +} + +bool read_torch_legacy_file(const std::string& file_path, + std::vector& tensor_storages, + std::string* error) { + std::ifstream file(file_path, std::ios::binary); + if (!file.is_open()) { + set_error(error, "failed to open '" + file_path + "'"); + return false; + } + + file.seekg(0, file.end); + size_t file_size = (size_t)file.tellg(); + file.seekg(0, file.beg); + if (file_size == 0) { + set_error(error, "empty file '" + file_path + "'"); + return false; + } + + std::vector buffer(file_size); + file.read((char*)buffer.data(), file_size); + if (!file) { + set_error(error, "failed to read '" + file_path + "'"); + return false; + } + + auto finalize_tensor_offsets = [&](size_t storage_data_offset, + const std::unordered_map& legacy_storage_map) -> bool { + if (storage_data_offset > file_size) { + return false; + } + + std::vector storage_keys; + storage_keys.reserve(legacy_storage_map.size()); + for (const auto& [storage_key, _] : legacy_storage_map) { + storage_keys.push_back(storage_key); + } + std::sort(storage_keys.begin(), storage_keys.end()); + + std::unordered_map storage_offsets; + uint64_t current_offset = storage_data_offset; + for (const auto& storage_key : storage_keys) { + auto it = legacy_storage_map.find(storage_key); + if (it == legacy_storage_map.end()) { + return false; + } + if (current_offset + LEGACY_STORAGE_HEADER_SIZE + it->second > file_size) { + return false; + } + storage_offsets[storage_key] = current_offset + LEGACY_STORAGE_HEADER_SIZE; + current_offset += LEGACY_STORAGE_HEADER_SIZE + it->second; + } + + for (auto& tensor_storage : tensor_storages) { + if (tensor_storage.storage_key.empty()) { + continue; + } + + auto it_offset = storage_offsets.find(tensor_storage.storage_key); + auto it_size = legacy_storage_map.find(tensor_storage.storage_key); + if (it_offset == storage_offsets.end() || it_size == legacy_storage_map.end()) { + return false; + } + + uint64_t base_offset = it_offset->second; + uint64_t storage_nbytes = it_size->second; + uint64_t tensor_nbytes = tensor_storage.nbytes_to_read(); + if (tensor_storage.offset + tensor_nbytes > storage_nbytes) { + return false; + } + + tensor_storage.offset = base_offset + tensor_storage.offset; + tensor_storage.storage_key.clear(); + } + + return true; + }; + + auto parse_state_dict_at = [&](size_t state_dict_offset, size_t state_dict_size, size_t* storage_data_offset) -> bool { + tensor_storages.clear(); + std::unordered_map legacy_storage_map; + if (!parse_torch_state_dict_pickle(buffer.data() + state_dict_offset, + state_dict_size, + tensor_storages, + legacy_storage_map, + error)) { + return false; + } + + size_t offset_after_state_dict = state_dict_offset + state_dict_size; + size_t storage_keys_size = 0; + if (!skip_pickle_object(buffer.data() + offset_after_state_dict, + buffer.size() - offset_after_state_dict, + &storage_keys_size)) { + return false; + } + + *storage_data_offset = offset_after_state_dict + storage_keys_size; + return finalize_tensor_offsets(*storage_data_offset, legacy_storage_map); + }; + + size_t object_size_1 = 0; + size_t offset = 0; + + if (skip_pickle_object(buffer.data(), buffer.size(), &object_size_1) && + pickle_object_is_torch_magic_number(buffer.data(), object_size_1)) { + offset += object_size_1; + + size_t object_size_2 = 0; + if (!skip_pickle_object(buffer.data() + offset, buffer.size() - offset, &object_size_2)) { + set_error(error, torch_legacy_diagnostics(file_path, buffer)); + return false; + } + uint32_t protocol_version = 0; + if (!parse_pickle_uint32_object(buffer.data() + offset, object_size_2, &protocol_version) || protocol_version != 1001) { + set_error(error, torch_legacy_diagnostics(file_path, buffer)); + return false; + } + offset += object_size_2; + + size_t object_size_3 = 0; + if (!skip_pickle_object(buffer.data() + offset, buffer.size() - offset, &object_size_3)) { + set_error(error, torch_legacy_diagnostics(file_path, buffer)); + return false; + } + offset += object_size_3; + + size_t state_dict_size = 0; + if (!skip_pickle_object(buffer.data() + offset, buffer.size() - offset, &state_dict_size)) { + set_error(error, torch_legacy_diagnostics(file_path, buffer)); + return false; + } + + size_t storage_data_offset = 0; + if (parse_state_dict_at(offset, state_dict_size, &storage_data_offset)) { + return true; + } + + if (error != nullptr && error->empty()) { + set_error(error, torch_legacy_diagnostics(file_path, buffer)); + } + return false; + } + + size_t state_dict_size = 0; + if (skip_pickle_object(buffer.data(), buffer.size(), &state_dict_size)) { + size_t storage_data_offset = 0; + if (parse_state_dict_at(0, state_dict_size, &storage_data_offset)) { + return true; + } + } + + if (error != nullptr && error->empty()) { + set_error(error, torch_legacy_diagnostics(file_path, buffer)); + } + return false; +} diff --git a/src/model_io/torch_legacy_io.h b/src/model_io/torch_legacy_io.h new file mode 100644 index 000000000..6680e02a1 --- /dev/null +++ b/src/model_io/torch_legacy_io.h @@ -0,0 +1,13 @@ +#ifndef __SD_MODEL_IO_TORCH_LEGACY_IO_H__ +#define __SD_MODEL_IO_TORCH_LEGACY_IO_H__ + +#include +#include + +#include "tensor_storage.h" + +bool read_torch_legacy_file(const std::string& file_path, + std::vector& tensor_storages, + std::string* error = nullptr); + +#endif // __SD_MODEL_IO_TORCH_LEGACY_IO_H__ diff --git a/src/model_io/torch_zip_io.cpp b/src/model_io/torch_zip_io.cpp new file mode 100644 index 000000000..9eaf6c53a --- /dev/null +++ b/src/model_io/torch_zip_io.cpp @@ -0,0 +1,140 @@ +#include "torch_zip_io.h" + +#include +#include +#include +#include +#include + +#include "pickle_io.h" + +#include "zip.h" + +static void set_error(std::string* error, const std::string& message) { + if (error != nullptr) { + *error = message; + } +} + +bool is_torch_zip_file(const std::string& file_path) { + zip_t* zip = zip_open(file_path.c_str(), 0, 'r'); + if (zip == nullptr) { + return false; + } + zip_close(zip); + return true; +} + +static bool find_zip_entry(zip_t* zip, const std::string& entry_name, int* index, uint64_t* size) { + size_t n = zip_entries_total(zip); + for (size_t i = 0; i < n; ++i) { + zip_entry_openbyindex(zip, i); + std::string name = zip_entry_name(zip); + if (name == entry_name) { + *index = (int)i; + *size = zip_entry_size(zip); + zip_entry_close(zip); + return true; + } + zip_entry_close(zip); + } + return false; +} + +static bool parse_zip_data_pkl(const uint8_t* buffer, + size_t buffer_size, + zip_t* zip, + const std::string& dir, + std::vector& tensor_storages, + std::string* error) { + std::vector parsed_tensors; + std::unordered_map storage_nbytes; + if (!parse_torch_state_dict_pickle(buffer, buffer_size, parsed_tensors, storage_nbytes, error)) { + if (error != nullptr && error->empty()) { + *error = "failed to parse torch zip pickle metadata"; + } + return false; + } + + for (auto& tensor_storage : parsed_tensors) { + if (tensor_storage.storage_key.empty()) { + set_error(error, "tensor '" + tensor_storage.name + "' has no storage key"); + return false; + } + + const std::string entry_name = dir + "data/" + tensor_storage.storage_key; + int zip_index = -1; + uint64_t entry_size = 0; + if (!find_zip_entry(zip, entry_name, &zip_index, &entry_size)) { + set_error(error, "storage entry '" + entry_name + "' was not found"); + return false; + } + + auto it_storage_size = storage_nbytes.find(tensor_storage.storage_key); + if (it_storage_size != storage_nbytes.end() && entry_size < it_storage_size->second) { + set_error(error, "storage entry '" + entry_name + "' is smaller than pickle metadata"); + return false; + } + + uint64_t tensor_nbytes = tensor_storage.nbytes_to_read(); + if (tensor_storage.offset + tensor_nbytes > entry_size) { + set_error(error, "tensor '" + tensor_storage.name + "' exceeds storage entry '" + entry_name + "'"); + return false; + } + + tensor_storage.index_in_zip = zip_index; + tensor_storage.storage_key.clear(); + tensor_storages.push_back(tensor_storage); + } + + return true; +} + +bool read_torch_zip_file(const std::string& file_path, + std::vector& tensor_storages, + std::string* error) { + zip_t* zip = zip_open(file_path.c_str(), 0, 'r'); + if (zip == nullptr) { + set_error(error, "failed to open '" + file_path + "'"); + return false; + } + + tensor_storages.clear(); + bool success = true; + bool found_data_pkl = false; + int n = (int)zip_entries_total(zip); + for (int i = 0; i < n; ++i) { + zip_entry_openbyindex(zip, i); + std::string name = zip_entry_name(zip); + size_t pos = name.find("data.pkl"); + if (pos != std::string::npos) { + found_data_pkl = true; + std::string dir = name.substr(0, pos); + void* pkl_data = nullptr; + size_t pkl_size = 0; + zip_entry_read(zip, &pkl_data, &pkl_size); + + if (pkl_data == nullptr || pkl_size == 0) { + set_error(error, "failed to read '" + name + "' from '" + file_path + "'"); + success = false; + } else if (!parse_zip_data_pkl((const uint8_t*)pkl_data, pkl_size, zip, dir, tensor_storages, error)) { + success = false; + } + + free(pkl_data); + } + zip_entry_close(zip); + + if (!success) { + break; + } + } + + if (success && !found_data_pkl) { + set_error(error, "data.pkl was not found in '" + file_path + "'"); + success = false; + } + + zip_close(zip); + return success; +} diff --git a/src/model_io/torch_zip_io.h b/src/model_io/torch_zip_io.h new file mode 100644 index 000000000..54fb099a7 --- /dev/null +++ b/src/model_io/torch_zip_io.h @@ -0,0 +1,14 @@ +#ifndef __SD_MODEL_IO_TORCH_ZIP_IO_H__ +#define __SD_MODEL_IO_TORCH_ZIP_IO_H__ + +#include +#include + +#include "tensor_storage.h" + +bool is_torch_zip_file(const std::string& file_path); +bool read_torch_zip_file(const std::string& file_path, + std::vector& tensor_storages, + std::string* error = nullptr); + +#endif // __SD_MODEL_IO_TORCH_ZIP_IO_H__ diff --git a/src/name_conversion.cpp b/src/name_conversion.cpp index d5d5e052c..618c7f6e9 100644 --- a/src/name_conversion.cpp +++ b/src/name_conversion.cpp @@ -1120,7 +1120,7 @@ std::string convert_tensor_name(std::string name, SDVersion version) { for (const auto& prefix : first_stage_model_prefix_vec) { if (starts_with(name, prefix)) { name = convert_first_stage_model_name(name.substr(prefix.size()), prefix); - if (version == VERSION_SDXS) { + if (version == VERSION_SDXS_512_DS || version == VERSION_SDXS_09) { name = "tae." + name; } else { name = prefix + name; diff --git a/src/rope.hpp b/src/rope.hpp index db577f5d3..f84fac885 100644 --- a/src/rope.hpp +++ b/src/rope.hpp @@ -7,6 +7,11 @@ #include "ggml_extend.hpp" namespace Rope { + enum class EmbedNDLayout { + Matrix, + ErnieImage, + }; + template __STATIC_INLINE__ std::vector linspace(T start, T end, int num) { std::vector result(num); @@ -169,7 +174,8 @@ namespace Rope { int bs, const std::vector& axis_thetas, const std::vector& axes_dim, - const std::vector>& wrap_dims = {}) { + const std::vector>& wrap_dims = {}, + EmbedNDLayout layout = EmbedNDLayout::Matrix) { std::vector> trans_ids = transpose(ids); size_t pos_len = ids.size() / bs; size_t num_axes = axes_dim.size(); @@ -204,6 +210,24 @@ namespace Rope { offset += rope_emb[0].size(); } + if (layout == EmbedNDLayout::ErnieImage) { + int head_dim = emb_dim * 2; + std::vector ernie_emb(bs * pos_len * head_dim * 2, 0.0f); + for (size_t pos_idx = 0; pos_idx < bs * pos_len; ++pos_idx) { + for (int i = 0; i < emb_dim; ++i) { + float cos_val = emb[pos_idx][4 * i]; + float sin_val = emb[pos_idx][4 * i + 2]; + size_t cos_offset = pos_idx * head_dim + 2 * i; + size_t sin_offset = bs * pos_len * head_dim + cos_offset; + ernie_emb[cos_offset] = cos_val; + ernie_emb[cos_offset + 1] = cos_val; + ernie_emb[sin_offset] = sin_val; + ernie_emb[sin_offset + 1] = sin_val; + } + } + return ernie_emb; + } + return flatten(emb); } @@ -211,9 +235,10 @@ namespace Rope { int bs, float theta, const std::vector& axes_dim, - const std::vector>& wrap_dims = {}) { + const std::vector>& wrap_dims = {}, + EmbedNDLayout layout = EmbedNDLayout::Matrix) { std::vector axis_thetas(axes_dim.size(), theta); - return embed_nd(ids, bs, axis_thetas, axes_dim, wrap_dims); + return embed_nd(ids, bs, axis_thetas, axes_dim, wrap_dims, layout); } __STATIC_INLINE__ std::vector> gen_refs_ids(int patch_size, @@ -437,6 +462,74 @@ namespace Rope { return embed_nd(ids, bs, static_cast(theta), axes_dim, wrap_dims); } + __STATIC_INLINE__ std::vector> gen_ernie_image_ids(int h, + int w, + int patch_size, + int bs, + int context_len) { + int h_len = h / patch_size; + int w_len = w / patch_size; + + std::vector> img_ids(h_len * w_len, std::vector(3, 0.0f)); + std::vector h_ids = linspace(0.f, static_cast(h_len - 1), h_len); + std::vector w_ids = linspace(0.f, static_cast(w_len - 1), w_len); + for (int i = 0; i < h_len; ++i) { + for (int j = 0; j < w_len; ++j) { + img_ids[i * w_len + j][0] = static_cast(context_len); + img_ids[i * w_len + j][1] = h_ids[i]; + img_ids[i * w_len + j][2] = w_ids[j]; + } + } + + std::vector> img_ids_repeated(bs * img_ids.size(), std::vector(3, 0.0f)); + for (int i = 0; i < bs; ++i) { + for (int j = 0; j < static_cast(img_ids.size()); ++j) { + img_ids_repeated[i * img_ids.size() + j] = img_ids[j]; + } + } + + std::vector> txt_ids(bs * context_len, std::vector(3, 0.0f)); + for (int i = 0; i < bs; ++i) { + for (int j = 0; j < context_len; ++j) { + txt_ids[i * context_len + j][0] = static_cast(j); + } + } + + return concat_ids(img_ids_repeated, txt_ids, bs); + } + + __STATIC_INLINE__ std::vector gen_ernie_image_pe(int h, + int w, + int patch_size, + int bs, + int context_len, + int theta, + bool circular_h, + bool circular_w, + const std::vector& axes_dim) { + std::vector> ids = gen_ernie_image_ids(h, w, patch_size, bs, context_len); + std::vector> wrap_dims; + if ((circular_h || circular_w) && bs > 0 && axes_dim.size() >= 3) { + int h_len = h / patch_size; + int w_len = w / patch_size; + if (h_len > 0 && w_len > 0) { + size_t pos_len = ids.size() / bs; + wrap_dims.assign(axes_dim.size(), std::vector(pos_len, 0)); + const size_t img_tokens = static_cast(h_len) * static_cast(w_len); + for (size_t token_i = 0; token_i < img_tokens; ++token_i) { + if (circular_h) { + wrap_dims[1][token_i] = h_len; + } + if (circular_w) { + wrap_dims[2][token_i] = w_len; + } + } + } + } + + return embed_nd(ids, bs, static_cast(theta), axes_dim, wrap_dims, EmbedNDLayout::ErnieImage); + } + __STATIC_INLINE__ std::vector> gen_vid_ids(int t, int h, int w, diff --git a/src/stable-diffusion.cpp b/src/stable-diffusion.cpp index 683a07d53..60e793fa7 100644 --- a/src/stable-diffusion.cpp +++ b/src/stable-diffusion.cpp @@ -30,7 +30,8 @@ const char* model_version_to_str[] = { "SD 2.x", "SD 2.x Inpaint", "SD 2.x Tiny UNet", - "SDXS", + "SDXS (512-DS)", + "SDXS (09)", "SDXL", "SDXL Inpaint", "SDXL Instruct-Pix2Pix", @@ -52,6 +53,7 @@ const char* model_version_to_str[] = { "Flux.2 klein", "Z-Image", "Ovis Image", + "Ernie Image", }; const char* sampling_methods_str[] = { @@ -69,6 +71,7 @@ const char* sampling_methods_str[] = { "TCD", "Res Multistep", "Res 2s", + "ER-SDE", }; /*================================================== Helper Functions ================================================*/ @@ -413,7 +416,7 @@ class StableDiffusionGGML { } bool tae_preview_only = sd_ctx_params->tae_preview_only; - if (version == VERSION_SDXS) { + if (version == VERSION_SDXS_512_DS || version == VERSION_SDXS_09) { tae_preview_only = false; use_tae = true; } @@ -551,6 +554,15 @@ class StableDiffusionGGML { tensor_storage_map, "model.diffusion_model", version); + } else if (sd_version_is_ernie_image(version)) { + cond_stage_model = std::make_shared(clip_backend, + offload_params_to_cpu, + tensor_storage_map, + version); + diffusion_model = std::make_shared(backend, + offload_params_to_cpu, + tensor_storage_map, + "model.diffusion_model"); } else { // SD1.x SD2.x SDXL std::map embbeding_map; for (uint32_t i = 0; i < sd_ctx_params->embedding_count; i++) { @@ -819,6 +831,10 @@ class StableDiffusionGGML { if (version == VERSION_SVD) { ignore_tensors.insert("conditioner.embedders.3"); } + if (sd_version_is_ernie_image(version)) { + ignore_tensors.insert("text_encoders.llm.vision_tower."); + ignore_tensors.insert("text_encoders.llm.multi_modal_projector."); + } bool success = model_loader.load_tensors(tensors, ignore_tensors, n_threads, sd_ctx_params->enable_mmap); if (!success) { LOG_ERROR("load tensors from model loader failed"); @@ -922,10 +938,13 @@ class StableDiffusionGGML { sd_version_is_wan(version) || sd_version_is_qwen_image(version) || sd_version_is_anima(version) || + sd_version_is_ernie_image(version) || sd_version_is_z_image(version)) { pred_type = FLOW_PRED; if (sd_version_is_wan(version)) { default_flow_shift = 5.f; + } else if (sd_version_is_ernie_image(version)) { + default_flow_shift = 4.f; } else { default_flow_shift = 3.f; } @@ -1395,7 +1414,7 @@ class StableDiffusionGGML { uint32_t dim = is_video ? static_cast(latents.shape()[3]) : static_cast(latents.shape()[2]); if (dim == 128) { - if (sd_version_is_flux2(version)) { + if (sd_version_uses_flux2_vae(version)) { latent_rgb_proj = flux2_latent_rgb_proj; latent_rgb_bias = flux2_latent_rgb_bias; patch_sz = 2; @@ -1612,7 +1631,7 @@ class StableDiffusionGGML { denoiser.get(), sigmas); size_t steps = sigmas.size() - 1; - bool has_skiplayer = slg_scale != 0.0f && !skip_layers.empty(); + bool has_skiplayer = (slg_scale != 0.0f || guidance.slg.uncond) && !skip_layers.empty(); if (has_skiplayer && !sd_version_is_dit(version)) { has_skiplayer = false; LOG_WARN("SLG is incompatible with this model type"); @@ -1625,6 +1644,14 @@ class StableDiffusionGGML { sd::Tensor denoised = x_t; SamplePreviewContext preview = prepare_sample_preview_context(); + sd::Tensor apg_momentum_buffer; + std::vector apg_cfg_norm; + bool apg_enabled = false; + if (guidance.apg.eta != 1.f || guidance.apg.momentum != 0.f || guidance.apg.norm_threshold != 0.f) { + LOG_INFO("APG enabled: eta=%g, momentum=%g, norm_threshold=%g", guidance.apg.eta, guidance.apg.momentum, guidance.apg.norm_threshold); + apg_enabled = true; + } + auto denoise = [&](const sd::Tensor& x, float sigma, int step) -> sd::Tensor { if (step == 1 || step == -1) { pretty_progress(0, (int)steps, 0); @@ -1726,6 +1753,10 @@ class StableDiffusionGGML { } } + bool is_skiplayer_step = has_skiplayer && + step > (int)(guidance.slg.layer_start * static_cast(sigmas.size())) && + step < (int)(guidance.slg.layer_end * static_cast(sigmas.size())); + if (!uncond.empty()) { if (!step_cache.is_step_skipped()) { compute_sample_controls(control_image, @@ -1734,7 +1765,12 @@ class StableDiffusionGGML { uncond, &controls); } - uncond_out = run_condition(uncond); + const std::vector* uncond_skip = nullptr; + if (is_skiplayer_step && guidance.slg.uncond) { + LOG_DEBUG("Skipping layers at uncond step %d\n", step); + uncond_skip = &skip_layers; + } + uncond_out = run_condition(uncond, nullptr, uncond_skip); if (uncond_out.empty()) { return {}; } @@ -1746,10 +1782,7 @@ class StableDiffusionGGML { return {}; } } - bool is_skiplayer_step = has_skiplayer && - step > (int)(guidance.slg.layer_start * static_cast(sigmas.size())) && - step < (int)(guidance.slg.layer_end * static_cast(sigmas.size())); - if (is_skiplayer_step) { + if (is_skiplayer_step && slg_scale != 0.0f) { LOG_DEBUG("Skipping layers at step %d\n", step); if (!step_cache.is_step_skipped()) { skip_cond_out = run_condition(cond, @@ -1763,16 +1796,87 @@ class StableDiffusionGGML { GGML_ASSERT(!cond_out.empty()); sd::Tensor latent_result = cond_out; - if (!uncond_out.empty()) { - if (!img_cond_out.empty()) { - latent_result = uncond_out + - img_cfg_scale * (img_cond_out - uncond_out) + - cfg_scale * (cond_out - img_cond_out); + + if (!uncond_out.empty() || !img_cond_out.empty()) { + + sd::Tensor delta; + if (img_cond_out.empty()) { + // classic CFG (img_cfg_scale == cfg_scale != 1) + delta = cond_out - uncond_out; + } else if (cfg_scale == 1.f) { + // Weird guidance (important: use img_cfg_scale instead of cfg_scale in the final formula) + delta = img_cond_out - uncond_out; + } else if (uncond_out.empty()) { + // pure img CFG (img_cfg_scale == 1, cfg_scale !=1) + delta = cond_out - img_cond_out; } else { - latent_result = uncond_out + cfg_scale * (cond_out - uncond_out); + // 2-conditioning CFG (img_cfg_scale != cfg_scale != 1) + // apply APG to the outer direction + delta = cond_out - img_cond_out; + } + + // APG: https://arxiv.org/pdf/2410.02416 + + if (guidance.apg.momentum != 0.f) { + if (!apg_momentum_buffer.empty()) { + delta += guidance.apg.momentum * apg_momentum_buffer; + } + apg_momentum_buffer = delta; + } + + float diff_norm = 0.f; + if (apg_enabled) { + float cfg_norm = 0.f; + for (int64_t i = 0; i < delta.numel(); ++i) { + cfg_norm += delta[i] * delta[i]; + } + cfg_norm = sqrtf(cfg_norm); + apg_cfg_norm.push_back(cfg_norm); + if (guidance.apg.norm_threshold > 0) { + diff_norm = cfg_norm; + } + } + + float apg_scale_factor = 1.f; + if (diff_norm > 0) { + if (guidance.apg.norm_threshold_smoothing <= 0) { + apg_scale_factor = std::min(1.0f, guidance.apg.norm_threshold / diff_norm); + } else { + // Experimental: smooth saturate + float x = guidance.apg.norm_threshold / diff_norm; + apg_scale_factor = x / std::pow(1 + std::pow(x, 1.0 / guidance.apg.norm_threshold_smoothing), + guidance.apg.norm_threshold_smoothing); + } + delta *= apg_scale_factor; + } + + if (guidance.apg.eta != 1.0f) { + float cond_norm_sq = 0.f; + float dot = 0.f; + for (int64_t i = 0; i < delta.numel(); ++i) { + cond_norm_sq += cond_out[i] * cond_out[i]; + dot += cond_out[i] * delta[i]; + } + // pre-normalize (avoids one square root and ne_elements extra divs) + dot /= cond_norm_sq; + + for (int64_t i = 0; i < delta.numel(); ++i) { + float apg_parallel = dot * cond_out[i]; + float apg_orthogonal = delta[i] - apg_parallel; + // tweak deltas + delta[i] = apg_orthogonal + guidance.apg.eta * apg_parallel; + } + } + + if (img_cond_out.empty()) { + latent_result = uncond_out + cfg_scale * delta; + } else if (cfg_scale == 1.f) { + latent_result += (img_cfg_scale - 1.f) * delta; + } else if (uncond_out.empty()) { + latent_result = img_cond_out + cfg_scale * delta; + } else { + latent_result = uncond_out + img_cfg_scale * (img_cond_out - uncond_out) + cfg_scale * delta; } - } else if (!img_cond_out.empty()) { - latent_result = img_cond_out + cfg_scale * (cond_out - img_cond_out); } if (is_skiplayer_step && !skip_cond_out.empty()) { @@ -1804,6 +1908,12 @@ class StableDiffusionGGML { } return {}; } + if (!apg_cfg_norm.empty()) { + std::sort(apg_cfg_norm.begin(), apg_cfg_norm.end()); + size_t mid = apg_cfg_norm.size() / 2; + float median = (mid % 2 == 0 ? (apg_cfg_norm[mid - 1] + apg_cfg_norm[mid]) / 2.0f : apg_cfg_norm[mid]); + LOG_DEBUG("CFG Delta norm: [%g, %g], median=%g", apg_cfg_norm[0], apg_cfg_norm[apg_cfg_norm.size()-1], median); + } auto x0 = std::move(x0_opt); sd_sample::log_sample_cache_summary(cache_runtime, steps); @@ -1844,7 +1954,7 @@ class StableDiffusionGGML { latent_channel = 48; } else if (version == VERSION_CHROMA_RADIANCE) { latent_channel = 3; - } else if (sd_version_is_flux2(version)) { + } else if (sd_version_uses_flux2_vae(version)) { latent_channel = 128; } else { latent_channel = 16; @@ -1975,6 +2085,7 @@ const char* sample_method_to_str[] = { "tcd", "res_multistep", "res_2s", + "er_sde", }; const char* sd_sample_method_name(enum sample_method_t sample_method) { @@ -2224,6 +2335,9 @@ void sd_sample_params_init(sd_sample_params_t* sample_params) { sample_params->guidance.txt_cfg = 7.0f; sample_params->guidance.img_cfg = INFINITY; sample_params->guidance.distilled_guidance = 3.5f; + sample_params->guidance.apg.eta = INFINITY; + sample_params->guidance.apg.momentum = INFINITY; + sample_params->guidance.apg.norm_threshold = INFINITY; sample_params->guidance.slg.layer_count = 0; sample_params->guidance.slg.layer_start = 0.01f; sample_params->guidance.slg.layer_end = 0.2f; @@ -2370,6 +2484,14 @@ struct sd_ctx_t { StableDiffusionGGML* sd = nullptr; }; +static bool sd_version_supports_video_generation(SDVersion version) { + return version == VERSION_SVD || sd_version_is_wan(version); +} + +static bool sd_version_supports_image_generation(SDVersion version) { + return !sd_version_supports_video_generation(version); +} + sd_ctx_t* new_sd_ctx(const sd_ctx_params_t* sd_ctx_params) { sd_ctx_t* sd_ctx = (sd_ctx_t*)malloc(sizeof(sd_ctx_t)); if (sd_ctx == nullptr) { @@ -2399,6 +2521,20 @@ void free_sd_ctx(sd_ctx_t* sd_ctx) { free(sd_ctx); } +SD_API bool sd_ctx_supports_image_generation(const sd_ctx_t* sd_ctx) { + if (sd_ctx == nullptr || sd_ctx->sd == nullptr) { + return false; + } + return sd_version_supports_image_generation(sd_ctx->sd->version); +} + +SD_API bool sd_ctx_supports_video_generation(const sd_ctx_t* sd_ctx) { + if (sd_ctx == nullptr || sd_ctx->sd == nullptr) { + return false; + } + return sd_version_supports_video_generation(sd_ctx->sd->version); +} + enum sample_method_t sd_get_default_sample_method(const sd_ctx_t* sd_ctx) { if (sd_ctx != nullptr && sd_ctx->sd != nullptr) { if (sd_version_is_dit(sd_ctx->sd->version)) { @@ -2415,8 +2551,10 @@ enum scheduler_t sd_get_default_scheduler(const sd_ctx_t* sd_ctx, enum sample_me return EXPONENTIAL_SCHEDULER; } } - if (sample_method == LCM_SAMPLE_METHOD) { + if (sample_method == LCM_SAMPLE_METHOD || sample_method == TCD_SAMPLE_METHOD) { return LCM_SCHEDULER; + } else if (sample_method == DDIM_TRAILING_SAMPLE_METHOD) { + return SIMPLE_SCHEDULER; } return DISCRETE_SCHEDULER; } @@ -2457,6 +2595,7 @@ static float resolve_eta(sd_ctx_t* sd_ctx, return 0.0f; case EULER_A_SAMPLE_METHOD: case DPMPP2S_A_SAMPLE_METHOD: + case ER_SDE_SAMPLE_METHOD: return 1.0f; default:; } @@ -2588,6 +2727,20 @@ struct GenerationRequest { LOG_WARN("%scfg value out of expected range may produce unexpected results", prefix); } } + if (!std::isfinite(guidance->apg.eta)) { + guidance->apg.eta = 1.f; + } + if (!std::isfinite(guidance->apg.momentum)) { + guidance->apg.momentum = 0.f; + } + if (!std::isfinite(guidance->apg.norm_threshold)) { + guidance->apg.norm_threshold = 0.f; + } + guidance->apg.norm_threshold = std::max(0.f, guidance->apg.norm_threshold); + if (!std::isfinite(guidance->apg.norm_threshold_smoothing)) { + guidance->apg.norm_threshold_smoothing = 0.f; + } + guidance->apg.norm_threshold_smoothing = std::max(0.f, guidance->apg.norm_threshold_smoothing); } void resolve(sd_ctx_t* sd_ctx) { @@ -2674,6 +2827,19 @@ struct SamplePlan { sd_ctx->sd->get_image_seq_len(request->height, request->width), scheduler, sd_ctx->sd->version); + + { + std::ostringstream oss; + oss << std::setprecision(6); + oss << "| " << sd_scheduler_name(sample_params->scheduler) << " | " << sigmas[0] << " |"; + for (size_t i = 1; i < sigmas.size(); ++i) { + oss << sigmas[i] << " |"; + //if (i != sigmas.size() - 1) { + // oss << ","; + //} + } + LOG_DEBUG("Sigma schedule: \n%s", oss.str().c_str()); + } } eta = resolve_eta(sd_ctx, eta, sample_method); diff --git a/src/unet.hpp b/src/unet.hpp index 63e23eb93..2a24f14ed 100644 --- a/src/unet.hpp +++ b/src/unet.hpp @@ -217,11 +217,11 @@ class UnetModelBlock : public GGMLBlock { } else if (sd_version_is_unet_edit(version)) { in_channels = 8; } - if (version == VERSION_SD1_TINY_UNET || version == VERSION_SD2_TINY_UNET || version == VERSION_SDXS) { + if (version == VERSION_SD1_TINY_UNET || version == VERSION_SD2_TINY_UNET || version == VERSION_SDXS_512_DS || version == VERSION_SDXS_09) { num_res_blocks = 1; channel_mult = {1, 2, 4}; tiny_unet = true; - if (version == VERSION_SDXS) { + if (version == VERSION_SDXS_512_DS) { attention_resolutions = {4, 2}; // here just like SDXL } } @@ -264,6 +264,10 @@ class UnetModelBlock : public GGMLBlock { if (version == VERSION_SVD) { return new SpatialVideoTransformer(in_channels, n_head, d_head, depth, context_dim, use_linear_projection); } else { + if (version == VERSION_SDXS_09 && n_head == 5) { + n_head = 1; // to carry a special case of sdxs_09 into CrossAttentionLayer, + d_head = 320; // works as long the product remains equal (5*64 == 1*320) + } return new SpatialTransformer(in_channels, n_head, d_head, depth, context_dim, use_linear_projection); } }; diff --git a/src/vae.hpp b/src/vae.hpp index 22be8867a..dc69535e8 100644 --- a/src/vae.hpp +++ b/src/vae.hpp @@ -69,7 +69,7 @@ struct VAE : public GGMLRunner { int scale_factor = 8; if (version == VERSION_WAN2_2_TI2V) { scale_factor = 16; - } else if (sd_version_is_flux2(version)) { + } else if (sd_version_uses_flux2_vae(version)) { scale_factor = 16; } else if (version == VERSION_CHROMA_RADIANCE) { scale_factor = 1;