Skip to content

Commit

Permalink
Remove outdated version checks (#888)
Browse files Browse the repository at this point in the history
Remove version checks that are no longer required because we don't
support the min versions anymore.

None of the version checks we do is actually required anymore. After
removing them, we thus no longer depend on distutils's LooseVersion.

Minimum scikit-learn version has been bumped to 0.22 to make this
change possible.
  • Loading branch information
BenjaminBossan committed Aug 26, 2022
1 parent 787b859 commit 1491d5a
Show file tree
Hide file tree
Showing 7 changed files with 8 additions and 38 deletions.
1 change: 1 addition & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added experimental support for [huggingface accelerate](https://github.com/huggingface/accelerate); use the provided mixin class to add advanced training capabilities provided by the accelerate library to skorch

### Changed
- The minimum required scikit-learn version has been bumped to 0.22.0
- Initialize data loaders for training and validation dataset once per fit call instead of once per epoch ([migration guide](https://skorch.readthedocs.io/en/stable/user/FAQ.html#migration-from-0-11-to-0-12))
- It is now possible to call `np.asarray` with `SliceDataset`s (#858)
- Add integration for Huggingface tokenizers; use `skorch.hf.HuggingfaceTokenizer` to train a Huggingface tokenizer on your custom data; use `skorch.hf.HuggingfacePretrainedTokenizer` to load a pre-trained Huggingface tokenizer
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
numpy>=1.13.3
scikit-learn>=0.19.1
scikit-learn>=0.22.0
scipy>=1.1.0
tabulate>=0.7.7
tqdm>=4.14.0
7 changes: 1 addition & 6 deletions skorch/callbacks/scoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,16 @@

from contextlib import contextmanager
from contextlib import suppress
from distutils.version import LooseVersion
from functools import partial
import warnings

import numpy as np
import sklearn
from sklearn.metrics import make_scorer, check_scoring

if LooseVersion(sklearn.__version__) >= '0.22':
from sklearn.metrics._scorer import _BaseScorer
else:
from sklearn.metrics.scorer import _BaseScorer

from skorch.callbacks import Callback
from skorch.dataset import unpack_data
from sklearn.metrics._scorer import _BaseScorer
from skorch.utils import data_from_dataset
from skorch.utils import is_skorch_dataset
from skorch.utils import to_numpy
Expand Down
9 changes: 0 additions & 9 deletions skorch/tests/callbacks/test_lr_scheduler.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
"""Tests for lr_scheduler.py"""
from distutils.version import LooseVersion
from unittest.mock import Mock

import numpy as np
Expand Down Expand Up @@ -191,10 +190,6 @@ def test_lr_scheduler_set_params(self, classifier_module, classifier_data):
@pytest.mark.parametrize('policy,kwargs', [
(StepLR, {'gamma': 0.9, 'step_size': 1})
])
@pytest.mark.skipif(
LooseVersion(torch.__version__) < '1.4',
reason="Feature isn't supported with this torch version."
)
def test_lr_scheduler_record_epoch_step(self,
classifier_module,
classifier_data,
Expand All @@ -212,10 +207,6 @@ def test_lr_scheduler_record_epoch_step(self,
net.fit(*classifier_data)
assert np.all(net.history[:, 'event_lr'] == lrs)

@pytest.mark.skipif(
LooseVersion(torch.__version__) < '1.4',
reason="Feature isn't supported with this torch version."
)
def test_lr_scheduler_record_batch_step(self, classifier_module, classifier_data):
X, y = classifier_data
batch_size = 128
Expand Down
13 changes: 4 additions & 9 deletions skorch/tests/test_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
"""

import copy
from distutils.version import LooseVersion
from functools import partial
import os
from pathlib import Path
Expand Down Expand Up @@ -1784,23 +1783,19 @@ def test_repr_fitted_works(self, net_cls, module_cls, data):
expected = """<class 'skorch.classifier.NeuralNetClassifier'>[initialized](
module_=MLPModule(
(nonlin): PReLU(num_parameters=1)
(output_nonlin): Softmax()
(output_nonlin): Softmax(dim=-1)
(sequential): Sequential(
(0): Linear(in_features=20, out_features=11, bias=True)
(1): PReLU(num_parameters=1)
(2): Dropout(p=0.5)
(2): Dropout(p=0.5, inplace=False)
(3): Linear(in_features=11, out_features=11, bias=True)
(4): PReLU(num_parameters=1)
(5): Dropout(p=0.5)
(5): Dropout(p=0.5, inplace=False)
(6): Linear(in_features=11, out_features=2, bias=True)
(7): Softmax()
(7): Softmax(dim=-1)
)
),
)"""
if LooseVersion(torch.__version__) >= '1.2':
expected = expected.replace("Softmax()", "Softmax(dim=-1)")
expected = expected.replace("Dropout(p=0.5)",
"Dropout(p=0.5, inplace=False)")
assert result == expected

def test_fit_params_passed_to_module(self, net_cls, data):
Expand Down
8 changes: 0 additions & 8 deletions skorch/tests/test_probabilistic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import copy
import pickle
import re
from distutils.version import LooseVersion

import numpy as np
import pytest
Expand All @@ -19,13 +18,6 @@

gpytorch = pytest.importorskip('gpytorch')

# extract pytorch version without possible '+something' suffix
pytorch_version, _, _ = torch.__version__.partition('+')
# Keep up to date with the gpytorch's supported versions:
# https://github.com/cornellius-gp/gpytorch#installation
if LooseVersion(pytorch_version) < '1.9':
pytest.skip("gpytorch does not support PyTorch versions < 1.9", allow_module_level=True)


def get_batch_size(dist):
"""Return the shape of the distribution
Expand Down
6 changes: 1 addition & 5 deletions skorch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

from collections.abc import Mapping, Sequence
from contextlib import contextmanager
from distutils.version import LooseVersion
from enum import Enum
from functools import partial
from itertools import tee
Expand All @@ -17,6 +16,7 @@
from scipy import sparse
import sklearn
from sklearn.exceptions import NotFittedError
from sklearn.utils import _safe_indexing as safe_indexing
from sklearn.utils.validation import check_is_fitted as sk_check_is_fitted
import torch
from torch.nn import BCELoss
Expand All @@ -28,10 +28,6 @@
from skorch.exceptions import DeviceWarning
from skorch.exceptions import NotInitializedError

if LooseVersion(sklearn.__version__) >= '0.22.0':
from sklearn.utils import _safe_indexing as safe_indexing
else:
from sklearn.utils import safe_indexing

GPYTORCH_INSTALLED = False
try:
Expand Down

0 comments on commit 1491d5a

Please sign in to comment.