Skip to content

Add _apply_fn_to_data in AOBaseClass #2349

@drisspg

Description

@drisspg
Contributor

Summary

This pattern is very common and can be implemented generically.

The only times this will change is when we need to spoof our actual size, which is uncommon NJT is the only one I can think of

def _apply_fn_to_data(self, fn: Callable):
    """Applies a fn to all tensor components stored on this class"""
    tensor_names, ctx = self.__tensor_flatten__()

    # Apply the function to each tensor component
    new_tensors = {}
    for name in tensor_names:
        new_tensors[name] = fn(getattr(self, name))

    return self.__class__.__tensor_unflatten__(
        new_tensors,
        ctx,
        None,  # outer_size parameter
        None,  # outer_stride parameter
    )

Activity

added a commit that references this issue on Jun 12, 2025
721e970
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

      Development

      Participants

      @drisspg

      Issue actions

        Add _apply_fn_to_data in AOBaseClass · Issue #2349 · pytorch/ao