Skip to content

Commit

Permalink
Closes #86: Discard ApplicationContext post-apply
Browse files Browse the repository at this point in the history
  • Loading branch information
carbonleakage authored and shaypal5 committed Feb 22, 2022
1 parent 9d04fdd commit a2119d4
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 22 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Expand Up @@ -104,3 +104,6 @@ tags
# don't like pipfiles
Pipfile
Pipfile.lock

# Temporary testing file
notebooks/debug*.py
59 changes: 38 additions & 21 deletions pdpipe/core.py
Expand Up @@ -810,7 +810,8 @@ def _transform(self, df, verbose):
raise NotImplementedError

def _post_transform_lock(self):
self.application_context.lock()
# Application context is discarded after pipeline application
self.application_context = None
self.fit_context.lock()

def apply(
Expand All @@ -819,7 +820,8 @@ def apply(
exraise: Optional[bool] = None,
verbose: Optional[bool] = False,
time: Optional[bool] = False,
context: Optional[dict] = {},
fit_context: Optional[dict] = {},
application_context: Optional[dict] = {},
):
"""Applies this pipeline stage to the given dataframe.
Expand Down Expand Up @@ -860,15 +862,16 @@ def apply(
exraise=exraise,
verbose=verbose,
time=time,
context=context,
application_context=application_context,
)
return res
res = self.fit_transform(
X=df,
exraise=exraise,
verbose=verbose,
time=time,
context=context,
fit_context=fit_context,
application_context=application_context,
)
return res

Expand All @@ -878,12 +881,13 @@ def __timed_fit_transform(
y: Optional[Iterable] = None,
exraise: Optional[bool] = None,
verbose: Optional[bool] = False,
context: Optional[dict] = {},
fit_context: Optional[dict] = {},
application_context: Optional[dict] = {},
):
self.fit_context = PdpApplicationContext()
self.fit_context.update(context)
self.fit_context.update(fit_context)
self.application_context = PdpApplicationContext()
self.application_context.update(context)
self.application_context.update(application_context)
inter_x = X
times = []
prev = time.time()
Expand Down Expand Up @@ -918,7 +922,8 @@ def fit_transform(
exraise: Optional[bool] = None,
verbose: Optional[bool] = False,
time: Optional[bool] = False,
context: Optional[dict] = {},
fit_context: Optional[dict] = {},
application_context: Optional[dict] = {},
):
"""Fits this pipeline and transforms the input dataframe.
Expand All @@ -942,7 +947,10 @@ def fit_transform(
time : bool, default False
If True, per-stage application time is measured and reported when
pipeline application is done.
context : dict, optional
fit_context : dict, option
Context for the entire pipeline, is retained after the pipeline
application is completed.
application_context : dict, optional
Context to add to the application context of this call. Can map
str keys to arbitrary object values to be used by pipeline stages
during this pipeline application.
Expand All @@ -954,12 +962,15 @@ def fit_transform(
"""
if time:
return self.__timed_fit_transform(
X=X, y=y, exraise=exraise, verbose=verbose, context=context)
X=X, y=y, exraise=exraise,
verbose=verbose,
fit_context=fit_context,
application_context=application_context)
inter_x = X
self.application_context = PdpApplicationContext()
self.application_context.update(context)
self.application_context.update(application_context)
self.fit_context = PdpApplicationContext()
self.fit_context.update(context)
self.fit_context.update(fit_context)
for i, stage in enumerate(self._stages):
try:
stage.fit_context = self.fit_context
Expand All @@ -985,7 +996,8 @@ def fit(
exraise: Optional[bool] = None,
verbose: Optional[bool] = False,
time: Optional[bool] = False,
context: Optional[dict] = {},
fit_context: Optional[dict] = {},
application_context: Optional[dict] = {},
):
"""Fits this pipeline without transforming the input dataframe.
Expand All @@ -1009,7 +1021,10 @@ def fit(
time : bool, default False
If True, per-stage application time is measured and reported when
pipeline application is done.
context : dict, optional
fit_context : dict, option
Context for the entire pipeline, is retained after the pipeline
application is completed.
application_context : dict, optional
Context to add to the application context of this call. Can map
str keys to arbitrary object values to be used by pipeline stages
during this pipeline application.
Expand All @@ -1025,7 +1040,8 @@ def fit(
exraise=exraise,
verbose=verbose,
time=time,
context=context,
fit_context=fit_context,
application_context=application_context,
)
return X

Expand All @@ -1035,13 +1051,13 @@ def __timed_transform(
y: Optional[Iterable[float]] = None,
exraise: Optional[bool] = None,
verbose: Optional[bool] = None,
context: Optional[dict] = {},
application_context: Optional[dict] = {},
) -> pandas.DataFrame:
inter_x = X
times = []
prev = time.time()
self.application_context = PdpApplicationContext()
self.application_context.update(context)
self.application_context.update(application_context)
for i, stage in enumerate(self._stages):
try:
stage.fit_context = self.fit_context
Expand Down Expand Up @@ -1073,7 +1089,7 @@ def transform(
exraise: Optional[bool] = None,
verbose: Optional[bool] = None,
time: Optional[bool] = False,
context: Optional[dict] = {},
application_context: Optional[dict] = {},
) -> pandas.DataFrame:
"""Transforms the given dataframe without fitting this pipeline.
Expand All @@ -1100,7 +1116,7 @@ def transform(
time : bool, default False
If True, per-stage application time is measured and reported when
pipeline application is done.
context : dict, optional
application_context : dict, optional
Context to add to the application context of this call. Can map
str keys to arbitrary object values to be used by pipeline stages
during this pipeline application.
Expand All @@ -1117,10 +1133,11 @@ def transform(
" unfitted!").format(stage))
if time:
return self.__timed_transform(
X=X, y=y, exraise=exraise, verbose=verbose, context=context)
X=X, y=y, exraise=exraise, verbose=verbose,
application_context=application_context)
inter_df = X
self.application_context = PdpApplicationContext()
self.application_context.update(context)
self.application_context.update(application_context)
for i, stage in enumerate(self._stages):
try:
stage.application_context = self.application_context
Expand Down
2 changes: 1 addition & 1 deletion tests/core/test_app_context.py
Expand Up @@ -132,7 +132,7 @@ def test_application_context_injection():
assert len(pipeline) == 2
df = _test_df()
val = randint(840, 921)
res_df = pipeline.apply(df, verbose=True, context={'a': val})
res_df = pipeline.apply(df, verbose=True, fit_context={'a': val})
assert 'num1' in res_df.columns
assert 'num1+val' in res_df.columns
assert 'num2' in res_df.columns
Expand Down

0 comments on commit a2119d4

Please sign in to comment.