Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ API and command-line option may change frequently.***
- SDXL, [SDXL-Turbo](https://huggingface.co/stabilityai/sdxl-turbo)
- [Some SD1.x and SDXL distilled models](./docs/distilled_sd.md)
- [SD3/SD3.5](./docs/sd3.md)
- [Flux-dev/Flux-schnell](./docs/flux.md)
- [FlUX.1-dev/FlUX.1-schnell](./docs/flux.md)
- [FLUX.2-dev](./docs/flux2.md)
- [Chroma](./docs/chroma.md)
- [Chroma1-Radiance](./docs/chroma_radiance.md)
- [Qwen Image](./docs/qwen_image.md)
Expand Down Expand Up @@ -118,7 +119,8 @@ If you want to improve performance or reduce VRAM/RAM usage, please refer to [pe

- [SD1.x/SD2.x/SDXL](./docs/sd.md)
- [SD3/SD3.5](./docs/sd3.md)
- [Flux-dev/Flux-schnell](./docs/flux.md)
- [FlUX.1-dev/FlUX.1-schnell](./docs/flux.md)
- [FLUX.2-dev](./docs/flux2.md)
- [FLUX.1-Kontext-dev](./docs/kontext.md)
- [Chroma](./docs/chroma.md)
- [🔥Qwen Image](./docs/qwen_image.md)
Expand Down
Binary file added assets/flux2/example.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
146 changes: 98 additions & 48 deletions conditioner.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
#define __CONDITIONER_HPP__

#include "clip.hpp"
#include "qwenvl.hpp"
#include "llm.hpp"
#include "t5.hpp"

struct SDCondition {
Expand Down Expand Up @@ -1623,61 +1623,74 @@ struct T5CLIPEmbedder : public Conditioner {
}
};

struct Qwen2_5_VLCLIPEmbedder : public Conditioner {
Qwen::Qwen2Tokenizer tokenizer;
std::shared_ptr<Qwen::Qwen2_5_VLRunner> qwenvl;

Qwen2_5_VLCLIPEmbedder(ggml_backend_t backend,
bool offload_params_to_cpu,
const String2TensorStorage& tensor_storage_map = {},
const std::string prefix = "",
bool enable_vision = false) {
qwenvl = std::make_shared<Qwen::Qwen2_5_VLRunner>(backend,
offload_params_to_cpu,
tensor_storage_map,
"text_encoders.qwen2vl",
enable_vision);
struct LLMEmbedder : public Conditioner {
SDVersion version;
std::shared_ptr<LLM::BPETokenizer> tokenizer;
std::shared_ptr<LLM::LLMRunner> llm;

LLMEmbedder(ggml_backend_t backend,
bool offload_params_to_cpu,
const String2TensorStorage& tensor_storage_map = {},
SDVersion version = VERSION_QWEN_IMAGE,
const std::string prefix = "",
bool enable_vision = false)
: version(version) {
LLM::LLMArch arch = LLM::LLMArch::QWEN2_5_VL;
if (sd_version_is_flux2(version)) {
arch = LLM::LLMArch::MISTRAL_SMALL_3_2;
} else if (sd_version_is_zimage(version)) {
arch = LLM::LLMArch::QWEN3_4B;
}
if (arch == LLM::LLMArch::MISTRAL_SMALL_3_2) {
tokenizer = std::make_shared<LLM::MistralTokenizer>();
} else {
tokenizer = std::make_shared<LLM::Qwen2Tokenizer>();
}
llm = std::make_shared<LLM::LLMRunner>(arch,
backend,
offload_params_to_cpu,
tensor_storage_map,
"text_encoders.llm",
enable_vision);
}

void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors) override {
qwenvl->get_param_tensors(tensors, "text_encoders.qwen2vl");
llm->get_param_tensors(tensors, "text_encoders.llm");
}

void alloc_params_buffer() override {
qwenvl->alloc_params_buffer();
llm->alloc_params_buffer();
}

void free_params_buffer() override {
qwenvl->free_params_buffer();
llm->free_params_buffer();
}

size_t get_params_buffer_size() override {
size_t buffer_size = 0;
buffer_size += qwenvl->get_params_buffer_size();
buffer_size += llm->get_params_buffer_size();
return buffer_size;
}

void set_weight_adapter(const std::shared_ptr<WeightAdapter>& adapter) override {
if (qwenvl) {
qwenvl->set_weight_adapter(adapter);
if (llm) {
llm->set_weight_adapter(adapter);
}
}

std::tuple<std::vector<int>, std::vector<float>> tokenize(std::string text,
size_t max_length = 0,
size_t system_prompt_length = 0,
bool padding = false) {
std::pair<int, int> attn_range,
size_t max_length = 0,
bool padding = false) {
std::vector<std::pair<std::string, float>> parsed_attention;
if (system_prompt_length > 0) {
parsed_attention.emplace_back(text.substr(0, system_prompt_length), 1.f);
auto new_parsed_attention = parse_prompt_attention(text.substr(system_prompt_length, text.size() - system_prompt_length));
parsed_attention.emplace_back(text.substr(0, attn_range.first), 1.f);
if (attn_range.second - attn_range.first > 0) {
auto new_parsed_attention = parse_prompt_attention(text.substr(attn_range.first, attn_range.second - attn_range.first));
parsed_attention.insert(parsed_attention.end(),
new_parsed_attention.begin(),
new_parsed_attention.end());
} else {
parsed_attention = parse_prompt_attention(text);
}

parsed_attention.emplace_back(text.substr(attn_range.second), 1.f);
{
std::stringstream ss;
ss << "[";
Expand All @@ -1693,12 +1706,12 @@ struct Qwen2_5_VLCLIPEmbedder : public Conditioner {
for (const auto& item : parsed_attention) {
const std::string& curr_text = item.first;
float curr_weight = item.second;
std::vector<int> curr_tokens = tokenizer.tokenize(curr_text, nullptr);
std::vector<int> curr_tokens = tokenizer->tokenize(curr_text, nullptr);
tokens.insert(tokens.end(), curr_tokens.begin(), curr_tokens.end());
weights.insert(weights.end(), curr_tokens.size(), curr_weight);
}

tokenizer.pad_tokens(tokens, weights, max_length, padding);
tokenizer->pad_tokens(tokens, weights, max_length, padding);

// for (int i = 0; i < tokens.size(); i++) {
// std::cout << tokens[i] << ":" << weights[i] << ", " << i << std::endl;
Expand All @@ -1713,9 +1726,10 @@ struct Qwen2_5_VLCLIPEmbedder : public Conditioner {
const ConditionerParams& conditioner_params) override {
std::string prompt;
std::vector<std::pair<int, ggml_tensor*>> image_embeds;
size_t system_prompt_length = 0;
std::pair<int, int> prompt_attn_range;
int prompt_template_encode_start_idx = 34;
if (qwenvl->enable_vision && conditioner_params.ref_images.size() > 0) {
std::set<int> out_layers;
if (llm->enable_vision && conditioner_params.ref_images.size() > 0) {
LOG_INFO("QwenImageEditPlusPipeline");
prompt_template_encode_start_idx = 64;
int image_embed_idx = 64 + 6;
Expand All @@ -1727,7 +1741,7 @@ struct Qwen2_5_VLCLIPEmbedder : public Conditioner {

for (int i = 0; i < conditioner_params.ref_images.size(); i++) {
sd_image_f32_t image = sd_image_t_to_sd_image_f32_t(*conditioner_params.ref_images[i]);
double factor = qwenvl->params.vision.patch_size * qwenvl->params.vision.spatial_merge_size;
double factor = llm->params.vision.patch_size * llm->params.vision.spatial_merge_size;
int height = image.height;
int width = image.width;
int h_bar = static_cast<int>(std::round(height / factor)) * factor;
Expand Down Expand Up @@ -1757,7 +1771,7 @@ struct Qwen2_5_VLCLIPEmbedder : public Conditioner {
resized_image.data = nullptr;

ggml_tensor* image_embed = nullptr;
qwenvl->encode_image(n_threads, image_tensor, &image_embed, work_ctx);
llm->encode_image(n_threads, image_tensor, &image_embed, work_ctx);
image_embeds.emplace_back(image_embed_idx, image_embed);
image_embed_idx += 1 + image_embed->ne[1] + 6;

Expand All @@ -1771,17 +1785,41 @@ struct Qwen2_5_VLCLIPEmbedder : public Conditioner {
}

prompt = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n";

system_prompt_length = prompt.size();

prompt += img_prompt;

prompt_attn_range.first = prompt.size();
prompt += conditioner_params.text;
prompt_attn_range.second = prompt.size();

prompt += "<|im_end|>\n<|im_start|>assistant\n";
} else if (sd_version_is_flux2(version)) {
prompt_template_encode_start_idx = 0;
out_layers = {10, 20, 30};

prompt = "[SYSTEM_PROMPT]You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object\nattribution and actions without speculation.[/SYSTEM_PROMPT][INST]";

prompt_attn_range.first = prompt.size();
prompt += conditioner_params.text;
prompt_attn_range.second = prompt.size();

prompt += "[/INST]";
} else {
prompt = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n" + conditioner_params.text + "<|im_end|>\n<|im_start|>assistant\n";
prompt_template_encode_start_idx = 34;
if (sd_version_is_zimage(version)) {
prompt_template_encode_start_idx = 0;
prompt = "<|im_start|>user\n";
} else {
prompt = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n";
}

prompt_attn_range.first = prompt.size();
prompt += conditioner_params.text;
prompt_attn_range.second = prompt.size();

prompt += "<|im_end|>\n<|im_start|>assistant\n";
}

auto tokens_and_weights = tokenize(prompt, 0, system_prompt_length, false);
auto tokens_and_weights = tokenize(prompt, prompt_attn_range, 0, false);
auto& tokens = std::get<0>(tokens_and_weights);
auto& weights = std::get<1>(tokens_and_weights);

Expand All @@ -1790,11 +1828,12 @@ struct Qwen2_5_VLCLIPEmbedder : public Conditioner {

auto input_ids = vector_to_ggml_tensor_i32(work_ctx, tokens);

qwenvl->compute(n_threads,
input_ids,
image_embeds,
&hidden_states,
work_ctx);
llm->compute(n_threads,
input_ids,
image_embeds,
out_layers,
&hidden_states,
work_ctx);
{
auto tensor = hidden_states;
float original_mean = ggml_ext_tensor_mean(tensor);
Expand All @@ -1813,14 +1852,25 @@ struct Qwen2_5_VLCLIPEmbedder : public Conditioner {

GGML_ASSERT(hidden_states->ne[1] > prompt_template_encode_start_idx);

int64_t zero_pad_len = 0;
if (sd_version_is_flux2(version)) {
int64_t min_length = 512;
if (hidden_states->ne[1] - prompt_template_encode_start_idx < min_length) {
zero_pad_len = min_length - hidden_states->ne[1] + prompt_template_encode_start_idx;
}
}

ggml_tensor* new_hidden_states = ggml_new_tensor_3d(work_ctx,
GGML_TYPE_F32,
hidden_states->ne[0],
hidden_states->ne[1] - prompt_template_encode_start_idx,
hidden_states->ne[1] - prompt_template_encode_start_idx + zero_pad_len,
hidden_states->ne[2]);

ggml_ext_tensor_iter(new_hidden_states, [&](ggml_tensor* new_hidden_states, int64_t i0, int64_t i1, int64_t i2, int64_t i3) {
float value = ggml_ext_tensor_get_f32(hidden_states, i0, i1 + prompt_template_encode_start_idx, i2, i3);
float value = 0.f;
if (i1 + prompt_template_encode_start_idx < hidden_states->ne[1]) {
value = ggml_ext_tensor_get_f32(hidden_states, i0, i1 + prompt_template_encode_start_idx, i2, i3);
}
ggml_ext_tensor_set_f32(new_hidden_states, value, i0, i1, i2, i3);
});

Expand Down
93 changes: 89 additions & 4 deletions denoiser.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,7 @@ struct Denoiser {
virtual ggml_tensor* noise_scaling(float sigma, ggml_tensor* noise, ggml_tensor* latent) = 0;
virtual ggml_tensor* inverse_noise_scaling(float sigma, ggml_tensor* latent) = 0;

virtual std::vector<float> get_sigmas(uint32_t n, scheduler_t scheduler_type, SDVersion version) {
virtual std::vector<float> get_sigmas(uint32_t n, int /*image_seq_len*/, scheduler_t scheduler_type, SDVersion version) {
auto bound_t_to_sigma = std::bind(&Denoiser::t_to_sigma, this, std::placeholders::_1);
std::shared_ptr<SigmaScheduler> scheduler;
switch (scheduler_type) {
Expand Down Expand Up @@ -582,10 +582,14 @@ struct FluxFlowDenoiser : public Denoiser {
set_parameters(shift);
}

void set_parameters(float shift = 1.15f) {
void set_shift(float shift) {
this->shift = shift;
for (int i = 1; i < TIMESTEPS + 1; i++) {
sigmas[i - 1] = t_to_sigma(i / TIMESTEPS * TIMESTEPS);
}

void set_parameters(float shift) {
set_shift(shift);
for (int i = 0; i < TIMESTEPS; i++) {
sigmas[i] = t_to_sigma(i);
}
}

Expand Down Expand Up @@ -627,6 +631,87 @@ struct FluxFlowDenoiser : public Denoiser {
}
};

struct Flux2FlowDenoiser : public FluxFlowDenoiser {
Flux2FlowDenoiser() = default;

float compute_empirical_mu(uint32_t n, int image_seq_len) {
const float a1 = 8.73809524e-05f;
const float b1 = 1.89833333f;
const float a2 = 0.00016927f;
const float b2 = 0.45666666f;

if (image_seq_len > 4300) {
float mu = a2 * image_seq_len + b2;
return mu;
}

float m_200 = a2 * image_seq_len + b2;
float m_10 = a1 * image_seq_len + b1;

float a = (m_200 - m_10) / 190.0f;
float b = m_200 - 200.0f * a;
float mu = a * n + b;

return mu;
}

std::vector<float> get_sigmas(uint32_t n, int image_seq_len, scheduler_t scheduler_type, SDVersion version) override {
float mu = compute_empirical_mu(n, image_seq_len);
LOG_DEBUG("Flux2FlowDenoiser: set shift to %.3f", mu);
set_shift(mu);
return Denoiser::get_sigmas(n, image_seq_len, scheduler_type, version);
}
};

// Z-Image flow matching denoiser
struct ZImageFlowDenoiser : public Denoiser {
float sigmas[TIMESTEPS];
float shift = 3.0f;

ZImageFlowDenoiser(float shift = 3.0f) {
this->shift = shift;
for (int i = 0; i < TIMESTEPS; i++) {
sigmas[i] = t_to_sigma(i);
}
}

float sigma_min() override {
return sigmas[0];
}

float sigma_max() override {
return sigmas[TIMESTEPS - 1];
}

float sigma_to_t(float sigma) override {
return 1.0f - sigma;
}

float t_to_sigma(float t) override {
float sigma_raw = (t + 1) / TIMESTEPS;
return shift * sigma_raw / (1.0f + (shift - 1.0f) * sigma_raw);
}

std::vector<float> get_scalings(float sigma) override {
float c_skip = 1.0f;
float c_out = sigma;
float c_in = 1.0f;
return {c_skip, c_out, c_in};
}

ggml_tensor* noise_scaling(float sigma, ggml_tensor* noise, ggml_tensor* latent) override {
ggml_ext_tensor_scale_inplace(noise, sigma);
ggml_ext_tensor_scale_inplace(latent, 1.0f - sigma);
ggml_ext_tensor_add_inplace(latent, noise);
return latent;
}

ggml_tensor* inverse_noise_scaling(float sigma, ggml_tensor* latent) override {
ggml_ext_tensor_scale_inplace(latent, 1.0f / (1.0f - sigma));
return latent;
}
};

typedef std::function<ggml_tensor*(ggml_tensor*, float, int)> denoise_cb_t;

// k diffusion reverse ODE: dx = (x - D(x;\sigma)) / \sigma dt; \sigma(t) = t
Expand Down
Loading