Skip to content

Commit

Permalink
Merge pull request #23 from pomonam/docu
Browse files Browse the repository at this point in the history
Update documentation
  • Loading branch information
pomonam committed Jun 22, 2024
2 parents ccfd2ad + 3c0b52f commit d43a06c
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 26 deletions.
44 changes: 27 additions & 17 deletions DOCUMENTATION.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,7 @@ query_dataset = prepare_query_dataset()
**Define a Task.**
To compute influence scores, you need to define a [`Task`](https://github.com/pomonam/kronfluence/blob/main/kronfluence/task.py) class.
This class contains information about the trained model and how influence scores will be computed:
(1) how to compute the training loss; (2) how to compute the measurable quantity (f(θ) in the [paper](https://arxiv.org/abs/2308.03296); see Equation 5);
(3) which modules to use for influence function computations; and (4) whether the model used [attention mask](https://huggingface.co/docs/transformers/en/glossary#attention-mask).
(1) how to compute the training loss; (2) how to compute the measurable quantity (f(θ) in the [paper](https://arxiv.org/abs/2308.03296); see **Equation 5**);

```python
from typing import Any, Dict, List, Optional, Union
Expand Down Expand Up @@ -180,6 +179,11 @@ def forward(x: torch.Tensor) -> torch.Tensor:
x = self.linear.weight @ x + self.linear.bias # This does not work 😞
```

> [!WARNING]
> The default arguments assume the module is used only once during the forward pass.
> IIf your model shares parameters (e.g., the module is used in multiple places during the forward pass), set
> `shared_parameters_exist=True` in both `FactorArguments` and `ScoreArguments`.
**Why are there so many arguments?**
Kronfluence was originally developed to compute influence scores on large-scale models, which is why `FactorArguments` and `ScoreArguments`
have many parameters to support these use cases. However, for most standard applications, the default argument values
Expand All @@ -201,7 +205,6 @@ factor_args = FactorArguments(
use_empirical_fisher=False,
distributed_sync_steps=1000,
amp_dtype=None,
compile_mode=None,

# Settings for covariance matrix fitting.
covariance_max_examples=100_000,
Expand All @@ -219,6 +222,8 @@ factor_args = FactorArguments(
lambda_module_partition_size=1,
lambda_iterative_aggregate=False,
cached_activation_cpu_offload=False,
shared_parameters_exist=False,
per_sample_gradient_dtype=torch.float32,
lambda_dtype=torch.float32,
)

Expand All @@ -231,7 +236,6 @@ You can change:
- `use_empirical_fisher`: Determines whether to use the [empirical Fisher](https://arxiv.org/abs/1905.12558) (using actual labels from batch)
instead of the true Fisher (using sampled labels from model's predictions). It is recommended to be `False`.
- `amp_dtype`: Selects the dtype for [automatic mixed precision (AMP)](https://pytorch.org/docs/stable/amp.html). Disables AMP if set to `None`.
- `compile_model`: Selects the mode for [torch compile](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html). Disables torch compile if set to `None`.

### Fitting Covariance Matrices

Expand All @@ -245,7 +249,7 @@ analyzer.fit_covariance_matrices(factors_name="initial_factor", dataset=train_da
covariance_matrices = analyzer.load_covariance_matrices(factors_name="initial_factor")
```

This step corresponds to Equation 16 in the paper. You can tune:
This step corresponds to **Equation 16** in the paper. You can tune:
- `covariance_max_examples`: Controls the maximum number of data points for fitting covariance matrices. Setting it to `None`,
Kronfluence computes covariance matrices for all data points.
- `covariance_data_partition_size`: Number of data partitions to use for computing covariance matrices.
Expand Down Expand Up @@ -278,7 +282,7 @@ analyzer.perform_eigendecomposition(factors_name="initial_factor", factor_args=f
eigen_factors = analyzer.load_eigendecomposition(factors_name="initial_factor")
```

This corresponds to Equation 18 in the paper. You can tune:
This corresponds to **Equation 18** in the paper. You can tune:
- `eigendecomposition_dtype`: `dtype` for performing eigendecomposition. You can also use `torch.float32`,
but `torch.float64` is strongly recommended.

Expand All @@ -293,18 +297,20 @@ analyzer.fit_lambda_matrices(factors_name="initial_factor", dataset=train_datase
lambda_matrices = analyzer.load_lambda_matrices(factors_name="initial_factor")
```

This corresponds to Equation 20 in the paper. You can tune:
This corresponds to **Equation 20** in the paper. You can tune:
- `lambda_max_examples`: Controls the maximum number of data points for fitting Lambda matrices.
- `lambda_data_partition_size`: Number of data partitions to use for computing Lambda matrices.
- `lambda_module_partition_size`: Number of module partitions to use for computing Lambda matrices.
- `cached_activation_cpu_offload`: Computing the per-sample-gradient requires saving the intermediate activation in memory.
You can set `cached_activation_cpu_offload=True` to cache these activations in CPU. This is helpful for dealing with OOMs, but will make the overall computation slower.
- `lambda_iterative_aggregate`: Whether to compute the Lambda matrices with for-loops instead of batched matrix multiplications.
This is helpful for reducing peak memory, as it avoids holding multiple copies of tensors with the same shape as the per-sample-gradient.
- `shared_parameters_exist`: Specifies whether the shared parameters exist in the forward pass.
- `per_sample_gradient_dtype`: `dtype` for computing per-sample-gradient. You can also use `torch.bfloat16`
or `torch.float16`.
- `lambda_dtype`: `dtype` for computing Lambda matrices. You can also use `torch.bfloat16`
or `torch.float16`.


**Dealing with OOMs.** Here are some steps to fix Out of Memory (OOM) errors.
1. Try reducing the `per_device_batch_size` when fitting Lambda matrices.
2. Try setting `lambda_iterative_aggregate=True` or `cached_activation_cpu_offload=True`. (Try out `lambda_iterative_aggregate=True` first.)
Expand Down Expand Up @@ -336,13 +342,11 @@ score_args = ScoreArguments(
cached_activation_cpu_offload=False,
distributed_sync_steps=1000,
amp_dtype=None,
compile_mode=None,

# More functionalities to compute influence scores.
data_partition_size=1,
module_partition_size=1,
per_module_score=False,
per_token_score=False,
use_measurement_for_self_influence=False,

# Configuration for query batching.
Expand All @@ -361,17 +365,16 @@ score_args = ScoreArguments(
`(0.1 x mean eigenvalues)` if `None`, as done in [this paper](https://arxiv.org/abs/2308.03296).
- `cached_activation_cpu_offload`: Whether to offload cached activations to CPU.
- `amp_dtype`: Selects the dtype for [automatic mixed precision (AMP)](https://pytorch.org/docs/stable/amp.html). Disables AMP if set to `None`.
- `compile_model`: Selects the mode for [torch compile](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html). Disables torch compile if set to `None`.
- `data_partition_size`: Number of data partitions for computing influence scores.
- `module_partition_size`: Number of module partitions for computing influence scores.
- `per_module_score`: Whether to return a per-module influence scores. Instead of summing over influences across
all modules, this will keep track of intermediate module-wise scores.
- `per_token_score`: Whether to return a per-token influence scores. Instead of summing over influence scores across
all tokens, this will keep track of influence scores for each token. Note that this is only supported for Transformer-based models (language modeling).
- `query_gradient_rank`: The rank for the query batching (low-rank approximation to the query gradient; see Section 3.2.2). If `None`, no query batching will be used.
- `query_gradient_rank`: The rank for the query batching (low-rank approximation to the preconditioned query gradient; see **Section 3.2.2**). If `None`, no query batching will be used.
- `query_gradient_svd_dtype`: `dtype` for performing singular value decomposition (SVD) for query batch. You can also use `torch.float64`.
- `num_query_gradient_aggregations`: Number of query gradients to aggregate over. For example, when `num_query_gradient_aggregations = 2` with
`query_batch_size = 16`, a total of 32 query gradients will be stored in memory when computing dot products with training gradients.
- `num_query_gradient_aggregations`: Number of query gradients to aggregate over. For example, when `num_query_gradient_aggregations=2` with
`query_batch_size=16`, a total of 32 query gradients will be stored in memory when computing dot products with training gradients.
- `use_measurement_for_self_influence`: Whether to use the measurement (instead of the loss) when computing self-influence scores.
- `score_dtype`: `dtype` for computing influence scores. You can use `torch.bfloat16` or `torch.float16`.
- `per_sample_gradient_dtype`: `dtype` for computing per-sample-gradient. You can use `torch.bfloat16` or `torch.float16`.
Expand All @@ -380,7 +383,7 @@ but `torch.float32` is recommended.

### Computing Influence Scores

To compute pairwise influence scores (Equation 5 in the paper), you can run:
To compute pairwise influence scores (**Equation 5** in the paper), you can run:

```python
# Computing pairwise influence scores.
Expand All @@ -389,7 +392,7 @@ analyzer.compute_pairwise_scores(scores_name="pairwise", factors_name="ekfac", s
scores = analyzer.load_pairwise_scores(scores_name="pairwise")
```

To compute self-influence scores (see Section 5.4 from [paper](https://arxiv.org/pdf/1703.04730.pdf)), you can run:
To compute self-influence scores (see **Section 5.4** from [this paper](https://arxiv.org/pdf/1703.04730.pdf)), you can run:

```python
# Computing self-influence scores.
Expand All @@ -398,6 +401,13 @@ analyzer.compute_self_scores(scores_name="self", factors_name="ekfac", score_arg
scores = analyzer.load_self_scores(scores_name="self")
```

By default, self-influence score computations only use the loss function for gradient calculations.
In this case, the method returns a vector of size `len(train_dataset)`, where each value corresponds
to `g_l^T ⋅ H^{-1} ⋅ g_l`. Here, `g_l` denotes the gradient of the loss function with respect to the model parameters,
and `H^{-1}` represents the inverse Hessian matrix. If you want to use the measurement function instead of the loss function
for self-influence calculations, set `use_measurement_for_self_influence=True`. In this case, each value in the returned
vector will correspond to `g_m^T ⋅ H^{-1} ⋅ g_l`, where `g_m` is the gradient of the measurement function with respect to the model parameters.

**Dealing with OOMs.** Here are some steps to fix Out of Memory (OOM) errors.
1. Try reducing the `per_device_query_batch_size` or `per_device_train_batch_size`.
2. Try setting `cached_activation_cpu_offload=True`.
Expand Down Expand Up @@ -428,4 +438,4 @@ train the model.
4. [Understanding Black-box Predictions via Influence Functions](https://arxiv.org/abs/1703.04730). Pang Wei Koh, Percy Liang. ICML, 2017.
5. [Optimizing Neural Networks with Kronecker-factored Approximate Curvature](https://arxiv.org/abs/1503.05671). James Martens, Roger Grosse. Tech Report, 2015.
5. [Fast Approximate Natural Gradient Descent in a Kronecker-factored Eigenbasis](https://arxiv.org/abs/1806.03884). Thomas George, César Laurent, Xavier Bouthillier, Nicolas Ballas, Pascal Vincent. NeurIPS, 2018.
6. [Training Data Attribution via Approximate Unrolled Differentiation](https://arxiv.org/abs/2405.12186). Juhan Bae, Wu Lin, Jonathan Lorraine, Roger Grosse. Preprint, 2024.
6. [Training Data Attribution via Approximate Unrolled Differentiation](https://arxiv.org/abs/2405.12186). Juhan Bae, Wu Lin, Jonathan Lorraine, Roger Grosse. Preprint, 2024.
57 changes: 48 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@
---

> **Kronfluence** is a PyTorch package designed to compute [influence functions](https://arxiv.org/abs/1703.04730) using [Kronecker-factored Approximate Curvature (KFAC)](https://arxiv.org/abs/1503.05671) or [Eigenvalue-corrected KFAC (EKFAC)](https://arxiv.org/abs/1806.03884).
For a detailed description of the methodology, see the [**paper**](https://arxiv.org/abs/2308.03296) *Studying Large Language Model Generalization with Influence Functions*.

For a detailed description of the methodology, see the [**paper**](https://arxiv.org/abs/2308.03296), *Studying Large Language Model Generalization with Influence Functions*.
---

> [!WARNING]
Expand Down Expand Up @@ -113,8 +112,30 @@ analyzer.compute_pairwise_scores(
scores = analyzer.load_pairwise_scores(scores_name="my_scores")
```

Kronfluence supports various PyTorch features. The following table summarizes the supported features:

<div align="center">

| Feature | Supported |
|-----------------------------------------------------------------------------------------------------------------------------|:---------:|
| [Distributed Data Parallel (DDP)](https://pytorch.org/docs/master/generated/torch.nn.parallel.DistributedDataParallel.html) ||
| [Automatic Mixed Precision (AMP)](https://pytorch.org/docs/stable/amp.html) ||
| [Torch Compile](https://pytorch.org/docs/stable/generated/torch.compile.html) ||
| [Gradient Checkpointing](https://pytorch.org/docs/stable/checkpoint.html) ||
| [Fully Sharded Data Parallel (FSDP)](https://pytorch.org/docs/stable/fsdp.html) ||

</div>

The [examples](https://github.com/pomonam/kronfluence/tree/main/examples) folder contains several examples demonstrating how to use Kronfluence.

## LogIX

While Kronfluence supports influence function computations on large-scale models like `Meta-Llama-3-8B-Instruct`, for those
interested in running influence analysis on even larger models or with a large number of query data points, our
project [LogIX](https://github.com/logix-project/logix) may be worth exploring. It integrates with frameworks like
[HuggingFace Trainer](https://huggingface.co/docs/transformers/en/main_classes/trainer) and [PyTorch Lightning](https://lightning.ai/docs/pytorch/stable/)
and is also compatible with many PyTorch features (DDP & FSDP & [DeepSpeed](https://github.com/microsoft/DeepSpeed)).

## Contributing

Contributions are welcome! To get started, please review our [Code of Conduct](https://github.com/pomonam/kronfluence/blob/main/CODE_OF_CONDUCT.md). For bug fixes, please submit a pull request.
Expand All @@ -131,13 +152,31 @@ cd kronfluence
pip install -e ."[dev]"
```

## LogIX
### Style Testing

While Kronfluence supports influence function computations on large-scale models like `Meta-Llama-3-8B-Instruct`, for those
interested in running influence analysis on even larger models or with a high number of query data points, our another
project [LogIX](https://github.com/logix-project/logix) may be worth exploring. It integrates with frameworks like
[HuggingFace Trainer](https://huggingface.co/docs/transformers/en/main_classes/trainer) and [PyTorch Lightning](https://lightning.ai/docs/pytorch/stable/)
and is compatible with many existing PyTorch features (DDP & FSDP).
To maintain code quality and consistency, we run ruff and linting tests on pull requests. Before submitting a
pull request, please ensure that your code adheres to our formatting and linting guidelines. The following commands will
modify your code. It is recommended to create a Git commit before running them to easily revert any unintended changes.

Sort import orderings using [isort](https://pycqa.github.io/isort/):

```bash
isort kronfluence
```

Format code using [ruff](https://docs.astral.sh/ruff/):

```bash
ruff format kronfluence
```

To view all [pylint](https://www.pylint.org/) complaints, run the following command:

```bash
pylint kronfluence
```

Please address any reported issues before submitting your PR.

## Acknowledgements

Expand All @@ -146,4 +185,4 @@ I also thank Hwijeen Ahn, Sang Keun Choe, Youngseog Chung, Minsoo Kang, Lev McKi

## License

This software is released under the Apache 2.0 License, as detailed in the [LICENSE](https://github.com/pomonam/kronfluence/blob/main/LICENSE) file.
This software is released under the Apache 2.0 License, as detailed in the [LICENSE](https://github.com/pomonam/kronfluence/blob/main/LICENSE) file.

0 comments on commit d43a06c

Please sign in to comment.