Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Lowering for FlexAttention Backwards #125515

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
171 changes: 115 additions & 56 deletions benchmarks/transformer/score_mod.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from collections import defaultdict
from dataclasses import asdict, dataclass
from functools import partial
from typing import Callable, List
from typing import Callable, List, Optional, Tuple

import numpy as np
import torch
Expand All @@ -29,58 +29,64 @@ def benchmark_torch_function_in_microseconds(func: Callable, *args, **kwargs) ->

@dataclass(frozen=True)
class ExperimentConfig:
batch_size: int
num_heads: int
q_seq_len: int
k_seq_len: int
head_dim: int
shape: Tuple[int]
score_mod: Callable
dtype: torch.dtype
calculate_bwd_time: bool

def __post_init__(self):
assert len(self.shape) == 4, "Shape must be of length 4"

def asdict(self):
return asdict(self)
# Convert the dataclass instance to a dictionary
d = asdict(self)
# Remove the 'calculate_bwd_time' key
d.pop("calculate_bwd_time", None)
return d


@dataclass(frozen=True)
class ExperimentResults:
class Times:
eager_time: float
compiled_time: float

def get_entries(self) -> List:
return [
f"{self.eager_time:2f}",
f"{self.compiled_time:2f}",
]

@dataclass(frozen=True)
class ExperimentResults:
fwd_times: Times
bwd_times: Optional[Times]


@dataclass(frozen=True)
class Experiment:
config: ExperimentConfig
results: ExperimentResults

def get_entries(self) -> List:
return self.config.get_entries() + self.results.get_entries()

def asdict(self):
dict1 = asdict(self.config)
dict1 = self.config.asdict()
dict2 = asdict(self.results)
return {**dict1, **dict2}


def generate_inputs(
batch_size,
num_heads,
q_sequence_length,
kv_sequence_length,
head_dim,
dtype,
device,
batch_size: int,
num_heads: int,
q_sequence_length: int,
kv_sequence_length: int,
head_dim: int,
dtype: torch.dtype,
device: torch.device,
requires_grad: bool,
):
q_shape = (batch_size, q_sequence_length, num_heads * head_dim)
kv_shape = (batch_size, kv_sequence_length, num_heads * head_dim)

make_q = partial(torch.rand, q_shape, device=device, dtype=dtype)
make_kv = partial(torch.rand, kv_shape, device=device, dtype=dtype)
make_q = partial(
torch.rand, q_shape, device=device, dtype=dtype, requires_grad=requires_grad
)
make_kv = partial(
torch.rand, kv_shape, device=device, dtype=dtype, requires_grad=requires_grad
)
query = (
make_q()
.view(batch_size, q_sequence_length, num_heads, head_dim)
Expand All @@ -101,14 +107,16 @@ def generate_inputs(

def run_single_experiment(config: ExperimentConfig, dynamic=False) -> ExperimentResults:
device = torch.device("cuda")
batch_size, num_heads, q_seq_len, head_dim = config.shape
query, key, value = generate_inputs(
config.batch_size,
config.num_heads,
config.q_seq_len,
config.k_seq_len,
config.head_dim,
batch_size,
num_heads,
q_seq_len,
q_seq_len,
head_dim,
config.dtype,
device,
requires_grad=config.calculate_bwd_time,
)

def eager_sdpa(query, key, value, _):
Expand All @@ -125,23 +133,47 @@ def eager_sdpa(query, key, value, _):
compiled_sdpa, query, key, value, score_mod
)

return ExperimentResults(
eager_time=forward_eager_time,
compiled_time=forward_compiled_time,
)
if config.calculate_bwd_time:
out_eager = eager_sdpa(query, key, value, score_mod)
dOut = torch.randn_like(out_eager)
backward_eager_time = benchmark_torch_function_in_microseconds(
out_eager.backward, dOut, retain_graph=True
)

out_compile = compiled_sdpa(query, key, value, score_mod)
dOut = torch.randn_like(out_eager)
backward_compile_time = benchmark_torch_function_in_microseconds(
out_compile.backward, dOut, retain_graph=True
)

return ExperimentResults(
fwd_times=Times(forward_eager_time, forward_compiled_time),
bwd_times=Times(backward_eager_time, backward_compile_time),
)
else:
return ExperimentResults(
fwd_times=Times(forward_eager_time, forward_compiled_time),
bwd_times=None,
)


def calculate_speedup(results: ExperimentResults) -> float:
return results.eager_time / results.compiled_time
def calculate_speedup(results: ExperimentResults, type: str) -> float:
if type == "fwd":
return results.fwd_times.eager_time / results.fwd_times.compiled_time
elif type == "bwd":
assert results.bwd_times is not None
return results.bwd_times.eager_time / results.bwd_times.compiled_time
else:
raise ValueError(f"Invalid type {type}")


def get_func_name(func):
return func.__name__.split("<locals>.")[-1].split(" at ")[0]


def get_average_speedups(results: List[Experiment]):
def get_average_speedups(results: List[Experiment], type: str):
# Calculate speedups
speedups = [calculate_speedup(r.results) for r in results]
speedups = [calculate_speedup(r.results, type) for r in results]

# Find indices of max and min speedups
max_speedup_index = np.argmax(speedups)
Expand Down Expand Up @@ -177,20 +209,39 @@ def print_results(results: List[Experiment]):
table_data = defaultdict(list)
for experiment in results:
for key, value in experiment.asdict().items():
if key == "eager_time" or key == "compiled_time":
value = float(value)
table_data[key].append(value)
if key == "fwd_times":
for name, time in value.items():
table_data[f"fwd_{name}"].append(float(time))
elif key == "bwd_times":
if experiment.config.calculate_bwd_time:
for name, time in value.items():
table_data[f"bwd_{name}"].append(float(time))
else:
table_data[key].append(value)

# Calculate speedups
speedups = [calculate_speedup(r.results) for r in results]
table_data["speedup"] = speedups
fwd_speedups = [calculate_speedup(r.results, type="fwd") for r in results]
table_data["fwd_speedup"] = fwd_speedups
if results[0].config.calculate_bwd_time:
bwd_speedups = [calculate_speedup(r.results, type="bwd") for r in results]
table_data["bwd_speedup"] = bwd_speedups

table_data["score_mod"] = [get_func_name(func) for func in table_data["score_mod"]]
print(tabulate(table_data, headers="keys", tablefmt="github", floatfmt=".3f"))

average_data = get_average_speedups(results)
print("\n")
print("FWD Speedups".center(125, "="))
print("\n")
average_data = get_average_speedups(results, type="fwd")
print(tabulate(average_data, headers="keys", tablefmt="github", floatfmt=".3f"))

if results[0].config.calculate_bwd_time:
print("\n")
print("BWD Speedups".center(125, "="))
print("\n")
average_data = get_average_speedups(results, type="bwd")
print(tabulate(average_data, headers="keys", tablefmt="github", floatfmt=".3f"))


def generate_score_mods() -> List[Callable]:
def noop(score, b, h, m, n):
Expand All @@ -208,8 +259,8 @@ def head_bias(score, b, h, m, n):
return [noop, causal_mask, relative_bias, head_bias]


def generate_experiment_configs() -> List[ExperimentConfig]:
batch_sizes = [1, 8, 16]
def generate_experiment_configs(calculate_bwd: bool) -> List[ExperimentConfig]:
batch_sizes = [2, 8, 16]
num_heads = [16]
q_kv_seq_lens = [(512, 512), (1024, 1024), (4096, 4096)]
head_dims = [64, 128, 256]
Expand All @@ -228,41 +279,49 @@ def generate_experiment_configs() -> List[ExperimentConfig]:
) in itertools.product(
batch_sizes, num_heads, q_kv_seq_lens, head_dims, score_mods, dtypes
):
assert q_seq_len == kv_seq_len, "Only equal length inputs supported for now."
all_configs.append(
ExperimentConfig(
batch_size=bsz,
num_heads=n_heads,
q_seq_len=q_seq_len,
k_seq_len=kv_seq_len,
head_dim=head_dim,
shape=(bsz, n_heads, q_seq_len, head_dim),
score_mod=score_mod,
dtype=dtype,
calculate_bwd_time=calculate_bwd,
)
)

return all_configs


def main(dynamic=False):
def main(dynamic: bool, calculate_bwd: bool):
seed = 123
np.random.seed(seed)
torch.manual_seed(seed)
results = []
for config in tqdm(generate_experiment_configs()):
for config in tqdm(generate_experiment_configs(calculate_bwd)):
results.append(
Experiment(config, run_single_experiment(config, dynamic=dynamic))
)
for config in tqdm(generate_experiment_configs(calculate_bwd)):
results.append(Experiment(config, run_single_experiment(config)))

print_results(results)


if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Set up the argument parser
parser = argparse.ArgumentParser(
description="Run sweep over sizes and score mods for flex attention"
)
parser.add_argument(
"--dynamic",
action="store_true",
help="Runs a dynamic shapes version of compiled flex attention.",
)
parser.add_argument(
"--calculate-bwd", action="store_true", help="Calculate backward pass times"
)

# Parse arguments
args = parser.parse_args()
main(args.dynamic)

main(args.dynamic, args.calculate_bwd)