Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP attempt to update packages but GPU stopped working #29

Merged
merged 1 commit into from
Aug 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 48 additions & 30 deletions flake.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions nix/chex.nix
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
, pytestCheckHook
, toolz
, cloudpickle
, typing-extensions
}:

buildPythonPackage rec {
Expand All @@ -30,6 +31,7 @@ buildPythonPackage rec {
jax
numpy
toolz
typing-extensions
];

postPatch = ''
Expand Down
9 changes: 5 additions & 4 deletions nix/flax.nix
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,12 @@

buildPythonPackage rec {
name = "flax";
format = "pyproject";
src = fetchFromGitHub {
owner = "google";
repo = "flax";
rev = "v0.6.5";
hash = "sha256-Vv68BK83gTIKj0r9x+twdhqmRYziD0vxQCdHkYSeTak=";
rev = "v0.7.2";
hash = "sha256-Zj2xwtUBYrr0lwSjKn8bLHiBtKB0ZUFif7byHoGSZvg=";
};
propagatedBuildInputs = [
jax
Expand All @@ -25,8 +26,8 @@ buildPythonPackage rec {
matplotlib
];
postPatch = ''
sed -i '/tensorstore/d' setup.py
sed -i '/tensorstore/d' pyproject.toml
sed -i '/orbax/d' pyproject.toml
'';
doCheck = false;
pythonRemoveDeps = [ "orbax" ];
}
108 changes: 0 additions & 108 deletions nix/jax.nix
Original file line number Diff line number Diff line change
@@ -1,108 +0,0 @@
{ lib
, absl-py
, blas
, buildPythonPackage
, etils
, fetchFromGitHub
, lapack
, matplotlib
, numpy
, opt-einsum
, pytestCheckHook
, pytest-xdist
, pythonOlder
, scipy
, typing-extensions
, jaxlib
}:

let
usingMKL = blas.implementation == "mkl" || lapack.implementation == "mkl";
in
buildPythonPackage rec {
pname = "jax";
version = "0.4.10";
format = "setuptools";

disabled = pythonOlder "3.7";

src = fetchFromGitHub {
owner = "google";
repo = pname;
rev = "jax-v${version}";
hash = "sha256-USdEVEcZ28YHDJQDzWzpWdBQQimi27xe5Quc9dESoXw=";
};

# jaxlib is _not_ included in propagatedBuildInputs because there are
# different versions of jaxlib depending on the desired target hardware. The
# JAX project ships separate wheels for CPU, GPU, and TPU. Currently only the
# CPU wheel is packaged.
propagatedBuildInputs = [
absl-py
etils
numpy
opt-einsum
scipy
typing-extensions
jaxlib
] ++ etils.optional-dependencies.epath;

checkInputs = [
jaxlib
matplotlib
pytestCheckHook
pytest-xdist
];

# high parallelism will result in the tests getting stuck
dontUsePytestXdist = true;

# NOTE: Don't run the tests in the expiremental directory as they require flax
# which creates a circular dependency. See https://discourse.nixos.org/t/how-to-nix-ify-python-packages-with-circular-dependencies/14648/2.
# Not a big deal, this is how the JAX docs suggest running the test suite
# anyhow.
pytestFlagsArray = [
"--numprocesses=4"
"-W ignore::DeprecationWarning"
"tests/"
];

disabledTests = [
# Exceeds tolerance when the machine is busy
"test_custom_linear_solve_aux"
] ++ lib.optionals usingMKL [
# See
# * https://github.com/google/jax/issues/9705
# * https://discourse.nixos.org/t/getting-different-results-for-the-same-build-on-two-equally-configured-machines/17921
# * https://github.com/NixOS/nixpkgs/issues/161960
"test_custom_linear_solve_cholesky"
"test_custom_root_with_aux"
"testEigvalsGrad_shape"
];

doCheck = false; # Disable running checks during the build process


# See https://github.com/google/jax/issues/11722. This is a temporary fix in
# order to unblock etils, and upgrading jax/jaxlib to the latest version. See
# https://github.com/NixOS/nixpkgs/issues/183173#issuecomment-1204074993.
disabledTestPaths = [
"tests/api_test.py"
"tests/core_test.py"
"tests/lax_numpy_indexing_test.py"
"tests/lax_numpy_test.py"
"tests/nn_test.py"
"tests/random_test.py"
"tests/sparse_test.py"
];

# As of 0.3.22, `import jax` does not work without jaxlib being installed.
pythonImportsCheck = [ ];

meta = with lib; {
description = "Differentiate, compile, and transform Numpy code";
homepage = "https://github.com/google/jax";
license = licenses.asl20;
maintainers = with maintainers; [ samuela ];
};
}
2 changes: 2 additions & 0 deletions nix/jaxopt.nix
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
, fetchFromGitHub
, matplotlib
, scikit-learn
, typing-extensions
}:

buildPythonPackage rec {
Expand All @@ -27,6 +28,7 @@ buildPythonPackage rec {
jax
matplotlib
scikit-learn
typing-extensions
];

doCheck = true;
Expand Down
8 changes: 5 additions & 3 deletions nix/orbax.nix
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,17 @@
, numpy
, pyyaml
, tensorflow
, importlib-resources
#, tensorstore
}:

buildPythonPackage rec {
name = "orbax";
name = "orbax-checkpoint";
src = fetchFromGitHub {
owner = "google";
repo = "orbax";
rev = "v0.1.6";
hash = "sha256-Vkqt2ovTan6bQJI4Il06hG0NlYmt60to4ue4U9qG9HY=";
rev = "v0.1.7";
hash = "sha256-Zk9hbvSA82jt0wLR7AZWEmHDA4A1+9t0ezf74FYkqe0=";
};
format = "pyproject";

Expand All @@ -36,6 +37,7 @@ buildPythonPackage rec {
numpy
pyyaml
tensorflow
importlib-resources
# tensorstore
];

Expand Down
Loading