Skip to content

Commit

Permalink
Ensure compatibility with torch>=1.13 torchvision>=0.14 (#3155)
Browse files Browse the repository at this point in the history
* Bump dependency versions to torch>1.13 torchvision>=0.14

* Avoid setuptools>=60

* Replace Tensor.lu() -> torch.linalg.lu_factor()

* Fix test

* Revert "Bump dependency versions to torch>1.13 torchvision>=0.14"

This reverts commit 90ec206.

* Add comment on lap
  • Loading branch information
fritzo committed Nov 22, 2022
1 parent 8b7e564 commit 44267ff
Show file tree
Hide file tree
Showing 5 changed files with 17 additions and 18 deletions.
14 changes: 7 additions & 7 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ jobs:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip wheel 'setuptools!=58.5.*'
python -m pip install --upgrade pip wheel 'setuptools!=58.5.*,<60'
pip install flake8 black isort>=5.0 mypy nbstripout nbformat
- name: Lint
run: |
Expand All @@ -51,7 +51,7 @@ jobs:
sudo add-apt-repository -y ppa:ubuntu-toolchain-r/test
sudo apt-get update
sudo apt-get install gcc-8 g++-8 ninja-build graphviz
python -m pip install --upgrade pip wheel 'setuptools!=58.5.*'
python -m pip install --upgrade pip wheel 'setuptools!=58.5.*,<60'
# Keep track of pyro-api master branch
pip install https://github.com/pyro-ppl/pyro-api/archive/master.zip
pip install torch==1.11.0+cpu torchvision==0.12.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
Expand Down Expand Up @@ -79,7 +79,7 @@ jobs:
sudo add-apt-repository -y ppa:ubuntu-toolchain-r/test
sudo apt-get update
sudo apt-get install gcc-8 g++-8 ninja-build
python -m pip install --upgrade pip wheel 'setuptools!=58.5.*'
python -m pip install --upgrade pip wheel 'setuptools!=58.5.*,<60'
# Keep track of pyro-api master branch
pip install https://github.com/pyro-ppl/pyro-api/archive/master.zip
pip install torch==1.11.0+cpu torchvision==0.12.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
Expand Down Expand Up @@ -113,7 +113,7 @@ jobs:
sudo add-apt-repository -y ppa:ubuntu-toolchain-r/test
sudo apt-get update
sudo apt-get install gcc-8 g++-8 ninja-build
python -m pip install --upgrade pip wheel 'setuptools!=58.5.*'
python -m pip install --upgrade pip wheel 'setuptools!=58.5.*,<60'
# Keep track of pyro-api master branch
pip install https://github.com/pyro-ppl/pyro-api/archive/master.zip
pip install torch==1.11.0+cpu torchvision==0.12.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
Expand Down Expand Up @@ -147,7 +147,7 @@ jobs:
sudo add-apt-repository -y ppa:ubuntu-toolchain-r/test
sudo apt-get update
sudo apt-get install gcc-8 g++-8 ninja-build
python -m pip install --upgrade pip wheel 'setuptools!=58.5.*'
python -m pip install --upgrade pip wheel 'setuptools!=58.5.*,<60'
# Keep track of pyro-api master branch
pip install https://github.com/pyro-ppl/pyro-api/archive/master.zip
pip install torch==1.11.0+cpu torchvision==0.12.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
Expand Down Expand Up @@ -179,7 +179,7 @@ jobs:
sudo add-apt-repository -y ppa:ubuntu-toolchain-r/test
sudo apt-get update
sudo apt-get install gcc-8 g++-8 ninja-build
python -m pip install --upgrade pip wheel 'setuptools!=58.5.*'
python -m pip install --upgrade pip wheel 'setuptools!=58.5.*,<60'
# Keep track of pyro-api master branch
pip install https://github.com/pyro-ppl/pyro-api/archive/master.zip
pip install torch==1.11.0+cpu torchvision==0.12.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
Expand Down Expand Up @@ -211,7 +211,7 @@ jobs:
sudo add-apt-repository -y ppa:ubuntu-toolchain-r/test
sudo apt-get update
sudo apt-get install gcc-8 g++-8 ninja-build
python -m pip install --upgrade pip wheel 'setuptools!=58.5.*'
python -m pip install --upgrade pip wheel 'setuptools!=58.5.*,<60'
# Keep track of pyro-api master branch
pip install https://github.com/pyro-ppl/pyro-api/archive/master.zip
pip install torch==1.11.0+cpu torchvision==0.12.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
Expand Down
1 change: 1 addition & 0 deletions docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ opt_einsum>=2.3.2
pyro-api>=0.1.1
tqdm>=4.36
funsor[torch]
setuptools<60
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def __init__(self, channels=3, permutation=None):
W, _ = torch.linalg.qr(torch.randn(channels, channels))

# Construct the partially pivoted LU-form and the pivots
LU, pivots = W.lu()
LU, pivots = torch.linalg.lu_factor(W)

# Convert the pivots into the permutation matrix
if permutation is None:
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@
"scikit-learn",
"seaborn>=0.11.0",
"wget",
"lap",
"lap", # Requires setuptools<60
# 'biopython>=1.54',
# 'scanpy>=1.4', # Requires HDF5
# 'scvi>=0.6', # Requires loopy and other fragile packages
Expand Down
16 changes: 7 additions & 9 deletions tests/optim/test_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,27 +56,25 @@ def optim_params(param_name):
elif param_name == free_param:
return {"lr": 0.01}

def get_steps(adam):
state = adam.get_state()["loc_q"]["state"]
return int(list(state.values())[0]["step"])

adam = optim.Adam(optim_params)
adam2 = optim.Adam(optim_params)
svi = SVI(model, guide, adam, loss=TraceGraph_ELBO())
svi2 = SVI(model, guide, adam2, loss=TraceGraph_ELBO())

svi.step()
adam_initial_step_count = list(adam.get_state()["loc_q"]["state"].items())[0][
1
]["step"]
adam_initial_step_count = get_steps(adam)
with TemporaryDirectory() as tempdir:
filename = os.path.join(tempdir, "optimizer_state.pt")
adam.save(filename)
svi.step()
adam_final_step_count = list(adam.get_state()["loc_q"]["state"].items())[0][
1
]["step"]
adam_final_step_count = get_steps(adam)
adam2.load(filename)
svi2.step()
adam2_step_count_after_load_and_step = list(
adam2.get_state()["loc_q"]["state"].items()
)[0][1]["step"]
adam2_step_count_after_load_and_step = get_steps(adam2)

assert adam_initial_step_count == 1
assert adam_final_step_count == 2
Expand Down

0 comments on commit 44267ff

Please sign in to comment.