Skip to content

Commit

Permalink
Merge pull request #35 from softmax1/surgery
Browse files Browse the repository at this point in the history
Surgery
  • Loading branch information
christopher-w-murphy committed Sep 5, 2023
2 parents 660bebe + 0ee69a7 commit 91faff3
Show file tree
Hide file tree
Showing 11 changed files with 581 additions and 16 deletions.
89 changes: 88 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@ In the spirit of the flash attention paper, further gains can be made by conside
- `flash_attention_n_triton`: recommended for non-integer values of _n_ when a GPU is available, uses Triton
- `slow_attention_n`: flexible, torch-based implementation

🧠 **Perform statistical analyses**: Compute summary statistics for both the weights and activations of your model.
🧠 **Run statistical analyses**: Compute summary statistics for both the weights and activations of your model.
The activation stats are computed online as the model is training.

🔥 **Perform "surgery" on existing models** Take a pretrained model with softmax_0 in its attention mechanism and "operate" on it to replace softmax_0 with softmax_n.

## Install
Simple installation
```bash
Expand All @@ -25,6 +27,10 @@ Optionally install the Triton implementation
$ pip install flash-attention-softmax-n[triton]
$ pip install -U --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/ triton-nightly
```
Optionally install the surgery subpackage for converting pretrained models to softmax_n
```bash
$ pip install flash-attention-softmax-n[surgery]
```

## Usage

Expand Down Expand Up @@ -151,4 +157,85 @@ print(activations_statistics['...attention.output...']['kurtosis'])
print(weight_statistics['...attention.output...']['kurtosis'])

save_results({'activations': activations_statistics, 'weights': weight_statistics}, 'my-gpt4')
```

### Surgery
"Operate" on pretrained models to generalize them to softmax_n.
Based on MosaicML's [composer](https://github.com/mosaicml/composer).

Functional API: add one line of code to your script.
```python
import transformers

from flash_attention_softmax_n.surgery import apply_attention_softmax_n


model = transformers.AutoModel.from_pretrained('bert-base-uncased')
apply_attention_softmax_n(model=model, softmax_n_param=1.)
...
```

Object-oriented API for use with the MosaicML composer trainer.
```python
import composer
import transformers

from flash_attention_softmax_n.surgery import AttentionSoftmaxN


model = transformers.AutoModel.from_pretrained('bert-base-uncased')
trainer = composer.trainer.Trainer(
model=model,
algorithms=[AttentionSoftmaxN(softmax_n_param=1.)]
)
...
```

Add your model to the registry!
(Currently, only BERT and RoBERTa without flash attention are available by default.)
As an example, use `policy_registry` to replace slow_attention_0 in `MyModel` with flash_attention_n.
After registration, wrap the model in `apply_attention_softmax_n`.
```python
import types

import torch

from flash_attention_n import slow_attention_n, flash_attention_n
from flash_attention_softmax_n.surgery import apply_attention_softmax_n
from flash_attention_n.surgery.surgery_functions import policy_registry


class MyModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.attn = SlowAttention()

def forward(self, q, k, v):
return self.attn(q, k, v, softmax_n_param=0.)


class SlowAttention(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, q, k, v):
return slow_attention_n(q, k, v, softmax_n_param=0.)


@policy_registry.register(SlowAttention)
def slow_attention_converter(module: torch.nn.Module, module_index: int, softmax_n_param: float) -> torch.nn.Module:
assert isinstance(module, SlowAttention)
del module_index # unused
module.n = softmax_n_param
setattr(module, 'forward', types.MethodType(forward, module))
return module


def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
return flash_attention_n(q, k, v, softmax_n_param=int(self.n))


if __name__ == '__main__':
model = MyModel()
apply_attention_softmax_n(model=model, softmax_n_param=1.) # will log a warning if the model isn't registered
```
8 changes: 3 additions & 5 deletions flash_attention_softmax_n/__init__.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
from warnings import warn

from einops._torch_specific import allow_ops_in_compiled_graph

from flash_attention_softmax_n.core.flash_attn import flash_attention_n
from flash_attention_softmax_n.core.functional import softmax_n, slow_attention_n
try:
from flash_attention_softmax_n.core.flash_attn_triton import flash_attention_n_triton
except ModuleNotFoundError as e:
warn(f'The Triton flash attention implementation, `flash_attention_n_triton`, is not available. {e}.')
flash_attention_n_triton = None
TRITON_INSTALLED = True
except ModuleNotFoundError:
TRITON_INSTALLED = False


allow_ops_in_compiled_graph()
5 changes: 5 additions & 0 deletions flash_attention_softmax_n/surgery/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
try:
from flash_attention_softmax_n.surgery.attention_softmax_n import AttentionSoftmaxN, apply_attention_softmax_n
SURGERY_INSTALLED = True
except ModuleNotFoundError:
SURGERY_INSTALLED = False
108 changes: 108 additions & 0 deletions flash_attention_softmax_n/surgery/attention_softmax_n.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
from __future__ import annotations

import logging
from typing import Optional, Sequence, Union

from composer.core import Algorithm, Event, State
from composer.loggers import Logger
from composer.utils import module_surgery
from torch.nn import Module
from torch.optim import Optimizer

from flash_attention_softmax_n.surgery.surgery_functions import policy_registry

log = logging.getLogger(__name__)

__all__ = ['AttentionSoftmaxN', 'apply_attention_softmax_n']


def apply_attention_softmax_n(
model: Module,
softmax_n_param: float,
optimizers: Optional[Union[Optimizer, Sequence[Optimizer]]] = None
) -> None:
"""
Replaces the forward method of SelfAttention with a version that uses softmax_n.
Example:
```python
import transformers
from flash_attention_softmax_n.surgery import apply_attention_softmax_n
model = transformers.AutoModel.from_pretrained('bert-base-uncased')
apply_attention_softmax_n(model=model, softmax_n_param=1.)
```
:param model: Model to transform.
:param softmax_n_param: The value of n.
:param optimizers: Existing optimizers that are bound to `model.parameters()`. Omit this parameters if optimizers will be constructed after calling this function.
"""
def as_replacement_function(surgery_function):

def replacement_function(module: Module, module_index: int):
return surgery_function(module, module_index, softmax_n_param=softmax_n_param)

return replacement_function

policies = {
module_class: as_replacement_function(attention_softmax_n_surgery_function)
for module_class, attention_softmax_n_surgery_function in policy_registry.items()
}

replaced_pairs = module_surgery.replace_module_classes(model, optimizers=optimizers, policies=policies)

count = len(replaced_pairs)
if count == 0:
supported_modules = ''.join(sorted(['\n\t' + c.__module__ + '.' + c.__name__ for c in policy_registry.keys()]))
log.warning(f'AttentionSoftmaxN had no effect on the model! Support for AttentionSoftmaxN surgery '
f'is currently limited to the following classes: {supported_modules}')
else:
log.info(f'{count} instances of AttentionSoftmaxN added')


class AttentionSoftmaxN(Algorithm):
"""
Object that applies attention softmax_n in a Mosaic trainer.
Example:
```python
import composer
import transformers
from flash_attention_softmax_n.surgery import AttentionSoftmaxN
model = transformers.AutoModel.from_pretrained('bert-base-uncased')
trainer = composer.trainer.Trainer(
model=model,
algorithms=[AttentionSoftmaxN(softmax_n_param=1.)]
)
```
"""
def __init__(self, softmax_n_param: float) -> None:
self.softmax_n_param = softmax_n_param
self._applied = False

def __repr__(self) -> str:
return f'{self.__class__.__name__}()'

@staticmethod
def required_on_load() -> bool:
return True

def match(self, event: Event, state: State) -> bool:
del state # unused
return event == Event.INIT and not self._applied

def apply(self, event: Event, state: State, logger: Logger) -> None:
del event, logger # unused
apply_attention_softmax_n(
state.model,
softmax_n_param=self.softmax_n_param,
optimizers=state.optimizers,
)
self._applied = True
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
try:
from flash_attention_softmax_n.surgery.surgery_functions import _bert
from flash_attention_softmax_n.surgery.surgery_functions.utils import policy_registry

__all__ = ['policy_registry']
except ModuleNotFoundError:
__all__ = []
121 changes: 121 additions & 0 deletions flash_attention_softmax_n/surgery/surgery_functions/_bert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
from math import sqrt
from types import MethodType
from typing import Optional, Tuple

from torch import Tensor, FloatTensor, cat, matmul, tensor, long, arange, einsum
from torch.nn import Module
from transformers.models.bert.modeling_bert import BertSelfAttention
from transformers.models.roberta.modeling_roberta import RobertaSelfAttention

from flash_attention_softmax_n import softmax_n
from flash_attention_softmax_n.surgery.surgery_functions.utils import policy_registry


@policy_registry.register(BertSelfAttention, RobertaSelfAttention)
def bert_attention_converter(module: Module, module_index: int, softmax_n_param: float) -> Module:
"""Adds AttentionSoftmaxN to Bert-style SelfAttention."""
assert isinstance(module, (BertSelfAttention, RobertaSelfAttention))
del module_index # unused
module.n = softmax_n_param
setattr(module, 'forward', MethodType(forward, module))
return module


def forward(
self,
hidden_states: Tensor,
attention_mask: Optional[FloatTensor] = None,
head_mask: Optional[FloatTensor] = None,
encoder_hidden_states: Optional[FloatTensor] = None,
encoder_attention_mask: Optional[FloatTensor] = None,
past_key_value: Optional[Tuple[Tuple[FloatTensor]]] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[Tensor]:
mixed_query_layer = self.query(hidden_states)

# If this is instantiated as a cross-attention module, the keys
# and values come from an encoder; the attention mask needs to be
# such that the encoder's padding tokens are not attended to.
is_cross_attention = encoder_hidden_states is not None

if is_cross_attention and past_key_value is not None:
# reuse k,v, cross_attentions
key_layer = past_key_value[0]
value_layer = past_key_value[1]
attention_mask = encoder_attention_mask
elif is_cross_attention:
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
attention_mask = encoder_attention_mask
elif past_key_value is not None:
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
key_layer = cat([past_key_value[0], key_layer], dim=2)
value_layer = cat([past_key_value[1], value_layer], dim=2)
else:
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))

query_layer = self.transpose_for_scores(mixed_query_layer)

use_cache = past_key_value is not None
if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
# Additional calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
# all previous decoder key/value_states. Additional calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bidirectional self-attention `past_key_value` is always `None`
past_key_value = (key_layer, value_layer)

# Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = matmul(query_layer, key_layer.transpose(-1, -2))

if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
query_length, key_length = query_layer.shape[2], key_layer.shape[2]
if use_cache:
position_ids_l = tensor(key_length - 1, dtype=long, device=hidden_states.device).view(-1, 1)
else:
position_ids_l = arange(query_length, dtype=long, device=hidden_states.device).view(-1, 1)
position_ids_r = arange(key_length, dtype=long, device=hidden_states.device).view(1, -1)
distance = position_ids_l - position_ids_r

positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility

if self.position_embedding_type == "relative_key":
relative_position_scores = einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
attention_scores = attention_scores + relative_position_scores
elif self.position_embedding_type == "relative_key_query":
relative_position_scores_query = einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
relative_position_scores_key = einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key

attention_scores = attention_scores / sqrt(self.attention_head_size)
if attention_mask is not None:
# Apply the attention mask is (precomputed for all layers in RobertaModel forward() function)
attention_scores = attention_scores + attention_mask

# Normalize the attention scores to probabilities.
attention_probs = softmax_n(attention_scores, n=self.n, dim=-1) # *** modified by CWM ***

# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self.dropout(attention_probs)

# Mask heads if we want to
if head_mask is not None:
attention_probs = attention_probs * head_mask

context_layer = matmul(attention_probs, value_layer)

context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(new_context_layer_shape)

outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)

if self.is_decoder:
outputs = outputs + (past_key_value,)
return outputs
Loading

0 comments on commit 91faff3

Please sign in to comment.