diff --git a/.github/workflows/publish.yaml b/.github/workflows/publish.yaml
index a0b09fa78..192f95626 100644
--- a/.github/workflows/publish.yaml
+++ b/.github/workflows/publish.yaml
@@ -40,12 +40,12 @@ jobs:
strategy:
fail-fast: false
matrix:
- # Using ubuntu-22.04 instead of 24.04 for more compatibility (glibc). Ideally we'd use the
+ # Using ubuntu-20.04 instead of 22.04 for more compatibility (glibc). Ideally we'd use the
# manylinux docker image, but I haven't figured out how to install CUDA on manylinux.
- os: [ubuntu-22.04]
+ os: [ubuntu-20.04]
python-version: ['3.9', '3.10', '3.11', '3.12', '3.13']
- torch-version: ['2.4.0', '2.5.1', '2.6.0', '2.7.1']
- cuda-version: ['11.8.0', '12.9.1']
+ torch-version: ['2.1.2', '2.2.2', '2.3.1', '2.4.0', '2.5.1', '2.6.0.dev20241001']
+ cuda-version: ['11.8.0', '12.3.2']
# We need separate wheels that either uses C++11 ABI (-D_GLIBCXX_USE_CXX11_ABI) or not.
# Pytorch wheels currently don't use it, but nvcr images have Pytorch compiled with C++11 ABI.
# Without this we get import error (undefined symbol: _ZN3c105ErrorC2ENS_14SourceLocationESs)
@@ -53,6 +53,16 @@ jobs:
cxx11_abi: ['FALSE', 'TRUE']
exclude:
# see https://github.com/pytorch/pytorch/blob/main/RELEASE.md#release-compatibility-matrix
+ # Pytorch < 2.2 does not support Python 3.12
+ - torch-version: '2.1.2'
+ python-version: '3.12'
+ # Pytorch < 2.5 does not support Python 3.13
+ - torch-version: '2.1.2'
+ python-version: '3.13'
+ - torch-version: '2.2.2'
+ python-version: '3.13'
+ - torch-version: '2.3.1'
+ python-version: '3.13'
- torch-version: '2.4.0'
python-version: '3.13'
@@ -89,7 +99,7 @@ jobs:
- name: Install CUDA ${{ matrix.cuda-version }}
if: ${{ matrix.cuda-version != 'cpu' }}
- uses: Jimver/cuda-toolkit@v0.2.26
+ uses: Jimver/cuda-toolkit@v0.2.19
id: cuda-toolkit
with:
cuda: ${{ matrix.cuda-version }}
@@ -111,8 +121,8 @@ jobs:
# e.g. we can have system CUDA version being 11.7 but if torch==1.12 then we need to download the wheel from cu116
# This code is ugly, maybe there's a better way to do this.
export TORCH_CUDA_VERSION=$(python -c "from os import environ as env; \
- minv = {'2.4': 118, '2.5': 118, '2.6': 118, '2.7': 118}[env['MATRIX_TORCH_VERSION']]; \
- maxv = {'2.4': 124, '2.5': 124, '2.6': 126, '2.7': 128}[env['MATRIX_TORCH_VERSION']]; \
+ minv = {'2.1': 118, '2.2': 118, '2.3': 118, '2.4': 118, '2.5': 118, '2.6': 118}[env['MATRIX_TORCH_VERSION']]; \
+ maxv = {'2.1': 121, '2.2': 121, '2.3': 121, '2.4': 124, '2.5': 124, '2.6': 124}[env['MATRIX_TORCH_VERSION']]; \
print(minv if int(env['MATRIX_CUDA_VERSION']) < 120 else maxv)" \
)
if [[ ${{ matrix.torch-version }} == *"dev"* ]]; then
diff --git a/.gitignore b/.gitignore
index dbde1b117..be627d6c2 100644
--- a/.gitignore
+++ b/.gitignore
@@ -3,4 +3,5 @@
build/
**.so
*.hip
-*_hip.*
\ No newline at end of file
+*_hip.*
+venv/
\ No newline at end of file
diff --git a/MANIFEST.in b/MANIFEST.in
deleted file mode 100644
index 318499d57..000000000
--- a/MANIFEST.in
+++ /dev/null
@@ -1,3 +0,0 @@
-recursive-include csrc *
-recursive-include csrc *
-README.md
diff --git a/README.md b/README.md
index bf3b76a93..a1219441e 100755
--- a/README.md
+++ b/README.md
@@ -1,3 +1,207 @@
+# Differential Mamba
+
+
+
+Nadav Schneider,
+Itamar Zimerman,
+Eliya Nachmani
+
+
+
+This repository contains the official PyTorch implementation of Differential Mamba paper.
+We also provide training code, evaluation code, and model checkpoints to reproduce the results in the paper, including all the baselines.
+
+
+
+
+
+
+# Setup
+## Clone Project
+```
+git clone https://github.com/maxmelichov/DiffMamba # This version us using mamba-ssm==2.2.4
+cd DiffMamba
+```
+
+## Create Environment
+Use a virtual environment (recommended). Create and activate one, then upgrade pip:
+```
+python3 -m venv .venv
+# How to activate
+source .venv/bin/activate
+python -m pip install --upgrade pip
+```
+If you already have an active virtual environment, you can skip these steps.
+
+
+Mamba Installation:
+```
+pip install causal-conv1d==1.5.0.post8
+pip install flash-attn==2.7.4.post1
+# make sure you are in the right Directory (you should be in DiffMamba)
+pip install .
+```
+
+## Additional Requirements - Language Modeling
+
+Install the requirements in: https://github.com/state-spaces/s4
+
+In order to train/evaluate the Language Modeling task, first, download the data. This can be done using the following scripts:
+```
+python language_modeling/src/data/datasets/get_wt103.py
+bash language_modeling/src/data/transformer-xl/enwik8/get_enwik8.sh
+bash language_modeling/src/data/transformer-xl/text8/get_text8.sh
+```
+Then, move the resulting datasets into language_modeling/data directory.
+
+## Additional Requirements - Retrieval
+
+Install the requirements in: https://github.com/booydar/babilong
+
+To fine-tune on PG19, please make sure to download the dataset according to the instructions at [deepmind/pg19](https://huggingface.co/datasets/deepmind/pg19) or use the Huggingface dataset version.
+
+## Additional Requirements - Tuned-Lens
+
+Install the requirements in: https://github.com/AlignmentResearch/tuned-lens
+
+Make sure to download The-Pile validation set to train the lens.
+Locate the .json or .txt file in the directory tuned-lens/data.
+
+
+
+# Experiments
+## Language Modeling
+Run cd language_modeling.
+Then, run the following:
+```
+python train.py experiment=lm/diffmamba2-text8 trainer.devices=[0] model.dropout=0.5 loader.l_max=512 train.seed=0 trainer.accumulate_grad_batches=1 loader.batch_size=50 model.n_layers=12 model.d_model=1024 trainer.max_epochs=40 trainer.precision=32
+```
+
+```trainer.devices```: used to determine the GPUs for training. [0] use cuda:0 while [2] use cuda:2. [0, 2] will use cuda:0 and cuda:2 with DDP training, while 2 will choose the first two gpus available (cuda:0 and cuda:1).
+
+```loader.l_max```: the max length or context window for the current training
+
+```model.n_layers```: determine the model size
+
+```optimizer.lr```: to change the learning rate, otherwise, use the default
+
+```trainer.max_epochs```: number of epochs
+
+```loader.batch_size```: represent the batch size
+
+```model.dropout```: the dropout of the current model
+
+```trainer.seed```: responsible of the training seed
+
+```accumulate_grad_batches```: can be used if the memory in the GPU is not sufficient for the required batch size
+
+
+## Retrieval
+
+
+Run cd retrieval.
+To evaluate the models, make sure to save the models checkpoints in the Diff-Mamba/outputs directory.
+
+### Finetune PG19
+To finetune Mamba on PG19 run:
+```
+torchrun --nproc_per_node=4 finetune_pg19.py --model_id=AntonV/mamba2-370m-hf --lr=3e-4 --batch_size=6 --grad_accum_steps=12 --max_steps=4000 --weight_decay=0.1 --warmup=400 --save_steps=500 --eval_steps=500 --output_dir=./outputs/mamba2-370m-pg19-finetune
+```
+To finetune Diff-Mamba on PG19 run:
+```
+torchrun --nproc_per_node=4 finetune_pg19.py --model_id=AntonV/mamba2-370m-hf --diffmamba --lr=3e-4 --batch_size=6 --grad_accum_steps=12 --max_steps=4000 --weight_decay=0.1 --warmup=400 --save_steps=500 --eval_steps=500 --output_dir=./outputs
+```
+
+### Finetune BABILong
+To finetune Mamba on BABILong run:
+```
+torchrun --nproc_per_node=1 finetune_needle.py --ckpt_path=./outputs/mamba2-370m-pg19-finetune --lr=3e-4 --batch_size=6 --grad_accum_steps=1 --max_steps=500 --weight_decay=0.1 --warmup=50 --save_steps=100 --eval_steps=100 --seed=0 --output_dir=./outputs/mamba2-370m-needle-finetune
+```
+To finetune Diff-Mamba on BABILong run:
+```
+torchrun --nproc_per_node=1 finetune_needle.py --ckpt_path=./outputs/diffmamba2-370m-pg19-finetune --diffmamba --lr=3e-4 --batch_size=6 --grad_accum_steps=1 --max_steps=500 --weight_decay=0.1 --warmup=50 --save_steps=100 --eval_steps=100 --seed=0 --output_dir=./outputs/diffmamba2-370m-needle-finetune
+```
+
+```--nproc_per_node```: choose number of GPUs for DDP training
+
+```--grad_accum_steps```: this variable is used to increase effective batch size under memory limitations
+
+```--diffmamba```: this is a flag that has to be chosen when training Diff-Mamba
+
+```--model_id```: this is the mamba pretrained model loaded from Huggingface
+
+### Evaluate
+
+To evaluate a model on the different tasks and context lengths run:
+
+```
+bash scripts/run_activation-beacon-diffmamba2-370m-needle-finetune-seed0_no_instruct.sh
+```
+or
+```
+bash scripts/run_activation-beacon-diffmamba2-370m_pg19-finetune_no_instruct.sh
+```
+Results will be saved in the directory scripts/babilong_evals.
+
+### Plot
+To plot the scores, simply run:
+```
+python plot.py --model_name diffmamba2-370m-needle-finetune-seed0 --results_folder scripts/babilong_evals/diffmamba2-370m-needle-finetune-seed0
+```
+To plot the relative percentage run:
+```
+python plot_compare.py --model_name diffmamba2-370m-needle-finetune --ratio
+```
+The plot will be saved in scripts/babilong_evals. Use the flag ```--ratio``` for the relative precentage plot or omit it for the original scores plot
+
+## Tuned-Lens
+
+
+Run cd tuned-lens.
+### Training Lens
+Then to train lens for mamba, run:
+```
+python -m tuned_lens train --model.name ../../../outputs/mamba2-370m-pg19-finetune --data.name data/valid.txt --per_gpu_batch_size=1 --ssm --output my_lenses/mamba2-370m-pg19-finetune
+```
+To train diffmamba, specify the correct path to the model and change the required output directory.
+To train the lens in a distributed fashion, change ```--per_gpu_batch_size``` to the number of available GPUs.
+
+### Evaluate
+To evaluate run:
+```
+python test_babilong_0k.py --ckpt_path ../../../outputs/mamba2-370m-needle-finetune
+```
+add ```--diffmamba``` flag if using Diff-Mamba.
+
+You can stop the test early when using the flag ```--num_examples```. The compatible lens will be loaded from the my_lenses directory.
+
+### Plot
+To plot the results run:
+```
+python plot_tuned_lens.py --diff_results_path results/diffmamba2-370m-needle-finetune-lens_eval.txt --mamba_results_path results/mamba2-370m-needle-finetune-lens_eval.txt
+```
+Use ```--log``` to create a log scale plot and ```--start-layer``` and ```--end-layer``` to choose specific layers to plot.
+
+## Acknowledgements
+
+All model implementations are based on [Mamba](https://github.com/state-spaces/mamba). Training and evaluation for the language modeling experiments are based on [S4](https://github.com/state-spaces/s4) repository. Evaluation on BABILong is based on [BABILong](https://github.com/booydar/babilong) repo, and measuring signal-to-noise ratio through the layers is based on [tuned-lens](https://github.com/AlignmentResearch/tuned-lens).
+
+## Citation
+
+If you use this code, please consider citing the following:
+
+```
+@misc{schneider2025differentialmamba,
+ title={Differential Mamba},
+ author={Nadav Schneider and Itamar Zimerman and Eliya Nachmani},
+ year={2025},
+ eprint={2507.06204},
+ archivePrefix={arXiv},
+ primaryClass={cs.LG},
+ url={https://arxiv.org/abs/2507.06204},
+}
+```
+
# Mamba

diff --git a/figures/LensLogScale.PNG b/figures/LensLogScale.PNG
new file mode 100644
index 000000000..3d714a491
Binary files /dev/null and b/figures/LensLogScale.PNG differ
diff --git a/figures/babilong.PNG b/figures/babilong.PNG
new file mode 100644
index 000000000..f54b15f17
Binary files /dev/null and b/figures/babilong.PNG differ
diff --git a/figures/diffmamba.PNG b/figures/diffmamba.PNG
new file mode 100644
index 000000000..e297080c7
Binary files /dev/null and b/figures/diffmamba.PNG differ
diff --git a/mamba_ssm/__init__.py b/mamba_ssm/__init__.py
index 6280931e4..ac4f6e311 100644
--- a/mamba_ssm/__init__.py
+++ b/mamba_ssm/__init__.py
@@ -1,4 +1,4 @@
-__version__ = "2.2.5"
+__version__ = "2.2.4"
from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn
from mamba_ssm.modules.mamba_simple import Mamba
diff --git a/mamba_ssm/models/mixer_seq_simple.py b/mamba_ssm/models/mixer_seq_simple.py
index fae2257a9..7a22a2b71 100644
--- a/mamba_ssm/models/mixer_seq_simple.py
+++ b/mamba_ssm/models/mixer_seq_simple.py
@@ -11,22 +11,118 @@
import torch
import torch.nn as nn
-from mamba_ssm.models.config_mamba import MambaConfig
-from mamba_ssm.modules.mamba_simple import Mamba
-from mamba_ssm.modules.mamba2 import Mamba2
-from mamba_ssm.modules.mha import MHA
-from mamba_ssm.modules.mlp import GatedMLP
-from mamba_ssm.modules.block import Block
-from mamba_ssm.utils.generation import GenerationMixin
-from mamba_ssm.utils.hf import load_config_hf, load_state_dict_hf
+from DiffMamba.mamba_ssm.models.config_mamba import MambaConfig
+from DiffMamba.mamba_ssm.modules.mamba_simple import Mamba
+from DiffMamba.mamba_ssm.modules.mamba2 import Mamba2
+from DiffMamba.mamba_ssm.modules.mha import MHA
+from DiffMamba.mamba_ssm.modules.mlp import GatedMLP
+from DiffMamba.mamba_ssm.modules.block import Block, DiffBlockPaper
+from DiffMamba.mamba_ssm.utils.generation import GenerationMixin
+from DiffMamba.mamba_ssm.utils.hf import load_config_hf, load_state_dict_hf
try:
- from mamba_ssm.ops.triton.layer_norm import RMSNorm, layer_norm_fn, rms_norm_fn
+ from DiffMamba.mamba_ssm.ops.triton.layer_norm import RMSNorm, layer_norm_fn, rms_norm_fn
except ImportError:
RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None
+def create_diff_block(
+ d_model,
+ d_intermediate,
+ ssm_cfg=None,
+ attn_layer_idx=None,
+ attn_cfg=None,
+ norm_epsilon=1e-5,
+ rms_norm=False,
+ residual_in_fp32=False,
+ fused_add_norm=False,
+ layer_idx=None,
+ device=None,
+ dtype=None):
+ if ssm_cfg is None:
+ ssm_cfg = {}
+ if attn_layer_idx is None:
+ attn_layer_idx = []
+ if attn_cfg is None:
+ attn_cfg = {}
+ factory_kwargs = {"device": device, "dtype": dtype}
+ if layer_idx % 4 != 2: # layer_idx % 4 != 1 and layer_idx % 4 != 3:
+ if layer_idx not in attn_layer_idx:
+ # Create a copy of the config to modify
+ ssm_cfg = copy.deepcopy(ssm_cfg) if ssm_cfg is not None else {}
+ ssm_layer = ssm_cfg.pop("layer", "Mamba1")
+ if ssm_layer not in ["Mamba1", "Mamba2"]:
+ raise ValueError(f"Invalid ssm_layer: {ssm_layer}, only support Mamba1 and Mamba2")
+ mixer_cls = partial(
+ Mamba2 if ssm_layer == "Mamba2" else Mamba,
+ layer_idx=layer_idx,
+ **ssm_cfg,
+ **factory_kwargs
+ )
+ else:
+ mixer_cls = partial(MHA, layer_idx=layer_idx, **attn_cfg, **factory_kwargs)
+ norm_cls = partial(
+ nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon, **factory_kwargs
+ )
+ if d_intermediate == 0:
+ mlp_cls = nn.Identity
+ else:
+ mlp_cls = partial(
+ GatedMLP, hidden_features=d_intermediate, out_features=d_model, **factory_kwargs
+ )
+ block = Block(
+ d_model,
+ mixer_cls,
+ mlp_cls,
+ norm_cls=norm_cls,
+ fused_add_norm=fused_add_norm,
+ residual_in_fp32=residual_in_fp32,
+ )
+ block.layer_idx = layer_idx
+ else:
+ if layer_idx not in attn_layer_idx:
+ # Create a copy of the config to modify
+ ssm_cfg = copy.deepcopy(ssm_cfg) if ssm_cfg is not None else {}
+ ssm_layer = ssm_cfg.pop("layer", "Mamba1")
+ if ssm_layer not in ["Mamba1", "Mamba2"]:
+ raise ValueError(f"Invalid ssm_layer: {ssm_layer}, only support Mamba1 and Mamba2")
+ mixer_cls1 = partial(
+ Mamba2 if ssm_layer == "Mamba2" else Mamba,
+ layer_idx=layer_idx,
+ **ssm_cfg,
+ **factory_kwargs
+ )
+ mixer_cls2 = partial(
+ Mamba2 if ssm_layer == "Mamba2" else Mamba,
+ layer_idx=layer_idx,
+ **ssm_cfg,
+ **factory_kwargs
+ )
+ else:
+ mixer_cls1 = partial(MHA, layer_idx=layer_idx, **attn_cfg, **factory_kwargs)
+ mixer_cls2 = partial(MHA, layer_idx=layer_idx, **attn_cfg, **factory_kwargs)
+ norm_cls = partial(
+ nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon, **factory_kwargs
+ )
+ if d_intermediate == 0:
+ mlp_cls = nn.Identity
+ else:
+ mlp_cls = partial(
+ GatedMLP, hidden_features=d_intermediate, out_features=d_model, **factory_kwargs
+ )
+ block = DiffBlockPaper(
+ d_model,
+ mixer_cls1,
+ mixer_cls2,
+ mlp_cls,
+ norm_cls=norm_cls,
+ fused_add_norm=fused_add_norm,
+ residual_in_fp32=residual_in_fp32,
+ layer_idx=layer_idx,
+ )
+ block.layer_idx = layer_idx
+ return block
-def create_block(
+def create_regular_block(
d_model,
d_intermediate,
ssm_cfg=None,
@@ -38,8 +134,7 @@ def create_block(
fused_add_norm=False,
layer_idx=None,
device=None,
- dtype=None,
-):
+ dtype=None,):
if ssm_cfg is None:
ssm_cfg = {}
if attn_layer_idx is None:
@@ -81,6 +176,52 @@ def create_block(
block.layer_idx = layer_idx
return block
+def create_block(
+ d_model,
+ d_intermediate,
+ ssm_cfg=None,
+ attn_layer_idx=None,
+ attn_cfg=None,
+ norm_epsilon=1e-5,
+ rms_norm=False,
+ residual_in_fp32=False,
+ fused_add_norm=False,
+ layer_idx=None,
+ device=None,
+ dtype=None,
+ mamba_type="DiffMamba",
+):
+ if mamba_type == "DiffMamba":
+ return create_diff_block(
+ d_model,
+ d_intermediate,
+ ssm_cfg,
+ attn_layer_idx,
+ attn_cfg,
+ norm_epsilon,
+ rms_norm,
+ residual_in_fp32,
+ fused_add_norm,
+ layer_idx,
+ device,
+ dtype
+ )
+ else:
+ return create_regular_block(
+ d_model,
+ d_intermediate,
+ ssm_cfg,
+ attn_layer_idx,
+ attn_cfg,
+ norm_epsilon,
+ rms_norm,
+ residual_in_fp32,
+ fused_add_norm,
+ layer_idx,
+ device,
+ dtype
+ )
+
# https://github.com/huggingface/transformers/blob/c28d04e9e252a1a099944e325685f14d242ecdcd/src/transformers/models/gpt2/modeling_gpt2.py#L454
def _init_weights(
diff --git a/mamba_ssm/modules/block.py b/mamba_ssm/modules/block.py
index 1bd968a0b..05b3408eb 100644
--- a/mamba_ssm/modules/block.py
+++ b/mamba_ssm/modules/block.py
@@ -1,10 +1,13 @@
# Copyright (c) 2024, Tri Dao, Albert Gu.
-from typing import Optional
+from typing import Optional, Tuple, Type, Callable
import torch
from torch import nn, Tensor
+import torch.nn.functional as F
+import os
+os.environ["TOKENIZERS_PARALLELISM"] = "false"
-from mamba_ssm.ops.triton.layer_norm import RMSNorm, layer_norm_fn
+from DiffMamba.mamba_ssm.ops.triton.layer_norm import RMSNorm, layer_norm_fn
class Block(nn.Module):
@@ -89,3 +92,269 @@ def forward(
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
+
+
+class DiffBlockPaper(nn.Module):
+ """
+ Diff-Mamba block: Add->Norm -> (mixer1 || mixer2) -> Norm each -> subtract with λ -> Linear -> Norm
+ -> (optional MLP sublayer like vanilla Block)
+
+ Returns (hidden_states, residual) with the SAME contract as mamba_ssm.modules.block.Block:
+ - If no MLP: residual is the pre-norm Add sum, hidden_states is the sublayer output (no add here).
+ - If MLP: we do residual += hidden_states before norm2+MLP, as in vanilla.
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ mixer_cls1: Callable[[int], nn.Module],
+ mixer_cls2: Callable[[int], nn.Module],
+ mlp_cls: Callable[[int], nn.Module],
+ *,
+ norm_cls: Callable[[int], nn.Module] = nn.LayerNorm,
+ fused_add_norm: bool = False,
+ residual_in_fp32: bool = False,
+ layer_idx: int = 0,
+ use_postscale: bool = False, # optional extra scaling by (1 - lambda_init)
+ lambda_init: Optional[float] = None,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ) -> None:
+ super().__init__()
+ self.d_model = dim
+ self.layer_idx = layer_idx
+ self.fused_add_norm = fused_add_norm
+ self.residual_in_fp32 = residual_in_fp32
+ self.use_postscale = bool(use_postscale)
+
+ # Prenorm for Add->Norm input
+ self.norm = norm_cls(dim)
+
+ # Two parallel mixers
+ self.mixer1 = mixer_cls1(dim)
+ self.mixer2 = mixer_cls2(dim)
+
+ # Post-mixer norms (separate for each branch) and post-sub norm
+ self.subln = norm_cls(dim)
+
+ # Per-layer scalar λ (σ(λ̄)+λ_init), λ̄ initialized very negative -> small λ
+ self.lambda_init = 0.8 - 0.6 * torch.exp(torch.tensor(-0.3 * self.layer_idx))
+ self.lambda_q1 = nn.Parameter(torch.randn(self.d_model))
+
+ # Optional second sublayer (MLP) mirrors vanilla Block
+ if mlp_cls is nn.Identity:
+ self.mlp = None
+ else:
+ self.norm2 = norm_cls(dim)
+ self.mlp = mlp_cls(dim)
+
+ if self.fused_add_norm:
+ assert layer_norm_fn is not None, "fused_add_norm=True requires Triton layer_norm_fn"
+ assert isinstance(self.norm, (nn.LayerNorm, RMSNorm)) if RMSNorm is not None else isinstance(self.norm, nn.LayerNorm)
+
+ # -------- cache helper (lets each mixer see its own cache view) --------
+ class _SwapCache:
+ def __init__(self, ip, idx: int, view):
+ self.ip, self.idx, self.view, self.orig = ip, idx, view, None
+ def __enter__(self):
+ if self.ip is not None:
+ self.orig = self.ip.key_value_memory_dict.get(self.idx, None)
+ self.ip.key_value_memory_dict[self.idx] = self.view
+ def __exit__(self, exc_type, exc, tb):
+ if self.ip is not None:
+ if self.orig is None:
+ self.ip.key_value_memory_dict.pop(self.idx, None)
+ else:
+ self.ip.key_value_memory_dict[self.idx] = self.orig
+
+ def _run_mixers(self, x: Tensor, inference_params=None, **mixer_kwargs) -> Tuple[Tensor, Tensor]:
+ # No caching provided: straightforward parallel calls
+ if inference_params is None:
+ y1 = self.mixer1(x, inference_params=None, **mixer_kwargs)
+ y2 = self.mixer2(x, inference_params=None, **mixer_kwargs)
+ return y1, y2
+
+ # Ensure persistent, SEPARATE cache slots for each branch
+ ip = inference_params
+ main_key = self.layer_idx
+ # Use a hidden negative key for branch 2 to avoid collisions with real layer indices
+ branch2_key = -(self.layer_idx + 1)
+
+ # Ensure branch 1 cache exists at main_key
+ slot1 = ip.key_value_memory_dict.get(main_key, None)
+ if slot1 is None:
+ try:
+ bs = x.shape[0]
+ max_seqlen = getattr(ip, "max_seqlen", None)
+ dtype = x.dtype
+ if max_seqlen is not None:
+ ip.key_value_memory_dict[main_key] = self.mixer1.allocate_inference_cache(bs, max_seqlen, dtype=dtype)
+ slot1 = ip.key_value_memory_dict.get(main_key, None)
+ except Exception:
+ slot1 = None
+
+ # Ensure branch 2 cache exists at branch2_key
+ slot2 = ip.key_value_memory_dict.get(branch2_key, None)
+ if slot2 is None:
+ try:
+ bs = x.shape[0]
+ max_seqlen = getattr(ip, "max_seqlen", None)
+ dtype = x.dtype
+ if max_seqlen is not None:
+ ip.key_value_memory_dict[branch2_key] = self.mixer2.allocate_inference_cache(bs, max_seqlen, dtype=dtype)
+ slot2 = ip.key_value_memory_dict.get(branch2_key, None)
+ except Exception:
+ slot2 = None
+
+ # Run branches against their own caches; swap branch2 into the main key during its call
+ y1 = self.mixer1(x, inference_params=ip, **mixer_kwargs)
+ if slot2 is not None:
+ with DiffBlockPaper._SwapCache(ip, main_key, slot2):
+ y2 = self.mixer2(x, inference_params=ip, **mixer_kwargs)
+ else:
+ y2 = self.mixer2(x, inference_params=ip, **mixer_kwargs)
+ return y1, y2
+
+ @staticmethod
+ def _to_norm_dtype(norm: nn.Module, x: Tensor) -> Tensor:
+ w = getattr(norm, "weight", None)
+ return x.to(w.dtype) if isinstance(w, torch.Tensor) else x
+
+ def forward(
+ self,
+ hidden_states: Tensor,
+ residual: Optional[Tensor] = None,
+ inference_params=None,
+ **mixer_kwargs,
+ ) -> Tuple[Tensor, Tensor]:
+ # ---- Add -> Norm (prenorm) ----
+ if not self.fused_add_norm:
+ residual = (hidden_states + residual) if residual is not None else hidden_states
+ hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype))
+ if self.residual_in_fp32:
+ residual = residual.to(torch.float32)
+ else:
+ hidden_states, residual = layer_norm_fn(
+ hidden_states,
+ self.norm.weight,
+ self.norm.bias,
+ residual=residual,
+ prenorm=True,
+ residual_in_fp32=self.residual_in_fp32,
+ eps=self.norm.eps,
+ is_rms_norm=isinstance(self.norm, RMSNorm)
+ )
+
+ # ---- Scalar λ per layer ----
+ lambda_q1 = torch.sum(self.lambda_q1, dim=-1).float()
+ lambda_full = torch.sigmoid(lambda_q1) + self.lambda_init
+
+ # ---- Parallel mixers ----
+ y1, y2 = self._run_mixers(hidden_states, inference_params, **mixer_kwargs)
+
+ # ---- Differential combine -> out proj -> post-sub norm ----
+ attn = y1 - lambda_full * y2
+ attn = self.subln(attn)
+
+ # First sublayer output
+ hidden_states = attn * (1.0 - self.lambda_init)
+
+
+ # ---- Optional MLP sublayer (mirrors vanilla Block) ----
+ if self.mlp is not None:
+ if not self.fused_add_norm:
+ residual = hidden_states + residual
+ hidden_states = self.norm2(residual.to(dtype=self.norm2.weight.dtype))
+ if self.residual_in_fp32:
+ residual = residual.to(torch.float32)
+ else:
+ hidden_states, residual = layer_norm_fn(
+ hidden_states,
+ self.norm2.weight,
+ self.norm2.bias,
+ residual=residual,
+ prenorm=True,
+ residual_in_fp32=self.residual_in_fp32,
+ eps=self.norm2.eps,
+ is_rms_norm=isinstance(self.norm2, RMSNorm)
+ )
+ hidden_states = self.mlp(hidden_states)
+ else:
+ residual = hidden_states + residual
+
+ return hidden_states, residual
+
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
+ # Share a single cache object between both mixers
+ cache_fn = getattr(self.mixer1, "allocate_inference_cache", None)
+ shared = cache_fn(batch_size, max_seqlen, dtype=dtype, **kwargs) if callable(cache_fn) else None
+ return shared
+
+ @classmethod
+ def from_pretrained_block(
+ cls,
+ block: nn.Module,
+ mixer_cls: Optional[Callable[[int], nn.Module]] = None,
+ mlp_cls: Optional[Callable[[int], nn.Module]] = None,
+ norm_cls: Optional[Callable[[int], nn.Module]] = None,
+ fused_add_norm: Optional[bool] = None,
+ residual_in_fp32: Optional[bool] = None,
+ lambda_init: Optional[float] = None,
+ use_postscale: bool = False,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ) -> "DiffBlockPaper":
+ """Build a DiffBlock from a vanilla Block and copy weights into both mixers."""
+ src_mixer = getattr(block, "mixer", None)
+ src_mlp = getattr(block, "mlp", None)
+ src_norm = getattr(block, "norm", None)
+
+ mixer_cls = mixer_cls or (src_mixer.__class__)
+ mlp_cls = mlp_cls or (src_mlp.__class__ if src_mlp is not None else nn.Identity)
+ norm_cls = norm_cls or (src_norm.__class__ if src_norm is not None else nn.LayerNorm)
+
+ fused_add_norm = fused_add_norm if fused_add_norm is not None else getattr(block, "fused_add_norm", False)
+ residual_in_fp32 = residual_in_fp32 if residual_in_fp32 is not None else getattr(block, "residual_in_fp32", False)
+
+ # Try to infer dimension robustly if block doesn't have d_model
+ dim = getattr(block, "d_model", None)
+ if dim is None:
+ if hasattr(block, "norm") and hasattr(block.norm, "weight"):
+ dim = block.norm.weight.shape[0]
+ else:
+ raise ValueError("Cannot infer dim from the provided block")
+
+ newb = cls(
+ dim=dim,
+ mixer_cls1=mixer_cls,
+ mixer_cls2=mixer_cls,
+ mlp_cls=mlp_cls,
+ norm_cls=norm_cls,
+ fused_add_norm=fused_add_norm,
+ residual_in_fp32=residual_in_fp32,
+ layer_idx=getattr(block, "layer_idx", 0),
+ lambda_init=lambda_init,
+ use_postscale=use_postscale,
+ device=device,
+ dtype=dtype,
+ )
+
+ # copy prenorm
+ if src_norm is not None:
+ newb.norm.load_state_dict(src_norm.state_dict(), strict=False)
+ # seed sublayer norm with same stats
+ newb.subln.load_state_dict(newb.norm.state_dict(), strict=False)
+
+ # copy mixer weights into both mixers
+ if src_mixer is not None:
+ st = src_mixer.state_dict()
+ newb.mixer1.load_state_dict(st, strict=False)
+ newb.mixer2.load_state_dict(st, strict=False)
+
+ # copy mlp & norm2 if present
+ if src_mlp is not None and newb.mlp is not None:
+ newb.mlp.load_state_dict(src_mlp.state_dict(), strict=False)
+ if hasattr(block, "norm2") and hasattr(newb, "norm2"):
+ newb.norm2.load_state_dict(block.norm2.state_dict(), strict=False)
+
+ return newb
diff --git a/mamba_ssm/ops/selective_scan_interface.py b/mamba_ssm/ops/selective_scan_interface.py
index a41f1359c..c51ec40dc 100644
--- a/mamba_ssm/ops/selective_scan_interface.py
+++ b/mamba_ssm/ops/selective_scan_interface.py
@@ -8,12 +8,10 @@
try:
from causal_conv1d import causal_conv1d_fn
- from causal_conv1d.cpp_functions import causal_conv1d_fwd_function, causal_conv1d_bwd_function, causal_conv1d_update_function
+ import causal_conv1d_cuda
except ImportError:
causal_conv1d_fn = None
- causal_conv1d_fwd_function = None
- causal_conv1d_bwd_function = None
- causal_conv1d_update_function = None
+ causal_conv1d_cuda = None
from mamba_ssm.ops.triton.layer_norm import _layer_norm_fwd
@@ -192,7 +190,7 @@ def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weigh
"""
xz: (batch, dim, seqlen)
"""
- assert causal_conv1d_fwd_function is not None, "causal_conv1d_cuda is not available. Please install causal-conv1d."
+ assert causal_conv1d_cuda is not None, "causal_conv1d_cuda is not available. Please install causal-conv1d."
assert checkpoint_lvl in [0, 1]
L = xz.shape[-1]
delta_rank = delta_proj_weight.shape[1]
@@ -208,7 +206,7 @@ def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weigh
conv1d_weight = rearrange(conv1d_weight, "d 1 w -> d w")
x, z = xz.chunk(2, dim=1)
conv1d_bias = conv1d_bias.contiguous() if conv1d_bias is not None else None
- conv1d_out = causal_conv1d_fwd_function(
+ conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(
x, conv1d_weight, conv1d_bias, None, None, None, True
)
# We're being very careful here about the layout, to avoid extra transposes.
@@ -281,7 +279,7 @@ def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weigh
@custom_bwd
def backward(ctx, dout):
# dout: (batch, seqlen, dim)
- assert causal_conv1d_fwd_function is not None, "causal_conv1d_cuda is not available. Please install causal-conv1d."
+ assert causal_conv1d_cuda is not None, "causal_conv1d_cuda is not available. Please install causal-conv1d."
(xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight, delta_proj_weight, out_proj_weight,
conv1d_out, delta, A, B, C, D, delta_bias, scan_intermediates, b_rms_weight, c_rms_weight, dt_rms_weight, out) = ctx.saved_tensors
L = xz.shape[-1]
@@ -291,7 +289,7 @@ def backward(ctx, dout):
if dout.stride(-1) != 1:
dout = dout.contiguous()
if ctx.checkpoint_lvl == 1:
- conv1d_out = causal_conv1d_fwd_function(
+ conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(
x, conv1d_weight, conv1d_bias, None, None, None, True
)
delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(),
@@ -357,7 +355,7 @@ def backward(ctx, dout):
dconv1d_out = rearrange(dconv1d_out, "d (b l) -> b d l", b=x.shape[0], l=x.shape[-1])
# The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the
# backward of conv1d with the backward of chunk).
- dx, dconv1d_weight, dconv1d_bias, *_ = causal_conv1d_bwd_function(
+ dx, dconv1d_weight, dconv1d_bias, *_ = causal_conv1d_cuda.causal_conv1d_bwd(
x, conv1d_weight, conv1d_bias, dconv1d_out, None, None, None, dx, False, True
)
dconv1d_bias = dconv1d_bias if conv1d_bias is not None else None
diff --git a/mamba_ssm/ops/triton/ssd_chunk_scan.py b/mamba_ssm/ops/triton/ssd_chunk_scan.py
index 959078061..fa5b813a2 100644
--- a/mamba_ssm/ops/triton/ssd_chunk_scan.py
+++ b/mamba_ssm/ops/triton/ssd_chunk_scan.py
@@ -132,8 +132,7 @@ def _chunk_scan_fwd_kernel(
dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < chunk_size - k, other=0.0).to(tl.float32)
# If there's seq_idx, we already set cb[i, j] = 0 for seq_idx[i] != seq_idx[j].
# So we don't need masking wrt seq_idx here.
- # cb *= tl.exp((dA_cs_m[:, None] - dA_cs_k[None, :]))
- cb *= tl.exp(tl.minimum((dA_cs_m[:, None] - dA_cs_k[None, :]), 0.0))
+ cb *= tl.exp((dA_cs_m[:, None] - dA_cs_k[None, :]))
dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size - k, other=0.0).to(tl.float32)
cb *= dt_k
if IS_CAUSAL:
@@ -680,8 +679,7 @@ def _chunk_scan_bwd_dx_kernel(
cb = tl.load(cb_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_k[None, :] < K_MAX - k), other=0.0)
dout = tl.load(dout_ptrs, mask=(offs_k[:, None] < K_MAX - k) & (offs_n[None, :] < hdim), other=0.0)
dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < K_MAX - k, other=0.0).to(tl.float32)
- # cb *= tl.exp(dA_cs_k[None, :] - dA_cs_m[:, None])
- cb *= tl.exp(tl.minimum((dA_cs_k[None, :] - dA_cs_m[:, None]), 0.0))
+ cb *= tl.exp(dA_cs_k[None, :] - dA_cs_m[:, None])
# If we don't have the (k + offs_k[None, :] < K_MAX) mask, for indices outside this range,
# we might have dA_cs_m = 0.0 and dA_cs_k very negative, and tl.exp will return inf.
# Multiplying with cb, which is 0.0 outside the range, will make the result NaN.
@@ -818,8 +816,7 @@ def _chunk_scan_bwd_dcb_kernel(
dcb *= dt_n
dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size_limit, other=0.0).to(tl.float32)
dA_cs_n = tl.load(dA_cumsum_ptr + offs_n * stride_dA_cs_csize, mask=offs_n < chunk_size_limit, other=0.0).to(tl.float32)
- # dcb *= tl.exp(dA_cs_m[:, None] - dA_cs_n[None, :])
- dcb *= tl.exp(tl.minimum((dA_cs_m[:, None] - dA_cs_n[None, :]), 0.0))
+ dcb *= tl.exp(dA_cs_m[:, None] - dA_cs_n[None, :])
if HAS_DDA_CS:
tl.static_assert(not HAS_SEQ_IDX, "HAS_SEQ_IDX not supported with HAS_DDA_CS yet")
ddA_cs = dcb * cb
@@ -1011,8 +1008,7 @@ def _chunk_scan_bwd_ddAcs_stable_kernel_old(
acc *= dt_n
dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size, other=0.0).to(tl.float32)
dA_cs_n = tl.load(dA_cumsum_ptr + offs_n * stride_dA_cs_csize, mask=offs_n < chunk_size, other=0.0).to(tl.float32)
- # acc *= tl.exp(dA_cs_m[:, None] - dA_cs_n[None, :])
- acc *= tl.exp(tl.minimum((dA_cs_m[:, None] - dA_cs_n[None, :]), 0.0))
+ acc *= tl.exp(dA_cs_m[:, None] - dA_cs_n[None, :])
mask = offs_m[:, None] >= offs_n[None, :] + 1
acc = tl.where(mask, acc, 0.0)
acc = tl.cumsum(acc, axis=1)
@@ -1138,8 +1134,7 @@ def _chunk_scan_bwd_ddAcs_stable_kernel(
cb = tl.load(cb_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_n[None, :] < chunk_size - start_n), other=0.0).to(tl.float32)
acc *= cb
dA_cs_n = tl.load(dA_cumsum_ptr + (start_n + offs_n) * stride_dA_cs_csize, mask=offs_n < chunk_size - start_n, other=0.0).to(tl.float32)
- # acc *= tl.exp(dA_cs_m[:, None] - dA_cs_n[None, :])
- acc *= tl.exp(tl.minimum((dA_cs_m[:, None] - dA_cs_n[None, :]), 0.0))
+ acc *= tl.exp(dA_cs_m[:, None] - dA_cs_n[None, :])
mask = offs_m[:, None] >= start_n + offs_n[None, :] + 1
acc = tl.where(mask, acc, 0.0)
rowsum_new = rowsum + tl.sum(acc, axis=1)
diff --git a/mamba_ssm/ops/triton/ssd_chunk_state.py b/mamba_ssm/ops/triton/ssd_chunk_state.py
index 50838d055..bb49c9a96 100644
--- a/mamba_ssm/ops/triton/ssd_chunk_state.py
+++ b/mamba_ssm/ops/triton/ssd_chunk_state.py
@@ -141,7 +141,7 @@ def _chunk_cumsum_bwd_kernel(
dt += dt_bias[:, None]
if DT_SOFTPLUS:
dt_presoftplus = dt
- dt = tl.where(dt <= 20.0, softplus(dt), dt)
+ dt = tl.where(dt <= 20.0, softplus(dt), ddt)
clamp_mask = (dt < dt_min) | (dt > dt_max)
# As of Triton 2.2.0, tl.clamp is not available yet
# dt = tl.clamp(dt, dt_min, dt_max)
@@ -229,11 +229,9 @@ def _chunk_state_fwd_kernel(
seq_idx_k = tl.load(seq_idx_ptrs, mask=offs_k < chunk_size_limit - k, other=-1)
dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to(tl.float32)
if not HAS_SEQ_IDX:
- # scale = tl.exp((dA_cs_last - dA_cs_k)) * dt_k
- scale = tl.exp(tl.minimum((dA_cs_last - dA_cs_k), 0.0)) * dt_k
+ scale = tl.exp((dA_cs_last - dA_cs_k)) * dt_k
else:
- # scale = tl.where(seq_idx_k == seq_idx_last, tl.exp((dA_cs_last - dA_cs_k)) * dt_k, 0.0)
- scale = tl.where(seq_idx_k == seq_idx_last, tl.exp(tl.minimum((dA_cs_last - dA_cs_k), 0.0)) * dt_k, 0.0)
+ scale = tl.where(seq_idx_k == seq_idx_last, tl.exp((dA_cs_last - dA_cs_k)) * dt_k, 0.0)
b *= scale[:, None]
b = b.to(x_ptr.dtype.element_ty)
acc += tl.dot(x, b)
@@ -334,8 +332,7 @@ def _chunk_state_bwd_dx_kernel(
dA_cumsum_ptrs = dA_cumsum_ptr + offs_m * stride_dA_cs_csize
dA_cs_m = tl.load(dA_cumsum_ptrs, mask=offs_m < chunk_size, other=0.0).to(tl.float32)
dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size, other=0.0).to(tl.float32)
- # acc *= tl.exp(dA_cs_last - dA_cs_m)[:, None]
- acc *= tl.exp(tl.minimum((dA_cs_last - dA_cs_m), 0.0))[:, None]
+ acc *= tl.exp(dA_cs_last - dA_cs_m)[:, None]
x_ptrs = x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim)
x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32)
@@ -437,11 +434,9 @@ def _chunk_state_bwd_db_kernel(
dA_cs_m = tl.load(dA_cumsum_ptrs, mask=offs_m < chunk_size, other=0.0).to(tl.float32)
dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size, other=0.0).to(tl.float32)
if not HAS_SEQ_IDX:
- # scale = tl.exp(dA_cs_last - dA_cs_m)
- scale = tl.exp(tl.minimum((dA_cs_last - dA_cs_m), 0.0))
+ scale = tl.exp(dA_cs_last - dA_cs_m)
else:
- # scale = tl.where(seq_idx_m == seq_idx_last, tl.exp(dA_cs_last - dA_cs_m), 0.0)
- scale = tl.where(seq_idx_m == seq_idx_last, tl.exp(tl.minimum((dA_cs_last - dA_cs_m), 0.0)), 0.0)
+ scale = tl.where(seq_idx_m == seq_idx_last, tl.exp(dA_cs_last - dA_cs_m), 0.0)
db *= (scale * dt_m)[:, None]
if HAS_DDA_CS:
# This is the gradient wrt (dA_cs_last - dA_cs_m), i.e. the exclusive reverse cumsum
@@ -554,13 +549,11 @@ def _chunk_state_bwd_ddAcs_stable_kernel(
dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size, other=0.0).to(tl.float32)
dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(tl.float32)
if not HAS_SEQ_IDX:
- # scale = tl.exp(dA_cs_last - dA_cs_m)
- scale = tl.exp(tl.minimum((dA_cs_last - dA_cs_m), 0.0))
+ scale = tl.exp(dA_cs_last - dA_cs_m)
else:
seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1)
seq_idx_last = tl.load(seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen)
- # scale = tl.where(seq_idx_m == seq_idx_last, tl.exp(dA_cs_last - dA_cs_m), 0.0)
- scale = tl.where(seq_idx_m == seq_idx_last, tl.exp(tl.minimum((dA_cs_last - dA_cs_m), 0.0)), 0.0)
+ scale = tl.where(seq_idx_m == seq_idx_last, tl.exp(dA_cs_last - dA_cs_m), 0.0)
acc *= scale[:, None]
x_ptrs = x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim)
@@ -641,10 +634,8 @@ def _chunk_state_varlen_kernel(
b = tl.load(b_ptrs, mask=(offs_k[:, None] < chunk_size_limit - k) & (offs_n[None, :] < dstate) & (offs_k[:, None] >= start_idx_cur - k), other=0.0).to(tl.float32)
dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to(tl.float32)
dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to(tl.float32)
- # scale = tl.where((offs_k >= start_idx_cur - k) & (offs_k < chunk_size_limit - k),
- # tl.exp((dA_cs_last - dA_cs_k)) * dt_k, 0.0)
scale = tl.where((offs_k >= start_idx_cur - k) & (offs_k < chunk_size_limit - k),
- tl.exp(tl.minimum((dA_cs_last - dA_cs_k), 0.0)) * dt_k, 0.0)
+ tl.exp((dA_cs_last - dA_cs_k)) * dt_k, 0.0)
b *= scale[:, None]
b = b.to(x_ptr.dtype.element_ty)
acc += tl.dot(x, b)
diff --git a/mamba_ssm/ops/triton/ssd_combined.py b/mamba_ssm/ops/triton/ssd_combined.py
index bbf4ecf84..58a6e04a9 100644
--- a/mamba_ssm/ops/triton/ssd_combined.py
+++ b/mamba_ssm/ops/triton/ssd_combined.py
@@ -20,12 +20,9 @@
try:
from causal_conv1d import causal_conv1d_fn
- from causal_conv1d.cpp_functions import causal_conv1d_fwd_function, causal_conv1d_bwd_function, causal_conv1d_update_function
+ import causal_conv1d_cuda
except ImportError:
- causal_conv1d_fn = None
- causal_conv1d_fwd_function = None
- causal_conv1d_bwd_function = None
- causal_conv1d_update_function = None
+ causal_conv1d_fn, causal_conv1d_cuda = None, None
from mamba_ssm.ops.triton.ssd_bmm import _bmm_chunk_fwd, _bmm_chunk_bwd
from mamba_ssm.ops.triton.ssd_chunk_state import _chunk_cumsum_fwd, _chunk_cumsum_bwd
@@ -50,13 +47,6 @@ def init_to_zero(names):
return lambda nargs: [nargs[name].zero_() for name in names if nargs[name] is not None]
-def rearrange_and_update_stride(tensor, pattern=None, dim=2):
- # ensure tensor.stride(dim) is a multiple of eight after rearranging according to pattern,
- # if not call contiguous(), rearrange only if pattern is not None
- tensor_rearranged = rearrange(tensor, pattern) if pattern is not None else tensor
- return tensor_rearranged.contiguous() if tensor_rearranged.stride(dim) % 8 != 0 else tensor_rearranged
-
-
@triton.autotune(
configs=[
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8, pre_hook=init_to_zero(["ddt_ptr"])),
@@ -130,13 +120,11 @@ def _chunk_scan_chunk_state_bwd_dx_kernel(
dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(tl.float32)
if not HAS_SEQ_IDX:
- # scale = tl.exp(dA_cs_last - dA_cs_m)
- scale = tl.exp(tl.minimum((dA_cs_last - dA_cs_m), 0.0))
+ scale = tl.exp(dA_cs_last - dA_cs_m)
else:
seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1)
seq_idx_last = tl.load(seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen)
- # scale = tl.where(seq_idx_m == seq_idx_last, tl.exp(dA_cs_last - dA_cs_m), 0.0)
- scale = tl.where(seq_idx_m == seq_idx_last, tl.exp(tl.minimum((dA_cs_last - dA_cs_m), 0.0)), 0.0)
+ scale = tl.where(seq_idx_m == seq_idx_last, tl.exp(dA_cs_last - dA_cs_m), 0.0)
# Might be faster to just do 1 iteration with larger BLOCK_SIZE_K, up to block size 128
# However, we're getting error with the Triton compiler 2.1.0 for that code path:
# Unexpected mma -> mma layout conversion
@@ -182,8 +170,7 @@ def _chunk_scan_chunk_state_bwd_dx_kernel(
cb = tl.load(cb_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_k[None, :] < K_MAX - k), other=0.0)
dout = tl.load(dout_ptrs, mask=(offs_k[:, None] < K_MAX - k) & (offs_n[None, :] < hdim), other=0.0)
dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < K_MAX - k, other=0.0).to(tl.float32)
- # cb *= tl.exp(dA_cs_k[None, :] - dA_cs_m[:, None])
- cb *= tl.exp(tl.minimum((dA_cs_k[None, :] - dA_cs_m[:, None]), 0.0))
+ cb *= tl.exp(dA_cs_k[None, :] - dA_cs_m[:, None])
# If we don't have the (k + offs_k[None, :] < K_MAX) mask, for indices outside this range,
# we might have dA_cs_m = 0.0 and dA_cs_k very negative, and tl.exp will return inf.
# Multiplying with cb, which is 0.0 outside the range, will make the result NaN.
@@ -789,7 +776,7 @@ def forward(ctx, zxbcdt, conv1d_weight, conv1d_bias, dt_bias, A, D, chunk_size,
zx0, z, xBC, dt = torch.split(zxbcdt, [2 * d_nonssm, dim, dim + ngroups * dstate * 2, nheads], dim=-1)
seq_idx = seq_idx.contiguous() if seq_idx is not None else None
xBC_conv = rearrange(
- causal_conv1d_fwd_function(rearrange_and_update_stride(xBC, "b s d -> b d s"),
+ causal_conv1d_cuda.causal_conv1d_fwd(rearrange(xBC, "b s d -> b d s"),
conv1d_weight, conv1d_bias, seq_idx, None, None, activation in ["silu", "swish"]),
"b d s -> b s d"
)
@@ -863,8 +850,8 @@ def backward(ctx, dout, *args):
zx0, z, xBC, dt = torch.split(zxbcdt, [2 * d_nonssm, dim, dim + 2 * ctx.ngroups * dstate, nheads], dim=-1)
# Recompute x, B, C
xBC_conv = rearrange(
- causal_conv1d_fwd_function(rearrange_and_update_stride(xBC, "b s d -> b d s"),
- conv1d_weight, conv1d_bias, seq_idx, None, None, ctx.activation in ["silu", "swish"]),
+ causal_conv1d_cuda.causal_conv1d_fwd(rearrange(xBC, "b s d -> b d s"),
+ conv1d_weight, conv1d_bias, seq_idx, None, None, ctx.activation in ["silu", "swish"]),
"b d s -> b s d"
)
x, B, C = torch.split(xBC_conv, [dim, ctx.ngroups * dstate, ctx.ngroups * dstate], dim=-1)
@@ -913,14 +900,10 @@ def backward(ctx, dout, *args):
else:
doutproj_weight, doutproj_bias = None, None
dxBC_given = rearrange(dxBC_given, "b s d -> b d s")
- dxBC_given_update, dweight, dbias, *_ = causal_conv1d_bwd_function(
- rearrange_and_update_stride(xBC, "b s d -> b d s"), conv1d_weight, conv1d_bias,
- rearrange(dxBC, "b s d -> b d s"), seq_idx, None, None, rearrange_and_update_stride(dxBC_given), False, ctx.activation in ["silu", "swish"]
+ dxBC_given, dweight, dbias, *_ = causal_conv1d_cuda.causal_conv1d_bwd(
+ rearrange(xBC, "b s d -> b d s"), conv1d_weight, conv1d_bias,
+ rearrange(dxBC, "b s d -> b d s"), seq_idx, None, None, dxBC_given, False, ctx.activation in ["silu", "swish"]
)
- if dxBC_given.stride() != dxBC_given_update.stride():
- dxBC_given.copy_(dxBC_given_update)
- else:
- dxBC_given = dxBC_given_update
dxBC_given = rearrange(dxBC_given, "b d s -> b s d")
return dzxbcdt, dweight, dbias, ddt_bias, dA, dD, None, dinitial_states, None, None, None, None, drmsnorm_weight, None, doutproj_weight, doutproj_bias, None, None, None
diff --git a/pyproject.toml b/pyproject.toml
index 5831fe66e..ab6315c33 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -12,12 +12,11 @@ license = { file = "LICENSE" } # Include a LICENSE file in your repo
keywords = ["cuda", "pytorch", "state-space model"]
classifiers = [
"Programming Language :: Python :: 3",
- "License :: OSI Approved :: Apache Software License",
+ "License :: OSI Approved :: BSD License",
"Operating System :: Unix"
]
dependencies = [
"torch",
- "triton",
"ninja",
"einops",
"transformers",
diff --git a/setup.py b/setup.py
index 32d62ed90..3ee91645f 100755
--- a/setup.py
+++ b/setup.py
@@ -172,29 +172,22 @@ def append_nvcc_threads(nvcc_extra_args):
"Note: make sure nvcc has a supported version by running nvcc -V."
)
- if bare_metal_version <= Version("12.9"):
- cc_flag.append("-gencode")
- cc_flag.append("arch=compute_53,code=sm_53")
- cc_flag.append("-gencode")
- cc_flag.append("arch=compute_62,code=sm_62")
- cc_flag.append("-gencode")
- cc_flag.append("arch=compute_70,code=sm_70")
- cc_flag.append("-gencode")
- cc_flag.append("arch=compute_72,code=sm_72")
cc_flag.append("-gencode")
- cc_flag.append("arch=compute_75,code=sm_75")
+ cc_flag.append("arch=compute_53,code=sm_53")
+ cc_flag.append("-gencode")
+ cc_flag.append("arch=compute_62,code=sm_62")
+ cc_flag.append("-gencode")
+ cc_flag.append("arch=compute_70,code=sm_70")
+ cc_flag.append("-gencode")
+ cc_flag.append("arch=compute_72,code=sm_72")
cc_flag.append("-gencode")
cc_flag.append("arch=compute_80,code=sm_80")
cc_flag.append("-gencode")
cc_flag.append("arch=compute_87,code=sm_87")
+
if bare_metal_version >= Version("11.8"):
cc_flag.append("-gencode")
cc_flag.append("arch=compute_90,code=sm_90")
- if bare_metal_version >= Version("12.8"):
- cc_flag.append("-gencode")
- cc_flag.append("arch=compute_100,code=sm_100")
- cc_flag.append("-gencode")
- cc_flag.append("arch=compute_120,code=sm_120")
# HACK: The compiler flag -D_GLIBCXX_USE_CXX11_ABI is set to be the same as
@@ -363,7 +356,7 @@ def run(self):
url="https://github.com/state-spaces/mamba",
classifiers=[
"Programming Language :: Python :: 3",
- "License :: OSI Approved :: Apache Software License",
+ "License :: OSI Approved :: BSD License",
"Operating System :: Unix",
],
ext_modules=ext_modules,
@@ -378,7 +371,7 @@ def run(self):
"packaging",
"ninja",
"einops",
- "triton",
+ # "triton",
"transformers",
# "causal_conv1d>=1.4.0",
],
diff --git a/usage.md b/usage.md
deleted file mode 100644
index 1b588ce2c..000000000
--- a/usage.md
+++ /dev/null
@@ -1,43 +0,0 @@
-# Mamba adoption
-
-We've been very happy to see Mamba being adopted by many organizations
-and research labs to speed up their training / inference.
-This page contains a partial list of places where Mamba is being used.
-If you'd like to add links to your organization / product / codebase, please open a
-PR or email us. We'd very much like to hear from you!
-
-## Large language models and multi-modal models
-
-- [Tencent's Hunyuan-TurboS (560B)](https://arxiv.org/abs/2505.15431)
-
-- [Nvidia Nemotron-H (8B, 47B, 56B)](https://research.nvidia.com/labs/adlr/nemotronh/)
-
-- [AI21 Jamba (398B)](https://www.ai21.com/blog/announcing-jamba-model-family/)
-
-- [TII Falcon-H1 (34B)](https://falconllm.tii.ae/falcon-h1.html)
-
-- [IBM Bamba (9B)](https://research.ibm.com/blog/bamba-ssm-transformer-model)
-
-- [Mistral's Codestral (7B)](https://mistral.ai/news/codestral-mamba)
-
-- [Nvidia Mamba-2 Hybrid (8B)](https://arxiv.org/abs/2406.07887)
-
-- [Microsoft Samba (4B)](https://arxiv.org/abs/2406.07522v1)
-
-- [TII Falcon-Mamba (7B)](https://falconllm.tii.ae/tii-releases-first-sslm-with-falcon-mamba-7b.html)
-
-## Inference frameworks
-
-- vLLM
-
-- Nvidia's TensorRT-LLM
-
-## Hardware
-
-- Nvidia GPUs
-
-- [AMD GPUs](https://rocm.blogs.amd.com/artificial-intelligence/mamba/README.html)
-
-- [AWS Trainium 2](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/general/nki/tutorials/fused_mamba.html)
-
-