Skip to content

Commit

Permalink
MAINT Replace setup_module by pytest fixtures (#28475)
Browse files Browse the repository at this point in the history
  • Loading branch information
lesteve committed Feb 22, 2024
1 parent 458d7a7 commit 93f6ce9
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 72 deletions.
88 changes: 38 additions & 50 deletions sklearn/datasets/tests/test_lfw.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,7 @@
joblib, successive runs will be fast (less than 200ms).
"""

import os
import random
import shutil
import tempfile
from functools import partial

import numpy as np
Expand All @@ -21,10 +18,6 @@
from sklearn.datasets.tests.test_common import check_return_X_y
from sklearn.utils._testing import assert_array_equal

SCIKIT_LEARN_DATA = None
SCIKIT_LEARN_EMPTY_DATA = None
LFW_HOME = None

FAKE_NAMES = [
"Abdelatif_Smith",
"Abhati_Kepler",
Expand All @@ -36,44 +29,46 @@
]


def setup_module():
"""Test fixture run once and common to all tests of this module"""
Image = pytest.importorskip("PIL.Image")
@pytest.fixture(scope="module")
def mock_empty_data_home(tmp_path_factory):
data_dir = tmp_path_factory.mktemp("scikit_learn_empty_test")

global SCIKIT_LEARN_DATA, SCIKIT_LEARN_EMPTY_DATA, LFW_HOME
yield data_dir

SCIKIT_LEARN_DATA = tempfile.mkdtemp(prefix="scikit_learn_lfw_test_")
LFW_HOME = os.path.join(SCIKIT_LEARN_DATA, "lfw_home")

SCIKIT_LEARN_EMPTY_DATA = tempfile.mkdtemp(prefix="scikit_learn_empty_test_")
@pytest.fixture(scope="module")
def mock_data_home(tmp_path_factory):
"""Test fixture run once and common to all tests of this module"""
Image = pytest.importorskip("PIL.Image")

if not os.path.exists(LFW_HOME):
os.makedirs(LFW_HOME)
data_dir = tmp_path_factory.mktemp("scikit_learn_lfw_test")
lfw_home = data_dir / "lfw_home"
lfw_home.mkdir(parents=True, exist_ok=True)

random_state = random.Random(42)
np_rng = np.random.RandomState(42)

# generate some random jpeg files for each person
counts = {}
for name in FAKE_NAMES:
folder_name = os.path.join(LFW_HOME, "lfw_funneled", name)
if not os.path.exists(folder_name):
os.makedirs(folder_name)
folder_name = lfw_home / "lfw_funneled" / name
folder_name.mkdir(parents=True, exist_ok=True)

n_faces = np_rng.randint(1, 5)
counts[name] = n_faces
for i in range(n_faces):
file_path = os.path.join(folder_name, name + "_%04d.jpg" % i)
file_path = folder_name / (name + "_%04d.jpg" % i)
uniface = np_rng.randint(0, 255, size=(250, 250, 3))
img = Image.fromarray(uniface.astype(np.uint8))
img.save(file_path)

# add some random file pollution to test robustness
with open(os.path.join(LFW_HOME, "lfw_funneled", ".test.swp"), "wb") as f:
f.write(b"Text file to be ignored by the dataset loader.")
(lfw_home / "lfw_funneled" / ".test.swp").write_bytes(
b"Text file to be ignored by the dataset loader."
)

# generate some pairing metadata files using the same format as LFW
with open(os.path.join(LFW_HOME, "pairsDevTrain.txt"), "wb") as f:
with open(lfw_home / "pairsDevTrain.txt", "wb") as f:
f.write(b"10\n")
more_than_two = [name for name, count in counts.items() if count >= 2]
for i in range(5):
Expand All @@ -92,29 +87,22 @@ def setup_module():
).encode()
)

with open(os.path.join(LFW_HOME, "pairsDevTest.txt"), "wb") as f:
f.write(b"Fake place holder that won't be tested")

with open(os.path.join(LFW_HOME, "pairs.txt"), "wb") as f:
f.write(b"Fake place holder that won't be tested")

(lfw_home / "pairsDevTest.txt").write_bytes(
b"Fake place holder that won't be tested"
)
(lfw_home / "pairs.txt").write_bytes(b"Fake place holder that won't be tested")

def teardown_module():
"""Test fixture (clean up) run once after all tests of this module"""
if os.path.isdir(SCIKIT_LEARN_DATA):
shutil.rmtree(SCIKIT_LEARN_DATA)
if os.path.isdir(SCIKIT_LEARN_EMPTY_DATA):
shutil.rmtree(SCIKIT_LEARN_EMPTY_DATA)
yield data_dir


def test_load_empty_lfw_people():
def test_load_empty_lfw_people(mock_empty_data_home):
with pytest.raises(OSError):
fetch_lfw_people(data_home=SCIKIT_LEARN_EMPTY_DATA, download_if_missing=False)
fetch_lfw_people(data_home=mock_empty_data_home, download_if_missing=False)


def test_load_fake_lfw_people():
def test_load_fake_lfw_people(mock_data_home):
lfw_people = fetch_lfw_people(
data_home=SCIKIT_LEARN_DATA, min_faces_per_person=3, download_if_missing=False
data_home=mock_data_home, min_faces_per_person=3, download_if_missing=False
)

# The data is croped around the center as a rectangular bounding box
Expand All @@ -132,7 +120,7 @@ def test_load_fake_lfw_people():
# It is possible to ask for the original data without any croping or color
# conversion and not limit on the number of picture per person
lfw_people = fetch_lfw_people(
data_home=SCIKIT_LEARN_DATA,
data_home=mock_data_home,
resize=None,
slice_=None,
color=True,
Expand Down Expand Up @@ -161,7 +149,7 @@ def test_load_fake_lfw_people():
# test return_X_y option
fetch_func = partial(
fetch_lfw_people,
data_home=SCIKIT_LEARN_DATA,
data_home=mock_data_home,
resize=None,
slice_=None,
color=True,
Expand All @@ -170,23 +158,23 @@ def test_load_fake_lfw_people():
check_return_X_y(lfw_people, fetch_func)


def test_load_fake_lfw_people_too_restrictive():
def test_load_fake_lfw_people_too_restrictive(mock_data_home):
with pytest.raises(ValueError):
fetch_lfw_people(
data_home=SCIKIT_LEARN_DATA,
data_home=mock_data_home,
min_faces_per_person=100,
download_if_missing=False,
)


def test_load_empty_lfw_pairs():
def test_load_empty_lfw_pairs(mock_empty_data_home):
with pytest.raises(OSError):
fetch_lfw_pairs(data_home=SCIKIT_LEARN_EMPTY_DATA, download_if_missing=False)
fetch_lfw_pairs(data_home=mock_empty_data_home, download_if_missing=False)


def test_load_fake_lfw_pairs():
def test_load_fake_lfw_pairs(mock_data_home):
lfw_pairs_train = fetch_lfw_pairs(
data_home=SCIKIT_LEARN_DATA, download_if_missing=False
data_home=mock_data_home, download_if_missing=False
)

# The data is croped around the center as a rectangular bounding box
Expand All @@ -203,7 +191,7 @@ def test_load_fake_lfw_pairs():
# It is possible to ask for the original data without any croping or color
# conversion
lfw_pairs_train = fetch_lfw_pairs(
data_home=SCIKIT_LEARN_DATA,
data_home=mock_data_home,
resize=None,
slice_=None,
color=True,
Expand All @@ -218,7 +206,7 @@ def test_load_fake_lfw_pairs():
assert lfw_pairs_train.DESCR.startswith(".. _labeled_faces_in_the_wild_dataset:")


def test_fetch_lfw_people_internal_cropping():
def test_fetch_lfw_people_internal_cropping(mock_data_home):
"""Check that we properly crop the images.
Non-regression test for:
Expand All @@ -229,7 +217,7 @@ def test_fetch_lfw_people_internal_cropping():
# pre-allocated based on `slice_` parameter.
slice_ = (slice(70, 195), slice(78, 172))
lfw = fetch_lfw_people(
data_home=SCIKIT_LEARN_DATA,
data_home=mock_data_home,
min_faces_per_person=3,
download_if_missing=False,
resize=None,
Expand Down
31 changes: 9 additions & 22 deletions sklearn/metrics/tests/test_score_objects.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
import numbers
import os
import pickle
import shutil
import tempfile
from copy import deepcopy
from functools import partial
from unittest.mock import Mock
Expand Down Expand Up @@ -166,28 +163,17 @@ def _make_estimators(X_train, y_train, y_ml_train):
)


X_mm, y_mm, y_ml_mm = None, None, None
ESTIMATORS = None
TEMP_FOLDER = None


def setup_module():
# Create some memory mapped data
global X_mm, y_mm, y_ml_mm, TEMP_FOLDER, ESTIMATORS
TEMP_FOLDER = tempfile.mkdtemp(prefix="sklearn_test_score_objects_")
@pytest.fixture(scope="module")
def memmap_data_and_estimators(tmp_path_factory):
temp_folder = tmp_path_factory.mktemp("sklearn_test_score_objects")
X, y = make_classification(n_samples=30, n_features=5, random_state=0)
_, y_ml = make_multilabel_classification(n_samples=X.shape[0], random_state=0)
filename = os.path.join(TEMP_FOLDER, "test_data.pkl")
filename = temp_folder / "test_data.pkl"
joblib.dump((X, y, y_ml), filename)
X_mm, y_mm, y_ml_mm = joblib.load(filename, mmap_mode="r")
ESTIMATORS = _make_estimators(X_mm, y_mm, y_ml_mm)

estimators = _make_estimators(X_mm, y_mm, y_ml_mm)

def teardown_module():
global X_mm, y_mm, y_ml_mm, TEMP_FOLDER, ESTIMATORS
# GC closes the mmap file descriptors
X_mm, y_mm, y_ml_mm, ESTIMATORS = None, None, None, None
shutil.rmtree(TEMP_FOLDER)
yield X_mm, y_mm, y_ml_mm, estimators


class EstimatorWithFit(BaseEstimator):
Expand Down Expand Up @@ -688,10 +674,11 @@ def test_regression_scorer_sample_weight():


@pytest.mark.parametrize("name", get_scorer_names())
def test_scorer_memmap_input(name):
def test_scorer_memmap_input(name, memmap_data_and_estimators):
# Non-regression test for #6147: some score functions would
# return singleton memmap when computed on memmap data instead of scalar
# float values.
X_mm, y_mm, y_ml_mm, estimators = memmap_data_and_estimators

if name in REQUIRE_POSITIVE_Y_SCORERS:
y_mm_1 = _require_positive_y(y_mm)
Expand All @@ -701,7 +688,7 @@ def test_scorer_memmap_input(name):

# UndefinedMetricWarning for P / R scores
with ignore_warnings():
scorer, estimator = get_scorer(name), ESTIMATORS[name]
scorer, estimator = get_scorer(name), estimators[name]
if name in MULTILABEL_ONLY_SCORERS:
score = scorer(estimator, X_mm, y_ml_mm_1)
else:
Expand Down

0 comments on commit 93f6ce9

Please sign in to comment.