Skip to content

Commit

Permalink
Update docstring for functional form of IterDataPipes
Browse files Browse the repository at this point in the history
Copy the docstring from IterDataPipe classes to their functional form. Xref pytorch/data#792.
  • Loading branch information
weiji14 committed May 2, 2023
1 parent c6c9258 commit 603172a
Showing 1 changed file with 15 additions and 3 deletions.
18 changes: 15 additions & 3 deletions torch/utils/data/datapipes/datapipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,9 @@ def __getattr__(self, attribute_name):
if attribute_name in _iter_deprecated_functional_names:
kwargs = _iter_deprecated_functional_names[attribute_name]
_deprecation_warning(**kwargs)
function = functools.partial(IterDataPipe.functions[attribute_name], self)
f = IterDataPipe.functions[attribute_name]
function = functools.partial(f, self)
functools.update_wrapper(wrapper=function, wrapped=f, assigned=("__doc__",))
return function
else:
raise AttributeError("'{0}' object has no attribute '{1}".format(self.__class__.__name__, attribute_name))
Expand All @@ -144,7 +146,12 @@ def class_function(cls, enable_df_api_tracing, source_dp, *args, **kwargs):

return result_pipe

function = functools.partial(class_function, cls_to_register, enable_df_api_tracing)
function = functools.partial(
class_function, cls_to_register, enable_df_api_tracing
)
functools.update_wrapper(
wrapper=function, wrapped=cls_to_register, assigned=("__doc__",)
)
cls.functions[function_name] = function

def __getstate__(self):
Expand Down Expand Up @@ -253,7 +260,9 @@ def __getattr__(self, attribute_name):
if attribute_name in _map_deprecated_functional_names:
kwargs = _map_deprecated_functional_names[attribute_name]
_deprecation_warning(**kwargs)
function = functools.partial(MapDataPipe.functions[attribute_name], self)
f = MapDataPipe.functions[attribute_name]
function = functools.partial(f, self)
functools.update_wrapper(wrapper=function, wrapped=f, assigned=("__doc__",))
return function
else:
raise AttributeError("'{0}' object has no attribute '{1}".format(self.__class__.__name__, attribute_name))
Expand All @@ -272,6 +281,9 @@ def class_function(cls, source_dp, *args, **kwargs):
return result_pipe

function = functools.partial(class_function, cls_to_register)
functools.update_wrapper(
wrapper=unction, wrapped=cls_to_register, assigned=("__doc__",)
)
cls.functions[function_name] = function

def __getstate__(self):
Expand Down

0 comments on commit 603172a

Please sign in to comment.