-
Notifications
You must be signed in to change notification settings - Fork 25.6k
[inductor][addmm] remove inp(unexpanded) path #161208
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
\# why - simplifies the code a lot - unnecessary for performance (see below) - matches other mm family kernels in logic now \# what - if we're not autotuning (only ATen), we use a flexible layout - in any case, we use inp_expanded (the expanded bias view) for the aten kernel \# performance analysis \## results this is on H100 ``` ================================================================================ BENCHMARK SUMMARY (MERGED) ================================================================================ Config Dim Type (M, N, K) DType Bias Runtime (inp_expanded) ms Runtime (inp) ms ---------------------------------------------------------------------------------------------------------------------------- 0 large_M (65536, 8192, 8192) float16 full 14.531 14.562 1 large_M (65536, 8192, 8192) float16 row 14.682 14.675 2 large_M (65536, 8192, 8192) float16 column 14.754 14.740 3 large_M (65536, 8192, 8192) bfloat16 full 15.172 15.148 4 large_M (65536, 8192, 8192) bfloat16 row 15.072 15.085 5 large_M (65536, 8192, 8192) bfloat16 column 15.082 15.114 6 large_M (65536, 8192, 8192) float32 full 185.726 186.308 7 large_M (65536, 8192, 8192) float32 row 185.042 185.864 8 large_M (65536, 8192, 8192) float32 column 185.221 185.991 9 large_M (32768, 4096, 4096) float16 full 2.025 2.036 10 large_M (32768, 4096, 4096) float16 row 2.029 2.033 11 large_M (32768, 4096, 4096) float16 column 2.036 2.047 12 large_M (32768, 4096, 4096) bfloat16 full 1.966 1.981 13 large_M (32768, 4096, 4096) bfloat16 row 1.963 1.979 14 large_M (32768, 4096, 4096) bfloat16 column 1.973 1.981 15 large_M (32768, 4096, 4096) float32 full 24.096 24.180 16 large_M (32768, 4096, 4096) float32 row 23.951 24.033 17 large_M (32768, 4096, 4096) float32 column 23.996 24.061 18 large_M (16384, 2048, 2048) float16 full 0.297 0.298 19 large_M (16384, 2048, 2048) float16 row 0.298 0.299 20 large_M (16384, 2048, 2048) float16 column 0.301 0.300 21 large_M (16384, 2048, 2048) bfloat16 full 0.293 0.293 22 large_M (16384, 2048, 2048) bfloat16 row 0.290 0.291 23 large_M (16384, 2048, 2048) bfloat16 column 0.292 0.293 24 large_M (16384, 2048, 2048) float32 full 3.077 3.073 25 large_M (16384, 2048, 2048) float32 row 3.034 3.033 26 large_M (16384, 2048, 2048) float32 column 3.040 3.038 27 large_K (8192, 8192, 65536) float16 full 14.278 14.297 28 large_K (8192, 8192, 65536) float16 row 14.325 14.283 29 large_K (8192, 8192, 65536) float16 column 14.179 14.302 30 large_K (8192, 8192, 65536) bfloat16 full 13.616 13.628 31 large_K (8192, 8192, 65536) bfloat16 row 13.584 13.642 32 large_K (8192, 8192, 65536) bfloat16 column 13.594 13.694 33 large_K (8192, 8192, 65536) float32 full 175.933 176.153 34 large_K (8192, 8192, 65536) float32 row 175.504 175.877 35 large_K (8192, 8192, 65536) float32 column 175.432 175.992 36 large_K (4096, 4096, 32768) float16 full 1.726 1.724 37 large_K (4096, 4096, 32768) float16 row 1.731 1.735 38 large_K (4096, 4096, 32768) float16 column 1.733 1.737 39 large_K (4096, 4096, 32768) bfloat16 full 1.662 1.658 40 large_K (4096, 4096, 32768) bfloat16 row 1.664 1.655 41 large_K (4096, 4096, 32768) bfloat16 column 1.660 1.667 42 large_K (4096, 4096, 32768) float32 full 22.263 22.305 43 large_K (4096, 4096, 32768) float32 row 22.257 22.322 44 large_K (4096, 4096, 32768) float32 column 22.247 22.337 45 large_K (2048, 2048, 16384) float16 full 0.236 0.236 46 large_K (2048, 2048, 16384) float16 row 0.238 0.239 47 large_K (2048, 2048, 16384) float16 column 0.238 0.239 48 large_K (2048, 2048, 16384) bfloat16 full 0.219 0.219 49 large_K (2048, 2048, 16384) bfloat16 row 0.221 0.222 50 large_K (2048, 2048, 16384) bfloat16 column 0.222 0.222 51 large_K (2048, 2048, 16384) float32 full 2.786 2.789 52 large_K (2048, 2048, 16384) float32 row 2.790 2.782 53 large_K (2048, 2048, 16384) float32 column 2.791 2.791 54 large_N (8192, 65536, 8192) float16 full 14.692 14.723 55 large_N (8192, 65536, 8192) float16 row 14.721 14.637 56 large_N (8192, 65536, 8192) float16 column 14.743 14.737 57 large_N (8192, 65536, 8192) bfloat16 full 15.156 15.128 58 large_N (8192, 65536, 8192) bfloat16 row 15.152 15.124 59 large_N (8192, 65536, 8192) bfloat16 column 15.112 15.090 60 large_N (8192, 65536, 8192) float32 full 179.127 179.313 61 large_N (8192, 65536, 8192) float32 row 178.445 178.961 62 large_N (8192, 65536, 8192) float32 column 178.693 178.694 63 large_N (4096, 32768, 4096) float16 full 2.035 2.037 64 large_N (4096, 32768, 4096) float16 row 2.042 2.037 65 large_N (4096, 32768, 4096) float16 column 2.053 2.046 66 large_N (4096, 32768, 4096) bfloat16 full 1.992 1.997 67 large_N (4096, 32768, 4096) bfloat16 row 1.997 1.987 68 large_N (4096, 32768, 4096) bfloat16 column 2.005 2.001 69 large_N (4096, 32768, 4096) float32 full 23.126 23.077 70 large_N (4096, 32768, 4096) float32 row 23.002 22.956 71 large_N (4096, 32768, 4096) float32 column 23.012 22.969 72 large_N (2048, 16384, 2048) float16 full 0.314 0.314 73 large_N (2048, 16384, 2048) float16 row 0.311 0.311 74 large_N (2048, 16384, 2048) float16 column 0.314 0.314 75 large_N (2048, 16384, 2048) bfloat16 full 0.306 0.305 76 large_N (2048, 16384, 2048) bfloat16 row 0.302 0.302 77 large_N (2048, 16384, 2048) bfloat16 column 0.305 0.303 78 large_N (2048, 16384, 2048) float32 full 2.975 2.971 79 large_N (2048, 16384, 2048) float32 row 2.927 2.925 80 large_N (2048, 16384, 2048) float32 column 2.934 2.936 81 large_all (16384, 16384, 16384) float16 full 14.062 14.096 82 large_all (16384, 16384, 16384) float16 row 14.058 14.078 83 large_all (16384, 16384, 16384) float16 column 14.107 14.120 84 large_all (16384, 16384, 16384) bfloat16 full 13.504 13.460 85 large_all (16384, 16384, 16384) bfloat16 row 13.495 13.499 86 large_all (16384, 16384, 16384) bfloat16 column 13.509 13.461 87 large_all (16384, 16384, 16384) float32 full 177.279 177.242 88 large_all (16384, 16384, 16384) float32 row 176.896 176.651 89 large_all (16384, 16384, 16384) float32 column 176.830 176.451 ``` \## script ``` """ Torch addmm benchmarking script covering different input configurations. Tests 30 different combinations of: - Input types: large M, large K, large N, large everything - Data types: float16, bfloat16, float32 - Bias types: full, row, column """ import itertools import torch import torch.utils.benchmark as benchmark def create_test_configurations(): """Create 30 different test configurations for torch.addmm""" # Define dimension configurations # Large M: many rows in input matrix # Large K: many columns in input/rows in weight # Large N: many columns in weight matrix # Large everything: all dimensions large dim_configs = [ # Large M configurations {"M": 65536, "K": 8192, "N": 8192, "type": "large_M"}, {"M": 32768, "K": 4096, "N": 4096, "type": "large_M"}, {"M": 16384, "K": 2048, "N": 2048, "type": "large_M"}, # Large K configurations {"M": 8192, "K": 65536, "N": 8192, "type": "large_K"}, {"M": 4096, "K": 32768, "N": 4096, "type": "large_K"}, {"M": 2048, "K": 16384, "N": 2048, "type": "large_K"}, # Large N configurations {"M": 8192, "K": 8192, "N": 65536, "type": "large_N"}, {"M": 4096, "K": 4096, "N": 32768, "type": "large_N"}, {"M": 2048, "K": 2048, "N": 16384, "type": "large_N"}, # Large everything configurations {"M": 16384, "K": 16384, "N": 16384, "type": "large_all"}, ] # Data types to test dtypes = [torch.float16, torch.bfloat16, torch.float32] # Bias configurations bias_configs = ["full", "row", "column"] # Generate all combinations configurations = [] config_id = 0 for dim_config, dtype, bias_type in itertools.product( dim_configs, dtypes, bias_configs ): config = { "id": config_id, "M": dim_config["M"], "K": dim_config["K"], "N": dim_config["N"], "dim_type": dim_config["type"], "dtype": dtype, "bias_type": bias_type, } configurations.append(config) config_id += 1 return configurations def create_tensors(config): """Create input tensors for a given configuration""" M, K, N = config["M"], config["K"], config["N"] dtype = config["dtype"] bias_type = config["bias_type"] # Create input tensor (M x K) input_tensor = torch.randn(M, K, dtype=dtype, device="cuda") # Create weight tensor (K x N) weight_tensor = torch.randn(K, N, dtype=dtype, device="cuda") # Create bias tensor based on bias type if bias_type == "full": bias_tensor = torch.randn(M, N, dtype=dtype, device="cuda") elif bias_type == "row": bias_tensor = torch.randn(M, 1, dtype=dtype, device="cuda") elif bias_type == "column": bias_tensor = torch.randn(1, N, dtype=dtype, device="cuda") return input_tensor, weight_tensor, bias_tensor def benchmark_addmm(config, use_compile=False): """Benchmark torch.addmm for a given configuration""" input_tensor, weight_tensor, bias_tensor = create_tensors(config) # Define the operation def addmm_op(): return torch.addmm(bias_tensor, input_tensor, weight_tensor) # Optionally compile the operation if use_compile: addmm_op = torch.compile(addmm_op) # Warmup for compiled version for _ in range(3): _ = addmm_op() torch.cuda.synchronize() # Benchmark using torch.utils.benchmark timer = benchmark.Timer( stmt="addmm_op()", globals={"addmm_op": addmm_op}, description=f"addmm_{config['dim_type']}_{config['dtype']}_{config['bias_type']}_compiled_{use_compile}", ) # Run benchmark measurement = timer.blocked_autorange(min_run_time=1.0) return measurement def print_tensor_info(config, input_tensor, weight_tensor, bias_tensor): """Print information about input tensors""" print(f"\\nConfiguration {config['id']}:") print(f" Dimension type: {config['dim_type']}") print(f" Data type: {config['dtype']}") print(f" Bias type: {config['bias_type']}") print(f" Input tensor shape: {input_tensor.shape} ({input_tensor.dtype})") print(f" Weight tensor shape: {weight_tensor.shape} ({weight_tensor.dtype})") print(f" Bias tensor shape: {bias_tensor.shape} ({bias_tensor.dtype})") def main(): """Main benchmarking function""" print("Torch addmm Benchmarking Script (Compiled Only)") print("=" * 50) # Check CUDA availability if not torch.cuda.is_available(): print("CUDA not available. This benchmark requires CUDA.") return print(f"Using device: {torch.cuda.get_device_name()}") print(f"PyTorch version: {torch.__version__}") # Create test configurations configurations = create_test_configurations() print(f"\\nTesting {len(configurations)} configurations...") results = [] for config in configurations: print(f"\\n{'='*60}") # Create tensors and print info input_tensor, weight_tensor, bias_tensor = create_tensors(config) print_tensor_info(config, input_tensor, weight_tensor, bias_tensor) # Benchmark with compilation only print("\\nBenchmarking with torch.compile:") try: measurement = benchmark_addmm(config, use_compile=True) runtime_ms = measurement.mean * 1000 print(f" Runtime: {runtime_ms:.3f} ms") results.append( { "config": config, "runtime_ms": runtime_ms, } ) except Exception as e: print(f" Error: {e}") # Clear cache torch.cuda.empty_cache() # Print summary print(f"\\n\\n{'='*80}") print("BENCHMARK SUMMARY") print(f"{'='*80}") print( f"{'Config':<8} {'Dim Type':<12} {'(M, N, K)':<20} {'DType':<12} {'Bias':<8} {'Runtime (ms)':<15}" ) print("-" * 90) for result in results: config = result["config"] dtype_str = str(config["dtype"]).split(".")[-1] dimensions_str = f"({config['M']}, {config['N']}, {config['K']})" print( f"{config['id']:<8} {config['dim_type']:<12} {dimensions_str:<20} {dtype_str:<12} {config['bias_type']:<8} {result['runtime_ms']:<15.3f}" ) if __name__ == "__main__": main() ``` [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/161208
Note: Links to docs will display an error until the docs builds have been completed. ❌ 8 New FailuresAs of commit bb9ebbf with merge base 38a492d ( NEW FAILURES - The following jobs have failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
coconutruben
added a commit
that referenced
this pull request
Aug 21, 2025
\# why - simplifies the code a lot - unnecessary for performance (see below) - matches other mm family kernels in logic now \# what - if we're not autotuning (only ATen), we use a flexible layout - in any case, we use inp_expanded (the expanded bias view) for the aten kernel \# performance analysis \## results this is on H100 ``` ================================================================================ BENCHMARK SUMMARY (MERGED) ================================================================================ Config Dim Type (M, N, K) DType Bias Runtime (inp_expanded) ms Runtime (inp) ms ---------------------------------------------------------------------------------------------------------------------------- 0 large_M (65536, 8192, 8192) float16 full 14.531 14.562 1 large_M (65536, 8192, 8192) float16 row 14.682 14.675 2 large_M (65536, 8192, 8192) float16 column 14.754 14.740 3 large_M (65536, 8192, 8192) bfloat16 full 15.172 15.148 4 large_M (65536, 8192, 8192) bfloat16 row 15.072 15.085 5 large_M (65536, 8192, 8192) bfloat16 column 15.082 15.114 6 large_M (65536, 8192, 8192) float32 full 185.726 186.308 7 large_M (65536, 8192, 8192) float32 row 185.042 185.864 8 large_M (65536, 8192, 8192) float32 column 185.221 185.991 9 large_M (32768, 4096, 4096) float16 full 2.025 2.036 10 large_M (32768, 4096, 4096) float16 row 2.029 2.033 11 large_M (32768, 4096, 4096) float16 column 2.036 2.047 12 large_M (32768, 4096, 4096) bfloat16 full 1.966 1.981 13 large_M (32768, 4096, 4096) bfloat16 row 1.963 1.979 14 large_M (32768, 4096, 4096) bfloat16 column 1.973 1.981 15 large_M (32768, 4096, 4096) float32 full 24.096 24.180 16 large_M (32768, 4096, 4096) float32 row 23.951 24.033 17 large_M (32768, 4096, 4096) float32 column 23.996 24.061 18 large_M (16384, 2048, 2048) float16 full 0.297 0.298 19 large_M (16384, 2048, 2048) float16 row 0.298 0.299 20 large_M (16384, 2048, 2048) float16 column 0.301 0.300 21 large_M (16384, 2048, 2048) bfloat16 full 0.293 0.293 22 large_M (16384, 2048, 2048) bfloat16 row 0.290 0.291 23 large_M (16384, 2048, 2048) bfloat16 column 0.292 0.293 24 large_M (16384, 2048, 2048) float32 full 3.077 3.073 25 large_M (16384, 2048, 2048) float32 row 3.034 3.033 26 large_M (16384, 2048, 2048) float32 column 3.040 3.038 27 large_K (8192, 8192, 65536) float16 full 14.278 14.297 28 large_K (8192, 8192, 65536) float16 row 14.325 14.283 29 large_K (8192, 8192, 65536) float16 column 14.179 14.302 30 large_K (8192, 8192, 65536) bfloat16 full 13.616 13.628 31 large_K (8192, 8192, 65536) bfloat16 row 13.584 13.642 32 large_K (8192, 8192, 65536) bfloat16 column 13.594 13.694 33 large_K (8192, 8192, 65536) float32 full 175.933 176.153 34 large_K (8192, 8192, 65536) float32 row 175.504 175.877 35 large_K (8192, 8192, 65536) float32 column 175.432 175.992 36 large_K (4096, 4096, 32768) float16 full 1.726 1.724 37 large_K (4096, 4096, 32768) float16 row 1.731 1.735 38 large_K (4096, 4096, 32768) float16 column 1.733 1.737 39 large_K (4096, 4096, 32768) bfloat16 full 1.662 1.658 40 large_K (4096, 4096, 32768) bfloat16 row 1.664 1.655 41 large_K (4096, 4096, 32768) bfloat16 column 1.660 1.667 42 large_K (4096, 4096, 32768) float32 full 22.263 22.305 43 large_K (4096, 4096, 32768) float32 row 22.257 22.322 44 large_K (4096, 4096, 32768) float32 column 22.247 22.337 45 large_K (2048, 2048, 16384) float16 full 0.236 0.236 46 large_K (2048, 2048, 16384) float16 row 0.238 0.239 47 large_K (2048, 2048, 16384) float16 column 0.238 0.239 48 large_K (2048, 2048, 16384) bfloat16 full 0.219 0.219 49 large_K (2048, 2048, 16384) bfloat16 row 0.221 0.222 50 large_K (2048, 2048, 16384) bfloat16 column 0.222 0.222 51 large_K (2048, 2048, 16384) float32 full 2.786 2.789 52 large_K (2048, 2048, 16384) float32 row 2.790 2.782 53 large_K (2048, 2048, 16384) float32 column 2.791 2.791 54 large_N (8192, 65536, 8192) float16 full 14.692 14.723 55 large_N (8192, 65536, 8192) float16 row 14.721 14.637 56 large_N (8192, 65536, 8192) float16 column 14.743 14.737 57 large_N (8192, 65536, 8192) bfloat16 full 15.156 15.128 58 large_N (8192, 65536, 8192) bfloat16 row 15.152 15.124 59 large_N (8192, 65536, 8192) bfloat16 column 15.112 15.090 60 large_N (8192, 65536, 8192) float32 full 179.127 179.313 61 large_N (8192, 65536, 8192) float32 row 178.445 178.961 62 large_N (8192, 65536, 8192) float32 column 178.693 178.694 63 large_N (4096, 32768, 4096) float16 full 2.035 2.037 64 large_N (4096, 32768, 4096) float16 row 2.042 2.037 65 large_N (4096, 32768, 4096) float16 column 2.053 2.046 66 large_N (4096, 32768, 4096) bfloat16 full 1.992 1.997 67 large_N (4096, 32768, 4096) bfloat16 row 1.997 1.987 68 large_N (4096, 32768, 4096) bfloat16 column 2.005 2.001 69 large_N (4096, 32768, 4096) float32 full 23.126 23.077 70 large_N (4096, 32768, 4096) float32 row 23.002 22.956 71 large_N (4096, 32768, 4096) float32 column 23.012 22.969 72 large_N (2048, 16384, 2048) float16 full 0.314 0.314 73 large_N (2048, 16384, 2048) float16 row 0.311 0.311 74 large_N (2048, 16384, 2048) float16 column 0.314 0.314 75 large_N (2048, 16384, 2048) bfloat16 full 0.306 0.305 76 large_N (2048, 16384, 2048) bfloat16 row 0.302 0.302 77 large_N (2048, 16384, 2048) bfloat16 column 0.305 0.303 78 large_N (2048, 16384, 2048) float32 full 2.975 2.971 79 large_N (2048, 16384, 2048) float32 row 2.927 2.925 80 large_N (2048, 16384, 2048) float32 column 2.934 2.936 81 large_all (16384, 16384, 16384) float16 full 14.062 14.096 82 large_all (16384, 16384, 16384) float16 row 14.058 14.078 83 large_all (16384, 16384, 16384) float16 column 14.107 14.120 84 large_all (16384, 16384, 16384) bfloat16 full 13.504 13.460 85 large_all (16384, 16384, 16384) bfloat16 row 13.495 13.499 86 large_all (16384, 16384, 16384) bfloat16 column 13.509 13.461 87 large_all (16384, 16384, 16384) float32 full 177.279 177.242 88 large_all (16384, 16384, 16384) float32 row 176.896 176.651 89 large_all (16384, 16384, 16384) float32 column 176.830 176.451 ``` \## script ``` """ Torch addmm benchmarking script covering different input configurations. Tests 30 different combinations of: - Input types: large M, large K, large N, large everything - Data types: float16, bfloat16, float32 - Bias types: full, row, column """ import itertools import torch import torch.utils.benchmark as benchmark def create_test_configurations(): """Create 30 different test configurations for torch.addmm""" # Define dimension configurations # Large M: many rows in input matrix # Large K: many columns in input/rows in weight # Large N: many columns in weight matrix # Large everything: all dimensions large dim_configs = [ # Large M configurations {"M": 65536, "K": 8192, "N": 8192, "type": "large_M"}, {"M": 32768, "K": 4096, "N": 4096, "type": "large_M"}, {"M": 16384, "K": 2048, "N": 2048, "type": "large_M"}, # Large K configurations {"M": 8192, "K": 65536, "N": 8192, "type": "large_K"}, {"M": 4096, "K": 32768, "N": 4096, "type": "large_K"}, {"M": 2048, "K": 16384, "N": 2048, "type": "large_K"}, # Large N configurations {"M": 8192, "K": 8192, "N": 65536, "type": "large_N"}, {"M": 4096, "K": 4096, "N": 32768, "type": "large_N"}, {"M": 2048, "K": 2048, "N": 16384, "type": "large_N"}, # Large everything configurations {"M": 16384, "K": 16384, "N": 16384, "type": "large_all"}, ] # Data types to test dtypes = [torch.float16, torch.bfloat16, torch.float32] # Bias configurations bias_configs = ["full", "row", "column"] # Generate all combinations configurations = [] config_id = 0 for dim_config, dtype, bias_type in itertools.product( dim_configs, dtypes, bias_configs ): config = { "id": config_id, "M": dim_config["M"], "K": dim_config["K"], "N": dim_config["N"], "dim_type": dim_config["type"], "dtype": dtype, "bias_type": bias_type, } configurations.append(config) config_id += 1 return configurations def create_tensors(config): """Create input tensors for a given configuration""" M, K, N = config["M"], config["K"], config["N"] dtype = config["dtype"] bias_type = config["bias_type"] # Create input tensor (M x K) input_tensor = torch.randn(M, K, dtype=dtype, device="cuda") # Create weight tensor (K x N) weight_tensor = torch.randn(K, N, dtype=dtype, device="cuda") # Create bias tensor based on bias type if bias_type == "full": bias_tensor = torch.randn(M, N, dtype=dtype, device="cuda") elif bias_type == "row": bias_tensor = torch.randn(M, 1, dtype=dtype, device="cuda") elif bias_type == "column": bias_tensor = torch.randn(1, N, dtype=dtype, device="cuda") return input_tensor, weight_tensor, bias_tensor def benchmark_addmm(config, use_compile=False): """Benchmark torch.addmm for a given configuration""" input_tensor, weight_tensor, bias_tensor = create_tensors(config) # Define the operation def addmm_op(): return torch.addmm(bias_tensor, input_tensor, weight_tensor) # Optionally compile the operation if use_compile: addmm_op = torch.compile(addmm_op) # Warmup for compiled version for _ in range(3): _ = addmm_op() torch.cuda.synchronize() # Benchmark using torch.utils.benchmark timer = benchmark.Timer( stmt="addmm_op()", globals={"addmm_op": addmm_op}, description=f"addmm_{config['dim_type']}_{config['dtype']}_{config['bias_type']}_compiled_{use_compile}", ) # Run benchmark measurement = timer.blocked_autorange(min_run_time=1.0) return measurement def print_tensor_info(config, input_tensor, weight_tensor, bias_tensor): """Print information about input tensors""" print(f"\\nConfiguration {config['id']}:") print(f" Dimension type: {config['dim_type']}") print(f" Data type: {config['dtype']}") print(f" Bias type: {config['bias_type']}") print(f" Input tensor shape: {input_tensor.shape} ({input_tensor.dtype})") print(f" Weight tensor shape: {weight_tensor.shape} ({weight_tensor.dtype})") print(f" Bias tensor shape: {bias_tensor.shape} ({bias_tensor.dtype})") def main(): """Main benchmarking function""" print("Torch addmm Benchmarking Script (Compiled Only)") print("=" * 50) # Check CUDA availability if not torch.cuda.is_available(): print("CUDA not available. This benchmark requires CUDA.") return print(f"Using device: {torch.cuda.get_device_name()}") print(f"PyTorch version: {torch.__version__}") # Create test configurations configurations = create_test_configurations() print(f"\\nTesting {len(configurations)} configurations...") results = [] for config in configurations: print(f"\\n{'='*60}") # Create tensors and print info input_tensor, weight_tensor, bias_tensor = create_tensors(config) print_tensor_info(config, input_tensor, weight_tensor, bias_tensor) # Benchmark with compilation only print("\\nBenchmarking with torch.compile:") try: measurement = benchmark_addmm(config, use_compile=True) runtime_ms = measurement.mean * 1000 print(f" Runtime: {runtime_ms:.3f} ms") results.append( { "config": config, "runtime_ms": runtime_ms, } ) except Exception as e: print(f" Error: {e}") # Clear cache torch.cuda.empty_cache() # Print summary print(f"\\n\\n{'='*80}") print("BENCHMARK SUMMARY") print(f"{'='*80}") print( f"{'Config':<8} {'Dim Type':<12} {'(M, N, K)':<20} {'DType':<12} {'Bias':<8} {'Runtime (ms)':<15}" ) print("-" * 90) for result in results: config = result["config"] dtype_str = str(config["dtype"]).split(".")[-1] dimensions_str = f"({config['M']}, {config['N']}, {config['K']})" print( f"{config['id']:<8} {config['dim_type']:<12} {dimensions_str:<20} {dtype_str:<12} {config['bias_type']:<8} {result['runtime_ms']:<15.3f}" ) if __name__ == "__main__": main() ``` ghstack-source-id: e788401 Pull Request resolved: #161208
# why - simplifies the code a lot - unnecessary for performance (see below) - matches other mm family kernels in logic now # what - if we're not autotuning (only ATen), we use a flexible layout - in any case, we use inp_expanded (the expanded bias view) for the aten kernel # performance analysis ## results this is on H100 ``` ================================================================================ BENCHMARK SUMMARY (MERGED) ================================================================================ Config Dim Type (M, N, K) DType Bias Runtime (inp_expanded) ms Runtime (inp) ms ---------------------------------------------------------------------------------------------------------------------------- 0 large_M (65536, 8192, 8192) float16 full 14.531 14.562 1 large_M (65536, 8192, 8192) float16 row 14.682 14.675 2 large_M (65536, 8192, 8192) float16 column 14.754 14.740 3 large_M (65536, 8192, 8192) bfloat16 full 15.172 15.148 4 large_M (65536, 8192, 8192) bfloat16 row 15.072 15.085 5 large_M (65536, 8192, 8192) bfloat16 column 15.082 15.114 6 large_M (65536, 8192, 8192) float32 full 185.726 186.308 7 large_M (65536, 8192, 8192) float32 row 185.042 185.864 8 large_M (65536, 8192, 8192) float32 column 185.221 185.991 9 large_M (32768, 4096, 4096) float16 full 2.025 2.036 10 large_M (32768, 4096, 4096) float16 row 2.029 2.033 11 large_M (32768, 4096, 4096) float16 column 2.036 2.047 12 large_M (32768, 4096, 4096) bfloat16 full 1.966 1.981 13 large_M (32768, 4096, 4096) bfloat16 row 1.963 1.979 14 large_M (32768, 4096, 4096) bfloat16 column 1.973 1.981 15 large_M (32768, 4096, 4096) float32 full 24.096 24.180 16 large_M (32768, 4096, 4096) float32 row 23.951 24.033 17 large_M (32768, 4096, 4096) float32 column 23.996 24.061 18 large_M (16384, 2048, 2048) float16 full 0.297 0.298 19 large_M (16384, 2048, 2048) float16 row 0.298 0.299 20 large_M (16384, 2048, 2048) float16 column 0.301 0.300 21 large_M (16384, 2048, 2048) bfloat16 full 0.293 0.293 22 large_M (16384, 2048, 2048) bfloat16 row 0.290 0.291 23 large_M (16384, 2048, 2048) bfloat16 column 0.292 0.293 24 large_M (16384, 2048, 2048) float32 full 3.077 3.073 25 large_M (16384, 2048, 2048) float32 row 3.034 3.033 26 large_M (16384, 2048, 2048) float32 column 3.040 3.038 27 large_K (8192, 8192, 65536) float16 full 14.278 14.297 28 large_K (8192, 8192, 65536) float16 row 14.325 14.283 29 large_K (8192, 8192, 65536) float16 column 14.179 14.302 30 large_K (8192, 8192, 65536) bfloat16 full 13.616 13.628 31 large_K (8192, 8192, 65536) bfloat16 row 13.584 13.642 32 large_K (8192, 8192, 65536) bfloat16 column 13.594 13.694 33 large_K (8192, 8192, 65536) float32 full 175.933 176.153 34 large_K (8192, 8192, 65536) float32 row 175.504 175.877 35 large_K (8192, 8192, 65536) float32 column 175.432 175.992 36 large_K (4096, 4096, 32768) float16 full 1.726 1.724 37 large_K (4096, 4096, 32768) float16 row 1.731 1.735 38 large_K (4096, 4096, 32768) float16 column 1.733 1.737 39 large_K (4096, 4096, 32768) bfloat16 full 1.662 1.658 40 large_K (4096, 4096, 32768) bfloat16 row 1.664 1.655 41 large_K (4096, 4096, 32768) bfloat16 column 1.660 1.667 42 large_K (4096, 4096, 32768) float32 full 22.263 22.305 43 large_K (4096, 4096, 32768) float32 row 22.257 22.322 44 large_K (4096, 4096, 32768) float32 column 22.247 22.337 45 large_K (2048, 2048, 16384) float16 full 0.236 0.236 46 large_K (2048, 2048, 16384) float16 row 0.238 0.239 47 large_K (2048, 2048, 16384) float16 column 0.238 0.239 48 large_K (2048, 2048, 16384) bfloat16 full 0.219 0.219 49 large_K (2048, 2048, 16384) bfloat16 row 0.221 0.222 50 large_K (2048, 2048, 16384) bfloat16 column 0.222 0.222 51 large_K (2048, 2048, 16384) float32 full 2.786 2.789 52 large_K (2048, 2048, 16384) float32 row 2.790 2.782 53 large_K (2048, 2048, 16384) float32 column 2.791 2.791 54 large_N (8192, 65536, 8192) float16 full 14.692 14.723 55 large_N (8192, 65536, 8192) float16 row 14.721 14.637 56 large_N (8192, 65536, 8192) float16 column 14.743 14.737 57 large_N (8192, 65536, 8192) bfloat16 full 15.156 15.128 58 large_N (8192, 65536, 8192) bfloat16 row 15.152 15.124 59 large_N (8192, 65536, 8192) bfloat16 column 15.112 15.090 60 large_N (8192, 65536, 8192) float32 full 179.127 179.313 61 large_N (8192, 65536, 8192) float32 row 178.445 178.961 62 large_N (8192, 65536, 8192) float32 column 178.693 178.694 63 large_N (4096, 32768, 4096) float16 full 2.035 2.037 64 large_N (4096, 32768, 4096) float16 row 2.042 2.037 65 large_N (4096, 32768, 4096) float16 column 2.053 2.046 66 large_N (4096, 32768, 4096) bfloat16 full 1.992 1.997 67 large_N (4096, 32768, 4096) bfloat16 row 1.997 1.987 68 large_N (4096, 32768, 4096) bfloat16 column 2.005 2.001 69 large_N (4096, 32768, 4096) float32 full 23.126 23.077 70 large_N (4096, 32768, 4096) float32 row 23.002 22.956 71 large_N (4096, 32768, 4096) float32 column 23.012 22.969 72 large_N (2048, 16384, 2048) float16 full 0.314 0.314 73 large_N (2048, 16384, 2048) float16 row 0.311 0.311 74 large_N (2048, 16384, 2048) float16 column 0.314 0.314 75 large_N (2048, 16384, 2048) bfloat16 full 0.306 0.305 76 large_N (2048, 16384, 2048) bfloat16 row 0.302 0.302 77 large_N (2048, 16384, 2048) bfloat16 column 0.305 0.303 78 large_N (2048, 16384, 2048) float32 full 2.975 2.971 79 large_N (2048, 16384, 2048) float32 row 2.927 2.925 80 large_N (2048, 16384, 2048) float32 column 2.934 2.936 81 large_all (16384, 16384, 16384) float16 full 14.062 14.096 82 large_all (16384, 16384, 16384) float16 row 14.058 14.078 83 large_all (16384, 16384, 16384) float16 column 14.107 14.120 84 large_all (16384, 16384, 16384) bfloat16 full 13.504 13.460 85 large_all (16384, 16384, 16384) bfloat16 row 13.495 13.499 86 large_all (16384, 16384, 16384) bfloat16 column 13.509 13.461 87 large_all (16384, 16384, 16384) float32 full 177.279 177.242 88 large_all (16384, 16384, 16384) float32 row 176.896 176.651 89 large_all (16384, 16384, 16384) float32 column 176.830 176.451 ``` ## script ``` """ Torch addmm benchmarking script covering different input configurations. Tests 30 different combinations of: - Input types: large M, large K, large N, large everything - Data types: float16, bfloat16, float32 - Bias types: full, row, column """ import itertools import torch import torch.utils.benchmark as benchmark def create_test_configurations(): """Create 30 different test configurations for torch.addmm""" # Define dimension configurations # Large M: many rows in input matrix # Large K: many columns in input/rows in weight # Large N: many columns in weight matrix # Large everything: all dimensions large dim_configs = [ # Large M configurations {"M": 65536, "K": 8192, "N": 8192, "type": "large_M"}, {"M": 32768, "K": 4096, "N": 4096, "type": "large_M"}, {"M": 16384, "K": 2048, "N": 2048, "type": "large_M"}, # Large K configurations {"M": 8192, "K": 65536, "N": 8192, "type": "large_K"}, {"M": 4096, "K": 32768, "N": 4096, "type": "large_K"}, {"M": 2048, "K": 16384, "N": 2048, "type": "large_K"}, # Large N configurations {"M": 8192, "K": 8192, "N": 65536, "type": "large_N"}, {"M": 4096, "K": 4096, "N": 32768, "type": "large_N"}, {"M": 2048, "K": 2048, "N": 16384, "type": "large_N"}, # Large everything configurations {"M": 16384, "K": 16384, "N": 16384, "type": "large_all"}, ] # Data types to test dtypes = [torch.float16, torch.bfloat16, torch.float32] # Bias configurations bias_configs = ["full", "row", "column"] # Generate all combinations configurations = [] config_id = 0 for dim_config, dtype, bias_type in itertools.product( dim_configs, dtypes, bias_configs ): config = { "id": config_id, "M": dim_config["M"], "K": dim_config["K"], "N": dim_config["N"], "dim_type": dim_config["type"], "dtype": dtype, "bias_type": bias_type, } configurations.append(config) config_id += 1 return configurations def create_tensors(config): """Create input tensors for a given configuration""" M, K, N = config["M"], config["K"], config["N"] dtype = config["dtype"] bias_type = config["bias_type"] # Create input tensor (M x K) input_tensor = torch.randn(M, K, dtype=dtype, device="cuda") # Create weight tensor (K x N) weight_tensor = torch.randn(K, N, dtype=dtype, device="cuda") # Create bias tensor based on bias type if bias_type == "full": bias_tensor = torch.randn(M, N, dtype=dtype, device="cuda") elif bias_type == "row": bias_tensor = torch.randn(M, 1, dtype=dtype, device="cuda") elif bias_type == "column": bias_tensor = torch.randn(1, N, dtype=dtype, device="cuda") return input_tensor, weight_tensor, bias_tensor def benchmark_addmm(config, use_compile=False): """Benchmark torch.addmm for a given configuration""" input_tensor, weight_tensor, bias_tensor = create_tensors(config) # Define the operation def addmm_op(): return torch.addmm(bias_tensor, input_tensor, weight_tensor) # Optionally compile the operation if use_compile: addmm_op = torch.compile(addmm_op) # Warmup for compiled version for _ in range(3): _ = addmm_op() torch.cuda.synchronize() # Benchmark using torch.utils.benchmark timer = benchmark.Timer( stmt="addmm_op()", globals={"addmm_op": addmm_op}, description=f"addmm_{config['dim_type']}_{config['dtype']}_{config['bias_type']}_compiled_{use_compile}", ) # Run benchmark measurement = timer.blocked_autorange(min_run_time=1.0) return measurement def print_tensor_info(config, input_tensor, weight_tensor, bias_tensor): """Print information about input tensors""" print(f"\\nConfiguration {config['id']}:") print(f" Dimension type: {config['dim_type']}") print(f" Data type: {config['dtype']}") print(f" Bias type: {config['bias_type']}") print(f" Input tensor shape: {input_tensor.shape} ({input_tensor.dtype})") print(f" Weight tensor shape: {weight_tensor.shape} ({weight_tensor.dtype})") print(f" Bias tensor shape: {bias_tensor.shape} ({bias_tensor.dtype})") def main(): """Main benchmarking function""" print("Torch addmm Benchmarking Script (Compiled Only)") print("=" * 50) # Check CUDA availability if not torch.cuda.is_available(): print("CUDA not available. This benchmark requires CUDA.") return print(f"Using device: {torch.cuda.get_device_name()}") print(f"PyTorch version: {torch.__version__}") # Create test configurations configurations = create_test_configurations() print(f"\\nTesting {len(configurations)} configurations...") results = [] for config in configurations: print(f"\\n{'='*60}") # Create tensors and print info input_tensor, weight_tensor, bias_tensor = create_tensors(config) print_tensor_info(config, input_tensor, weight_tensor, bias_tensor) # Benchmark with compilation only print("\\nBenchmarking with torch.compile:") try: measurement = benchmark_addmm(config, use_compile=True) runtime_ms = measurement.mean * 1000 print(f" Runtime: {runtime_ms:.3f} ms") results.append( { "config": config, "runtime_ms": runtime_ms, } ) except Exception as e: print(f" Error: {e}") # Clear cache torch.cuda.empty_cache() # Print summary print(f"\\n\\n{'='*80}") print("BENCHMARK SUMMARY") print(f"{'='*80}") print( f"{'Config':<8} {'Dim Type':<12} {'(M, N, K)':<20} {'DType':<12} {'Bias':<8} {'Runtime (ms)':<15}" ) print("-" * 90) for result in results: config = result["config"] dtype_str = str(config["dtype"]).split(".")[-1] dimensions_str = f"({config['M']}, {config['N']}, {config['K']})" print( f"{config['id']:<8} {config['dim_type']:<12} {dimensions_str:<20} {dtype_str:<12} {config['bias_type']:<8} {result['runtime_ms']:<15.3f}" ) if __name__ == "__main__": main() ``` cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov [ghstack-poisoned]
coconutruben
added a commit
that referenced
this pull request
Aug 21, 2025
\# why - simplifies the code a lot - unnecessary for performance (see below) - matches other mm family kernels in logic now \# what - if we're not autotuning (only ATen), we use a flexible layout - in any case, we use inp_expanded (the expanded bias view) for the aten kernel \# performance analysis \## results this is on H100 ``` ================================================================================ BENCHMARK SUMMARY (MERGED) ================================================================================ Config Dim Type (M, N, K) DType Bias Runtime (inp_expanded) ms Runtime (inp) ms ---------------------------------------------------------------------------------------------------------------------------- 0 large_M (65536, 8192, 8192) float16 full 14.531 14.562 1 large_M (65536, 8192, 8192) float16 row 14.682 14.675 2 large_M (65536, 8192, 8192) float16 column 14.754 14.740 3 large_M (65536, 8192, 8192) bfloat16 full 15.172 15.148 4 large_M (65536, 8192, 8192) bfloat16 row 15.072 15.085 5 large_M (65536, 8192, 8192) bfloat16 column 15.082 15.114 6 large_M (65536, 8192, 8192) float32 full 185.726 186.308 7 large_M (65536, 8192, 8192) float32 row 185.042 185.864 8 large_M (65536, 8192, 8192) float32 column 185.221 185.991 9 large_M (32768, 4096, 4096) float16 full 2.025 2.036 10 large_M (32768, 4096, 4096) float16 row 2.029 2.033 11 large_M (32768, 4096, 4096) float16 column 2.036 2.047 12 large_M (32768, 4096, 4096) bfloat16 full 1.966 1.981 13 large_M (32768, 4096, 4096) bfloat16 row 1.963 1.979 14 large_M (32768, 4096, 4096) bfloat16 column 1.973 1.981 15 large_M (32768, 4096, 4096) float32 full 24.096 24.180 16 large_M (32768, 4096, 4096) float32 row 23.951 24.033 17 large_M (32768, 4096, 4096) float32 column 23.996 24.061 18 large_M (16384, 2048, 2048) float16 full 0.297 0.298 19 large_M (16384, 2048, 2048) float16 row 0.298 0.299 20 large_M (16384, 2048, 2048) float16 column 0.301 0.300 21 large_M (16384, 2048, 2048) bfloat16 full 0.293 0.293 22 large_M (16384, 2048, 2048) bfloat16 row 0.290 0.291 23 large_M (16384, 2048, 2048) bfloat16 column 0.292 0.293 24 large_M (16384, 2048, 2048) float32 full 3.077 3.073 25 large_M (16384, 2048, 2048) float32 row 3.034 3.033 26 large_M (16384, 2048, 2048) float32 column 3.040 3.038 27 large_K (8192, 8192, 65536) float16 full 14.278 14.297 28 large_K (8192, 8192, 65536) float16 row 14.325 14.283 29 large_K (8192, 8192, 65536) float16 column 14.179 14.302 30 large_K (8192, 8192, 65536) bfloat16 full 13.616 13.628 31 large_K (8192, 8192, 65536) bfloat16 row 13.584 13.642 32 large_K (8192, 8192, 65536) bfloat16 column 13.594 13.694 33 large_K (8192, 8192, 65536) float32 full 175.933 176.153 34 large_K (8192, 8192, 65536) float32 row 175.504 175.877 35 large_K (8192, 8192, 65536) float32 column 175.432 175.992 36 large_K (4096, 4096, 32768) float16 full 1.726 1.724 37 large_K (4096, 4096, 32768) float16 row 1.731 1.735 38 large_K (4096, 4096, 32768) float16 column 1.733 1.737 39 large_K (4096, 4096, 32768) bfloat16 full 1.662 1.658 40 large_K (4096, 4096, 32768) bfloat16 row 1.664 1.655 41 large_K (4096, 4096, 32768) bfloat16 column 1.660 1.667 42 large_K (4096, 4096, 32768) float32 full 22.263 22.305 43 large_K (4096, 4096, 32768) float32 row 22.257 22.322 44 large_K (4096, 4096, 32768) float32 column 22.247 22.337 45 large_K (2048, 2048, 16384) float16 full 0.236 0.236 46 large_K (2048, 2048, 16384) float16 row 0.238 0.239 47 large_K (2048, 2048, 16384) float16 column 0.238 0.239 48 large_K (2048, 2048, 16384) bfloat16 full 0.219 0.219 49 large_K (2048, 2048, 16384) bfloat16 row 0.221 0.222 50 large_K (2048, 2048, 16384) bfloat16 column 0.222 0.222 51 large_K (2048, 2048, 16384) float32 full 2.786 2.789 52 large_K (2048, 2048, 16384) float32 row 2.790 2.782 53 large_K (2048, 2048, 16384) float32 column 2.791 2.791 54 large_N (8192, 65536, 8192) float16 full 14.692 14.723 55 large_N (8192, 65536, 8192) float16 row 14.721 14.637 56 large_N (8192, 65536, 8192) float16 column 14.743 14.737 57 large_N (8192, 65536, 8192) bfloat16 full 15.156 15.128 58 large_N (8192, 65536, 8192) bfloat16 row 15.152 15.124 59 large_N (8192, 65536, 8192) bfloat16 column 15.112 15.090 60 large_N (8192, 65536, 8192) float32 full 179.127 179.313 61 large_N (8192, 65536, 8192) float32 row 178.445 178.961 62 large_N (8192, 65536, 8192) float32 column 178.693 178.694 63 large_N (4096, 32768, 4096) float16 full 2.035 2.037 64 large_N (4096, 32768, 4096) float16 row 2.042 2.037 65 large_N (4096, 32768, 4096) float16 column 2.053 2.046 66 large_N (4096, 32768, 4096) bfloat16 full 1.992 1.997 67 large_N (4096, 32768, 4096) bfloat16 row 1.997 1.987 68 large_N (4096, 32768, 4096) bfloat16 column 2.005 2.001 69 large_N (4096, 32768, 4096) float32 full 23.126 23.077 70 large_N (4096, 32768, 4096) float32 row 23.002 22.956 71 large_N (4096, 32768, 4096) float32 column 23.012 22.969 72 large_N (2048, 16384, 2048) float16 full 0.314 0.314 73 large_N (2048, 16384, 2048) float16 row 0.311 0.311 74 large_N (2048, 16384, 2048) float16 column 0.314 0.314 75 large_N (2048, 16384, 2048) bfloat16 full 0.306 0.305 76 large_N (2048, 16384, 2048) bfloat16 row 0.302 0.302 77 large_N (2048, 16384, 2048) bfloat16 column 0.305 0.303 78 large_N (2048, 16384, 2048) float32 full 2.975 2.971 79 large_N (2048, 16384, 2048) float32 row 2.927 2.925 80 large_N (2048, 16384, 2048) float32 column 2.934 2.936 81 large_all (16384, 16384, 16384) float16 full 14.062 14.096 82 large_all (16384, 16384, 16384) float16 row 14.058 14.078 83 large_all (16384, 16384, 16384) float16 column 14.107 14.120 84 large_all (16384, 16384, 16384) bfloat16 full 13.504 13.460 85 large_all (16384, 16384, 16384) bfloat16 row 13.495 13.499 86 large_all (16384, 16384, 16384) bfloat16 column 13.509 13.461 87 large_all (16384, 16384, 16384) float32 full 177.279 177.242 88 large_all (16384, 16384, 16384) float32 row 176.896 176.651 89 large_all (16384, 16384, 16384) float32 column 176.830 176.451 ``` \## script ``` """ Torch addmm benchmarking script covering different input configurations. Tests 30 different combinations of: - Input types: large M, large K, large N, large everything - Data types: float16, bfloat16, float32 - Bias types: full, row, column """ import itertools import torch import torch.utils.benchmark as benchmark def create_test_configurations(): """Create 30 different test configurations for torch.addmm""" # Define dimension configurations # Large M: many rows in input matrix # Large K: many columns in input/rows in weight # Large N: many columns in weight matrix # Large everything: all dimensions large dim_configs = [ # Large M configurations {"M": 65536, "K": 8192, "N": 8192, "type": "large_M"}, {"M": 32768, "K": 4096, "N": 4096, "type": "large_M"}, {"M": 16384, "K": 2048, "N": 2048, "type": "large_M"}, # Large K configurations {"M": 8192, "K": 65536, "N": 8192, "type": "large_K"}, {"M": 4096, "K": 32768, "N": 4096, "type": "large_K"}, {"M": 2048, "K": 16384, "N": 2048, "type": "large_K"}, # Large N configurations {"M": 8192, "K": 8192, "N": 65536, "type": "large_N"}, {"M": 4096, "K": 4096, "N": 32768, "type": "large_N"}, {"M": 2048, "K": 2048, "N": 16384, "type": "large_N"}, # Large everything configurations {"M": 16384, "K": 16384, "N": 16384, "type": "large_all"}, ] # Data types to test dtypes = [torch.float16, torch.bfloat16, torch.float32] # Bias configurations bias_configs = ["full", "row", "column"] # Generate all combinations configurations = [] config_id = 0 for dim_config, dtype, bias_type in itertools.product( dim_configs, dtypes, bias_configs ): config = { "id": config_id, "M": dim_config["M"], "K": dim_config["K"], "N": dim_config["N"], "dim_type": dim_config["type"], "dtype": dtype, "bias_type": bias_type, } configurations.append(config) config_id += 1 return configurations def create_tensors(config): """Create input tensors for a given configuration""" M, K, N = config["M"], config["K"], config["N"] dtype = config["dtype"] bias_type = config["bias_type"] # Create input tensor (M x K) input_tensor = torch.randn(M, K, dtype=dtype, device="cuda") # Create weight tensor (K x N) weight_tensor = torch.randn(K, N, dtype=dtype, device="cuda") # Create bias tensor based on bias type if bias_type == "full": bias_tensor = torch.randn(M, N, dtype=dtype, device="cuda") elif bias_type == "row": bias_tensor = torch.randn(M, 1, dtype=dtype, device="cuda") elif bias_type == "column": bias_tensor = torch.randn(1, N, dtype=dtype, device="cuda") return input_tensor, weight_tensor, bias_tensor def benchmark_addmm(config, use_compile=False): """Benchmark torch.addmm for a given configuration""" input_tensor, weight_tensor, bias_tensor = create_tensors(config) # Define the operation def addmm_op(): return torch.addmm(bias_tensor, input_tensor, weight_tensor) # Optionally compile the operation if use_compile: addmm_op = torch.compile(addmm_op) # Warmup for compiled version for _ in range(3): _ = addmm_op() torch.cuda.synchronize() # Benchmark using torch.utils.benchmark timer = benchmark.Timer( stmt="addmm_op()", globals={"addmm_op": addmm_op}, description=f"addmm_{config['dim_type']}_{config['dtype']}_{config['bias_type']}_compiled_{use_compile}", ) # Run benchmark measurement = timer.blocked_autorange(min_run_time=1.0) return measurement def print_tensor_info(config, input_tensor, weight_tensor, bias_tensor): """Print information about input tensors""" print(f"\\nConfiguration {config['id']}:") print(f" Dimension type: {config['dim_type']}") print(f" Data type: {config['dtype']}") print(f" Bias type: {config['bias_type']}") print(f" Input tensor shape: {input_tensor.shape} ({input_tensor.dtype})") print(f" Weight tensor shape: {weight_tensor.shape} ({weight_tensor.dtype})") print(f" Bias tensor shape: {bias_tensor.shape} ({bias_tensor.dtype})") def main(): """Main benchmarking function""" print("Torch addmm Benchmarking Script (Compiled Only)") print("=" * 50) # Check CUDA availability if not torch.cuda.is_available(): print("CUDA not available. This benchmark requires CUDA.") return print(f"Using device: {torch.cuda.get_device_name()}") print(f"PyTorch version: {torch.__version__}") # Create test configurations configurations = create_test_configurations() print(f"\\nTesting {len(configurations)} configurations...") results = [] for config in configurations: print(f"\\n{'='*60}") # Create tensors and print info input_tensor, weight_tensor, bias_tensor = create_tensors(config) print_tensor_info(config, input_tensor, weight_tensor, bias_tensor) # Benchmark with compilation only print("\\nBenchmarking with torch.compile:") try: measurement = benchmark_addmm(config, use_compile=True) runtime_ms = measurement.mean * 1000 print(f" Runtime: {runtime_ms:.3f} ms") results.append( { "config": config, "runtime_ms": runtime_ms, } ) except Exception as e: print(f" Error: {e}") # Clear cache torch.cuda.empty_cache() # Print summary print(f"\\n\\n{'='*80}") print("BENCHMARK SUMMARY") print(f"{'='*80}") print( f"{'Config':<8} {'Dim Type':<12} {'(M, N, K)':<20} {'DType':<12} {'Bias':<8} {'Runtime (ms)':<15}" ) print("-" * 90) for result in results: config = result["config"] dtype_str = str(config["dtype"]).split(".")[-1] dimensions_str = f"({config['M']}, {config['N']}, {config['K']})" print( f"{config['id']:<8} {config['dim_type']:<12} {dimensions_str:<20} {dtype_str:<12} {config['bias_type']:<8} {result['runtime_ms']:<15.3f}" ) if __name__ == "__main__": main() ``` ghstack-source-id: a544f1d Pull Request resolved: #161208
@jataylo could you take a look as well to make sure this does not regress AMD performance? |
eellison
approved these changes
Aug 22, 2025
# why - simplifies the code a lot - unnecessary for performance (see below) - matches other mm family kernels in logic now # what - if we're not autotuning (only ATen), we use a flexible layout - in any case, we use inp_expanded (the expanded bias view) for the aten kernel # performance analysis ## results this is on H100 ``` ================================================================================ BENCHMARK SUMMARY (MERGED) ================================================================================ Config Dim Type (M, N, K) DType Bias Runtime (inp_expanded) ms Runtime (inp) ms ---------------------------------------------------------------------------------------------------------------------------- 0 large_M (65536, 8192, 8192) float16 full 14.531 14.562 1 large_M (65536, 8192, 8192) float16 row 14.682 14.675 2 large_M (65536, 8192, 8192) float16 column 14.754 14.740 3 large_M (65536, 8192, 8192) bfloat16 full 15.172 15.148 4 large_M (65536, 8192, 8192) bfloat16 row 15.072 15.085 5 large_M (65536, 8192, 8192) bfloat16 column 15.082 15.114 6 large_M (65536, 8192, 8192) float32 full 185.726 186.308 7 large_M (65536, 8192, 8192) float32 row 185.042 185.864 8 large_M (65536, 8192, 8192) float32 column 185.221 185.991 9 large_M (32768, 4096, 4096) float16 full 2.025 2.036 10 large_M (32768, 4096, 4096) float16 row 2.029 2.033 11 large_M (32768, 4096, 4096) float16 column 2.036 2.047 12 large_M (32768, 4096, 4096) bfloat16 full 1.966 1.981 13 large_M (32768, 4096, 4096) bfloat16 row 1.963 1.979 14 large_M (32768, 4096, 4096) bfloat16 column 1.973 1.981 15 large_M (32768, 4096, 4096) float32 full 24.096 24.180 16 large_M (32768, 4096, 4096) float32 row 23.951 24.033 17 large_M (32768, 4096, 4096) float32 column 23.996 24.061 18 large_M (16384, 2048, 2048) float16 full 0.297 0.298 19 large_M (16384, 2048, 2048) float16 row 0.298 0.299 20 large_M (16384, 2048, 2048) float16 column 0.301 0.300 21 large_M (16384, 2048, 2048) bfloat16 full 0.293 0.293 22 large_M (16384, 2048, 2048) bfloat16 row 0.290 0.291 23 large_M (16384, 2048, 2048) bfloat16 column 0.292 0.293 24 large_M (16384, 2048, 2048) float32 full 3.077 3.073 25 large_M (16384, 2048, 2048) float32 row 3.034 3.033 26 large_M (16384, 2048, 2048) float32 column 3.040 3.038 27 large_K (8192, 8192, 65536) float16 full 14.278 14.297 28 large_K (8192, 8192, 65536) float16 row 14.325 14.283 29 large_K (8192, 8192, 65536) float16 column 14.179 14.302 30 large_K (8192, 8192, 65536) bfloat16 full 13.616 13.628 31 large_K (8192, 8192, 65536) bfloat16 row 13.584 13.642 32 large_K (8192, 8192, 65536) bfloat16 column 13.594 13.694 33 large_K (8192, 8192, 65536) float32 full 175.933 176.153 34 large_K (8192, 8192, 65536) float32 row 175.504 175.877 35 large_K (8192, 8192, 65536) float32 column 175.432 175.992 36 large_K (4096, 4096, 32768) float16 full 1.726 1.724 37 large_K (4096, 4096, 32768) float16 row 1.731 1.735 38 large_K (4096, 4096, 32768) float16 column 1.733 1.737 39 large_K (4096, 4096, 32768) bfloat16 full 1.662 1.658 40 large_K (4096, 4096, 32768) bfloat16 row 1.664 1.655 41 large_K (4096, 4096, 32768) bfloat16 column 1.660 1.667 42 large_K (4096, 4096, 32768) float32 full 22.263 22.305 43 large_K (4096, 4096, 32768) float32 row 22.257 22.322 44 large_K (4096, 4096, 32768) float32 column 22.247 22.337 45 large_K (2048, 2048, 16384) float16 full 0.236 0.236 46 large_K (2048, 2048, 16384) float16 row 0.238 0.239 47 large_K (2048, 2048, 16384) float16 column 0.238 0.239 48 large_K (2048, 2048, 16384) bfloat16 full 0.219 0.219 49 large_K (2048, 2048, 16384) bfloat16 row 0.221 0.222 50 large_K (2048, 2048, 16384) bfloat16 column 0.222 0.222 51 large_K (2048, 2048, 16384) float32 full 2.786 2.789 52 large_K (2048, 2048, 16384) float32 row 2.790 2.782 53 large_K (2048, 2048, 16384) float32 column 2.791 2.791 54 large_N (8192, 65536, 8192) float16 full 14.692 14.723 55 large_N (8192, 65536, 8192) float16 row 14.721 14.637 56 large_N (8192, 65536, 8192) float16 column 14.743 14.737 57 large_N (8192, 65536, 8192) bfloat16 full 15.156 15.128 58 large_N (8192, 65536, 8192) bfloat16 row 15.152 15.124 59 large_N (8192, 65536, 8192) bfloat16 column 15.112 15.090 60 large_N (8192, 65536, 8192) float32 full 179.127 179.313 61 large_N (8192, 65536, 8192) float32 row 178.445 178.961 62 large_N (8192, 65536, 8192) float32 column 178.693 178.694 63 large_N (4096, 32768, 4096) float16 full 2.035 2.037 64 large_N (4096, 32768, 4096) float16 row 2.042 2.037 65 large_N (4096, 32768, 4096) float16 column 2.053 2.046 66 large_N (4096, 32768, 4096) bfloat16 full 1.992 1.997 67 large_N (4096, 32768, 4096) bfloat16 row 1.997 1.987 68 large_N (4096, 32768, 4096) bfloat16 column 2.005 2.001 69 large_N (4096, 32768, 4096) float32 full 23.126 23.077 70 large_N (4096, 32768, 4096) float32 row 23.002 22.956 71 large_N (4096, 32768, 4096) float32 column 23.012 22.969 72 large_N (2048, 16384, 2048) float16 full 0.314 0.314 73 large_N (2048, 16384, 2048) float16 row 0.311 0.311 74 large_N (2048, 16384, 2048) float16 column 0.314 0.314 75 large_N (2048, 16384, 2048) bfloat16 full 0.306 0.305 76 large_N (2048, 16384, 2048) bfloat16 row 0.302 0.302 77 large_N (2048, 16384, 2048) bfloat16 column 0.305 0.303 78 large_N (2048, 16384, 2048) float32 full 2.975 2.971 79 large_N (2048, 16384, 2048) float32 row 2.927 2.925 80 large_N (2048, 16384, 2048) float32 column 2.934 2.936 81 large_all (16384, 16384, 16384) float16 full 14.062 14.096 82 large_all (16384, 16384, 16384) float16 row 14.058 14.078 83 large_all (16384, 16384, 16384) float16 column 14.107 14.120 84 large_all (16384, 16384, 16384) bfloat16 full 13.504 13.460 85 large_all (16384, 16384, 16384) bfloat16 row 13.495 13.499 86 large_all (16384, 16384, 16384) bfloat16 column 13.509 13.461 87 large_all (16384, 16384, 16384) float32 full 177.279 177.242 88 large_all (16384, 16384, 16384) float32 row 176.896 176.651 89 large_all (16384, 16384, 16384) float32 column 176.830 176.451 ``` ## script ``` """ Torch addmm benchmarking script covering different input configurations. Tests 30 different combinations of: - Input types: large M, large K, large N, large everything - Data types: float16, bfloat16, float32 - Bias types: full, row, column """ import itertools import torch import torch.utils.benchmark as benchmark def create_test_configurations(): """Create 30 different test configurations for torch.addmm""" # Define dimension configurations # Large M: many rows in input matrix # Large K: many columns in input/rows in weight # Large N: many columns in weight matrix # Large everything: all dimensions large dim_configs = [ # Large M configurations {"M": 65536, "K": 8192, "N": 8192, "type": "large_M"}, {"M": 32768, "K": 4096, "N": 4096, "type": "large_M"}, {"M": 16384, "K": 2048, "N": 2048, "type": "large_M"}, # Large K configurations {"M": 8192, "K": 65536, "N": 8192, "type": "large_K"}, {"M": 4096, "K": 32768, "N": 4096, "type": "large_K"}, {"M": 2048, "K": 16384, "N": 2048, "type": "large_K"}, # Large N configurations {"M": 8192, "K": 8192, "N": 65536, "type": "large_N"}, {"M": 4096, "K": 4096, "N": 32768, "type": "large_N"}, {"M": 2048, "K": 2048, "N": 16384, "type": "large_N"}, # Large everything configurations {"M": 16384, "K": 16384, "N": 16384, "type": "large_all"}, ] # Data types to test dtypes = [torch.float16, torch.bfloat16, torch.float32] # Bias configurations bias_configs = ["full", "row", "column"] # Generate all combinations configurations = [] config_id = 0 for dim_config, dtype, bias_type in itertools.product( dim_configs, dtypes, bias_configs ): config = { "id": config_id, "M": dim_config["M"], "K": dim_config["K"], "N": dim_config["N"], "dim_type": dim_config["type"], "dtype": dtype, "bias_type": bias_type, } configurations.append(config) config_id += 1 return configurations def create_tensors(config): """Create input tensors for a given configuration""" M, K, N = config["M"], config["K"], config["N"] dtype = config["dtype"] bias_type = config["bias_type"] # Create input tensor (M x K) input_tensor = torch.randn(M, K, dtype=dtype, device="cuda") # Create weight tensor (K x N) weight_tensor = torch.randn(K, N, dtype=dtype, device="cuda") # Create bias tensor based on bias type if bias_type == "full": bias_tensor = torch.randn(M, N, dtype=dtype, device="cuda") elif bias_type == "row": bias_tensor = torch.randn(M, 1, dtype=dtype, device="cuda") elif bias_type == "column": bias_tensor = torch.randn(1, N, dtype=dtype, device="cuda") return input_tensor, weight_tensor, bias_tensor def benchmark_addmm(config, use_compile=False): """Benchmark torch.addmm for a given configuration""" input_tensor, weight_tensor, bias_tensor = create_tensors(config) # Define the operation def addmm_op(): return torch.addmm(bias_tensor, input_tensor, weight_tensor) # Optionally compile the operation if use_compile: addmm_op = torch.compile(addmm_op) # Warmup for compiled version for _ in range(3): _ = addmm_op() torch.cuda.synchronize() # Benchmark using torch.utils.benchmark timer = benchmark.Timer( stmt="addmm_op()", globals={"addmm_op": addmm_op}, description=f"addmm_{config['dim_type']}_{config['dtype']}_{config['bias_type']}_compiled_{use_compile}", ) # Run benchmark measurement = timer.blocked_autorange(min_run_time=1.0) return measurement def print_tensor_info(config, input_tensor, weight_tensor, bias_tensor): """Print information about input tensors""" print(f"\\nConfiguration {config['id']}:") print(f" Dimension type: {config['dim_type']}") print(f" Data type: {config['dtype']}") print(f" Bias type: {config['bias_type']}") print(f" Input tensor shape: {input_tensor.shape} ({input_tensor.dtype})") print(f" Weight tensor shape: {weight_tensor.shape} ({weight_tensor.dtype})") print(f" Bias tensor shape: {bias_tensor.shape} ({bias_tensor.dtype})") def main(): """Main benchmarking function""" print("Torch addmm Benchmarking Script (Compiled Only)") print("=" * 50) # Check CUDA availability if not torch.cuda.is_available(): print("CUDA not available. This benchmark requires CUDA.") return print(f"Using device: {torch.cuda.get_device_name()}") print(f"PyTorch version: {torch.__version__}") # Create test configurations configurations = create_test_configurations() print(f"\\nTesting {len(configurations)} configurations...") results = [] for config in configurations: print(f"\\n{'='*60}") # Create tensors and print info input_tensor, weight_tensor, bias_tensor = create_tensors(config) print_tensor_info(config, input_tensor, weight_tensor, bias_tensor) # Benchmark with compilation only print("\\nBenchmarking with torch.compile:") try: measurement = benchmark_addmm(config, use_compile=True) runtime_ms = measurement.mean * 1000 print(f" Runtime: {runtime_ms:.3f} ms") results.append( { "config": config, "runtime_ms": runtime_ms, } ) except Exception as e: print(f" Error: {e}") # Clear cache torch.cuda.empty_cache() # Print summary print(f"\\n\\n{'='*80}") print("BENCHMARK SUMMARY") print(f"{'='*80}") print( f"{'Config':<8} {'Dim Type':<12} {'(M, N, K)':<20} {'DType':<12} {'Bias':<8} {'Runtime (ms)':<15}" ) print("-" * 90) for result in results: config = result["config"] dtype_str = str(config["dtype"]).split(".")[-1] dimensions_str = f"({config['M']}, {config['N']}, {config['K']})" print( f"{config['id']:<8} {config['dim_type']:<12} {dimensions_str:<20} {dtype_str:<12} {config['bias_type']:<8} {result['runtime_ms']:<15.3f}" ) if __name__ == "__main__": main() ``` cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov [ghstack-poisoned]
This was referenced Aug 23, 2025
This was referenced Aug 23, 2025
# why - simplifies the code a lot - unnecessary for performance (see below) - matches other mm family kernels in logic now # what - if we're not autotuning (only ATen), we use a flexible layout - in any case, we use inp_expanded (the expanded bias view) for the aten kernel # performance analysis ## results this is on H100 ``` ================================================================================ BENCHMARK SUMMARY (MERGED) ================================================================================ Config Dim Type (M, N, K) DType Bias Runtime (inp_expanded) ms Runtime (inp) ms ---------------------------------------------------------------------------------------------------------------------------- 0 large_M (65536, 8192, 8192) float16 full 14.531 14.562 1 large_M (65536, 8192, 8192) float16 row 14.682 14.675 2 large_M (65536, 8192, 8192) float16 column 14.754 14.740 3 large_M (65536, 8192, 8192) bfloat16 full 15.172 15.148 4 large_M (65536, 8192, 8192) bfloat16 row 15.072 15.085 5 large_M (65536, 8192, 8192) bfloat16 column 15.082 15.114 6 large_M (65536, 8192, 8192) float32 full 185.726 186.308 7 large_M (65536, 8192, 8192) float32 row 185.042 185.864 8 large_M (65536, 8192, 8192) float32 column 185.221 185.991 9 large_M (32768, 4096, 4096) float16 full 2.025 2.036 10 large_M (32768, 4096, 4096) float16 row 2.029 2.033 11 large_M (32768, 4096, 4096) float16 column 2.036 2.047 12 large_M (32768, 4096, 4096) bfloat16 full 1.966 1.981 13 large_M (32768, 4096, 4096) bfloat16 row 1.963 1.979 14 large_M (32768, 4096, 4096) bfloat16 column 1.973 1.981 15 large_M (32768, 4096, 4096) float32 full 24.096 24.180 16 large_M (32768, 4096, 4096) float32 row 23.951 24.033 17 large_M (32768, 4096, 4096) float32 column 23.996 24.061 18 large_M (16384, 2048, 2048) float16 full 0.297 0.298 19 large_M (16384, 2048, 2048) float16 row 0.298 0.299 20 large_M (16384, 2048, 2048) float16 column 0.301 0.300 21 large_M (16384, 2048, 2048) bfloat16 full 0.293 0.293 22 large_M (16384, 2048, 2048) bfloat16 row 0.290 0.291 23 large_M (16384, 2048, 2048) bfloat16 column 0.292 0.293 24 large_M (16384, 2048, 2048) float32 full 3.077 3.073 25 large_M (16384, 2048, 2048) float32 row 3.034 3.033 26 large_M (16384, 2048, 2048) float32 column 3.040 3.038 27 large_K (8192, 8192, 65536) float16 full 14.278 14.297 28 large_K (8192, 8192, 65536) float16 row 14.325 14.283 29 large_K (8192, 8192, 65536) float16 column 14.179 14.302 30 large_K (8192, 8192, 65536) bfloat16 full 13.616 13.628 31 large_K (8192, 8192, 65536) bfloat16 row 13.584 13.642 32 large_K (8192, 8192, 65536) bfloat16 column 13.594 13.694 33 large_K (8192, 8192, 65536) float32 full 175.933 176.153 34 large_K (8192, 8192, 65536) float32 row 175.504 175.877 35 large_K (8192, 8192, 65536) float32 column 175.432 175.992 36 large_K (4096, 4096, 32768) float16 full 1.726 1.724 37 large_K (4096, 4096, 32768) float16 row 1.731 1.735 38 large_K (4096, 4096, 32768) float16 column 1.733 1.737 39 large_K (4096, 4096, 32768) bfloat16 full 1.662 1.658 40 large_K (4096, 4096, 32768) bfloat16 row 1.664 1.655 41 large_K (4096, 4096, 32768) bfloat16 column 1.660 1.667 42 large_K (4096, 4096, 32768) float32 full 22.263 22.305 43 large_K (4096, 4096, 32768) float32 row 22.257 22.322 44 large_K (4096, 4096, 32768) float32 column 22.247 22.337 45 large_K (2048, 2048, 16384) float16 full 0.236 0.236 46 large_K (2048, 2048, 16384) float16 row 0.238 0.239 47 large_K (2048, 2048, 16384) float16 column 0.238 0.239 48 large_K (2048, 2048, 16384) bfloat16 full 0.219 0.219 49 large_K (2048, 2048, 16384) bfloat16 row 0.221 0.222 50 large_K (2048, 2048, 16384) bfloat16 column 0.222 0.222 51 large_K (2048, 2048, 16384) float32 full 2.786 2.789 52 large_K (2048, 2048, 16384) float32 row 2.790 2.782 53 large_K (2048, 2048, 16384) float32 column 2.791 2.791 54 large_N (8192, 65536, 8192) float16 full 14.692 14.723 55 large_N (8192, 65536, 8192) float16 row 14.721 14.637 56 large_N (8192, 65536, 8192) float16 column 14.743 14.737 57 large_N (8192, 65536, 8192) bfloat16 full 15.156 15.128 58 large_N (8192, 65536, 8192) bfloat16 row 15.152 15.124 59 large_N (8192, 65536, 8192) bfloat16 column 15.112 15.090 60 large_N (8192, 65536, 8192) float32 full 179.127 179.313 61 large_N (8192, 65536, 8192) float32 row 178.445 178.961 62 large_N (8192, 65536, 8192) float32 column 178.693 178.694 63 large_N (4096, 32768, 4096) float16 full 2.035 2.037 64 large_N (4096, 32768, 4096) float16 row 2.042 2.037 65 large_N (4096, 32768, 4096) float16 column 2.053 2.046 66 large_N (4096, 32768, 4096) bfloat16 full 1.992 1.997 67 large_N (4096, 32768, 4096) bfloat16 row 1.997 1.987 68 large_N (4096, 32768, 4096) bfloat16 column 2.005 2.001 69 large_N (4096, 32768, 4096) float32 full 23.126 23.077 70 large_N (4096, 32768, 4096) float32 row 23.002 22.956 71 large_N (4096, 32768, 4096) float32 column 23.012 22.969 72 large_N (2048, 16384, 2048) float16 full 0.314 0.314 73 large_N (2048, 16384, 2048) float16 row 0.311 0.311 74 large_N (2048, 16384, 2048) float16 column 0.314 0.314 75 large_N (2048, 16384, 2048) bfloat16 full 0.306 0.305 76 large_N (2048, 16384, 2048) bfloat16 row 0.302 0.302 77 large_N (2048, 16384, 2048) bfloat16 column 0.305 0.303 78 large_N (2048, 16384, 2048) float32 full 2.975 2.971 79 large_N (2048, 16384, 2048) float32 row 2.927 2.925 80 large_N (2048, 16384, 2048) float32 column 2.934 2.936 81 large_all (16384, 16384, 16384) float16 full 14.062 14.096 82 large_all (16384, 16384, 16384) float16 row 14.058 14.078 83 large_all (16384, 16384, 16384) float16 column 14.107 14.120 84 large_all (16384, 16384, 16384) bfloat16 full 13.504 13.460 85 large_all (16384, 16384, 16384) bfloat16 row 13.495 13.499 86 large_all (16384, 16384, 16384) bfloat16 column 13.509 13.461 87 large_all (16384, 16384, 16384) float32 full 177.279 177.242 88 large_all (16384, 16384, 16384) float32 row 176.896 176.651 89 large_all (16384, 16384, 16384) float32 column 176.830 176.451 ``` ## script ``` """ Torch addmm benchmarking script covering different input configurations. Tests 30 different combinations of: - Input types: large M, large K, large N, large everything - Data types: float16, bfloat16, float32 - Bias types: full, row, column """ import itertools import torch import torch.utils.benchmark as benchmark def create_test_configurations(): """Create 30 different test configurations for torch.addmm""" # Define dimension configurations # Large M: many rows in input matrix # Large K: many columns in input/rows in weight # Large N: many columns in weight matrix # Large everything: all dimensions large dim_configs = [ # Large M configurations {"M": 65536, "K": 8192, "N": 8192, "type": "large_M"}, {"M": 32768, "K": 4096, "N": 4096, "type": "large_M"}, {"M": 16384, "K": 2048, "N": 2048, "type": "large_M"}, # Large K configurations {"M": 8192, "K": 65536, "N": 8192, "type": "large_K"}, {"M": 4096, "K": 32768, "N": 4096, "type": "large_K"}, {"M": 2048, "K": 16384, "N": 2048, "type": "large_K"}, # Large N configurations {"M": 8192, "K": 8192, "N": 65536, "type": "large_N"}, {"M": 4096, "K": 4096, "N": 32768, "type": "large_N"}, {"M": 2048, "K": 2048, "N": 16384, "type": "large_N"}, # Large everything configurations {"M": 16384, "K": 16384, "N": 16384, "type": "large_all"}, ] # Data types to test dtypes = [torch.float16, torch.bfloat16, torch.float32] # Bias configurations bias_configs = ["full", "row", "column"] # Generate all combinations configurations = [] config_id = 0 for dim_config, dtype, bias_type in itertools.product( dim_configs, dtypes, bias_configs ): config = { "id": config_id, "M": dim_config["M"], "K": dim_config["K"], "N": dim_config["N"], "dim_type": dim_config["type"], "dtype": dtype, "bias_type": bias_type, } configurations.append(config) config_id += 1 return configurations def create_tensors(config): """Create input tensors for a given configuration""" M, K, N = config["M"], config["K"], config["N"] dtype = config["dtype"] bias_type = config["bias_type"] # Create input tensor (M x K) input_tensor = torch.randn(M, K, dtype=dtype, device="cuda") # Create weight tensor (K x N) weight_tensor = torch.randn(K, N, dtype=dtype, device="cuda") # Create bias tensor based on bias type if bias_type == "full": bias_tensor = torch.randn(M, N, dtype=dtype, device="cuda") elif bias_type == "row": bias_tensor = torch.randn(M, 1, dtype=dtype, device="cuda") elif bias_type == "column": bias_tensor = torch.randn(1, N, dtype=dtype, device="cuda") return input_tensor, weight_tensor, bias_tensor def benchmark_addmm(config, use_compile=False): """Benchmark torch.addmm for a given configuration""" input_tensor, weight_tensor, bias_tensor = create_tensors(config) # Define the operation def addmm_op(): return torch.addmm(bias_tensor, input_tensor, weight_tensor) # Optionally compile the operation if use_compile: addmm_op = torch.compile(addmm_op) # Warmup for compiled version for _ in range(3): _ = addmm_op() torch.cuda.synchronize() # Benchmark using torch.utils.benchmark timer = benchmark.Timer( stmt="addmm_op()", globals={"addmm_op": addmm_op}, description=f"addmm_{config['dim_type']}_{config['dtype']}_{config['bias_type']}_compiled_{use_compile}", ) # Run benchmark measurement = timer.blocked_autorange(min_run_time=1.0) return measurement def print_tensor_info(config, input_tensor, weight_tensor, bias_tensor): """Print information about input tensors""" print(f"\\nConfiguration {config['id']}:") print(f" Dimension type: {config['dim_type']}") print(f" Data type: {config['dtype']}") print(f" Bias type: {config['bias_type']}") print(f" Input tensor shape: {input_tensor.shape} ({input_tensor.dtype})") print(f" Weight tensor shape: {weight_tensor.shape} ({weight_tensor.dtype})") print(f" Bias tensor shape: {bias_tensor.shape} ({bias_tensor.dtype})") def main(): """Main benchmarking function""" print("Torch addmm Benchmarking Script (Compiled Only)") print("=" * 50) # Check CUDA availability if not torch.cuda.is_available(): print("CUDA not available. This benchmark requires CUDA.") return print(f"Using device: {torch.cuda.get_device_name()}") print(f"PyTorch version: {torch.__version__}") # Create test configurations configurations = create_test_configurations() print(f"\\nTesting {len(configurations)} configurations...") results = [] for config in configurations: print(f"\\n{'='*60}") # Create tensors and print info input_tensor, weight_tensor, bias_tensor = create_tensors(config) print_tensor_info(config, input_tensor, weight_tensor, bias_tensor) # Benchmark with compilation only print("\\nBenchmarking with torch.compile:") try: measurement = benchmark_addmm(config, use_compile=True) runtime_ms = measurement.mean * 1000 print(f" Runtime: {runtime_ms:.3f} ms") results.append( { "config": config, "runtime_ms": runtime_ms, } ) except Exception as e: print(f" Error: {e}") # Clear cache torch.cuda.empty_cache() # Print summary print(f"\\n\\n{'='*80}") print("BENCHMARK SUMMARY") print(f"{'='*80}") print( f"{'Config':<8} {'Dim Type':<12} {'(M, N, K)':<20} {'DType':<12} {'Bias':<8} {'Runtime (ms)':<15}" ) print("-" * 90) for result in results: config = result["config"] dtype_str = str(config["dtype"]).split(".")[-1] dimensions_str = f"({config['M']}, {config['N']}, {config['K']})" print( f"{config['id']:<8} {config['dim_type']:<12} {dimensions_str:<20} {dtype_str:<12} {config['bias_type']:<8} {result['runtime_ms']:<15.3f}" ) if __name__ == "__main__": main() ``` cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov [ghstack-poisoned]
# why - simplifies the code a lot - unnecessary for performance (see below) - matches other mm family kernels in logic now # what - if we're not autotuning (only ATen), we use a flexible layout - in any case, we use inp_expanded (the expanded bias view) for the aten kernel # performance analysis ## results this is on H100 ``` ================================================================================ BENCHMARK SUMMARY (MERGED) ================================================================================ Config Dim Type (M, N, K) DType Bias Runtime (inp_expanded) ms Runtime (inp) ms ---------------------------------------------------------------------------------------------------------------------------- 0 large_M (65536, 8192, 8192) float16 full 14.531 14.562 1 large_M (65536, 8192, 8192) float16 row 14.682 14.675 2 large_M (65536, 8192, 8192) float16 column 14.754 14.740 3 large_M (65536, 8192, 8192) bfloat16 full 15.172 15.148 4 large_M (65536, 8192, 8192) bfloat16 row 15.072 15.085 5 large_M (65536, 8192, 8192) bfloat16 column 15.082 15.114 6 large_M (65536, 8192, 8192) float32 full 185.726 186.308 7 large_M (65536, 8192, 8192) float32 row 185.042 185.864 8 large_M (65536, 8192, 8192) float32 column 185.221 185.991 9 large_M (32768, 4096, 4096) float16 full 2.025 2.036 10 large_M (32768, 4096, 4096) float16 row 2.029 2.033 11 large_M (32768, 4096, 4096) float16 column 2.036 2.047 12 large_M (32768, 4096, 4096) bfloat16 full 1.966 1.981 13 large_M (32768, 4096, 4096) bfloat16 row 1.963 1.979 14 large_M (32768, 4096, 4096) bfloat16 column 1.973 1.981 15 large_M (32768, 4096, 4096) float32 full 24.096 24.180 16 large_M (32768, 4096, 4096) float32 row 23.951 24.033 17 large_M (32768, 4096, 4096) float32 column 23.996 24.061 18 large_M (16384, 2048, 2048) float16 full 0.297 0.298 19 large_M (16384, 2048, 2048) float16 row 0.298 0.299 20 large_M (16384, 2048, 2048) float16 column 0.301 0.300 21 large_M (16384, 2048, 2048) bfloat16 full 0.293 0.293 22 large_M (16384, 2048, 2048) bfloat16 row 0.290 0.291 23 large_M (16384, 2048, 2048) bfloat16 column 0.292 0.293 24 large_M (16384, 2048, 2048) float32 full 3.077 3.073 25 large_M (16384, 2048, 2048) float32 row 3.034 3.033 26 large_M (16384, 2048, 2048) float32 column 3.040 3.038 27 large_K (8192, 8192, 65536) float16 full 14.278 14.297 28 large_K (8192, 8192, 65536) float16 row 14.325 14.283 29 large_K (8192, 8192, 65536) float16 column 14.179 14.302 30 large_K (8192, 8192, 65536) bfloat16 full 13.616 13.628 31 large_K (8192, 8192, 65536) bfloat16 row 13.584 13.642 32 large_K (8192, 8192, 65536) bfloat16 column 13.594 13.694 33 large_K (8192, 8192, 65536) float32 full 175.933 176.153 34 large_K (8192, 8192, 65536) float32 row 175.504 175.877 35 large_K (8192, 8192, 65536) float32 column 175.432 175.992 36 large_K (4096, 4096, 32768) float16 full 1.726 1.724 37 large_K (4096, 4096, 32768) float16 row 1.731 1.735 38 large_K (4096, 4096, 32768) float16 column 1.733 1.737 39 large_K (4096, 4096, 32768) bfloat16 full 1.662 1.658 40 large_K (4096, 4096, 32768) bfloat16 row 1.664 1.655 41 large_K (4096, 4096, 32768) bfloat16 column 1.660 1.667 42 large_K (4096, 4096, 32768) float32 full 22.263 22.305 43 large_K (4096, 4096, 32768) float32 row 22.257 22.322 44 large_K (4096, 4096, 32768) float32 column 22.247 22.337 45 large_K (2048, 2048, 16384) float16 full 0.236 0.236 46 large_K (2048, 2048, 16384) float16 row 0.238 0.239 47 large_K (2048, 2048, 16384) float16 column 0.238 0.239 48 large_K (2048, 2048, 16384) bfloat16 full 0.219 0.219 49 large_K (2048, 2048, 16384) bfloat16 row 0.221 0.222 50 large_K (2048, 2048, 16384) bfloat16 column 0.222 0.222 51 large_K (2048, 2048, 16384) float32 full 2.786 2.789 52 large_K (2048, 2048, 16384) float32 row 2.790 2.782 53 large_K (2048, 2048, 16384) float32 column 2.791 2.791 54 large_N (8192, 65536, 8192) float16 full 14.692 14.723 55 large_N (8192, 65536, 8192) float16 row 14.721 14.637 56 large_N (8192, 65536, 8192) float16 column 14.743 14.737 57 large_N (8192, 65536, 8192) bfloat16 full 15.156 15.128 58 large_N (8192, 65536, 8192) bfloat16 row 15.152 15.124 59 large_N (8192, 65536, 8192) bfloat16 column 15.112 15.090 60 large_N (8192, 65536, 8192) float32 full 179.127 179.313 61 large_N (8192, 65536, 8192) float32 row 178.445 178.961 62 large_N (8192, 65536, 8192) float32 column 178.693 178.694 63 large_N (4096, 32768, 4096) float16 full 2.035 2.037 64 large_N (4096, 32768, 4096) float16 row 2.042 2.037 65 large_N (4096, 32768, 4096) float16 column 2.053 2.046 66 large_N (4096, 32768, 4096) bfloat16 full 1.992 1.997 67 large_N (4096, 32768, 4096) bfloat16 row 1.997 1.987 68 large_N (4096, 32768, 4096) bfloat16 column 2.005 2.001 69 large_N (4096, 32768, 4096) float32 full 23.126 23.077 70 large_N (4096, 32768, 4096) float32 row 23.002 22.956 71 large_N (4096, 32768, 4096) float32 column 23.012 22.969 72 large_N (2048, 16384, 2048) float16 full 0.314 0.314 73 large_N (2048, 16384, 2048) float16 row 0.311 0.311 74 large_N (2048, 16384, 2048) float16 column 0.314 0.314 75 large_N (2048, 16384, 2048) bfloat16 full 0.306 0.305 76 large_N (2048, 16384, 2048) bfloat16 row 0.302 0.302 77 large_N (2048, 16384, 2048) bfloat16 column 0.305 0.303 78 large_N (2048, 16384, 2048) float32 full 2.975 2.971 79 large_N (2048, 16384, 2048) float32 row 2.927 2.925 80 large_N (2048, 16384, 2048) float32 column 2.934 2.936 81 large_all (16384, 16384, 16384) float16 full 14.062 14.096 82 large_all (16384, 16384, 16384) float16 row 14.058 14.078 83 large_all (16384, 16384, 16384) float16 column 14.107 14.120 84 large_all (16384, 16384, 16384) bfloat16 full 13.504 13.460 85 large_all (16384, 16384, 16384) bfloat16 row 13.495 13.499 86 large_all (16384, 16384, 16384) bfloat16 column 13.509 13.461 87 large_all (16384, 16384, 16384) float32 full 177.279 177.242 88 large_all (16384, 16384, 16384) float32 row 176.896 176.651 89 large_all (16384, 16384, 16384) float32 column 176.830 176.451 ``` ## script ``` """ Torch addmm benchmarking script covering different input configurations. Tests 30 different combinations of: - Input types: large M, large K, large N, large everything - Data types: float16, bfloat16, float32 - Bias types: full, row, column """ import itertools import torch import torch.utils.benchmark as benchmark def create_test_configurations(): """Create 30 different test configurations for torch.addmm""" # Define dimension configurations # Large M: many rows in input matrix # Large K: many columns in input/rows in weight # Large N: many columns in weight matrix # Large everything: all dimensions large dim_configs = [ # Large M configurations {"M": 65536, "K": 8192, "N": 8192, "type": "large_M"}, {"M": 32768, "K": 4096, "N": 4096, "type": "large_M"}, {"M": 16384, "K": 2048, "N": 2048, "type": "large_M"}, # Large K configurations {"M": 8192, "K": 65536, "N": 8192, "type": "large_K"}, {"M": 4096, "K": 32768, "N": 4096, "type": "large_K"}, {"M": 2048, "K": 16384, "N": 2048, "type": "large_K"}, # Large N configurations {"M": 8192, "K": 8192, "N": 65536, "type": "large_N"}, {"M": 4096, "K": 4096, "N": 32768, "type": "large_N"}, {"M": 2048, "K": 2048, "N": 16384, "type": "large_N"}, # Large everything configurations {"M": 16384, "K": 16384, "N": 16384, "type": "large_all"}, ] # Data types to test dtypes = [torch.float16, torch.bfloat16, torch.float32] # Bias configurations bias_configs = ["full", "row", "column"] # Generate all combinations configurations = [] config_id = 0 for dim_config, dtype, bias_type in itertools.product( dim_configs, dtypes, bias_configs ): config = { "id": config_id, "M": dim_config["M"], "K": dim_config["K"], "N": dim_config["N"], "dim_type": dim_config["type"], "dtype": dtype, "bias_type": bias_type, } configurations.append(config) config_id += 1 return configurations def create_tensors(config): """Create input tensors for a given configuration""" M, K, N = config["M"], config["K"], config["N"] dtype = config["dtype"] bias_type = config["bias_type"] # Create input tensor (M x K) input_tensor = torch.randn(M, K, dtype=dtype, device="cuda") # Create weight tensor (K x N) weight_tensor = torch.randn(K, N, dtype=dtype, device="cuda") # Create bias tensor based on bias type if bias_type == "full": bias_tensor = torch.randn(M, N, dtype=dtype, device="cuda") elif bias_type == "row": bias_tensor = torch.randn(M, 1, dtype=dtype, device="cuda") elif bias_type == "column": bias_tensor = torch.randn(1, N, dtype=dtype, device="cuda") return input_tensor, weight_tensor, bias_tensor def benchmark_addmm(config, use_compile=False): """Benchmark torch.addmm for a given configuration""" input_tensor, weight_tensor, bias_tensor = create_tensors(config) # Define the operation def addmm_op(): return torch.addmm(bias_tensor, input_tensor, weight_tensor) # Optionally compile the operation if use_compile: addmm_op = torch.compile(addmm_op) # Warmup for compiled version for _ in range(3): _ = addmm_op() torch.cuda.synchronize() # Benchmark using torch.utils.benchmark timer = benchmark.Timer( stmt="addmm_op()", globals={"addmm_op": addmm_op}, description=f"addmm_{config['dim_type']}_{config['dtype']}_{config['bias_type']}_compiled_{use_compile}", ) # Run benchmark measurement = timer.blocked_autorange(min_run_time=1.0) return measurement def print_tensor_info(config, input_tensor, weight_tensor, bias_tensor): """Print information about input tensors""" print(f"\\nConfiguration {config['id']}:") print(f" Dimension type: {config['dim_type']}") print(f" Data type: {config['dtype']}") print(f" Bias type: {config['bias_type']}") print(f" Input tensor shape: {input_tensor.shape} ({input_tensor.dtype})") print(f" Weight tensor shape: {weight_tensor.shape} ({weight_tensor.dtype})") print(f" Bias tensor shape: {bias_tensor.shape} ({bias_tensor.dtype})") def main(): """Main benchmarking function""" print("Torch addmm Benchmarking Script (Compiled Only)") print("=" * 50) # Check CUDA availability if not torch.cuda.is_available(): print("CUDA not available. This benchmark requires CUDA.") return print(f"Using device: {torch.cuda.get_device_name()}") print(f"PyTorch version: {torch.__version__}") # Create test configurations configurations = create_test_configurations() print(f"\\nTesting {len(configurations)} configurations...") results = [] for config in configurations: print(f"\\n{'='*60}") # Create tensors and print info input_tensor, weight_tensor, bias_tensor = create_tensors(config) print_tensor_info(config, input_tensor, weight_tensor, bias_tensor) # Benchmark with compilation only print("\\nBenchmarking with torch.compile:") try: measurement = benchmark_addmm(config, use_compile=True) runtime_ms = measurement.mean * 1000 print(f" Runtime: {runtime_ms:.3f} ms") results.append( { "config": config, "runtime_ms": runtime_ms, } ) except Exception as e: print(f" Error: {e}") # Clear cache torch.cuda.empty_cache() # Print summary print(f"\\n\\n{'='*80}") print("BENCHMARK SUMMARY") print(f"{'='*80}") print( f"{'Config':<8} {'Dim Type':<12} {'(M, N, K)':<20} {'DType':<12} {'Bias':<8} {'Runtime (ms)':<15}" ) print("-" * 90) for result in results: config = result["config"] dtype_str = str(config["dtype"]).split(".")[-1] dimensions_str = f"({config['M']}, {config['N']}, {config['K']})" print( f"{config['id']:<8} {config['dim_type']:<12} {dimensions_str:<20} {dtype_str:<12} {config['bias_type']:<8} {result['runtime_ms']:<15.3f}" ) if __name__ == "__main__": main() ``` cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov [ghstack-poisoned]
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Labels
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Stack from ghstack (oldest at bottom):
why
what
the aten kernel
performance analysis
results
this is on H100
script
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov @Lucaskabela