## PyTorch


In [1]:
import torch


def create_torch_tensors(device):
    x = torch.rand((10000, 10000), dtype=torch.float32)
    y = torch.rand((10000, 10000), dtype=torch.float32)
    x = x.to(device)
    y = y.to(device)

    return x, y

In [2]:
device = torch.device("cpu")
x, y = create_torch_tensors(device)

In [3]:
%%timeit
x * y

30.7 ms ± 3.74 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [4]:
device = torch.device("mps")
# device = torch.device("cuda") # specify CUDA to try this on an NVIDIA GPU
x, y = create_torch_tensors(device)

In [5]:
%%timeit
x * y

9.8 ms ± 109 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


## TensorFlow


In [6]:
import tensorflow as tf


def create_tf_tensors():
    x = tf.random.uniform((10000, 10000), dtype=tf.float32)
    y = tf.random.uniform((10000, 10000), dtype=tf.float32)

    return x, y


x, y = create_tf_tensors()

In [7]:
%%timeit

with tf.device("/CPU:0"):
    x * y

34.2 ms ± 906 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [8]:
%%timeit

with tf.device("/GPU:0"):
    x * y

9.74 ms ± 66.1 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


## JAX


In [10]:
import jax
import jax.numpy as jnp

In [11]:
# using default GPU
def create_jax_tensors():
    x = jax.random.uniform(jax.random.PRNGKey(0), (10000, 10000), dtype=jnp.float32)
    y = jax.random.uniform(jax.random.PRNGKey(1), (10000, 10000), dtype=jnp.float32)

    return x, y


x, y = create_jax_tensors()

Platform 'METAL' is experimental and not all JAX functionality may be correctly supported!


Metal device set to: Apple M3 Pro

systemMemory: 36.00 GB
maxCacheSize: 13.50 GB



I0000 00:00:1745919114.662170 12956740 service.cc:145] XLA service 0x40ffcb430 initialized for platform METAL (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1745919114.662185 12956740 service.cc:153]   StreamExecutor device (0): Metal, <undefined>
I0000 00:00:1745919114.663527 12956740 mps_client.cc:406] Using Simple allocator.
I0000 00:00:1745919114.663538 12956740 mps_client.cc:384] XLA backend will use up to 25643024384 bytes on device 0 for SimpleAllocator.


In [12]:
%%timeit
x * y

9.59 ms ± 15.1 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [13]:
# Forcing JAX to use CPU for comparison
cpu_device = jax.devices("cpu")[0]

with jax.default_device(cpu_device):
    def create_jax_tensors():
        x = jax.random.uniform(jax.random.PRNGKey(0), (10000, 10000), dtype=jnp.float32)
        y = jax.random.uniform(jax.random.PRNGKey(1), (10000, 10000), dtype=jnp.float32)

        return x, y


    x, y = create_jax_tensors()

In [14]:
%%timeit
x * y

105 ms ± 532 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)
