Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pin jaxlib to use cuda120 build with cuda-nvcc and update README note #549

Merged
merged 8 commits into from
Jul 1, 2024

Conversation

weiji14
Copy link
Member

@weiji14 weiji14 commented May 21, 2024

Adding a explicit pin on jaxlib=*=cuda120* to pick up the cuda120 build version, as the conda-lock solver was picking up the cpu build of jaxlib instead of the cuda build in #514. See:

- name: jaxlib
version: 0.4.23
manager: conda
platform: linux-64
dependencies:
libabseil: '>=20230802.1,<20230803.0a0'
libgcc-ng: '>=12'
libgrpc: '>=1.59.3,<1.60.0a0'
libstdcxx-ng: '>=12'
libzlib: '>=1.2.13,<1.3.0a0'
ml_dtypes: '>=0.2.0'
numpy: '>=1.23.5,<2.0a0'
openssl: '>=3.2.1,<4.0a0'
python: '>=3.11,<3.12.0a0'
python_abi: 3.11.*
scipy: '>=1.9'
url: https://conda.anaconda.org/conda-forge/linux-64/jaxlib-0.4.23-cpu_py311hc0fb0b9_0.conda

Edit: the cuda120 build is actually picked up automatically now as of the 2024.06.02 tag, but setting the cuda120 pin still to be sure that this doesn't break in the future.

Note that jaxlib-0.4.23-cuda120py* has an explicit runtime dependency on cuda-nvcc since conda-forge/jaxlib-feedstock#241, so this should mean users won't have to install cuda-nvcc explicitly anymore.

The cuda-nvcc workaround in the main README.md file has also been removed, in place of a message recommending users to use ml-notebook>=2024.06.02.

Fixes #438

The conda-lock solver was picking up the cpu build of jaxlib instead of the cuda build. Adding a explicit pin here to pick up the cuda120 build version.
@weiji14 weiji14 self-assigned this May 21, 2024
@pangeo-bot
Copy link
Collaborator

/condalock
Automatically locking new conda environment, building, and testing images...

Copy link
Contributor

Binder 👈 Try on Mybinder.org!

@weiji14
Copy link
Member Author

weiji14 commented May 21, 2024

Pulling in jaxlib with the cuda120 build doesn't work yet, see https://github.com/pangeo-data/pangeo-docker-images/actions/runs/9181646795/job/25248830043#step:4:56:

INFO:conda_lock.conda_lock:Using virtual packages from virtual-packages.yml
Locking dependencies for ['linux-64']...
INFO:conda_lock.conda_solver:linux-64 using specs ['cuda-version >=12.0', 'flax >=0.8.0', 'jax', 'jaxlib >=0.4.23 cuda120*', 'jupyterlab-nvdashboard', 'keras-cv', 'tensorflow >=2.15.0 cuda120*', 'adlfs', 'argopy', 'awscli', 'black', 'boto3', 'bottleneck', 'cartopy', 'cdsapi', 'cfgrib', 'cf_xarray', 'ciso', 'cmocean', 'dask-ml', 'datashader', 'descartes', 'earthaccess', 'eofs', 'erddapy', 'esmpy', 'fastjmd95', 'flox', 'fsspec', 'gcm_filters', 'gcsfs', 'gh', 'gh-scoped-creds', 'geocube', 'geopandas', 'geopy', 'geoviews-core', 'git-lfs', 'gsw', 'h5netcdf', 'h5py', 'holoviews', 'hvplot', 'intake', 'intake-esm', 'intake-geopandas', 'intake-stac', 'intake-xarray', 'ipdb', 'ipykernel', 'ipyleaflet', 'ipytree', 'ipywidgets', 'jupyterlab_code_formatter', 'jupyterlab-git', 'jupyterlab-lsp', 'jupyterlab-myst', 'jupyter-panel-proxy', 'jupyter-resource-usage', 'kerchunk', 'line_profiler', 'lxml', 'lz4', 'matplotlib-base', 'memory_profiler', 'metpy', 'nb_conda_kernels', 'nbstripout', 'nc-time-axis', 'netcdf4', 'numbagg', 'numcodecs', 'numpy', 'numpy_groupies', 'odc-stac', 'pandas', 'panel', 'parcels', 'param', 'pop-tools', 'pyarrow', 'pycamhd', 'pydap', 'pystac', 'pystac-client', 'python-blosc', 'python-gist', 'python-graphviz', 'python-lsp-ruff', 'python-xxhash', 'rasterio', 'rechunker', 'rio-cogeo', 'rioxarray', 'ruff', 's3fs', 'satpy', 'scikit-image', 'scikit-learn', 'scipy', 'seaborn', 'sparse', 'snakeviz', 'stackstac', 'tiledb-py', 'timezonefinder', 'watermark', 'xarray', 'xarrayutils', 'xarray-datatree', 'xarray_leaflet', 'xarray-spatial', 'xbatcher', 'xcape', 'xclim', 'xesmf', 'xgboost', 'xgcm', 'xhistogram', 'xmip', 'xmitgcm', 'xpublish', 'xrft', 'xskillscore', 'xxhash', 'zarr', 'python 3.11.*', 'pangeo-notebook 2024.05.20.*', 'pip']
Failed to parse json, Expecting value: line 1 column 1 (char 0)
Could not lock the environment for platform linux-64
Could not solve for environment specs
The following packages are incompatible
├─ jaxlib >=0.4.23 cuda120* is installable with the potential options
│  ├─ jaxlib 0.4.23 would require
│  │  └─ libabseil >=20240116.1,<20240117.0a0 , which can be installed;
│  ├─ jaxlib 0.4.23 would require
│  │  └─ libabseil >=20240116.2,<20240117.0a0 , which can be installed;
│  └─ jaxlib 0.4.23 would require
│     └─ python_abi 3.12.* *_cp312, which requires
│        └─ python 3.12.* *_cpython, which can be installed;
├─ python 3.11**  is not installable because it conflicts with any installable versions previously reported;
└─ tensorflow >=2.15.0 cuda120* is not installable because it requires
   └─ tensorflow-base [2.15.0 cuda120py310heceb7ac_2|2.15.0 cuda120py310heceb7ac_3|...|2.15.0 cuda120py39hf42b710_3], which requires
      └─ libabseil >=20230802.1,<20230803.0a0 , which conflicts with any installable versions previously reported.
{
    "success": false
}

Need to wait for newer version of tensorflow on conda-forge to use libabseil>=20240116, wait for conda-forge/tensorflow-feedstock#372 or conda-forge/tensorflow-feedstock#385

@weiji14 weiji14 mentioned this pull request May 21, 2024
README.md Outdated
@@ -54,7 +54,7 @@ The primary use of these Docker images is running on Pangeo Cloud deployments wi
* Since 2020.10.16, [mamba](https://github.com/mamba-org/mamba) is installed into the base-image and conda-lock environment and is used by default to solve for a compatible environment (see #146)
* For a simple list of packages for a given image, you can use a link like this: https://github.com/pangeo-data/pangeo-docker-images/blob/2020.10.08/pangeo-notebook/packages.txt
* To compare changes between two images, you can use a link like this: https://github.com/pangeo-data/pangeo-docker-images/compare/2020.10.03..2020.10.08
* Our `ml-notebook` image now contains JAX and TensorFlow with XLA enabled. Due to licensing issues, conda-forge does not have `ptxas`, but `ptxas` is needed for XLA to work correctly. Should you like to use JAX and/or TensorFlow with XLA optimization, please install `ptxas` on your own, for example, by `conda install -c nvidia cuda-nvcc`. At the time of writing (October 2022), JAX throws a compilation error if the `ptxas` version is higher than the driver version. There does not exist an easy solution for K80 GPUs, but in the case of T4 GPUs, you should install `conda install -c nvidia cuda-nvcc==11.6.*` to be safe. Alternatively for any GPU, you could set an environment variable to resolve the error caused by JAX: `XLA_FLAGS="--xla_gpu_force_compilation_parallelism=1"`. The aforementioned error will be removed (and likely turned into a warning) in a future version of JAX. See https://github.com/google/jax/issues/12776#issuecomment-1276649134
* Our `ml-notebook` image now contains JAX and TensorFlow with XLA enabled. Due to licensing issues, conda-forge does not have `ptxas`, but `ptxas` is needed for XLA to work correctly. Should you like to use JAX and/or TensorFlow with XLA optimization, please install `ptxas` on your own, for example, by `conda install -c nvidia cuda-nvcc`. At the time of writing (October 2022), JAX throws a compilation error if the `ptxas` version is higher than the driver version. There does not exist an easy solution for K80 GPUs, but in the case of T4 GPUs or newer, you should install `conda install -c nvidia cuda-nvcc==12.*` to be safe. Alternatively for any GPU, you could set an environment variable to resolve the error caused by JAX: `XLA_FLAGS="--xla_gpu_force_compilation_parallelism=1"`. The aforementioned error will be removed (and likely turned into a warning) in a future version of JAX. See https://github.com/google/jax/issues/12776#issuecomment-1276649134
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be ok to remove this note, once we have cuda-nvcc pulled in as a dependency of jaxlib. Ideally, we'll add a unit test to https://github.com/pangeo-data/pangeo-docker-images/blob/master/tests/test_ml-notebook.py (maybe using the snippet from #387 (comment)) to ensure that jax works.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the uninitiated, I'm not even sure what XLA is :) In light of that, could add a sentence here to give context for why this matters (or just link https://openxla.org/xla).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To be honest, I didn't know what the XLA acronym stands for either 😆 I've added a link to that XLA page in commit a5fb3e9, though users shouldn't need to dig into this too much since we're using XLA-enabled builds of Tensorflow by default now.

Regression test to ensure that JaX and cuda-nvcc are installed and compatible on GPU devices.
Since we're using a version of GitHub Actions without a GPU, the JaX random number generator test on GPU cannot be properly tested, so wrapping the check in a try-except block. Also tidied up some import statements and docstrings in the test file.
Since `cuda-nvcc` is installed with jaxlib's cuda120 builds, the workaround to conda install cuda-nvcc should not be needed anymore.
@weiji14 weiji14 changed the title Pin jaxlib to use cuda120 build with cuda-nvcc dependency Pin jaxlib to use cuda120 build with cuda-nvcc and update README note Jun 7, 2024
Comment on lines +54 to +62
# Test running on GPU (need to run locally)
try:
gpu_device = jax.devices("gpu")[0]
with jax.default_device(gpu_device):
key = random.key(seed=24)
x = random.normal(key=key)
np.testing.assert_allclose(x, -1.168644)
except RuntimeError: # Unknown backend: 'gpu' requested
logging.log(level=logging.INFO, msg="JAX was not tested on a GPU device")
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't work on GitHub Actions without a GPU, but I ran this locally using:

mamba create -n ml-notebook --file https://raw.githubusercontent.com/pangeo-data/pangeo-docker-images/2024.06.02/ml-notebook/conda-linux-64.lock
mamba activate ml-notebook
pytest --verbose tests/test_ml-notebook.py

The tests passed on my computer with an NVIDIA GPU, so should be ok I think.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@weiji14 weiji14 marked this pull request as ready for review June 7, 2024 03:52
Copy link
Member

@scottyhq scottyhq left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the notes @weiji14, sorry I overlooked approving this until now. merge when you're happy with it!

@@ -8,6 +8,7 @@ dependencies:
- cuda-version>=12.0
- flax>=0.8.0
- jax
- jaxlib>=0.4.23=cuda120*
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This pin on using a cuda120* build of jaxlib might actually need to go away if we're planning on adding multi-arch builds of ml-notebook (see #399 (review) for context), but will deal with this later since the same pin is present on Tensorflow.

@weiji14 weiji14 merged commit 09649eb into master Jul 1, 2024
5 checks passed
@weiji14 weiji14 deleted the jaxlib-cuda-build branch July 1, 2024 01:14
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

cuda-nvcc missing again
3 participants