diff --git a/pre_commit_hooks/file_contents_sorter.py b/pre_commit_hooks/file_contents_sorter.py index 4c1c7479..ebbcd206 100644 --- a/pre_commit_hooks/file_contents_sorter.py +++ b/pre_commit_hooks/file_contents_sorter.py @@ -13,6 +13,7 @@ from typing import Any from typing import Callable from typing import IO +from typing import Iterable from typing import Optional from typing import Sequence @@ -23,12 +24,16 @@ def sort_file_contents( f: IO[bytes], key: Optional[Callable[[bytes], Any]], + *, + unique: bool = False, ) -> int: before = list(f) - after = sorted( - (line.strip(b'\n\r') for line in before if line.strip()), - key=key, + lines: Iterable[bytes] = ( + line.rstrip(b'\n\r') for line in before if line.strip() ) + if unique: + lines = set(lines) + after = sorted(lines, key=key) before_string = b''.join(before) after_string = b'\n'.join(after) + b'\n' @@ -52,13 +57,20 @@ def main(argv: Optional[Sequence[str]] = None) -> int: default=None, help='fold lower case to upper case characters', ) + parser.add_argument( + '--unique', + action='store_true', + help='ensure each line is unique', + ) args = parser.parse_args(argv) retv = PASS for arg in args.filenames: with open(arg, 'rb+') as file_obj: - ret_for_file = sort_file_contents(file_obj, key=args.ignore_case) + ret_for_file = sort_file_contents( + file_obj, key=args.ignore_case, unique=args.unique, + ) if ret_for_file: print(f'Sorting {arg}') diff --git a/tests/file_contents_sorter_test.py b/tests/file_contents_sorter_test.py index 9ebb021a..15f11342 100644 --- a/tests/file_contents_sorter_test.py +++ b/tests/file_contents_sorter_test.py @@ -45,6 +45,36 @@ FAIL, b'fee\nFie\nFoe\nfum\n', ), + ( + b'Fie\nFoe\nfee\nfee\nfum\n', + ['--ignore-case'], + FAIL, + b'fee\nfee\nFie\nFoe\nfum\n', + ), + ( + b'Fie\nFoe\nfee\nfum\n', + ['--unique'], + PASS, + b'Fie\nFoe\nfee\nfum\n', + ), + ( + b'Fie\nFie\nFoe\nfee\nfum\n', + ['--unique'], + FAIL, + b'Fie\nFoe\nfee\nfum\n', + ), + ( + b'fee\nFie\nFoe\nfum\n', + ['--unique', '--ignore-case'], + PASS, + b'fee\nFie\nFoe\nfum\n', + ), + ( + b'fee\nfee\nFie\nFoe\nfum\n', + ['--unique', '--ignore-case'], + FAIL, + b'fee\nFie\nFoe\nfum\n', + ), ), ) def test_integration(input_s, argv, expected_retval, output, tmpdir):