Skip to content

Commit

Permalink
Fix BERT benchmark for 2 gcd (facebookincubator#6)
Browse files Browse the repository at this point in the history
* fixed batch_size > 1

* load so file for benchmark
  • Loading branch information
zjing14 committed Oct 23, 2022
1 parent 3b6a195 commit 19f94f2
Show file tree
Hide file tree
Showing 5 changed files with 54 additions and 33 deletions.
2 changes: 1 addition & 1 deletion 3rdparty/composable_kernel
62 changes: 40 additions & 22 deletions examples/03_bert/benchmark_ait.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ def compile_module(
use_fp16_acc: bool,
encoders_only: bool,
pt_model: torch.nn.Module,
benchmark: bool
) -> None:
model_name = f"BERT_{activation}_{batch_size}_{seq_length}"
target = detect_target(use_fp16_acc=use_fp16_acc)
Expand All @@ -207,7 +208,10 @@ def compile_module(

params = map_pt_params(model, pt_model, batch_size, seq_length)

mod = compile_model(y, target, "./tmp", model_name)
if benchmark:
mod = Model(os.path.join("./tmp", model_name, "test.so"))
else:
mod = compile_model(y, target, "./tmp", model_name)

for k, v in params.items():
mod.set_constant_with_tensor(k, v)
Expand Down Expand Up @@ -267,30 +271,44 @@ def compile_and_benchmark(
pt_model.eval()
hidden_size = pt_model.config.hidden_size

if batch_size < 1:
batch_sizes = [1, 2, 4, 8, 16, 32, 64, 128, 256]
if batch_size >= 1 and seq_length >= 1:
mod = compile_module(
batch_size,
seq_length,
hidden_size,
activation,
use_fp16_acc,
encoders_only,
pt_model,
1,
)
benchmark(batch_size, seq_length, hidden_size, mod, graph_mode, encoders_only)
else:
batch_sizes = [batch_size]
if batch_size < 1:
batch_sizes = [1, 2, 4, 8, 16, 32, 64, 128, 256]
else:
batch_sizes = [batch_size]

if seq_length < 1:
seq_lengths = (
[64, 128, 384, 512, 1024, 4096] if encoders_only else [64, 128, 384, 512]
)
else:
seq_lengths = [seq_length]

for seq_length in seq_lengths:
for bs in batch_sizes:
mod = compile_module(
bs,
seq_length,
hidden_size,
activation,
use_fp16_acc,
encoders_only,
pt_model,
if seq_length < 1:
seq_lengths = (
[64, 128, 384, 512, 1024, 4096] if encoders_only else [64, 128, 384, 512]
)
benchmark(bs, seq_length, hidden_size, mod, graph_mode, encoders_only)
else:
seq_lengths = [seq_length]

for sq in seq_lengths:
for bs in batch_sizes:
mod = compile_module(
bs,
sq,
hidden_size,
activation,
use_fp16_acc,
encoders_only,
pt_model,
0,
)
benchmark(bs, seq_length, hidden_size, mod, graph_mode, encoders_only)


if __name__ == "__main__":
Expand Down
8 changes: 4 additions & 4 deletions examples/03_bert/benchmark_mi250.sh
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
#!/bin/bash

#profile
HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python3 benchmark_ait.py
#HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python3 benchmark_ait.py

#1GCD
HIP_VISIBLE_DEVICES=0 python3 benchmark_ait.py --batch-size $1
#HIP_VISIBLE_DEVICES=0 python3 benchmark_ait.py --batch-size $1 --seq_length $2

#2GCD
HIP_VISIBLE_DEVICES=0 python3 benchmark_ait.py --batch-size $1 &
HIP_VISIBLE_DEVICES=1 python3 benchmark_ait.py --batch-size $1 && fg
HIP_VISIBLE_DEVICES=0 python3 benchmark_ait.py --batch-size $1 --seq-length $2 &
HIP_VISIBLE_DEVICES=1 python3 benchmark_ait.py --batch-size $1 --seq-length $2 && fg
6 changes: 3 additions & 3 deletions examples/05_stable_diffusion/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def benchmark_unet(

latent_model_input_pt = torch.randn(batch_size, 4, hh, ww).cuda().half()
text_embeddings_pt = torch.randn(batch_size, 64, 768).cuda().half()
timesteps_pt = torch.Tensor([1, 1]).cuda().half()
timesteps_pt = torch.Tensor([1] * batch_size).cuda().half()

with autocast("cuda"):
pt_ys = pt_mod(
Expand Down Expand Up @@ -148,7 +148,7 @@ def benchmark_clip(

tokenizer = CLIPTokenizer.from_pretrained(version)
text_input = tokenizer(
["a photo of an astronaut riding a horse on mars"],
["a photo of an astronaut riding a horse on mars"] * batch_size,
padding="max_length",
max_length=seqlen,
truncation=True,
Expand Down Expand Up @@ -278,7 +278,7 @@ def benchmark_vae(batch_size=1, height=64, width=64, benchmark_pt=False, verify=
@click.option("--verify", type=bool, default=False, help="verify correctness")
@click.option("--benchmark-pt", type=bool, default=False, help="run pt benchmark")
def benchmark_diffusers(token, batch_size, verify, benchmark_pt):
assert batch_size == 1, "batch size must be 1 for submodule verification"
#assert batch_size == 1, "batch size must be 1 for submodule verification"
logging.getLogger().setLevel(logging.INFO)
np.random.seed(0)
torch.manual_seed(4896)
Expand Down
9 changes: 6 additions & 3 deletions examples/05_stable_diffusion/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,22 +21,25 @@

@click.command()
@click.option("--token", default="", help="access token")
@click.option("--batch-size", default=1, help="batch size")
@click.option("--prompt", default="A vision of paradise, Unreal Engine", help="prompt")
@click.option(
"--benchmark", type=bool, default=False, help="run stable diffusion e2e benchmark"
)
def run(token, prompt, benchmark):
def run(token, batch_size, prompt, benchmark):
pipe = StableDiffusionAITPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4",
revision="fp16",
torch_dtype=torch.float16,
use_auth_token=token,
).to("cuda")

prompts = [prompt] * batch_size

with torch.autocast("cuda"):
image = pipe(prompt).images[0]
image = pipe(prompts).images[0]
if benchmark:
t = benchmark_torch_function(10, pipe, prompt)
t = benchmark_torch_function(10, pipe, prompts)
print(f"sd e2e: {t} ms")

image.save("example_ait.png")
Expand Down

0 comments on commit 19f94f2

Please sign in to comment.