diff --git a/.github/workflows/publish.yaml b/.github/workflows/publish.yaml index a0b09fa78..192f95626 100644 --- a/.github/workflows/publish.yaml +++ b/.github/workflows/publish.yaml @@ -40,12 +40,12 @@ jobs: strategy: fail-fast: false matrix: - # Using ubuntu-22.04 instead of 24.04 for more compatibility (glibc). Ideally we'd use the + # Using ubuntu-20.04 instead of 22.04 for more compatibility (glibc). Ideally we'd use the # manylinux docker image, but I haven't figured out how to install CUDA on manylinux. - os: [ubuntu-22.04] + os: [ubuntu-20.04] python-version: ['3.9', '3.10', '3.11', '3.12', '3.13'] - torch-version: ['2.4.0', '2.5.1', '2.6.0', '2.7.1'] - cuda-version: ['11.8.0', '12.9.1'] + torch-version: ['2.1.2', '2.2.2', '2.3.1', '2.4.0', '2.5.1', '2.6.0.dev20241001'] + cuda-version: ['11.8.0', '12.3.2'] # We need separate wheels that either uses C++11 ABI (-D_GLIBCXX_USE_CXX11_ABI) or not. # Pytorch wheels currently don't use it, but nvcr images have Pytorch compiled with C++11 ABI. # Without this we get import error (undefined symbol: _ZN3c105ErrorC2ENS_14SourceLocationESs) @@ -53,6 +53,16 @@ jobs: cxx11_abi: ['FALSE', 'TRUE'] exclude: # see https://github.com/pytorch/pytorch/blob/main/RELEASE.md#release-compatibility-matrix + # Pytorch < 2.2 does not support Python 3.12 + - torch-version: '2.1.2' + python-version: '3.12' + # Pytorch < 2.5 does not support Python 3.13 + - torch-version: '2.1.2' + python-version: '3.13' + - torch-version: '2.2.2' + python-version: '3.13' + - torch-version: '2.3.1' + python-version: '3.13' - torch-version: '2.4.0' python-version: '3.13' @@ -89,7 +99,7 @@ jobs: - name: Install CUDA ${{ matrix.cuda-version }} if: ${{ matrix.cuda-version != 'cpu' }} - uses: Jimver/cuda-toolkit@v0.2.26 + uses: Jimver/cuda-toolkit@v0.2.19 id: cuda-toolkit with: cuda: ${{ matrix.cuda-version }} @@ -111,8 +121,8 @@ jobs: # e.g. we can have system CUDA version being 11.7 but if torch==1.12 then we need to download the wheel from cu116 # This code is ugly, maybe there's a better way to do this. export TORCH_CUDA_VERSION=$(python -c "from os import environ as env; \ - minv = {'2.4': 118, '2.5': 118, '2.6': 118, '2.7': 118}[env['MATRIX_TORCH_VERSION']]; \ - maxv = {'2.4': 124, '2.5': 124, '2.6': 126, '2.7': 128}[env['MATRIX_TORCH_VERSION']]; \ + minv = {'2.1': 118, '2.2': 118, '2.3': 118, '2.4': 118, '2.5': 118, '2.6': 118}[env['MATRIX_TORCH_VERSION']]; \ + maxv = {'2.1': 121, '2.2': 121, '2.3': 121, '2.4': 124, '2.5': 124, '2.6': 124}[env['MATRIX_TORCH_VERSION']]; \ print(minv if int(env['MATRIX_CUDA_VERSION']) < 120 else maxv)" \ ) if [[ ${{ matrix.torch-version }} == *"dev"* ]]; then diff --git a/.gitignore b/.gitignore index dbde1b117..be627d6c2 100644 --- a/.gitignore +++ b/.gitignore @@ -3,4 +3,5 @@ build/ **.so *.hip -*_hip.* \ No newline at end of file +*_hip.* +venv/ \ No newline at end of file diff --git a/MANIFEST.in b/MANIFEST.in deleted file mode 100644 index 318499d57..000000000 --- a/MANIFEST.in +++ /dev/null @@ -1,3 +0,0 @@ -recursive-include csrc * -recursive-include csrc * -README.md diff --git a/README.md b/README.md index bf3b76a93..a1219441e 100755 --- a/README.md +++ b/README.md @@ -1,3 +1,207 @@ +# Differential Mamba + +

+ +Nadav Schneider, +Itamar Zimerman, +Eliya Nachmani + + + +This repository contains the official PyTorch implementation of Differential Mamba paper. +We also provide training code, evaluation code, and model checkpoints to reproduce the results in the paper, including all the baselines. + + + +

+
+ +# Setup +## Clone Project +``` +git clone https://github.com/maxmelichov/DiffMamba # This version us using mamba-ssm==2.2.4 +cd DiffMamba +``` + +## Create Environment +Use a virtual environment (recommended). Create and activate one, then upgrade pip: +``` +python3 -m venv .venv +# How to activate +source .venv/bin/activate +python -m pip install --upgrade pip +``` +If you already have an active virtual environment, you can skip these steps. + + +Mamba Installation: +``` +pip install causal-conv1d==1.5.0.post8 +pip install flash-attn==2.7.4.post1 +# make sure you are in the right Directory (you should be in DiffMamba) +pip install . +``` + +## Additional Requirements - Language Modeling + +Install the requirements in: https://github.com/state-spaces/s4 + +In order to train/evaluate the Language Modeling task, first, download the data. This can be done using the following scripts: +``` +python language_modeling/src/data/datasets/get_wt103.py +bash language_modeling/src/data/transformer-xl/enwik8/get_enwik8.sh +bash language_modeling/src/data/transformer-xl/text8/get_text8.sh +``` +Then, move the resulting datasets into language_modeling/data directory. + +## Additional Requirements - Retrieval + +Install the requirements in: https://github.com/booydar/babilong + +To fine-tune on PG19, please make sure to download the dataset according to the instructions at [deepmind/pg19](https://huggingface.co/datasets/deepmind/pg19) or use the Huggingface dataset version. + +## Additional Requirements - Tuned-Lens + +Install the requirements in: https://github.com/AlignmentResearch/tuned-lens + +Make sure to download The-Pile validation set to train the lens. +Locate the .json or .txt file in the directory tuned-lens/data. + + + +# Experiments +## Language Modeling +Run cd language_modeling. +Then, run the following: +``` +python train.py experiment=lm/diffmamba2-text8 trainer.devices=[0] model.dropout=0.5 loader.l_max=512 train.seed=0 trainer.accumulate_grad_batches=1 loader.batch_size=50 model.n_layers=12 model.d_model=1024 trainer.max_epochs=40 trainer.precision=32 +``` + +```trainer.devices```: used to determine the GPUs for training. [0] use cuda:0 while [2] use cuda:2. [0, 2] will use cuda:0 and cuda:2 with DDP training, while 2 will choose the first two gpus available (cuda:0 and cuda:1). + +```loader.l_max```: the max length or context window for the current training + +```model.n_layers```: determine the model size + +```optimizer.lr```: to change the learning rate, otherwise, use the default + +```trainer.max_epochs```: number of epochs + +```loader.batch_size```: represent the batch size + +```model.dropout```: the dropout of the current model + +```trainer.seed```: responsible of the training seed + +```accumulate_grad_batches```: can be used if the memory in the GPU is not sufficient for the required batch size + + +## Retrieval + + +Run cd retrieval. +To evaluate the models, make sure to save the models checkpoints in the Diff-Mamba/outputs directory. + +### Finetune PG19 +To finetune Mamba on PG19 run: +``` +torchrun --nproc_per_node=4 finetune_pg19.py --model_id=AntonV/mamba2-370m-hf --lr=3e-4 --batch_size=6 --grad_accum_steps=12 --max_steps=4000 --weight_decay=0.1 --warmup=400 --save_steps=500 --eval_steps=500 --output_dir=./outputs/mamba2-370m-pg19-finetune +``` +To finetune Diff-Mamba on PG19 run: +``` +torchrun --nproc_per_node=4 finetune_pg19.py --model_id=AntonV/mamba2-370m-hf --diffmamba --lr=3e-4 --batch_size=6 --grad_accum_steps=12 --max_steps=4000 --weight_decay=0.1 --warmup=400 --save_steps=500 --eval_steps=500 --output_dir=./outputs +``` + +### Finetune BABILong +To finetune Mamba on BABILong run: +``` +torchrun --nproc_per_node=1 finetune_needle.py --ckpt_path=./outputs/mamba2-370m-pg19-finetune --lr=3e-4 --batch_size=6 --grad_accum_steps=1 --max_steps=500 --weight_decay=0.1 --warmup=50 --save_steps=100 --eval_steps=100 --seed=0 --output_dir=./outputs/mamba2-370m-needle-finetune +``` +To finetune Diff-Mamba on BABILong run: +``` +torchrun --nproc_per_node=1 finetune_needle.py --ckpt_path=./outputs/diffmamba2-370m-pg19-finetune --diffmamba --lr=3e-4 --batch_size=6 --grad_accum_steps=1 --max_steps=500 --weight_decay=0.1 --warmup=50 --save_steps=100 --eval_steps=100 --seed=0 --output_dir=./outputs/diffmamba2-370m-needle-finetune +``` + +```--nproc_per_node```: choose number of GPUs for DDP training + +```--grad_accum_steps```: this variable is used to increase effective batch size under memory limitations + +```--diffmamba```: this is a flag that has to be chosen when training Diff-Mamba + +```--model_id```: this is the mamba pretrained model loaded from Huggingface + +### Evaluate + +To evaluate a model on the different tasks and context lengths run: + +``` +bash scripts/run_activation-beacon-diffmamba2-370m-needle-finetune-seed0_no_instruct.sh +``` +or +``` +bash scripts/run_activation-beacon-diffmamba2-370m_pg19-finetune_no_instruct.sh +``` +Results will be saved in the directory scripts/babilong_evals. + +### Plot +To plot the scores, simply run: +``` +python plot.py --model_name diffmamba2-370m-needle-finetune-seed0 --results_folder scripts/babilong_evals/diffmamba2-370m-needle-finetune-seed0 +``` +To plot the relative percentage run: +``` +python plot_compare.py --model_name diffmamba2-370m-needle-finetune --ratio +``` +The plot will be saved in scripts/babilong_evals. Use the flag ```--ratio``` for the relative precentage plot or omit it for the original scores plot + +## Tuned-Lens + + +Run cd tuned-lens. +### Training Lens +Then to train lens for mamba, run: +``` +python -m tuned_lens train --model.name ../../../outputs/mamba2-370m-pg19-finetune --data.name data/valid.txt --per_gpu_batch_size=1 --ssm --output my_lenses/mamba2-370m-pg19-finetune +``` +To train diffmamba, specify the correct path to the model and change the required output directory. +To train the lens in a distributed fashion, change ```--per_gpu_batch_size``` to the number of available GPUs. + +### Evaluate +To evaluate run: +``` +python test_babilong_0k.py --ckpt_path ../../../outputs/mamba2-370m-needle-finetune +``` +add ```--diffmamba``` flag if using Diff-Mamba. + +You can stop the test early when using the flag ```--num_examples```. The compatible lens will be loaded from the my_lenses directory. + +### Plot +To plot the results run: +``` +python plot_tuned_lens.py --diff_results_path results/diffmamba2-370m-needle-finetune-lens_eval.txt --mamba_results_path results/mamba2-370m-needle-finetune-lens_eval.txt +``` +Use ```--log``` to create a log scale plot and ```--start-layer``` and ```--end-layer``` to choose specific layers to plot. + +## Acknowledgements + +All model implementations are based on [Mamba](https://github.com/state-spaces/mamba). Training and evaluation for the language modeling experiments are based on [S4](https://github.com/state-spaces/s4) repository. Evaluation on BABILong is based on [BABILong](https://github.com/booydar/babilong) repo, and measuring signal-to-noise ratio through the layers is based on [tuned-lens](https://github.com/AlignmentResearch/tuned-lens). + +## Citation + +If you use this code, please consider citing the following: + +``` +@misc{schneider2025differentialmamba, + title={Differential Mamba}, + author={Nadav Schneider and Itamar Zimerman and Eliya Nachmani}, + year={2025}, + eprint={2507.06204}, + archivePrefix={arXiv}, + primaryClass={cs.LG}, + url={https://arxiv.org/abs/2507.06204}, +} +``` + # Mamba ![Mamba](assets/selection.png "Selective State Space") diff --git a/figures/LensLogScale.PNG b/figures/LensLogScale.PNG new file mode 100644 index 000000000..3d714a491 Binary files /dev/null and b/figures/LensLogScale.PNG differ diff --git a/figures/babilong.PNG b/figures/babilong.PNG new file mode 100644 index 000000000..f54b15f17 Binary files /dev/null and b/figures/babilong.PNG differ diff --git a/figures/diffmamba.PNG b/figures/diffmamba.PNG new file mode 100644 index 000000000..e297080c7 Binary files /dev/null and b/figures/diffmamba.PNG differ diff --git a/mamba_ssm/__init__.py b/mamba_ssm/__init__.py index 6280931e4..ac4f6e311 100644 --- a/mamba_ssm/__init__.py +++ b/mamba_ssm/__init__.py @@ -1,4 +1,4 @@ -__version__ = "2.2.5" +__version__ = "2.2.4" from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn from mamba_ssm.modules.mamba_simple import Mamba diff --git a/mamba_ssm/models/mixer_seq_simple.py b/mamba_ssm/models/mixer_seq_simple.py index fae2257a9..7a22a2b71 100644 --- a/mamba_ssm/models/mixer_seq_simple.py +++ b/mamba_ssm/models/mixer_seq_simple.py @@ -11,22 +11,118 @@ import torch import torch.nn as nn -from mamba_ssm.models.config_mamba import MambaConfig -from mamba_ssm.modules.mamba_simple import Mamba -from mamba_ssm.modules.mamba2 import Mamba2 -from mamba_ssm.modules.mha import MHA -from mamba_ssm.modules.mlp import GatedMLP -from mamba_ssm.modules.block import Block -from mamba_ssm.utils.generation import GenerationMixin -from mamba_ssm.utils.hf import load_config_hf, load_state_dict_hf +from DiffMamba.mamba_ssm.models.config_mamba import MambaConfig +from DiffMamba.mamba_ssm.modules.mamba_simple import Mamba +from DiffMamba.mamba_ssm.modules.mamba2 import Mamba2 +from DiffMamba.mamba_ssm.modules.mha import MHA +from DiffMamba.mamba_ssm.modules.mlp import GatedMLP +from DiffMamba.mamba_ssm.modules.block import Block, DiffBlockPaper +from DiffMamba.mamba_ssm.utils.generation import GenerationMixin +from DiffMamba.mamba_ssm.utils.hf import load_config_hf, load_state_dict_hf try: - from mamba_ssm.ops.triton.layer_norm import RMSNorm, layer_norm_fn, rms_norm_fn + from DiffMamba.mamba_ssm.ops.triton.layer_norm import RMSNorm, layer_norm_fn, rms_norm_fn except ImportError: RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None +def create_diff_block( + d_model, + d_intermediate, + ssm_cfg=None, + attn_layer_idx=None, + attn_cfg=None, + norm_epsilon=1e-5, + rms_norm=False, + residual_in_fp32=False, + fused_add_norm=False, + layer_idx=None, + device=None, + dtype=None): + if ssm_cfg is None: + ssm_cfg = {} + if attn_layer_idx is None: + attn_layer_idx = [] + if attn_cfg is None: + attn_cfg = {} + factory_kwargs = {"device": device, "dtype": dtype} + if layer_idx % 4 != 2: # layer_idx % 4 != 1 and layer_idx % 4 != 3: + if layer_idx not in attn_layer_idx: + # Create a copy of the config to modify + ssm_cfg = copy.deepcopy(ssm_cfg) if ssm_cfg is not None else {} + ssm_layer = ssm_cfg.pop("layer", "Mamba1") + if ssm_layer not in ["Mamba1", "Mamba2"]: + raise ValueError(f"Invalid ssm_layer: {ssm_layer}, only support Mamba1 and Mamba2") + mixer_cls = partial( + Mamba2 if ssm_layer == "Mamba2" else Mamba, + layer_idx=layer_idx, + **ssm_cfg, + **factory_kwargs + ) + else: + mixer_cls = partial(MHA, layer_idx=layer_idx, **attn_cfg, **factory_kwargs) + norm_cls = partial( + nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon, **factory_kwargs + ) + if d_intermediate == 0: + mlp_cls = nn.Identity + else: + mlp_cls = partial( + GatedMLP, hidden_features=d_intermediate, out_features=d_model, **factory_kwargs + ) + block = Block( + d_model, + mixer_cls, + mlp_cls, + norm_cls=norm_cls, + fused_add_norm=fused_add_norm, + residual_in_fp32=residual_in_fp32, + ) + block.layer_idx = layer_idx + else: + if layer_idx not in attn_layer_idx: + # Create a copy of the config to modify + ssm_cfg = copy.deepcopy(ssm_cfg) if ssm_cfg is not None else {} + ssm_layer = ssm_cfg.pop("layer", "Mamba1") + if ssm_layer not in ["Mamba1", "Mamba2"]: + raise ValueError(f"Invalid ssm_layer: {ssm_layer}, only support Mamba1 and Mamba2") + mixer_cls1 = partial( + Mamba2 if ssm_layer == "Mamba2" else Mamba, + layer_idx=layer_idx, + **ssm_cfg, + **factory_kwargs + ) + mixer_cls2 = partial( + Mamba2 if ssm_layer == "Mamba2" else Mamba, + layer_idx=layer_idx, + **ssm_cfg, + **factory_kwargs + ) + else: + mixer_cls1 = partial(MHA, layer_idx=layer_idx, **attn_cfg, **factory_kwargs) + mixer_cls2 = partial(MHA, layer_idx=layer_idx, **attn_cfg, **factory_kwargs) + norm_cls = partial( + nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon, **factory_kwargs + ) + if d_intermediate == 0: + mlp_cls = nn.Identity + else: + mlp_cls = partial( + GatedMLP, hidden_features=d_intermediate, out_features=d_model, **factory_kwargs + ) + block = DiffBlockPaper( + d_model, + mixer_cls1, + mixer_cls2, + mlp_cls, + norm_cls=norm_cls, + fused_add_norm=fused_add_norm, + residual_in_fp32=residual_in_fp32, + layer_idx=layer_idx, + ) + block.layer_idx = layer_idx + return block -def create_block( +def create_regular_block( d_model, d_intermediate, ssm_cfg=None, @@ -38,8 +134,7 @@ def create_block( fused_add_norm=False, layer_idx=None, device=None, - dtype=None, -): + dtype=None,): if ssm_cfg is None: ssm_cfg = {} if attn_layer_idx is None: @@ -81,6 +176,52 @@ def create_block( block.layer_idx = layer_idx return block +def create_block( + d_model, + d_intermediate, + ssm_cfg=None, + attn_layer_idx=None, + attn_cfg=None, + norm_epsilon=1e-5, + rms_norm=False, + residual_in_fp32=False, + fused_add_norm=False, + layer_idx=None, + device=None, + dtype=None, + mamba_type="DiffMamba", +): + if mamba_type == "DiffMamba": + return create_diff_block( + d_model, + d_intermediate, + ssm_cfg, + attn_layer_idx, + attn_cfg, + norm_epsilon, + rms_norm, + residual_in_fp32, + fused_add_norm, + layer_idx, + device, + dtype + ) + else: + return create_regular_block( + d_model, + d_intermediate, + ssm_cfg, + attn_layer_idx, + attn_cfg, + norm_epsilon, + rms_norm, + residual_in_fp32, + fused_add_norm, + layer_idx, + device, + dtype + ) + # https://github.com/huggingface/transformers/blob/c28d04e9e252a1a099944e325685f14d242ecdcd/src/transformers/models/gpt2/modeling_gpt2.py#L454 def _init_weights( diff --git a/mamba_ssm/modules/block.py b/mamba_ssm/modules/block.py index 1bd968a0b..05b3408eb 100644 --- a/mamba_ssm/modules/block.py +++ b/mamba_ssm/modules/block.py @@ -1,10 +1,13 @@ # Copyright (c) 2024, Tri Dao, Albert Gu. -from typing import Optional +from typing import Optional, Tuple, Type, Callable import torch from torch import nn, Tensor +import torch.nn.functional as F +import os +os.environ["TOKENIZERS_PARALLELISM"] = "false" -from mamba_ssm.ops.triton.layer_norm import RMSNorm, layer_norm_fn +from DiffMamba.mamba_ssm.ops.triton.layer_norm import RMSNorm, layer_norm_fn class Block(nn.Module): @@ -89,3 +92,269 @@ def forward( def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) + + +class DiffBlockPaper(nn.Module): + """ + Diff-Mamba block: Add->Norm -> (mixer1 || mixer2) -> Norm each -> subtract with λ -> Linear -> Norm + -> (optional MLP sublayer like vanilla Block) + + Returns (hidden_states, residual) with the SAME contract as mamba_ssm.modules.block.Block: + - If no MLP: residual is the pre-norm Add sum, hidden_states is the sublayer output (no add here). + - If MLP: we do residual += hidden_states before norm2+MLP, as in vanilla. + """ + + def __init__( + self, + dim: int, + mixer_cls1: Callable[[int], nn.Module], + mixer_cls2: Callable[[int], nn.Module], + mlp_cls: Callable[[int], nn.Module], + *, + norm_cls: Callable[[int], nn.Module] = nn.LayerNorm, + fused_add_norm: bool = False, + residual_in_fp32: bool = False, + layer_idx: int = 0, + use_postscale: bool = False, # optional extra scaling by (1 - lambda_init) + lambda_init: Optional[float] = None, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ) -> None: + super().__init__() + self.d_model = dim + self.layer_idx = layer_idx + self.fused_add_norm = fused_add_norm + self.residual_in_fp32 = residual_in_fp32 + self.use_postscale = bool(use_postscale) + + # Prenorm for Add->Norm input + self.norm = norm_cls(dim) + + # Two parallel mixers + self.mixer1 = mixer_cls1(dim) + self.mixer2 = mixer_cls2(dim) + + # Post-mixer norms (separate for each branch) and post-sub norm + self.subln = norm_cls(dim) + + # Per-layer scalar λ (σ(λ̄)+λ_init), λ̄ initialized very negative -> small λ + self.lambda_init = 0.8 - 0.6 * torch.exp(torch.tensor(-0.3 * self.layer_idx)) + self.lambda_q1 = nn.Parameter(torch.randn(self.d_model)) + + # Optional second sublayer (MLP) mirrors vanilla Block + if mlp_cls is nn.Identity: + self.mlp = None + else: + self.norm2 = norm_cls(dim) + self.mlp = mlp_cls(dim) + + if self.fused_add_norm: + assert layer_norm_fn is not None, "fused_add_norm=True requires Triton layer_norm_fn" + assert isinstance(self.norm, (nn.LayerNorm, RMSNorm)) if RMSNorm is not None else isinstance(self.norm, nn.LayerNorm) + + # -------- cache helper (lets each mixer see its own cache view) -------- + class _SwapCache: + def __init__(self, ip, idx: int, view): + self.ip, self.idx, self.view, self.orig = ip, idx, view, None + def __enter__(self): + if self.ip is not None: + self.orig = self.ip.key_value_memory_dict.get(self.idx, None) + self.ip.key_value_memory_dict[self.idx] = self.view + def __exit__(self, exc_type, exc, tb): + if self.ip is not None: + if self.orig is None: + self.ip.key_value_memory_dict.pop(self.idx, None) + else: + self.ip.key_value_memory_dict[self.idx] = self.orig + + def _run_mixers(self, x: Tensor, inference_params=None, **mixer_kwargs) -> Tuple[Tensor, Tensor]: + # No caching provided: straightforward parallel calls + if inference_params is None: + y1 = self.mixer1(x, inference_params=None, **mixer_kwargs) + y2 = self.mixer2(x, inference_params=None, **mixer_kwargs) + return y1, y2 + + # Ensure persistent, SEPARATE cache slots for each branch + ip = inference_params + main_key = self.layer_idx + # Use a hidden negative key for branch 2 to avoid collisions with real layer indices + branch2_key = -(self.layer_idx + 1) + + # Ensure branch 1 cache exists at main_key + slot1 = ip.key_value_memory_dict.get(main_key, None) + if slot1 is None: + try: + bs = x.shape[0] + max_seqlen = getattr(ip, "max_seqlen", None) + dtype = x.dtype + if max_seqlen is not None: + ip.key_value_memory_dict[main_key] = self.mixer1.allocate_inference_cache(bs, max_seqlen, dtype=dtype) + slot1 = ip.key_value_memory_dict.get(main_key, None) + except Exception: + slot1 = None + + # Ensure branch 2 cache exists at branch2_key + slot2 = ip.key_value_memory_dict.get(branch2_key, None) + if slot2 is None: + try: + bs = x.shape[0] + max_seqlen = getattr(ip, "max_seqlen", None) + dtype = x.dtype + if max_seqlen is not None: + ip.key_value_memory_dict[branch2_key] = self.mixer2.allocate_inference_cache(bs, max_seqlen, dtype=dtype) + slot2 = ip.key_value_memory_dict.get(branch2_key, None) + except Exception: + slot2 = None + + # Run branches against their own caches; swap branch2 into the main key during its call + y1 = self.mixer1(x, inference_params=ip, **mixer_kwargs) + if slot2 is not None: + with DiffBlockPaper._SwapCache(ip, main_key, slot2): + y2 = self.mixer2(x, inference_params=ip, **mixer_kwargs) + else: + y2 = self.mixer2(x, inference_params=ip, **mixer_kwargs) + return y1, y2 + + @staticmethod + def _to_norm_dtype(norm: nn.Module, x: Tensor) -> Tensor: + w = getattr(norm, "weight", None) + return x.to(w.dtype) if isinstance(w, torch.Tensor) else x + + def forward( + self, + hidden_states: Tensor, + residual: Optional[Tensor] = None, + inference_params=None, + **mixer_kwargs, + ) -> Tuple[Tensor, Tensor]: + # ---- Add -> Norm (prenorm) ---- + if not self.fused_add_norm: + residual = (hidden_states + residual) if residual is not None else hidden_states + hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype)) + if self.residual_in_fp32: + residual = residual.to(torch.float32) + else: + hidden_states, residual = layer_norm_fn( + hidden_states, + self.norm.weight, + self.norm.bias, + residual=residual, + prenorm=True, + residual_in_fp32=self.residual_in_fp32, + eps=self.norm.eps, + is_rms_norm=isinstance(self.norm, RMSNorm) + ) + + # ---- Scalar λ per layer ---- + lambda_q1 = torch.sum(self.lambda_q1, dim=-1).float() + lambda_full = torch.sigmoid(lambda_q1) + self.lambda_init + + # ---- Parallel mixers ---- + y1, y2 = self._run_mixers(hidden_states, inference_params, **mixer_kwargs) + + # ---- Differential combine -> out proj -> post-sub norm ---- + attn = y1 - lambda_full * y2 + attn = self.subln(attn) + + # First sublayer output + hidden_states = attn * (1.0 - self.lambda_init) + + + # ---- Optional MLP sublayer (mirrors vanilla Block) ---- + if self.mlp is not None: + if not self.fused_add_norm: + residual = hidden_states + residual + hidden_states = self.norm2(residual.to(dtype=self.norm2.weight.dtype)) + if self.residual_in_fp32: + residual = residual.to(torch.float32) + else: + hidden_states, residual = layer_norm_fn( + hidden_states, + self.norm2.weight, + self.norm2.bias, + residual=residual, + prenorm=True, + residual_in_fp32=self.residual_in_fp32, + eps=self.norm2.eps, + is_rms_norm=isinstance(self.norm2, RMSNorm) + ) + hidden_states = self.mlp(hidden_states) + else: + residual = hidden_states + residual + + return hidden_states, residual + + def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): + # Share a single cache object between both mixers + cache_fn = getattr(self.mixer1, "allocate_inference_cache", None) + shared = cache_fn(batch_size, max_seqlen, dtype=dtype, **kwargs) if callable(cache_fn) else None + return shared + + @classmethod + def from_pretrained_block( + cls, + block: nn.Module, + mixer_cls: Optional[Callable[[int], nn.Module]] = None, + mlp_cls: Optional[Callable[[int], nn.Module]] = None, + norm_cls: Optional[Callable[[int], nn.Module]] = None, + fused_add_norm: Optional[bool] = None, + residual_in_fp32: Optional[bool] = None, + lambda_init: Optional[float] = None, + use_postscale: bool = False, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ) -> "DiffBlockPaper": + """Build a DiffBlock from a vanilla Block and copy weights into both mixers.""" + src_mixer = getattr(block, "mixer", None) + src_mlp = getattr(block, "mlp", None) + src_norm = getattr(block, "norm", None) + + mixer_cls = mixer_cls or (src_mixer.__class__) + mlp_cls = mlp_cls or (src_mlp.__class__ if src_mlp is not None else nn.Identity) + norm_cls = norm_cls or (src_norm.__class__ if src_norm is not None else nn.LayerNorm) + + fused_add_norm = fused_add_norm if fused_add_norm is not None else getattr(block, "fused_add_norm", False) + residual_in_fp32 = residual_in_fp32 if residual_in_fp32 is not None else getattr(block, "residual_in_fp32", False) + + # Try to infer dimension robustly if block doesn't have d_model + dim = getattr(block, "d_model", None) + if dim is None: + if hasattr(block, "norm") and hasattr(block.norm, "weight"): + dim = block.norm.weight.shape[0] + else: + raise ValueError("Cannot infer dim from the provided block") + + newb = cls( + dim=dim, + mixer_cls1=mixer_cls, + mixer_cls2=mixer_cls, + mlp_cls=mlp_cls, + norm_cls=norm_cls, + fused_add_norm=fused_add_norm, + residual_in_fp32=residual_in_fp32, + layer_idx=getattr(block, "layer_idx", 0), + lambda_init=lambda_init, + use_postscale=use_postscale, + device=device, + dtype=dtype, + ) + + # copy prenorm + if src_norm is not None: + newb.norm.load_state_dict(src_norm.state_dict(), strict=False) + # seed sublayer norm with same stats + newb.subln.load_state_dict(newb.norm.state_dict(), strict=False) + + # copy mixer weights into both mixers + if src_mixer is not None: + st = src_mixer.state_dict() + newb.mixer1.load_state_dict(st, strict=False) + newb.mixer2.load_state_dict(st, strict=False) + + # copy mlp & norm2 if present + if src_mlp is not None and newb.mlp is not None: + newb.mlp.load_state_dict(src_mlp.state_dict(), strict=False) + if hasattr(block, "norm2") and hasattr(newb, "norm2"): + newb.norm2.load_state_dict(block.norm2.state_dict(), strict=False) + + return newb diff --git a/mamba_ssm/ops/selective_scan_interface.py b/mamba_ssm/ops/selective_scan_interface.py index a41f1359c..c51ec40dc 100644 --- a/mamba_ssm/ops/selective_scan_interface.py +++ b/mamba_ssm/ops/selective_scan_interface.py @@ -8,12 +8,10 @@ try: from causal_conv1d import causal_conv1d_fn - from causal_conv1d.cpp_functions import causal_conv1d_fwd_function, causal_conv1d_bwd_function, causal_conv1d_update_function + import causal_conv1d_cuda except ImportError: causal_conv1d_fn = None - causal_conv1d_fwd_function = None - causal_conv1d_bwd_function = None - causal_conv1d_update_function = None + causal_conv1d_cuda = None from mamba_ssm.ops.triton.layer_norm import _layer_norm_fwd @@ -192,7 +190,7 @@ def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weigh """ xz: (batch, dim, seqlen) """ - assert causal_conv1d_fwd_function is not None, "causal_conv1d_cuda is not available. Please install causal-conv1d." + assert causal_conv1d_cuda is not None, "causal_conv1d_cuda is not available. Please install causal-conv1d." assert checkpoint_lvl in [0, 1] L = xz.shape[-1] delta_rank = delta_proj_weight.shape[1] @@ -208,7 +206,7 @@ def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weigh conv1d_weight = rearrange(conv1d_weight, "d 1 w -> d w") x, z = xz.chunk(2, dim=1) conv1d_bias = conv1d_bias.contiguous() if conv1d_bias is not None else None - conv1d_out = causal_conv1d_fwd_function( + conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd( x, conv1d_weight, conv1d_bias, None, None, None, True ) # We're being very careful here about the layout, to avoid extra transposes. @@ -281,7 +279,7 @@ def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weigh @custom_bwd def backward(ctx, dout): # dout: (batch, seqlen, dim) - assert causal_conv1d_fwd_function is not None, "causal_conv1d_cuda is not available. Please install causal-conv1d." + assert causal_conv1d_cuda is not None, "causal_conv1d_cuda is not available. Please install causal-conv1d." (xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight, delta_proj_weight, out_proj_weight, conv1d_out, delta, A, B, C, D, delta_bias, scan_intermediates, b_rms_weight, c_rms_weight, dt_rms_weight, out) = ctx.saved_tensors L = xz.shape[-1] @@ -291,7 +289,7 @@ def backward(ctx, dout): if dout.stride(-1) != 1: dout = dout.contiguous() if ctx.checkpoint_lvl == 1: - conv1d_out = causal_conv1d_fwd_function( + conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd( x, conv1d_weight, conv1d_bias, None, None, None, True ) delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(), @@ -357,7 +355,7 @@ def backward(ctx, dout): dconv1d_out = rearrange(dconv1d_out, "d (b l) -> b d l", b=x.shape[0], l=x.shape[-1]) # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the # backward of conv1d with the backward of chunk). - dx, dconv1d_weight, dconv1d_bias, *_ = causal_conv1d_bwd_function( + dx, dconv1d_weight, dconv1d_bias, *_ = causal_conv1d_cuda.causal_conv1d_bwd( x, conv1d_weight, conv1d_bias, dconv1d_out, None, None, None, dx, False, True ) dconv1d_bias = dconv1d_bias if conv1d_bias is not None else None diff --git a/mamba_ssm/ops/triton/ssd_chunk_scan.py b/mamba_ssm/ops/triton/ssd_chunk_scan.py index 959078061..fa5b813a2 100644 --- a/mamba_ssm/ops/triton/ssd_chunk_scan.py +++ b/mamba_ssm/ops/triton/ssd_chunk_scan.py @@ -132,8 +132,7 @@ def _chunk_scan_fwd_kernel( dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < chunk_size - k, other=0.0).to(tl.float32) # If there's seq_idx, we already set cb[i, j] = 0 for seq_idx[i] != seq_idx[j]. # So we don't need masking wrt seq_idx here. - # cb *= tl.exp((dA_cs_m[:, None] - dA_cs_k[None, :])) - cb *= tl.exp(tl.minimum((dA_cs_m[:, None] - dA_cs_k[None, :]), 0.0)) + cb *= tl.exp((dA_cs_m[:, None] - dA_cs_k[None, :])) dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size - k, other=0.0).to(tl.float32) cb *= dt_k if IS_CAUSAL: @@ -680,8 +679,7 @@ def _chunk_scan_bwd_dx_kernel( cb = tl.load(cb_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_k[None, :] < K_MAX - k), other=0.0) dout = tl.load(dout_ptrs, mask=(offs_k[:, None] < K_MAX - k) & (offs_n[None, :] < hdim), other=0.0) dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < K_MAX - k, other=0.0).to(tl.float32) - # cb *= tl.exp(dA_cs_k[None, :] - dA_cs_m[:, None]) - cb *= tl.exp(tl.minimum((dA_cs_k[None, :] - dA_cs_m[:, None]), 0.0)) + cb *= tl.exp(dA_cs_k[None, :] - dA_cs_m[:, None]) # If we don't have the (k + offs_k[None, :] < K_MAX) mask, for indices outside this range, # we might have dA_cs_m = 0.0 and dA_cs_k very negative, and tl.exp will return inf. # Multiplying with cb, which is 0.0 outside the range, will make the result NaN. @@ -818,8 +816,7 @@ def _chunk_scan_bwd_dcb_kernel( dcb *= dt_n dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size_limit, other=0.0).to(tl.float32) dA_cs_n = tl.load(dA_cumsum_ptr + offs_n * stride_dA_cs_csize, mask=offs_n < chunk_size_limit, other=0.0).to(tl.float32) - # dcb *= tl.exp(dA_cs_m[:, None] - dA_cs_n[None, :]) - dcb *= tl.exp(tl.minimum((dA_cs_m[:, None] - dA_cs_n[None, :]), 0.0)) + dcb *= tl.exp(dA_cs_m[:, None] - dA_cs_n[None, :]) if HAS_DDA_CS: tl.static_assert(not HAS_SEQ_IDX, "HAS_SEQ_IDX not supported with HAS_DDA_CS yet") ddA_cs = dcb * cb @@ -1011,8 +1008,7 @@ def _chunk_scan_bwd_ddAcs_stable_kernel_old( acc *= dt_n dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size, other=0.0).to(tl.float32) dA_cs_n = tl.load(dA_cumsum_ptr + offs_n * stride_dA_cs_csize, mask=offs_n < chunk_size, other=0.0).to(tl.float32) - # acc *= tl.exp(dA_cs_m[:, None] - dA_cs_n[None, :]) - acc *= tl.exp(tl.minimum((dA_cs_m[:, None] - dA_cs_n[None, :]), 0.0)) + acc *= tl.exp(dA_cs_m[:, None] - dA_cs_n[None, :]) mask = offs_m[:, None] >= offs_n[None, :] + 1 acc = tl.where(mask, acc, 0.0) acc = tl.cumsum(acc, axis=1) @@ -1138,8 +1134,7 @@ def _chunk_scan_bwd_ddAcs_stable_kernel( cb = tl.load(cb_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_n[None, :] < chunk_size - start_n), other=0.0).to(tl.float32) acc *= cb dA_cs_n = tl.load(dA_cumsum_ptr + (start_n + offs_n) * stride_dA_cs_csize, mask=offs_n < chunk_size - start_n, other=0.0).to(tl.float32) - # acc *= tl.exp(dA_cs_m[:, None] - dA_cs_n[None, :]) - acc *= tl.exp(tl.minimum((dA_cs_m[:, None] - dA_cs_n[None, :]), 0.0)) + acc *= tl.exp(dA_cs_m[:, None] - dA_cs_n[None, :]) mask = offs_m[:, None] >= start_n + offs_n[None, :] + 1 acc = tl.where(mask, acc, 0.0) rowsum_new = rowsum + tl.sum(acc, axis=1) diff --git a/mamba_ssm/ops/triton/ssd_chunk_state.py b/mamba_ssm/ops/triton/ssd_chunk_state.py index 50838d055..bb49c9a96 100644 --- a/mamba_ssm/ops/triton/ssd_chunk_state.py +++ b/mamba_ssm/ops/triton/ssd_chunk_state.py @@ -141,7 +141,7 @@ def _chunk_cumsum_bwd_kernel( dt += dt_bias[:, None] if DT_SOFTPLUS: dt_presoftplus = dt - dt = tl.where(dt <= 20.0, softplus(dt), dt) + dt = tl.where(dt <= 20.0, softplus(dt), ddt) clamp_mask = (dt < dt_min) | (dt > dt_max) # As of Triton 2.2.0, tl.clamp is not available yet # dt = tl.clamp(dt, dt_min, dt_max) @@ -229,11 +229,9 @@ def _chunk_state_fwd_kernel( seq_idx_k = tl.load(seq_idx_ptrs, mask=offs_k < chunk_size_limit - k, other=-1) dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to(tl.float32) if not HAS_SEQ_IDX: - # scale = tl.exp((dA_cs_last - dA_cs_k)) * dt_k - scale = tl.exp(tl.minimum((dA_cs_last - dA_cs_k), 0.0)) * dt_k + scale = tl.exp((dA_cs_last - dA_cs_k)) * dt_k else: - # scale = tl.where(seq_idx_k == seq_idx_last, tl.exp((dA_cs_last - dA_cs_k)) * dt_k, 0.0) - scale = tl.where(seq_idx_k == seq_idx_last, tl.exp(tl.minimum((dA_cs_last - dA_cs_k), 0.0)) * dt_k, 0.0) + scale = tl.where(seq_idx_k == seq_idx_last, tl.exp((dA_cs_last - dA_cs_k)) * dt_k, 0.0) b *= scale[:, None] b = b.to(x_ptr.dtype.element_ty) acc += tl.dot(x, b) @@ -334,8 +332,7 @@ def _chunk_state_bwd_dx_kernel( dA_cumsum_ptrs = dA_cumsum_ptr + offs_m * stride_dA_cs_csize dA_cs_m = tl.load(dA_cumsum_ptrs, mask=offs_m < chunk_size, other=0.0).to(tl.float32) dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size, other=0.0).to(tl.float32) - # acc *= tl.exp(dA_cs_last - dA_cs_m)[:, None] - acc *= tl.exp(tl.minimum((dA_cs_last - dA_cs_m), 0.0))[:, None] + acc *= tl.exp(dA_cs_last - dA_cs_m)[:, None] x_ptrs = x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim) x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32) @@ -437,11 +434,9 @@ def _chunk_state_bwd_db_kernel( dA_cs_m = tl.load(dA_cumsum_ptrs, mask=offs_m < chunk_size, other=0.0).to(tl.float32) dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size, other=0.0).to(tl.float32) if not HAS_SEQ_IDX: - # scale = tl.exp(dA_cs_last - dA_cs_m) - scale = tl.exp(tl.minimum((dA_cs_last - dA_cs_m), 0.0)) + scale = tl.exp(dA_cs_last - dA_cs_m) else: - # scale = tl.where(seq_idx_m == seq_idx_last, tl.exp(dA_cs_last - dA_cs_m), 0.0) - scale = tl.where(seq_idx_m == seq_idx_last, tl.exp(tl.minimum((dA_cs_last - dA_cs_m), 0.0)), 0.0) + scale = tl.where(seq_idx_m == seq_idx_last, tl.exp(dA_cs_last - dA_cs_m), 0.0) db *= (scale * dt_m)[:, None] if HAS_DDA_CS: # This is the gradient wrt (dA_cs_last - dA_cs_m), i.e. the exclusive reverse cumsum @@ -554,13 +549,11 @@ def _chunk_state_bwd_ddAcs_stable_kernel( dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size, other=0.0).to(tl.float32) dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(tl.float32) if not HAS_SEQ_IDX: - # scale = tl.exp(dA_cs_last - dA_cs_m) - scale = tl.exp(tl.minimum((dA_cs_last - dA_cs_m), 0.0)) + scale = tl.exp(dA_cs_last - dA_cs_m) else: seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1) seq_idx_last = tl.load(seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen) - # scale = tl.where(seq_idx_m == seq_idx_last, tl.exp(dA_cs_last - dA_cs_m), 0.0) - scale = tl.where(seq_idx_m == seq_idx_last, tl.exp(tl.minimum((dA_cs_last - dA_cs_m), 0.0)), 0.0) + scale = tl.where(seq_idx_m == seq_idx_last, tl.exp(dA_cs_last - dA_cs_m), 0.0) acc *= scale[:, None] x_ptrs = x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim) @@ -641,10 +634,8 @@ def _chunk_state_varlen_kernel( b = tl.load(b_ptrs, mask=(offs_k[:, None] < chunk_size_limit - k) & (offs_n[None, :] < dstate) & (offs_k[:, None] >= start_idx_cur - k), other=0.0).to(tl.float32) dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to(tl.float32) dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to(tl.float32) - # scale = tl.where((offs_k >= start_idx_cur - k) & (offs_k < chunk_size_limit - k), - # tl.exp((dA_cs_last - dA_cs_k)) * dt_k, 0.0) scale = tl.where((offs_k >= start_idx_cur - k) & (offs_k < chunk_size_limit - k), - tl.exp(tl.minimum((dA_cs_last - dA_cs_k), 0.0)) * dt_k, 0.0) + tl.exp((dA_cs_last - dA_cs_k)) * dt_k, 0.0) b *= scale[:, None] b = b.to(x_ptr.dtype.element_ty) acc += tl.dot(x, b) diff --git a/mamba_ssm/ops/triton/ssd_combined.py b/mamba_ssm/ops/triton/ssd_combined.py index bbf4ecf84..58a6e04a9 100644 --- a/mamba_ssm/ops/triton/ssd_combined.py +++ b/mamba_ssm/ops/triton/ssd_combined.py @@ -20,12 +20,9 @@ try: from causal_conv1d import causal_conv1d_fn - from causal_conv1d.cpp_functions import causal_conv1d_fwd_function, causal_conv1d_bwd_function, causal_conv1d_update_function + import causal_conv1d_cuda except ImportError: - causal_conv1d_fn = None - causal_conv1d_fwd_function = None - causal_conv1d_bwd_function = None - causal_conv1d_update_function = None + causal_conv1d_fn, causal_conv1d_cuda = None, None from mamba_ssm.ops.triton.ssd_bmm import _bmm_chunk_fwd, _bmm_chunk_bwd from mamba_ssm.ops.triton.ssd_chunk_state import _chunk_cumsum_fwd, _chunk_cumsum_bwd @@ -50,13 +47,6 @@ def init_to_zero(names): return lambda nargs: [nargs[name].zero_() for name in names if nargs[name] is not None] -def rearrange_and_update_stride(tensor, pattern=None, dim=2): - # ensure tensor.stride(dim) is a multiple of eight after rearranging according to pattern, - # if not call contiguous(), rearrange only if pattern is not None - tensor_rearranged = rearrange(tensor, pattern) if pattern is not None else tensor - return tensor_rearranged.contiguous() if tensor_rearranged.stride(dim) % 8 != 0 else tensor_rearranged - - @triton.autotune( configs=[ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8, pre_hook=init_to_zero(["ddt_ptr"])), @@ -130,13 +120,11 @@ def _chunk_scan_chunk_state_bwd_dx_kernel( dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(tl.float32) if not HAS_SEQ_IDX: - # scale = tl.exp(dA_cs_last - dA_cs_m) - scale = tl.exp(tl.minimum((dA_cs_last - dA_cs_m), 0.0)) + scale = tl.exp(dA_cs_last - dA_cs_m) else: seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1) seq_idx_last = tl.load(seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen) - # scale = tl.where(seq_idx_m == seq_idx_last, tl.exp(dA_cs_last - dA_cs_m), 0.0) - scale = tl.where(seq_idx_m == seq_idx_last, tl.exp(tl.minimum((dA_cs_last - dA_cs_m), 0.0)), 0.0) + scale = tl.where(seq_idx_m == seq_idx_last, tl.exp(dA_cs_last - dA_cs_m), 0.0) # Might be faster to just do 1 iteration with larger BLOCK_SIZE_K, up to block size 128 # However, we're getting error with the Triton compiler 2.1.0 for that code path: # Unexpected mma -> mma layout conversion @@ -182,8 +170,7 @@ def _chunk_scan_chunk_state_bwd_dx_kernel( cb = tl.load(cb_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_k[None, :] < K_MAX - k), other=0.0) dout = tl.load(dout_ptrs, mask=(offs_k[:, None] < K_MAX - k) & (offs_n[None, :] < hdim), other=0.0) dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < K_MAX - k, other=0.0).to(tl.float32) - # cb *= tl.exp(dA_cs_k[None, :] - dA_cs_m[:, None]) - cb *= tl.exp(tl.minimum((dA_cs_k[None, :] - dA_cs_m[:, None]), 0.0)) + cb *= tl.exp(dA_cs_k[None, :] - dA_cs_m[:, None]) # If we don't have the (k + offs_k[None, :] < K_MAX) mask, for indices outside this range, # we might have dA_cs_m = 0.0 and dA_cs_k very negative, and tl.exp will return inf. # Multiplying with cb, which is 0.0 outside the range, will make the result NaN. @@ -789,7 +776,7 @@ def forward(ctx, zxbcdt, conv1d_weight, conv1d_bias, dt_bias, A, D, chunk_size, zx0, z, xBC, dt = torch.split(zxbcdt, [2 * d_nonssm, dim, dim + ngroups * dstate * 2, nheads], dim=-1) seq_idx = seq_idx.contiguous() if seq_idx is not None else None xBC_conv = rearrange( - causal_conv1d_fwd_function(rearrange_and_update_stride(xBC, "b s d -> b d s"), + causal_conv1d_cuda.causal_conv1d_fwd(rearrange(xBC, "b s d -> b d s"), conv1d_weight, conv1d_bias, seq_idx, None, None, activation in ["silu", "swish"]), "b d s -> b s d" ) @@ -863,8 +850,8 @@ def backward(ctx, dout, *args): zx0, z, xBC, dt = torch.split(zxbcdt, [2 * d_nonssm, dim, dim + 2 * ctx.ngroups * dstate, nheads], dim=-1) # Recompute x, B, C xBC_conv = rearrange( - causal_conv1d_fwd_function(rearrange_and_update_stride(xBC, "b s d -> b d s"), - conv1d_weight, conv1d_bias, seq_idx, None, None, ctx.activation in ["silu", "swish"]), + causal_conv1d_cuda.causal_conv1d_fwd(rearrange(xBC, "b s d -> b d s"), + conv1d_weight, conv1d_bias, seq_idx, None, None, ctx.activation in ["silu", "swish"]), "b d s -> b s d" ) x, B, C = torch.split(xBC_conv, [dim, ctx.ngroups * dstate, ctx.ngroups * dstate], dim=-1) @@ -913,14 +900,10 @@ def backward(ctx, dout, *args): else: doutproj_weight, doutproj_bias = None, None dxBC_given = rearrange(dxBC_given, "b s d -> b d s") - dxBC_given_update, dweight, dbias, *_ = causal_conv1d_bwd_function( - rearrange_and_update_stride(xBC, "b s d -> b d s"), conv1d_weight, conv1d_bias, - rearrange(dxBC, "b s d -> b d s"), seq_idx, None, None, rearrange_and_update_stride(dxBC_given), False, ctx.activation in ["silu", "swish"] + dxBC_given, dweight, dbias, *_ = causal_conv1d_cuda.causal_conv1d_bwd( + rearrange(xBC, "b s d -> b d s"), conv1d_weight, conv1d_bias, + rearrange(dxBC, "b s d -> b d s"), seq_idx, None, None, dxBC_given, False, ctx.activation in ["silu", "swish"] ) - if dxBC_given.stride() != dxBC_given_update.stride(): - dxBC_given.copy_(dxBC_given_update) - else: - dxBC_given = dxBC_given_update dxBC_given = rearrange(dxBC_given, "b d s -> b s d") return dzxbcdt, dweight, dbias, ddt_bias, dA, dD, None, dinitial_states, None, None, None, None, drmsnorm_weight, None, doutproj_weight, doutproj_bias, None, None, None diff --git a/pyproject.toml b/pyproject.toml index 5831fe66e..ab6315c33 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,12 +12,11 @@ license = { file = "LICENSE" } # Include a LICENSE file in your repo keywords = ["cuda", "pytorch", "state-space model"] classifiers = [ "Programming Language :: Python :: 3", - "License :: OSI Approved :: Apache Software License", + "License :: OSI Approved :: BSD License", "Operating System :: Unix" ] dependencies = [ "torch", - "triton", "ninja", "einops", "transformers", diff --git a/setup.py b/setup.py index 32d62ed90..3ee91645f 100755 --- a/setup.py +++ b/setup.py @@ -172,29 +172,22 @@ def append_nvcc_threads(nvcc_extra_args): "Note: make sure nvcc has a supported version by running nvcc -V." ) - if bare_metal_version <= Version("12.9"): - cc_flag.append("-gencode") - cc_flag.append("arch=compute_53,code=sm_53") - cc_flag.append("-gencode") - cc_flag.append("arch=compute_62,code=sm_62") - cc_flag.append("-gencode") - cc_flag.append("arch=compute_70,code=sm_70") - cc_flag.append("-gencode") - cc_flag.append("arch=compute_72,code=sm_72") cc_flag.append("-gencode") - cc_flag.append("arch=compute_75,code=sm_75") + cc_flag.append("arch=compute_53,code=sm_53") + cc_flag.append("-gencode") + cc_flag.append("arch=compute_62,code=sm_62") + cc_flag.append("-gencode") + cc_flag.append("arch=compute_70,code=sm_70") + cc_flag.append("-gencode") + cc_flag.append("arch=compute_72,code=sm_72") cc_flag.append("-gencode") cc_flag.append("arch=compute_80,code=sm_80") cc_flag.append("-gencode") cc_flag.append("arch=compute_87,code=sm_87") + if bare_metal_version >= Version("11.8"): cc_flag.append("-gencode") cc_flag.append("arch=compute_90,code=sm_90") - if bare_metal_version >= Version("12.8"): - cc_flag.append("-gencode") - cc_flag.append("arch=compute_100,code=sm_100") - cc_flag.append("-gencode") - cc_flag.append("arch=compute_120,code=sm_120") # HACK: The compiler flag -D_GLIBCXX_USE_CXX11_ABI is set to be the same as @@ -363,7 +356,7 @@ def run(self): url="https://github.com/state-spaces/mamba", classifiers=[ "Programming Language :: Python :: 3", - "License :: OSI Approved :: Apache Software License", + "License :: OSI Approved :: BSD License", "Operating System :: Unix", ], ext_modules=ext_modules, @@ -378,7 +371,7 @@ def run(self): "packaging", "ninja", "einops", - "triton", + # "triton", "transformers", # "causal_conv1d>=1.4.0", ], diff --git a/usage.md b/usage.md deleted file mode 100644 index 1b588ce2c..000000000 --- a/usage.md +++ /dev/null @@ -1,43 +0,0 @@ -# Mamba adoption - -We've been very happy to see Mamba being adopted by many organizations -and research labs to speed up their training / inference. -This page contains a partial list of places where Mamba is being used. -If you'd like to add links to your organization / product / codebase, please open a -PR or email us. We'd very much like to hear from you! - -## Large language models and multi-modal models - -- [Tencent's Hunyuan-TurboS (560B)](https://arxiv.org/abs/2505.15431) - -- [Nvidia Nemotron-H (8B, 47B, 56B)](https://research.nvidia.com/labs/adlr/nemotronh/) - -- [AI21 Jamba (398B)](https://www.ai21.com/blog/announcing-jamba-model-family/) - -- [TII Falcon-H1 (34B)](https://falconllm.tii.ae/falcon-h1.html) - -- [IBM Bamba (9B)](https://research.ibm.com/blog/bamba-ssm-transformer-model) - -- [Mistral's Codestral (7B)](https://mistral.ai/news/codestral-mamba) - -- [Nvidia Mamba-2 Hybrid (8B)](https://arxiv.org/abs/2406.07887) - -- [Microsoft Samba (4B)](https://arxiv.org/abs/2406.07522v1) - -- [TII Falcon-Mamba (7B)](https://falconllm.tii.ae/tii-releases-first-sslm-with-falcon-mamba-7b.html) - -## Inference frameworks - -- vLLM - -- Nvidia's TensorRT-LLM - -## Hardware - -- Nvidia GPUs - -- [AMD GPUs](https://rocm.blogs.amd.com/artificial-intelligence/mamba/README.html) - -- [AWS Trainium 2](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/general/nki/tutorials/fused_mamba.html) - -