In [1]:
%pip install -e ../.. --no-deps

Obtaining file:///Users/carlostrujillo/Documents/GitHub/pytensor
  Installing build dependencies ... [?25ldone
[?25h  Checking if build backend supports build_editable ... [?25ldone
[?25h  Getting requirements to build editable ... [?25ldone
[?25h  Preparing editable metadata (pyproject.toml) ... [?25ldone
[?25hBuilding wheels for collected packages: pytensor
  Building editable for pytensor (pyproject.toml) ... [?25ldone
[?25h  Created wheel for pytensor: filename=pytensor-2.31.7+80.g06ccf91ba.dirty-0.editable-cp312-cp312-macosx_11_0_arm64.whl size=7323 sha256=c09587a5f3141d49000666d2817c5a01436f13ff5a19aa3deda20f647660afee
  Stored in directory: /private/var/folders/f0/rbz8xs8s17n3k3f_ccp31bvh0000gn/T/pip-ephem-wheel-cache-i00nb67k/wheels/52/f6/4c/e6784e2203d5405c94db1d544248730e598e4397674416af05
Successfully built pytensor
Installing collected packages: pytensor
  Attempting uninstall: pytensor
    Found existing installation: pytensor 2.31.7+80.g06ccf91ba.dirty
    Unins

In [1]:
import time
import numpy as np
import jax
import jax.numpy as jnp

import pytensor
import pytensor.tensor as pt
from pytensor.compile.function import function
from pytensor.compile.mode import Mode
from pytensor.graph import RewriteDatabaseQuery
from pytensor.link.jax import JAXLinker


In [2]:
# Configure JAX to use float32 for consistency with MLX
jax.config.update("jax_enable_x64", False)

# Set up PyTensor JAX mode
jax_optimizer = RewriteDatabaseQuery(include=["jax"], exclude=[])
pytensor_jax_mode = "JAX"

# Try to set up MLX mode
try:
    from pytensor.link.mlx import MLXLinker
    import mlx.core as mx
    mlx_optimizer = RewriteDatabaseQuery(include=["mlx"], exclude=[])
    pytensor_mlx_mode = "MLX"
    MLX_AVAILABLE = True
except ImportError:
    MLX_AVAILABLE = False

def timer_jax(func, N=1000):
    """Time function execution with proper JAX synchronization, repeated N times"""
    def wrapper(*args, **kwargs):
        times = []
        for _ in range(N):
            start = time.perf_counter()
            result = func(*args, **kwargs)
            if hasattr(result, 'block_until_ready'):
                result.block_until_ready()
            elif isinstance(result, (list, tuple)):
                for r in result:
                    if hasattr(r, 'block_until_ready'):
                        r.block_until_ready()
            end = time.perf_counter()
            times.append(end - start)
        
        mean_time = np.mean(times)
        std_time = np.std(times)
        return result, mean_time, std_time
    return wrapper

def timer_mlx(func, N=1000):
    """Time function execution with proper MLX synchronization, repeated N times"""
    def wrapper(*args, **kwargs):
        times = []
        for _ in range(N):
            start = time.perf_counter()
            result = func(*args, **kwargs)
            # For MLX, we need to use mx.eval() to force computation
            if MLX_AVAILABLE:
                if isinstance(result, (list, tuple)):
                    mx.eval(*result)
                else:
                    mx.eval(result)
            end = time.perf_counter()
            times.append(end - start)
        
        mean_time = np.mean(times)
        std_time = np.std(times)
        return result, mean_time, std_time
    return wrapper

def run_benchmark(N=1000):
    """Run comprehensive benchmark comparing PyTensor JAX vs MLX backends"""
    import pandas as pd
    
    sizes = [2, 4, 1080, 2080, 3080]
    results = []
    
    print(f"Running benchmarks with N={N} repetitions per test...")
    
    for size in sizes:
        print(f"Testing {size}x{size} matrices...")
        
        # Generate test matrices with fixed seed for reproducibility
        np.random.seed(42)
        A = np.random.randn(size, size).astype(np.float32)
        B = np.random.randn(size, size).astype(np.float32)
        C = np.random.randn(size, size).astype(np.float32)

        pt_A = pt.matrix('A', dtype='float32')
        pt_B = pt.matrix('B', dtype='float32')  
        pt_C = pt.matrix('C', dtype='float32')
        result = pt.dot(pt.dot(pt_A, pt_B), pt_C)


        f_jax = function([pt_A, pt_B, pt_C], result, mode=pytensor_jax_mode, trust_input=True)
        f_mlx = function([pt_A, pt_B, pt_C], result, mode=pytensor_mlx_mode, trust_input=True)
        f_jax(A, B, C)
        f_mlx(A, B, C)
        
        # === TEST 1: Matrix Multiplication Chain ===
        # PyTensor + JAX backend
        @timer_jax
        def pytensor_jax_matmul():
            return f_jax(A, B, C)
        
        # PyTensor + MLX backend
        @timer_mlx
        def pytensor_mlx_matmul():
            if not MLX_AVAILABLE:
                return None, float('inf'), 0
            return f_mlx(A, B, C)
        
        # Run matrix multiplication test
        _, jax_mean, jax_std = pytensor_jax_matmul()
        try:
            _, mlx_mean, mlx_std = pytensor_mlx_matmul()
        except Exception as e:
            print(f"MLX matmul error: {e}")
            mlx_mean, mlx_std = float('inf'), 0
        
        # Calculate percentage improvement (positive = MLX is faster, negative = MLX is slower)
        if mlx_mean != float('inf') and mlx_mean > 0:
            speedup_percentage = ((jax_mean - mlx_mean) / jax_mean) * 100
            speedup_str = f'{speedup_percentage:+.1f}%'
        else:
            speedup_str = 'N/A'
        
        results.append({
            'Size': f'{size}x{size}',
            'Operation': 'Matrix Chain (A @ B @ C)',
            'PyTensor+JAX Mean (s)': f'{jax_mean:.6f}',
            'PyTensor+JAX Std (s)': f'{jax_std:.6f}',
            'PyTensor+MLX Mean (s)': f'{mlx_mean:.6f}' if mlx_mean != float('inf') else 'Error',
            'PyTensor+MLX Std (s)': f'{mlx_std:.6f}' if mlx_mean != float('inf') else 'N/A',
            'MLX Performance': speedup_str
        })
        
        # === TEST 2: Element-wise Operations ===
        # PyTensor + JAX
        result = pt.sin(pt_A) + pt.cos(pt_B)
        f_jax = function([pt_A, pt_B], result, mode=pytensor_jax_mode, trust_input=True)
        f_mlx = function([pt_A, pt_B], result, mode=pytensor_mlx_mode, trust_input=True)
        f_jax(A, B)
        f_mlx(A, B)

        @timer_jax
        def pytensor_jax_elemwise():
            return f_jax(A, B)
        
        # PyTensor + MLX
        @timer_mlx
        def pytensor_mlx_elemwise():
            if not MLX_AVAILABLE:
                return None, float('inf'), 0
            return f_mlx(A, B)
        
        # Run element-wise test
        _, jax_mean, jax_std = pytensor_jax_elemwise()
        try:
            _, mlx_mean, mlx_std = pytensor_mlx_elemwise()
        except Exception as e:
            print(f"MLX elemwise error: {e}")
            mlx_mean, mlx_std = float('inf'), 0
        
        # Calculate percentage improvement
        if mlx_mean != float('inf') and mlx_mean > 0:
            speedup_percentage = ((jax_mean - mlx_mean) / jax_mean) * 100
            speedup_str = f'{speedup_percentage:+.1f}%'
        else:
            speedup_str = 'N/A'
        
        results.append({
            'Size': f'{size}x{size}',
            'Operation': 'Element-wise (sin(A) + cos(B))',
            'PyTensor+JAX Mean (s)': f'{jax_mean:.6f}',
            'PyTensor+JAX Std (s)': f'{jax_std:.6f}',
            'PyTensor+MLX Mean (s)': f'{mlx_mean:.6f}' if mlx_mean != float('inf') else 'Error',
            'PyTensor+MLX Std (s)': f'{mlx_std:.6f}' if mlx_mean != float('inf') else 'N/A',
            'MLX Performance': speedup_str
        })
        
        # === TEST 3: Matrix Addition with Broadcasting ===
        # PyTensor + JAX
        result = pt_A + pt_B.T
        f_jax = function([pt_A, pt_B], result, mode=pytensor_jax_mode, trust_input=True)
        f_mlx = function([pt_A, pt_B], result, mode=pytensor_mlx_mode, trust_input=True)
        f_jax(A, B)
        f_mlx(A, B)
        @timer_jax
        def pytensor_jax_broadcast():
            return f_jax(A, B)
        
        # PyTensor + MLX
        @timer_mlx
        def pytensor_mlx_broadcast():
            if not MLX_AVAILABLE:
                return None, float('inf'), 0
            return f_mlx(A, B)
        
        # Run broadcasting test
        _, jax_mean, jax_std = pytensor_jax_broadcast()
        try:
            _, mlx_mean, mlx_std = pytensor_mlx_broadcast()
        except Exception as e:
            print(f"MLX broadcast error: {e}")
            mlx_mean, mlx_std = float('inf'), 0
        
        # Calculate percentage improvement
        if mlx_mean != float('inf') and mlx_mean > 0:
            speedup_percentage = ((jax_mean - mlx_mean) / jax_mean) * 100
            speedup_str = f'{speedup_percentage:+.1f}%'
        else:
            speedup_str = 'N/A'
        
        results.append({
            'Size': f'{size}x{size}',
            'Operation': 'Broadcasting (A + B.T)',
            'PyTensor+JAX Mean (s)': f'{jax_mean:.6f}',
            'PyTensor+JAX Std (s)': f'{jax_std:.6f}',
            'PyTensor+MLX Mean (s)': f'{mlx_mean:.6f}' if mlx_mean != float('inf') else 'Error',
            'PyTensor+MLX Std (s)': f'{mlx_std:.6f}' if mlx_mean != float('inf') else 'N/A',
            'MLX Performance': speedup_str
        })
    
    # Create and display results table
    df = pd.DataFrame(results)
    return df

def main(N=1000):
    """Main benchmark execution"""
    # Display system info
    system_info = {
        'JAX version': jax.__version__,
        'PyTensor version': pytensor.__version__,
        'MLX Available': 'Yes' if MLX_AVAILABLE else 'No',
        'Platform': 'Apple Silicon' if MLX_AVAILABLE else 'Generic',
        'Repetitions (N)': N
    }
    
    if MLX_AVAILABLE:
        system_info['MLX version'] = mx.__version__
    
    import pandas as pd
    info_df = pd.DataFrame([system_info])
    
    # Then run benchmarks
    results_df = run_benchmark(N=N)
    
    return info_df, results_df


In [3]:
iteration=150
_, results = main(N=iteration)

Running benchmarks with N=150 repetitions per test...
Testing 2x2 matrices...
Testing 4x4 matrices...
Testing 1080x1080 matrices...
Testing 2080x2080 matrices...
Testing 3080x3080 matrices...


In [4]:
print(f"\nBenchmark Results over {iteration} repetitions:")
print(results.to_string(index=False))


Benchmark Results over 150 repetitions:
     Size                      Operation PyTensor+JAX Mean (s) PyTensor+JAX Std (s) PyTensor+MLX Mean (s) PyTensor+MLX Std (s) MLX Performance
      2x2       Matrix Chain (A @ B @ C)              0.000009             0.000002              0.000305             0.000299        -3213.5%
      2x2 Element-wise (sin(A) + cos(B))              0.000007             0.000002              0.000352             0.003757        -5078.0%
      2x2         Broadcasting (A + B.T)              0.000007             0.000001              0.000188             0.000153        -2721.1%
      4x4       Matrix Chain (A @ B @ C)              0.000009             0.000001              0.000209             0.000063        -2126.2%
      4x4 Element-wise (sin(A) + cos(B))              0.000007             0.000001              0.000180             0.000066        -2449.5%
      4x4         Broadcasting (A + B.T)              0.000007             0.000003              0.00

In [None]:
# # Additional timing analysis - separate compilation vs execution time
# if MLX_AVAILABLE:
#     print("\n=== Detailed MLX Timing Analysis ===")
    
#     # Test with medium-sized matrix
#     np.random.seed(42)
#     A = np.random.randn(512, 512).astype(np.float32)
#     B = np.random.randn(512, 512).astype(np.float32)
#     C = np.random.randn(512, 512).astype(np.float32)
    
#     # Create PyTensor function (compilation time)
#     start = time.perf_counter()
#     pt_A = pt.matrix('A', dtype='float32')
#     pt_B = pt.matrix('B', dtype='float32')
#     pt_C = pt.matrix('C', dtype='float32')
#     result_expr = pt_A @ pt_B @ pt_C
#     f_mlx = function([pt_A, pt_B, pt_C], result_expr, mode=pytensor_mlx_mode)
#     compilation_time = time.perf_counter() - start
    
#     # First execution (may include additional compilation/optimization)
#     start = time.perf_counter()
#     result = f_mlx(A, B, C)
#     mx.eval(result)  # Force evaluation
#     first_exec_time = time.perf_counter() - start
    
#     # Subsequent executions (should be faster)
#     exec_times = []
#     for _ in range(1000):
#         start = time.perf_counter()
#         result = f_mlx(A, B, C)
#         mx.eval(result)
#         exec_times.append(time.perf_counter() - start)
    
#     avg_exec_time = np.mean(exec_times)
#     std_exec_time = np.std(exec_times)
    
#     print(f"Compilation time: {compilation_time:.4f}s")
#     print(f"First execution: {first_exec_time:.4f}s")
#     print(f"Average execution (5 runs): {avg_exec_time:.4f}s ± {std_exec_time:.4f}s")
#     print(f"Individual execution times: {[f'{t:.4f}' for t in exec_times]}")
