Skip to content

Commit

Permalink
Add stbt.set_global_ocr_corrections: Default corrections for all ocr …
Browse files Browse the repository at this point in the history
…calls
  • Loading branch information
drothlis committed Apr 6, 2020
1 parent 89d889f commit 2475ee9
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 5 deletions.
42 changes: 37 additions & 5 deletions _stbt/ocr.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,14 @@ def ocr(frame=None, region=Region.ALL,
corrections={'bad': 'good',
re.compile(r'[oO]'): '0'}
Plain strings are replaced first (in the order they are specified),
followed by regular expresions (in the order they are specified).
The default value for this parameter can be set with
`stbt.set_global_ocr_corrections`. If global corrections have been set
*and* this ``corrections`` parameter is specified, the corrections in
this parameter are applied first.
| Added in v28: The ``upsample`` and ``text_color`` parameters.
| Added in v29: The ``text_color_threshold`` parameter.
| Added in v30: The ``engine`` parameter and support for Tesseract v4.
Expand Down Expand Up @@ -287,9 +295,7 @@ def ocr(frame=None, region=Region.ALL,
tesseract_user_patterns, tesseract_user_words, upsample, text_color,
text_color_threshold, engine, char_whitelist, imglog)
text = text.strip().translate(_ocr_transtab)

if corrections is not None:
text = apply_ocr_corrections(text, corrections)
text = apply_ocr_corrections(text, corrections)

debug(u"OCR in region %s read '%s'." % (region, text))
_log_ocr_image_debug(imglog, text)
Expand Down Expand Up @@ -390,12 +396,21 @@ def match_text(text, frame=None, region=Region.ALL,
PatternType = type(re.compile(""))


def apply_ocr_corrections(text, corrections):
def apply_ocr_corrections(text, corrections=None):
"""Applies the same corrections as `stbt.ocr`'s ``corrections`` parameter.
This is also available as a separate function, so that you can use it to
post-process old test artifacts using new corrections.
post-process old test artifacts using new corrections. See also
`stbt.set_global_ocr_corrections`.
"""
if corrections:
text = _apply_ocr_corrections(text, corrections)
if global_ocr_corrections:
text = _apply_ocr_corrections(text, global_ocr_corrections)
return text


def _apply_ocr_corrections(text, corrections):
# Match plain strings at word boundaries:
pattern = "|".join(r"\b(" + re.escape(k) + r")\b"
for k in corrections
Expand All @@ -410,6 +425,23 @@ def apply_ocr_corrections(text, corrections):
return text


global_ocr_corrections = {}


def set_global_ocr_corrections(corrections):
"""Specify default OCR corrections that apply to all calls to `stbt.ocr`
and `stbt.apply_ocr_corrections`.
See the ``corrections`` parameter of `stbt.ocr` for more details.
We recommend calling this function from ``tests/__init__.py`` to ensure it
is called before any test script is executed.
"""
global global_ocr_corrections
debug("Initialising global ocr corrections to: %r" % (corrections,))
global_ocr_corrections = corrections


_memoise_tesseract_version = None


Expand Down
2 changes: 2 additions & 0 deletions stbt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
ocr,
OcrEngine,
OcrMode,
set_global_ocr_corrections,
TextMatchResult)
from _stbt.precondition import (
as_precondition,
Expand Down Expand Up @@ -113,6 +114,7 @@
"press_until_match",
"Region",
"save_frame",
"set_global_ocr_corrections",
"TextMatchResult",
"TransitionStatus",
"UITestError",
Expand Down
9 changes: 9 additions & 0 deletions tests/test_ocr.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,15 @@ def test_corrections(corrections, expected):
assert expected == stbt.ocr(frame=f, mode=stbt.OcrMode.SINGLE_WORD,
corrections=corrections)

try:
stbt.set_global_ocr_corrections({'OO': '11'})
if expected == "OO":
expected = "11"
assert expected == stbt.ocr(frame=f, mode=stbt.OcrMode.SINGLE_WORD,
corrections=corrections)
finally:
stbt.set_global_ocr_corrections({})


@requires_tesseract
@pytest.mark.parametrize("words", [
Expand Down

0 comments on commit 2475ee9

Please sign in to comment.