Skip to content

Commit

Permalink
WIP attempt to update packages but GPU stopped working
Browse files Browse the repository at this point in the history
  • Loading branch information
tartavull committed Aug 20, 2023
1 parent d0516ba commit 851b743
Show file tree
Hide file tree
Showing 6 changed files with 62 additions and 145 deletions.
78 changes: 48 additions & 30 deletions flake.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions nix/chex.nix
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
, pytestCheckHook
, toolz
, cloudpickle
, typing-extensions
}:

buildPythonPackage rec {
Expand All @@ -30,6 +31,7 @@ buildPythonPackage rec {
jax
numpy
toolz
typing-extensions
];

postPatch = ''
Expand Down
9 changes: 5 additions & 4 deletions nix/flax.nix
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,12 @@

buildPythonPackage rec {
name = "flax";
format = "pyproject";
src = fetchFromGitHub {
owner = "google";
repo = "flax";
rev = "v0.6.5";
hash = "sha256-Vv68BK83gTIKj0r9x+twdhqmRYziD0vxQCdHkYSeTak=";
rev = "v0.7.2";
hash = "sha256-Zj2xwtUBYrr0lwSjKn8bLHiBtKB0ZUFif7byHoGSZvg=";
};
propagatedBuildInputs = [
jax
Expand All @@ -25,8 +26,8 @@ buildPythonPackage rec {
matplotlib
];
postPatch = ''
sed -i '/tensorstore/d' setup.py
sed -i '/tensorstore/d' pyproject.toml
sed -i '/orbax/d' pyproject.toml
'';
doCheck = false;
pythonRemoveDeps = [ "orbax" ];
}
108 changes: 0 additions & 108 deletions nix/jax.nix
Original file line number Diff line number Diff line change
@@ -1,108 +0,0 @@
{ lib
, absl-py
, blas
, buildPythonPackage
, etils
, fetchFromGitHub
, lapack
, matplotlib
, numpy
, opt-einsum
, pytestCheckHook
, pytest-xdist
, pythonOlder
, scipy
, typing-extensions
, jaxlib
}:

let
usingMKL = blas.implementation == "mkl" || lapack.implementation == "mkl";
in
buildPythonPackage rec {
pname = "jax";
version = "0.4.10";
format = "setuptools";

disabled = pythonOlder "3.7";

src = fetchFromGitHub {
owner = "google";
repo = pname;
rev = "jax-v${version}";
hash = "sha256-USdEVEcZ28YHDJQDzWzpWdBQQimi27xe5Quc9dESoXw=";
};

# jaxlib is _not_ included in propagatedBuildInputs because there are
# different versions of jaxlib depending on the desired target hardware. The
# JAX project ships separate wheels for CPU, GPU, and TPU. Currently only the
# CPU wheel is packaged.
propagatedBuildInputs = [
absl-py
etils
numpy
opt-einsum
scipy
typing-extensions
jaxlib
] ++ etils.optional-dependencies.epath;

checkInputs = [
jaxlib
matplotlib
pytestCheckHook
pytest-xdist
];

# high parallelism will result in the tests getting stuck
dontUsePytestXdist = true;

# NOTE: Don't run the tests in the expiremental directory as they require flax
# which creates a circular dependency. See https://discourse.nixos.org/t/how-to-nix-ify-python-packages-with-circular-dependencies/14648/2.
# Not a big deal, this is how the JAX docs suggest running the test suite
# anyhow.
pytestFlagsArray = [
"--numprocesses=4"
"-W ignore::DeprecationWarning"
"tests/"
];

disabledTests = [
# Exceeds tolerance when the machine is busy
"test_custom_linear_solve_aux"
] ++ lib.optionals usingMKL [
# See
# * https://github.com/google/jax/issues/9705
# * https://discourse.nixos.org/t/getting-different-results-for-the-same-build-on-two-equally-configured-machines/17921
# * https://github.com/NixOS/nixpkgs/issues/161960
"test_custom_linear_solve_cholesky"
"test_custom_root_with_aux"
"testEigvalsGrad_shape"
];

doCheck = false; # Disable running checks during the build process


# See https://github.com/google/jax/issues/11722. This is a temporary fix in
# order to unblock etils, and upgrading jax/jaxlib to the latest version. See
# https://github.com/NixOS/nixpkgs/issues/183173#issuecomment-1204074993.
disabledTestPaths = [
"tests/api_test.py"
"tests/core_test.py"
"tests/lax_numpy_indexing_test.py"
"tests/lax_numpy_test.py"
"tests/nn_test.py"
"tests/random_test.py"
"tests/sparse_test.py"
];

# As of 0.3.22, `import jax` does not work without jaxlib being installed.
pythonImportsCheck = [ ];

meta = with lib; {
description = "Differentiate, compile, and transform Numpy code";
homepage = "https://github.com/google/jax";
license = licenses.asl20;
maintainers = with maintainers; [ samuela ];
};
}
2 changes: 2 additions & 0 deletions nix/jaxopt.nix
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
, fetchFromGitHub
, matplotlib
, scikit-learn
, typing-extensions
}:

buildPythonPackage rec {
Expand All @@ -27,6 +28,7 @@ buildPythonPackage rec {
jax
matplotlib
scikit-learn
typing-extensions
];

doCheck = true;
Expand Down
8 changes: 5 additions & 3 deletions nix/orbax.nix
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,17 @@
, numpy
, pyyaml
, tensorflow
, importlib-resources
#, tensorstore
}:

buildPythonPackage rec {
name = "orbax";
name = "orbax-checkpoint";
src = fetchFromGitHub {
owner = "google";
repo = "orbax";
rev = "v0.1.6";
hash = "sha256-Vkqt2ovTan6bQJI4Il06hG0NlYmt60to4ue4U9qG9HY=";
rev = "v0.1.7";
hash = "sha256-Zk9hbvSA82jt0wLR7AZWEmHDA4A1+9t0ezf74FYkqe0=";
};
format = "pyproject";

Expand All @@ -36,6 +37,7 @@ buildPythonPackage rec {
numpy
pyyaml
tensorflow
importlib-resources
# tensorstore
];

Expand Down

0 comments on commit 851b743

Please sign in to comment.