From 52908bac916b0b1365d979caa0ed7280d07ca0d3 Mon Sep 17 00:00:00 2001
From: Maxim Melichov <80150303+maxmelichov@users.noreply.github.com>
Date: Thu, 21 Aug 2025 12:29:40 +0300
Subject: [PATCH 01/10] Update mixer_seq_simple.py
---
mamba_ssm/models/mixer_seq_simple.py | 149 ++++++++++++++++++++++++++-
1 file changed, 145 insertions(+), 4 deletions(-)
diff --git a/mamba_ssm/models/mixer_seq_simple.py b/mamba_ssm/models/mixer_seq_simple.py
index fae2257a9..0db8846fe 100644
--- a/mamba_ssm/models/mixer_seq_simple.py
+++ b/mamba_ssm/models/mixer_seq_simple.py
@@ -16,7 +16,7 @@
from mamba_ssm.modules.mamba2 import Mamba2
from mamba_ssm.modules.mha import MHA
from mamba_ssm.modules.mlp import GatedMLP
-from mamba_ssm.modules.block import Block
+from mamba_ssm.modules.block import Block, DiffBlock, DiffBlockPaper
from mamba_ssm.utils.generation import GenerationMixin
from mamba_ssm.utils.hf import load_config_hf, load_state_dict_hf
@@ -25,8 +25,104 @@
except ImportError:
RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None
+def create_diff_block(
+ d_model,
+ d_intermediate,
+ ssm_cfg=None,
+ attn_layer_idx=None,
+ attn_cfg=None,
+ norm_epsilon=1e-5,
+ rms_norm=False,
+ residual_in_fp32=False,
+ fused_add_norm=False,
+ layer_idx=None,
+ device=None,
+ dtype=None):
+ if ssm_cfg is None:
+ ssm_cfg = {}
+ if attn_layer_idx is None:
+ attn_layer_idx = []
+ if attn_cfg is None:
+ attn_cfg = {}
+ factory_kwargs = {"device": device, "dtype": dtype}
+ if layer_idx % 4 != 2: # layer_idx % 4 != 1 and layer_idx % 4 != 3:
+ if layer_idx not in attn_layer_idx:
+ # Create a copy of the config to modify
+ ssm_cfg = copy.deepcopy(ssm_cfg) if ssm_cfg is not None else {}
+ ssm_layer = ssm_cfg.pop("layer", "Mamba1")
+ if ssm_layer not in ["Mamba1", "Mamba2"]:
+ raise ValueError(f"Invalid ssm_layer: {ssm_layer}, only support Mamba1 and Mamba2")
+ mixer_cls = partial(
+ Mamba2 if ssm_layer == "Mamba2" else Mamba,
+ layer_idx=layer_idx,
+ **ssm_cfg,
+ **factory_kwargs
+ )
+ else:
+ mixer_cls = partial(MHA, layer_idx=layer_idx, **attn_cfg, **factory_kwargs)
+ norm_cls = partial(
+ nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon, **factory_kwargs
+ )
+ if d_intermediate == 0:
+ mlp_cls = nn.Identity
+ else:
+ mlp_cls = partial(
+ GatedMLP, hidden_features=d_intermediate, out_features=d_model, **factory_kwargs
+ )
+ block = Block(
+ d_model,
+ mixer_cls,
+ mlp_cls,
+ norm_cls=norm_cls,
+ fused_add_norm=fused_add_norm,
+ residual_in_fp32=residual_in_fp32,
+ )
+ block.layer_idx = layer_idx
+ else:
+ if layer_idx not in attn_layer_idx:
+ # Create a copy of the config to modify
+ ssm_cfg = copy.deepcopy(ssm_cfg) if ssm_cfg is not None else {}
+ ssm_layer = ssm_cfg.pop("layer", "Mamba1")
+ if ssm_layer not in ["Mamba1", "Mamba2"]:
+ raise ValueError(f"Invalid ssm_layer: {ssm_layer}, only support Mamba1 and Mamba2")
+ mixer_cls1 = partial(
+ Mamba2 if ssm_layer == "Mamba2" else Mamba,
+ layer_idx=layer_idx,
+ **ssm_cfg,
+ **factory_kwargs
+ )
+ mixer_cls2 = partial(
+ Mamba2 if ssm_layer == "Mamba2" else Mamba,
+ layer_idx=layer_idx,
+ **ssm_cfg,
+ **factory_kwargs
+ )
+ else:
+ mixer_cls1 = partial(MHA, layer_idx=layer_idx, **attn_cfg, **factory_kwargs)
+ mixer_cls2 = partial(MHA, layer_idx=layer_idx, **attn_cfg, **factory_kwargs)
+ norm_cls = partial(
+ nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon, **factory_kwargs
+ )
+ if d_intermediate == 0:
+ mlp_cls = nn.Identity
+ else:
+ mlp_cls = partial(
+ GatedMLP, hidden_features=d_intermediate, out_features=d_model, **factory_kwargs
+ )
+ block = DiffBlockPaper(
+ d_model,
+ mixer_cls1,
+ mixer_cls2,
+ mlp_cls,
+ norm_cls=norm_cls,
+ fused_add_norm=fused_add_norm,
+ residual_in_fp32=residual_in_fp32,
+ layer_idx=layer_idx,
+ )
+ block.layer_idx = layer_idx
+ return block
-def create_block(
+def create_regular_block(
d_model,
d_intermediate,
ssm_cfg=None,
@@ -38,8 +134,7 @@ def create_block(
fused_add_norm=False,
layer_idx=None,
device=None,
- dtype=None,
-):
+ dtype=None,):
if ssm_cfg is None:
ssm_cfg = {}
if attn_layer_idx is None:
@@ -81,6 +176,52 @@ def create_block(
block.layer_idx = layer_idx
return block
+def create_block(
+ d_model,
+ d_intermediate,
+ ssm_cfg=None,
+ attn_layer_idx=None,
+ attn_cfg=None,
+ norm_epsilon=1e-5,
+ rms_norm=False,
+ residual_in_fp32=False,
+ fused_add_norm=False,
+ layer_idx=None,
+ device=None,
+ dtype=None,
+ mamba_type="DiffMamba",
+):
+ if mamba_type == "DiffMamba":
+ return create_diff_block(
+ d_model,
+ d_intermediate,
+ ssm_cfg,
+ attn_layer_idx,
+ attn_cfg,
+ norm_epsilon,
+ rms_norm,
+ residual_in_fp32,
+ fused_add_norm,
+ layer_idx,
+ device,
+ dtype
+ )
+ else:
+ return create_regular_block(
+ d_model,
+ d_intermediate,
+ ssm_cfg,
+ attn_layer_idx,
+ attn_cfg,
+ norm_epsilon,
+ rms_norm,
+ residual_in_fp32,
+ fused_add_norm,
+ layer_idx,
+ device,
+ dtype
+ )
+
# https://github.com/huggingface/transformers/blob/c28d04e9e252a1a099944e325685f14d242ecdcd/src/transformers/models/gpt2/modeling_gpt2.py#L454
def _init_weights(
From 4ca04479c658ed94cd37dfd06a0e3225a2640371 Mon Sep 17 00:00:00 2001
From: Maxim Melichov <80150303+maxmelichov@users.noreply.github.com>
Date: Thu, 21 Aug 2025 12:30:23 +0300
Subject: [PATCH 02/10] Update block.py
---
mamba_ssm/modules/block.py | 235 ++++++++++++++++++++++++++++++++++++-
1 file changed, 234 insertions(+), 1 deletion(-)
diff --git a/mamba_ssm/modules/block.py b/mamba_ssm/modules/block.py
index 1bd968a0b..607103e9d 100644
--- a/mamba_ssm/modules/block.py
+++ b/mamba_ssm/modules/block.py
@@ -1,8 +1,11 @@
# Copyright (c) 2024, Tri Dao, Albert Gu.
-from typing import Optional
+from typing import Optional, Tuple, Type, Callable
import torch
from torch import nn, Tensor
+import torch.nn.functional as F
+import os
+os.environ["TOKENIZERS_PARALLELISM"] = "false"
from mamba_ssm.ops.triton.layer_norm import RMSNorm, layer_norm_fn
@@ -89,3 +92,233 @@ def forward(
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
+
+
+class DiffBlockPaper(nn.Module):
+ """
+ Diff-Mamba block: Add->Norm -> (mixer1 || mixer2) -> Norm each -> subtract with λ -> Linear -> Norm
+ -> (optional MLP sublayer like vanilla Block)
+
+ Returns (hidden_states, residual) with the SAME contract as mamba_ssm.modules.block.Block:
+ - If no MLP: residual is the pre-norm Add sum, hidden_states is the sublayer output (no add here).
+ - If MLP: we do residual += hidden_states before norm2+MLP, as in vanilla.
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ mixer_cls1: Callable[[int], nn.Module],
+ mixer_cls2: Callable[[int], nn.Module],
+ mlp_cls: Callable[[int], nn.Module],
+ *,
+ norm_cls: Callable[[int], nn.Module] = nn.LayerNorm,
+ fused_add_norm: bool = False,
+ residual_in_fp32: bool = False,
+ layer_idx: int = 0,
+ use_postscale: bool = False, # optional extra scaling by (1 - lambda_init)
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ) -> None:
+ super().__init__()
+ self.d_model = dim
+ self.layer_idx = layer_idx
+ self.fused_add_norm = fused_add_norm
+ self.residual_in_fp32 = residual_in_fp32
+ self.use_postscale = bool(use_postscale)
+
+ # Prenorm for Add->Norm input
+ self.norm = norm_cls(dim)
+
+ # Two parallel mixers
+ self.mixer1 = mixer_cls1(dim)
+ self.mixer2 = mixer_cls2(dim)
+
+ # Post-mixer norms (separate for each branch) and post-sub norm
+ self.subln = norm_cls(dim)
+
+ # Per-layer scalar λ (σ(λ̄)+λ_init), λ̄ initialized very negative -> small λ
+ self.lambda_init = 0.8 - 0.6 * torch.exp(torch.tensor(-0.3 * self.layer_idx))
+ self.lambda_q1 = nn.Parameter(torch.randn(self.d_model))
+
+ # Optional second sublayer (MLP) mirrors vanilla Block
+ if mlp_cls is nn.Identity:
+ self.mlp = None
+ else:
+ self.norm2 = norm_cls(dim)
+ self.mlp = mlp_cls(dim)
+
+ if self.fused_add_norm:
+ assert layer_norm_fn is not None, "fused_add_norm=True requires Triton layer_norm_fn"
+ assert isinstance(self.norm, (nn.LayerNorm, RMSNorm)) if RMSNorm is not None else isinstance(self.norm, nn.LayerNorm)
+
+ # -------- cache helper (lets each mixer see its own cache view) --------
+ class _SwapCache:
+ def __init__(self, ip, idx: int, view):
+ self.ip, self.idx, self.view, self.orig = ip, idx, view, None
+ def __enter__(self):
+ if self.ip is not None:
+ self.orig = self.ip.key_value_memory_dict.get(self.idx, None)
+ self.ip.key_value_memory_dict[self.idx] = self.view
+ def __exit__(self, exc_type, exc, tb):
+ if self.ip is not None:
+ if self.orig is None:
+ self.ip.key_value_memory_dict.pop(self.idx, None)
+ else:
+ self.ip.key_value_memory_dict[self.idx] = self.orig
+
+ def _run_mixers(self, x: Tensor, inference_params=None, **mixer_kwargs) -> Tuple[Tensor, Tensor]:
+ if inference_params is None:
+ y1 = self.mixer1(x, inference_params=None, **mixer_kwargs)
+ y2 = self.mixer2(x, inference_params=None, **mixer_kwargs)
+ return y1, y2
+
+ slot = inference_params.key_value_memory_dict.get(self.layer_idx, None)
+ if isinstance(slot, tuple) and len(slot) == 2:
+ c1, c2 = slot
+ with DiffBlock._SwapCache(inference_params, self.layer_idx, c1):
+ y1 = self.mixer1(x, inference_params=inference_params, **mixer_kwargs)
+ with DiffBlock._SwapCache(inference_params, self.layer_idx, c2):
+ y2 = self.mixer2(x, inference_params=inference_params, **mixer_kwargs)
+ else:
+ y1 = self.mixer1(x, inference_params=inference_params, **mixer_kwargs)
+ y2 = self.mixer2(x, inference_params=inference_params, **mixer_kwargs)
+ return y1, y2
+
+ @staticmethod
+ def _to_norm_dtype(norm: nn.Module, x: Tensor) -> Tensor:
+ w = getattr(norm, "weight", None)
+ return x.to(w.dtype) if isinstance(w, torch.Tensor) else x
+
+ def forward(
+ self,
+ hidden_states: Tensor,
+ residual: Optional[Tensor] = None,
+ inference_params=None,
+ **mixer_kwargs,
+ ) -> Tuple[Tensor, Tensor]:
+ # ---- Add -> Norm (prenorm) ----
+ if not self.fused_add_norm:
+ residual = (hidden_states + residual) if residual is not None else hidden_states
+ hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype))
+ if self.residual_in_fp32:
+ residual = residual.to(torch.float32)
+ else:
+ hidden_states, residual = layer_norm_fn(
+ hidden_states,
+ self.norm.weight,
+ self.norm.bias,
+ residual=residual,
+ prenorm=True,
+ residual_in_fp32=self.residual_in_fp32,
+ eps=self.norm.eps,
+ is_rms_norm=isinstance(self.norm, RMSNorm)
+ )
+
+ # ---- Scalar λ per layer ----
+ lambda_q1 = torch.sum(self.lambda_q1, dim=-1).float()
+ lambda_full = torch.sigmoid(lambda_q1) + self.lambda_init
+
+ # ---- Parallel mixers ----
+ y1, y2 = self._run_mixers(hidden_states, inference_params, **mixer_kwargs)
+
+ # ---- Differential combine -> out proj -> post-sub norm ----
+ attn = y1 - lambda_full * y2
+ attn = self.subln(attn)
+
+ # First sublayer output
+ hidden_states = attn * (1.0 - self.lambda_init)
+
+
+ # ---- Optional MLP sublayer (mirrors vanilla Block) ----
+ if self.mlp is not None:
+ if not self.fused_add_norm:
+ residual = hidden_states + residual
+ hidden_states = self.norm2(residual.to(dtype=self.norm2.weight.dtype))
+ if self.residual_in_fp32:
+ residual = residual.to(torch.float32)
+ else:
+ hidden_states, residual = layer_norm_fn(
+ hidden_states,
+ self.norm2.weight,
+ self.norm2.bias,
+ residual=residual,
+ prenorm=True,
+ residual_in_fp32=self.residual_in_fp32,
+ eps=self.norm2.eps,
+ is_rms_norm=isinstance(self.norm2, RMSNorm)
+ )
+ hidden_states = self.mlp(hidden_states)
+ else:
+ residual = hidden_states + residual
+
+ return hidden_states, residual
+
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
+ cache1 = getattr(self.mixer1, "allocate_inference_cache", None)
+ cache2 = getattr(self.mixer2, "allocate_inference_cache", None)
+ c1 = cache1(batch_size, max_seqlen, dtype=dtype, **kwargs) if callable(cache1) else None
+ c2 = cache2(batch_size, max_seqlen, dtype=dtype, **kwargs) if callable(cache2) else None
+ return (c1, c2)
+
+ @classmethod
+ def from_pretrained_block(
+ cls,
+ block: nn.Module,
+ mixer_cls: Optional[Callable[[int], nn.Module]] = None,
+ mlp_cls: Optional[Callable[[int], nn.Module]] = None,
+ norm_cls: Optional[Callable[[int], nn.Module]] = None,
+ fused_add_norm: Optional[bool] = None,
+ residual_in_fp32: Optional[bool] = None,
+ lambda_init: float = 0.1,
+ use_postscale: bool = False,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ) -> "DiffBlock":
+ """Build a DiffBlock from a vanilla Block and copy weights into both mixers."""
+ src_mixer = getattr(block, "mixer", None)
+ src_mlp = getattr(block, "mlp", None)
+ src_norm = getattr(block, "norm", None)
+
+ mixer_cls = mixer_cls or (src_mixer.__class__)
+ mlp_cls = mlp_cls or (src_mlp.__class__ if src_mlp is not None else nn.Identity)
+ norm_cls = norm_cls or (src_norm.__class__ if src_norm is not None else nn.LayerNorm)
+
+ fused_add_norm = fused_add_norm if fused_add_norm is not None else getattr(block, "fused_add_norm", False)
+ residual_in_fp32 = residual_in_fp32 if residual_in_fp32 is not None else getattr(block, "residual_in_fp32", False)
+
+ newb = cls(
+ dim=block.d_model,
+ mixer_cls1=mixer_cls,
+ mixer_cls2=mixer_cls,
+ mlp_cls=mlp_cls,
+ norm_cls=norm_cls,
+ fused_add_norm=fused_add_norm,
+ residual_in_fp32=residual_in_fp32,
+ layer_idx=getattr(block, "layer_idx", 0),
+ lambda_init=lambda_init,
+ use_postscale=use_postscale,
+ device=device,
+ dtype=dtype,
+ )
+
+ # copy prenorm
+ if src_norm is not None:
+ newb.norm.load_state_dict(src_norm.state_dict(), strict=False)
+ # seed post norms with same stats
+ newb.post_mamba_norm1.load_state_dict(newb.norm.state_dict(), strict=False)
+ newb.post_mamba_norm2.load_state_dict(newb.norm.state_dict(), strict=False)
+ newb.post_sub_norm.load_state_dict(newb.norm.state_dict(), strict=False)
+
+ # copy mixer weights into both mixers
+ if src_mixer is not None:
+ st = src_mixer.state_dict()
+ newb.mixer1.load_state_dict(st, strict=False)
+ newb.mixer2.load_state_dict(st, strict=False)
+
+ # copy mlp & norm2 if present
+ if src_mlp is not None and newb.mlp is not None:
+ newb.mlp.load_state_dict(src_mlp.state_dict(), strict=False)
+ if hasattr(block, "norm2") and hasattr(newb, "norm2"):
+ newb.norm2.load_state_dict(block.norm2.state_dict(), strict=False)
+
+ return newb
From ee3fd0918f7258dbd316c357763e304855bf4c50 Mon Sep 17 00:00:00 2001
From: Maxim Melichov <80150303+maxmelichov@users.noreply.github.com>
Date: Thu, 21 Aug 2025 12:31:52 +0300
Subject: [PATCH 03/10] Update README.md
---
README.md | 201 ++++++++++++++++++++++++++++++++++++++++++++++++++++++
1 file changed, 201 insertions(+)
diff --git a/README.md b/README.md
index bf3b76a93..92dbb9cd1 100755
--- a/README.md
+++ b/README.md
@@ -1,3 +1,204 @@
+# Differential Mamba
+
+
+
+Nadav Schneider,
+Itamar Zimerman,
+Eliya Nachmani
+
+
+
+This repository contains the official PyTorch implementation of Differential Mamba paper.
+We also provide training code, evaluation code, and model checkpoints to reproduce the results in the paper, including all the baselines.
+
+
+
+
+
+
+# Setup
+## Clone Project
+```
+git clone https://github.com/nadavsc/Diff-Mamba.git
+cd Diff-Mamba
+```
+
+## Create Environment
+To set up our environment, please run:
+```
+conda env create -f environment.yml
+conda activate diffmamba
+```
+Note: this should include all the necessary packages to run all the training and evaluation scripts. Nonetheless, make sure the additional requirements are satisfied:
+
+
+Mamba Installation:
+```
+pip install causal-conv1d==1.5.0
+pip install mamba-ssm==2.2.4
+```
+
+## Additional Requirements - Language Modeling
+
+Install the requirements in: https://github.com/state-spaces/s4
+
+In order to train/evaluate the Language Modeling task, first, download the data. This can be done using the following scripts:
+```
+python language_modeling/src/data/datasets/get_wt103.py
+bash language_modeling/src/data/transformer-xl/enwik8/get_enwik8.sh
+bash language_modeling/src/data/transformer-xl/text8/get_text8.sh
+```
+Then, move the resulting datasets into language_modeling/data directory.
+
+## Additional Requirements - Retrieval
+
+Install the requirements in: https://github.com/booydar/babilong
+
+To fine-tune on PG19, please make sure to download the dataset according to the instructions at [deepmind/pg19](https://huggingface.co/datasets/deepmind/pg19) or use the Huggingface dataset version.
+
+## Additional Requirements - Tuned-Lens
+
+Install the requirements in: https://github.com/AlignmentResearch/tuned-lens
+
+Make sure to download The-Pile validation set to train the lens.
+Locate the .json or .txt file in the directory tuned-lens/data.
+
+
+
+# Experiments
+## Language Modeling
+Run cd language_modeling.
+Then, run the following:
+```
+python train.py experiment=lm/diffmamba2-text8 trainer.devices=[0] model.dropout=0.5 loader.l_max=512 train.seed=0 trainer.accumulate_grad_batches=1 loader.batch_size=50 model.n_layers=12 model.d_model=1024 trainer.max_epochs=40 trainer.precision=32
+```
+
+```trainer.devices```: used to determine the GPUs for training. [0] use cuda:0 while [2] use cuda:2. [0, 2] will use cuda:0 and cuda:2 with DDP training, while 2 will choose the first two gpus available (cuda:0 and cuda:1).
+
+```loader.l_max```: the max length or context window for the current training
+
+```model.n_layers```: determine the model size
+
+```optimizer.lr```: to change the learning rate, otherwise, use the default
+
+```trainer.max_epochs```: number of epochs
+
+```loader.batch_size```: represent the batch size
+
+```model.dropout```: the dropout of the current model
+
+```trainer.seed```: responsible of the training seed
+
+```accumulate_grad_batches```: can be used if the memory in the GPU is not sufficient for the required batch size
+
+
+## Retrieval
+
+
+Run cd retrieval.
+To evaluate the models, make sure to save the models checkpoints in the Diff-Mamba/outputs directory.
+
+### Finetune PG19
+To finetune Mamba on PG19 run:
+```
+torchrun --nproc_per_node=4 finetune_pg19.py --model_id=AntonV/mamba2-370m-hf --lr=3e-4 --batch_size=6 --grad_accum_steps=12 --max_steps=4000 --weight_decay=0.1 --warmup=400 --save_steps=500 --eval_steps=500 --output_dir=./outputs/mamba2-370m-pg19-finetune
+```
+To finetune Diff-Mamba on PG19 run:
+```
+torchrun --nproc_per_node=4 finetune_pg19.py --model_id=AntonV/mamba2-370m-hf --diffmamba --lr=3e-4 --batch_size=6 --grad_accum_steps=12 --max_steps=4000 --weight_decay=0.1 --warmup=400 --save_steps=500 --eval_steps=500 --output_dir=./outputs
+```
+
+### Finetune BABILong
+To finetune Mamba on BABILong run:
+```
+torchrun --nproc_per_node=1 finetune_needle.py --ckpt_path=./outputs/mamba2-370m-pg19-finetune --lr=3e-4 --batch_size=6 --grad_accum_steps=1 --max_steps=500 --weight_decay=0.1 --warmup=50 --save_steps=100 --eval_steps=100 --seed=0 --output_dir=./outputs/mamba2-370m-needle-finetune
+```
+To finetune Diff-Mamba on BABILong run:
+```
+torchrun --nproc_per_node=1 finetune_needle.py --ckpt_path=./outputs/diffmamba2-370m-pg19-finetune --diffmamba --lr=3e-4 --batch_size=6 --grad_accum_steps=1 --max_steps=500 --weight_decay=0.1 --warmup=50 --save_steps=100 --eval_steps=100 --seed=0 --output_dir=./outputs/diffmamba2-370m-needle-finetune
+```
+
+```--nproc_per_node```: choose number of GPUs for DDP training
+
+```--grad_accum_steps```: this variable is used to increase effective batch size under memory limitations
+
+```--diffmamba```: this is a flag that has to be chosen when training Diff-Mamba
+
+```--model_id```: this is the mamba pretrained model loaded from Huggingface
+
+### Evaluate
+
+To evaluate a model on the different tasks and context lengths run:
+
+```
+bash scripts/run_activation-beacon-diffmamba2-370m-needle-finetune-seed0_no_instruct.sh
+```
+or
+```
+bash scripts/run_activation-beacon-diffmamba2-370m_pg19-finetune_no_instruct.sh
+```
+Results will be saved in the directory scripts/babilong_evals.
+
+### Plot
+To plot the scores, simply run:
+```
+python plot.py --model_name diffmamba2-370m-needle-finetune-seed0 --results_folder scripts/babilong_evals/diffmamba2-370m-needle-finetune-seed0
+```
+To plot the relative percentage run:
+```
+python plot_compare.py --model_name diffmamba2-370m-needle-finetune --ratio
+```
+The plot will be saved in scripts/babilong_evals. Use the flag ```--ratio``` for the relative precentage plot or omit it for the original scores plot
+
+## Tuned-Lens
+
+
+Run cd tuned-lens.
+### Training Lens
+Then to train lens for mamba, run:
+```
+python -m tuned_lens train --model.name ../../../outputs/mamba2-370m-pg19-finetune --data.name data/valid.txt --per_gpu_batch_size=1 --ssm --output my_lenses/mamba2-370m-pg19-finetune
+```
+To train diffmamba, specify the correct path to the model and change the required output directory.
+To train the lens in a distributed fashion, change ```--per_gpu_batch_size``` to the number of available GPUs.
+
+### Evaluate
+To evaluate run:
+```
+python test_babilong_0k.py --ckpt_path ../../../outputs/mamba2-370m-needle-finetune
+```
+add ```--diffmamba``` flag if using Diff-Mamba.
+
+You can stop the test early when using the flag ```--num_examples```. The compatible lens will be loaded from the my_lenses directory.
+
+### Plot
+To plot the results run:
+```
+python plot_tuned_lens.py --diff_results_path results/diffmamba2-370m-needle-finetune-lens_eval.txt --mamba_results_path results/mamba2-370m-needle-finetune-lens_eval.txt
+```
+Use ```--log``` to create a log scale plot and ```--start-layer``` and ```--end-layer``` to choose specific layers to plot.
+
+## Acknowledgements
+
+All model implementations are based on [Mamba](https://github.com/state-spaces/mamba). Training and evaluation for the language modeling experiments are based on [S4](https://github.com/state-spaces/s4) repository. Evaluation on BABILong is based on [BABILong](https://github.com/booydar/babilong) repo, and measuring signal-to-noise ratio through the layers is based on [tuned-lens](https://github.com/AlignmentResearch/tuned-lens).
+
+## Citation
+
+If you use this code, please consider citing the following:
+
+```
+@misc{schneider2025differentialmamba,
+ title={Differential Mamba},
+ author={Nadav Schneider and Itamar Zimerman and Eliya Nachmani},
+ year={2025},
+ eprint={2507.06204},
+ archivePrefix={arXiv},
+ primaryClass={cs.LG},
+ url={https://arxiv.org/abs/2507.06204},
+}
+```
+
+
# Mamba

From 421c9dbf1ddf4608417299d15b130a206a867077 Mon Sep 17 00:00:00 2001
From: Maxim Melichov <80150303+maxmelichov@users.noreply.github.com>
Date: Sat, 23 Aug 2025 23:15:47 +0300
Subject: [PATCH 04/10] Update block.py
---
mamba_ssm/modules/block.py | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
diff --git a/mamba_ssm/modules/block.py b/mamba_ssm/modules/block.py
index 607103e9d..54493dbd5 100644
--- a/mamba_ssm/modules/block.py
+++ b/mamba_ssm/modules/block.py
@@ -175,9 +175,9 @@ def _run_mixers(self, x: Tensor, inference_params=None, **mixer_kwargs) -> Tuple
slot = inference_params.key_value_memory_dict.get(self.layer_idx, None)
if isinstance(slot, tuple) and len(slot) == 2:
c1, c2 = slot
- with DiffBlock._SwapCache(inference_params, self.layer_idx, c1):
+ with DiffBlockPaper._SwapCache(inference_params, self.layer_idx, c1):
y1 = self.mixer1(x, inference_params=inference_params, **mixer_kwargs)
- with DiffBlock._SwapCache(inference_params, self.layer_idx, c2):
+ with DiffBlockPaper._SwapCache(inference_params, self.layer_idx, c2):
y2 = self.mixer2(x, inference_params=inference_params, **mixer_kwargs)
else:
y1 = self.mixer1(x, inference_params=inference_params, **mixer_kwargs)
@@ -273,7 +273,7 @@ def from_pretrained_block(
use_postscale: bool = False,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
- ) -> "DiffBlock":
+ ) -> "DiffBlockPaper":
"""Build a DiffBlock from a vanilla Block and copy weights into both mixers."""
src_mixer = getattr(block, "mixer", None)
src_mlp = getattr(block, "mlp", None)
From 9e29bd6def8567b28a2ef613a63d48274c6d5170 Mon Sep 17 00:00:00 2001
From: Maxim Melichov <80150303+maxmelichov@users.noreply.github.com>
Date: Sat, 23 Aug 2025 23:57:07 +0300
Subject: [PATCH 05/10] Remove unused DiffBlock import
---
mamba_ssm/models/mixer_seq_simple.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mamba_ssm/models/mixer_seq_simple.py b/mamba_ssm/models/mixer_seq_simple.py
index 0db8846fe..2552c9f69 100644
--- a/mamba_ssm/models/mixer_seq_simple.py
+++ b/mamba_ssm/models/mixer_seq_simple.py
@@ -16,7 +16,7 @@
from mamba_ssm.modules.mamba2 import Mamba2
from mamba_ssm.modules.mha import MHA
from mamba_ssm.modules.mlp import GatedMLP
-from mamba_ssm.modules.block import Block, DiffBlock, DiffBlockPaper
+from mamba_ssm.modules.block import Block, DiffBlockPaper
from mamba_ssm.utils.generation import GenerationMixin
from mamba_ssm.utils.hf import load_config_hf, load_state_dict_hf
From 7d6b9a6d08f290e15c8479d34a7d61cf261a6d1b Mon Sep 17 00:00:00 2001
From: maxmelichov
Date: Mon, 25 Aug 2025 10:35:51 +0300
Subject: [PATCH 06/10] Remove MANIFEST.in and usage.md; update license in
pyproject.toml and setup.py to BSD; adjust CUDA architecture flags in
setup.py; update version to 2.2.4 in __init__.py; modify import paths in
various files to reflect new package structure.
---
.github/workflows/publish.yaml | 24 ++-
MANIFEST.in | 3 -
README.md | 201 ----------------------
mamba_ssm/__init__.py | 2 +-
mamba_ssm/models/mixer_seq_simple.py | 18 +-
mamba_ssm/modules/block.py | 2 +-
mamba_ssm/ops/selective_scan_interface.py | 16 +-
mamba_ssm/ops/triton/ssd_chunk_scan.py | 15 +-
mamba_ssm/ops/triton/ssd_chunk_state.py | 27 +--
mamba_ssm/ops/triton/ssd_combined.py | 39 ++---
pyproject.toml | 3 +-
setup.py | 27 ++-
usage.md | 43 -----
13 files changed, 71 insertions(+), 349 deletions(-)
delete mode 100644 MANIFEST.in
delete mode 100644 usage.md
diff --git a/.github/workflows/publish.yaml b/.github/workflows/publish.yaml
index a0b09fa78..192f95626 100644
--- a/.github/workflows/publish.yaml
+++ b/.github/workflows/publish.yaml
@@ -40,12 +40,12 @@ jobs:
strategy:
fail-fast: false
matrix:
- # Using ubuntu-22.04 instead of 24.04 for more compatibility (glibc). Ideally we'd use the
+ # Using ubuntu-20.04 instead of 22.04 for more compatibility (glibc). Ideally we'd use the
# manylinux docker image, but I haven't figured out how to install CUDA on manylinux.
- os: [ubuntu-22.04]
+ os: [ubuntu-20.04]
python-version: ['3.9', '3.10', '3.11', '3.12', '3.13']
- torch-version: ['2.4.0', '2.5.1', '2.6.0', '2.7.1']
- cuda-version: ['11.8.0', '12.9.1']
+ torch-version: ['2.1.2', '2.2.2', '2.3.1', '2.4.0', '2.5.1', '2.6.0.dev20241001']
+ cuda-version: ['11.8.0', '12.3.2']
# We need separate wheels that either uses C++11 ABI (-D_GLIBCXX_USE_CXX11_ABI) or not.
# Pytorch wheels currently don't use it, but nvcr images have Pytorch compiled with C++11 ABI.
# Without this we get import error (undefined symbol: _ZN3c105ErrorC2ENS_14SourceLocationESs)
@@ -53,6 +53,16 @@ jobs:
cxx11_abi: ['FALSE', 'TRUE']
exclude:
# see https://github.com/pytorch/pytorch/blob/main/RELEASE.md#release-compatibility-matrix
+ # Pytorch < 2.2 does not support Python 3.12
+ - torch-version: '2.1.2'
+ python-version: '3.12'
+ # Pytorch < 2.5 does not support Python 3.13
+ - torch-version: '2.1.2'
+ python-version: '3.13'
+ - torch-version: '2.2.2'
+ python-version: '3.13'
+ - torch-version: '2.3.1'
+ python-version: '3.13'
- torch-version: '2.4.0'
python-version: '3.13'
@@ -89,7 +99,7 @@ jobs:
- name: Install CUDA ${{ matrix.cuda-version }}
if: ${{ matrix.cuda-version != 'cpu' }}
- uses: Jimver/cuda-toolkit@v0.2.26
+ uses: Jimver/cuda-toolkit@v0.2.19
id: cuda-toolkit
with:
cuda: ${{ matrix.cuda-version }}
@@ -111,8 +121,8 @@ jobs:
# e.g. we can have system CUDA version being 11.7 but if torch==1.12 then we need to download the wheel from cu116
# This code is ugly, maybe there's a better way to do this.
export TORCH_CUDA_VERSION=$(python -c "from os import environ as env; \
- minv = {'2.4': 118, '2.5': 118, '2.6': 118, '2.7': 118}[env['MATRIX_TORCH_VERSION']]; \
- maxv = {'2.4': 124, '2.5': 124, '2.6': 126, '2.7': 128}[env['MATRIX_TORCH_VERSION']]; \
+ minv = {'2.1': 118, '2.2': 118, '2.3': 118, '2.4': 118, '2.5': 118, '2.6': 118}[env['MATRIX_TORCH_VERSION']]; \
+ maxv = {'2.1': 121, '2.2': 121, '2.3': 121, '2.4': 124, '2.5': 124, '2.6': 124}[env['MATRIX_TORCH_VERSION']]; \
print(minv if int(env['MATRIX_CUDA_VERSION']) < 120 else maxv)" \
)
if [[ ${{ matrix.torch-version }} == *"dev"* ]]; then
diff --git a/MANIFEST.in b/MANIFEST.in
deleted file mode 100644
index 318499d57..000000000
--- a/MANIFEST.in
+++ /dev/null
@@ -1,3 +0,0 @@
-recursive-include csrc *
-recursive-include csrc *
-README.md
diff --git a/README.md b/README.md
index 92dbb9cd1..bf3b76a93 100755
--- a/README.md
+++ b/README.md
@@ -1,204 +1,3 @@
-# Differential Mamba
-
-
-
-Nadav Schneider,
-Itamar Zimerman,
-Eliya Nachmani
-
-
-
-This repository contains the official PyTorch implementation of Differential Mamba paper.
-We also provide training code, evaluation code, and model checkpoints to reproduce the results in the paper, including all the baselines.
-
-
-
-
-
-
-# Setup
-## Clone Project
-```
-git clone https://github.com/nadavsc/Diff-Mamba.git
-cd Diff-Mamba
-```
-
-## Create Environment
-To set up our environment, please run:
-```
-conda env create -f environment.yml
-conda activate diffmamba
-```
-Note: this should include all the necessary packages to run all the training and evaluation scripts. Nonetheless, make sure the additional requirements are satisfied:
-
-
-Mamba Installation:
-```
-pip install causal-conv1d==1.5.0
-pip install mamba-ssm==2.2.4
-```
-
-## Additional Requirements - Language Modeling
-
-Install the requirements in: https://github.com/state-spaces/s4
-
-In order to train/evaluate the Language Modeling task, first, download the data. This can be done using the following scripts:
-```
-python language_modeling/src/data/datasets/get_wt103.py
-bash language_modeling/src/data/transformer-xl/enwik8/get_enwik8.sh
-bash language_modeling/src/data/transformer-xl/text8/get_text8.sh
-```
-Then, move the resulting datasets into language_modeling/data directory.
-
-## Additional Requirements - Retrieval
-
-Install the requirements in: https://github.com/booydar/babilong
-
-To fine-tune on PG19, please make sure to download the dataset according to the instructions at [deepmind/pg19](https://huggingface.co/datasets/deepmind/pg19) or use the Huggingface dataset version.
-
-## Additional Requirements - Tuned-Lens
-
-Install the requirements in: https://github.com/AlignmentResearch/tuned-lens
-
-Make sure to download The-Pile validation set to train the lens.
-Locate the .json or .txt file in the directory tuned-lens/data.
-
-
-
-# Experiments
-## Language Modeling
-Run cd language_modeling.
-Then, run the following:
-```
-python train.py experiment=lm/diffmamba2-text8 trainer.devices=[0] model.dropout=0.5 loader.l_max=512 train.seed=0 trainer.accumulate_grad_batches=1 loader.batch_size=50 model.n_layers=12 model.d_model=1024 trainer.max_epochs=40 trainer.precision=32
-```
-
-```trainer.devices```: used to determine the GPUs for training. [0] use cuda:0 while [2] use cuda:2. [0, 2] will use cuda:0 and cuda:2 with DDP training, while 2 will choose the first two gpus available (cuda:0 and cuda:1).
-
-```loader.l_max```: the max length or context window for the current training
-
-```model.n_layers```: determine the model size
-
-```optimizer.lr```: to change the learning rate, otherwise, use the default
-
-```trainer.max_epochs```: number of epochs
-
-```loader.batch_size```: represent the batch size
-
-```model.dropout```: the dropout of the current model
-
-```trainer.seed```: responsible of the training seed
-
-```accumulate_grad_batches```: can be used if the memory in the GPU is not sufficient for the required batch size
-
-
-## Retrieval
-
-
-Run cd retrieval.
-To evaluate the models, make sure to save the models checkpoints in the Diff-Mamba/outputs directory.
-
-### Finetune PG19
-To finetune Mamba on PG19 run:
-```
-torchrun --nproc_per_node=4 finetune_pg19.py --model_id=AntonV/mamba2-370m-hf --lr=3e-4 --batch_size=6 --grad_accum_steps=12 --max_steps=4000 --weight_decay=0.1 --warmup=400 --save_steps=500 --eval_steps=500 --output_dir=./outputs/mamba2-370m-pg19-finetune
-```
-To finetune Diff-Mamba on PG19 run:
-```
-torchrun --nproc_per_node=4 finetune_pg19.py --model_id=AntonV/mamba2-370m-hf --diffmamba --lr=3e-4 --batch_size=6 --grad_accum_steps=12 --max_steps=4000 --weight_decay=0.1 --warmup=400 --save_steps=500 --eval_steps=500 --output_dir=./outputs
-```
-
-### Finetune BABILong
-To finetune Mamba on BABILong run:
-```
-torchrun --nproc_per_node=1 finetune_needle.py --ckpt_path=./outputs/mamba2-370m-pg19-finetune --lr=3e-4 --batch_size=6 --grad_accum_steps=1 --max_steps=500 --weight_decay=0.1 --warmup=50 --save_steps=100 --eval_steps=100 --seed=0 --output_dir=./outputs/mamba2-370m-needle-finetune
-```
-To finetune Diff-Mamba on BABILong run:
-```
-torchrun --nproc_per_node=1 finetune_needle.py --ckpt_path=./outputs/diffmamba2-370m-pg19-finetune --diffmamba --lr=3e-4 --batch_size=6 --grad_accum_steps=1 --max_steps=500 --weight_decay=0.1 --warmup=50 --save_steps=100 --eval_steps=100 --seed=0 --output_dir=./outputs/diffmamba2-370m-needle-finetune
-```
-
-```--nproc_per_node```: choose number of GPUs for DDP training
-
-```--grad_accum_steps```: this variable is used to increase effective batch size under memory limitations
-
-```--diffmamba```: this is a flag that has to be chosen when training Diff-Mamba
-
-```--model_id```: this is the mamba pretrained model loaded from Huggingface
-
-### Evaluate
-
-To evaluate a model on the different tasks and context lengths run:
-
-```
-bash scripts/run_activation-beacon-diffmamba2-370m-needle-finetune-seed0_no_instruct.sh
-```
-or
-```
-bash scripts/run_activation-beacon-diffmamba2-370m_pg19-finetune_no_instruct.sh
-```
-Results will be saved in the directory scripts/babilong_evals.
-
-### Plot
-To plot the scores, simply run:
-```
-python plot.py --model_name diffmamba2-370m-needle-finetune-seed0 --results_folder scripts/babilong_evals/diffmamba2-370m-needle-finetune-seed0
-```
-To plot the relative percentage run:
-```
-python plot_compare.py --model_name diffmamba2-370m-needle-finetune --ratio
-```
-The plot will be saved in scripts/babilong_evals. Use the flag ```--ratio``` for the relative precentage plot or omit it for the original scores plot
-
-## Tuned-Lens
-
-
-Run cd tuned-lens.
-### Training Lens
-Then to train lens for mamba, run:
-```
-python -m tuned_lens train --model.name ../../../outputs/mamba2-370m-pg19-finetune --data.name data/valid.txt --per_gpu_batch_size=1 --ssm --output my_lenses/mamba2-370m-pg19-finetune
-```
-To train diffmamba, specify the correct path to the model and change the required output directory.
-To train the lens in a distributed fashion, change ```--per_gpu_batch_size``` to the number of available GPUs.
-
-### Evaluate
-To evaluate run:
-```
-python test_babilong_0k.py --ckpt_path ../../../outputs/mamba2-370m-needle-finetune
-```
-add ```--diffmamba``` flag if using Diff-Mamba.
-
-You can stop the test early when using the flag ```--num_examples```. The compatible lens will be loaded from the my_lenses directory.
-
-### Plot
-To plot the results run:
-```
-python plot_tuned_lens.py --diff_results_path results/diffmamba2-370m-needle-finetune-lens_eval.txt --mamba_results_path results/mamba2-370m-needle-finetune-lens_eval.txt
-```
-Use ```--log``` to create a log scale plot and ```--start-layer``` and ```--end-layer``` to choose specific layers to plot.
-
-## Acknowledgements
-
-All model implementations are based on [Mamba](https://github.com/state-spaces/mamba). Training and evaluation for the language modeling experiments are based on [S4](https://github.com/state-spaces/s4) repository. Evaluation on BABILong is based on [BABILong](https://github.com/booydar/babilong) repo, and measuring signal-to-noise ratio through the layers is based on [tuned-lens](https://github.com/AlignmentResearch/tuned-lens).
-
-## Citation
-
-If you use this code, please consider citing the following:
-
-```
-@misc{schneider2025differentialmamba,
- title={Differential Mamba},
- author={Nadav Schneider and Itamar Zimerman and Eliya Nachmani},
- year={2025},
- eprint={2507.06204},
- archivePrefix={arXiv},
- primaryClass={cs.LG},
- url={https://arxiv.org/abs/2507.06204},
-}
-```
-
-
# Mamba

diff --git a/mamba_ssm/__init__.py b/mamba_ssm/__init__.py
index 6280931e4..ac4f6e311 100644
--- a/mamba_ssm/__init__.py
+++ b/mamba_ssm/__init__.py
@@ -1,4 +1,4 @@
-__version__ = "2.2.5"
+__version__ = "2.2.4"
from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn
from mamba_ssm.modules.mamba_simple import Mamba
diff --git a/mamba_ssm/models/mixer_seq_simple.py b/mamba_ssm/models/mixer_seq_simple.py
index 2552c9f69..7a22a2b71 100644
--- a/mamba_ssm/models/mixer_seq_simple.py
+++ b/mamba_ssm/models/mixer_seq_simple.py
@@ -11,17 +11,17 @@
import torch
import torch.nn as nn
-from mamba_ssm.models.config_mamba import MambaConfig
-from mamba_ssm.modules.mamba_simple import Mamba
-from mamba_ssm.modules.mamba2 import Mamba2
-from mamba_ssm.modules.mha import MHA
-from mamba_ssm.modules.mlp import GatedMLP
-from mamba_ssm.modules.block import Block, DiffBlockPaper
-from mamba_ssm.utils.generation import GenerationMixin
-from mamba_ssm.utils.hf import load_config_hf, load_state_dict_hf
+from DiffMamba.mamba_ssm.models.config_mamba import MambaConfig
+from DiffMamba.mamba_ssm.modules.mamba_simple import Mamba
+from DiffMamba.mamba_ssm.modules.mamba2 import Mamba2
+from DiffMamba.mamba_ssm.modules.mha import MHA
+from DiffMamba.mamba_ssm.modules.mlp import GatedMLP
+from DiffMamba.mamba_ssm.modules.block import Block, DiffBlockPaper
+from DiffMamba.mamba_ssm.utils.generation import GenerationMixin
+from DiffMamba.mamba_ssm.utils.hf import load_config_hf, load_state_dict_hf
try:
- from mamba_ssm.ops.triton.layer_norm import RMSNorm, layer_norm_fn, rms_norm_fn
+ from DiffMamba.mamba_ssm.ops.triton.layer_norm import RMSNorm, layer_norm_fn, rms_norm_fn
except ImportError:
RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None
diff --git a/mamba_ssm/modules/block.py b/mamba_ssm/modules/block.py
index 54493dbd5..1a7e06340 100644
--- a/mamba_ssm/modules/block.py
+++ b/mamba_ssm/modules/block.py
@@ -7,7 +7,7 @@
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
-from mamba_ssm.ops.triton.layer_norm import RMSNorm, layer_norm_fn
+from DiffMamba.mamba_ssm.ops.triton.layer_norm import RMSNorm, layer_norm_fn
class Block(nn.Module):
diff --git a/mamba_ssm/ops/selective_scan_interface.py b/mamba_ssm/ops/selective_scan_interface.py
index a41f1359c..c51ec40dc 100644
--- a/mamba_ssm/ops/selective_scan_interface.py
+++ b/mamba_ssm/ops/selective_scan_interface.py
@@ -8,12 +8,10 @@
try:
from causal_conv1d import causal_conv1d_fn
- from causal_conv1d.cpp_functions import causal_conv1d_fwd_function, causal_conv1d_bwd_function, causal_conv1d_update_function
+ import causal_conv1d_cuda
except ImportError:
causal_conv1d_fn = None
- causal_conv1d_fwd_function = None
- causal_conv1d_bwd_function = None
- causal_conv1d_update_function = None
+ causal_conv1d_cuda = None
from mamba_ssm.ops.triton.layer_norm import _layer_norm_fwd
@@ -192,7 +190,7 @@ def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weigh
"""
xz: (batch, dim, seqlen)
"""
- assert causal_conv1d_fwd_function is not None, "causal_conv1d_cuda is not available. Please install causal-conv1d."
+ assert causal_conv1d_cuda is not None, "causal_conv1d_cuda is not available. Please install causal-conv1d."
assert checkpoint_lvl in [0, 1]
L = xz.shape[-1]
delta_rank = delta_proj_weight.shape[1]
@@ -208,7 +206,7 @@ def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weigh
conv1d_weight = rearrange(conv1d_weight, "d 1 w -> d w")
x, z = xz.chunk(2, dim=1)
conv1d_bias = conv1d_bias.contiguous() if conv1d_bias is not None else None
- conv1d_out = causal_conv1d_fwd_function(
+ conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(
x, conv1d_weight, conv1d_bias, None, None, None, True
)
# We're being very careful here about the layout, to avoid extra transposes.
@@ -281,7 +279,7 @@ def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weigh
@custom_bwd
def backward(ctx, dout):
# dout: (batch, seqlen, dim)
- assert causal_conv1d_fwd_function is not None, "causal_conv1d_cuda is not available. Please install causal-conv1d."
+ assert causal_conv1d_cuda is not None, "causal_conv1d_cuda is not available. Please install causal-conv1d."
(xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight, delta_proj_weight, out_proj_weight,
conv1d_out, delta, A, B, C, D, delta_bias, scan_intermediates, b_rms_weight, c_rms_weight, dt_rms_weight, out) = ctx.saved_tensors
L = xz.shape[-1]
@@ -291,7 +289,7 @@ def backward(ctx, dout):
if dout.stride(-1) != 1:
dout = dout.contiguous()
if ctx.checkpoint_lvl == 1:
- conv1d_out = causal_conv1d_fwd_function(
+ conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(
x, conv1d_weight, conv1d_bias, None, None, None, True
)
delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(),
@@ -357,7 +355,7 @@ def backward(ctx, dout):
dconv1d_out = rearrange(dconv1d_out, "d (b l) -> b d l", b=x.shape[0], l=x.shape[-1])
# The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the
# backward of conv1d with the backward of chunk).
- dx, dconv1d_weight, dconv1d_bias, *_ = causal_conv1d_bwd_function(
+ dx, dconv1d_weight, dconv1d_bias, *_ = causal_conv1d_cuda.causal_conv1d_bwd(
x, conv1d_weight, conv1d_bias, dconv1d_out, None, None, None, dx, False, True
)
dconv1d_bias = dconv1d_bias if conv1d_bias is not None else None
diff --git a/mamba_ssm/ops/triton/ssd_chunk_scan.py b/mamba_ssm/ops/triton/ssd_chunk_scan.py
index 959078061..fa5b813a2 100644
--- a/mamba_ssm/ops/triton/ssd_chunk_scan.py
+++ b/mamba_ssm/ops/triton/ssd_chunk_scan.py
@@ -132,8 +132,7 @@ def _chunk_scan_fwd_kernel(
dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < chunk_size - k, other=0.0).to(tl.float32)
# If there's seq_idx, we already set cb[i, j] = 0 for seq_idx[i] != seq_idx[j].
# So we don't need masking wrt seq_idx here.
- # cb *= tl.exp((dA_cs_m[:, None] - dA_cs_k[None, :]))
- cb *= tl.exp(tl.minimum((dA_cs_m[:, None] - dA_cs_k[None, :]), 0.0))
+ cb *= tl.exp((dA_cs_m[:, None] - dA_cs_k[None, :]))
dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size - k, other=0.0).to(tl.float32)
cb *= dt_k
if IS_CAUSAL:
@@ -680,8 +679,7 @@ def _chunk_scan_bwd_dx_kernel(
cb = tl.load(cb_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_k[None, :] < K_MAX - k), other=0.0)
dout = tl.load(dout_ptrs, mask=(offs_k[:, None] < K_MAX - k) & (offs_n[None, :] < hdim), other=0.0)
dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < K_MAX - k, other=0.0).to(tl.float32)
- # cb *= tl.exp(dA_cs_k[None, :] - dA_cs_m[:, None])
- cb *= tl.exp(tl.minimum((dA_cs_k[None, :] - dA_cs_m[:, None]), 0.0))
+ cb *= tl.exp(dA_cs_k[None, :] - dA_cs_m[:, None])
# If we don't have the (k + offs_k[None, :] < K_MAX) mask, for indices outside this range,
# we might have dA_cs_m = 0.0 and dA_cs_k very negative, and tl.exp will return inf.
# Multiplying with cb, which is 0.0 outside the range, will make the result NaN.
@@ -818,8 +816,7 @@ def _chunk_scan_bwd_dcb_kernel(
dcb *= dt_n
dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size_limit, other=0.0).to(tl.float32)
dA_cs_n = tl.load(dA_cumsum_ptr + offs_n * stride_dA_cs_csize, mask=offs_n < chunk_size_limit, other=0.0).to(tl.float32)
- # dcb *= tl.exp(dA_cs_m[:, None] - dA_cs_n[None, :])
- dcb *= tl.exp(tl.minimum((dA_cs_m[:, None] - dA_cs_n[None, :]), 0.0))
+ dcb *= tl.exp(dA_cs_m[:, None] - dA_cs_n[None, :])
if HAS_DDA_CS:
tl.static_assert(not HAS_SEQ_IDX, "HAS_SEQ_IDX not supported with HAS_DDA_CS yet")
ddA_cs = dcb * cb
@@ -1011,8 +1008,7 @@ def _chunk_scan_bwd_ddAcs_stable_kernel_old(
acc *= dt_n
dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size, other=0.0).to(tl.float32)
dA_cs_n = tl.load(dA_cumsum_ptr + offs_n * stride_dA_cs_csize, mask=offs_n < chunk_size, other=0.0).to(tl.float32)
- # acc *= tl.exp(dA_cs_m[:, None] - dA_cs_n[None, :])
- acc *= tl.exp(tl.minimum((dA_cs_m[:, None] - dA_cs_n[None, :]), 0.0))
+ acc *= tl.exp(dA_cs_m[:, None] - dA_cs_n[None, :])
mask = offs_m[:, None] >= offs_n[None, :] + 1
acc = tl.where(mask, acc, 0.0)
acc = tl.cumsum(acc, axis=1)
@@ -1138,8 +1134,7 @@ def _chunk_scan_bwd_ddAcs_stable_kernel(
cb = tl.load(cb_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_n[None, :] < chunk_size - start_n), other=0.0).to(tl.float32)
acc *= cb
dA_cs_n = tl.load(dA_cumsum_ptr + (start_n + offs_n) * stride_dA_cs_csize, mask=offs_n < chunk_size - start_n, other=0.0).to(tl.float32)
- # acc *= tl.exp(dA_cs_m[:, None] - dA_cs_n[None, :])
- acc *= tl.exp(tl.minimum((dA_cs_m[:, None] - dA_cs_n[None, :]), 0.0))
+ acc *= tl.exp(dA_cs_m[:, None] - dA_cs_n[None, :])
mask = offs_m[:, None] >= start_n + offs_n[None, :] + 1
acc = tl.where(mask, acc, 0.0)
rowsum_new = rowsum + tl.sum(acc, axis=1)
diff --git a/mamba_ssm/ops/triton/ssd_chunk_state.py b/mamba_ssm/ops/triton/ssd_chunk_state.py
index 50838d055..bb49c9a96 100644
--- a/mamba_ssm/ops/triton/ssd_chunk_state.py
+++ b/mamba_ssm/ops/triton/ssd_chunk_state.py
@@ -141,7 +141,7 @@ def _chunk_cumsum_bwd_kernel(
dt += dt_bias[:, None]
if DT_SOFTPLUS:
dt_presoftplus = dt
- dt = tl.where(dt <= 20.0, softplus(dt), dt)
+ dt = tl.where(dt <= 20.0, softplus(dt), ddt)
clamp_mask = (dt < dt_min) | (dt > dt_max)
# As of Triton 2.2.0, tl.clamp is not available yet
# dt = tl.clamp(dt, dt_min, dt_max)
@@ -229,11 +229,9 @@ def _chunk_state_fwd_kernel(
seq_idx_k = tl.load(seq_idx_ptrs, mask=offs_k < chunk_size_limit - k, other=-1)
dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to(tl.float32)
if not HAS_SEQ_IDX:
- # scale = tl.exp((dA_cs_last - dA_cs_k)) * dt_k
- scale = tl.exp(tl.minimum((dA_cs_last - dA_cs_k), 0.0)) * dt_k
+ scale = tl.exp((dA_cs_last - dA_cs_k)) * dt_k
else:
- # scale = tl.where(seq_idx_k == seq_idx_last, tl.exp((dA_cs_last - dA_cs_k)) * dt_k, 0.0)
- scale = tl.where(seq_idx_k == seq_idx_last, tl.exp(tl.minimum((dA_cs_last - dA_cs_k), 0.0)) * dt_k, 0.0)
+ scale = tl.where(seq_idx_k == seq_idx_last, tl.exp((dA_cs_last - dA_cs_k)) * dt_k, 0.0)
b *= scale[:, None]
b = b.to(x_ptr.dtype.element_ty)
acc += tl.dot(x, b)
@@ -334,8 +332,7 @@ def _chunk_state_bwd_dx_kernel(
dA_cumsum_ptrs = dA_cumsum_ptr + offs_m * stride_dA_cs_csize
dA_cs_m = tl.load(dA_cumsum_ptrs, mask=offs_m < chunk_size, other=0.0).to(tl.float32)
dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size, other=0.0).to(tl.float32)
- # acc *= tl.exp(dA_cs_last - dA_cs_m)[:, None]
- acc *= tl.exp(tl.minimum((dA_cs_last - dA_cs_m), 0.0))[:, None]
+ acc *= tl.exp(dA_cs_last - dA_cs_m)[:, None]
x_ptrs = x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim)
x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32)
@@ -437,11 +434,9 @@ def _chunk_state_bwd_db_kernel(
dA_cs_m = tl.load(dA_cumsum_ptrs, mask=offs_m < chunk_size, other=0.0).to(tl.float32)
dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size, other=0.0).to(tl.float32)
if not HAS_SEQ_IDX:
- # scale = tl.exp(dA_cs_last - dA_cs_m)
- scale = tl.exp(tl.minimum((dA_cs_last - dA_cs_m), 0.0))
+ scale = tl.exp(dA_cs_last - dA_cs_m)
else:
- # scale = tl.where(seq_idx_m == seq_idx_last, tl.exp(dA_cs_last - dA_cs_m), 0.0)
- scale = tl.where(seq_idx_m == seq_idx_last, tl.exp(tl.minimum((dA_cs_last - dA_cs_m), 0.0)), 0.0)
+ scale = tl.where(seq_idx_m == seq_idx_last, tl.exp(dA_cs_last - dA_cs_m), 0.0)
db *= (scale * dt_m)[:, None]
if HAS_DDA_CS:
# This is the gradient wrt (dA_cs_last - dA_cs_m), i.e. the exclusive reverse cumsum
@@ -554,13 +549,11 @@ def _chunk_state_bwd_ddAcs_stable_kernel(
dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size, other=0.0).to(tl.float32)
dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(tl.float32)
if not HAS_SEQ_IDX:
- # scale = tl.exp(dA_cs_last - dA_cs_m)
- scale = tl.exp(tl.minimum((dA_cs_last - dA_cs_m), 0.0))
+ scale = tl.exp(dA_cs_last - dA_cs_m)
else:
seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1)
seq_idx_last = tl.load(seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen)
- # scale = tl.where(seq_idx_m == seq_idx_last, tl.exp(dA_cs_last - dA_cs_m), 0.0)
- scale = tl.where(seq_idx_m == seq_idx_last, tl.exp(tl.minimum((dA_cs_last - dA_cs_m), 0.0)), 0.0)
+ scale = tl.where(seq_idx_m == seq_idx_last, tl.exp(dA_cs_last - dA_cs_m), 0.0)
acc *= scale[:, None]
x_ptrs = x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim)
@@ -641,10 +634,8 @@ def _chunk_state_varlen_kernel(
b = tl.load(b_ptrs, mask=(offs_k[:, None] < chunk_size_limit - k) & (offs_n[None, :] < dstate) & (offs_k[:, None] >= start_idx_cur - k), other=0.0).to(tl.float32)
dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to(tl.float32)
dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to(tl.float32)
- # scale = tl.where((offs_k >= start_idx_cur - k) & (offs_k < chunk_size_limit - k),
- # tl.exp((dA_cs_last - dA_cs_k)) * dt_k, 0.0)
scale = tl.where((offs_k >= start_idx_cur - k) & (offs_k < chunk_size_limit - k),
- tl.exp(tl.minimum((dA_cs_last - dA_cs_k), 0.0)) * dt_k, 0.0)
+ tl.exp((dA_cs_last - dA_cs_k)) * dt_k, 0.0)
b *= scale[:, None]
b = b.to(x_ptr.dtype.element_ty)
acc += tl.dot(x, b)
diff --git a/mamba_ssm/ops/triton/ssd_combined.py b/mamba_ssm/ops/triton/ssd_combined.py
index bbf4ecf84..58a6e04a9 100644
--- a/mamba_ssm/ops/triton/ssd_combined.py
+++ b/mamba_ssm/ops/triton/ssd_combined.py
@@ -20,12 +20,9 @@
try:
from causal_conv1d import causal_conv1d_fn
- from causal_conv1d.cpp_functions import causal_conv1d_fwd_function, causal_conv1d_bwd_function, causal_conv1d_update_function
+ import causal_conv1d_cuda
except ImportError:
- causal_conv1d_fn = None
- causal_conv1d_fwd_function = None
- causal_conv1d_bwd_function = None
- causal_conv1d_update_function = None
+ causal_conv1d_fn, causal_conv1d_cuda = None, None
from mamba_ssm.ops.triton.ssd_bmm import _bmm_chunk_fwd, _bmm_chunk_bwd
from mamba_ssm.ops.triton.ssd_chunk_state import _chunk_cumsum_fwd, _chunk_cumsum_bwd
@@ -50,13 +47,6 @@ def init_to_zero(names):
return lambda nargs: [nargs[name].zero_() for name in names if nargs[name] is not None]
-def rearrange_and_update_stride(tensor, pattern=None, dim=2):
- # ensure tensor.stride(dim) is a multiple of eight after rearranging according to pattern,
- # if not call contiguous(), rearrange only if pattern is not None
- tensor_rearranged = rearrange(tensor, pattern) if pattern is not None else tensor
- return tensor_rearranged.contiguous() if tensor_rearranged.stride(dim) % 8 != 0 else tensor_rearranged
-
-
@triton.autotune(
configs=[
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8, pre_hook=init_to_zero(["ddt_ptr"])),
@@ -130,13 +120,11 @@ def _chunk_scan_chunk_state_bwd_dx_kernel(
dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(tl.float32)
if not HAS_SEQ_IDX:
- # scale = tl.exp(dA_cs_last - dA_cs_m)
- scale = tl.exp(tl.minimum((dA_cs_last - dA_cs_m), 0.0))
+ scale = tl.exp(dA_cs_last - dA_cs_m)
else:
seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1)
seq_idx_last = tl.load(seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen)
- # scale = tl.where(seq_idx_m == seq_idx_last, tl.exp(dA_cs_last - dA_cs_m), 0.0)
- scale = tl.where(seq_idx_m == seq_idx_last, tl.exp(tl.minimum((dA_cs_last - dA_cs_m), 0.0)), 0.0)
+ scale = tl.where(seq_idx_m == seq_idx_last, tl.exp(dA_cs_last - dA_cs_m), 0.0)
# Might be faster to just do 1 iteration with larger BLOCK_SIZE_K, up to block size 128
# However, we're getting error with the Triton compiler 2.1.0 for that code path:
# Unexpected mma -> mma layout conversion
@@ -182,8 +170,7 @@ def _chunk_scan_chunk_state_bwd_dx_kernel(
cb = tl.load(cb_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_k[None, :] < K_MAX - k), other=0.0)
dout = tl.load(dout_ptrs, mask=(offs_k[:, None] < K_MAX - k) & (offs_n[None, :] < hdim), other=0.0)
dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < K_MAX - k, other=0.0).to(tl.float32)
- # cb *= tl.exp(dA_cs_k[None, :] - dA_cs_m[:, None])
- cb *= tl.exp(tl.minimum((dA_cs_k[None, :] - dA_cs_m[:, None]), 0.0))
+ cb *= tl.exp(dA_cs_k[None, :] - dA_cs_m[:, None])
# If we don't have the (k + offs_k[None, :] < K_MAX) mask, for indices outside this range,
# we might have dA_cs_m = 0.0 and dA_cs_k very negative, and tl.exp will return inf.
# Multiplying with cb, which is 0.0 outside the range, will make the result NaN.
@@ -789,7 +776,7 @@ def forward(ctx, zxbcdt, conv1d_weight, conv1d_bias, dt_bias, A, D, chunk_size,
zx0, z, xBC, dt = torch.split(zxbcdt, [2 * d_nonssm, dim, dim + ngroups * dstate * 2, nheads], dim=-1)
seq_idx = seq_idx.contiguous() if seq_idx is not None else None
xBC_conv = rearrange(
- causal_conv1d_fwd_function(rearrange_and_update_stride(xBC, "b s d -> b d s"),
+ causal_conv1d_cuda.causal_conv1d_fwd(rearrange(xBC, "b s d -> b d s"),
conv1d_weight, conv1d_bias, seq_idx, None, None, activation in ["silu", "swish"]),
"b d s -> b s d"
)
@@ -863,8 +850,8 @@ def backward(ctx, dout, *args):
zx0, z, xBC, dt = torch.split(zxbcdt, [2 * d_nonssm, dim, dim + 2 * ctx.ngroups * dstate, nheads], dim=-1)
# Recompute x, B, C
xBC_conv = rearrange(
- causal_conv1d_fwd_function(rearrange_and_update_stride(xBC, "b s d -> b d s"),
- conv1d_weight, conv1d_bias, seq_idx, None, None, ctx.activation in ["silu", "swish"]),
+ causal_conv1d_cuda.causal_conv1d_fwd(rearrange(xBC, "b s d -> b d s"),
+ conv1d_weight, conv1d_bias, seq_idx, None, None, ctx.activation in ["silu", "swish"]),
"b d s -> b s d"
)
x, B, C = torch.split(xBC_conv, [dim, ctx.ngroups * dstate, ctx.ngroups * dstate], dim=-1)
@@ -913,14 +900,10 @@ def backward(ctx, dout, *args):
else:
doutproj_weight, doutproj_bias = None, None
dxBC_given = rearrange(dxBC_given, "b s d -> b d s")
- dxBC_given_update, dweight, dbias, *_ = causal_conv1d_bwd_function(
- rearrange_and_update_stride(xBC, "b s d -> b d s"), conv1d_weight, conv1d_bias,
- rearrange(dxBC, "b s d -> b d s"), seq_idx, None, None, rearrange_and_update_stride(dxBC_given), False, ctx.activation in ["silu", "swish"]
+ dxBC_given, dweight, dbias, *_ = causal_conv1d_cuda.causal_conv1d_bwd(
+ rearrange(xBC, "b s d -> b d s"), conv1d_weight, conv1d_bias,
+ rearrange(dxBC, "b s d -> b d s"), seq_idx, None, None, dxBC_given, False, ctx.activation in ["silu", "swish"]
)
- if dxBC_given.stride() != dxBC_given_update.stride():
- dxBC_given.copy_(dxBC_given_update)
- else:
- dxBC_given = dxBC_given_update
dxBC_given = rearrange(dxBC_given, "b d s -> b s d")
return dzxbcdt, dweight, dbias, ddt_bias, dA, dD, None, dinitial_states, None, None, None, None, drmsnorm_weight, None, doutproj_weight, doutproj_bias, None, None, None
diff --git a/pyproject.toml b/pyproject.toml
index 5831fe66e..ab6315c33 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -12,12 +12,11 @@ license = { file = "LICENSE" } # Include a LICENSE file in your repo
keywords = ["cuda", "pytorch", "state-space model"]
classifiers = [
"Programming Language :: Python :: 3",
- "License :: OSI Approved :: Apache Software License",
+ "License :: OSI Approved :: BSD License",
"Operating System :: Unix"
]
dependencies = [
"torch",
- "triton",
"ninja",
"einops",
"transformers",
diff --git a/setup.py b/setup.py
index 32d62ed90..3ee91645f 100755
--- a/setup.py
+++ b/setup.py
@@ -172,29 +172,22 @@ def append_nvcc_threads(nvcc_extra_args):
"Note: make sure nvcc has a supported version by running nvcc -V."
)
- if bare_metal_version <= Version("12.9"):
- cc_flag.append("-gencode")
- cc_flag.append("arch=compute_53,code=sm_53")
- cc_flag.append("-gencode")
- cc_flag.append("arch=compute_62,code=sm_62")
- cc_flag.append("-gencode")
- cc_flag.append("arch=compute_70,code=sm_70")
- cc_flag.append("-gencode")
- cc_flag.append("arch=compute_72,code=sm_72")
cc_flag.append("-gencode")
- cc_flag.append("arch=compute_75,code=sm_75")
+ cc_flag.append("arch=compute_53,code=sm_53")
+ cc_flag.append("-gencode")
+ cc_flag.append("arch=compute_62,code=sm_62")
+ cc_flag.append("-gencode")
+ cc_flag.append("arch=compute_70,code=sm_70")
+ cc_flag.append("-gencode")
+ cc_flag.append("arch=compute_72,code=sm_72")
cc_flag.append("-gencode")
cc_flag.append("arch=compute_80,code=sm_80")
cc_flag.append("-gencode")
cc_flag.append("arch=compute_87,code=sm_87")
+
if bare_metal_version >= Version("11.8"):
cc_flag.append("-gencode")
cc_flag.append("arch=compute_90,code=sm_90")
- if bare_metal_version >= Version("12.8"):
- cc_flag.append("-gencode")
- cc_flag.append("arch=compute_100,code=sm_100")
- cc_flag.append("-gencode")
- cc_flag.append("arch=compute_120,code=sm_120")
# HACK: The compiler flag -D_GLIBCXX_USE_CXX11_ABI is set to be the same as
@@ -363,7 +356,7 @@ def run(self):
url="https://github.com/state-spaces/mamba",
classifiers=[
"Programming Language :: Python :: 3",
- "License :: OSI Approved :: Apache Software License",
+ "License :: OSI Approved :: BSD License",
"Operating System :: Unix",
],
ext_modules=ext_modules,
@@ -378,7 +371,7 @@ def run(self):
"packaging",
"ninja",
"einops",
- "triton",
+ # "triton",
"transformers",
# "causal_conv1d>=1.4.0",
],
diff --git a/usage.md b/usage.md
deleted file mode 100644
index 1b588ce2c..000000000
--- a/usage.md
+++ /dev/null
@@ -1,43 +0,0 @@
-# Mamba adoption
-
-We've been very happy to see Mamba being adopted by many organizations
-and research labs to speed up their training / inference.
-This page contains a partial list of places where Mamba is being used.
-If you'd like to add links to your organization / product / codebase, please open a
-PR or email us. We'd very much like to hear from you!
-
-## Large language models and multi-modal models
-
-- [Tencent's Hunyuan-TurboS (560B)](https://arxiv.org/abs/2505.15431)
-
-- [Nvidia Nemotron-H (8B, 47B, 56B)](https://research.nvidia.com/labs/adlr/nemotronh/)
-
-- [AI21 Jamba (398B)](https://www.ai21.com/blog/announcing-jamba-model-family/)
-
-- [TII Falcon-H1 (34B)](https://falconllm.tii.ae/falcon-h1.html)
-
-- [IBM Bamba (9B)](https://research.ibm.com/blog/bamba-ssm-transformer-model)
-
-- [Mistral's Codestral (7B)](https://mistral.ai/news/codestral-mamba)
-
-- [Nvidia Mamba-2 Hybrid (8B)](https://arxiv.org/abs/2406.07887)
-
-- [Microsoft Samba (4B)](https://arxiv.org/abs/2406.07522v1)
-
-- [TII Falcon-Mamba (7B)](https://falconllm.tii.ae/tii-releases-first-sslm-with-falcon-mamba-7b.html)
-
-## Inference frameworks
-
-- vLLM
-
-- Nvidia's TensorRT-LLM
-
-## Hardware
-
-- Nvidia GPUs
-
-- [AMD GPUs](https://rocm.blogs.amd.com/artificial-intelligence/mamba/README.html)
-
-- [AWS Trainium 2](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/general/nki/tutorials/fused_mamba.html)
-
-
From 4ea3e20d0d51928b48b3e9299dcb7eabf3f78185 Mon Sep 17 00:00:00 2001
From: maxmelichov
Date: Mon, 25 Aug 2025 10:41:24 +0300
Subject: [PATCH 07/10] Update .gitignore to exclude virtual environment
directory and enhance README.md with comprehensive setup instructions,
experiment details, and citation information for Differential Mamba.
---
.gitignore | 3 +-
README.md | 200 +++++++++++++++++++++++++++++++++++++++
figures/LensLogScale.PNG | Bin 0 -> 215578 bytes
figures/babilong.PNG | Bin 0 -> 84244 bytes
figures/diffmamba.PNG | Bin 0 -> 96338 bytes
5 files changed, 202 insertions(+), 1 deletion(-)
create mode 100644 figures/LensLogScale.PNG
create mode 100644 figures/babilong.PNG
create mode 100644 figures/diffmamba.PNG
diff --git a/.gitignore b/.gitignore
index dbde1b117..be627d6c2 100644
--- a/.gitignore
+++ b/.gitignore
@@ -3,4 +3,5 @@
build/
**.so
*.hip
-*_hip.*
\ No newline at end of file
+*_hip.*
+venv/
\ No newline at end of file
diff --git a/README.md b/README.md
index bf3b76a93..f21a90b65 100755
--- a/README.md
+++ b/README.md
@@ -1,3 +1,203 @@
+# Differential Mamba
+
+
+
+Nadav Schneider,
+Itamar Zimerman,
+Eliya Nachmani
+
+
+
+This repository contains the official PyTorch implementation of Differential Mamba paper.
+We also provide training code, evaluation code, and model checkpoints to reproduce the results in the paper, including all the baselines.
+
+
+
+
+
+
+# Setup
+## Clone Project
+```
+git clone https://github.com/nadavsc/Diff-Mamba.git
+cd Diff-Mamba
+```
+
+## Create Environment
+To set up our environment, please run:
+```
+conda env create -f environment.yml
+conda activate diffmamba
+```
+Note: this should include all the necessary packages to run all the training and evaluation scripts. Nonetheless, make sure the additional requirements are satisfied:
+
+
+Mamba Installation:
+```
+pip install causal-conv1d==1.5.0
+pip install mamba-ssm==2.2.4
+```
+
+## Additional Requirements - Language Modeling
+
+Install the requirements in: https://github.com/state-spaces/s4
+
+In order to train/evaluate the Language Modeling task, first, download the data. This can be done using the following scripts:
+```
+python language_modeling/src/data/datasets/get_wt103.py
+bash language_modeling/src/data/transformer-xl/enwik8/get_enwik8.sh
+bash language_modeling/src/data/transformer-xl/text8/get_text8.sh
+```
+Then, move the resulting datasets into language_modeling/data directory.
+
+## Additional Requirements - Retrieval
+
+Install the requirements in: https://github.com/booydar/babilong
+
+To fine-tune on PG19, please make sure to download the dataset according to the instructions at [deepmind/pg19](https://huggingface.co/datasets/deepmind/pg19) or use the Huggingface dataset version.
+
+## Additional Requirements - Tuned-Lens
+
+Install the requirements in: https://github.com/AlignmentResearch/tuned-lens
+
+Make sure to download The-Pile validation set to train the lens.
+Locate the .json or .txt file in the directory tuned-lens/data.
+
+
+
+# Experiments
+## Language Modeling
+Run cd language_modeling.
+Then, run the following:
+```
+python train.py experiment=lm/diffmamba2-text8 trainer.devices=[0] model.dropout=0.5 loader.l_max=512 train.seed=0 trainer.accumulate_grad_batches=1 loader.batch_size=50 model.n_layers=12 model.d_model=1024 trainer.max_epochs=40 trainer.precision=32
+```
+
+```trainer.devices```: used to determine the GPUs for training. [0] use cuda:0 while [2] use cuda:2. [0, 2] will use cuda:0 and cuda:2 with DDP training, while 2 will choose the first two gpus available (cuda:0 and cuda:1).
+
+```loader.l_max```: the max length or context window for the current training
+
+```model.n_layers```: determine the model size
+
+```optimizer.lr```: to change the learning rate, otherwise, use the default
+
+```trainer.max_epochs```: number of epochs
+
+```loader.batch_size```: represent the batch size
+
+```model.dropout```: the dropout of the current model
+
+```trainer.seed```: responsible of the training seed
+
+```accumulate_grad_batches```: can be used if the memory in the GPU is not sufficient for the required batch size
+
+
+## Retrieval
+
+
+Run cd retrieval.
+To evaluate the models, make sure to save the models checkpoints in the Diff-Mamba/outputs directory.
+
+### Finetune PG19
+To finetune Mamba on PG19 run:
+```
+torchrun --nproc_per_node=4 finetune_pg19.py --model_id=AntonV/mamba2-370m-hf --lr=3e-4 --batch_size=6 --grad_accum_steps=12 --max_steps=4000 --weight_decay=0.1 --warmup=400 --save_steps=500 --eval_steps=500 --output_dir=./outputs/mamba2-370m-pg19-finetune
+```
+To finetune Diff-Mamba on PG19 run:
+```
+torchrun --nproc_per_node=4 finetune_pg19.py --model_id=AntonV/mamba2-370m-hf --diffmamba --lr=3e-4 --batch_size=6 --grad_accum_steps=12 --max_steps=4000 --weight_decay=0.1 --warmup=400 --save_steps=500 --eval_steps=500 --output_dir=./outputs
+```
+
+### Finetune BABILong
+To finetune Mamba on BABILong run:
+```
+torchrun --nproc_per_node=1 finetune_needle.py --ckpt_path=./outputs/mamba2-370m-pg19-finetune --lr=3e-4 --batch_size=6 --grad_accum_steps=1 --max_steps=500 --weight_decay=0.1 --warmup=50 --save_steps=100 --eval_steps=100 --seed=0 --output_dir=./outputs/mamba2-370m-needle-finetune
+```
+To finetune Diff-Mamba on BABILong run:
+```
+torchrun --nproc_per_node=1 finetune_needle.py --ckpt_path=./outputs/diffmamba2-370m-pg19-finetune --diffmamba --lr=3e-4 --batch_size=6 --grad_accum_steps=1 --max_steps=500 --weight_decay=0.1 --warmup=50 --save_steps=100 --eval_steps=100 --seed=0 --output_dir=./outputs/diffmamba2-370m-needle-finetune
+```
+
+```--nproc_per_node```: choose number of GPUs for DDP training
+
+```--grad_accum_steps```: this variable is used to increase effective batch size under memory limitations
+
+```--diffmamba```: this is a flag that has to be chosen when training Diff-Mamba
+
+```--model_id```: this is the mamba pretrained model loaded from Huggingface
+
+### Evaluate
+
+To evaluate a model on the different tasks and context lengths run:
+
+```
+bash scripts/run_activation-beacon-diffmamba2-370m-needle-finetune-seed0_no_instruct.sh
+```
+or
+```
+bash scripts/run_activation-beacon-diffmamba2-370m_pg19-finetune_no_instruct.sh
+```
+Results will be saved in the directory scripts/babilong_evals.
+
+### Plot
+To plot the scores, simply run:
+```
+python plot.py --model_name diffmamba2-370m-needle-finetune-seed0 --results_folder scripts/babilong_evals/diffmamba2-370m-needle-finetune-seed0
+```
+To plot the relative percentage run:
+```
+python plot_compare.py --model_name diffmamba2-370m-needle-finetune --ratio
+```
+The plot will be saved in scripts/babilong_evals. Use the flag ```--ratio``` for the relative precentage plot or omit it for the original scores plot
+
+## Tuned-Lens
+
+
+Run cd tuned-lens.
+### Training Lens
+Then to train lens for mamba, run:
+```
+python -m tuned_lens train --model.name ../../../outputs/mamba2-370m-pg19-finetune --data.name data/valid.txt --per_gpu_batch_size=1 --ssm --output my_lenses/mamba2-370m-pg19-finetune
+```
+To train diffmamba, specify the correct path to the model and change the required output directory.
+To train the lens in a distributed fashion, change ```--per_gpu_batch_size``` to the number of available GPUs.
+
+### Evaluate
+To evaluate run:
+```
+python test_babilong_0k.py --ckpt_path ../../../outputs/mamba2-370m-needle-finetune
+```
+add ```--diffmamba``` flag if using Diff-Mamba.
+
+You can stop the test early when using the flag ```--num_examples```. The compatible lens will be loaded from the my_lenses directory.
+
+### Plot
+To plot the results run:
+```
+python plot_tuned_lens.py --diff_results_path results/diffmamba2-370m-needle-finetune-lens_eval.txt --mamba_results_path results/mamba2-370m-needle-finetune-lens_eval.txt
+```
+Use ```--log``` to create a log scale plot and ```--start-layer``` and ```--end-layer``` to choose specific layers to plot.
+
+## Acknowledgements
+
+All model implementations are based on [Mamba](https://github.com/state-spaces/mamba). Training and evaluation for the language modeling experiments are based on [S4](https://github.com/state-spaces/s4) repository. Evaluation on BABILong is based on [BABILong](https://github.com/booydar/babilong) repo, and measuring signal-to-noise ratio through the layers is based on [tuned-lens](https://github.com/AlignmentResearch/tuned-lens).
+
+## Citation
+
+If you use this code, please consider citing the following:
+
+```
+@misc{schneider2025differentialmamba,
+ title={Differential Mamba},
+ author={Nadav Schneider and Itamar Zimerman and Eliya Nachmani},
+ year={2025},
+ eprint={2507.06204},
+ archivePrefix={arXiv},
+ primaryClass={cs.LG},
+ url={https://arxiv.org/abs/2507.06204},
+}
+```
+
# Mamba

diff --git a/figures/LensLogScale.PNG b/figures/LensLogScale.PNG
new file mode 100644
index 0000000000000000000000000000000000000000..3d714a49102683d7ae781b7151d25d1a016d890c
GIT binary patch
literal 215578
zcmeEP3s{n8`>(Gy*9lA4T&2>%<*H?^w#*}$wQ|dBvv0PgvTEfzX=a9+KxeX6G^?3e
zl5J(JlKhuRo>)VplWDW0Jdu)tks>OAoZkQa0v;2i*1kMlm%DbM1n=`a_j4bA_wSBa
zFn`XV{=@tC>eXwI+uT{-^y>AVTd!Vy0^fTVeq$0CoC*Ky9rVqdnY|M3m{-7yzJ4z5
zF1>msQU@rPyaTWMt(*H@P_JGX7W!XrcEIT$d-Xz2y3KM~7~=a(b@k?tUPsnOl*C6!
z84CoX{!{nv!`wbUx>n|WJ$3QZ6X)4ptoxZ<7^y;N*quv0Z1aco&h1eV`u;Yw
z6aB4ph&mprAP;8;nz!*YM|l#~QaVI1+kPK!dK*6k*;mo7Caf{Jdp5N7?*q+6)lpL@
zRy*e~cv*IibJ%>pY0Gb)`gvc>2Ns0vQ$IgEcsDTb1TW#?!B1nRqy#)oIXmFuxMkCp
z3r;1Zrc92Rvchwxf4jdNlv*TKY96sFBoLExS6m=@|0T0
zw)>uasoHPNF!ZsT7Wmn-_3YZIJx8EloO+I+=N5D_MdYO}Br5`NN)95&lVwrl`Y5cr
zmaNXfDdUkk+BN)XXIxz^mi6$UUCa~*|L>nBkkx#$S}eLqR&mIxQbf%}M8@R2V{>eS
zm!PWg=$le7^-C@GEuxH8^O1V8WxzKyth}CDy5{Cl_h?X)#R|Ye{aV)8}I^YmeizspA|I=DDFhj?E5Dp?Ct7L0!eaRdZ
z86nS;ytj#2WNc~yDwh$QT`21A-w
zD3O2QN>=i)xS|?If|5wIMO}mP1_LWKbd5Kgc4JADLoChC>Ql$REGFH7JEzX027Hi!
zrECh7!fyZ2u-l@Is#{FuEh?l^TyjG`|*)R*1+lxsXnlDQk1c&sRqYFPqVwxJ0SX2E<{F%EvTv_2k
zmWwY_)6e`XN2-o=H9`0|6?qPC4t9Qpds~GPW_(
zeh!(*RaFtN{>$^XP&{y|IzQ$YKAU{cvH75yYmJ$z$^dc3SLx*bPB!>(aTbLpNsCR!
zDNBu!D(NDUqB0p<6}pdes+_%9t;BBf{I_;oswjgiNBG4&awKQePD#XIp()m+j+3FV
z>Q`{df72PtC;7ZuDH)N9Ri63uRCO#ZyvD7LPjm`Es*RCq5s6wZJn;Vfnu1Hp4|8eA
zvVcGSEeRyyQ59q(nEdBH;iErox#6{?gtK??Jj}#%9>3%&NcH7WSklh#XrmK@kb_SH
z-PGrVcByK+d-YZk@Z(b@Q^@@dv98Cg6E2QxOGCE4eMKUTDD1F&cM&r098yD0
zPfX9ae~qgu=Bl0z5=G(3DhXxNR3V2+!b(>T6!IlfqV&cfPQc1vNdK4R@Q16({bj$O
z<3E!!5hu00JUlWJeryc3Q^z(A8fHODZQF@&8U9hMs-#(|
z*1sfzwv%DOi?~X|nbS#~FWLbKg
zN?sqFJClr41=Fz-HL4W2B}mi-Qn66#&CHPprK^dmXba6%%Oi41cS}yGgxn}0*W#?C
zJPNN)WtWww5LL=lk4qtm>Q96RCaTpUL`?yQmx?R7JY1MFn`v9QSju6~65*ZryVXZH
zB4)-%{fLTR>!jXAsznnM=(Q6JRpoV$ZBp`ij=h7AS=fn-DyPIWS*2!|rM2gv%0TU)
zk6;8>gn2*rRPZAg?m=J5f8I;aksSCiKMqcXPz7tALEU5dUZ9ah-J@$^rbr4wwdoy;
zS?hd_j}8t)g6Gb}4gKvK<8jYQXt|069Jyvw0{buG_gSXY`&)haH&cihx7()
zy_t_NFm`jwol(3@;mWA082ienbJ^^b)0S`0IrQMQxdv2)WH)5Yz(G+EQ3oM`sDo;V+cijIg4K57p6RSQNcKV04i=UAq
zGuIQJx=6N)F!FlLmnF!CIGA>Kon
z_??k?9qInnoz|9TXG`qKiH7=j9#tn+JKZr|z_I)BXoyqg-$c)NAJ;F~eE+75p}qg`
zupsv86-~a8PCI9X9jKSQTN4&k@b%ghlCZD*_r9unp|#t`9NQAE7;Lu3VyVX&OEahG
z>r~GN3HQo5ITJa_mlr()U2Ol^{&h|vrOu;VM9CW{3k8KgF!;hSq?ko*BmVhMNLd=v`o&2W-
zJ-Q*tXcEG(4}VZy{H0TDEbx2?FA7eE%Mpd=v&jR=!z>VIa1*Vk-9_GAhe^M}`@As&+1vyv|MrMv
zSgw&RR$KK`{W{_+)n|QF2P3$MT8eX;??AIk+~!=x=`&*YWll0Q9t_i$Ou
zpR7#E6^kHl3|wle3h2epjdAdu+K~JO|5xXI)J^_#60ugssY5q?;a1KTV|zz&*iq6A
zYqxWFrnCDJi*FVYzcF6U=wnl$W?|LUrTN3t$u<2jVaVeHR=zo8Nmy8-`e!=+qWVhN
z9rh}*KgS^Db;fMNd3NZ3p4w}amwm>a)PSd#zKU}-Iuw8Wa>JrZd3gpbA%DbmazzYV
zm|ejUB|d|N2(xP?{O`$kp3ae9>MLaklDCEs9r`;JiBY!?x9R9cXSb!qt$#_x^Umig
zo>@lDb(W02!YttjSDy%TO3!6|EUahLWG+iG(DB39j!jAUCba)52Uv^Nu%6-6V>D8f
z8&lDIL2%vSeO?$WZ67%x_Q5P<#GiI`Pxf(cHzdu?BN<#X2*-0TbW$(OC6siWrWw!_TImH8xD$Zzyz%HaL@%ywwgCbFSsHu4ZTu`VxM0DZsz6lhn2b*c@vE~Fa=#jZT_0^|da!o8THBDF;Q03h18DwG-E5f~w+54zIBg5*&
zu{}Z09jnSWt&(b=UzDM!Vtd#u3b>t$rKfH3&Zpn4x)mH!CZFyj`YJsu>^CHsM##w;
z?X=*8w4xdW_4A7i=lNbmLDlzOU)s!fTF2Q&OhE5EdIaFiphxxoy#(|$(BFCH{?sMTofELEI;@X!cZr|H_R|c-0V~qD8@vt98Vq`7Q&~bL51X
zBA)Zm(jES&l%WS=1_2*$EqW-<8N%v_6HM+g>U{%4R#JA1Amz*Qai=SFj~?`4?Di7b
zxn?ql>Nn2?sWsX)9b={E>NN4{tuGoB`4h3+j~#q~bSt9Ve)89GQipms6XwCkfmcYo
zuL&n-X+&zKsJ3d|Us5g^xu}>lq0ZUD*{w`sf5mzJP06RJmKk?xskH{{;TIO5(Oz2%
z;NEXUm$gu^H|M6sY2m`g6;9m|i;bV@!yk9r2h|3sNS68vkPKARWJi8@!Q4YIHWvpf
zYrr1K%3ktQwSTM3RAEf!A-M7Z8FX6#K7qYFkXV(UQ*rI7J2Blu5@9|mf#UY-&@dVX
z439TCL&*Ex8;_^6pYKo?!O{Fx`TaQqQa9q&C9^pH)mW1@B`bbrK=`*+GkoG_uJ!95
zd<=fZ-3uOnFZk2~&Y1Xg$dB}L-TufCx2_B=K<
z{XjV5-MOthy9Sl3~>pq|`59`Xo9RQ@XoE8c=?~YJ5VDkKD~fe*UX`9G!iXLE4&i
ztADx&1@1`ms^LJPd;8!~m!rdn~R4-fj4&M~VQOv>+^hi%LoCy`yZn@T$=WOdaswg3?
zxc0ws(ux{4lS!Pon4~lE+ViaUIUc_(LUmt44%Z-`x=BbE3jRYX$Drr5QF~6&GvNSz
z_?UYM3r`D+wxDu**)xv)+j^s?_Cr@Wz*fNT=fx#>#wASM0V_$lV*zrzqu`Lkg2FTC
z`L=u;D*fZ@7EPz{ukhNspX3{(A|!ZO-iQa=)hYE>{nX??g?WSXw+u`VC{Mx57gcC9
ziv6D#FgU^5$0<1#>{Wq8b;7I3Yy3>J>IsD7@XVKW&-^|-_yQK%L@`?GS6s97Tsd!lVG;CmrMJ{Tny=hE?;@J&F8~W?q
zkfAA?@olUMy^z{HDA!!bwegxgKowJq<89n3{cl4}=$YJQrRK5xeRR)eS#@n}jMfZ_
z5?5Nx{Eo)BP*Q|9nv+k3R6#a1K!%u9J=kEo`j-61J;AxNk)t=_y*`e|o;AebwfDDw
zN9T<6_xY1vv;~B*;EwibwNi>oFK8hLuoeB2;QT?3H?|?`x~tjl+fZA&P>{F|10^NS
zxEK89S1lX}7!~ht@Hq)4Y~%P1y_u~q4cx6Dqg7s@JF8|?-kfkR0agNwgDVD_v`{OK
zO$z0uZvzJaZt!1rjyV?i_=flx&Gx!pd!5|6XN=&l4PeT&itz0T*K3t*rYA+%%tLLg
zR<8|^wY8}r3g2jtz*i$N8^@zZ@Pd44uL*im`rn3nG5wuBxKkQhC9_}NWXoT#tkJXe
z_RS4Y&(<4QsO*vY2504;ShF65^VS@Ro~`fM`W`cD5GbWb*Z1i9*TT&7IF~)n{$`k<}M%
zpGRpiY%06^rvjL&=UkQLtzbpo!Wicd!>v!QdwP`Ks%wI>GTH5CQyziuNl
z)hr9^g=Qv`AT!xk9yR)lRg<|Dg|M&nl15qoC
zr83dNm<(mjEwXgeCrDj9AHljORV<4mt06<{@wx{pob{G+EjpOQI~I|PNVVCv*eWqM
z&Rsb99hDA{w%u^V=h1>$W~ZQxR*9B6&*TQ5(3NspWK1C;uBYJ9B59|;F4FY2=EZ_z
zjooIA#ex&|R*C-1nK2n^`D&%)mIeygZ%$`DpdFTxi?}OXxsQ>mRFQL$5ZN(}2uMC$
zF%tG@DHQ3Q-Tvyd_bbk5L_Urz&N9QP6kd3uq|_Ig#DV&1QX*!yfGGmBk^R9$67?9K
zQ2h9TDjhGylkdE*s?D=@d**-o0rA{M%Xm*{v@j5Z(8`sYzwoW*|9u2d<2f@JP5FBu
z8P4C$yy7pfxiyKLBC_{REV)EhDb}k}mnG$k#1-*0^)pVocQT#EB%RoCA|{`OJDbTYOTt
zS0$4fQ$OS4nbySNnJMXIDVtQzbna@^VaIRf(}Uv-hWBfg{m_)>rD!Ed
zziIvVDv7~2Unq+YAofv9zEH`39Yin8A5l&Bu*-E1JN<~jR1{NjOof}tEy)?9`|%&2
zvz1GF{VC^9NO3Ei#JMaKA@{xTs%Oq&KP^#RDJdpfj6RdANdKK1(Rg!D+usrn+OY;z-VkgJaX;xym#m)y{9VMHr
zyp4Ss9&>XQ6FH?j`^%?&NM~7QO4NXjC=UytkiW&q>A?ldAb^5@iH)~;MTL8ECdc_`
zyJf_dmGyQd>J9jGaTKe%K}1#Ec1|VT3E?3!oVqS3>_VA*tV&p&iQjFxfI-^5&3S%7
z0>(UzVJQe9AayyXq~B#YJi3564q!x&19Q)G1(}#F0R`f*XAz<!(Z*cpX}+B;DOps(ySD=x
z{6!lLX>|%*W$JS{VWFUa(te#vhqL(KRH|+2_Ge0ZZ6%Gr8dghOTVG2i?rOaLle)maH?E)I3YfvI+e&b-VPttcS-$S9pK2JMxk43`
zg|F!fo0305Sa(W3lQ}d~vny=qzM%OpW`6q+RvgS&n*NY3_4qIB1qu&k2-RUU$yzs)
zy2$_!^C=Q{h7AD4qYElffJ$3t#mUMgAeXEKlAJ*y75Sa)hZX*S%HtdQ5
zgK~-*NH;^*{+f!DHuwNQIC{0#Fl~ai15vEwxivj7AWfHX(9F
zvOiN!|Ctl>Q#^T|dO~ICC)&2)#^#
zN4jG_Jx=xR#J}Q6Am}&K9R%S6UG%L7cAJMb8(%&~7B!dKW5ltCZ@WxA2>DB(8fLE@IFcP|}gOPH%Uvcn%@4AQ4{y6U#
zNKU7wU}i9IT>E)}s`n$Ze5i+<=lcL3lYw0{@zd_^sQ7Y60&zM3TlJ$gb(q~Zm7;7!
zcx&;6j_c=^6_Y5xO-twe;;F}(hQU8QB(bcTENKmhu
z#^fg-Uc{ZURCVlIaQ!&_9b4G`|vLV0b|Hsib%4jXXzxl
z*n@n;>>hr^Kj7ujR+7hW_`0C2uA`>vnie_#<;!|l3~Y62dYcC9eKKoQMDX;oTa%c|
zQeb;=huQ{7&H{H&TIw~h_=?7jd_6uOP37X7xIKTrO@|n(+Mj7ZG7&Q)_v(A=iGOG2
z(zKr6YZ2)*WZFNj{9k6;!2{B@_Mp93$7=2xyPXR-6Gs^djy{^zs7&UnLq;xDrCOel
zJwD(;a$NK&bmyA!%_BRp@(8bUH4)~rguOEDKm1x;&!-_9m~^IE&N}bvrM1kj#p~1|
zbjnTC-LQ6i7JgH6&+@KPg5Ucq5wjqzD%E-AQb9KGnY}JMmKLtj-{;`Q|7ESCjVy)z
z^$X6!Myq}|a*|!hNZPr~A|s@jK-{x0&bG5F%<_C!6S6#yM$#;CsY>DCD@E$^JSyWC
zoZB{puDD|j#7v$2`W2DpE+s@ZFKumnli-C`fsLhum3S#Jb?@WF>&~%E@jy?;Ooii3IPDy<
zRB*#$(tOpVKM<7473*(N^{6%NA!f&UX6<~^);kttJU1IwZ|`^k?M
zaaaAelxUJA4`$)5;hG|qTy3x%RI4LR*PA=F-^lH!+<
zSL*UOZ#T#=2TQq>_IL52}Z0=cw-
z6!d-qQiE6DqH>omQ~4y59P0_iK5)V}I>ia!bv2nJSiCMASQBl?Chqhjx$Ti&EiU
zR+R8jCJw2^At664RW46pSAV2Zve!*pe&^KfeKEhLtR~L9BNxjFMPwE5OR>0{{TM=#
z&_To3-VPIRhwv%tKL)uLE?$X}3~naf!n5CT2umHEmINKrAG`4bo=APh>%j9QFXn-oDQa;|v^aY34w;m-;CQI>4
zsmysql7Nrfd2w9OwB;L%F8%H!H2r8kIc)oh7%jzmJ0!v#BZ5JJc9l~zRL={s%tP1Y
zAB+&PB_dA1*+ih(OS3$<=uYw@nXvq|*!Vj{OkMs95K6`wNzkC0U^Mc{;&sYOlqNQ0
z-@crj2A<#s!oy~R5?&;I?hxm-UI8hB>LtX~flh)2S0I$Oo_)?EW)Uu4>O--9;`x#0
z8otJCi4LJyu0VnSWdgd$8YJZ6xvQ3`LV%!Bw&k?*-lZyMS2Ena*O+S2A$%W5>LzCR
z?6mwzmPeEAxKJg%VNFwRz)AZJshc?c$?;j8er@j8gduPNym^%w9x-}7w0PFsi&2%fSBj8a}jO27|2E;yi^YD#LA@Rfwddd
zl+92^<|PykkwYJ$zWjQ#X*$kH{Z4_CD(3*rFu)A|X16
zHoz9wcSG+7C^_ict!an(Ll@?2jSIA}J+|4Vmjg7;zGX$5lWnaVE%S;wgVM~gDG#2a
zbZ$Tc`j(ULd`nCq~#@MtQd1{?!*=VMe%`KzE<&HGJK5OHF;&Xvofi
z4EQahA%b8HtTC@~wL8>dfF2G<2lHg1g*(>PxB;)>vCm&~V$XRErS{v`R#i|?W>uHH
zM!GY?eM91X4U>5XP*|_&83Mz%+2RKB?}dW5iMx{&HQ?@EbqZ|HLHtAF+Li#wZha7V
z(++c~mjQJ1b;7;QyaWN`%ou3D4?73M&AmH}+prJ(Yj?aS)=o|&z1d_KJRxrM3WWhJEsDNgirhjW!w>5T5!;p)&?v$_Y
z`>K)zS@u248W^A&26dqGdpe&;p~c~S|JAW-H;n7)%(B%mh(k9Urb9$a!y9L0801B5
zD5_h%vxA<>uQhieHzF-2;Lv~r7p{_UaGHWh=~;S)T+BvWh1iAANeWzK=IY@F*BdQt
zRY&4F=%qSo?qCIEPv`m!0ZO@gW0*5~!x&}DE0LqDN7aA;Rfeovig8y42O5fKu=kgL
z?rO${kt`ZMvR0#)hNiqlIKB5*p-N~F_)rW3WJ^cw^5JZWkc=}myy`!b?QVU1&w}QI
zesm+JbmM>CZp5E-kPVHm0ml-Z`)$QIC{HGM{C1dS0@L6OER4hWxOcD(GqO75@uBby
zD3m`fTzRgL6Ym1^5{3KwSZr9t{dHM!Mn=}P6t4YX7ba&tab{Kul8d(e8stHtS(|Fj
zB*$hWoI2LYLK-A@%g*4*OD#46zJpZ1r%VA&?c&@_`yT>TFoB1^)7Rikk7%27f-y>)
zQOLO=vr*pQ?41@VjKGO#o5(9<^*Y6&H8Oh|6ac-%c)&=n3x?h!15eYCNjEKel@d1QxP
zqaA@hlbiLet_Rt`>t~?oH+`rf^Ui<>43sG##?kIiKcmre<-lcAF@#RT8-21<{U$sd
z&p0RrH4fY{WYV6YzEIJnb
zKZ|o8!Yp0`ar02p8{piKDQubJnZE-{$xTd^22dFOi|QF7GL5dXSTB2C7=SC1b3;|$
z!2BpbZ~w=(_g%4F?(SY6)lao?xbQ+>aTRoPFA}GC1n*xeV7&=%uR{s39iZ)#t_e4@
zP%{3~d|GW)ukDtRhdm;@;Dxn=P|Uo@TnuoDoS-4gUjZ)c(
zV5RltQC)SVkN0Pm9q@3juh_5U{>Vq}r1lxCV?v|N(=tl@VrZ1MsqOEzbEuTc`DK_
zsK4)W0nl$0D_?Ksj%yCkm6B&SQtWlDt9cXQ2VG6H%oG7xi=cn&vI(In-
z@i&Xe1^LOPs5?wnUfRW>vhXNNj3I4Psw2lcfzvN^fSk-g@+JeLRmtkDdh?qd$m12g
zO72Ylxr^SUbO-VZ2n|!=4G})4cYM4xH6k5^pu+~8LI-H&lL2*|l!XrUkx>mUaEA3y
zVU#sgf?x=Z<2y+(R%&enoTVe39F&K|FTsi56^*vl3_9avxo#SaiwdIavj=dqP7RTWjKWWs|bFjNQt
zo=+)|5Pi%hqwJn!rIf3x0w5?*mmpPYVPViY8dh##ajFxk*Pv(nw_AoSqM)P;h%X^(6S#gg&fVK(~@Z#hEqe*!HxEED}vh;u+EBoXpQ^osV5c;csgUq>E0h5
z|IfnBo1P`|kE~9(QXXK3JHm|)w{#wI_>aSnt5c#!@IT-DymtSU)APO0o=CU?PG2%i
zjQHT*G~43|S9%9$Z@0XB#OL(tvG+c=UXVH^F2N4w9}-Uc!2edvti0Lp-sh&PVzhrx
zUBH=ldiDI{8<)@Y3Hndc1kr-j_5VrA-d=FX6#e(pw13|imk_;x^K7Qi)R%vj_Pvg2
zG%rW)cMev8>0f2Cmry)t&cynsNKUYl3M@eKRnJzhmj?4{Lx!Wc{aXZQs5d3t0qJm-K_Uj8m?
z(N`m26@n#aAMQw5p*OA(qiH9pDJ1WPebApFAcOg}KIM!d-;
zC)Lf7=2di*A{A!}gQiKQ$Amcs=4F`@j=4L*NUy=
z^6tIQwa1!Jw6y?z1*aRG%lh_}7SQirg7$#REmnEWKYlqHJz&^?f))#G6cz(EgCecn
z?Kgh{#R@D}6wY{huFEd6biRohC)Fc7oT_m6+B}xs)eiNLs|lbi&O)*vBx;v%jNSN}
z<5?2%2#R0r?wHzY=hzV9iy|8lV-w2|u6Y_bO%ftMMW{MqG)#Ah=oOs^1Hm=-_OkZE
zuF+lCSGkZ!XH?$QN^(iu#`3SKB^9$Q_D_(xNLql-%kSCX5223)Sln0V5
zvI_j}7dUDF5lzB7IUX`pjMIb6oUTvSsiQ%sSNj7lOuOW4cr~`}hdiw|X?nPqU$cMh
z%EIZIVU5fMV?HM-#cUi|Rm;!8XUOT^B1^G&(;bIbxW1ig-Jq)K{fD4!8c}pL^0@bV
zVet_{DH-b4WOXb|fvQg{d*YZzwf3cHm2h*3E;|KdZtS%>JEP)Qs=#F@D#++jgND|q
zX7t3@!mFuqMlmh44X9=?C#du~psIQ8R8)FF-6rYOZKF0az;tOP{biq}y;YVaITK>-
z&RayvN-0>i2Bsx6s*l^BQ(qa
zq24FBQs9ZCxuV|ZG*i?Q0gtk_9x(`!TnRm8*SfwkrAn}%Ij+1!mY>=vOh+{-;F
zoi#S>MwY75%A!jPMn!;GtfZo-c`7<2l!hFu23mQfo{t*M^0O3c1AJkynu~Ld1z;B&
z#IuvEUAT^gQ;VrℑLqUcZWp=I(NE4o*J=%0$&KjRC?ib=xkq)HV&yQpfd+gKmR=
z$ho?Ux(!3ak<~T#olyaRRArdgZbLjHCnLGbgO<9e;~%!&6A*N2A50XUg)+Os@Q)*|
zA5eT=cB!2L^SDDh;c_27l|uu8it#g=fzA;HS{$D481u3i&!gy4Z&mKf7ySGN_gCZh
zHr($HWq$n5qG=1%(KnuscF5o}?0Hd;fv?@^JJ(kzm(`GmO~<{Rboc#`J39V
zoq^kcrkgsOTt9qzb5w-mqw3wRCj4VoFVDK|b+79a+oq1v`57$}a2?7qTobI~O+)*D
zCgGz4Eq<2CEW5#yVLS#*E9C8G@pEX*@T43xh{dYvSvGWrow0})!QpE&QO)CaO|Es?
zw7bS34T-514sn~u-=R`B7w?0K!U6j+4Kl8Mj5V1Ie7)3LWG)u_O|m9GujZD%>`q@B
zEAO4treV@$uR3t1?sCe;xn-yVi`s^P!IZ@~<_d4hsG4QdZ0M{HDEyZ=WGTYp+)jJ+
zwIy)$X3{>!QRoy>9Cuuq{baEadf{rB9QI0OHOm;O8DbH|K~&J~Nv!vOse9VVcjG?z
zY7Va6qLIHCVN-mKtP)(`r!e1b|MH(nEg7WR?ANA7*M%h=IFBXrftsreHcl*30zfvW
zcLejpc9*+!nVD+xa(!WPn;M`)#bBW#@YReHsP~Uk@6DwXjxn6Wc-0(UyZUX-vlG};
zdeb?b;EYvti^#bNgeV5a-COiQph}9`wVDYO=J+mt)?+Yk=4xsC`JWM~0?N1sue){9
zxlNPav2t9Su$$(N=K7R%c`5eMX53}d2RX>rsS2cEve;-k0jGin9JNf$#v=7j81kU@
z!^~?o@+CKQdg_JjWyg2cG!`4pB-G}z&|6%77P*Nso#gZx5;|k_d38}~X#07=y*_oU
zgAMC!U1Vyw1|2u}adQsY7vV9iGiLfi*aN^B0CWcSfb5M|$Rxr)GfN-7(WCF@H0#`JL!SPn
z%1=6a`kVV$L|!KxVcez`$VM{Rq1F*E2~uBC=Z&o!oY)4u$*>xoVOY6W1{_Ee%p(~X
zr*QLivMQK<7A`@=6?yT1GC^sGYVMQE(+*_f_q3aIO-$TFIy^-mnxmJ
z{zVb6so{N=K;_O&*fs>qe9oidNcB&dHTAW8969IBii1Plp9y-gjBmAaokoCco#fLI
zfD?4a*n5jGfrvc#!>tUAN4PRpm}3#Z$w4#W{Tt#wUEid$&o|XE0KvPUgVfQsGgyrQ$ME<
zr-B)57tLXq0x;|%r8dCN+r5-pl!*owvk3t`5t|vjfM-8f13#!eHa>y+hpakm&dzZh
z7mWO;UxdHSOrQr~>!t%bh@cOdYp7oV<&N7g0~Z?RH~dRzK9Q;r3+~12$1XclIaee4
zapHChC-)`Njp=x9xL*W7Qr7YH6n_j&h%TRLeC;-zdVA_220m@XuA2M0enV~37
z?~Xnco1vM=j_-By_xG*-Tx9iW{ZtNnO}MfI3kkYgUR*W|=wsCC>DDnc8k-Vlv<%gF
z?crJ9DylD3WM|boJlnwyrK%&VqlF8$iZzj#NRnV8dSV_qy(m;qW7dNQvvKpHbp@oh
z8kNvIEmY2-;PC0o`pVvJEG9LN?3MB4n3XdLDQ7xZ{X+S`hMV1ej(;9pel#0jDFehso_GlTTO;7-&sXTRZ$96DtM)iASKchPE
z^zqvHCV8$JP>DcF9>l3yv%i>XQa`-riK8Wz%@fNw;{WDK)gIwP>(ZI|+aWswC(wZ2l5-!19?~1@2R5b+ORA4Z2=hz0JsZMNNXcRHUGVK;vRRmCxZD%);54
zCYgD@S1P8!KRsr>CS~58Qw5pRd^C-n(M7h!gs(OJ`T;Nh6zE|L<@bxRIdJjJv_y9^
zP3zxSsk_ui>%2Wa6k`L}`;ySVPThLE<_~vu{}}ZFA7~7XPFtV196CcCJRWR_KI)!y
z1HU9sEFnzSZI@qoTOaJyP|hi&Z{Y@j!){mG+2}lv0&;;{i=|vl$t+D2g=`^3Qex7TW#@N7;AhM{6sY=y$f!>&4>jQ0|X|>zk
zqi#H07^56IO=X>01&=AYv`>EzunlS|Hj%%|D|lr2t3M&
zkoQ;cg3-8N%UD!uo@_rUN0LOvBPbFihmOaShjOHHS;hgvAJ>buDNpqT4hKidnP7|C
z+x>3U9$=18h0b_nFIZ)VE&5O_
z<|;EdE;|);eBmxds++gbXy=k~W%2;a{S4)ibn;L=d32dWzzM9$_~hf7etdlmkWI$y
zO%qKRmQ{V#emg%Nq)QXFy@NHfOd#wNkQ}c`_|BM#-ehOvkG;8LaKnEKlfKKdmsN}T
zsReU>^cEd*m
zGHf&w&L&GSgE!p+~ivWYzBXXL+nq$s59)Wj}m-$GSj-@?rz=%cvqW%EgnUq~&+
zmt-fQ5Eg8nn5h&GCv>+BK>;0OrZ{6hZRjJC2VuFq`J~6!xD_M+I-n=7ykV5-~Vom}I
zf=T#F4b4QoU&a0`O|(GA#$MIKNiZTtkj2op%62JB%$J}5hcTbpq#At$b*XX*F)}=A
zx9~z{=`cM(WnyB4@6RzWbV|30{M~LTJKX3S7{Hohn0srOMroq(F_2%#mkC`<_P6(o
z#)<<^m)h4=-vSpdPt#G25Q$FQQDrsqS}8T}I?If+WjAZ8p17$C>r#LRWyO2Y)aALX
zdpQ7^uAh%k9Yb80B>m0i*T#l@>aPtMQO=z%6IzF=o{oyx1GU{M(O|JCjf6<}7_vq8
zQ>Q?mc@?TM(B)@gxm#8Zyn{FqMr>BD$4!U+-K9m^U>Ri~YG!3ywljkBF}Chcb`!PV
z%ZXavZmD06fGheKwIZ6%ZGP%(ZzNqUV&V0*twVq7o3&@fbpg?5oz
zEZX;rMP$%OMMZm`v1qakt6?g!#*jbimKH+8MhneErB{K_?axeTDv5bY*T??Vb(I1@
zNMXhkXN!^jdw0cX2|FXNR9|mR*HseCu(?~EzvrP{20S*FtmLPD<&9;E#dvu-RI{=}
z=R`2yqqj{C(S-waAb(WW*D}pR7??RQNp=#Rbf_N166eeb@}^j0a`@fOp#r5^e`hdf
zj=wiH{cbtw&oX<*z>sO(QZhV@?H-RpD-OU_U3#Yb#pS^JYyZAD@XyjP$n}=mLovlk
zo50b}{M(Lu-N{{QFBlgO3W998j!`nCZ1r^!n#Ag+f}x&og5NZ^48;+XsDO*=hHCPV
zmL%GS?rQ3Z(nJHoS}(C`*LAHRiK#^m+o2RY_>6)v+=#N8)tK+hH+H|02c+pVSvr_Y
zpE%+e3&TnM;xSDDWHe0W&C(m)&wM~yx6J<0ej`g8I^hrlvp}Ohm2x2+QHsi;I8>(q
z_6KEoszvPWZWp_vF^TJ7)bU2_WVsUi>4CD<`iMijcrF*%*1sOH;k^?z_KO133o}dU
ztZpX_LW-3$!+9Fd{*xu5PXdWW*MPeZZSJdy$LK|+tKX~8y_~eEFM}kd3+j~q#zr2!$2Ym
zrU*gOI
zszGWoF&_FT#|_2gY`09%f%!TQ{8xRFVuj8}QKCbBm0}nS!&F_!)X;mn4)dygY7v|D
z4}<)EPTo2=bWG+ri`wl40BGy*oT`RVWgD|E3@w_E=5z$^%CKN%-~*IGuQwmzT+_0F{|$iDHl?$Ma2&o>sud#qmlW%^`wA=Ua`k}$TU0Ae
zVbZWx&2$=MBy&_6)FX^I4qpL>%ML+Ex*}6zW6EEbiQm?iY^Z!}?+(DDFyWUY%sZ!e
zR8p)W0R$x(G7LqEF`YgOwAjtSczGL*Nx%_gtHo_Ko-&1F!%CYbn6?Itc77Q-XBjy%
zp%0)vg2F?fEt1?#WrImV;j%Z{a!J$)F@ha*9evG9`j6)cN_YQjvlR{N-fnae4P?+m
zcCOP49_?9rBddCQ%~(CB4$aSs;M)X>_uTcKyWT#zMt4_q07XhFY
zo^5BI4RvfGUY^n-9oOhzbo37NfcNz0VMQqszOtsVoq9XsfUR&~i!5wsA(l~&_)Bm4
zJ*@21Aj3S`&MWz=^-DewmTd0gho5>4dLlImGsIqL(Pq0eI=OaIxW=p1&`Rm;5uPV-
z+1Wg#=`@5>YM*70|Im!qLfP8bXqrqzJrypgX5MTo(nBFN^BNkZHFKldRt%zC*KIfZ
z-J52*Dn?tPXl@-0>ydw}>}1@}zcqGK@}O5P588mlFuUBz_9oKi0KEd>qnOeV`wUp+
z)fS`O@|n1!F%X=F0!gLm>9u3;O>4d%ohM(mT8YkLlA+B%XuX0)h>~I3oTdn9?%Agr
zC=T1WJu7D7-a|K}75UiZPJE4Ei14j#CRksQjTp!o!s|^b%qeY^c3&zlrZ5z}jVqF^
zHLT6A)0?EEFzmWDg|O{0G;2?>wwmNC77YjZyR$}OV02kbK*rT=CQ%%^4`uwF)UTAm
zZ35r{%lnIFRJxu_Yjm#hUv1FhuWjyoY9FB!)WXS}ML4TpApZ>WI=}KZ8V{q^ix!YS
zZ>xfn1nee1YFx5L_D7VYz!XLtf(p@T<#_`6;jVMBc+inqwoO?45lkqscKldeN`Ss)
z)J~z&GrLW*-L~}1gR!5e@M@W(rp5`EwQ+fLhP(MiGYkjN#%F<^U)@U2w{fW_SfkTC
z^hmB@+%Pi2-Z|zl@O9f@e(^_Yd0Q|
zu6Awb^zBox@xV~6OHa)V{t)`}Tqqyepk44}(vR%6M`m`%Em{UH4s`1wj#e&|F;%|U
ztiY1*%D(FnpfsA7GWFVT(8~ygQL<85nR6~!gBP`Jx^EOmw
zXZ3?nCJwAEur;1P=uOp%Q(cldHGcO
zFXL)vI$!S_I{UKWqMgYMYnA718O|6AQ2`L}-T-N^*zA9*L^2acQ}@QqB2b7VM1
zI9~q}9_;wWsh|(G6l71h=kI^~GH~!;*^;ti=GvhCpdwP9ybm0f;H2F1hX=M@kQ%z9
z0A4^8ZELIX6XwIbfCZeR=)8mkyB+ZF!0Stj@j6{xGy`TMYz*oHA5x(Eka6G{K0V@z
zdLf?&_~uNFU2X1QB@P)g%W?%n9`aizPwH4QIk8-PEh2!NL!4XY6mqRBE!093Lad5j
zQ8YC>gdJ^}#`Csbal-vn=zpFbu|A#<+JBX;hs)a9zVOj0(Np3Va2z$74S!hsHSomn
zcttZRpA1ExxDY&>wx3wLMql)ali-PAjh{oNK>H!^sle0y);9_o*HyRTb_Muzhr2xH
z60wmX#<7Xh1AN&IzO+rJhDqm5EeSkjB%02Ni2muo%2@Eo`<-4LyR{(Ob3vhISNu=i
zh9mGmImi(+y$9Q3(b~;|{WIWTN1QI!Zc}alq?i$CP66Zzq@%F4bi9188^5;k*g&4(
zu`laHths2$aB6?%4@aQ0o0|MZxKbv@g{|%@4(XG{6Dg`oS)57u3}yZO7-VeDqNBES
zo&`O}Deq_4FVXb8ZG@x3%z2k?dsxKsDYf^Yw+>DnpTrrJy8375bBsu;L*(Ye_N
z+c=09bwuYBB{s&bNO`d>-6Qm-@0=E`xmIwy#+@#9gyI`)&pg43n=OP~E1&PKvITty
z*CK6Z<&$?;#q>+cZmt5~xLQG#IB0D6S~{x43s#_C3+h|X*S332wJXRLhF-h)t73m<
z_E$ao0b8rxk7bZQ3YS$&P(CK!GF!E&I+?6~@=06z<(Tk(?$|t3`=Yn#ouruf
z)vtJ7baUgsk2P_ROPp1!98APG@PrK2L4L)y@{ra^ai-f7Ubpme*5&*C?6lARSg4rstJZ(H)I+P@7Sad72MYOzun1z*Cg9RV7@1l_nWPM39UzJnxrNfD7}7q)JxO+jun3JKAN3_)kDH$$tu7zMV3yyy!k
zUmMh57=WblYa5HkZG`07fHJf;ELME3{EU{(eN0Ijt|vTd3Xh;h(BrSQsLJ+4(2aPu
z#t1qvsCI8pe!D5Ztq<$b^SPoS8T788&%Alqc_s79f_g;K;b;8lvR+x@)(Y}gvt(^PwC
z2@UL|&Mq*NI>Y!GKB3Ch1o_PH1h;8Q;L0}8$c;#)t^w3^H?f^mYH(~cHEu*ydjL_m
z+h~x~P_ycs(6UgNJRTDhV0e;$>;~T7dco!|VlO)*#@TP;D+Wi=xqH9T+sg4Gh_qAo
zSiEY59$Ex!+GrEehoe{b(p)
zo5z88rX^ydM_WN#eb5;c)S|4fMSdbZyT)QK{s`?UXUB}qP{p#Z?qYA*jfRAHl87gy
zinuMI9psvMm%d)9>k?_%K;3<^4PQA)_Rq5XDdc1;51r;{x^^LSifH=mTjGJbk(2&H
zf>&mpGi$^l9lc4b9`o(~gKxi_)Tav++OFIKSo@QYn&%7DH>TeG;-rq*q*+`Ctw!A^
zWnd$D5zDPS;CEv0m%4VV=8>EWkO6g{FvZy_u-2q`oe_B){nKxM>zH4fB~75$y!&Lp
zqf%aAZ0PZWP(KLD(k)A{JdcUo>Va3dtK$Nvn607
zSte9Ej^vI~^x!aGge!t@=zMK_YRgt+U|t;8rVF)w8V_Z!(_pdm&_{NqveU)?L(5J9
z5weXBN>Lk7dltb_S;kz9%V28qoN&cOs;g-$8vtyTqb4`eR5pv_#$Fl*G;mtB({`Cx
z8{X8(DNzqu0C#NVeim<{iC004Rb&8@U#gncGA$msbZo)k!%d_S&7v67f}f$pq9rWV
z!+G1KU-Pi~tYO+VT!XY8x1S1sJJiq~d-9cNL5$iV9U8crm
zGaBnpKxf(2i&W>HL6HhZGPwh+Wz$!gyfJpeLvzc-;t{r@?B+$`o?F?#=zPcmh{(!s
znE^FLrMrpoh|sxZEQ&FiBzfY~y#9IgaKyB}%?MPRh0)#428;n45bp432KEd(rAH}I
zdC;6y_4mo;;^(_SsgVw3Dw(bk*FJ(uDd!C!EZT=L?|4s
z@iv%pnQFO^qbWaRATpIf%B~x4qb0>M!(#n$Y7yF^%rxKrlQU*D0SB
z|7=G8t1C)IU!0Zn~D|nF(HpJz-corkm&|4G`RiUA~r}bhNcf>
zMDse&{M)le{HGbYo~~><_LBJXR$R*6AVV}E^@a3(&t=o1G&P;!%UhODfw&S`*rSUY
zxfOOd+ny?6(6Zbff~G9Q8zZjqW@JjTO{qvmL&Cpj?=<7nw5en*MY3k{8M=4B1pE?`
z_R(LMIo_dZOGX=iSrl*}L`T`$vu*mfP0Vz9yg%SzNVpX{dy_UZHs2d7pOe)riTFZ#
zEIRj^Zb)5s%3mL?ZMVrHunY|9_+~Yvo>*H$t;|0<2jk{gVlJ-S%;N1GjP!$~e^VOv
zCBoJZCZ;vN?%gTw4zffTd;QEK8C)x|hE`db9!Fl-_5cmAr_nGPOG{7E4ZV*#z#vem
z4`my#m!d9P?b#qweR&j9t<qIBg@j?~12Cn;beO6H$h|XI=YBpi<hq=7Bf3vQZR)g?h
zLY_8L*IliwE4Oi!OwUfWJ72nIrwpCMjb6db>}*>!D!|IUvcc?ryJV2bf2~bwNcLf{Pg~n#s7rz<8%rN
zr1nk`tOyqmy*-7jm2t4p<>W|sjsrrvA%$Jv_zWi9MapA+R#ewnW>2R3!zIYv*dMuS
zQ0d4G{!Wn$kEL1^7rQ^XPnK%HxJ2{R?jW1t1V#UrRR5%a5q_ep>qJkuO(9$KvHlbg
zrG;RXcH5A$LLl7;knYdf@x(=-6y6bqE?U?eeWx*m&}e9GHEfLIe8rLSezD&E2YLPh
zr+va&!D$~|UYu3usZzMY*4T2#Njj#kPwuL_`CfMei#Ndl>b`-Fe_pLx942&SW5|2@
z_E)|(x+$pfa~WD|c<7}h+fnoz9PUNJAj5}2(1z*o%!x}}8{QC8Pmx&w!{{{KNe0Al
z4B74*E(Qdd29VC8f@VE%Cnz$jr}x_6xCP{7rowGTffTQ4h2s9{HWyB$DkAPYd&a9M
z`5u(mG-GVUjza@xgnHev?QX`K=%oSO0La#%3ivd#MsR{SrL-7oI#v2`>(DOS{F`12
zK`T<&q#V|f&l##t`o`@3AJoS|=$PQWbIU??%tv*gBSU@y*Wt(<@$ex>(dz7xpms
zab0KnSpOn#W#j8?&u~8ZfzPY2A9>z({(k<(EAP*;$eIsg1m;y~uIXK00ogiSlW1P$
zFy)(MUwGPuLmm*dD}jsG$9$Zg^JHoJgDKy5FJ2E2t=p5O{)^IY<$nuYY5y=YK#hWy
z>H4~oao>2St&jOd->dw=IPg!o4?hGC6kO>0KnyW0^7(T7%rgoWXq}5cG?B*Iv-^lO
zF7o7kn!2kt7}B1Pqn(emO0kG(eSJ^E{(=^BFb1^y*TI^^H~oJ&2bwAc!;2tW^!t9p
zmw^}3uY07Wf>)e*#Vu-C`-5@c05^4zrJD5aJI3VN{=Xz6LHzFC`v>BZXoaijseQeZ&s?u)>w1{+
ztI~#iAYi2d`(tYW&<3s?&7+KK8;;El42kWqhEPU~YrdW_C
zG=%KMJWro;J8BG=d<~`{cK)dyG`^bW!e;WYD!7JySUte3d46tubUl|&NkA~TY
zuBJD00YwIXAWU3PZaRWE_H5~}k21&?yf$xG@Q39QG=}Zcl}?-JMJJfAqN8%%m%p7z
zyL*gNKOd!ALgzgt{<++CuVh@61CZUGU(+$?zl=Fe)J}d}0F%On0%lE-H4X(rohzmH0Go*TcS?2{A
zKHZmN|3(Hg?NLTsWYjFc76J1#$l4^&B=i8Z3fjGLPL#R32A-BsP*##eu^{h8WY|<0
zp5R44GOKgqB=*}SRO*PfSdR<_-wosT0t-8c1~)Jjj4lor^i?uLpo_oDB$#Afde=5G
z@`+