Skip to content

Commit

Permalink
Mamba-2 code release
Browse files Browse the repository at this point in the history
  • Loading branch information
tridao committed Jun 3, 2024
1 parent c59255a commit 60dadf2
Show file tree
Hide file tree
Showing 21 changed files with 6,707 additions and 118 deletions.
58 changes: 49 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
> **Mamba: Linear-Time Sequence Modeling with Selective State Spaces**\
> Albert Gu*, Tri Dao*\
> Paper: https://arxiv.org/abs/2312.00752
> **Transformers are {SSM}s: Generalized Models and Efficient Algorithms Through Structured State Space Duality**\
> Tri Dao*, Albert Gu*\
> Paper: https://arxiv.org/abs/2405.21060
## About

Expand Down Expand Up @@ -43,7 +46,7 @@ The main module of this repository is the Mamba architecture block wrapping the
Source: [modules/mamba_simple.py](mamba_ssm/modules/mamba_simple.py).

Usage:
```
``` python
import torch
from mamba_ssm import Mamba

Expand All @@ -60,6 +63,24 @@ y = model(x)
assert y.shape == x.shape
```

The Mamba-2 block is implemented at [modules/mamba2.py](mamba_ssm/modules/mamba2.py).

A simpler version is at [modules/mamba2_simple.py](mamba_ssm/modules/mamba2_simple.py)

The usage is similar to Mamba(-1):
``` python
from mamba_ssm import Mamba2
model = Mamba(
# This module uses roughly 3 * expand * d_model^2 parameters
d_model=dim, # Model dimension d_model
d_state=64, # SSM state expansion factor, typically 64 or 128
d_conv=4, # Local convolution width
expand=2, # Block expansion factor
).to("cuda")
y = model(x)
assert y.shape == x.shape
```

### Mamba Language Model

Finally, we provide an example of a complete language model: a deep sequence model backbone (with repeating Mamba blocks) + language model head.
Expand All @@ -70,12 +91,12 @@ This is an example of how to integrate Mamba into an end-to-end neural network.
This example is used in the generation scripts below.



## Pretrained Models

Pretrained models are uploaded to
[Hugging Face](https://huggingface.co/state-spaces): `mamba-130m`, `mamba-370m`,
`mamba-790m`, `mamba-1.4b`, `mamba-2.8b`, trained on 300B tokens on the Pile, as well as `mamba-2.8b-slimpj`
`mamba-790m`, `mamba-1.4b`, `mamba-2.8b`, `mamba2-130m`, `mamba2-370m`,
`mamba2-780m`, `mamba2-1.3b`, `mamba2-2.7b`, `transformerpp-2.7b`, `mamba2attn-2.7b`, trained on 300B tokens on the Pile, as well as `mamba-2.8b-slimpj`
(trained on 600B tokens on the SlimPajama dataset).


Expand Down Expand Up @@ -106,17 +127,24 @@ library.

1. Install `lm-evaluation-harness` by `pip install lm-eval==0.4.2`.
2. Run evaluation with (more documentation at the [lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness/tree/big-refactor) repo):
```
``` sh
lm_eval --model mamba_ssm --model_args pretrained=state-spaces/mamba-130m --tasks lambada_openai,hellaswag,piqa,arc_easy,arc_challenge,winogrande,openbookqa --device cuda --batch_size 256
python evals/lm_harness_eval.py --model hf --model_args pretrained=EleutherAI/pythia-160m --tasks lambada_openai,hellaswag,piqa,arc_easy,arc_challenge,winogrande --device cuda --batch_size 64
```

To reproduce the results on the `mamba-2.8b-slimpj` model reported in the blogposts:
```
``` sh
lm_eval --model mamba_ssm --model_args pretrained=state-spaces/mamba-2.8b-slimpj --tasks boolq,piqa,hellaswag,winogrande,arc_easy,arc_challenge,openbookqa,race,truthfulqa_mc2 --device cuda --batch_size 256
lm_eval --model mamba_ssm --model_args pretrained=state-spaces/mamba-2.8b-slimpj --tasks mmlu --num_fewshot 5 --device cuda --batch_size 256
```

To run evaluations on Mamba-2 models, simply replace the model names:
``` sh
lm_eval --model mamba_ssm --model_args pretrained=state-spaces/mamba2-2.7b --tasks lambada_openai,hellaswag,piqa,arc_easy,arc_challenge,winogrande,openbookqa --device cuda --batch_size 256
lm_eval --model mamba_ssm --model_args pretrained=state-spaces/transformerpp-2.7b --tasks lambada_openai,hellaswag,piqa,arc_easy,arc_challenge,winogrande,openbookqa --device cuda --batch_size 256
lm_eval --model mamba_ssm --model_args pretrained=state-spaces/mamba2attn-2.7b --tasks lambada_openai,hellaswag,piqa,arc_easy,arc_challenge,winogrande,openbookqa --device cuda --batch_size 256
```

Note that the result of each task might differ from reported values by 0.1-0.3 due to noise in the evaluation process.

## Inference
Expand All @@ -132,16 +160,21 @@ Other configurable options include the top-p (nucleus sampling) probability, and

To test generation latency (e.g. batch size = 1) with different sampling strategies:

```
``` sh
python benchmarks/benchmark_generation_mamba_simple.py --model-name "state-spaces/mamba-2.8b" --prompt "My cat wrote all this CUDA code for a new language model and" --topp 0.9 --temperature 0.7 --repetition-penalty 1.2
python benchmarks/benchmark_generation_mamba_simple.py --model-name "EleutherAI/pythia-2.8b" --prompt "My cat wrote all this CUDA code for a new language model and" --topp 0.9 --temperature 0.7 --repetition-penalty 1.2
python benchmarks/benchmark_generation_mamba_simple.py --model-name "state-spaces/mamba-2.8b" --prompt "My cat wrote all this CUDA code for a new language model and" --minp 0.05 --topk 0 --temperature 0.7 --repetition-penalty 1.2
```

To test generation throughput with random prompts (e.g. large batch size):
``` sh
python benchmarks/benchmark_generation_mamba_simple.py --model-name "state-spaces/mamba-2.8b" --batch 64
python benchmarks/benchmark_generation_mamba_simple.py --model-name "EleutherAI/pythia-2.8b" --batch 64
```
python benchmarks/benchmark_generation_mamba_simple.py --model-name "state-spaces/mamba-2.8b" --batch 128
python benchmarks/benchmark_generation_mamba_simple.py --model-name "EleutherAI/pythia-2.8b" --batch 128

With Mamba-2, you just need to change the model name:
``` sh
python benchmarks/benchmark_generation_mamba_simple.py --model-name "state-spaces/mamba2-2.7b" --prompt "My cat wrote all this CUDA code for a new language model and" --topp 0.9 --temperature 0.7 --repetition-penalty 1.2
```


Expand All @@ -164,12 +197,19 @@ that is specific to the training framework.

## Citation

If you use this codebase, or otherwise found our work valuable, please cite Mamba:
If you use this codebase, or otherwise find our work valuable, please cite Mamba:
```
@article{mamba,
title={Mamba: Linear-Time Sequence Modeling with Selective State Spaces},
author={Gu, Albert and Dao, Tri},
journal={arXiv preprint arXiv:2312.00752},
year={2023}
}
@inproceedings{mamba2,
title={Transformers are {SSM}s: Generalized Models and Efficient Algorithms Through Structured State Space Duality},
author={Dao, Tri and Gu, Albert},
booktitle={International Conference on Machine Learning (ICML)},
year={2024}
}
```
2 changes: 1 addition & 1 deletion benchmarks/benchmark_generation_mamba_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
dtype = torch.float16

print(f"Loading model {args.model_name}")
is_mamba = args.model_name.startswith("state-spaces/mamba-")
is_mamba = args.model_name.startswith("state-spaces/mamba") or args.model_name.startswith("state-spaces/transformerpp")
if is_mamba:
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
model = MambaLMHeadModel.from_pretrained(args.model_name, device=device, dtype=dtype)
Expand Down
3 changes: 2 additions & 1 deletion mamba_ssm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
__version__ = "1.2.2"
__version__ = "2.0.0"

from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn
from mamba_ssm.modules.mamba_simple import Mamba
from mamba_ssm.modules.mamba2 import Mamba2
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
Empty file.
144 changes: 144 additions & 0 deletions mamba_ssm/distributed/distributed_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
from typing import Optional

import torch
from torch import Tensor
from torch.distributed import ProcessGroup

# `all_gather_into_tensor` and `reduce_scatter_tensor` are new placeholders for
# `_all_gather_base` and `_reduce_scatter_base`. They require the most recent
# version of PyTorch. The following 4 lines are for backward compatibility with
# older PyTorch.
if "all_gather_into_tensor" not in dir(torch.distributed):
torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base
if "reduce_scatter_tensor" not in dir(torch.distributed):
torch.distributed.reduce_scatter_tensor = torch.distributed._reduce_scatter_base


# Raw operation, does not support autograd, but does support async
def all_gather_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False):
world_size = torch.distributed.get_world_size(process_group)
output = torch.empty(
world_size * input_.shape[0], *input_.shape[1:], dtype=input_.dtype, device=input_.device
)
handle = torch.distributed.all_gather_into_tensor(
output, input_.contiguous(), group=process_group, async_op=async_op
)
return output, handle


# Raw operation, does not support autograd, but does support async
def reduce_scatter_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False):
world_size = torch.distributed.get_world_size(process_group)
assert input_.shape[0] % world_size == 0
output = torch.empty(
input_.shape[0] // world_size, *input_.shape[1:], dtype=input_.dtype, device=input_.device
)
handle = torch.distributed.reduce_scatter_tensor(
output, input_.contiguous(), group=process_group, async_op=async_op
)
return output, handle


# Raw operation, does not support autograd, but does support async
def all_reduce_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False):
input_ = input_.contiguous()
handle = torch.distributed.all_reduce(input_, group=process_group, async_op=async_op)
return input_, handle


class AllGatherFunc(torch.autograd.Function):
"""Gather the input from sequence parallel region and concatenate."""

@staticmethod
def forward(ctx, input_: Tensor, process_group: ProcessGroup) -> Tensor:
ctx.process_group = process_group
output, _ = all_gather_raw(input_, process_group)
return output

@staticmethod
def backward(ctx, grad_output: Tensor):
grad_input, _ = reduce_scatter_raw(grad_output, ctx.process_group)
return grad_input, None


# Supports autograd, but does not support async
all_gather = AllGatherFunc.apply


class ReduceScatterFunc(torch.autograd.Function):
"""Reduce scatter the input from the sequence parallel region and concatenate."""

@staticmethod
def forward(ctx, input_: Tensor, process_group: ProcessGroup) -> Tensor:
ctx.process_group = process_group
output, _ = reduce_scatter_raw(input_, process_group)
return output

@staticmethod
def backward(ctx, grad_output: Tensor):
grad_input, _ = all_gather_raw(grad_output, ctx.process_group)
return grad_input, None


# Supports autograd, but does not support async
reduce_scatter = ReduceScatterFunc.apply


class AllReduceFunc(torch.autograd.Function):
"""Gather the input from sequence parallel region and concatenate."""

@staticmethod
def forward(ctx, input_: Tensor, process_group: ProcessGroup) -> Tensor:
ctx.process_group = process_group
output, _ = all_reduce_raw(input_, process_group)
return output

@staticmethod
def backward(ctx, grad_output: Tensor):
return grad_output, None


# Supports autograd, but does not support async
all_reduce = AllReduceFunc.apply


def sync_shared_params(model: torch.nn.Module, process_group: ProcessGroup):
# We want to iterate over parameters with _shared_params=True in the same order,
# as different ranks might have different number of parameters (e.g., only rank 0 has bias).
pamams_shared = {
name: p for name, p in model.named_parameters() if getattr(p, "_shared_params", False)
}
for _, p in sorted(pamams_shared.items()):
with torch.no_grad():
# Broadcast needs src to be global rank, not group rank
torch.distributed.broadcast(
p, src=torch.distributed.get_global_rank(process_group, 0), group=process_group
)


# Ref: https://github.com/NVIDIA/Megatron-LM/blob/52e636888cccc41e931251c417a7181fc36de926/megatron/optimizer/optimizer.py#L256
def allreduce_sequence_parallel_grad(model: torch.nn.Module, process_group: ProcessGroup):
# We want to iterate over parameters with _sequence_parallel=True in the same order,
# as different ranks might have different number of parameters (e.g., only rank 0 has bias).
params_seqparallel = {
name: p for name, p in model.named_parameters() if getattr(p, "_sequence_parallel", False)
}
grads = [p.grad for _, p in sorted(params_seqparallel.items())]
if grads:
with torch.no_grad():
coalesced = torch._utils._flatten_dense_tensors(grads)
torch.distributed.all_reduce(coalesced, group=process_group)
for buf, synced in zip(grads, torch._utils._unflatten_dense_tensors(coalesced, grads)):
buf.copy_(synced)


def get_dim_for_local_rank(dim: int, world_size: int, local_rank: int, multiple_of: int = 1) -> int:
"""Get the dim for the local rank derived from splitting dim on world_size processes.
The split may not be even across the world_size processes.
"""
multiple = dim // multiple_of
div = multiple // world_size
mod = multiple % world_size
local_multiple = div + int(local_rank < mod)
return local_multiple * multiple_of
Loading

0 comments on commit 60dadf2

Please sign in to comment.