Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added type hints to ImageMorph #7708

Merged
merged 1 commit into from
Jan 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,6 @@ exclude = [
'^src/PIL/DdsImagePlugin.py$',
'^src/PIL/FpxImagePlugin.py$',
'^src/PIL/Image.py$',
'^src/PIL/ImageMorph.py$',
'^src/PIL/ImageQt.py$',
'^src/PIL/ImageShow.py$',
'^src/PIL/ImImagePlugin.py$',
Expand Down
50 changes: 30 additions & 20 deletions src/PIL/ImageMorph.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,14 @@ class LutBuilder:

"""

def __init__(self, patterns=None, op_name=None):
def __init__(
self, patterns: list[str] | None = None, op_name: str | None = None
) -> None:
if patterns is not None:
self.patterns = patterns
else:
self.patterns = []
self.lut = None
self.lut: bytearray | None = None
if op_name is not None:
known_patterns = {
"corner": ["1:(... ... ...)->0", "4:(00. 01. ...)->1"],
Expand All @@ -87,25 +89,27 @@ def __init__(self, patterns=None, op_name=None):

self.patterns = known_patterns[op_name]

def add_patterns(self, patterns):
def add_patterns(self, patterns: list[str]) -> None:
self.patterns += patterns

def build_default_lut(self):
def build_default_lut(self) -> None:
symbols = [0, 1]
m = 1 << 4 # pos of current pixel
self.lut = bytearray(symbols[(i & m) > 0] for i in range(LUT_SIZE))

def get_lut(self):
def get_lut(self) -> bytearray | None:
return self.lut

def _string_permute(self, pattern, permutation):
def _string_permute(self, pattern: str, permutation: list[int]) -> str:
"""string_permute takes a pattern and a permutation and returns the
string permuted according to the permutation list.
"""
assert len(permutation) == 9
return "".join(pattern[p] for p in permutation)

def _pattern_permute(self, basic_pattern, options, basic_result):
def _pattern_permute(
self, basic_pattern: str, options: str, basic_result: int
) -> list[tuple[str, int]]:
"""pattern_permute takes a basic pattern and its result and clones
the pattern according to the modifications described in the $options
parameter. It returns a list of all cloned patterns."""
Expand Down Expand Up @@ -135,12 +139,13 @@ def _pattern_permute(self, basic_pattern, options, basic_result):

return patterns

def build_lut(self):
def build_lut(self) -> bytearray:
"""Compile all patterns into a morphology lut.

TBD :Build based on (file) morphlut:modify_lut
"""
self.build_default_lut()
assert self.lut is not None
patterns = []

# Parse and create symmetries of the patterns strings
Expand All @@ -159,10 +164,10 @@ def build_lut(self):
patterns += self._pattern_permute(pattern, options, result)

# compile the patterns into regular expressions for speed
for i, pattern in enumerate(patterns):
compiled_patterns = []
for pattern in patterns:
p = pattern[0].replace(".", "X").replace("X", "[01]")
p = re.compile(p)
patterns[i] = (p, pattern[1])
compiled_patterns.append((re.compile(p), pattern[1]))

# Step through table and find patterns that match.
# Note that all the patterns are searched. The last one
Expand All @@ -172,8 +177,8 @@ def build_lut(self):
bitpattern = bin(i)[2:]
bitpattern = ("0" * (9 - len(bitpattern)) + bitpattern)[::-1]

for p, r in patterns:
if p.match(bitpattern):
for pattern, r in compiled_patterns:
if pattern.match(bitpattern):
self.lut[i] = [0, 1][r]

return self.lut
Expand All @@ -182,15 +187,20 @@ def build_lut(self):
class MorphOp:
"""A class for binary morphological operators"""

def __init__(self, lut=None, op_name=None, patterns=None):
def __init__(
self,
lut: bytearray | None = None,
op_name: str | None = None,
patterns: list[str] | None = None,
) -> None:
"""Create a binary morphological operator"""
self.lut = lut
if op_name is not None:
self.lut = LutBuilder(op_name=op_name).build_lut()
elif patterns is not None:
self.lut = LutBuilder(patterns=patterns).build_lut()

def apply(self, image):
def apply(self, image: Image.Image):
"""Run a single morphological operation on an image

Returns a tuple of the number of changed pixels and the
Expand All @@ -206,7 +216,7 @@ def apply(self, image):
count = _imagingmorph.apply(bytes(self.lut), image.im.id, outimage.im.id)
return count, outimage

def match(self, image):
def match(self, image: Image.Image):
"""Get a list of coordinates matching the morphological operation on
an image.

Expand All @@ -221,7 +231,7 @@ def match(self, image):
raise ValueError(msg)
return _imagingmorph.match(bytes(self.lut), image.im.id)

def get_on_pixels(self, image):
def get_on_pixels(self, image: Image.Image):
"""Get a list of all turned on pixels in a binary image

Returns a list of tuples of (x,y) coordinates
Expand All @@ -232,7 +242,7 @@ def get_on_pixels(self, image):
raise ValueError(msg)
return _imagingmorph.get_on_pixels(image.im.id)

def load_lut(self, filename):
def load_lut(self, filename: str) -> None:
"""Load an operator from an mrl file"""
with open(filename, "rb") as f:
self.lut = bytearray(f.read())
Expand All @@ -242,14 +252,14 @@ def load_lut(self, filename):
msg = "Wrong size operator file!"
raise Exception(msg)

def save_lut(self, filename):
def save_lut(self, filename: str) -> None:
"""Save an operator to an mrl file"""
if self.lut is None:
msg = "No operator loaded"
raise Exception(msg)
with open(filename, "wb") as f:
f.write(self.lut)

def set_lut(self, lut):
def set_lut(self, lut: bytearray | None) -> None:
"""Set the lut from an external source"""
self.lut = lut
5 changes: 5 additions & 0 deletions src/PIL/_imagingmorph.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from __future__ import annotations

from typing import Any

def __getattr__(name: str) -> Any: ...