Skip to content

Commit

Permalink
Remove fault tolerant mode from LF/SF
Browse files Browse the repository at this point in the history
  • Loading branch information
henryre committed Oct 7, 2019
1 parent 9a86617 commit e592620
Show file tree
Hide file tree
Showing 6 changed files with 10 additions and 85 deletions.
27 changes: 2 additions & 25 deletions snorkel/labeling/lf/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,6 @@ class LabelingFunction:
Labeling resources passed in to ``f`` via ``kwargs``
pre
Preprocessors to run on data points before LF execution
fault_tolerant
Output ``-1`` if LF execution fails?
Raises
------
Expand All @@ -39,8 +37,6 @@ class LabelingFunction:
----------
name
See above
fault_tolerant
See above
"""

def __init__(
Expand All @@ -49,10 +45,8 @@ def __init__(
f: Callable[..., int],
resources: Optional[Mapping[str, Any]] = None,
pre: Optional[List[BasePreprocessor]] = None,
fault_tolerant: bool = False,
) -> None:
self.name = name
self.fault_tolerant = fault_tolerant
self._f = f
self._resources = resources or {}
self._pre = pre or []
Expand All @@ -67,9 +61,7 @@ def _preprocess_data_point(self, x: DataPoint) -> DataPoint:
def __call__(self, x: DataPoint) -> int:
"""Label data point.
Runs all preprocessors, then passes to LF. If an exception
is encountered and the LF is in fault tolerant mode,
the LF abstains from voting.
Runs all preprocessors, then passes preprocessed data point to LF.
Parameters
----------
Expand All @@ -82,11 +74,6 @@ def __call__(self, x: DataPoint) -> int:
Label for data point
"""
x = self._preprocess_data_point(x)
if self.fault_tolerant:
try:
return self._f(x, **self._resources)
except Exception:
return -1
return self._f(x, **self._resources)

def __repr__(self) -> str:
Expand All @@ -105,8 +92,6 @@ class labeling_function:
Labeling resources passed in to ``f`` via ``kwargs``
preprocessors
Preprocessors to run on data points before LF execution
fault_tolerant
Output ``-1`` if LF execution fails?
Examples
--------
Expand All @@ -132,14 +117,12 @@ def __init__(
name: Optional[str] = None,
resources: Optional[Mapping[str, Any]] = None,
pre: Optional[List[BasePreprocessor]] = None,
fault_tolerant: bool = False,
) -> None:
if callable(name):
raise ValueError("Looks like this decorator is missing parentheses!")
self.name = name
self.resources = resources
self.pre = pre
self.fault_tolerant = fault_tolerant

def __call__(self, f: Callable[..., int]) -> LabelingFunction:
"""Wrap a function to create a ``LabelingFunction``.
Expand All @@ -155,10 +138,4 @@ def __call__(self, f: Callable[..., int]) -> LabelingFunction:
New ``LabelingFunction`` executing logic in wrapped function
"""
name = self.name or f.__name__
return LabelingFunction(
name=name,
f=f,
resources=self.resources,
pre=self.pre,
fault_tolerant=self.fault_tolerant,
)
return LabelingFunction(name=name, f=f, resources=self.resources, pre=self.pre)
19 changes: 2 additions & 17 deletions snorkel/labeling/lf/nlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,6 @@ def __init__(
f: Callable[..., int],
resources: Optional[Mapping[str, Any]] = None,
pre: Optional[List[BasePreprocessor]] = None,
fault_tolerant: bool = False,
text_field: str = "text",
doc_field: str = "doc",
language: str = EN_CORE_WEB_SM,
Expand All @@ -83,13 +82,7 @@ def __init__(
self._create_or_check_preprocessor(
text_field, doc_field, language, disable, pre or [], memoize
)
super().__init__(
name,
f,
resources=resources,
pre=[self._nlp_config.nlp],
fault_tolerant=fault_tolerant,
)
super().__init__(name, f, resources=resources, pre=[self._nlp_config.nlp])


class NLPLabelingFunction(BaseNLPLabelingFunction):
Expand Down Expand Up @@ -123,8 +116,6 @@ class NLPLabelingFunction(BaseNLPLabelingFunction):
Labeling resources passed in to ``f`` via ``kwargs``
pre
Preprocessors to run before SpacyPreprocessor is executed
fault_tolerant
Output -1 if LF execution fails?
text_field
Name of data point text field to input
doc_field
Expand Down Expand Up @@ -161,8 +152,6 @@ class NLPLabelingFunction(BaseNLPLabelingFunction):
----------
name
See above
fault_tolerant
See above
"""

@classmethod
Expand All @@ -182,14 +171,13 @@ def __init__(
name: Optional[str] = None,
resources: Optional[Mapping[str, Any]] = None,
pre: Optional[List[BasePreprocessor]] = None,
fault_tolerant: bool = False,
text_field: str = "text",
doc_field: str = "doc",
language: str = EN_CORE_WEB_SM,
disable: Optional[List[str]] = None,
memoize: bool = True,
) -> None:
super().__init__(name, resources, pre, fault_tolerant)
super().__init__(name, resources, pre)
self.text_field = text_field
self.doc_field = doc_field
self.language = language
Expand Down Expand Up @@ -217,7 +205,6 @@ def __call__(self, f: Callable[..., int]) -> BaseNLPLabelingFunction:
f=f,
resources=self.resources,
pre=self.pre,
fault_tolerant=self.fault_tolerant,
text_field=self.text_field,
doc_field=self.doc_field,
language=self.language,
Expand All @@ -237,8 +224,6 @@ class nlp_labeling_function(base_nlp_labeling_function):
Labeling resources passed in to ``f`` via ``kwargs``
pre
Preprocessors to run before SpacyPreprocessor is executed
fault_tolerant
Output -1 if LF execution fails?
text_field
Name of data point text field to input
doc_field
Expand Down
6 changes: 0 additions & 6 deletions snorkel/labeling/lf/nlp_spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@ class SparkNLPLabelingFunction(BaseNLPLabelingFunction):
Labeling resources passed in to ``f`` via ``kwargs``
pre
Preprocessors to run before SpacyPreprocessor is executed
fault_tolerant
Output -1 if LF execution fails?
text_field
Name of data point text field to input
doc_field
Expand All @@ -48,8 +46,6 @@ class SparkNLPLabelingFunction(BaseNLPLabelingFunction):
----------
name
See above
fault_tolerant
See above
"""

@classmethod
Expand All @@ -72,8 +68,6 @@ class spark_nlp_labeling_function(base_nlp_labeling_function):
Labeling resources passed in to ``f`` via ``kwargs``
pre
Preprocessors to run before SpacyPreprocessor is executed
fault_tolerant
Output -1 if LF execution fails?
text_field
Name of data point text field to input
doc_field
Expand Down
12 changes: 1 addition & 11 deletions snorkel/slicing/sf/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@ class slicing_function:
Slicing resources passed in to ``f`` via ``kwargs``
preprocessors
Preprocessors to run on data points before SF execution
fault_tolerant
Output ``-1`` if SF execution fails?
Examples
--------
Expand All @@ -51,14 +49,12 @@ def __init__(
name: Optional[str] = None,
resources: Optional[Mapping[str, Any]] = None,
pre: Optional[List[BasePreprocessor]] = None,
fault_tolerant: bool = False,
) -> None:
if callable(name):
raise ValueError("Looks like this decorator is missing parentheses!")
self.name = name
self.resources = resources
self.pre = pre
self.fault_tolerant = fault_tolerant

def __call__(self, f: Callable[..., int]) -> SlicingFunction:
"""Wrap a function to create a ``SlicingFunction``.
Expand All @@ -74,10 +70,4 @@ def __call__(self, f: Callable[..., int]) -> SlicingFunction:
New ``SlicingFunction`` executing logic in wrapped function
"""
name = self.name or f.__name__
return SlicingFunction(
name=name,
f=f,
resources=self.resources,
pre=self.pre,
fault_tolerant=self.fault_tolerant,
)
return SlicingFunction(name=name, f=f, resources=self.resources, pre=self.pre)
4 changes: 0 additions & 4 deletions snorkel/slicing/sf/nlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,6 @@ class NLPSlicingFunction(BaseNLPLabelingFunction):
Labeling resources passed in to ``f`` via ``kwargs``
pre
Preprocessors to run before SpacyPreprocessor is executed
fault_tolerant
Output -1 if LF execution fails?
text_field
Name of data point text field to input
doc_field
Expand Down Expand Up @@ -75,8 +73,6 @@ class NLPSlicingFunction(BaseNLPLabelingFunction):
----------
name
See above
fault_tolerant
See above
"""

@classmethod
Expand Down
27 changes: 5 additions & 22 deletions test/labeling/lf/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,30 +34,14 @@ def _run_lf(self, lf: LabelingFunction) -> None:
self.assertEqual(lf(x_43), 0)
self.assertEqual(lf(x_19), -1)

def _run_lf_raise(self, lf: LabelingFunction) -> None:
x_none = SimpleNamespace(num=None)
with self.assertRaises(TypeError):
lf(x_none)

def _run_lf_no_raise(self, lf: LabelingFunction) -> None:
x_none = SimpleNamespace(num=None)
self.assertEqual(lf(x_none), -1)

def test_labeling_function(self) -> None:
lf = LabelingFunction(name="my_lf", f=f)
self._run_lf(lf)
self._run_lf_raise(lf)

def test_labeling_function_fault_tolerant(self) -> None:
lf = LabelingFunction(name="my_lf", f=f, fault_tolerant=True)
self._run_lf(lf)
self._run_lf_no_raise(lf)

def test_labeling_function_resources(self) -> None:
db = [3, 6, 43]
lf = LabelingFunction(name="my_lf", f=g, resources=dict(db=db))
self._run_lf(lf)
self._run_lf_no_raise(lf)

def test_labeling_function_preprocessor(self) -> None:
lf = LabelingFunction(name="my_lf", f=f, pre=[square, square])
Expand All @@ -79,7 +63,6 @@ def test_labeling_function_serialize(self) -> None:
lf = LabelingFunction(name="my_lf", f=g, resources=dict(db=db))
lf_load = pickle.loads(pickle.dumps(lf))
self._run_lf(lf_load)
self._run_lf_no_raise(lf_load)

def test_labeling_function_decorator(self) -> None:
@labeling_function()
Expand All @@ -89,17 +72,17 @@ def lf(x: DataPoint) -> int:
self.assertIsInstance(lf, LabelingFunction)
self.assertEqual(lf.name, "lf")
self._run_lf(lf)
self._run_lf_raise(lf)

def test_labeling_function_decorator_args(self) -> None:
@labeling_function(name="my_lf", fault_tolerant=True)
def lf(x: DataPoint) -> int:
return 0 if x.num > 42 else -1
db = [3, 6, 43 ** 2]

@labeling_function(name="my_lf", resources=dict(db=db), pre=[square])
def lf(x: DataPoint, db: List[int]) -> int:
return 0 if x.num in db else -1

self.assertIsInstance(lf, LabelingFunction)
self.assertEqual(lf.name, "my_lf")
self._run_lf(lf)
self._run_lf_no_raise(lf)

def test_labeling_function_decorator_no_parens(self) -> None:
with self.assertRaisesRegex(ValueError, "missing parentheses"):
Expand Down

0 comments on commit e592620

Please sign in to comment.