Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support different-length pos/neg prompts for FLUX.1-schnell variants like Chroma #11120

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

josephrocca
Copy link

@josephrocca josephrocca commented Mar 20, 2025

What does this PR do?

Context:

Chroma is a large-scale Apache 2.0 fine-tune of FLUX.1 Schnell. It is currently one of the top trending text-to-image models, and has been for several days now:

image

Someone recently asked about diffusers support:

I've currently got it working in diffusers:

but as you can see from the comments at the top of that script, it requires a couple of changes to diffusers source code for it to work out of the box.

Changes:

One such change is due to Chroma requiring masking/truncation of prompts (all but the final padding token).

Currently diffusers requires that prompts are the same length, since it assumes that the full 512 T5 tokens will be used for both positive and negative prompts.

So check_inputs blocks it, and if we remove that check, then we get this error:

  File "/opt/conda/lib/python3.11/site-packages/diffusers/pipelines/flux/pipeline_flux.py", line 904, in __call__
    neg_noise_pred = self.transformer(
                     ^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/accelerate/hooks.py", line 176, in new_forward
    output = module._old_forward(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/diffusers/models/transformers/transformer_flux.py", line 522, in forward
    encoder_hidden_states, hidden_states = block(
                                           ^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/diffusers/models/transformers/transformer_flux.py", line 180, in forward
    attention_outputs = self.attn(
                        ^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/diffusers/models/attention_processor.py", line 588, in forward
    return self.processor(
           ^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/diffusers/models/attention_processor.py", line 2318, in __call__
    query = apply_rotary_emb(query, image_rotary_emb)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/diffusers/models/embeddings.py", line 1208, in apply_rotary_emb
    out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
           ~~~~~~~~~~^~~~~

So we need to pass the negative prompt ids into the negative prompt forward pass, instead of passing the positive prompt ids into both.

Who can review?

@yiyixuxu @sayakpaul @DN6

@sayakpaul
Copy link
Member

Thanks for your PR!

Before we get to reviewing the PR, could you please provide some side-by-side results with Schnell and Chroma on the same inputs (including the seeds)?

@sayakpaul sayakpaul requested a review from yiyixuxu March 20, 2025 06:39
@asomoza
Copy link
Member

asomoza commented Mar 20, 2025

there some more context in this issue: #11010

P.S.: I tested V13

@josephrocca
Copy link
Author

josephrocca commented Mar 20, 2025

could you please provide some side-by-side results with Schnell and Chroma on the same inputs (including the seeds)?

Oh, sure thing, see links below for some comparison grids - but some quick notes:

  • If you're wondering how different/trained/diverged Chroma is from Schnell, these pics may be useful.
  • If you're looking to understand how "high quality" or "aesthetic" Chroma is, then these will not be particularly useful, since Chroma is still training, and has almost entirely been trained on 5122px images so far, with no post training preference tuning. I think it's not ready to be compared on aesthetics or fine details at this point.
  • I used ChatGPT 4.5 to generate the prompts used for these: https://chatgpt.com/share/67dc3cb0-fbb4-8007-a661-0184968418ad

Image Grids:

Also, if you skim the above grids, note that some of the images from Schnell look quite "clean" and coherent, and this is definitely an advantage that Schnell currently has, but note that in some cases Chroma should arguably win based on the style specified in the prompt ("courtroom sketch"):

image

Compared to Chroma:

image (1)

And Chroma with aesthetic keywords to try to emulate aesthetic tuning that Schnell has:

image (2)

You can see that although Schnell's is cleaner (and arguably slightly more coherent, though sample size is a bit small here), Chroma is definitely more faithful to the style specified in the prompt.

Also note that, as with other models that haven't had CFG baked in, you can get entirely different 'vibes' by tweaking Chroma's CFG - above I've used 5, with 20 steps (lodestone is currently doing some small-scale experiments as a precursor to a few-step lora for chroma).

My experience so far with testing Chroma is that it has a lot more "soul" than Schnell and Dev - it's quite fun to play with.

@nitinmukesh
Copy link

Awesome. I looked at some of the outputs, Anubis one is amazing among many others.

@josephrocca
Copy link
Author

josephrocca commented Mar 20, 2025

Side note: Playing with the official Chroma ComfyUI workflow just now with v15, I noticed that there are some potential differences in quality/coherence compared to my diffusers code which generated the above images - e.g. notice the alignment to the "bored human judge" in this seed=0 image with Chroma, which was less evident in the above examples:

So please take the above example images with a pinch of salt - Chroma quality may be better than what I've shown here. It could be due to quantization, or subtitles around ComfyUI sampling. I'd need to take bigger sample sizes to know if ComfyUI outputs are actually better, but I'm going to sleep now :)

Copy link
Member

@hlky hlky left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @josephrocca

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@josephrocca
Copy link
Author

I'm not sure about the conventions in diffusers, but since prompt truncation is equivalent to prompt masking, I wonder whether it'd be worth also/instead supporting masking for flux?

This is working code, inserted here in transformer_flux.py:

if joint_attention_kwargs is not None and "encoder_attention_mask" in joint_attention_kwargs and joint_attention_kwargs["encoder_attention_mask"] is not None:
    encoder_attention_mask = joint_attention_kwargs.pop("encoder_attention_mask")
    max_seq_length = encoder_hidden_states.shape[1]
    seq_length = encoder_attention_mask.sum(dim=-1)
    batch_size = encoder_attention_mask.shape[0]
    encoder_attention_mask_with_padding = encoder_attention_mask.clone()
    for i in range(batch_size):
        current_seq_len = int(seq_length[i].item())
        if current_seq_len < max_seq_length:
            available_padding = max_seq_length - current_seq_len
            tokens_to_unmask = min(1, available_padding) # unmask one of the padding tokens
            encoder_attention_mask_with_padding[i, current_seq_len : current_seq_len + tokens_to_unmask] = 1
    attention_mask = torch.cat(
        [
            encoder_attention_mask_with_padding,
            torch.ones([hidden_states.shape[0], hidden_states.shape[1]], device=encoder_attention_mask.device),
        ],
        dim=1,
    )
    attention_mask = attention_mask.float().T @ attention_mask.float()
    attention_mask = (
        attention_mask[None, None, ...]
        .repeat(encoder_hidden_states.shape[0], self.config.num_attention_heads, 1, 1)
        .int()
        .bool()
    )
    joint_attention_kwargs["attention_mask"] = attention_mask

@bghira
Copy link
Contributor

bghira commented Mar 23, 2025

can you demonstrate an example where zeroing the end of the prompt is equivalent to attention masking where the softmax scores for padding sequence is near -infinity?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants