In [None]:
import torch
import time

# Function to set the number of CPU threads
def set_cpu_thread_cap(num_threads):
    torch.set_num_threads(num_threads)
    torch.set_num_interop_threads(num_threads)
    print(f"Set max CPU threads to: {num_threads}")

def probe_device_dtypes(device):

    # If running on CPU, set the number of threads
    if device == 'cpu':
        set_cpu_thread_cap(30)  # Set this to the desired number of CPU threads
    
    """Dynamically probe the device for supported dtypes."""
    all_dtypes = [
        torch.float64, torch.float32, torch.float16, torch.bfloat16,
        torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64,
        torch.bool, torch.complex64, torch.complex128
    ]
    
    supported_dtypes = []
    matrix_size = 64  # Small size for quick probing
    
    for dtype in all_dtypes:
        try:
            if dtype in [torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64]:
                matrix_a = torch.randint(0, 100, (matrix_size, matrix_size), dtype=dtype).to(device)
                matrix_b = torch.randint(0, 100, (matrix_size, matrix_size), dtype=dtype).to(device)
            elif dtype == torch.bool:
                matrix_a = torch.randint(0, 2, (matrix_size, matrix_size), dtype=dtype).to(device)
                matrix_b = torch.randint(0, 2, (matrix_size, matrix_size), dtype=dtype).to(device)
            elif dtype in [torch.complex64, torch.complex128]:
                matrix_a = torch.randn(matrix_size, matrix_size, dtype=dtype).to(device)
                matrix_b = torch.randn(matrix_size, matrix_size, dtype=dtype).to(device)
            else:
                matrix_a = torch.rand(matrix_size, matrix_size, dtype=dtype).to(device)
                matrix_b = torch.rand(matrix_size, matrix_size, dtype=dtype).to(device)
            
            _ = torch.matmul(matrix_a, matrix_b)
            if "cuda" in device:
                torch.cuda.synchronize()
            elif "xpu" in device:
                torch.xpu.synchronize()
            
            supported_dtypes.append(dtype)
        except (RuntimeError, TypeError, AttributeError):
            continue
    
    return supported_dtypes, [dt for dt in all_dtypes if dt not in supported_dtypes]

def run_benchmark(device, dtype, matrix_size=4096, epochs=10, warmup_runs=2):
    """Run benchmark for a specific dtype on the given device."""
    if dtype in [torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64]:
        matrix_a = torch.randint(0, 100, (matrix_size, matrix_size), dtype=dtype).to(device)
        matrix_b = torch.randint(0, 100, (matrix_size, matrix_size), dtype=dtype).to(device)
    elif dtype == torch.bool:
        matrix_a = torch.randint(0, 2, (matrix_size, matrix_size), dtype=dtype).to(device)
        matrix_b = torch.randint(0, 2, (matrix_size, matrix_size), dtype=dtype).to(device)
    elif dtype in [torch.complex64, torch.complex128]:
        matrix_a = torch.randn(matrix_size, matrix_size, dtype=dtype).to(device)
        matrix_b = torch.randn(matrix_size, matrix_size, dtype=dtype).to(device)
    else:
        matrix_a = torch.rand(matrix_size, matrix_size, dtype=dtype).to(device)
        matrix_b = torch.rand(matrix_size, matrix_size, dtype=dtype).to(device)

    # Warmup runs
    for _ in range(warmup_runs):
        _ = torch.matmul(matrix_a, matrix_b)
    if "cuda" in device:
        torch.cuda.synchronize()
    elif "xpu" in device:
        torch.xpu.synchronize()

    # Main benchmark
    start_time = time.time()
    if "cuda" in device:
        torch.cuda.synchronize()
    elif "xpu" in device:
        torch.xpu.synchronize()
    
    for epoch in range(epochs):
        result = torch.matmul(matrix_a, matrix_b)
        if "cuda" in device:
            torch.cuda.synchronize()
        elif "xpu" in device:
            torch.xpu.synchronize()
        if (epoch + 1) % 10 == 0 or epoch == epochs - 1:
            print(f"Completed epoch {epoch + 1}/{epochs} for {str(dtype).split('.')[-1]}")

    end_time = time.time()
    if "cuda" in device:
        torch.cuda.synchronize()
    elif "xpu" in device:
        torch.xpu.synchronize()
    
    total_time = end_time - start_time
    operations = matrix_size ** 3 * epochs
    
    # Handle zero or near-zero time to avoid division by zero
    if total_time <= 0:
        print(f"Warning: Total time for {str(dtype).split('.')[-1]} was zero or negative, setting GOPS to 0")
        gops = 0.0
    else:
        gops = (operations / total_time) / 1e9
    
    return total_time, gops

def run_multi_dtype_benchmark(device_str="cuda:0", matrix_size=4096, epochs=10, warmup_runs=2, runs=1):
    # Determine device
    if "cuda" in device_str and torch.cuda.is_available():
        device = device_str
        device_name = torch.cuda.get_device_name(device)
    elif "xpu" in device_str and hasattr(torch, "xpu") and torch.xpu.is_available():
        device = device_str
        device_name = torch.xpu.get_device_name(device)
    else:
        device = "cpu"
        device_name = "CPU"
    
    # Display initial info
    print(f"Running on: {device} ({device_name})")
    print(f"Matrix size: {matrix_size}x{matrix_size}")
    print(f"Epochs: {epochs}")
    print(f"Warmup runs: {warmup_runs}")
    print(f"Number of runs: {runs}\n")

    # Probe supported dtypes
    print("Probing device for supported dtypes...")
    supported_dtypes, unsupported_dtypes = probe_device_dtypes(device)
    print(f"Supported dtypes: {[str(dt).split('.')[-1] for dt in supported_dtypes]}\n")

    # Store results across all runs
    all_results = {dtype: {"times": [], "gops_list": []} for dtype in supported_dtypes}

    # Run benchmarks multiple times
    for run in range(runs):
        print(f"\nRun {run + 1}/{runs}")
        print("-" * 50)
        
        for dtype in supported_dtypes:
            print(f"Running benchmark for {str(dtype).split('.')[-1]}")
            total_time, gops = run_benchmark(device, dtype, matrix_size, epochs, warmup_runs)
            all_results[dtype]["times"].append(total_time)
            all_results[dtype]["gops_list"].append(gops)
            
            # Print individual run stats
            dtype_name = str(dtype).split('.')[-1]
            print(f"{dtype_name:<12} Total Time: {total_time:<6.2f}s  GOPS: {gops:<6.2f}")

    # Calculate and print averages
    print("\nAverage Benchmark Results Across All Runs:")
    print("-" * 50)
    print(f"{'Dtype':<12} {'Avg Total Time (s)':<20} {'Avg GOPS':<10}")
    print("-" * 50)
    
    for dtype in supported_dtypes:
        dtype_name = str(dtype).split('.')[-1]
        avg_time = sum(all_results[dtype]["times"]) / runs
        avg_gops = sum(all_results[dtype]["gops_list"]) / runs
        print(f"{dtype_name:<12} {avg_time:<20.2f} {avg_gops:<10.2f}")

    print("-" * 50)

    # Print supported and unsupported dtypes
    print("\nSummary:")
    print(f"Supported dtypes: {[str(dt).split('.')[-1] for dt in supported_dtypes]}")
    print(f"Unsupported dtypes: {[str(dt).split('.')[-1] for dt in unsupported_dtypes]}")
    
    return all_results

# Example usage in notebook
matrix_size = 4096
epochs = 10
warmup_runs = 10
runs = 3  # Number of times to repeat the full benchmark
device = "cuda:0"  # Change to "xpu:0" for Intel XPU or "cpu" as needed

results = run_multi_dtype_benchmark(
    device_str=device,
    matrix_size=matrix_size,
    epochs=epochs,
    warmup_runs=warmup_runs,
    runs=runs
)