Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[P0] enable use_fast option in the alignable to hyper boost training speed in case intervention locations (for position+subspace) are fixed in a batch #33

Closed
frankaging opened this issue Jan 3, 2024 · 2 comments
Assignees

Comments

@frankaging
Copy link
Collaborator

frankaging commented Jan 3, 2024

Description:
Currently, the library aims for flexibility in the inputs as well as a small training batch size in case the intervention is trainable. For instance, we assume each example in the batch can have different intervention locations as well as different intervention subspaces allowing more flexible configurations.

This is not desired when we have a large batch size, and intervention location does not change within a batch. Suppose we want to localize (a+b) with a simple NN that solves (a+b)*c, and we want to localize (a+b) with DAS and a fixed dimensionality of 16, the intervention location stays the same. However, current code will actually do the intervention in the example-level, not in the batch level. See,

for batch_i, locations in enumerate(unit_locations):
    tensor_input[
        batch_i, locations, start_index:end_index
    ] = replacing_tensor_input[batch_i]

this can be,

tensor_input[
    :, locations, start_index:end_index
] = replacing_tensor_input[:]

subspace intervention,

    if subspaces is not None:
        for example_i in range(len(subspaces)):
            # render subspace as column indices
            sel_subspace_indices = []
            for subspace in subspaces[example_i]:
                sel_subspace_indices.extend(
                    [
                        i for i in range(
                            subspace_partition[subspace][0], 
                            subspace_partition[subspace][1]
                        )
                    ])
            if mode == "interchange":
                base[example_i, ..., sel_subspace_indices] = \
                    source[example_i, ..., sel_subspace_indices]
            elif mode == "add":
                base[example_i, ..., sel_subspace_indices] += \
                    source[example_i, ..., sel_subspace_indices]
            elif mode == "subtract":
                base[example_i, ..., sel_subspace_indices] -= \
                    source[example_i, ..., sel_subspace_indices]

can be,

if subspaces is not None:
    if subspace_partition is None:
        sel_subspace_indices = subspaces[0]
    else:
        sel_subspace_indices = []
        for subspace in subspaces[0]:
            sel_subspace_indices.extend(
                [
                    i for i in range(
                        subspace_partition[subspace][0], 
                        subspace_partition[subspace][1]
                    )
                ])
    if mode == "interchange":
        base[..., sel_subspace_indices] = \
            source[..., sel_subspace_indices]
    elif mode == "add":
        base[..., sel_subspace_indices] += \
            source[..., sel_subspace_indices]
    elif mode == "subtract":
        base[..., sel_subspace_indices] -= \
            source[..., sel_subspace_indices]
else:
    base[..., :interchange_dim] = source[..., :interchange_dim]

We should enable a flag as use_fact in the alignable config, and do a validation check that fails fast during the forward call.

This PR tracks the use_fast effort for position-based intervention as well as subspace-based intervention. It does not cover head-based or head+position-based yet. Will cover the latter one in a separate PR.

Testing Done:

  • writing additional integration tests (4)
  • log:
In case multiple location tags are passed only the first one will be considered
testing stream: value_output with a single position
WARNING:root:Detected use_fast=True means the intervention location will be static within a batch.

In case multiple location tags are passed only the first one will be considered
.
----------------------------------------------------------------------
Ran 18 tests in 30.117s

OK
@frankaging frankaging self-assigned this Jan 3, 2024
@frankaging frankaging changed the title [P0] enable use_fast option in the alignable to hyper boost training speed in case intervention locations are fixed in a batch [P0] enable use_fast option in the alignable to hyper boost training speed in case intervention locations (for position+subspace) are fixed in a batch Jan 4, 2024
@frankaging
Copy link
Collaborator Author

Instead of a validation check, we will throw a warning to speed up stuff, since validation over inputs will take time which is against the motivation of being fast.

@frankaging
Copy link
Collaborator Author

hand link: frankaging/align-transformers@63bd767

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant