In [1]:
import jax

In [3]:
devices = jax.devices()
devices

[CudaDevice(id=0)]

In [4]:
devices[0].device_kind

'NVIDIA GeForce RTX 3060'

In [5]:
for device in devices:
    print(device.device_kind.lower())

nvidia geforce rtx 3060


In [17]:
import os
os.environ["OPENBLAS_NUM_THREADS"] = "1" # numpyに許可する並列化を制限

import jax
import jax.numpy as jnp
import time

# 大きな行列を作成
size = 20000
x = jnp.ones((size, size)) * 0.01

# 計算の時間を測定する関数
def measure_time(func):
    start_time = time.perf_counter()
    result = func()
    end_time = time.perf_counter()
    return result, end_time - start_time

# CPUで計算
def cpu_calculation():
    ret = jax.device_put(x, device=jax.devices("cpu")[0]) @ x
    ret.block_until_ready()
    return ret

# GPUで計算
def gpu_calculation():
    ret = jax.device_put(x, device=jax.devices("gpu")[0]) @ x
    ret.block_until_ready()
    return ret

# GPUでの計算時間
ret_gpu, gpu_time = measure_time(gpu_calculation)
print(f"GPU calculation time: {gpu_time:.6f} seconds")

# CPUでの計算時間
ret_cpu, cpu_time = measure_time(cpu_calculation)
print(f"CPU calculation time: {cpu_time:.6f} seconds")

GPU calculation time: 1.578235 seconds
CPU calculation time: 42.021768 seconds


In [11]:
ret_gpu, ret_cpu

(Array([[0.5001807, 0.5001807, 0.5001807, ..., 0.5001807, 0.5001807,
         0.5001807],
        [0.5001807, 0.5001807, 0.5001807, ..., 0.5001807, 0.5001807,
         0.5001807],
        [0.5001807, 0.5001807, 0.5001807, ..., 0.5001807, 0.5001807,
         0.5001807],
        ...,
        [0.5001807, 0.5001807, 0.5001807, ..., 0.5001807, 0.5001807,
         0.5001807],
        [0.5001807, 0.5001807, 0.5001807, ..., 0.5001807, 0.5001807,
         0.5001807],
        [0.5001807, 0.5001807, 0.5001807, ..., 0.5001807, 0.5001807,
         0.5001807]], dtype=float32),
 Array([[0.49999917, 0.49999917, 0.49999917, ..., 0.49999982, 0.49999982,
         0.49999982],
        [0.49999917, 0.49999917, 0.49999917, ..., 0.49999982, 0.49999982,
         0.49999982],
        [0.49999917, 0.49999917, 0.49999917, ..., 0.49999982, 0.49999982,
         0.49999982],
        ...,
        [0.49999917, 0.49999917, 0.49999917, ..., 0.49999982, 0.49999982,
         0.49999982],
        [0.49999982, 0.49999982, 