Skip to content

Set up JAX sampling with GPUs in PyMC

Thomas Wiecki edited this page Nov 8, 2023 · 1 revision

Set up environment for JAX sampling with GPU supports in PyMC

This guide show the steps to set-up and run JAX sampling with GPU supports in PyMC. The step-by-step as follow:

1. Install Ubuntu 20.04.4 LTS (Focal Fossa)

The latest Ubuntu version is 22.04, but I'm a little bit conservative, so decided to install version 20.04. I download the 64-bit PC (AMD64) desktop image from here.

I made a Bootable USB using Rufus with the above ubuntu desktop .iso image. You can check this video How to Make Ubuntu 20.04 Bootable USB Drive. I assume that you have a NVIDIA GPU card on your local machine, and you know how to install ubuntu from a bootable USB. If not, you can just search it on Youtube.

2. Install NVIDIA Driver, CUDA 11.4, cuDNN v8.2.4

According to Jax's guidelines, to install GPU support for Jax, first we need to install CUDA and CuDNN.

To do that, I follow the Installation of NVIDIA Drivers, CUDA and cuDNN from this guideline (Kudo the author Ashutosh Kumar for this).

One note is that we may not be able to find a specific version of NVIDIA Drivers on this step. Instead, we can go to this url: https://download.nvidia.com/XFree86/Linux-x86_64/ to download our specific driver version. For my case, I download the file NVIDIA-Linux-x86_64-470.82.01.run at this link: https://download.nvidia.com/XFree86/Linux-x86_64/470.82.01/

After successfully following these steps in the guideline, we can run nvidia-smi and nvcc --version commands to verify the installation. In my case, it will be somethings like this:

Screenshot from 2022-05-29 17-09-50

3. Install Jax with GPU supports

Following the Jax's guidelines, after installing CUDA and CuDNN, we can using pip to install Jax with GPU support.

pip install --upgrade pip
# Installs the wheel compatible with CUDA 12 and cuDNN 8.2 or newer.
# Note: wheels only available on linux.
pip install -U "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

Check if GPU device is available in Jax

We can then run Ipython or python and using these following commands to check.

In [1]: import jax
In [2]: jax.default_backend()
Out[2]: 'gpu'
In [3]: jax.devices()
Out[3]: [GpuDevice(id=0, process_index=0)]

That's it. We have successfully installed Jax with GPU support. Now, we can run JAX-based sampling pm.sample(nuts_sampler="numpyro", ...) in PyMC with the GPU capability.