Skip to content

Commit

Permalink
Merge branch 'main' of github.com:pappasam/toml-sort
Browse files Browse the repository at this point in the history
  • Loading branch information
pappasam committed Mar 31, 2022
2 parents d3092a2 + dac037a commit a210ab3
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 1 deletion.
20 changes: 20 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,26 @@ Options:
--help Show this message and exit.
```

## Configuration file

toml-sort can also be configured by using the `pyproject.toml` file.
If the file exists and has a `tool.tomlsort` section, the configuration is used.
If both command line arguments and the configuration are used, the options are merged.
In the case of conflicts, the command line option is used.

In short, the names are the same as on the command line (and have the same meaning),
but `-` is replaced with `_`.
Please note, that only the below options are supported:

```toml
[tool.tomlsort]
all = true
in_place = true
no_header = true
check = true
ignore_case = true
```

## Example

The following example shows the input, and output, from the CLI with default options.
Expand Down
36 changes: 36 additions & 0 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,12 @@
import shutil
import subprocess
from typing import List, NamedTuple, Optional
from unittest import mock

import pytest

from toml_sort import cli

PATH_EXAMPLES = "tests/examples"

# NOTE: weird.toml currently exposes what I interpret to be some buggy
Expand Down Expand Up @@ -161,3 +164,36 @@ def test_multiple_files_and_errors(options):
]
result = capture(["toml-sort"] + options + paths)
assert result.returncode == 1, result.stdout


def test_load_config_file_read():
"""Test no error if pyproject.toml cannot be read."""
with mock.patch("toml_sort.cli.open", side_effect=OSError):
assert not cli.load_config_file()


@pytest.mark.parametrize(
"toml,expected",
[
("", {}),
("[tool.other]\nfoo=2", {}),
("[tool.tomlsort]", {}),
("[tool.tomlsort]\nall=true", {"all": True}),
],
)
def test_load_config_file(toml, expected):
"""Test load_config_file."""
open_mock = mock.mock_open(read_data=toml)
with mock.patch("toml_sort.cli.open", open_mock):
assert cli.load_config_file() == expected


@pytest.mark.parametrize(
"toml", ["[tool.tomlsort]\nunknown=2", "[tool.tomlsort]\nall=42"]
)
def test_load_config_file_invalid(toml):
"""Test error if pyproject.toml is not valid."""
open_mock = mock.mock_open(read_data=toml)
with mock.patch("toml_sort.cli.open", open_mock):
with pytest.raises(SystemExit):
cli.load_config_file()
44 changes: 43 additions & 1 deletion toml_sort/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
import argparse
import sys
from argparse import ArgumentParser
from typing import List, Optional
from typing import Any, Dict, List, Optional, Type

import tomlkit

from .tomlsort import TomlSort

Expand Down Expand Up @@ -59,6 +61,45 @@ def write_file(path: str, content: str) -> None:
fileobj.write(content)


def validate_and_copy(
data: Dict[str, Any], target: Dict[str, Any], key: str, type_: Type
) -> None:
"""Validate a configuration key."""
if key not in data:
return
if not isinstance(data[key], type_):
printerr(f"Value of tool.tomlsort.{key} should be of type {type_}.")
sys.exit(1)
target[key] = data.pop(key)


def load_config_file() -> Dict[str, Any]:
"""Load the configuration from pyproject.toml."""
try:
with open("pyproject.toml", encoding="utf-8") as file:
content = file.read()
except OSError:
return {}

document = tomlkit.parse(content)
tool_section = document.get("tool", tomlkit.document())
toml_sort_section = tool_section.get("tomlsort", tomlkit.document())
config = dict(toml_sort_section)

clean_config: Dict[str, Any] = {}
validate_and_copy(config, clean_config, "all", bool)
validate_and_copy(config, clean_config, "in_place", bool)
validate_and_copy(config, clean_config, "no_header", bool)
validate_and_copy(config, clean_config, "check", bool)
validate_and_copy(config, clean_config, "ignore_case", bool)

if config:
printerr(f"Unexpected configuration values: {config}")
sys.exit(1)

return clean_config


def get_parser() -> ArgumentParser:
"""Get the argument parser."""
parser = ArgumentParser(
Expand Down Expand Up @@ -139,6 +180,7 @@ def get_parser() -> ArgumentParser:
type=str,
nargs="*",
)
parser.set_defaults(**load_config_file())
return parser


Expand Down

0 comments on commit a210ab3

Please sign in to comment.