Skip to content

Commit

Permalink
Merge pull request #21 from pomonam/documentation
Browse files Browse the repository at this point in the history
Improvement documentation
  • Loading branch information
pomonam committed Jun 19, 2024
2 parents 1464e19 + 0601611 commit ccfd2ad
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 44 deletions.
92 changes: 58 additions & 34 deletions DOCUMENTATION.md
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
# Kronfluence: Technical Documentation & FAQs

For a detailed description of the methodology, please refer to the [**paper**](https://arxiv.org/abs/2308.03296) *Studying Large Language Model Generalization with Influence Functions*.
For a detailed description of the methodology, please refer to the [**paper**](https://arxiv.org/abs/2308.03296), *Studying Large Language Model Generalization with Influence Functions*.

## Requirements

Kronfluence has been tested and is compatible with the following versions of [PyTorch](https://pytorch.org/):
- PyTorch 2.1 or higher
- Python 3.9 or higher
- Python: Version 3.9 or later
- PyTorch: Version 2.1 or later

## Supported Modules & Strategies

Kronfluence offers support for:
- Computing influence functions on selected PyTorch modules. Currently, we support `nn.Linear` and `nn.Conv2d`.
- Computing influence functions with several Hessian approximation strategies, including `identity`, `diagonal`, `KFAC`, and `EKFAC`.
- Computing influence functions with several Hessian approximation strategies, including `identity`, `diagonal`, `kfac`, and `ekfac`.
- Computing pairwise and self-influence (with and without measurement) scores.

> [!NOTE]
Expand All @@ -22,12 +22,12 @@ Kronfluence offers support for:

## Step-by-Step Guide

See [UCI Regression example](https://github.com/pomonam/kronfluence/blob/main/examples/uci/) for the complete workflow and
interactive tutorial.
See [UCI Regression example](https://github.com/pomonam/kronfluence/blob/main/examples/uci/) for the complete workflow and an interactive tutorial.

**Prepare Your Model and Dataset.**
Before computing influence scores, you need to prepare the trained model and dataset. You can use any frameworks to
train the model (e.g., [Pytorch Lightning](https://lightning.ai/docs/pytorch/stable/) or [HuggingFace Trainer](https://huggingface.co/docs/transformers/main_classes/trainer)).
train the model (e.g., [Pytorch Lightning](https://lightning.ai/docs/pytorch/stable/) or [HuggingFace Trainer](https://huggingface.co/docs/transformers/main_classes/trainer)); you just need to prepare the final model parameters.

```python
...
# Get the model with the trained parameters.
Expand Down Expand Up @@ -58,15 +58,13 @@ class YourTask(Task):
model: nn.Module,
sample: bool = False,
) -> torch.Tensor:
# This will be used for computing the training gradient.
# TODO: Complete this method.

def compute_measurement(
self,
batch: Any,
model: nn.Module,
) -> torch.Tensor:
# This will be used for computing the measurable quantity.
# TODO: Complete this method.

def tracked_modules(self) -> Optional[List[str]]:
Expand All @@ -75,7 +73,7 @@ class YourTask(Task):

def get_attention_mask(self, batch: Any) -> Optional[Union[Dict[str, torch.Tensor], torch.Tensor]]:
# TODO: [Optional] Complete this method.
return None # No attention mask is used.
return None # Attention mask not used.
```

**Prepare Your Model for Influence Computations.**
Expand All @@ -95,11 +93,11 @@ If you have specified specific module names in `Task.tracked_modules`, `TrackedM

**\[Optional\] Create a DDP and FSDP Module.**
After calling `prepare_model`, you can create [DistributedDataParallel (DDP)](https://pytorch.org/tutorials/intermediate/ddp_tutorial.html) or
[FullyShardedDataParallel (FSDP)](https://pytorch.org/docs/stable/fsdp.html) module or even use `torch.compile`.
[FullyShardedDataParallel (FSDP)](https://pytorch.org/docs/stable/fsdp.html) module.

**Set up the Analyzer and Fit Factors.**
Initialize the `Analyzer` and execute `fit_all_factors` to compute all factors that aim to approximate the Hessian
(or Gauss-Newton Hessian). The computed factors will be stored on disk.
Initialize the `Analyzer` and run `fit_all_factors` to compute all factors that aim to approximate the Hessian
(Gauss-Newton Hessian). The computed factors will be stored on disk.

```python
from kronfluence.analyzer import Analyzer
Expand Down Expand Up @@ -138,7 +136,7 @@ You can organize all factors and scores for the specific model with `factors_nam

**What should I do if my model does not have any nn.Linear or nn.Conv2d modules?**
Currently, the implementation does not support influence computations for modules other than `nn.Linear` or `nn.Conv2d`.
Try rewriting the model so that it uses supported modules (as done for the `conv1d` module in [GPT-2 example](https://github.com/pomonam/kronfluence/tree/documentation/examples/wikitext)).
Try rewriting the model so that it uses supported modules (as done for the `conv1d` module in the [GPT-2 example](https://github.com/pomonam/kronfluence/tree/documentation/examples/wikitext)).
Alternatively, you can create a subclass of `TrackedModule` to compute influence scores for your custom module.
If there are specific modules you would like to see supported, please submit an issue.

Expand All @@ -147,8 +145,11 @@ We recommend using all supported modules for influence computations. However, if
on subset of the modules (e.g., influence computations only on MLP layers for transformer or influence computation only on the last layer),
inspect `model.named_modules()` to determine what modules to use. You can specify the list of module names you want to analyze.

> [!TIP]
> `Analyzer.get_module_summary(model)` can be helpful in figuring out what modules to include.
> [!NOTE]
> If the embedding layer for transformers are defined with `nn.Linear`, you must write
> If the embedding layer for transformers are defined with `nn.Linear`, you must write your own
> `task.tracked_modules` to avoid influence computations embedding matrices.
**How should I implement Task.compute_train_loss?**
Expand All @@ -158,14 +159,14 @@ the summed loss (over batches and tokens).
**How should I implement Task.compute_measurement?**
It depends on the analysis you would like to perform. Influence functions approximate the [effect of downweighting/upweighting
a training data point on the query's measurable quantity](https://arxiv.org/abs/2209.05364). You can use the loss, [margin](https://arxiv.org/abs/2303.14186) (for classification),
or [conditional log-likelihood](https://arxiv.org/abs/2308.03296) (for language modeling).
or [conditional log-likelihood](https://arxiv.org/abs/2308.03296) (for language modeling). Note that many influence functions implementation, by default, uses the loss.

**I encounter TrackedModuleNotFoundError when using DDP or FSDP.**
Make sure to call `prepare_model` before wrapping your model with DDP or FSDP. Calling `prepare_model` on DDP modules can
cause `TrackedModuleNotFoundError`.

**My model uses supported modules, but influence scores are not computed.**
Kronfluence uses module hooks to compute factors and influence scores. For these to be tracked and computed,
Kronfluence uses [module hooks](https://pytorch.org/docs/stable/generated/torch.Tensor.register_hook.html) to compute factors and influence scores. For these to be tracked and computed,
the model's forward pass should directly call the module.

```python
Expand All @@ -179,6 +180,11 @@ def forward(x: torch.Tensor) -> torch.Tensor:
x = self.linear.weight @ x + self.linear.bias # This does not work 😞
```

**Why are there so many arguments?**
Kronfluence was originally developed to compute influence scores on large-scale models, which is why `FactorArguments` and `ScoreArguments`
have many parameters to support these use cases. However, for most standard applications, the default argument values
should suffice. Feel free to use the default settings unless you have specific requirements that necessitate customization.

**I get X error when fitting factors/computing scores.**
Please feel free to contact me by [filing an issue](https://github.com/pomonam/kronfluence/issues) or [through email](mailto:jbae@cs.toronto.edu).

Expand All @@ -195,6 +201,7 @@ factor_args = FactorArguments(
use_empirical_fisher=False,
distributed_sync_steps=1000,
amp_dtype=None,
compile_mode=None,

# Settings for covariance matrix fitting.
covariance_max_examples=100_000,
Expand Down Expand Up @@ -224,11 +231,13 @@ You can change:
- `use_empirical_fisher`: Determines whether to use the [empirical Fisher](https://arxiv.org/abs/1905.12558) (using actual labels from batch)
instead of the true Fisher (using sampled labels from model's predictions). It is recommended to be `False`.
- `amp_dtype`: Selects the dtype for [automatic mixed precision (AMP)](https://pytorch.org/docs/stable/amp.html). Disables AMP if set to `None`.
- `compile_model`: Selects the mode for [torch compile](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html). Disables torch compile if set to `None`.

### Fitting Covariance Matrices

`kfac` and `ekfac` require computing the uncentered activation and pre-activation pseudo-gradient covariance matrices.
To fit covariance matrices, you can use `analyzer.fit_covariance_matrices`.

```python
# Fitting covariance matrices.
analyzer.fit_covariance_matrices(factors_name="initial_factor", dataset=train_dataset, factor_args=factor_args)
Expand All @@ -243,11 +252,11 @@ Kronfluence computes covariance matrices for all data points.
For example, when `covariance_data_partition_size = 2`, the dataset is split into 2 chunks and covariance matrices
are separately computed for each chunk. These chunked covariance matrices are later aggregated. This is useful with GPU preemption as intermediate
covariance matrices will be saved in disk. It can be also helpful when launching multiple parallel jobs, where each GPU
can compute covariance matrices on some partitioned data (You can specify `target_data_partitions` in the parameter).
can compute covariance matrices on some partitioned data (you can specify `target_data_partitions` in the parameter).
- `covariance_module_partition_size`: Number of module partitions to use for computing covariance matrices.
For example, when `covariance_module_partition_size = 2`, the module is split into 2 chunks and covariance matrices
are separately computed for each chunk. This is useful when the available GPU memory is limited (e.g., the total
covariance matrices cannot fit into memory). However, this will do multiple iterations over the dataset and can be slow.
covariance matrices cannot fit into GPU memory). However, this will require multiple iterations over the dataset and can be slow.
- `activation_covariance_dtype`: `dtype` for computing activation covariance matrices. You can also use `torch.bfloat16`
or `torch.float16`.
- `gradient_covariance_dtype`: `dtype` for computing pre-activation pseudo-gradient covariance matrices. You can also use `torch.bfloat16`
Expand All @@ -263,15 +272,15 @@ or `torch.float16`.
After computing the covariance matrices, `kfac` and `ekfac` require performing eigendecomposition.

```python
# Performing Eigendecomposition.
# Performing eigendecomposition.
analyzer.perform_eigendecomposition(factors_name="initial_factor", factor_args=factor_args)
# Loading Eigendecomposition results.
# Loading eigendecomposition results (e.g., eigenvectors and eigenvalues).
eigen_factors = analyzer.load_eigendecomposition(factors_name="initial_factor")
```

This corresponds to Equation 18 in the paper. You can tune:
- `eigendecomposition_dtype`: `dtype` for performing eigendecomposition. You can also use `torch.float32`,
but `torch.float64` is recommended.
but `torch.float64` is strongly recommended.

### Fitting Lambda Matrices

Expand All @@ -289,16 +298,16 @@ This corresponds to Equation 20 in the paper. You can tune:
- `lambda_data_partition_size`: Number of data partitions to use for computing Lambda matrices.
- `lambda_module_partition_size`: Number of module partitions to use for computing Lambda matrices.
- `cached_activation_cpu_offload`: Computing the per-sample-gradient requires saving the intermediate activation in memory.
You can set `cached_activation_cpu_offload=True` to cache these activations in CPU.
- `lambda_iterative_aggregate`: Whether to compute the Lambda matrices with for-loop instead of batched matrix multiplications.
You can set `cached_activation_cpu_offload=True` to cache these activations in CPU. This is helpful for dealing with OOMs, but will make the overall computation slower.
- `lambda_iterative_aggregate`: Whether to compute the Lambda matrices with for-loops instead of batched matrix multiplications.
This is helpful for reducing peak memory, as it avoids holding multiple copies of tensors with the same shape as the per-sample-gradient.
- `lambda_dtype`: `dtype` for computing Lambda matrices. You can also use `torch.bfloat16`
or `torch.float16`.


**Dealing with OOMs.** Here are some steps to fix Out of Memory (OOM) errors.
1. Try reducing the `per_device_batch_size` when fitting Lambda matrices.
2. Try setting `lambda_iterative_aggregate=True` or `cached_activation_cpu_offload=True`.
2. Try setting `lambda_iterative_aggregate=True` or `cached_activation_cpu_offload=True`. (Try out `lambda_iterative_aggregate=True` first.)
3. Try using lower precision for `lambda_dtype`.
4. Try using `lambda_module_partition_size > 1`.

Expand All @@ -311,7 +320,8 @@ different eigenvectors when performing eigendecomposition.

**How should I select the batch size?**
You can use the largest possible batch size that does not result in OOM. Typically, the batch size for fitting Lambda
matrices should be smaller than that used for fitting covariance matrices.
matrices should be smaller than that used for fitting covariance matrices. Furthermore, note that you should be getting similar results, regardless
of what batch size you use (different from training neural networks).

---

Expand All @@ -326,37 +336,43 @@ score_args = ScoreArguments(
cached_activation_cpu_offload=False,
distributed_sync_steps=1000,
amp_dtype=None,
compile_mode=None,

# More functionalities to compute influence scores.
data_partition_size=1,
module_partition_size=1,
per_module_score=False,

per_token_score=False,
use_measurement_for_self_influence=False,

# Configuration for query batching.
query_gradient_rank=None,
query_gradient_svd_dtype=torch.float32,
num_query_gradient_aggregations=1,
use_measurement_for_self_influence=False,

# Configuration for dtype.
score_dtype=torch.float32,
per_sample_gradient_dtype=torch.float32,
precondition_dtype=torch.float32,
)
```

- `damping`: A damping factor for the damped inverse Hessian-vector product (iHVP). Uses a heuristic based on mean eigenvalues
(0.1 x mean eigenvalues) if None.
`(0.1 x mean eigenvalues)` if `None`, as done in [this paper](https://arxiv.org/abs/2308.03296).
- `cached_activation_cpu_offload`: Whether to offload cached activations to CPU.
- `amp_dtype`: Selects the dtype for [automatic mixed precision (AMP)](https://pytorch.org/docs/stable/amp.html). Disables AMP if set to `None`.
- `compile_model`: Selects the mode for [torch compile](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html). Disables torch compile if set to `None`.
- `data_partition_size`: Number of data partitions for computing influence scores.
- `module_partition_size`: Number of module partitions for computing influence scores.
- `per_module_score`: Whether to return a per-module influence scores. Instead of summing over influences across
all modules, this will keep track of intermediate module-wise scores.

- `per_token_score`: Whether to return a per-token influence scores. Instead of summing over influence scores across
all tokens, this will keep track of influence scores for each token. Note that this is only supported for Transformer-based models (language modeling).
- `query_gradient_rank`: The rank for the query batching (low-rank approximation to the query gradient; see Section 3.2.2). If `None`, no query batching will be used.
- `query_gradient_svd_dtype`: `dtype` for performing singular value decomposition (SVD) for query batch. You can also use `torch.float64`.
- `num_query_gradient_aggregations`: Number of query gradients to aggregate over.
- `num_query_gradient_aggregations`: Number of query gradients to aggregate over. For example, when `num_query_gradient_aggregations = 2` with
`query_batch_size = 16`, a total of 32 query gradients will be stored in memory when computing dot products with training gradients.
- `use_measurement_for_self_influence`: Whether to use the measurement (instead of the loss) when computing self-influence scores.

- `score_dtype`: `dtype` for computing influence scores. You can use `torch.bfloat16` or `torch.float16`.
- `per_sample_gradient_dtype`: `dtype` for computing per-sample-gradient. You can use `torch.bfloat16` or `torch.float16`.
- `precondition_dtype`: `dtype` for performing preconditioning. You can use `torch.bfloat16` or `torch.float16`,
Expand All @@ -382,7 +398,7 @@ analyzer.compute_self_scores(scores_name="self", factors_name="ekfac", score_arg
scores = analyzer.load_self_scores(scores_name="self")
```

**Dealing with OOMs** Here are some steps to fix Out of Memory (OOM) errors.
**Dealing with OOMs.** Here are some steps to fix Out of Memory (OOM) errors.
1. Try reducing the `per_device_query_batch_size` or `per_device_train_batch_size`.
2. Try setting `cached_activation_cpu_offload=True`.
3. Try using lower precision for `per_sample_gradient_dtype` and `score_dtype`.
Expand All @@ -392,6 +408,13 @@ batching is only supported for computing pairwise influence scores, not self-inf

### FAQs

**How should I choose a damping term?**
When setting the damping term, both `1e-08` and `None` are reasonable choices. The optimal value may depend on your
specific workload. Another heuristic, suggested in [this paper](https://arxiv.org/abs/2405.12186), is to use `10 * learning_rate * num_iterations` when the model
was trained using SGD with a momentum of 0.9. In practice, I have observed that the damping term does not significantly
affect the final results as long as it is not too large (e.g., `1e-01`). Feel free to experiment with different values within a
reasonable range to find what works best for your use case.

**Influence scores are very large in magnitude.**
Ideally, influence scores need to be divided by the total number of training data points. However, the code does
not normalize the scores. If you would like, you can divide the scores with the total number of data points (or tokens) used to
Expand All @@ -404,4 +427,5 @@ train the model.
3. [TRAK: Attributing Model Behavior at Scale](https://arxiv.org/abs/2303.14186). Sung Min Park, Kristian Georgiev, Andrew Ilyas, Guillaume Leclerc, Aleksander Madry. ICML, 2023.
4. [Understanding Black-box Predictions via Influence Functions](https://arxiv.org/abs/1703.04730). Pang Wei Koh, Percy Liang. ICML, 2017.
5. [Optimizing Neural Networks with Kronecker-factored Approximate Curvature](https://arxiv.org/abs/1503.05671). James Martens, Roger Grosse. Tech Report, 2015.
5. [Fast Approximate Natural Gradient Descent in a Kronecker-factored Eigenbasis](https://arxiv.org/abs/1806.03884). Thomas George, César Laurent, Xavier Bouthillier, Nicolas Ballas, Pascal Vincent. NeurIPS, 2018.
5. [Fast Approximate Natural Gradient Descent in a Kronecker-factored Eigenbasis](https://arxiv.org/abs/1806.03884). Thomas George, César Laurent, Xavier Bouthillier, Nicolas Ballas, Pascal Vincent. NeurIPS, 2018.
6. [Training Data Attribution via Approximate Unrolled Differentiation](https://arxiv.org/abs/2405.12186). Juhan Bae, Wu Lin, Jonathan Lorraine, Roger Grosse. Preprint, 2024.
Loading

0 comments on commit ccfd2ad

Please sign in to comment.