Skip to content

Conversation

coconutruben
Copy link
Contributor

@coconutruben coconutruben commented Aug 21, 2025

Stack from ghstack (oldest at bottom):

why

  • simplifies the code a lot
  • unnecessary for performance (see below)
  • matches other mm family kernels in logic now

what

  • if we're not autotuning (only ATen), we use a flexible layout
  • in any case, we use inp_expanded (the expanded bias view) for
    the aten kernel

performance analysis

results

this is on H100

================================================================================
BENCHMARK SUMMARY (MERGED)
================================================================================
Config   Dim Type     (M, N, K)            DType        Bias     Runtime (inp_expanded) ms   Runtime (inp) ms
----------------------------------------------------------------------------------------------------------------------------
0        large_M      (65536, 8192, 8192)  float16      full     14.531                      14.562
1        large_M      (65536, 8192, 8192)  float16      row      14.682                      14.675
2        large_M      (65536, 8192, 8192)  float16      column   14.754                      14.740
3        large_M      (65536, 8192, 8192)  bfloat16     full     15.172                      15.148
4        large_M      (65536, 8192, 8192)  bfloat16     row      15.072                      15.085
5        large_M      (65536, 8192, 8192)  bfloat16     column   15.082                      15.114
6        large_M      (65536, 8192, 8192)  float32      full     185.726                     186.308
7        large_M      (65536, 8192, 8192)  float32      row      185.042                     185.864
8        large_M      (65536, 8192, 8192)  float32      column   185.221                     185.991
9        large_M      (32768, 4096, 4096)  float16      full     2.025                       2.036
10       large_M      (32768, 4096, 4096)  float16      row      2.029                       2.033
11       large_M      (32768, 4096, 4096)  float16      column   2.036                       2.047
12       large_M      (32768, 4096, 4096)  bfloat16     full     1.966                       1.981
13       large_M      (32768, 4096, 4096)  bfloat16     row      1.963                       1.979
14       large_M      (32768, 4096, 4096)  bfloat16     column   1.973                       1.981
15       large_M      (32768, 4096, 4096)  float32      full     24.096                      24.180
16       large_M      (32768, 4096, 4096)  float32      row      23.951                      24.033
17       large_M      (32768, 4096, 4096)  float32      column   23.996                      24.061
18       large_M      (16384, 2048, 2048)  float16      full     0.297                       0.298
19       large_M      (16384, 2048, 2048)  float16      row      0.298                       0.299
20       large_M      (16384, 2048, 2048)  float16      column   0.301                       0.300
21       large_M      (16384, 2048, 2048)  bfloat16     full     0.293                       0.293
22       large_M      (16384, 2048, 2048)  bfloat16     row      0.290                       0.291
23       large_M      (16384, 2048, 2048)  bfloat16     column   0.292                       0.293
24       large_M      (16384, 2048, 2048)  float32      full     3.077                       3.073
25       large_M      (16384, 2048, 2048)  float32      row      3.034                       3.033
26       large_M      (16384, 2048, 2048)  float32      column   3.040                       3.038
27       large_K      (8192, 8192, 65536)  float16      full     14.278                      14.297
28       large_K      (8192, 8192, 65536)  float16      row      14.325                      14.283
29       large_K      (8192, 8192, 65536)  float16      column   14.179                      14.302
30       large_K      (8192, 8192, 65536)  bfloat16     full     13.616                      13.628
31       large_K      (8192, 8192, 65536)  bfloat16     row      13.584                      13.642
32       large_K      (8192, 8192, 65536)  bfloat16     column   13.594                      13.694
33       large_K      (8192, 8192, 65536)  float32      full     175.933                     176.153
34       large_K      (8192, 8192, 65536)  float32      row      175.504                     175.877
35       large_K      (8192, 8192, 65536)  float32      column   175.432                     175.992
36       large_K      (4096, 4096, 32768)  float16      full     1.726                       1.724
37       large_K      (4096, 4096, 32768)  float16      row      1.731                       1.735
38       large_K      (4096, 4096, 32768)  float16      column   1.733                       1.737
39       large_K      (4096, 4096, 32768)  bfloat16     full     1.662                       1.658
40       large_K      (4096, 4096, 32768)  bfloat16     row      1.664                       1.655
41       large_K      (4096, 4096, 32768)  bfloat16     column   1.660                       1.667
42       large_K      (4096, 4096, 32768)  float32      full     22.263                      22.305
43       large_K      (4096, 4096, 32768)  float32      row      22.257                      22.322
44       large_K      (4096, 4096, 32768)  float32      column   22.247                      22.337
45       large_K      (2048, 2048, 16384)  float16      full     0.236                       0.236
46       large_K      (2048, 2048, 16384)  float16      row      0.238                       0.239
47       large_K      (2048, 2048, 16384)  float16      column   0.238                       0.239
48       large_K      (2048, 2048, 16384)  bfloat16     full     0.219                       0.219
49       large_K      (2048, 2048, 16384)  bfloat16     row      0.221                       0.222
50       large_K      (2048, 2048, 16384)  bfloat16     column   0.222                       0.222
51       large_K      (2048, 2048, 16384)  float32      full     2.786                       2.789
52       large_K      (2048, 2048, 16384)  float32      row      2.790                       2.782
53       large_K      (2048, 2048, 16384)  float32      column   2.791                       2.791
54       large_N      (8192, 65536, 8192)  float16      full     14.692                      14.723
55       large_N      (8192, 65536, 8192)  float16      row      14.721                      14.637
56       large_N      (8192, 65536, 8192)  float16      column   14.743                      14.737
57       large_N      (8192, 65536, 8192)  bfloat16     full     15.156                      15.128
58       large_N      (8192, 65536, 8192)  bfloat16     row      15.152                      15.124
59       large_N      (8192, 65536, 8192)  bfloat16     column   15.112                      15.090
60       large_N      (8192, 65536, 8192)  float32      full     179.127                     179.313
61       large_N      (8192, 65536, 8192)  float32      row      178.445                     178.961
62       large_N      (8192, 65536, 8192)  float32      column   178.693                     178.694
63       large_N      (4096, 32768, 4096)  float16      full     2.035                       2.037
64       large_N      (4096, 32768, 4096)  float16      row      2.042                       2.037
65       large_N      (4096, 32768, 4096)  float16      column   2.053                       2.046
66       large_N      (4096, 32768, 4096)  bfloat16     full     1.992                       1.997
67       large_N      (4096, 32768, 4096)  bfloat16     row      1.997                       1.987
68       large_N      (4096, 32768, 4096)  bfloat16     column   2.005                       2.001
69       large_N      (4096, 32768, 4096)  float32      full     23.126                      23.077
70       large_N      (4096, 32768, 4096)  float32      row      23.002                      22.956
71       large_N      (4096, 32768, 4096)  float32      column   23.012                      22.969
72       large_N      (2048, 16384, 2048)  float16      full     0.314                       0.314
73       large_N      (2048, 16384, 2048)  float16      row      0.311                       0.311
74       large_N      (2048, 16384, 2048)  float16      column   0.314                       0.314
75       large_N      (2048, 16384, 2048)  bfloat16     full     0.306                       0.305
76       large_N      (2048, 16384, 2048)  bfloat16     row      0.302                       0.302
77       large_N      (2048, 16384, 2048)  bfloat16     column   0.305                       0.303
78       large_N      (2048, 16384, 2048)  float32      full     2.975                       2.971
79       large_N      (2048, 16384, 2048)  float32      row      2.927                       2.925
80       large_N      (2048, 16384, 2048)  float32      column   2.934                       2.936
81       large_all    (16384, 16384, 16384) float16      full     14.062                      14.096
82       large_all    (16384, 16384, 16384) float16      row      14.058                      14.078
83       large_all    (16384, 16384, 16384) float16      column   14.107                      14.120
84       large_all    (16384, 16384, 16384) bfloat16     full     13.504                      13.460
85       large_all    (16384, 16384, 16384) bfloat16     row      13.495                      13.499
86       large_all    (16384, 16384, 16384) bfloat16     column   13.509                      13.461
87       large_all    (16384, 16384, 16384) float32      full     177.279                     177.242
88       large_all    (16384, 16384, 16384) float32      row      176.896                     176.651
89       large_all    (16384, 16384, 16384) float32      column   176.830                     176.451

script

"""
Torch addmm benchmarking script covering different input configurations.
Tests 30 different combinations of:
- Input types: large M, large K, large N, large everything
- Data types: float16, bfloat16, float32
- Bias types: full, row, column
"""

import itertools

import torch
import torch.utils.benchmark as benchmark

def create_test_configurations():
    """Create 30 different test configurations for torch.addmm"""

    # Define dimension configurations
    # Large M: many rows in input matrix
    # Large K: many columns in input/rows in weight
    # Large N: many columns in weight matrix
    # Large everything: all dimensions large
    dim_configs = [
        # Large M configurations
        {"M": 65536, "K": 8192, "N": 8192, "type": "large_M"},
        {"M": 32768, "K": 4096, "N": 4096, "type": "large_M"},
        {"M": 16384, "K": 2048, "N": 2048, "type": "large_M"},
        # Large K configurations
        {"M": 8192, "K": 65536, "N": 8192, "type": "large_K"},
        {"M": 4096, "K": 32768, "N": 4096, "type": "large_K"},
        {"M": 2048, "K": 16384, "N": 2048, "type": "large_K"},
        # Large N configurations
        {"M": 8192, "K": 8192, "N": 65536, "type": "large_N"},
        {"M": 4096, "K": 4096, "N": 32768, "type": "large_N"},
        {"M": 2048, "K": 2048, "N": 16384, "type": "large_N"},
        # Large everything configurations
        {"M": 16384, "K": 16384, "N": 16384, "type": "large_all"},
    ]

    # Data types to test
    dtypes = [torch.float16, torch.bfloat16, torch.float32]

    # Bias configurations
    bias_configs = ["full", "row", "column"]

    # Generate all combinations
    configurations = []
    config_id = 0

    for dim_config, dtype, bias_type in itertools.product(
        dim_configs, dtypes, bias_configs
    ):
        config = {
            "id": config_id,
            "M": dim_config["M"],
            "K": dim_config["K"],
            "N": dim_config["N"],
            "dim_type": dim_config["type"],
            "dtype": dtype,
            "bias_type": bias_type,
        }
        configurations.append(config)
        config_id += 1

    return configurations

def create_tensors(config):
    """Create input tensors for a given configuration"""
    M, K, N = config["M"], config["K"], config["N"]
    dtype = config["dtype"]
    bias_type = config["bias_type"]

    # Create input tensor (M x K)
    input_tensor = torch.randn(M, K, dtype=dtype, device="cuda")

    # Create weight tensor (K x N)
    weight_tensor = torch.randn(K, N, dtype=dtype, device="cuda")

    # Create bias tensor based on bias type
    if bias_type == "full":
        bias_tensor = torch.randn(M, N, dtype=dtype, device="cuda")
    elif bias_type == "row":
        bias_tensor = torch.randn(M, 1, dtype=dtype, device="cuda")
    elif bias_type == "column":
        bias_tensor = torch.randn(1, N, dtype=dtype, device="cuda")

    return input_tensor, weight_tensor, bias_tensor

def benchmark_addmm(config, use_compile=False):
    """Benchmark torch.addmm for a given configuration"""
    input_tensor, weight_tensor, bias_tensor = create_tensors(config)

    # Define the operation
    def addmm_op():
        return torch.addmm(bias_tensor, input_tensor, weight_tensor)

    # Optionally compile the operation
    if use_compile:
        addmm_op = torch.compile(addmm_op)
        # Warmup for compiled version
        for _ in range(3):
            _ = addmm_op()
        torch.cuda.synchronize()

    # Benchmark using torch.utils.benchmark
    timer = benchmark.Timer(
        stmt="addmm_op()",
        globals={"addmm_op": addmm_op},
        description=f"addmm_{config['dim_type']}_{config['dtype']}_{config['bias_type']}_compiled_{use_compile}",
    )

    # Run benchmark
    measurement = timer.blocked_autorange(min_run_time=1.0)

    return measurement

def print_tensor_info(config, input_tensor, weight_tensor, bias_tensor):
    """Print information about input tensors"""
    print(f"\\nConfiguration {config['id']}:")
    print(f"  Dimension type: {config['dim_type']}")
    print(f"  Data type: {config['dtype']}")
    print(f"  Bias type: {config['bias_type']}")
    print(f"  Input tensor shape: {input_tensor.shape} ({input_tensor.dtype})")
    print(f"  Weight tensor shape: {weight_tensor.shape} ({weight_tensor.dtype})")
    print(f"  Bias tensor shape: {bias_tensor.shape} ({bias_tensor.dtype})")

def main():
    """Main benchmarking function"""
    print("Torch addmm Benchmarking Script (Compiled Only)")
    print("=" * 50)

    # Check CUDA availability
    if not torch.cuda.is_available():
        print("CUDA not available. This benchmark requires CUDA.")
        return

    print(f"Using device: {torch.cuda.get_device_name()}")
    print(f"PyTorch version: {torch.__version__}")

    # Create test configurations
    configurations = create_test_configurations()
    print(f"\\nTesting {len(configurations)} configurations...")

    results = []

    for config in configurations:
        print(f"\\n{'='*60}")

        # Create tensors and print info
        input_tensor, weight_tensor, bias_tensor = create_tensors(config)
        print_tensor_info(config, input_tensor, weight_tensor, bias_tensor)

        # Benchmark with compilation only
        print("\\nBenchmarking with torch.compile:")
        try:
            measurement = benchmark_addmm(config, use_compile=True)
            runtime_ms = measurement.mean * 1000
            print(f"  Runtime: {runtime_ms:.3f} ms")
            results.append(
                {
                    "config": config,
                    "runtime_ms": runtime_ms,
                }
            )
        except Exception as e:
            print(f"  Error: {e}")

        # Clear cache
        torch.cuda.empty_cache()

    # Print summary
    print(f"\\n\\n{'='*80}")
    print("BENCHMARK SUMMARY")
    print(f"{'='*80}")

    print(
        f"{'Config':<8} {'Dim Type':<12} {'(M, N, K)':<20} {'DType':<12} {'Bias':<8} {'Runtime (ms)':<15}"
    )
    print("-" * 90)

    for result in results:
        config = result["config"]
        dtype_str = str(config["dtype"]).split(".")[-1]
        dimensions_str = f"({config['M']}, {config['N']}, {config['K']})"
        print(
            f"{config['id']:<8} {config['dim_type']:<12} {dimensions_str:<20} {dtype_str:<12} {config['bias_type']:<8} {result['runtime_ms']:<15.3f}"
        )

if __name__ == "__main__":
    main()

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov @Lucaskabela

\# why

- simplifies the code a lot
- unnecessary for performance (see below)
- matches other mm family kernels in logic now

\# what

- if we're not autotuning (only ATen), we use a flexible layout
- in any case, we use inp_expanded (the expanded bias view) for
  the aten kernel

\# performance analysis

\## results

this is on H100

```
================================================================================
BENCHMARK SUMMARY (MERGED)
================================================================================
Config   Dim Type     (M, N, K)            DType        Bias     Runtime (inp_expanded) ms   Runtime (inp) ms
----------------------------------------------------------------------------------------------------------------------------
0        large_M      (65536, 8192, 8192)  float16      full     14.531                      14.562
1        large_M      (65536, 8192, 8192)  float16      row      14.682                      14.675
2        large_M      (65536, 8192, 8192)  float16      column   14.754                      14.740
3        large_M      (65536, 8192, 8192)  bfloat16     full     15.172                      15.148
4        large_M      (65536, 8192, 8192)  bfloat16     row      15.072                      15.085
5        large_M      (65536, 8192, 8192)  bfloat16     column   15.082                      15.114
6        large_M      (65536, 8192, 8192)  float32      full     185.726                     186.308
7        large_M      (65536, 8192, 8192)  float32      row      185.042                     185.864
8        large_M      (65536, 8192, 8192)  float32      column   185.221                     185.991
9        large_M      (32768, 4096, 4096)  float16      full     2.025                       2.036
10       large_M      (32768, 4096, 4096)  float16      row      2.029                       2.033
11       large_M      (32768, 4096, 4096)  float16      column   2.036                       2.047
12       large_M      (32768, 4096, 4096)  bfloat16     full     1.966                       1.981
13       large_M      (32768, 4096, 4096)  bfloat16     row      1.963                       1.979
14       large_M      (32768, 4096, 4096)  bfloat16     column   1.973                       1.981
15       large_M      (32768, 4096, 4096)  float32      full     24.096                      24.180
16       large_M      (32768, 4096, 4096)  float32      row      23.951                      24.033
17       large_M      (32768, 4096, 4096)  float32      column   23.996                      24.061
18       large_M      (16384, 2048, 2048)  float16      full     0.297                       0.298
19       large_M      (16384, 2048, 2048)  float16      row      0.298                       0.299
20       large_M      (16384, 2048, 2048)  float16      column   0.301                       0.300
21       large_M      (16384, 2048, 2048)  bfloat16     full     0.293                       0.293
22       large_M      (16384, 2048, 2048)  bfloat16     row      0.290                       0.291
23       large_M      (16384, 2048, 2048)  bfloat16     column   0.292                       0.293
24       large_M      (16384, 2048, 2048)  float32      full     3.077                       3.073
25       large_M      (16384, 2048, 2048)  float32      row      3.034                       3.033
26       large_M      (16384, 2048, 2048)  float32      column   3.040                       3.038
27       large_K      (8192, 8192, 65536)  float16      full     14.278                      14.297
28       large_K      (8192, 8192, 65536)  float16      row      14.325                      14.283
29       large_K      (8192, 8192, 65536)  float16      column   14.179                      14.302
30       large_K      (8192, 8192, 65536)  bfloat16     full     13.616                      13.628
31       large_K      (8192, 8192, 65536)  bfloat16     row      13.584                      13.642
32       large_K      (8192, 8192, 65536)  bfloat16     column   13.594                      13.694
33       large_K      (8192, 8192, 65536)  float32      full     175.933                     176.153
34       large_K      (8192, 8192, 65536)  float32      row      175.504                     175.877
35       large_K      (8192, 8192, 65536)  float32      column   175.432                     175.992
36       large_K      (4096, 4096, 32768)  float16      full     1.726                       1.724
37       large_K      (4096, 4096, 32768)  float16      row      1.731                       1.735
38       large_K      (4096, 4096, 32768)  float16      column   1.733                       1.737
39       large_K      (4096, 4096, 32768)  bfloat16     full     1.662                       1.658
40       large_K      (4096, 4096, 32768)  bfloat16     row      1.664                       1.655
41       large_K      (4096, 4096, 32768)  bfloat16     column   1.660                       1.667
42       large_K      (4096, 4096, 32768)  float32      full     22.263                      22.305
43       large_K      (4096, 4096, 32768)  float32      row      22.257                      22.322
44       large_K      (4096, 4096, 32768)  float32      column   22.247                      22.337
45       large_K      (2048, 2048, 16384)  float16      full     0.236                       0.236
46       large_K      (2048, 2048, 16384)  float16      row      0.238                       0.239
47       large_K      (2048, 2048, 16384)  float16      column   0.238                       0.239
48       large_K      (2048, 2048, 16384)  bfloat16     full     0.219                       0.219
49       large_K      (2048, 2048, 16384)  bfloat16     row      0.221                       0.222
50       large_K      (2048, 2048, 16384)  bfloat16     column   0.222                       0.222
51       large_K      (2048, 2048, 16384)  float32      full     2.786                       2.789
52       large_K      (2048, 2048, 16384)  float32      row      2.790                       2.782
53       large_K      (2048, 2048, 16384)  float32      column   2.791                       2.791
54       large_N      (8192, 65536, 8192)  float16      full     14.692                      14.723
55       large_N      (8192, 65536, 8192)  float16      row      14.721                      14.637
56       large_N      (8192, 65536, 8192)  float16      column   14.743                      14.737
57       large_N      (8192, 65536, 8192)  bfloat16     full     15.156                      15.128
58       large_N      (8192, 65536, 8192)  bfloat16     row      15.152                      15.124
59       large_N      (8192, 65536, 8192)  bfloat16     column   15.112                      15.090
60       large_N      (8192, 65536, 8192)  float32      full     179.127                     179.313
61       large_N      (8192, 65536, 8192)  float32      row      178.445                     178.961
62       large_N      (8192, 65536, 8192)  float32      column   178.693                     178.694
63       large_N      (4096, 32768, 4096)  float16      full     2.035                       2.037
64       large_N      (4096, 32768, 4096)  float16      row      2.042                       2.037
65       large_N      (4096, 32768, 4096)  float16      column   2.053                       2.046
66       large_N      (4096, 32768, 4096)  bfloat16     full     1.992                       1.997
67       large_N      (4096, 32768, 4096)  bfloat16     row      1.997                       1.987
68       large_N      (4096, 32768, 4096)  bfloat16     column   2.005                       2.001
69       large_N      (4096, 32768, 4096)  float32      full     23.126                      23.077
70       large_N      (4096, 32768, 4096)  float32      row      23.002                      22.956
71       large_N      (4096, 32768, 4096)  float32      column   23.012                      22.969
72       large_N      (2048, 16384, 2048)  float16      full     0.314                       0.314
73       large_N      (2048, 16384, 2048)  float16      row      0.311                       0.311
74       large_N      (2048, 16384, 2048)  float16      column   0.314                       0.314
75       large_N      (2048, 16384, 2048)  bfloat16     full     0.306                       0.305
76       large_N      (2048, 16384, 2048)  bfloat16     row      0.302                       0.302
77       large_N      (2048, 16384, 2048)  bfloat16     column   0.305                       0.303
78       large_N      (2048, 16384, 2048)  float32      full     2.975                       2.971
79       large_N      (2048, 16384, 2048)  float32      row      2.927                       2.925
80       large_N      (2048, 16384, 2048)  float32      column   2.934                       2.936
81       large_all    (16384, 16384, 16384) float16      full     14.062                      14.096
82       large_all    (16384, 16384, 16384) float16      row      14.058                      14.078
83       large_all    (16384, 16384, 16384) float16      column   14.107                      14.120
84       large_all    (16384, 16384, 16384) bfloat16     full     13.504                      13.460
85       large_all    (16384, 16384, 16384) bfloat16     row      13.495                      13.499
86       large_all    (16384, 16384, 16384) bfloat16     column   13.509                      13.461
87       large_all    (16384, 16384, 16384) float32      full     177.279                     177.242
88       large_all    (16384, 16384, 16384) float32      row      176.896                     176.651
89       large_all    (16384, 16384, 16384) float32      column   176.830                     176.451
```

\## script

```
"""
Torch addmm benchmarking script covering different input configurations.
Tests 30 different combinations of:
- Input types: large M, large K, large N, large everything
- Data types: float16, bfloat16, float32
- Bias types: full, row, column
"""

import itertools

import torch
import torch.utils.benchmark as benchmark

def create_test_configurations():
    """Create 30 different test configurations for torch.addmm"""

    # Define dimension configurations
    # Large M: many rows in input matrix
    # Large K: many columns in input/rows in weight
    # Large N: many columns in weight matrix
    # Large everything: all dimensions large
    dim_configs = [
        # Large M configurations
        {"M": 65536, "K": 8192, "N": 8192, "type": "large_M"},
        {"M": 32768, "K": 4096, "N": 4096, "type": "large_M"},
        {"M": 16384, "K": 2048, "N": 2048, "type": "large_M"},
        # Large K configurations
        {"M": 8192, "K": 65536, "N": 8192, "type": "large_K"},
        {"M": 4096, "K": 32768, "N": 4096, "type": "large_K"},
        {"M": 2048, "K": 16384, "N": 2048, "type": "large_K"},
        # Large N configurations
        {"M": 8192, "K": 8192, "N": 65536, "type": "large_N"},
        {"M": 4096, "K": 4096, "N": 32768, "type": "large_N"},
        {"M": 2048, "K": 2048, "N": 16384, "type": "large_N"},
        # Large everything configurations
        {"M": 16384, "K": 16384, "N": 16384, "type": "large_all"},
    ]

    # Data types to test
    dtypes = [torch.float16, torch.bfloat16, torch.float32]

    # Bias configurations
    bias_configs = ["full", "row", "column"]

    # Generate all combinations
    configurations = []
    config_id = 0

    for dim_config, dtype, bias_type in itertools.product(
        dim_configs, dtypes, bias_configs
    ):
        config = {
            "id": config_id,
            "M": dim_config["M"],
            "K": dim_config["K"],
            "N": dim_config["N"],
            "dim_type": dim_config["type"],
            "dtype": dtype,
            "bias_type": bias_type,
        }
        configurations.append(config)
        config_id += 1

    return configurations

def create_tensors(config):
    """Create input tensors for a given configuration"""
    M, K, N = config["M"], config["K"], config["N"]
    dtype = config["dtype"]
    bias_type = config["bias_type"]

    # Create input tensor (M x K)
    input_tensor = torch.randn(M, K, dtype=dtype, device="cuda")

    # Create weight tensor (K x N)
    weight_tensor = torch.randn(K, N, dtype=dtype, device="cuda")

    # Create bias tensor based on bias type
    if bias_type == "full":
        bias_tensor = torch.randn(M, N, dtype=dtype, device="cuda")
    elif bias_type == "row":
        bias_tensor = torch.randn(M, 1, dtype=dtype, device="cuda")
    elif bias_type == "column":
        bias_tensor = torch.randn(1, N, dtype=dtype, device="cuda")

    return input_tensor, weight_tensor, bias_tensor

def benchmark_addmm(config, use_compile=False):
    """Benchmark torch.addmm for a given configuration"""
    input_tensor, weight_tensor, bias_tensor = create_tensors(config)

    # Define the operation
    def addmm_op():
        return torch.addmm(bias_tensor, input_tensor, weight_tensor)

    # Optionally compile the operation
    if use_compile:
        addmm_op = torch.compile(addmm_op)
        # Warmup for compiled version
        for _ in range(3):
            _ = addmm_op()
        torch.cuda.synchronize()

    # Benchmark using torch.utils.benchmark
    timer = benchmark.Timer(
        stmt="addmm_op()",
        globals={"addmm_op": addmm_op},
        description=f"addmm_{config['dim_type']}_{config['dtype']}_{config['bias_type']}_compiled_{use_compile}",
    )

    # Run benchmark
    measurement = timer.blocked_autorange(min_run_time=1.0)

    return measurement

def print_tensor_info(config, input_tensor, weight_tensor, bias_tensor):
    """Print information about input tensors"""
    print(f"\\nConfiguration {config['id']}:")
    print(f"  Dimension type: {config['dim_type']}")
    print(f"  Data type: {config['dtype']}")
    print(f"  Bias type: {config['bias_type']}")
    print(f"  Input tensor shape: {input_tensor.shape} ({input_tensor.dtype})")
    print(f"  Weight tensor shape: {weight_tensor.shape} ({weight_tensor.dtype})")
    print(f"  Bias tensor shape: {bias_tensor.shape} ({bias_tensor.dtype})")

def main():
    """Main benchmarking function"""
    print("Torch addmm Benchmarking Script (Compiled Only)")
    print("=" * 50)

    # Check CUDA availability
    if not torch.cuda.is_available():
        print("CUDA not available. This benchmark requires CUDA.")
        return

    print(f"Using device: {torch.cuda.get_device_name()}")
    print(f"PyTorch version: {torch.__version__}")

    # Create test configurations
    configurations = create_test_configurations()
    print(f"\\nTesting {len(configurations)} configurations...")

    results = []

    for config in configurations:
        print(f"\\n{'='*60}")

        # Create tensors and print info
        input_tensor, weight_tensor, bias_tensor = create_tensors(config)
        print_tensor_info(config, input_tensor, weight_tensor, bias_tensor)

        # Benchmark with compilation only
        print("\\nBenchmarking with torch.compile:")
        try:
            measurement = benchmark_addmm(config, use_compile=True)
            runtime_ms = measurement.mean * 1000
            print(f"  Runtime: {runtime_ms:.3f} ms")
            results.append(
                {
                    "config": config,
                    "runtime_ms": runtime_ms,
                }
            )
        except Exception as e:
            print(f"  Error: {e}")

        # Clear cache
        torch.cuda.empty_cache()

    # Print summary
    print(f"\\n\\n{'='*80}")
    print("BENCHMARK SUMMARY")
    print(f"{'='*80}")

    print(
        f"{'Config':<8} {'Dim Type':<12} {'(M, N, K)':<20} {'DType':<12} {'Bias':<8} {'Runtime (ms)':<15}"
    )
    print("-" * 90)

    for result in results:
        config = result["config"]
        dtype_str = str(config["dtype"]).split(".")[-1]
        dimensions_str = f"({config['M']}, {config['N']}, {config['K']})"
        print(
            f"{config['id']:<8} {config['dim_type']:<12} {dimensions_str:<20} {dtype_str:<12} {config['bias_type']:<8} {result['runtime_ms']:<15.3f}"
        )

if __name__ == "__main__":
    main()
```

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Aug 21, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/161208

Note: Links to docs will display an error until the docs builds have been completed.

❌ 8 New Failures

As of commit bb9ebbf with merge base 38a492d (image):

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

coconutruben added a commit that referenced this pull request Aug 21, 2025
\# why

- simplifies the code a lot
- unnecessary for performance (see below)
- matches other mm family kernels in logic now

\# what

- if we're not autotuning (only ATen), we use a flexible layout
- in any case, we use inp_expanded (the expanded bias view) for
  the aten kernel

\# performance analysis

\## results

this is on H100

```
================================================================================
BENCHMARK SUMMARY (MERGED)
================================================================================
Config   Dim Type     (M, N, K)            DType        Bias     Runtime (inp_expanded) ms   Runtime (inp) ms
----------------------------------------------------------------------------------------------------------------------------
0        large_M      (65536, 8192, 8192)  float16      full     14.531                      14.562
1        large_M      (65536, 8192, 8192)  float16      row      14.682                      14.675
2        large_M      (65536, 8192, 8192)  float16      column   14.754                      14.740
3        large_M      (65536, 8192, 8192)  bfloat16     full     15.172                      15.148
4        large_M      (65536, 8192, 8192)  bfloat16     row      15.072                      15.085
5        large_M      (65536, 8192, 8192)  bfloat16     column   15.082                      15.114
6        large_M      (65536, 8192, 8192)  float32      full     185.726                     186.308
7        large_M      (65536, 8192, 8192)  float32      row      185.042                     185.864
8        large_M      (65536, 8192, 8192)  float32      column   185.221                     185.991
9        large_M      (32768, 4096, 4096)  float16      full     2.025                       2.036
10       large_M      (32768, 4096, 4096)  float16      row      2.029                       2.033
11       large_M      (32768, 4096, 4096)  float16      column   2.036                       2.047
12       large_M      (32768, 4096, 4096)  bfloat16     full     1.966                       1.981
13       large_M      (32768, 4096, 4096)  bfloat16     row      1.963                       1.979
14       large_M      (32768, 4096, 4096)  bfloat16     column   1.973                       1.981
15       large_M      (32768, 4096, 4096)  float32      full     24.096                      24.180
16       large_M      (32768, 4096, 4096)  float32      row      23.951                      24.033
17       large_M      (32768, 4096, 4096)  float32      column   23.996                      24.061
18       large_M      (16384, 2048, 2048)  float16      full     0.297                       0.298
19       large_M      (16384, 2048, 2048)  float16      row      0.298                       0.299
20       large_M      (16384, 2048, 2048)  float16      column   0.301                       0.300
21       large_M      (16384, 2048, 2048)  bfloat16     full     0.293                       0.293
22       large_M      (16384, 2048, 2048)  bfloat16     row      0.290                       0.291
23       large_M      (16384, 2048, 2048)  bfloat16     column   0.292                       0.293
24       large_M      (16384, 2048, 2048)  float32      full     3.077                       3.073
25       large_M      (16384, 2048, 2048)  float32      row      3.034                       3.033
26       large_M      (16384, 2048, 2048)  float32      column   3.040                       3.038
27       large_K      (8192, 8192, 65536)  float16      full     14.278                      14.297
28       large_K      (8192, 8192, 65536)  float16      row      14.325                      14.283
29       large_K      (8192, 8192, 65536)  float16      column   14.179                      14.302
30       large_K      (8192, 8192, 65536)  bfloat16     full     13.616                      13.628
31       large_K      (8192, 8192, 65536)  bfloat16     row      13.584                      13.642
32       large_K      (8192, 8192, 65536)  bfloat16     column   13.594                      13.694
33       large_K      (8192, 8192, 65536)  float32      full     175.933                     176.153
34       large_K      (8192, 8192, 65536)  float32      row      175.504                     175.877
35       large_K      (8192, 8192, 65536)  float32      column   175.432                     175.992
36       large_K      (4096, 4096, 32768)  float16      full     1.726                       1.724
37       large_K      (4096, 4096, 32768)  float16      row      1.731                       1.735
38       large_K      (4096, 4096, 32768)  float16      column   1.733                       1.737
39       large_K      (4096, 4096, 32768)  bfloat16     full     1.662                       1.658
40       large_K      (4096, 4096, 32768)  bfloat16     row      1.664                       1.655
41       large_K      (4096, 4096, 32768)  bfloat16     column   1.660                       1.667
42       large_K      (4096, 4096, 32768)  float32      full     22.263                      22.305
43       large_K      (4096, 4096, 32768)  float32      row      22.257                      22.322
44       large_K      (4096, 4096, 32768)  float32      column   22.247                      22.337
45       large_K      (2048, 2048, 16384)  float16      full     0.236                       0.236
46       large_K      (2048, 2048, 16384)  float16      row      0.238                       0.239
47       large_K      (2048, 2048, 16384)  float16      column   0.238                       0.239
48       large_K      (2048, 2048, 16384)  bfloat16     full     0.219                       0.219
49       large_K      (2048, 2048, 16384)  bfloat16     row      0.221                       0.222
50       large_K      (2048, 2048, 16384)  bfloat16     column   0.222                       0.222
51       large_K      (2048, 2048, 16384)  float32      full     2.786                       2.789
52       large_K      (2048, 2048, 16384)  float32      row      2.790                       2.782
53       large_K      (2048, 2048, 16384)  float32      column   2.791                       2.791
54       large_N      (8192, 65536, 8192)  float16      full     14.692                      14.723
55       large_N      (8192, 65536, 8192)  float16      row      14.721                      14.637
56       large_N      (8192, 65536, 8192)  float16      column   14.743                      14.737
57       large_N      (8192, 65536, 8192)  bfloat16     full     15.156                      15.128
58       large_N      (8192, 65536, 8192)  bfloat16     row      15.152                      15.124
59       large_N      (8192, 65536, 8192)  bfloat16     column   15.112                      15.090
60       large_N      (8192, 65536, 8192)  float32      full     179.127                     179.313
61       large_N      (8192, 65536, 8192)  float32      row      178.445                     178.961
62       large_N      (8192, 65536, 8192)  float32      column   178.693                     178.694
63       large_N      (4096, 32768, 4096)  float16      full     2.035                       2.037
64       large_N      (4096, 32768, 4096)  float16      row      2.042                       2.037
65       large_N      (4096, 32768, 4096)  float16      column   2.053                       2.046
66       large_N      (4096, 32768, 4096)  bfloat16     full     1.992                       1.997
67       large_N      (4096, 32768, 4096)  bfloat16     row      1.997                       1.987
68       large_N      (4096, 32768, 4096)  bfloat16     column   2.005                       2.001
69       large_N      (4096, 32768, 4096)  float32      full     23.126                      23.077
70       large_N      (4096, 32768, 4096)  float32      row      23.002                      22.956
71       large_N      (4096, 32768, 4096)  float32      column   23.012                      22.969
72       large_N      (2048, 16384, 2048)  float16      full     0.314                       0.314
73       large_N      (2048, 16384, 2048)  float16      row      0.311                       0.311
74       large_N      (2048, 16384, 2048)  float16      column   0.314                       0.314
75       large_N      (2048, 16384, 2048)  bfloat16     full     0.306                       0.305
76       large_N      (2048, 16384, 2048)  bfloat16     row      0.302                       0.302
77       large_N      (2048, 16384, 2048)  bfloat16     column   0.305                       0.303
78       large_N      (2048, 16384, 2048)  float32      full     2.975                       2.971
79       large_N      (2048, 16384, 2048)  float32      row      2.927                       2.925
80       large_N      (2048, 16384, 2048)  float32      column   2.934                       2.936
81       large_all    (16384, 16384, 16384) float16      full     14.062                      14.096
82       large_all    (16384, 16384, 16384) float16      row      14.058                      14.078
83       large_all    (16384, 16384, 16384) float16      column   14.107                      14.120
84       large_all    (16384, 16384, 16384) bfloat16     full     13.504                      13.460
85       large_all    (16384, 16384, 16384) bfloat16     row      13.495                      13.499
86       large_all    (16384, 16384, 16384) bfloat16     column   13.509                      13.461
87       large_all    (16384, 16384, 16384) float32      full     177.279                     177.242
88       large_all    (16384, 16384, 16384) float32      row      176.896                     176.651
89       large_all    (16384, 16384, 16384) float32      column   176.830                     176.451
```

\## script

```
"""
Torch addmm benchmarking script covering different input configurations.
Tests 30 different combinations of:
- Input types: large M, large K, large N, large everything
- Data types: float16, bfloat16, float32
- Bias types: full, row, column
"""

import itertools

import torch
import torch.utils.benchmark as benchmark

def create_test_configurations():
    """Create 30 different test configurations for torch.addmm"""

    # Define dimension configurations
    # Large M: many rows in input matrix
    # Large K: many columns in input/rows in weight
    # Large N: many columns in weight matrix
    # Large everything: all dimensions large
    dim_configs = [
        # Large M configurations
        {"M": 65536, "K": 8192, "N": 8192, "type": "large_M"},
        {"M": 32768, "K": 4096, "N": 4096, "type": "large_M"},
        {"M": 16384, "K": 2048, "N": 2048, "type": "large_M"},
        # Large K configurations
        {"M": 8192, "K": 65536, "N": 8192, "type": "large_K"},
        {"M": 4096, "K": 32768, "N": 4096, "type": "large_K"},
        {"M": 2048, "K": 16384, "N": 2048, "type": "large_K"},
        # Large N configurations
        {"M": 8192, "K": 8192, "N": 65536, "type": "large_N"},
        {"M": 4096, "K": 4096, "N": 32768, "type": "large_N"},
        {"M": 2048, "K": 2048, "N": 16384, "type": "large_N"},
        # Large everything configurations
        {"M": 16384, "K": 16384, "N": 16384, "type": "large_all"},
    ]

    # Data types to test
    dtypes = [torch.float16, torch.bfloat16, torch.float32]

    # Bias configurations
    bias_configs = ["full", "row", "column"]

    # Generate all combinations
    configurations = []
    config_id = 0

    for dim_config, dtype, bias_type in itertools.product(
        dim_configs, dtypes, bias_configs
    ):
        config = {
            "id": config_id,
            "M": dim_config["M"],
            "K": dim_config["K"],
            "N": dim_config["N"],
            "dim_type": dim_config["type"],
            "dtype": dtype,
            "bias_type": bias_type,
        }
        configurations.append(config)
        config_id += 1

    return configurations

def create_tensors(config):
    """Create input tensors for a given configuration"""
    M, K, N = config["M"], config["K"], config["N"]
    dtype = config["dtype"]
    bias_type = config["bias_type"]

    # Create input tensor (M x K)
    input_tensor = torch.randn(M, K, dtype=dtype, device="cuda")

    # Create weight tensor (K x N)
    weight_tensor = torch.randn(K, N, dtype=dtype, device="cuda")

    # Create bias tensor based on bias type
    if bias_type == "full":
        bias_tensor = torch.randn(M, N, dtype=dtype, device="cuda")
    elif bias_type == "row":
        bias_tensor = torch.randn(M, 1, dtype=dtype, device="cuda")
    elif bias_type == "column":
        bias_tensor = torch.randn(1, N, dtype=dtype, device="cuda")

    return input_tensor, weight_tensor, bias_tensor

def benchmark_addmm(config, use_compile=False):
    """Benchmark torch.addmm for a given configuration"""
    input_tensor, weight_tensor, bias_tensor = create_tensors(config)

    # Define the operation
    def addmm_op():
        return torch.addmm(bias_tensor, input_tensor, weight_tensor)

    # Optionally compile the operation
    if use_compile:
        addmm_op = torch.compile(addmm_op)
        # Warmup for compiled version
        for _ in range(3):
            _ = addmm_op()
        torch.cuda.synchronize()

    # Benchmark using torch.utils.benchmark
    timer = benchmark.Timer(
        stmt="addmm_op()",
        globals={"addmm_op": addmm_op},
        description=f"addmm_{config['dim_type']}_{config['dtype']}_{config['bias_type']}_compiled_{use_compile}",
    )

    # Run benchmark
    measurement = timer.blocked_autorange(min_run_time=1.0)

    return measurement

def print_tensor_info(config, input_tensor, weight_tensor, bias_tensor):
    """Print information about input tensors"""
    print(f"\\nConfiguration {config['id']}:")
    print(f"  Dimension type: {config['dim_type']}")
    print(f"  Data type: {config['dtype']}")
    print(f"  Bias type: {config['bias_type']}")
    print(f"  Input tensor shape: {input_tensor.shape} ({input_tensor.dtype})")
    print(f"  Weight tensor shape: {weight_tensor.shape} ({weight_tensor.dtype})")
    print(f"  Bias tensor shape: {bias_tensor.shape} ({bias_tensor.dtype})")

def main():
    """Main benchmarking function"""
    print("Torch addmm Benchmarking Script (Compiled Only)")
    print("=" * 50)

    # Check CUDA availability
    if not torch.cuda.is_available():
        print("CUDA not available. This benchmark requires CUDA.")
        return

    print(f"Using device: {torch.cuda.get_device_name()}")
    print(f"PyTorch version: {torch.__version__}")

    # Create test configurations
    configurations = create_test_configurations()
    print(f"\\nTesting {len(configurations)} configurations...")

    results = []

    for config in configurations:
        print(f"\\n{'='*60}")

        # Create tensors and print info
        input_tensor, weight_tensor, bias_tensor = create_tensors(config)
        print_tensor_info(config, input_tensor, weight_tensor, bias_tensor)

        # Benchmark with compilation only
        print("\\nBenchmarking with torch.compile:")
        try:
            measurement = benchmark_addmm(config, use_compile=True)
            runtime_ms = measurement.mean * 1000
            print(f"  Runtime: {runtime_ms:.3f} ms")
            results.append(
                {
                    "config": config,
                    "runtime_ms": runtime_ms,
                }
            )
        except Exception as e:
            print(f"  Error: {e}")

        # Clear cache
        torch.cuda.empty_cache()

    # Print summary
    print(f"\\n\\n{'='*80}")
    print("BENCHMARK SUMMARY")
    print(f"{'='*80}")

    print(
        f"{'Config':<8} {'Dim Type':<12} {'(M, N, K)':<20} {'DType':<12} {'Bias':<8} {'Runtime (ms)':<15}"
    )
    print("-" * 90)

    for result in results:
        config = result["config"]
        dtype_str = str(config["dtype"]).split(".")[-1]
        dimensions_str = f"({config['M']}, {config['N']}, {config['K']})"
        print(
            f"{config['id']:<8} {config['dim_type']:<12} {dimensions_str:<20} {dtype_str:<12} {config['bias_type']:<8} {result['runtime_ms']:<15.3f}"
        )

if __name__ == "__main__":
    main()
```

ghstack-source-id: e788401
Pull Request resolved: #161208
@coconutruben coconutruben requested a review from eellison August 21, 2025 21:45
@coconutruben coconutruben added the topic: not user facing topic category label Aug 21, 2025
# why

- simplifies the code a lot
- unnecessary for performance (see below)
- matches other mm family kernels in logic now

# what

- if we're not autotuning (only ATen), we use a flexible layout
- in any case, we use inp_expanded (the expanded bias view) for
  the aten kernel

# performance analysis

## results

this is on H100

```
================================================================================
BENCHMARK SUMMARY (MERGED)
================================================================================
Config   Dim Type     (M, N, K)            DType        Bias     Runtime (inp_expanded) ms   Runtime (inp) ms
----------------------------------------------------------------------------------------------------------------------------
0        large_M      (65536, 8192, 8192)  float16      full     14.531                      14.562
1        large_M      (65536, 8192, 8192)  float16      row      14.682                      14.675
2        large_M      (65536, 8192, 8192)  float16      column   14.754                      14.740
3        large_M      (65536, 8192, 8192)  bfloat16     full     15.172                      15.148
4        large_M      (65536, 8192, 8192)  bfloat16     row      15.072                      15.085
5        large_M      (65536, 8192, 8192)  bfloat16     column   15.082                      15.114
6        large_M      (65536, 8192, 8192)  float32      full     185.726                     186.308
7        large_M      (65536, 8192, 8192)  float32      row      185.042                     185.864
8        large_M      (65536, 8192, 8192)  float32      column   185.221                     185.991
9        large_M      (32768, 4096, 4096)  float16      full     2.025                       2.036
10       large_M      (32768, 4096, 4096)  float16      row      2.029                       2.033
11       large_M      (32768, 4096, 4096)  float16      column   2.036                       2.047
12       large_M      (32768, 4096, 4096)  bfloat16     full     1.966                       1.981
13       large_M      (32768, 4096, 4096)  bfloat16     row      1.963                       1.979
14       large_M      (32768, 4096, 4096)  bfloat16     column   1.973                       1.981
15       large_M      (32768, 4096, 4096)  float32      full     24.096                      24.180
16       large_M      (32768, 4096, 4096)  float32      row      23.951                      24.033
17       large_M      (32768, 4096, 4096)  float32      column   23.996                      24.061
18       large_M      (16384, 2048, 2048)  float16      full     0.297                       0.298
19       large_M      (16384, 2048, 2048)  float16      row      0.298                       0.299
20       large_M      (16384, 2048, 2048)  float16      column   0.301                       0.300
21       large_M      (16384, 2048, 2048)  bfloat16     full     0.293                       0.293
22       large_M      (16384, 2048, 2048)  bfloat16     row      0.290                       0.291
23       large_M      (16384, 2048, 2048)  bfloat16     column   0.292                       0.293
24       large_M      (16384, 2048, 2048)  float32      full     3.077                       3.073
25       large_M      (16384, 2048, 2048)  float32      row      3.034                       3.033
26       large_M      (16384, 2048, 2048)  float32      column   3.040                       3.038
27       large_K      (8192, 8192, 65536)  float16      full     14.278                      14.297
28       large_K      (8192, 8192, 65536)  float16      row      14.325                      14.283
29       large_K      (8192, 8192, 65536)  float16      column   14.179                      14.302
30       large_K      (8192, 8192, 65536)  bfloat16     full     13.616                      13.628
31       large_K      (8192, 8192, 65536)  bfloat16     row      13.584                      13.642
32       large_K      (8192, 8192, 65536)  bfloat16     column   13.594                      13.694
33       large_K      (8192, 8192, 65536)  float32      full     175.933                     176.153
34       large_K      (8192, 8192, 65536)  float32      row      175.504                     175.877
35       large_K      (8192, 8192, 65536)  float32      column   175.432                     175.992
36       large_K      (4096, 4096, 32768)  float16      full     1.726                       1.724
37       large_K      (4096, 4096, 32768)  float16      row      1.731                       1.735
38       large_K      (4096, 4096, 32768)  float16      column   1.733                       1.737
39       large_K      (4096, 4096, 32768)  bfloat16     full     1.662                       1.658
40       large_K      (4096, 4096, 32768)  bfloat16     row      1.664                       1.655
41       large_K      (4096, 4096, 32768)  bfloat16     column   1.660                       1.667
42       large_K      (4096, 4096, 32768)  float32      full     22.263                      22.305
43       large_K      (4096, 4096, 32768)  float32      row      22.257                      22.322
44       large_K      (4096, 4096, 32768)  float32      column   22.247                      22.337
45       large_K      (2048, 2048, 16384)  float16      full     0.236                       0.236
46       large_K      (2048, 2048, 16384)  float16      row      0.238                       0.239
47       large_K      (2048, 2048, 16384)  float16      column   0.238                       0.239
48       large_K      (2048, 2048, 16384)  bfloat16     full     0.219                       0.219
49       large_K      (2048, 2048, 16384)  bfloat16     row      0.221                       0.222
50       large_K      (2048, 2048, 16384)  bfloat16     column   0.222                       0.222
51       large_K      (2048, 2048, 16384)  float32      full     2.786                       2.789
52       large_K      (2048, 2048, 16384)  float32      row      2.790                       2.782
53       large_K      (2048, 2048, 16384)  float32      column   2.791                       2.791
54       large_N      (8192, 65536, 8192)  float16      full     14.692                      14.723
55       large_N      (8192, 65536, 8192)  float16      row      14.721                      14.637
56       large_N      (8192, 65536, 8192)  float16      column   14.743                      14.737
57       large_N      (8192, 65536, 8192)  bfloat16     full     15.156                      15.128
58       large_N      (8192, 65536, 8192)  bfloat16     row      15.152                      15.124
59       large_N      (8192, 65536, 8192)  bfloat16     column   15.112                      15.090
60       large_N      (8192, 65536, 8192)  float32      full     179.127                     179.313
61       large_N      (8192, 65536, 8192)  float32      row      178.445                     178.961
62       large_N      (8192, 65536, 8192)  float32      column   178.693                     178.694
63       large_N      (4096, 32768, 4096)  float16      full     2.035                       2.037
64       large_N      (4096, 32768, 4096)  float16      row      2.042                       2.037
65       large_N      (4096, 32768, 4096)  float16      column   2.053                       2.046
66       large_N      (4096, 32768, 4096)  bfloat16     full     1.992                       1.997
67       large_N      (4096, 32768, 4096)  bfloat16     row      1.997                       1.987
68       large_N      (4096, 32768, 4096)  bfloat16     column   2.005                       2.001
69       large_N      (4096, 32768, 4096)  float32      full     23.126                      23.077
70       large_N      (4096, 32768, 4096)  float32      row      23.002                      22.956
71       large_N      (4096, 32768, 4096)  float32      column   23.012                      22.969
72       large_N      (2048, 16384, 2048)  float16      full     0.314                       0.314
73       large_N      (2048, 16384, 2048)  float16      row      0.311                       0.311
74       large_N      (2048, 16384, 2048)  float16      column   0.314                       0.314
75       large_N      (2048, 16384, 2048)  bfloat16     full     0.306                       0.305
76       large_N      (2048, 16384, 2048)  bfloat16     row      0.302                       0.302
77       large_N      (2048, 16384, 2048)  bfloat16     column   0.305                       0.303
78       large_N      (2048, 16384, 2048)  float32      full     2.975                       2.971
79       large_N      (2048, 16384, 2048)  float32      row      2.927                       2.925
80       large_N      (2048, 16384, 2048)  float32      column   2.934                       2.936
81       large_all    (16384, 16384, 16384) float16      full     14.062                      14.096
82       large_all    (16384, 16384, 16384) float16      row      14.058                      14.078
83       large_all    (16384, 16384, 16384) float16      column   14.107                      14.120
84       large_all    (16384, 16384, 16384) bfloat16     full     13.504                      13.460
85       large_all    (16384, 16384, 16384) bfloat16     row      13.495                      13.499
86       large_all    (16384, 16384, 16384) bfloat16     column   13.509                      13.461
87       large_all    (16384, 16384, 16384) float32      full     177.279                     177.242
88       large_all    (16384, 16384, 16384) float32      row      176.896                     176.651
89       large_all    (16384, 16384, 16384) float32      column   176.830                     176.451
```

## script

```
"""
Torch addmm benchmarking script covering different input configurations.
Tests 30 different combinations of:
- Input types: large M, large K, large N, large everything
- Data types: float16, bfloat16, float32
- Bias types: full, row, column
"""

import itertools

import torch
import torch.utils.benchmark as benchmark

def create_test_configurations():
    """Create 30 different test configurations for torch.addmm"""

    # Define dimension configurations
    # Large M: many rows in input matrix
    # Large K: many columns in input/rows in weight
    # Large N: many columns in weight matrix
    # Large everything: all dimensions large
    dim_configs = [
        # Large M configurations
        {"M": 65536, "K": 8192, "N": 8192, "type": "large_M"},
        {"M": 32768, "K": 4096, "N": 4096, "type": "large_M"},
        {"M": 16384, "K": 2048, "N": 2048, "type": "large_M"},
        # Large K configurations
        {"M": 8192, "K": 65536, "N": 8192, "type": "large_K"},
        {"M": 4096, "K": 32768, "N": 4096, "type": "large_K"},
        {"M": 2048, "K": 16384, "N": 2048, "type": "large_K"},
        # Large N configurations
        {"M": 8192, "K": 8192, "N": 65536, "type": "large_N"},
        {"M": 4096, "K": 4096, "N": 32768, "type": "large_N"},
        {"M": 2048, "K": 2048, "N": 16384, "type": "large_N"},
        # Large everything configurations
        {"M": 16384, "K": 16384, "N": 16384, "type": "large_all"},
    ]

    # Data types to test
    dtypes = [torch.float16, torch.bfloat16, torch.float32]

    # Bias configurations
    bias_configs = ["full", "row", "column"]

    # Generate all combinations
    configurations = []
    config_id = 0

    for dim_config, dtype, bias_type in itertools.product(
        dim_configs, dtypes, bias_configs
    ):
        config = {
            "id": config_id,
            "M": dim_config["M"],
            "K": dim_config["K"],
            "N": dim_config["N"],
            "dim_type": dim_config["type"],
            "dtype": dtype,
            "bias_type": bias_type,
        }
        configurations.append(config)
        config_id += 1

    return configurations

def create_tensors(config):
    """Create input tensors for a given configuration"""
    M, K, N = config["M"], config["K"], config["N"]
    dtype = config["dtype"]
    bias_type = config["bias_type"]

    # Create input tensor (M x K)
    input_tensor = torch.randn(M, K, dtype=dtype, device="cuda")

    # Create weight tensor (K x N)
    weight_tensor = torch.randn(K, N, dtype=dtype, device="cuda")

    # Create bias tensor based on bias type
    if bias_type == "full":
        bias_tensor = torch.randn(M, N, dtype=dtype, device="cuda")
    elif bias_type == "row":
        bias_tensor = torch.randn(M, 1, dtype=dtype, device="cuda")
    elif bias_type == "column":
        bias_tensor = torch.randn(1, N, dtype=dtype, device="cuda")

    return input_tensor, weight_tensor, bias_tensor

def benchmark_addmm(config, use_compile=False):
    """Benchmark torch.addmm for a given configuration"""
    input_tensor, weight_tensor, bias_tensor = create_tensors(config)

    # Define the operation
    def addmm_op():
        return torch.addmm(bias_tensor, input_tensor, weight_tensor)

    # Optionally compile the operation
    if use_compile:
        addmm_op = torch.compile(addmm_op)
        # Warmup for compiled version
        for _ in range(3):
            _ = addmm_op()
        torch.cuda.synchronize()

    # Benchmark using torch.utils.benchmark
    timer = benchmark.Timer(
        stmt="addmm_op()",
        globals={"addmm_op": addmm_op},
        description=f"addmm_{config['dim_type']}_{config['dtype']}_{config['bias_type']}_compiled_{use_compile}",
    )

    # Run benchmark
    measurement = timer.blocked_autorange(min_run_time=1.0)

    return measurement

def print_tensor_info(config, input_tensor, weight_tensor, bias_tensor):
    """Print information about input tensors"""
    print(f"\\nConfiguration {config['id']}:")
    print(f"  Dimension type: {config['dim_type']}")
    print(f"  Data type: {config['dtype']}")
    print(f"  Bias type: {config['bias_type']}")
    print(f"  Input tensor shape: {input_tensor.shape} ({input_tensor.dtype})")
    print(f"  Weight tensor shape: {weight_tensor.shape} ({weight_tensor.dtype})")
    print(f"  Bias tensor shape: {bias_tensor.shape} ({bias_tensor.dtype})")

def main():
    """Main benchmarking function"""
    print("Torch addmm Benchmarking Script (Compiled Only)")
    print("=" * 50)

    # Check CUDA availability
    if not torch.cuda.is_available():
        print("CUDA not available. This benchmark requires CUDA.")
        return

    print(f"Using device: {torch.cuda.get_device_name()}")
    print(f"PyTorch version: {torch.__version__}")

    # Create test configurations
    configurations = create_test_configurations()
    print(f"\\nTesting {len(configurations)} configurations...")

    results = []

    for config in configurations:
        print(f"\\n{'='*60}")

        # Create tensors and print info
        input_tensor, weight_tensor, bias_tensor = create_tensors(config)
        print_tensor_info(config, input_tensor, weight_tensor, bias_tensor)

        # Benchmark with compilation only
        print("\\nBenchmarking with torch.compile:")
        try:
            measurement = benchmark_addmm(config, use_compile=True)
            runtime_ms = measurement.mean * 1000
            print(f"  Runtime: {runtime_ms:.3f} ms")
            results.append(
                {
                    "config": config,
                    "runtime_ms": runtime_ms,
                }
            )
        except Exception as e:
            print(f"  Error: {e}")

        # Clear cache
        torch.cuda.empty_cache()

    # Print summary
    print(f"\\n\\n{'='*80}")
    print("BENCHMARK SUMMARY")
    print(f"{'='*80}")

    print(
        f"{'Config':<8} {'Dim Type':<12} {'(M, N, K)':<20} {'DType':<12} {'Bias':<8} {'Runtime (ms)':<15}"
    )
    print("-" * 90)

    for result in results:
        config = result["config"]
        dtype_str = str(config["dtype"]).split(".")[-1]
        dimensions_str = f"({config['M']}, {config['N']}, {config['K']})"
        print(
            f"{config['id']:<8} {config['dim_type']:<12} {dimensions_str:<20} {dtype_str:<12} {config['bias_type']:<8} {result['runtime_ms']:<15.3f}"
        )

if __name__ == "__main__":
    main()
```

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov

[ghstack-poisoned]
coconutruben added a commit that referenced this pull request Aug 21, 2025
\# why

- simplifies the code a lot
- unnecessary for performance (see below)
- matches other mm family kernels in logic now

\# what

- if we're not autotuning (only ATen), we use a flexible layout
- in any case, we use inp_expanded (the expanded bias view) for
  the aten kernel

\# performance analysis

\## results

this is on H100

```
================================================================================
BENCHMARK SUMMARY (MERGED)
================================================================================
Config   Dim Type     (M, N, K)            DType        Bias     Runtime (inp_expanded) ms   Runtime (inp) ms
----------------------------------------------------------------------------------------------------------------------------
0        large_M      (65536, 8192, 8192)  float16      full     14.531                      14.562
1        large_M      (65536, 8192, 8192)  float16      row      14.682                      14.675
2        large_M      (65536, 8192, 8192)  float16      column   14.754                      14.740
3        large_M      (65536, 8192, 8192)  bfloat16     full     15.172                      15.148
4        large_M      (65536, 8192, 8192)  bfloat16     row      15.072                      15.085
5        large_M      (65536, 8192, 8192)  bfloat16     column   15.082                      15.114
6        large_M      (65536, 8192, 8192)  float32      full     185.726                     186.308
7        large_M      (65536, 8192, 8192)  float32      row      185.042                     185.864
8        large_M      (65536, 8192, 8192)  float32      column   185.221                     185.991
9        large_M      (32768, 4096, 4096)  float16      full     2.025                       2.036
10       large_M      (32768, 4096, 4096)  float16      row      2.029                       2.033
11       large_M      (32768, 4096, 4096)  float16      column   2.036                       2.047
12       large_M      (32768, 4096, 4096)  bfloat16     full     1.966                       1.981
13       large_M      (32768, 4096, 4096)  bfloat16     row      1.963                       1.979
14       large_M      (32768, 4096, 4096)  bfloat16     column   1.973                       1.981
15       large_M      (32768, 4096, 4096)  float32      full     24.096                      24.180
16       large_M      (32768, 4096, 4096)  float32      row      23.951                      24.033
17       large_M      (32768, 4096, 4096)  float32      column   23.996                      24.061
18       large_M      (16384, 2048, 2048)  float16      full     0.297                       0.298
19       large_M      (16384, 2048, 2048)  float16      row      0.298                       0.299
20       large_M      (16384, 2048, 2048)  float16      column   0.301                       0.300
21       large_M      (16384, 2048, 2048)  bfloat16     full     0.293                       0.293
22       large_M      (16384, 2048, 2048)  bfloat16     row      0.290                       0.291
23       large_M      (16384, 2048, 2048)  bfloat16     column   0.292                       0.293
24       large_M      (16384, 2048, 2048)  float32      full     3.077                       3.073
25       large_M      (16384, 2048, 2048)  float32      row      3.034                       3.033
26       large_M      (16384, 2048, 2048)  float32      column   3.040                       3.038
27       large_K      (8192, 8192, 65536)  float16      full     14.278                      14.297
28       large_K      (8192, 8192, 65536)  float16      row      14.325                      14.283
29       large_K      (8192, 8192, 65536)  float16      column   14.179                      14.302
30       large_K      (8192, 8192, 65536)  bfloat16     full     13.616                      13.628
31       large_K      (8192, 8192, 65536)  bfloat16     row      13.584                      13.642
32       large_K      (8192, 8192, 65536)  bfloat16     column   13.594                      13.694
33       large_K      (8192, 8192, 65536)  float32      full     175.933                     176.153
34       large_K      (8192, 8192, 65536)  float32      row      175.504                     175.877
35       large_K      (8192, 8192, 65536)  float32      column   175.432                     175.992
36       large_K      (4096, 4096, 32768)  float16      full     1.726                       1.724
37       large_K      (4096, 4096, 32768)  float16      row      1.731                       1.735
38       large_K      (4096, 4096, 32768)  float16      column   1.733                       1.737
39       large_K      (4096, 4096, 32768)  bfloat16     full     1.662                       1.658
40       large_K      (4096, 4096, 32768)  bfloat16     row      1.664                       1.655
41       large_K      (4096, 4096, 32768)  bfloat16     column   1.660                       1.667
42       large_K      (4096, 4096, 32768)  float32      full     22.263                      22.305
43       large_K      (4096, 4096, 32768)  float32      row      22.257                      22.322
44       large_K      (4096, 4096, 32768)  float32      column   22.247                      22.337
45       large_K      (2048, 2048, 16384)  float16      full     0.236                       0.236
46       large_K      (2048, 2048, 16384)  float16      row      0.238                       0.239
47       large_K      (2048, 2048, 16384)  float16      column   0.238                       0.239
48       large_K      (2048, 2048, 16384)  bfloat16     full     0.219                       0.219
49       large_K      (2048, 2048, 16384)  bfloat16     row      0.221                       0.222
50       large_K      (2048, 2048, 16384)  bfloat16     column   0.222                       0.222
51       large_K      (2048, 2048, 16384)  float32      full     2.786                       2.789
52       large_K      (2048, 2048, 16384)  float32      row      2.790                       2.782
53       large_K      (2048, 2048, 16384)  float32      column   2.791                       2.791
54       large_N      (8192, 65536, 8192)  float16      full     14.692                      14.723
55       large_N      (8192, 65536, 8192)  float16      row      14.721                      14.637
56       large_N      (8192, 65536, 8192)  float16      column   14.743                      14.737
57       large_N      (8192, 65536, 8192)  bfloat16     full     15.156                      15.128
58       large_N      (8192, 65536, 8192)  bfloat16     row      15.152                      15.124
59       large_N      (8192, 65536, 8192)  bfloat16     column   15.112                      15.090
60       large_N      (8192, 65536, 8192)  float32      full     179.127                     179.313
61       large_N      (8192, 65536, 8192)  float32      row      178.445                     178.961
62       large_N      (8192, 65536, 8192)  float32      column   178.693                     178.694
63       large_N      (4096, 32768, 4096)  float16      full     2.035                       2.037
64       large_N      (4096, 32768, 4096)  float16      row      2.042                       2.037
65       large_N      (4096, 32768, 4096)  float16      column   2.053                       2.046
66       large_N      (4096, 32768, 4096)  bfloat16     full     1.992                       1.997
67       large_N      (4096, 32768, 4096)  bfloat16     row      1.997                       1.987
68       large_N      (4096, 32768, 4096)  bfloat16     column   2.005                       2.001
69       large_N      (4096, 32768, 4096)  float32      full     23.126                      23.077
70       large_N      (4096, 32768, 4096)  float32      row      23.002                      22.956
71       large_N      (4096, 32768, 4096)  float32      column   23.012                      22.969
72       large_N      (2048, 16384, 2048)  float16      full     0.314                       0.314
73       large_N      (2048, 16384, 2048)  float16      row      0.311                       0.311
74       large_N      (2048, 16384, 2048)  float16      column   0.314                       0.314
75       large_N      (2048, 16384, 2048)  bfloat16     full     0.306                       0.305
76       large_N      (2048, 16384, 2048)  bfloat16     row      0.302                       0.302
77       large_N      (2048, 16384, 2048)  bfloat16     column   0.305                       0.303
78       large_N      (2048, 16384, 2048)  float32      full     2.975                       2.971
79       large_N      (2048, 16384, 2048)  float32      row      2.927                       2.925
80       large_N      (2048, 16384, 2048)  float32      column   2.934                       2.936
81       large_all    (16384, 16384, 16384) float16      full     14.062                      14.096
82       large_all    (16384, 16384, 16384) float16      row      14.058                      14.078
83       large_all    (16384, 16384, 16384) float16      column   14.107                      14.120
84       large_all    (16384, 16384, 16384) bfloat16     full     13.504                      13.460
85       large_all    (16384, 16384, 16384) bfloat16     row      13.495                      13.499
86       large_all    (16384, 16384, 16384) bfloat16     column   13.509                      13.461
87       large_all    (16384, 16384, 16384) float32      full     177.279                     177.242
88       large_all    (16384, 16384, 16384) float32      row      176.896                     176.651
89       large_all    (16384, 16384, 16384) float32      column   176.830                     176.451
```

\## script

```
"""
Torch addmm benchmarking script covering different input configurations.
Tests 30 different combinations of:
- Input types: large M, large K, large N, large everything
- Data types: float16, bfloat16, float32
- Bias types: full, row, column
"""

import itertools

import torch
import torch.utils.benchmark as benchmark

def create_test_configurations():
    """Create 30 different test configurations for torch.addmm"""

    # Define dimension configurations
    # Large M: many rows in input matrix
    # Large K: many columns in input/rows in weight
    # Large N: many columns in weight matrix
    # Large everything: all dimensions large
    dim_configs = [
        # Large M configurations
        {"M": 65536, "K": 8192, "N": 8192, "type": "large_M"},
        {"M": 32768, "K": 4096, "N": 4096, "type": "large_M"},
        {"M": 16384, "K": 2048, "N": 2048, "type": "large_M"},
        # Large K configurations
        {"M": 8192, "K": 65536, "N": 8192, "type": "large_K"},
        {"M": 4096, "K": 32768, "N": 4096, "type": "large_K"},
        {"M": 2048, "K": 16384, "N": 2048, "type": "large_K"},
        # Large N configurations
        {"M": 8192, "K": 8192, "N": 65536, "type": "large_N"},
        {"M": 4096, "K": 4096, "N": 32768, "type": "large_N"},
        {"M": 2048, "K": 2048, "N": 16384, "type": "large_N"},
        # Large everything configurations
        {"M": 16384, "K": 16384, "N": 16384, "type": "large_all"},
    ]

    # Data types to test
    dtypes = [torch.float16, torch.bfloat16, torch.float32]

    # Bias configurations
    bias_configs = ["full", "row", "column"]

    # Generate all combinations
    configurations = []
    config_id = 0

    for dim_config, dtype, bias_type in itertools.product(
        dim_configs, dtypes, bias_configs
    ):
        config = {
            "id": config_id,
            "M": dim_config["M"],
            "K": dim_config["K"],
            "N": dim_config["N"],
            "dim_type": dim_config["type"],
            "dtype": dtype,
            "bias_type": bias_type,
        }
        configurations.append(config)
        config_id += 1

    return configurations

def create_tensors(config):
    """Create input tensors for a given configuration"""
    M, K, N = config["M"], config["K"], config["N"]
    dtype = config["dtype"]
    bias_type = config["bias_type"]

    # Create input tensor (M x K)
    input_tensor = torch.randn(M, K, dtype=dtype, device="cuda")

    # Create weight tensor (K x N)
    weight_tensor = torch.randn(K, N, dtype=dtype, device="cuda")

    # Create bias tensor based on bias type
    if bias_type == "full":
        bias_tensor = torch.randn(M, N, dtype=dtype, device="cuda")
    elif bias_type == "row":
        bias_tensor = torch.randn(M, 1, dtype=dtype, device="cuda")
    elif bias_type == "column":
        bias_tensor = torch.randn(1, N, dtype=dtype, device="cuda")

    return input_tensor, weight_tensor, bias_tensor

def benchmark_addmm(config, use_compile=False):
    """Benchmark torch.addmm for a given configuration"""
    input_tensor, weight_tensor, bias_tensor = create_tensors(config)

    # Define the operation
    def addmm_op():
        return torch.addmm(bias_tensor, input_tensor, weight_tensor)

    # Optionally compile the operation
    if use_compile:
        addmm_op = torch.compile(addmm_op)
        # Warmup for compiled version
        for _ in range(3):
            _ = addmm_op()
        torch.cuda.synchronize()

    # Benchmark using torch.utils.benchmark
    timer = benchmark.Timer(
        stmt="addmm_op()",
        globals={"addmm_op": addmm_op},
        description=f"addmm_{config['dim_type']}_{config['dtype']}_{config['bias_type']}_compiled_{use_compile}",
    )

    # Run benchmark
    measurement = timer.blocked_autorange(min_run_time=1.0)

    return measurement

def print_tensor_info(config, input_tensor, weight_tensor, bias_tensor):
    """Print information about input tensors"""
    print(f"\\nConfiguration {config['id']}:")
    print(f"  Dimension type: {config['dim_type']}")
    print(f"  Data type: {config['dtype']}")
    print(f"  Bias type: {config['bias_type']}")
    print(f"  Input tensor shape: {input_tensor.shape} ({input_tensor.dtype})")
    print(f"  Weight tensor shape: {weight_tensor.shape} ({weight_tensor.dtype})")
    print(f"  Bias tensor shape: {bias_tensor.shape} ({bias_tensor.dtype})")

def main():
    """Main benchmarking function"""
    print("Torch addmm Benchmarking Script (Compiled Only)")
    print("=" * 50)

    # Check CUDA availability
    if not torch.cuda.is_available():
        print("CUDA not available. This benchmark requires CUDA.")
        return

    print(f"Using device: {torch.cuda.get_device_name()}")
    print(f"PyTorch version: {torch.__version__}")

    # Create test configurations
    configurations = create_test_configurations()
    print(f"\\nTesting {len(configurations)} configurations...")

    results = []

    for config in configurations:
        print(f"\\n{'='*60}")

        # Create tensors and print info
        input_tensor, weight_tensor, bias_tensor = create_tensors(config)
        print_tensor_info(config, input_tensor, weight_tensor, bias_tensor)

        # Benchmark with compilation only
        print("\\nBenchmarking with torch.compile:")
        try:
            measurement = benchmark_addmm(config, use_compile=True)
            runtime_ms = measurement.mean * 1000
            print(f"  Runtime: {runtime_ms:.3f} ms")
            results.append(
                {
                    "config": config,
                    "runtime_ms": runtime_ms,
                }
            )
        except Exception as e:
            print(f"  Error: {e}")

        # Clear cache
        torch.cuda.empty_cache()

    # Print summary
    print(f"\\n\\n{'='*80}")
    print("BENCHMARK SUMMARY")
    print(f"{'='*80}")

    print(
        f"{'Config':<8} {'Dim Type':<12} {'(M, N, K)':<20} {'DType':<12} {'Bias':<8} {'Runtime (ms)':<15}"
    )
    print("-" * 90)

    for result in results:
        config = result["config"]
        dtype_str = str(config["dtype"]).split(".")[-1]
        dimensions_str = f"({config['M']}, {config['N']}, {config['K']})"
        print(
            f"{config['id']:<8} {config['dim_type']:<12} {dimensions_str:<20} {dtype_str:<12} {config['bias_type']:<8} {result['runtime_ms']:<15.3f}"
        )

if __name__ == "__main__":
    main()
```

ghstack-source-id: a544f1d
Pull Request resolved: #161208
@coconutruben
Copy link
Contributor Author

@jataylo could you take a look as well to make sure this does not regress AMD performance?

@coconutruben coconutruben requested a review from jataylo August 21, 2025 22:11
# why

- simplifies the code a lot
- unnecessary for performance (see below)
- matches other mm family kernels in logic now

# what

- if we're not autotuning (only ATen), we use a flexible layout
- in any case, we use inp_expanded (the expanded bias view) for
  the aten kernel

# performance analysis

## results

this is on H100

```
================================================================================
BENCHMARK SUMMARY (MERGED)
================================================================================
Config   Dim Type     (M, N, K)            DType        Bias     Runtime (inp_expanded) ms   Runtime (inp) ms
----------------------------------------------------------------------------------------------------------------------------
0        large_M      (65536, 8192, 8192)  float16      full     14.531                      14.562
1        large_M      (65536, 8192, 8192)  float16      row      14.682                      14.675
2        large_M      (65536, 8192, 8192)  float16      column   14.754                      14.740
3        large_M      (65536, 8192, 8192)  bfloat16     full     15.172                      15.148
4        large_M      (65536, 8192, 8192)  bfloat16     row      15.072                      15.085
5        large_M      (65536, 8192, 8192)  bfloat16     column   15.082                      15.114
6        large_M      (65536, 8192, 8192)  float32      full     185.726                     186.308
7        large_M      (65536, 8192, 8192)  float32      row      185.042                     185.864
8        large_M      (65536, 8192, 8192)  float32      column   185.221                     185.991
9        large_M      (32768, 4096, 4096)  float16      full     2.025                       2.036
10       large_M      (32768, 4096, 4096)  float16      row      2.029                       2.033
11       large_M      (32768, 4096, 4096)  float16      column   2.036                       2.047
12       large_M      (32768, 4096, 4096)  bfloat16     full     1.966                       1.981
13       large_M      (32768, 4096, 4096)  bfloat16     row      1.963                       1.979
14       large_M      (32768, 4096, 4096)  bfloat16     column   1.973                       1.981
15       large_M      (32768, 4096, 4096)  float32      full     24.096                      24.180
16       large_M      (32768, 4096, 4096)  float32      row      23.951                      24.033
17       large_M      (32768, 4096, 4096)  float32      column   23.996                      24.061
18       large_M      (16384, 2048, 2048)  float16      full     0.297                       0.298
19       large_M      (16384, 2048, 2048)  float16      row      0.298                       0.299
20       large_M      (16384, 2048, 2048)  float16      column   0.301                       0.300
21       large_M      (16384, 2048, 2048)  bfloat16     full     0.293                       0.293
22       large_M      (16384, 2048, 2048)  bfloat16     row      0.290                       0.291
23       large_M      (16384, 2048, 2048)  bfloat16     column   0.292                       0.293
24       large_M      (16384, 2048, 2048)  float32      full     3.077                       3.073
25       large_M      (16384, 2048, 2048)  float32      row      3.034                       3.033
26       large_M      (16384, 2048, 2048)  float32      column   3.040                       3.038
27       large_K      (8192, 8192, 65536)  float16      full     14.278                      14.297
28       large_K      (8192, 8192, 65536)  float16      row      14.325                      14.283
29       large_K      (8192, 8192, 65536)  float16      column   14.179                      14.302
30       large_K      (8192, 8192, 65536)  bfloat16     full     13.616                      13.628
31       large_K      (8192, 8192, 65536)  bfloat16     row      13.584                      13.642
32       large_K      (8192, 8192, 65536)  bfloat16     column   13.594                      13.694
33       large_K      (8192, 8192, 65536)  float32      full     175.933                     176.153
34       large_K      (8192, 8192, 65536)  float32      row      175.504                     175.877
35       large_K      (8192, 8192, 65536)  float32      column   175.432                     175.992
36       large_K      (4096, 4096, 32768)  float16      full     1.726                       1.724
37       large_K      (4096, 4096, 32768)  float16      row      1.731                       1.735
38       large_K      (4096, 4096, 32768)  float16      column   1.733                       1.737
39       large_K      (4096, 4096, 32768)  bfloat16     full     1.662                       1.658
40       large_K      (4096, 4096, 32768)  bfloat16     row      1.664                       1.655
41       large_K      (4096, 4096, 32768)  bfloat16     column   1.660                       1.667
42       large_K      (4096, 4096, 32768)  float32      full     22.263                      22.305
43       large_K      (4096, 4096, 32768)  float32      row      22.257                      22.322
44       large_K      (4096, 4096, 32768)  float32      column   22.247                      22.337
45       large_K      (2048, 2048, 16384)  float16      full     0.236                       0.236
46       large_K      (2048, 2048, 16384)  float16      row      0.238                       0.239
47       large_K      (2048, 2048, 16384)  float16      column   0.238                       0.239
48       large_K      (2048, 2048, 16384)  bfloat16     full     0.219                       0.219
49       large_K      (2048, 2048, 16384)  bfloat16     row      0.221                       0.222
50       large_K      (2048, 2048, 16384)  bfloat16     column   0.222                       0.222
51       large_K      (2048, 2048, 16384)  float32      full     2.786                       2.789
52       large_K      (2048, 2048, 16384)  float32      row      2.790                       2.782
53       large_K      (2048, 2048, 16384)  float32      column   2.791                       2.791
54       large_N      (8192, 65536, 8192)  float16      full     14.692                      14.723
55       large_N      (8192, 65536, 8192)  float16      row      14.721                      14.637
56       large_N      (8192, 65536, 8192)  float16      column   14.743                      14.737
57       large_N      (8192, 65536, 8192)  bfloat16     full     15.156                      15.128
58       large_N      (8192, 65536, 8192)  bfloat16     row      15.152                      15.124
59       large_N      (8192, 65536, 8192)  bfloat16     column   15.112                      15.090
60       large_N      (8192, 65536, 8192)  float32      full     179.127                     179.313
61       large_N      (8192, 65536, 8192)  float32      row      178.445                     178.961
62       large_N      (8192, 65536, 8192)  float32      column   178.693                     178.694
63       large_N      (4096, 32768, 4096)  float16      full     2.035                       2.037
64       large_N      (4096, 32768, 4096)  float16      row      2.042                       2.037
65       large_N      (4096, 32768, 4096)  float16      column   2.053                       2.046
66       large_N      (4096, 32768, 4096)  bfloat16     full     1.992                       1.997
67       large_N      (4096, 32768, 4096)  bfloat16     row      1.997                       1.987
68       large_N      (4096, 32768, 4096)  bfloat16     column   2.005                       2.001
69       large_N      (4096, 32768, 4096)  float32      full     23.126                      23.077
70       large_N      (4096, 32768, 4096)  float32      row      23.002                      22.956
71       large_N      (4096, 32768, 4096)  float32      column   23.012                      22.969
72       large_N      (2048, 16384, 2048)  float16      full     0.314                       0.314
73       large_N      (2048, 16384, 2048)  float16      row      0.311                       0.311
74       large_N      (2048, 16384, 2048)  float16      column   0.314                       0.314
75       large_N      (2048, 16384, 2048)  bfloat16     full     0.306                       0.305
76       large_N      (2048, 16384, 2048)  bfloat16     row      0.302                       0.302
77       large_N      (2048, 16384, 2048)  bfloat16     column   0.305                       0.303
78       large_N      (2048, 16384, 2048)  float32      full     2.975                       2.971
79       large_N      (2048, 16384, 2048)  float32      row      2.927                       2.925
80       large_N      (2048, 16384, 2048)  float32      column   2.934                       2.936
81       large_all    (16384, 16384, 16384) float16      full     14.062                      14.096
82       large_all    (16384, 16384, 16384) float16      row      14.058                      14.078
83       large_all    (16384, 16384, 16384) float16      column   14.107                      14.120
84       large_all    (16384, 16384, 16384) bfloat16     full     13.504                      13.460
85       large_all    (16384, 16384, 16384) bfloat16     row      13.495                      13.499
86       large_all    (16384, 16384, 16384) bfloat16     column   13.509                      13.461
87       large_all    (16384, 16384, 16384) float32      full     177.279                     177.242
88       large_all    (16384, 16384, 16384) float32      row      176.896                     176.651
89       large_all    (16384, 16384, 16384) float32      column   176.830                     176.451
```

## script

```
"""
Torch addmm benchmarking script covering different input configurations.
Tests 30 different combinations of:
- Input types: large M, large K, large N, large everything
- Data types: float16, bfloat16, float32
- Bias types: full, row, column
"""

import itertools

import torch
import torch.utils.benchmark as benchmark

def create_test_configurations():
    """Create 30 different test configurations for torch.addmm"""

    # Define dimension configurations
    # Large M: many rows in input matrix
    # Large K: many columns in input/rows in weight
    # Large N: many columns in weight matrix
    # Large everything: all dimensions large
    dim_configs = [
        # Large M configurations
        {"M": 65536, "K": 8192, "N": 8192, "type": "large_M"},
        {"M": 32768, "K": 4096, "N": 4096, "type": "large_M"},
        {"M": 16384, "K": 2048, "N": 2048, "type": "large_M"},
        # Large K configurations
        {"M": 8192, "K": 65536, "N": 8192, "type": "large_K"},
        {"M": 4096, "K": 32768, "N": 4096, "type": "large_K"},
        {"M": 2048, "K": 16384, "N": 2048, "type": "large_K"},
        # Large N configurations
        {"M": 8192, "K": 8192, "N": 65536, "type": "large_N"},
        {"M": 4096, "K": 4096, "N": 32768, "type": "large_N"},
        {"M": 2048, "K": 2048, "N": 16384, "type": "large_N"},
        # Large everything configurations
        {"M": 16384, "K": 16384, "N": 16384, "type": "large_all"},
    ]

    # Data types to test
    dtypes = [torch.float16, torch.bfloat16, torch.float32]

    # Bias configurations
    bias_configs = ["full", "row", "column"]

    # Generate all combinations
    configurations = []
    config_id = 0

    for dim_config, dtype, bias_type in itertools.product(
        dim_configs, dtypes, bias_configs
    ):
        config = {
            "id": config_id,
            "M": dim_config["M"],
            "K": dim_config["K"],
            "N": dim_config["N"],
            "dim_type": dim_config["type"],
            "dtype": dtype,
            "bias_type": bias_type,
        }
        configurations.append(config)
        config_id += 1

    return configurations

def create_tensors(config):
    """Create input tensors for a given configuration"""
    M, K, N = config["M"], config["K"], config["N"]
    dtype = config["dtype"]
    bias_type = config["bias_type"]

    # Create input tensor (M x K)
    input_tensor = torch.randn(M, K, dtype=dtype, device="cuda")

    # Create weight tensor (K x N)
    weight_tensor = torch.randn(K, N, dtype=dtype, device="cuda")

    # Create bias tensor based on bias type
    if bias_type == "full":
        bias_tensor = torch.randn(M, N, dtype=dtype, device="cuda")
    elif bias_type == "row":
        bias_tensor = torch.randn(M, 1, dtype=dtype, device="cuda")
    elif bias_type == "column":
        bias_tensor = torch.randn(1, N, dtype=dtype, device="cuda")

    return input_tensor, weight_tensor, bias_tensor

def benchmark_addmm(config, use_compile=False):
    """Benchmark torch.addmm for a given configuration"""
    input_tensor, weight_tensor, bias_tensor = create_tensors(config)

    # Define the operation
    def addmm_op():
        return torch.addmm(bias_tensor, input_tensor, weight_tensor)

    # Optionally compile the operation
    if use_compile:
        addmm_op = torch.compile(addmm_op)
        # Warmup for compiled version
        for _ in range(3):
            _ = addmm_op()
        torch.cuda.synchronize()

    # Benchmark using torch.utils.benchmark
    timer = benchmark.Timer(
        stmt="addmm_op()",
        globals={"addmm_op": addmm_op},
        description=f"addmm_{config['dim_type']}_{config['dtype']}_{config['bias_type']}_compiled_{use_compile}",
    )

    # Run benchmark
    measurement = timer.blocked_autorange(min_run_time=1.0)

    return measurement

def print_tensor_info(config, input_tensor, weight_tensor, bias_tensor):
    """Print information about input tensors"""
    print(f"\\nConfiguration {config['id']}:")
    print(f"  Dimension type: {config['dim_type']}")
    print(f"  Data type: {config['dtype']}")
    print(f"  Bias type: {config['bias_type']}")
    print(f"  Input tensor shape: {input_tensor.shape} ({input_tensor.dtype})")
    print(f"  Weight tensor shape: {weight_tensor.shape} ({weight_tensor.dtype})")
    print(f"  Bias tensor shape: {bias_tensor.shape} ({bias_tensor.dtype})")

def main():
    """Main benchmarking function"""
    print("Torch addmm Benchmarking Script (Compiled Only)")
    print("=" * 50)

    # Check CUDA availability
    if not torch.cuda.is_available():
        print("CUDA not available. This benchmark requires CUDA.")
        return

    print(f"Using device: {torch.cuda.get_device_name()}")
    print(f"PyTorch version: {torch.__version__}")

    # Create test configurations
    configurations = create_test_configurations()
    print(f"\\nTesting {len(configurations)} configurations...")

    results = []

    for config in configurations:
        print(f"\\n{'='*60}")

        # Create tensors and print info
        input_tensor, weight_tensor, bias_tensor = create_tensors(config)
        print_tensor_info(config, input_tensor, weight_tensor, bias_tensor)

        # Benchmark with compilation only
        print("\\nBenchmarking with torch.compile:")
        try:
            measurement = benchmark_addmm(config, use_compile=True)
            runtime_ms = measurement.mean * 1000
            print(f"  Runtime: {runtime_ms:.3f} ms")
            results.append(
                {
                    "config": config,
                    "runtime_ms": runtime_ms,
                }
            )
        except Exception as e:
            print(f"  Error: {e}")

        # Clear cache
        torch.cuda.empty_cache()

    # Print summary
    print(f"\\n\\n{'='*80}")
    print("BENCHMARK SUMMARY")
    print(f"{'='*80}")

    print(
        f"{'Config':<8} {'Dim Type':<12} {'(M, N, K)':<20} {'DType':<12} {'Bias':<8} {'Runtime (ms)':<15}"
    )
    print("-" * 90)

    for result in results:
        config = result["config"]
        dtype_str = str(config["dtype"]).split(".")[-1]
        dimensions_str = f"({config['M']}, {config['N']}, {config['K']})"
        print(
            f"{config['id']:<8} {config['dim_type']:<12} {dimensions_str:<20} {dtype_str:<12} {config['bias_type']:<8} {result['runtime_ms']:<15.3f}"
        )

if __name__ == "__main__":
    main()
```

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov

[ghstack-poisoned]
# why

- simplifies the code a lot
- unnecessary for performance (see below)
- matches other mm family kernels in logic now

# what

- if we're not autotuning (only ATen), we use a flexible layout
- in any case, we use inp_expanded (the expanded bias view) for
  the aten kernel

# performance analysis

## results

this is on H100

```
================================================================================
BENCHMARK SUMMARY (MERGED)
================================================================================
Config   Dim Type     (M, N, K)            DType        Bias     Runtime (inp_expanded) ms   Runtime (inp) ms
----------------------------------------------------------------------------------------------------------------------------
0        large_M      (65536, 8192, 8192)  float16      full     14.531                      14.562
1        large_M      (65536, 8192, 8192)  float16      row      14.682                      14.675
2        large_M      (65536, 8192, 8192)  float16      column   14.754                      14.740
3        large_M      (65536, 8192, 8192)  bfloat16     full     15.172                      15.148
4        large_M      (65536, 8192, 8192)  bfloat16     row      15.072                      15.085
5        large_M      (65536, 8192, 8192)  bfloat16     column   15.082                      15.114
6        large_M      (65536, 8192, 8192)  float32      full     185.726                     186.308
7        large_M      (65536, 8192, 8192)  float32      row      185.042                     185.864
8        large_M      (65536, 8192, 8192)  float32      column   185.221                     185.991
9        large_M      (32768, 4096, 4096)  float16      full     2.025                       2.036
10       large_M      (32768, 4096, 4096)  float16      row      2.029                       2.033
11       large_M      (32768, 4096, 4096)  float16      column   2.036                       2.047
12       large_M      (32768, 4096, 4096)  bfloat16     full     1.966                       1.981
13       large_M      (32768, 4096, 4096)  bfloat16     row      1.963                       1.979
14       large_M      (32768, 4096, 4096)  bfloat16     column   1.973                       1.981
15       large_M      (32768, 4096, 4096)  float32      full     24.096                      24.180
16       large_M      (32768, 4096, 4096)  float32      row      23.951                      24.033
17       large_M      (32768, 4096, 4096)  float32      column   23.996                      24.061
18       large_M      (16384, 2048, 2048)  float16      full     0.297                       0.298
19       large_M      (16384, 2048, 2048)  float16      row      0.298                       0.299
20       large_M      (16384, 2048, 2048)  float16      column   0.301                       0.300
21       large_M      (16384, 2048, 2048)  bfloat16     full     0.293                       0.293
22       large_M      (16384, 2048, 2048)  bfloat16     row      0.290                       0.291
23       large_M      (16384, 2048, 2048)  bfloat16     column   0.292                       0.293
24       large_M      (16384, 2048, 2048)  float32      full     3.077                       3.073
25       large_M      (16384, 2048, 2048)  float32      row      3.034                       3.033
26       large_M      (16384, 2048, 2048)  float32      column   3.040                       3.038
27       large_K      (8192, 8192, 65536)  float16      full     14.278                      14.297
28       large_K      (8192, 8192, 65536)  float16      row      14.325                      14.283
29       large_K      (8192, 8192, 65536)  float16      column   14.179                      14.302
30       large_K      (8192, 8192, 65536)  bfloat16     full     13.616                      13.628
31       large_K      (8192, 8192, 65536)  bfloat16     row      13.584                      13.642
32       large_K      (8192, 8192, 65536)  bfloat16     column   13.594                      13.694
33       large_K      (8192, 8192, 65536)  float32      full     175.933                     176.153
34       large_K      (8192, 8192, 65536)  float32      row      175.504                     175.877
35       large_K      (8192, 8192, 65536)  float32      column   175.432                     175.992
36       large_K      (4096, 4096, 32768)  float16      full     1.726                       1.724
37       large_K      (4096, 4096, 32768)  float16      row      1.731                       1.735
38       large_K      (4096, 4096, 32768)  float16      column   1.733                       1.737
39       large_K      (4096, 4096, 32768)  bfloat16     full     1.662                       1.658
40       large_K      (4096, 4096, 32768)  bfloat16     row      1.664                       1.655
41       large_K      (4096, 4096, 32768)  bfloat16     column   1.660                       1.667
42       large_K      (4096, 4096, 32768)  float32      full     22.263                      22.305
43       large_K      (4096, 4096, 32768)  float32      row      22.257                      22.322
44       large_K      (4096, 4096, 32768)  float32      column   22.247                      22.337
45       large_K      (2048, 2048, 16384)  float16      full     0.236                       0.236
46       large_K      (2048, 2048, 16384)  float16      row      0.238                       0.239
47       large_K      (2048, 2048, 16384)  float16      column   0.238                       0.239
48       large_K      (2048, 2048, 16384)  bfloat16     full     0.219                       0.219
49       large_K      (2048, 2048, 16384)  bfloat16     row      0.221                       0.222
50       large_K      (2048, 2048, 16384)  bfloat16     column   0.222                       0.222
51       large_K      (2048, 2048, 16384)  float32      full     2.786                       2.789
52       large_K      (2048, 2048, 16384)  float32      row      2.790                       2.782
53       large_K      (2048, 2048, 16384)  float32      column   2.791                       2.791
54       large_N      (8192, 65536, 8192)  float16      full     14.692                      14.723
55       large_N      (8192, 65536, 8192)  float16      row      14.721                      14.637
56       large_N      (8192, 65536, 8192)  float16      column   14.743                      14.737
57       large_N      (8192, 65536, 8192)  bfloat16     full     15.156                      15.128
58       large_N      (8192, 65536, 8192)  bfloat16     row      15.152                      15.124
59       large_N      (8192, 65536, 8192)  bfloat16     column   15.112                      15.090
60       large_N      (8192, 65536, 8192)  float32      full     179.127                     179.313
61       large_N      (8192, 65536, 8192)  float32      row      178.445                     178.961
62       large_N      (8192, 65536, 8192)  float32      column   178.693                     178.694
63       large_N      (4096, 32768, 4096)  float16      full     2.035                       2.037
64       large_N      (4096, 32768, 4096)  float16      row      2.042                       2.037
65       large_N      (4096, 32768, 4096)  float16      column   2.053                       2.046
66       large_N      (4096, 32768, 4096)  bfloat16     full     1.992                       1.997
67       large_N      (4096, 32768, 4096)  bfloat16     row      1.997                       1.987
68       large_N      (4096, 32768, 4096)  bfloat16     column   2.005                       2.001
69       large_N      (4096, 32768, 4096)  float32      full     23.126                      23.077
70       large_N      (4096, 32768, 4096)  float32      row      23.002                      22.956
71       large_N      (4096, 32768, 4096)  float32      column   23.012                      22.969
72       large_N      (2048, 16384, 2048)  float16      full     0.314                       0.314
73       large_N      (2048, 16384, 2048)  float16      row      0.311                       0.311
74       large_N      (2048, 16384, 2048)  float16      column   0.314                       0.314
75       large_N      (2048, 16384, 2048)  bfloat16     full     0.306                       0.305
76       large_N      (2048, 16384, 2048)  bfloat16     row      0.302                       0.302
77       large_N      (2048, 16384, 2048)  bfloat16     column   0.305                       0.303
78       large_N      (2048, 16384, 2048)  float32      full     2.975                       2.971
79       large_N      (2048, 16384, 2048)  float32      row      2.927                       2.925
80       large_N      (2048, 16384, 2048)  float32      column   2.934                       2.936
81       large_all    (16384, 16384, 16384) float16      full     14.062                      14.096
82       large_all    (16384, 16384, 16384) float16      row      14.058                      14.078
83       large_all    (16384, 16384, 16384) float16      column   14.107                      14.120
84       large_all    (16384, 16384, 16384) bfloat16     full     13.504                      13.460
85       large_all    (16384, 16384, 16384) bfloat16     row      13.495                      13.499
86       large_all    (16384, 16384, 16384) bfloat16     column   13.509                      13.461
87       large_all    (16384, 16384, 16384) float32      full     177.279                     177.242
88       large_all    (16384, 16384, 16384) float32      row      176.896                     176.651
89       large_all    (16384, 16384, 16384) float32      column   176.830                     176.451
```

## script

```
"""
Torch addmm benchmarking script covering different input configurations.
Tests 30 different combinations of:
- Input types: large M, large K, large N, large everything
- Data types: float16, bfloat16, float32
- Bias types: full, row, column
"""

import itertools

import torch
import torch.utils.benchmark as benchmark

def create_test_configurations():
    """Create 30 different test configurations for torch.addmm"""

    # Define dimension configurations
    # Large M: many rows in input matrix
    # Large K: many columns in input/rows in weight
    # Large N: many columns in weight matrix
    # Large everything: all dimensions large
    dim_configs = [
        # Large M configurations
        {"M": 65536, "K": 8192, "N": 8192, "type": "large_M"},
        {"M": 32768, "K": 4096, "N": 4096, "type": "large_M"},
        {"M": 16384, "K": 2048, "N": 2048, "type": "large_M"},
        # Large K configurations
        {"M": 8192, "K": 65536, "N": 8192, "type": "large_K"},
        {"M": 4096, "K": 32768, "N": 4096, "type": "large_K"},
        {"M": 2048, "K": 16384, "N": 2048, "type": "large_K"},
        # Large N configurations
        {"M": 8192, "K": 8192, "N": 65536, "type": "large_N"},
        {"M": 4096, "K": 4096, "N": 32768, "type": "large_N"},
        {"M": 2048, "K": 2048, "N": 16384, "type": "large_N"},
        # Large everything configurations
        {"M": 16384, "K": 16384, "N": 16384, "type": "large_all"},
    ]

    # Data types to test
    dtypes = [torch.float16, torch.bfloat16, torch.float32]

    # Bias configurations
    bias_configs = ["full", "row", "column"]

    # Generate all combinations
    configurations = []
    config_id = 0

    for dim_config, dtype, bias_type in itertools.product(
        dim_configs, dtypes, bias_configs
    ):
        config = {
            "id": config_id,
            "M": dim_config["M"],
            "K": dim_config["K"],
            "N": dim_config["N"],
            "dim_type": dim_config["type"],
            "dtype": dtype,
            "bias_type": bias_type,
        }
        configurations.append(config)
        config_id += 1

    return configurations

def create_tensors(config):
    """Create input tensors for a given configuration"""
    M, K, N = config["M"], config["K"], config["N"]
    dtype = config["dtype"]
    bias_type = config["bias_type"]

    # Create input tensor (M x K)
    input_tensor = torch.randn(M, K, dtype=dtype, device="cuda")

    # Create weight tensor (K x N)
    weight_tensor = torch.randn(K, N, dtype=dtype, device="cuda")

    # Create bias tensor based on bias type
    if bias_type == "full":
        bias_tensor = torch.randn(M, N, dtype=dtype, device="cuda")
    elif bias_type == "row":
        bias_tensor = torch.randn(M, 1, dtype=dtype, device="cuda")
    elif bias_type == "column":
        bias_tensor = torch.randn(1, N, dtype=dtype, device="cuda")

    return input_tensor, weight_tensor, bias_tensor

def benchmark_addmm(config, use_compile=False):
    """Benchmark torch.addmm for a given configuration"""
    input_tensor, weight_tensor, bias_tensor = create_tensors(config)

    # Define the operation
    def addmm_op():
        return torch.addmm(bias_tensor, input_tensor, weight_tensor)

    # Optionally compile the operation
    if use_compile:
        addmm_op = torch.compile(addmm_op)
        # Warmup for compiled version
        for _ in range(3):
            _ = addmm_op()
        torch.cuda.synchronize()

    # Benchmark using torch.utils.benchmark
    timer = benchmark.Timer(
        stmt="addmm_op()",
        globals={"addmm_op": addmm_op},
        description=f"addmm_{config['dim_type']}_{config['dtype']}_{config['bias_type']}_compiled_{use_compile}",
    )

    # Run benchmark
    measurement = timer.blocked_autorange(min_run_time=1.0)

    return measurement

def print_tensor_info(config, input_tensor, weight_tensor, bias_tensor):
    """Print information about input tensors"""
    print(f"\\nConfiguration {config['id']}:")
    print(f"  Dimension type: {config['dim_type']}")
    print(f"  Data type: {config['dtype']}")
    print(f"  Bias type: {config['bias_type']}")
    print(f"  Input tensor shape: {input_tensor.shape} ({input_tensor.dtype})")
    print(f"  Weight tensor shape: {weight_tensor.shape} ({weight_tensor.dtype})")
    print(f"  Bias tensor shape: {bias_tensor.shape} ({bias_tensor.dtype})")

def main():
    """Main benchmarking function"""
    print("Torch addmm Benchmarking Script (Compiled Only)")
    print("=" * 50)

    # Check CUDA availability
    if not torch.cuda.is_available():
        print("CUDA not available. This benchmark requires CUDA.")
        return

    print(f"Using device: {torch.cuda.get_device_name()}")
    print(f"PyTorch version: {torch.__version__}")

    # Create test configurations
    configurations = create_test_configurations()
    print(f"\\nTesting {len(configurations)} configurations...")

    results = []

    for config in configurations:
        print(f"\\n{'='*60}")

        # Create tensors and print info
        input_tensor, weight_tensor, bias_tensor = create_tensors(config)
        print_tensor_info(config, input_tensor, weight_tensor, bias_tensor)

        # Benchmark with compilation only
        print("\\nBenchmarking with torch.compile:")
        try:
            measurement = benchmark_addmm(config, use_compile=True)
            runtime_ms = measurement.mean * 1000
            print(f"  Runtime: {runtime_ms:.3f} ms")
            results.append(
                {
                    "config": config,
                    "runtime_ms": runtime_ms,
                }
            )
        except Exception as e:
            print(f"  Error: {e}")

        # Clear cache
        torch.cuda.empty_cache()

    # Print summary
    print(f"\\n\\n{'='*80}")
    print("BENCHMARK SUMMARY")
    print(f"{'='*80}")

    print(
        f"{'Config':<8} {'Dim Type':<12} {'(M, N, K)':<20} {'DType':<12} {'Bias':<8} {'Runtime (ms)':<15}"
    )
    print("-" * 90)

    for result in results:
        config = result["config"]
        dtype_str = str(config["dtype"]).split(".")[-1]
        dimensions_str = f"({config['M']}, {config['N']}, {config['K']})"
        print(
            f"{config['id']:<8} {config['dim_type']:<12} {dimensions_str:<20} {dtype_str:<12} {config['bias_type']:<8} {result['runtime_ms']:<15.3f}"
        )

if __name__ == "__main__":
    main()
```

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov

[ghstack-poisoned]
# why

- simplifies the code a lot
- unnecessary for performance (see below)
- matches other mm family kernels in logic now

# what

- if we're not autotuning (only ATen), we use a flexible layout
- in any case, we use inp_expanded (the expanded bias view) for
  the aten kernel

# performance analysis

## results

this is on H100

```
================================================================================
BENCHMARK SUMMARY (MERGED)
================================================================================
Config   Dim Type     (M, N, K)            DType        Bias     Runtime (inp_expanded) ms   Runtime (inp) ms
----------------------------------------------------------------------------------------------------------------------------
0        large_M      (65536, 8192, 8192)  float16      full     14.531                      14.562
1        large_M      (65536, 8192, 8192)  float16      row      14.682                      14.675
2        large_M      (65536, 8192, 8192)  float16      column   14.754                      14.740
3        large_M      (65536, 8192, 8192)  bfloat16     full     15.172                      15.148
4        large_M      (65536, 8192, 8192)  bfloat16     row      15.072                      15.085
5        large_M      (65536, 8192, 8192)  bfloat16     column   15.082                      15.114
6        large_M      (65536, 8192, 8192)  float32      full     185.726                     186.308
7        large_M      (65536, 8192, 8192)  float32      row      185.042                     185.864
8        large_M      (65536, 8192, 8192)  float32      column   185.221                     185.991
9        large_M      (32768, 4096, 4096)  float16      full     2.025                       2.036
10       large_M      (32768, 4096, 4096)  float16      row      2.029                       2.033
11       large_M      (32768, 4096, 4096)  float16      column   2.036                       2.047
12       large_M      (32768, 4096, 4096)  bfloat16     full     1.966                       1.981
13       large_M      (32768, 4096, 4096)  bfloat16     row      1.963                       1.979
14       large_M      (32768, 4096, 4096)  bfloat16     column   1.973                       1.981
15       large_M      (32768, 4096, 4096)  float32      full     24.096                      24.180
16       large_M      (32768, 4096, 4096)  float32      row      23.951                      24.033
17       large_M      (32768, 4096, 4096)  float32      column   23.996                      24.061
18       large_M      (16384, 2048, 2048)  float16      full     0.297                       0.298
19       large_M      (16384, 2048, 2048)  float16      row      0.298                       0.299
20       large_M      (16384, 2048, 2048)  float16      column   0.301                       0.300
21       large_M      (16384, 2048, 2048)  bfloat16     full     0.293                       0.293
22       large_M      (16384, 2048, 2048)  bfloat16     row      0.290                       0.291
23       large_M      (16384, 2048, 2048)  bfloat16     column   0.292                       0.293
24       large_M      (16384, 2048, 2048)  float32      full     3.077                       3.073
25       large_M      (16384, 2048, 2048)  float32      row      3.034                       3.033
26       large_M      (16384, 2048, 2048)  float32      column   3.040                       3.038
27       large_K      (8192, 8192, 65536)  float16      full     14.278                      14.297
28       large_K      (8192, 8192, 65536)  float16      row      14.325                      14.283
29       large_K      (8192, 8192, 65536)  float16      column   14.179                      14.302
30       large_K      (8192, 8192, 65536)  bfloat16     full     13.616                      13.628
31       large_K      (8192, 8192, 65536)  bfloat16     row      13.584                      13.642
32       large_K      (8192, 8192, 65536)  bfloat16     column   13.594                      13.694
33       large_K      (8192, 8192, 65536)  float32      full     175.933                     176.153
34       large_K      (8192, 8192, 65536)  float32      row      175.504                     175.877
35       large_K      (8192, 8192, 65536)  float32      column   175.432                     175.992
36       large_K      (4096, 4096, 32768)  float16      full     1.726                       1.724
37       large_K      (4096, 4096, 32768)  float16      row      1.731                       1.735
38       large_K      (4096, 4096, 32768)  float16      column   1.733                       1.737
39       large_K      (4096, 4096, 32768)  bfloat16     full     1.662                       1.658
40       large_K      (4096, 4096, 32768)  bfloat16     row      1.664                       1.655
41       large_K      (4096, 4096, 32768)  bfloat16     column   1.660                       1.667
42       large_K      (4096, 4096, 32768)  float32      full     22.263                      22.305
43       large_K      (4096, 4096, 32768)  float32      row      22.257                      22.322
44       large_K      (4096, 4096, 32768)  float32      column   22.247                      22.337
45       large_K      (2048, 2048, 16384)  float16      full     0.236                       0.236
46       large_K      (2048, 2048, 16384)  float16      row      0.238                       0.239
47       large_K      (2048, 2048, 16384)  float16      column   0.238                       0.239
48       large_K      (2048, 2048, 16384)  bfloat16     full     0.219                       0.219
49       large_K      (2048, 2048, 16384)  bfloat16     row      0.221                       0.222
50       large_K      (2048, 2048, 16384)  bfloat16     column   0.222                       0.222
51       large_K      (2048, 2048, 16384)  float32      full     2.786                       2.789
52       large_K      (2048, 2048, 16384)  float32      row      2.790                       2.782
53       large_K      (2048, 2048, 16384)  float32      column   2.791                       2.791
54       large_N      (8192, 65536, 8192)  float16      full     14.692                      14.723
55       large_N      (8192, 65536, 8192)  float16      row      14.721                      14.637
56       large_N      (8192, 65536, 8192)  float16      column   14.743                      14.737
57       large_N      (8192, 65536, 8192)  bfloat16     full     15.156                      15.128
58       large_N      (8192, 65536, 8192)  bfloat16     row      15.152                      15.124
59       large_N      (8192, 65536, 8192)  bfloat16     column   15.112                      15.090
60       large_N      (8192, 65536, 8192)  float32      full     179.127                     179.313
61       large_N      (8192, 65536, 8192)  float32      row      178.445                     178.961
62       large_N      (8192, 65536, 8192)  float32      column   178.693                     178.694
63       large_N      (4096, 32768, 4096)  float16      full     2.035                       2.037
64       large_N      (4096, 32768, 4096)  float16      row      2.042                       2.037
65       large_N      (4096, 32768, 4096)  float16      column   2.053                       2.046
66       large_N      (4096, 32768, 4096)  bfloat16     full     1.992                       1.997
67       large_N      (4096, 32768, 4096)  bfloat16     row      1.997                       1.987
68       large_N      (4096, 32768, 4096)  bfloat16     column   2.005                       2.001
69       large_N      (4096, 32768, 4096)  float32      full     23.126                      23.077
70       large_N      (4096, 32768, 4096)  float32      row      23.002                      22.956
71       large_N      (4096, 32768, 4096)  float32      column   23.012                      22.969
72       large_N      (2048, 16384, 2048)  float16      full     0.314                       0.314
73       large_N      (2048, 16384, 2048)  float16      row      0.311                       0.311
74       large_N      (2048, 16384, 2048)  float16      column   0.314                       0.314
75       large_N      (2048, 16384, 2048)  bfloat16     full     0.306                       0.305
76       large_N      (2048, 16384, 2048)  bfloat16     row      0.302                       0.302
77       large_N      (2048, 16384, 2048)  bfloat16     column   0.305                       0.303
78       large_N      (2048, 16384, 2048)  float32      full     2.975                       2.971
79       large_N      (2048, 16384, 2048)  float32      row      2.927                       2.925
80       large_N      (2048, 16384, 2048)  float32      column   2.934                       2.936
81       large_all    (16384, 16384, 16384) float16      full     14.062                      14.096
82       large_all    (16384, 16384, 16384) float16      row      14.058                      14.078
83       large_all    (16384, 16384, 16384) float16      column   14.107                      14.120
84       large_all    (16384, 16384, 16384) bfloat16     full     13.504                      13.460
85       large_all    (16384, 16384, 16384) bfloat16     row      13.495                      13.499
86       large_all    (16384, 16384, 16384) bfloat16     column   13.509                      13.461
87       large_all    (16384, 16384, 16384) float32      full     177.279                     177.242
88       large_all    (16384, 16384, 16384) float32      row      176.896                     176.651
89       large_all    (16384, 16384, 16384) float32      column   176.830                     176.451
```

## script

```
"""
Torch addmm benchmarking script covering different input configurations.
Tests 30 different combinations of:
- Input types: large M, large K, large N, large everything
- Data types: float16, bfloat16, float32
- Bias types: full, row, column
"""

import itertools

import torch
import torch.utils.benchmark as benchmark

def create_test_configurations():
    """Create 30 different test configurations for torch.addmm"""

    # Define dimension configurations
    # Large M: many rows in input matrix
    # Large K: many columns in input/rows in weight
    # Large N: many columns in weight matrix
    # Large everything: all dimensions large
    dim_configs = [
        # Large M configurations
        {"M": 65536, "K": 8192, "N": 8192, "type": "large_M"},
        {"M": 32768, "K": 4096, "N": 4096, "type": "large_M"},
        {"M": 16384, "K": 2048, "N": 2048, "type": "large_M"},
        # Large K configurations
        {"M": 8192, "K": 65536, "N": 8192, "type": "large_K"},
        {"M": 4096, "K": 32768, "N": 4096, "type": "large_K"},
        {"M": 2048, "K": 16384, "N": 2048, "type": "large_K"},
        # Large N configurations
        {"M": 8192, "K": 8192, "N": 65536, "type": "large_N"},
        {"M": 4096, "K": 4096, "N": 32768, "type": "large_N"},
        {"M": 2048, "K": 2048, "N": 16384, "type": "large_N"},
        # Large everything configurations
        {"M": 16384, "K": 16384, "N": 16384, "type": "large_all"},
    ]

    # Data types to test
    dtypes = [torch.float16, torch.bfloat16, torch.float32]

    # Bias configurations
    bias_configs = ["full", "row", "column"]

    # Generate all combinations
    configurations = []
    config_id = 0

    for dim_config, dtype, bias_type in itertools.product(
        dim_configs, dtypes, bias_configs
    ):
        config = {
            "id": config_id,
            "M": dim_config["M"],
            "K": dim_config["K"],
            "N": dim_config["N"],
            "dim_type": dim_config["type"],
            "dtype": dtype,
            "bias_type": bias_type,
        }
        configurations.append(config)
        config_id += 1

    return configurations

def create_tensors(config):
    """Create input tensors for a given configuration"""
    M, K, N = config["M"], config["K"], config["N"]
    dtype = config["dtype"]
    bias_type = config["bias_type"]

    # Create input tensor (M x K)
    input_tensor = torch.randn(M, K, dtype=dtype, device="cuda")

    # Create weight tensor (K x N)
    weight_tensor = torch.randn(K, N, dtype=dtype, device="cuda")

    # Create bias tensor based on bias type
    if bias_type == "full":
        bias_tensor = torch.randn(M, N, dtype=dtype, device="cuda")
    elif bias_type == "row":
        bias_tensor = torch.randn(M, 1, dtype=dtype, device="cuda")
    elif bias_type == "column":
        bias_tensor = torch.randn(1, N, dtype=dtype, device="cuda")

    return input_tensor, weight_tensor, bias_tensor

def benchmark_addmm(config, use_compile=False):
    """Benchmark torch.addmm for a given configuration"""
    input_tensor, weight_tensor, bias_tensor = create_tensors(config)

    # Define the operation
    def addmm_op():
        return torch.addmm(bias_tensor, input_tensor, weight_tensor)

    # Optionally compile the operation
    if use_compile:
        addmm_op = torch.compile(addmm_op)
        # Warmup for compiled version
        for _ in range(3):
            _ = addmm_op()
        torch.cuda.synchronize()

    # Benchmark using torch.utils.benchmark
    timer = benchmark.Timer(
        stmt="addmm_op()",
        globals={"addmm_op": addmm_op},
        description=f"addmm_{config['dim_type']}_{config['dtype']}_{config['bias_type']}_compiled_{use_compile}",
    )

    # Run benchmark
    measurement = timer.blocked_autorange(min_run_time=1.0)

    return measurement

def print_tensor_info(config, input_tensor, weight_tensor, bias_tensor):
    """Print information about input tensors"""
    print(f"\\nConfiguration {config['id']}:")
    print(f"  Dimension type: {config['dim_type']}")
    print(f"  Data type: {config['dtype']}")
    print(f"  Bias type: {config['bias_type']}")
    print(f"  Input tensor shape: {input_tensor.shape} ({input_tensor.dtype})")
    print(f"  Weight tensor shape: {weight_tensor.shape} ({weight_tensor.dtype})")
    print(f"  Bias tensor shape: {bias_tensor.shape} ({bias_tensor.dtype})")

def main():
    """Main benchmarking function"""
    print("Torch addmm Benchmarking Script (Compiled Only)")
    print("=" * 50)

    # Check CUDA availability
    if not torch.cuda.is_available():
        print("CUDA not available. This benchmark requires CUDA.")
        return

    print(f"Using device: {torch.cuda.get_device_name()}")
    print(f"PyTorch version: {torch.__version__}")

    # Create test configurations
    configurations = create_test_configurations()
    print(f"\\nTesting {len(configurations)} configurations...")

    results = []

    for config in configurations:
        print(f"\\n{'='*60}")

        # Create tensors and print info
        input_tensor, weight_tensor, bias_tensor = create_tensors(config)
        print_tensor_info(config, input_tensor, weight_tensor, bias_tensor)

        # Benchmark with compilation only
        print("\\nBenchmarking with torch.compile:")
        try:
            measurement = benchmark_addmm(config, use_compile=True)
            runtime_ms = measurement.mean * 1000
            print(f"  Runtime: {runtime_ms:.3f} ms")
            results.append(
                {
                    "config": config,
                    "runtime_ms": runtime_ms,
                }
            )
        except Exception as e:
            print(f"  Error: {e}")

        # Clear cache
        torch.cuda.empty_cache()

    # Print summary
    print(f"\\n\\n{'='*80}")
    print("BENCHMARK SUMMARY")
    print(f"{'='*80}")

    print(
        f"{'Config':<8} {'Dim Type':<12} {'(M, N, K)':<20} {'DType':<12} {'Bias':<8} {'Runtime (ms)':<15}"
    )
    print("-" * 90)

    for result in results:
        config = result["config"]
        dtype_str = str(config["dtype"]).split(".")[-1]
        dimensions_str = f"({config['M']}, {config['N']}, {config['K']})"
        print(
            f"{config['id']:<8} {config['dim_type']:<12} {dimensions_str:<20} {dtype_str:<12} {config['bias_type']:<8} {result['runtime_ms']:<15.3f}"
        )

if __name__ == "__main__":
    main()
```

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov

[ghstack-poisoned]
@github-actions github-actions bot deleted the gh/coconutruben/36/head branch October 12, 2025 02:13
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants