In [1]:
import time

import torch

device = "mps"
# Matrix multiplication benchmark
sizes = [1024, 2048, 4096, 8192]

for size in sizes:
    a = torch.randn(size, size, device=device)
    b = torch.randn(size, size, device=device)

    # Warmup
    for _ in range(3):
        c = a @ b
    torch.mps.synchronize()

    # Benchmark
    n_iters = 10
    start = time.time()
    for _ in range(n_iters):
        c = a @ b
    torch.mps.synchronize()
    elapsed = time.time() - start

    tflops = (2 * size ** 3 * n_iters) / elapsed / 1e12
    print(f"{size}x{size} matmul: {tflops:.2f} TFLOPS")

1024x1024 matmul: 3.83 TFLOPS
2048x2048 matmul: 9.11 TFLOPS
4096x4096 matmul: 13.31 TFLOPS
8192x8192 matmul: 13.03 TFLOPS
