Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 72 additions & 0 deletions torchbenchmark/models/flux/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import torch
from torchbenchmark.tasks import COMPUTER_VISION
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from torchbenchmark.util.framework.huggingface.model_factory import HuggingFaceAuthMixin
from torchbenchmark.util.model import BenchmarkModel

from .install import load_model_checkpoint


class Model(BenchmarkModel, HuggingFaceAuthMixin):
task = COMPUTER_VISION.GENERATION

DEFAULT_TRAIN_BSIZE = 1
DEFAULT_EVAL_BSIZE = 1
ALLOW_CUSTOMIZE_BSIZE = False
# Skip deepcopy because it will oom on A100 40GB
DEEPCOPY = False
# Default eval precision on CUDA device is fp16
DEFAULT_EVAL_CUDA_PRECISION = "fp16"

def __init__(self, test, device, batch_size=None, extra_args=[]):
HuggingFaceAuthMixin.__init__(self)
super().__init__(
test=test, device=device, batch_size=batch_size, extra_args=extra_args
)
self.pipe = load_model_checkpoint()
self.example_inputs = {
"prompt": "A cat holding a sign that says hello world",
"height": 1024,
"width": 1024,
"guidance_scale": 3.5,
"num_inference_steps": 50,
"max_sequence_length": 512,
"generator": torch.Generator("cpu").manual_seed(0),
}
self.pipe.to(self.device)

def enable_fp16(self):
# This model uses fp16 by default
# Make this function no-op.
pass

def get_module(self):
# A common configuration:
# - resolution = 1024x1024
# - maximum sequence length = 512
#
# The easiest way to get these metadata is probably to run the pipeline
# with the example inputs, and then breakpoint at the transformer module
# forward and print out the input tensor metadata.
inputs = {
"hidden_states": torch.randn(1, 4096, 64, device=self.device, dtype=torch.bfloat16),
"encoder_hidden_states": torch.randn(1, 512, 4096, device=self.device, dtype=torch.bfloat16),
"pooled_projections": torch.randn(1, 768, device=self.device, dtype=torch.bfloat16),
"img_ids": torch.ones(1, 512, 3, device=self.device, dtype=torch.bfloat16),
"txt_ids": torch.ones(1, 4096, 3, device=self.device, dtype=torch.bfloat16),
"timestep": torch.tensor([1.0], device=self.device, dtype=torch.bfloat16),
"guidance": torch.tensor([1.0], device=self.device, dtype=torch.bfloat16),
}

return self.pipe.transformer, inputs

def set_module(self, mod):
self.pipe.transformer = mod

def train(self):
raise NotImplementedError(
"Train test is not implemented for the stable diffusion model."
)

def eval(self):
image = self.pipe(**self.example_inputs)
return (image,)
27 changes: 27 additions & 0 deletions torchbenchmark/models/flux/install.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import os
import warnings

import torch
from torchbenchmark.util.framework.diffusers import install_diffusers

MODEL_NAME = "black-forest-labs/FLUX.1-dev"


def load_model_checkpoint():
from diffusers import FluxPipeline

pipe = FluxPipeline.from_pretrained(
MODEL_NAME, torch_dtype=torch.bfloat16, safety_checker=None
)

return pipe


if __name__ == "__main__":
install_diffusers()
if not "HUGGING_FACE_HUB_TOKEN" in os.environ:
warnings.warn(
"Make sure to set `HUGGINGFACE_HUB_TOKEN` so you can download weights"
)
else:
load_model_checkpoint()
10 changes: 10 additions & 0 deletions torchbenchmark/models/flux/metadata.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
devices:
NVIDIA A100-SXM4-40GB:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure why it's all A100-40GB in this repo.

I got about 1.3x inference speed up on A100 80G.

Copy link
Contributor

@huydhn huydhn Nov 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For more context, A100 40G was the SKU we had from AWS at the beginning when A100 was still new. Nowadays, TorchInductor benchmark is run on H100 though, so there is no incentive to migrate A100 to the 80GB version on CI anymore

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So are these annotations really used anywhere.....?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Strictly speaking, https://github.com/pytorch/pytorch/actions/workflows/inductor-perf-test-nightly.yml is still running an failing without anyone care. I think I should submit a PR to stop this, but it's compiler team's decision

eval_batch_size: 32
eval_benchmark: false
eval_deterministic: false
eval_nograd: true
train_benchmark: false
train_deterministic: false
not_implemented:
- device: cpu
Loading