Skip to content

add flake8-comprehentions #689

Merged
merged 6 commits into from May 26, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion .flake8
@@ -1,5 +1,5 @@
[flake8]
ignore = F, E203, W605, E501, W503, D100, D104
ignore = F, E203, W605, E501, W503, D100, D104, C408
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you decide to ignore C408?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because in cases like this

trainer_kwargs = dict(
logger=tslogger.pl_loggers,
max_epochs=self.max_epochs,
gpus=self.gpus,
checkpoint_callback=False,
gradient_clip_val=self.gradient_clip_val,
)

It would be easier to read dict(key=value) approach, than {"key": value} in my opinion

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you think?

max-line-length = 121
max-complexity = 18
docstring-convention=numpy
Expand Down
4 changes: 2 additions & 2 deletions etna/analysis/plotters.py
Expand Up @@ -56,12 +56,12 @@ def _select_quantiles(forecast_results: Dict[str, "TSDataset"], quantiles: Optio
intersection_quantiles_set = set.intersection(
*[_get_existing_quantiles(forecast) for forecast in forecast_results.values()]
)
intersection_quantiles = sorted(list(intersection_quantiles_set))
intersection_quantiles = sorted(intersection_quantiles_set)

if quantiles is None:
selected_quantiles = intersection_quantiles
else:
selected_quantiles = sorted(list(set(quantiles) & intersection_quantiles_set))
selected_quantiles = sorted(set(quantiles) & intersection_quantiles_set)
non_existent = set(quantiles) - intersection_quantiles_set
if non_existent:
warnings.warn(f"Quantiles {non_existent} do not exist in each forecast dataset. They will be dropped.")
Expand Down
4 changes: 2 additions & 2 deletions etna/datasets/tsdataset.py
Expand Up @@ -324,7 +324,7 @@ def _check_known_future(

if isinstance(known_future, str):
if known_future == "all":
return sorted(list(exog_columns))
return sorted(exog_columns)
else:
raise ValueError("The only possible literal is 'all'")
else:
Expand All @@ -335,7 +335,7 @@ def _check_known_future(
f"{known_future_unique.difference(exog_columns)}"
)
else:
return sorted(list(known_future_unique))
return sorted(known_future_unique)

@staticmethod
def _check_regressors(df: pd.DataFrame, df_regressors: pd.DataFrame):
Expand Down
2 changes: 1 addition & 1 deletion etna/ensembles/base.py
Expand Up @@ -17,7 +17,7 @@ def _validate_pipeline_number(pipelines: List[BasePipeline]):
@staticmethod
def _get_horizon(pipelines: List[BasePipeline]) -> int:
"""Get ensemble's horizon."""
horizons = set([pipeline.horizon for pipeline in pipelines])
horizons = {pipeline.horizon for pipeline in pipelines}
if len(horizons) > 1:
raise ValueError("All the pipelines should have the same horizon.")
return horizons.pop()
Expand Down
2 changes: 1 addition & 1 deletion etna/transforms/feature_selection/base.py
Expand Up @@ -18,7 +18,7 @@ def __init__(self, features_to_use: Union[List[str], Literal["all"]] = "all"):

def _get_features_to_use(self, df: pd.DataFrame) -> List[str]:
"""Get list of features from the dataframe to perform the selection on."""
features = set(df.columns.get_level_values("feature")) - set(["target"])
features = set(df.columns.get_level_values("feature")) - {"target"}
if self.features_to_use != "all":
features = features.intersection(self.features_to_use)
if sorted(features) != sorted(self.features_to_use):
Expand Down
2 changes: 1 addition & 1 deletion etna/transforms/missing_values/resample.py
Expand Up @@ -56,7 +56,7 @@ def _get_folds(self, df: pd.DataFrame) -> List[int]:
in_column_start_index = in_column_index[0]
left_tie_len = len(df[:in_column_start_index]) - 1
right_tie_len = len(df[in_column_start_index:])
folds_for_left_tie = [fold for fold in range(n_folds_per_gap - left_tie_len, n_folds_per_gap)]
folds_for_left_tie = list(range(n_folds_per_gap - left_tie_len, n_folds_per_gap))
folds_for_right_tie = [fold for _ in range(n_periods) for fold in range(n_folds_per_gap)][:right_tie_len]
return folds_for_left_tie + folds_for_right_tie

Expand Down
2 changes: 1 addition & 1 deletion etna/transforms/utils.py
Expand Up @@ -5,4 +5,4 @@
def match_target_quantiles(features: Set[str]) -> Set[str]:
"""Find quantiles in dataframe columns."""
pattern = re.compile("target_\d+\.\d+$")
return set(i for i in list(features) if pattern.match(i) is not None)
return {i for i in list(features) if pattern.match(i) is not None}
230 changes: 123 additions & 107 deletions poetry.lock

Large diffs are not rendered by default.

7 changes: 4 additions & 3 deletions pyproject.toml
Expand Up @@ -46,6 +46,7 @@ ruptures = "1.1.5"
numba = ">=0.53.1,<0.56.0"
seaborn = "^0.11.1"
statsmodels = ">=0.12,<0.14"
pmdarima = ">=1.8.0"
dill = "^0.3.4"
toml = "^0.10.2"
loguru = "^0.5.3"
Expand Down Expand Up @@ -89,6 +90,7 @@ isort = {version = "^5.8.0", optional = true}
flake8 = {version = "^3.9.2", optional = true}
pep8-naming = {version = "^0.12.1", optional = true}
flake8-bugbear = {version = "^22.4.25", optional = true}
flake8-comprehensions = {version = "^3.9.0", optional = true}
flake8-docstrings = {version = "^1.6.0", optional = true}
mypy = {version = "^0.910", optional = true}
types-PyYAML = {version = "^6.0.0", optional = true}
Expand All @@ -100,7 +102,6 @@ ipywidgets = {version = "^7.6.5", optional = true}

jupyter = {version = "*", optional = true}
nbconvert = {version = "*", optional = true}
pmdarima = ">=1.8.0"


[tool.poetry.extras]
Expand All @@ -113,7 +114,7 @@ release = ["click", "semver"]
docs = ["Sphinx", "numpydoc", "sphinx-rtd-theme", "nbsphinx", "sphinx-mathjax-offline", "myst-parser", "GitPython"]
tests = ["pytest-cov", "coverage", "pytest"]
jupyter = ["jupyter", "nbconvert"]
style = ["black", "isort", "flake8", "pep8-naming", "flake8-docstrings", "mypy", "types-PyYAML", "codespell", "flake8-bugbear"]
style = ["black", "isort", "flake8", "pep8-naming", "flake8-docstrings", "mypy", "types-PyYAML", "codespell", "flake8-bugbear", "flake8-comprehensions"]

all = [
"prophet",
Expand All @@ -128,7 +129,7 @@ all-dev = [
"click", "semver",
"Sphinx", "numpydoc", "sphinx-rtd-theme", "nbsphinx", "sphinx-mathjax-offline", "myst-parser", "GitPython",
"pytest-cov", "coverage", "pytest",
"black", "isort", "flake8", "pep8-naming", "flake8-docstrings", "mypy", "types-PyYAML", "codespell", "flake8-bugbear",
"black", "isort", "flake8", "pep8-naming", "flake8-docstrings", "mypy", "types-PyYAML", "codespell", "flake8-bugbear", "flake8-comprehensions",
"click", "semver",
"jupyter", "nbconvert"
]
Expand Down
2 changes: 1 addition & 1 deletion tests/test_analysis/test_eda_utils.py
Expand Up @@ -142,7 +142,7 @@ def test_cross_corr_with_full_nans(a, b, normed, expected_result):
pd.date_range(start="2020-01-03", periods=40, freq="D"),
"month",
["2020-Jan"] * 29 + ["2020-Feb"] * 11,
[i for i in range(3, 32)] + [i for i in range(1, 12)],
list(range(3, 32)) + list(range(1, 12)),
[str(i) for i in range(3, 32)] + [str(i) for i in range(1, 12)],
),
(
Expand Down
Expand Up @@ -35,7 +35,7 @@ def test_get_anomalies_prediction_interval_interface(outliers_tsds, model, in_co
"""Test that `get_anomalies_prediction_interval` produces correct columns."""
anomalies = get_anomalies_prediction_interval(outliers_tsds, model=model, interval_width=0.95, in_column=in_column)
assert isinstance(anomalies, dict)
assert sorted(list(anomalies.keys())) == sorted(outliers_tsds.segments)
assert sorted(anomalies.keys()) == sorted(outliers_tsds.segments)
for segment in anomalies.keys():
assert isinstance(anomalies[segment], list)
for date in anomalies[segment]:
Expand Down
2 changes: 1 addition & 1 deletion tests/test_analysis/test_outliers/test_density_outliers.py
Expand Up @@ -16,7 +16,7 @@ def simple_window() -> np.array:

def test_const_ts(const_ts_anomal):
anomal = get_anomalies_density(const_ts_anomal)
assert set(["segment_0", "segment_1"]) == set(anomal.keys())
assert {"segment_0", "segment_1"} == set(anomal.keys())
for seg in anomal.keys():
assert len(anomal[seg]) == 0

Expand Down
4 changes: 2 additions & 2 deletions tests/test_analysis/test_outliers/test_median_outliers.py
Expand Up @@ -6,7 +6,7 @@

def test_const_ts(const_ts_anomal):
anomal = get_anomalies_median(const_ts_anomal)
assert set(["segment_0", "segment_1"]) == set(anomal.keys())
assert {"segment_0", "segment_1"} == set(anomal.keys())
for seg in anomal.keys():
assert len(anomal[seg]) == 0

Expand Down Expand Up @@ -34,7 +34,7 @@ def test_median_outliers(window_size, alpha, right_anomal, outliers_tsds):
def test_interface_correct_args(true_params, outliers_tsds):
d = get_anomalies_median(ts=outliers_tsds, window_size=10, alpha=2)
assert isinstance(d, dict)
assert sorted(list(d.keys())) == sorted(true_params)
assert sorted(d.keys()) == sorted(true_params)
for i in d.keys():
for j in d[i]:
assert isinstance(j, np.datetime64)
Expand Down
4 changes: 2 additions & 2 deletions tests/test_pipeline/test_pipeline.py
Expand Up @@ -317,14 +317,14 @@ def test_get_fold_info_interface_daily(catboost_pipeline: Pipeline, big_daily_ex
"""Check that Pipeline.backtest returns info dataframe in correct format."""
_, _, info_df = catboost_pipeline.backtest(ts=big_daily_example_tsdf, metrics=DEFAULT_METRICS)
expected_columns = ["fold_number", "test_end_time", "test_start_time", "train_end_time", "train_start_time"]
assert expected_columns == list(sorted(info_df.columns))
assert expected_columns == sorted(info_df.columns)


def test_get_fold_info_interface_hours(catboost_pipeline: Pipeline, example_tsdf: TSDataset):
"""Check that Pipeline.backtest returns info dataframe in correct format with non-daily seasonality."""
_, _, info_df = catboost_pipeline.backtest(ts=example_tsdf, metrics=DEFAULT_METRICS)
expected_columns = ["fold_number", "test_end_time", "test_start_time", "train_end_time", "train_start_time"]
assert expected_columns == list(sorted(info_df.columns))
assert expected_columns == sorted(info_df.columns)


@pytest.mark.long
Expand Down
Expand Up @@ -184,7 +184,7 @@ def test_naming_ohe_encoder(two_df_with_new_values):
ohe.fit(df1)
segments = ["segment_0", "segment_1"]
target = ["target", "targets_0", "targets_1", "targets_2", "regressor_0"]
assert set([(i, j) for i in segments for j in target]) == set(ohe.transform(df2).columns.values)
assert {(i, j) for i in segments for j in target} == set(ohe.transform(df2).columns.values)


@pytest.mark.parametrize(
Expand Down
Expand Up @@ -268,4 +268,4 @@ def test_mrmr_right_regressors(relevance_table, ts_with_regressors):
for column in df_selected.columns.get_level_values("feature"):
if column.startswith("regressor"):
selected_regressors.add(column)
assert set(selected_regressors) == set(["regressor_useful_0", "regressor_useful_1", "regressor_useful_2"])
assert set(selected_regressors) == {"regressor_useful_0", "regressor_useful_1", "regressor_useful_2"}
2 changes: 1 addition & 1 deletion tests/test_transforms/test_missing_values/conftest.py
Expand Up @@ -24,7 +24,7 @@ def date_range(request) -> pd.DatetimeIndex:
def all_date_present_df(date_range: pd.Series) -> pd.DataFrame:
"""Create pd.DataFrame that contains some target on given range of dates without gaps."""
df = pd.DataFrame({"timestamp": date_range})
df["target"] = [i for i in range(len(df))]
df["target"] = list(range(len(df)))
df.set_index("timestamp", inplace=True)
return df

Expand Down
Expand Up @@ -177,7 +177,7 @@ def test_interface_correct_args_out_column(true_params: List[str], train_df: pd.
true_params = [f"{out_column}_{param}" for param in true_params]
for seg in result.columns.get_level_values(0).unique():
tmp_df = result[seg]
assert sorted(list(tmp_df.columns)) == sorted(true_params + ["target"])
assert sorted(tmp_df.columns) == sorted(true_params + ["target"])
for param in true_params:
assert tmp_df[param].dtype == "category"

Expand Down
Expand Up @@ -128,7 +128,7 @@ def test_interface_out_column(true_params: List[str], train_df: pd.DataFrame):
true_params = [f"{out_column}_{param}" for param in true_params]
for seg in result.columns.get_level_values(0).unique():
tmp_df = result[seg]
assert sorted(list(tmp_df.columns)) == sorted(true_params + ["target"])
assert sorted(tmp_df.columns) == sorted(true_params + ["target"])
for param in true_params:
assert tmp_df[param].dtype == "category"

Expand Down