Skip to content

Commit

Permalink
Test gpytorch only with torch version 1.9 and up (#827)
Browse files Browse the repository at this point in the history
Also, remove some obsolete excluded versions from CI.
  • Loading branch information
BenjaminBossan committed Dec 30, 2021
1 parent f6730f4 commit cf6615b
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 11 deletions.
6 changes: 1 addition & 5 deletions .github/workflows/testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,9 @@ jobs:
strategy:
matrix:
python_version: ['3.6', '3.7', '3.8', '3.9']
torch_version: ['1.7.1+cpu', '1.8.1+cpu', '1.9.1+cpu', '1.10.0+cpu']
torch_version: ['1.7.1+cpu', '1.8.1+cpu', '1.9.1+cpu', '1.10.1+cpu']
os: [ubuntu-latest]
exclude:
- python_version: '3.9'
torch_version: '1.5.1+cpu'
- python_version: '3.9'
torch_version: '1.6.0+cpu'
- python_version: '3.9'
torch_version: '1.7.1+cpu'

Expand Down
7 changes: 3 additions & 4 deletions skorch/tests/test_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -1756,10 +1756,9 @@ def test_repr_initialized_works(self, net_cls, module_cls):
)
),
)"""
if LooseVersion(torch.__version__) >= '1.2':
expected = expected.replace("Softmax()", "Softmax(dim=-1)")
expected = expected.replace("Dropout(p=0.5)",
"Dropout(p=0.5, inplace=False)")
expected = expected.replace("Softmax()", "Softmax(dim=-1)")
expected = expected.replace("Dropout(p=0.5)",
"Dropout(p=0.5, inplace=False)")
assert result == expected

def test_repr_fitted_works(self, net_cls, module_cls, data):
Expand Down
6 changes: 4 additions & 2 deletions skorch/tests/test_probabilistic.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@

# extract pytorch version without possible '+something' suffix
pytorch_version, _, _ = torch.__version__.partition('+')
if LooseVersion(pytorch_version) == '1.7.1':
pytest.skip("gpytorch does not work with PyTorch 1.7.1", allow_module_level=True)
# Keep up to date with the gpytorch's supported versions:
# https://github.com/cornellius-gp/gpytorch#installation
if LooseVersion(pytorch_version) < '1.9':
pytest.skip("gpytorch does not support PyTorch versions < 1.9", allow_module_level=True)


def get_batch_size(dist):
Expand Down

0 comments on commit cf6615b

Please sign in to comment.