From f0b2a09591de5cf276bbfa7ce06a3d35f508da84 Mon Sep 17 00:00:00 2001 From: Ryan Guo Date: Wed, 19 Nov 2025 13:25:58 -0800 Subject: [PATCH] Add Flux benchmark This adds a benchmark for the Flux image generation pipeline. Specifically, it only benchmarks the diffusion transformer (and omits the text encoder and vae, which don't take up much time for the e2e generation in Flux). Needs https://github.com/pytorch/pytorch/pull/168176 to run in pytorch repo: ``` python ./benchmarks/dynamo/torchbench.py --accuracy --inference --backend=inductor --only flux python ./benchmarks/dynamo/torchbench.py --performance --inference --backend=inductor --only flux ``` --- torchbenchmark/models/flux/__init__.py | 72 ++++++++++++++++++++++++ torchbenchmark/models/flux/install.py | 27 +++++++++ torchbenchmark/models/flux/metadata.yaml | 10 ++++ 3 files changed, 109 insertions(+) create mode 100644 torchbenchmark/models/flux/__init__.py create mode 100644 torchbenchmark/models/flux/install.py create mode 100644 torchbenchmark/models/flux/metadata.yaml diff --git a/torchbenchmark/models/flux/__init__.py b/torchbenchmark/models/flux/__init__.py new file mode 100644 index 0000000000..eb302c2a68 --- /dev/null +++ b/torchbenchmark/models/flux/__init__.py @@ -0,0 +1,72 @@ +import torch +from torchbenchmark.tasks import COMPUTER_VISION +from torchbenchmark.util.framework.huggingface.model_factory import HuggingFaceAuthMixin +from torchbenchmark.util.model import BenchmarkModel + +from .install import load_model_checkpoint + + +class Model(BenchmarkModel, HuggingFaceAuthMixin): + task = COMPUTER_VISION.GENERATION + + DEFAULT_TRAIN_BSIZE = 1 + DEFAULT_EVAL_BSIZE = 1 + ALLOW_CUSTOMIZE_BSIZE = False + # Skip deepcopy because it will oom on A100 40GB + DEEPCOPY = False + # Default eval precision on CUDA device is fp16 + DEFAULT_EVAL_CUDA_PRECISION = "fp16" + + def __init__(self, test, device, batch_size=None, extra_args=[]): + HuggingFaceAuthMixin.__init__(self) + super().__init__( + test=test, device=device, batch_size=batch_size, extra_args=extra_args + ) + self.pipe = load_model_checkpoint() + self.example_inputs = { + "prompt": "A cat holding a sign that says hello world", + "height": 1024, + "width": 1024, + "guidance_scale": 3.5, + "num_inference_steps": 50, + "max_sequence_length": 512, + "generator": torch.Generator("cpu").manual_seed(0), + } + self.pipe.to(self.device) + + def enable_fp16(self): + # This model uses fp16 by default + # Make this function no-op. + pass + + def get_module(self): + # A common configuration: + # - resolution = 1024x1024 + # - maximum sequence length = 512 + # + # The easiest way to get these metadata is probably to run the pipeline + # with the example inputs, and then breakpoint at the transformer module + # forward and print out the input tensor metadata. + inputs = { + "hidden_states": torch.randn(1, 4096, 64, device=self.device, dtype=torch.bfloat16), + "encoder_hidden_states": torch.randn(1, 512, 4096, device=self.device, dtype=torch.bfloat16), + "pooled_projections": torch.randn(1, 768, device=self.device, dtype=torch.bfloat16), + "img_ids": torch.ones(1, 512, 3, device=self.device, dtype=torch.bfloat16), + "txt_ids": torch.ones(1, 4096, 3, device=self.device, dtype=torch.bfloat16), + "timestep": torch.tensor([1.0], device=self.device, dtype=torch.bfloat16), + "guidance": torch.tensor([1.0], device=self.device, dtype=torch.bfloat16), + } + + return self.pipe.transformer, inputs + + def set_module(self, mod): + self.pipe.transformer = mod + + def train(self): + raise NotImplementedError( + "Train test is not implemented for the stable diffusion model." + ) + + def eval(self): + image = self.pipe(**self.example_inputs) + return (image,) diff --git a/torchbenchmark/models/flux/install.py b/torchbenchmark/models/flux/install.py new file mode 100644 index 0000000000..56a315ac14 --- /dev/null +++ b/torchbenchmark/models/flux/install.py @@ -0,0 +1,27 @@ +import os +import warnings + +import torch +from torchbenchmark.util.framework.diffusers import install_diffusers + +MODEL_NAME = "black-forest-labs/FLUX.1-dev" + + +def load_model_checkpoint(): + from diffusers import FluxPipeline + + pipe = FluxPipeline.from_pretrained( + MODEL_NAME, torch_dtype=torch.bfloat16, safety_checker=None + ) + + return pipe + + +if __name__ == "__main__": + install_diffusers() + if not "HUGGING_FACE_HUB_TOKEN" in os.environ: + warnings.warn( + "Make sure to set `HUGGINGFACE_HUB_TOKEN` so you can download weights" + ) + else: + load_model_checkpoint() diff --git a/torchbenchmark/models/flux/metadata.yaml b/torchbenchmark/models/flux/metadata.yaml new file mode 100644 index 0000000000..4a03e1edcb --- /dev/null +++ b/torchbenchmark/models/flux/metadata.yaml @@ -0,0 +1,10 @@ +devices: + NVIDIA A100-SXM4-40GB: + eval_batch_size: 32 +eval_benchmark: false +eval_deterministic: false +eval_nograd: true +train_benchmark: false +train_deterministic: false +not_implemented: +- device: cpu