In [None]:
import torch
import time
import itertools

def probe_device_dtypes(device):
    """Dynamically probe the device for supported dtypes."""
    all_dtypes = [
        torch.float64, torch.float32, torch.float16, torch.bfloat16,
        torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64,
        torch.bool, torch.complex64, torch.complex128
    ]
    
    supported_dtypes = []
    matrix_size = 64  # Small size for quick probing
    
    for dtype in all_dtypes:
        try:
            if dtype in [torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64]:
                matrix_a = torch.randint(0, 100, (matrix_size, matrix_size), dtype=dtype).to(device)
                matrix_b = torch.randint(0, 100, (matrix_size, matrix_size), dtype=dtype).to(device)
            elif dtype == torch.bool:
                matrix_a = torch.randint(0, 2, (matrix_size, matrix_size), dtype=dtype).to(device)
                matrix_b = torch.randint(0, 2, (matrix_size, matrix_size), dtype=dtype).to(device)
            elif dtype in [torch.complex64, torch.complex128]:
                matrix_a = torch.randn(matrix_size, matrix_size, dtype=dtype).to(device)
                matrix_b = torch.randn(matrix_size, matrix_size, dtype=dtype).to(device)
            else:
                matrix_a = torch.rand(matrix_size, matrix_size, dtype=dtype).to(device)
                matrix_b = torch.rand(matrix_size, matrix_size, dtype=dtype).to(device)
            
            _ = torch.matmul(matrix_a, matrix_b)
            if "cuda" in device:
                torch.cuda.synchronize()
            elif "xpu" in device:
                torch.xpu.synchronize()
            
            supported_dtypes.append(dtype)
        except (RuntimeError, TypeError, AttributeError):
            continue
    
    return supported_dtypes, [dt for dt in all_dtypes if dt not in supported_dtypes]

def check_amp_support(device, dtype):
    """Check if AMP is supported for the device and dtype."""
    amp_dtypes = [torch.float32, torch.float16, torch.bfloat16]
    if "cuda" in device and torch.cuda.is_available():
        return dtype in amp_dtypes
    elif "xpu" in device and hasattr(torch, "xpu") and torch.xpu.is_available():
        return dtype in amp_dtypes
    elif device == "cpu":
        return dtype in amp_dtypes and hasattr(torch, "cpu") and hasattr(torch.cpu, "amp")
    return False

def check_compile_support(device):
    """Check if torch.compile is supported and functional."""
    if not hasattr(torch, "compile"):
        return False
    # Test compilation with a simple operation
    try:
        compiled_fn = torch.compile(lambda x, y: x + y)
        test_tensor = torch.ones(1).to(device)
        _ = compiled_fn(test_tensor, test_tensor)
        if "cuda" in device:
            torch.cuda.synchronize()
        elif "xpu" in device:
            torch.xpu.synchronize()
        return True
    except (RuntimeError, TypeError, AttributeError):
        return False

def run_benchmark(device, dtype, matrix_size=4096, epochs=10, warmup_runs=2, use_amp=False, use_compile=False):
    """Run benchmark for a specific dtype with AMP and torch.compile options."""
    if dtype in [torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64]:
        matrix_a = torch.randint(0, 100, (matrix_size, matrix_size), dtype=dtype).to(device)
        matrix_b = torch.randint(0, 100, (matrix_size, matrix_size), dtype=dtype).to(device)
    elif dtype == torch.bool:
        matrix_a = torch.randint(0, 2, (matrix_size, matrix_size), dtype=dtype).to(device)
        matrix_b = torch.randint(0, 2, (matrix_size, matrix_size), dtype=dtype).to(device)
    elif dtype in [torch.complex64, torch.complex128]:
        matrix_a = torch.randn(matrix_size, matrix_size, dtype=dtype).to(device)
        matrix_b = torch.randn(matrix_size, matrix_size, dtype=dtype).to(device)
    else:
        matrix_a = torch.rand(matrix_size, matrix_size, dtype=dtype).to(device)
        matrix_b = torch.rand(matrix_size, matrix_size, dtype=dtype).to(device)

    # Compile matmul if requested and supported
    matmul_fn = torch.compile(torch.matmul) if use_compile and check_compile_support(device) else torch.matmul

    # Warmup runs
    for _ in range(warmup_runs):
        if use_amp and check_amp_support(device, dtype):
            with torch.autocast(device_type=device.split(":")[0], dtype=torch.float16 if "cuda" in device else torch.bfloat16):
                _ = matmul_fn(matrix_a, matrix_b)
        else:
            _ = matmul_fn(matrix_a, matrix_b)
    if "cuda" in device:
        torch.cuda.synchronize()
    elif "xpu" in device:
        torch.xpu.synchronize()

    # Main benchmark
    start_time = time.time()
    if "cuda" in device:
        torch.cuda.synchronize()
    elif "xpu" in device:
        torch.xpu.synchronize()
    
    for epoch in range(epochs):
        if use_amp and check_amp_support(device, dtype):
            with torch.autocast(device_type=device.split(":")[0], dtype=torch.float16 if "cuda" in device else torch.bfloat16):
                result = matmul_fn(matrix_a, matrix_b)
        else:
            result = matmul_fn(matrix_a, matrix_b)
        if "cuda" in device:
            torch.cuda.synchronize()
        elif "xpu" in device:
            torch.xpu.synchronize()
        if (epoch + 1) % 10 == 0 or epoch == epochs - 1:
            print(f"Completed epoch {epoch + 1}/{epochs} for {str(dtype).split('.')[-1]}")

    end_time = time.time()
    if "cuda" in device:
        torch.cuda.synchronize()
    elif "xpu" in device:
        torch.xpu.synchronize()
    
    total_time = end_time - start_time
    operations = matrix_size ** 3 * epochs
    
    if total_time <= 0:
        print(f"Warning: Total time for {str(dtype).split('.')[-1]} was zero or negative, setting GOPS to 0")
        gops = 0.0
    else:
        gops = (operations / total_time) / 1e9
    
    return total_time, gops

def run_multi_dtype_benchmark(device_str="cuda:0", matrix_size=4096, epochs=10, warmup_runs=2, runs=1):
    # Determine device
    if "cuda" in device_str and torch.cuda.is_available():
        device = device_str
        device_name = torch.cuda.get_device_name(device)
    elif "xpu" in device_str and hasattr(torch, "xpu") and torch.xpu.is_available():
        device = device_str
        device_name = torch.xpu.get_device_name(device)
    else:
        device = "cpu"
        device_name = "CPU"
    
    # Check AMP and compile support
    compile_supported = check_compile_support(device)
    avx512_support = device == "cpu" and hasattr(torch.backends, "cpu") and hasattr(torch.backends.cpu, "has_avx512") and torch.backends.cpu.has_avx512()

    # Display initial info
    print(f"Running on: {device} ({device_name})")
    if device == "cpu":
        print(f"AVX512 Support: {'Yes' if avx512_support else 'No'}")
    print(f"Matrix size: {matrix_size}x{matrix_size}")
    print(f"Epochs: {epochs}")
    print(f"Warmup runs: {warmup_runs}")
    print(f"Number of runs: {runs}")
    print(f"torch.compile Supported: {compile_supported}")
    print(f"AMP Supported: {'Partially (float32, float16, bfloat16)' if any(check_amp_support(device, dt) for dt in [torch.float32, torch.float16, torch.bfloat16]) else 'No'}\n")

    # Probe supported dtypes
    print("Probing device for supported dtypes...")
    supported_dtypes, unsupported_dtypes = probe_device_dtypes(device)
    print(f"Supported dtypes: {[str(dt).split('.')[-1] for dt in supported_dtypes]}\n")

    # Generate all permutations
    amp_options = [False, True] if any(check_amp_support(device, dt) for dt in supported_dtypes) else [False]
    compile_options = [False, True] if compile_supported else [False]
    permutations = list(itertools.product(amp_options, compile_options))

    # Store results
    all_results = {}

    for amp, comp in permutations:
        config_name = f"AMP:{'ON' if amp else 'OFF'}, Compile:{'ON' if comp else 'OFF'}"
        print(f"\nTesting Configuration: {config_name}")
        print("-" * 70)
        
        config_results = {dtype: {"times": [], "gops_list": []} for dtype in supported_dtypes}
        
        for run in range(runs):
            print(f"\nRun {run + 1}/{runs}")
            print("-" * 70)
            
            for dtype in supported_dtypes:
                print(f"Running benchmark for {str(dtype).split('.')[-1]}")
                total_time, gops = run_benchmark(device, dtype, matrix_size, epochs, warmup_runs, use_amp=amp, use_compile=comp)
                config_results[dtype]["times"].append(total_time)
                config_results[dtype]["gops_list"].append(gops)
                
                dtype_name = str(dtype).split('.')[-1]
                print(f"{dtype_name:<12} Total Time: {total_time:<6.2f}s  GOPS: {gops:<6.2f}  AMP: {amp:<5} Compile: {comp}")

        # Calculate and print averages for this configuration
        print(f"\nAverage Results for {config_name}:")
        print("-" * 70)
        print(f"{'Dtype':<12} {'Avg Total Time (s)':<20} {'Avg GOPS':<12} {'AMP':<8} {'Compile':<8}")
        print("-" * 70)
        
        for dtype in supported_dtypes:
            dtype_name = str(dtype).split('.')[-1]
            avg_time = sum(config_results[dtype]["times"]) / runs
            avg_gops = sum(config_results[dtype]["gops_list"]) / runs
            print(f"{dtype_name:<12} {avg_time:<20.2f} {avg_gops:<12.2f} {str(amp):<8} {str(comp):<8}")
        
        print("-" * 70)
        all_results[config_name] = config_results

    # Print supported and unsupported dtypes
    print("\nSummary:")
    print(f"Supported dtypes: {[str(dt).split('.')[-1] for dt in supported_dtypes]}")
    print(f"Unsupported dtypes: {[str(dt).split('.')[-1] for dt in unsupported_dtypes]}")
    
    return all_results

# Example usage in notebook
matrix_size = 4096
epochs = 100
warmup_runs = 20
runs = 3
device = "cuda:0"  # Options: "cuda:0", "xpu:0", "cpu"

results = run_multi_dtype_benchmark(
    device_str=device,
    matrix_size=matrix_size,
    epochs=epochs,
    warmup_runs=warmup_runs,
    runs=runs
)