Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

speeding up sdxl and fixes #2

Merged
merged 1 commit into from
Dec 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
8 changes: 5 additions & 3 deletions run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@ python run_benchmark.py --compile_unet --compile_mode=max-autotune --compile_vae
python run_benchmark.py --compile_unet --compile_mode=max-autotune --compile_vae --change_comp_config && \
python run_benchmark.py --compile_unet --compile_mode=max-autotune --compile_vae --change_comp_config --enable_fused_projections && \
python run_benchmark.py --compile_unet --compile_mode=max-autotune --compile_vae --enable_fused_projections && \
python run_benchmark.py --compile_unet --compile_mode=max-autotune --compile_vae --do_quant && \
python run_benchmark.py --compile_unet --compile_mode=max-autotune --compile_vae --do_quant "int8dynamic" && \
python run_benchmark.py --compile_unet --compile_mode=max-autotune --compile_vae --do_quant "int8weightonly" && \
python run_benchmark.py --compile_unet --compile_mode=max-autotune --compile_vae --do_quant "int4weightonly" && \
python run_benchmark.py --compile_unet --compile_mode=max-autotune --compile_vae --do_quant --change_comp_config && \
python run_benchmark.py --compile_unet --compile_mode=max-autotune --compile_vae --enable_fused_projections --do_quant && \
python run_benchmark.py --compile_unet --compile_mode=max-autotune --compile_vae --enable_fused_projections --do_quant --change_comp_config && python prepare_plot.py --final_csv_filename collated_results_peft.csv
python run_benchmark.py --compile_unet --compile_mode=max-autotune --compile_vae --enable_fused_projections --do_quant "int8dynamic" && \
python run_benchmark.py --compile_unet --compile_mode=max-autotune --compile_vae --enable_fused_projections --do_quant "int8dynamic" --change_comp_config && python prepare_plot.py --final_csv_filename collated_results_peft.csv
18 changes: 12 additions & 6 deletions run_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,20 +43,26 @@ def main(args) -> dict:
time=time,
memory=memory,
)
return data_dict
img = pipeline(
prompt=PROMPT,
num_inference_steps=args.num_inference_steps,
num_images_per_prompt=args.batch_size,
).images[0]


return data_dict, img


if __name__ == "__main__":
parser = create_parser()
args = parser.parse_args()
print(args)

if not args.compile_unet:
args.compile_mode = "NA"

data_dict = main(args)
data_dict, img = main(args)

name = (
CKPT_ID.replace("/", "_")
+ f"fp16@{not args.no_fp16}-sdpa@{not args.no_sdpa}-bs@{args.batch_size}-fuse@{args.enable_fused_projections}-upcast_vae@{args.upcast_vae}-steps@{args.num_inference_steps}-unet@{args.compile_unet}-vae@{args.compile_vae}-mode@{args.compile_mode}-change_comp_config@{args.change_comp_config}-do_quant@{args.do_quant}.csv"
+ f"bf16@{not args.no_bf16}-sdpa@{not args.no_sdpa}-bs@{args.batch_size}-fuse@{args.enable_fused_projections}-upcast_vae@{args.upcast_vae}-steps@{args.num_inference_steps}-unet@{args.compile_unet}-vae@{args.compile_vae}-mode@{args.compile_mode}-change_comp_config@{args.change_comp_config}-do_quant@{args.do_quant}-tag@{args.tag}.csv"
)
img.save(f"{name}.jpeg")
write_to_csv(name, data_dict)
13 changes: 8 additions & 5 deletions utils/benchmarking_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
BENCHMARK_FIELDS = [
"pipeline_cls",
"ckpt_id",
"fp16",
"bf16",
"sdpa",
"fused_qkv_projections",
"upcast_vae",
Expand All @@ -24,14 +24,15 @@
"time (secs)",
"memory (gbs)",
"actual_gpu_memory (gbs)",
"tag"
]
TOTAL_GPU_MEMORY = torch.cuda.get_device_properties(0).total_memory / (1024**3)


def create_parser():
"""Creates CLI args parser."""
parser = argparse.ArgumentParser()
parser.add_argument("--no_fp16", action="store_true")
parser.add_argument("--no_bf16", action="store_true")
parser.add_argument("--no_sdpa", action="store_true")
parser.add_argument("--batch_size", type=int, default=1)
parser.add_argument("--num_inference_steps", type=int, default=30)
Expand All @@ -43,7 +44,8 @@ def create_parser():
"--compile_mode", type=str, default="reduce-overhead", choices=["reduce-overhead", "max-autotune"]
)
parser.add_argument("--change_comp_config", action="store_true")
parser.add_argument("--do_quant", action="store_true")
parser.add_argument("--do_quant", type=str, default=None)
parser.add_argument("--tag", type=str, default="")
return parser


Expand Down Expand Up @@ -75,8 +77,8 @@ def generate_csv_dict(
data_dict = {
"pipeline_cls": pipeline_cls,
"ckpt_id": ckpt,
"fp16": args.no_fp16,
"sdpa": args.no_sdpa,
"bf16": not args.no_bf16,
"sdpa": not args.no_sdpa,
"fused_qkv_projections": args.enable_fused_projections,
"upcast_vae": args.upcast_vae,
"batch_size": args.batch_size,
Expand All @@ -89,6 +91,7 @@ def generate_csv_dict(
"time (secs)": time,
"memory (gbs)": memory,
"actual_gpu_memory (gbs)": f"{(TOTAL_GPU_MEMORY):.3f}",
"tag": args.tag,
}
return data_dict

Expand Down
76 changes: 42 additions & 34 deletions utils/pipeline_utils.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,29 @@
import torch

from torchao.quantization import change_linear_weights_to_int8_woqtensors, change_linear_weights_to_int8_dqtensors, change_linear_weights_to_int4_woqtensors, swap_conv2d_1x1_to_linear
from diffusers import AutoencoderKL, DiffusionPipeline

def dynamic_quant_filter_fn(mod, *args):
return isinstance(mod, torch.nn.Linear) and mod.in_features > 16 and not (mod.in_features, mod.out_features) in [
(320, 640),
(320, 1280),
(2816, 1280),
(1280, 640),
(1280, 320),
(512, 512),
(512, 1536),
(2048, 2560),
(2048, 1280),
]


CKPT_ID = "stabilityai/stable-diffusion-xl-base-1.0"
PROMPT = "ghibli style, a fantasy landscape with castles"


def apply_dynamic_quant_fn(m):
"""Applies weight-only and dynamic quantization in a selective manner."""
from torchao.quantization.dynamic_quant import DynamicallyPerAxisQuantizedLinear
from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter
from torchao.quantization.weight_only import WeightOnlyInt8QuantLinear

def from_float(mod):
if hasattr(mod, "lora_layer"):
assert mod.lora_layer is None
# if mod.weight.size(1) == 1280 and mod.weight.size(0) == 1280:
# return WeightOnlyInt8QuantLinear.from_float(mod)
# if mod.weight.size(1) == 640 and mod.weight.size(0) == 640:
# return WeightOnlyInt8QuantLinear.from_float(mod)
if mod.weight.size(1) == 5120 and mod.weight.size(0) == 1280:
return DynamicallyPerAxisQuantizedLinear.from_float(mod)
# if mod.weight.size(1) == 2560 and mod.weight.size(0) == 640:
# return DynamicallyPerAxisQuantizedLinear.from_float(mod)
return mod

_replace_with_custom_fn_if_matches_filter(
m,
from_float,
lambda mod, fqn: isinstance(mod, torch.nn.Linear),
)

# torch._inductor.config.fx_graph_cache = True # speeds up recompile, may reduce performance

def load_pipeline(args):
"""Loads the SDXL pipeline."""
dtype = torch.float32 if args.no_fp16 else torch.float16
dtype = torch.float32 if args.no_bf16 else torch.bfloat16
print(f"Using dtype: {dtype}")
pipe = DiffusionPipeline.from_pretrained(CKPT_ID, torch_dtype=dtype, use_safetensors=True)

Expand All @@ -59,36 +47,56 @@ def load_pipeline(args):
if args.compile_unet:
pipe.unet.to(memory_format=torch.channels_last)
print("Compile UNet")

swap_conv2d_1x1_to_linear(pipe.unet)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How much does it help? Should we make it configurable?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not a lot to be honest, the shapes looked like they would be good to apply quantization to but it seems the improvements are relatively minor. Note i've altered the api's so they can be called on individual modules if needed so you could call that on individual conv modules we'd like to target.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also,the code has torch._inductor.config.conv_1x1_as_mm = True so its already doing this, but just at a point after we apply quantization. Its unclear if the two different ways of doing conv1x1->linear have different perf on a shape to shape basis but overall we can see that the perf impact is relatively minor from perf results

if args.compile_mode == "max-autotune" and args.change_comp_config:
torch._inductor.config.conv_1x1_as_mm = True
torch._inductor.config.coordinate_descent_tuning = True
torch._inductor.config.epilogue_fusion = False
torch._inductor.config.coordinate_descent_check_all_directions = True

if args.do_quant:
print("Apply quantization to UNet")
apply_dynamic_quant_fn(pipe.unet)
if args.do_quant == "int4weightonly":
change_linear_weights_to_int4_woqtensors(pipe.unet)
elif args.do_quant == "int8weightonly":
change_linear_weights_to_int8_woqtensors(pipe.unet)
elif args.do_quant == "int8dynamic":
change_linear_weights_to_int8_dqtensors(pipe.unet, dynamic_quant_filter_fn)
else:
raise ValueError(f"Unknown do_quant value: {args.do_quant}.")
torch._inductor.config.force_fuse_int_mm_with_mul = True
torch._inductor.config.use_mixed_mm = True

if args.compile_mode == "max-autotune":
pipe.unet = torch.compile(pipe.unet, mode=args.compile_mode)
pipe.unet = torch.compile(pipe.unet, mode=args.compile_mode, fullgraph=True)
else:
pipe.unet = torch.compile(pipe.unet, mode=args.compile_mode, fullgraph=True)

if args.compile_vae:
pipe.vae.to(memory_format=torch.channels_last)
print("Compile VAE")

swap_conv2d_1x1_to_linear(pipe.vae)
if args.compile_mode == "max-autotune" and args.change_comp_config:
torch._inductor.config.conv_1x1_as_mm = True
torch._inductor.config.coordinate_descent_tuning = True
torch._inductor.config.epilogue_fusion = False
torch._inductor.config.coordinate_descent_check_all_directions = True

if args.do_quant:
print("Apply quantization to VAE")
apply_dynamic_quant_fn(pipe.vae)
if args.do_quant == "int4weightonly":
change_linear_weights_to_int4_woqtensors(pipe.vae)
elif args.do_quant == "int8weightonly":
change_linear_weights_to_int8_woqtensors(pipe.vae)
elif args.do_quant == "int8dynamic":
change_linear_weights_to_int8_dqtensors(pipe.vae, dynamic_quant_filter_fn)
else:
raise ValueError(f"Unknown do_quant value: {args.do_quant}.")
torch._inductor.config.force_fuse_int_mm_with_mul = True
torch._inductor.config.use_mixed_mm = True

if args.compile_mode == "max-autotune":
pipe.vae.decode = torch.compile(pipe.vae.decode, mode=args.compile_mode)
pipe.vae.decode = torch.compile(pipe.vae.decode, mode=args.compile_mode, fullgraph=True)
else:
pipe.vae.decode = torch.compile(pipe.vae.decode, mode=args.compile_mode, fullgraph=True)

Expand Down