Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,42 @@

This aims to be a simpler implementation of the [original repo](https://github.com/microsoft/Samba).

## Installation

> [!TIP]
> While the `pip install` command _should_ install all deps and the package, in practice some of the more CUDA-heavy deps are better installed separately from source. See section below for more details.

```bash
git clone https://github.com/pszemraj/samba-pytorch.git
cd samba-pytorch
pip install -e .
```

### Installing custom kernel packages first

After installing `torch`, `xformers`, and `flash-attn`, you may want to install `mamba-ssm`, `causal-conv1d`, and `fla` from source:

```bash
pip install --upgrade pip ninja
pip install git+https://github.com/state-spaces/mamba.git --no-build-isolation
pip install git+https://github.com/Dao-AILab/causal-conv1d.git --no-build-isolation
pip install git+https://github.com/sustcsonglin/flash-linear-attention@98c176e --no-build-isolation
```

Then, clone this repo and run commands as above.

## Usage

A basic example of creating a random model from a named config:

```python
from samba_pytorch import Config, GPT
cfg = Config.from_name('Samba_421M_1k_window')
print*(cfg)
model = GPT(cfg)
model
```

## repo structure

```text
Expand Down
83 changes: 43 additions & 40 deletions samba_pytorch/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,13 @@

# Copyright Lightning AI. Licensed under the Apache License 2.0,
# see LICENSE file at https://github.com/Lightning-AI/litgpt/blob/main/LICENSE

import warnings
from dataclasses import dataclass
from typing import Any, Literal, Optional, Type

import torch
from typing_extensions import Self

import samba_pytorch.samba
from samba_pytorch.utils import find_multiple


Expand Down Expand Up @@ -101,8 +100,9 @@ def from_name(cls, name: str, **kwargs: Any) -> Self:

@property
def mlp_class(self) -> Type:
from samba_pytorch import samba
# `self._mlp_class` cannot be the type to keep the config json serializable
return getattr(samba_pytorch.samba, self._mlp_class)
return getattr(samba, self._mlp_class)

@property
def norm_class(self) -> Type:
Expand All @@ -112,9 +112,12 @@ def norm_class(self) -> Type:

return RMSNorm
elif self._norm_class == "FusedRMSNorm":
from samba_pytorch.modules.rmsnorm import FusedRMSNorm
warnings.warn(
"FusedRMSNorm has been removed, using standard torch RMSNorm instead"
)
from samba_pytorch.modules.rmsnorm import RMSNorm

return FusedRMSNorm
return RMSNorm
return getattr(torch.nn, self._norm_class)


Expand All @@ -133,7 +136,7 @@ def norm_class(self) -> Type:
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
_norm_class="FusedRMSNorm",
_norm_class="RMSNorm",
norm_eps=1e-5,
_mlp_class="LLaMAMLP",
intermediate_size=4096,
Expand All @@ -150,7 +153,7 @@ def norm_class(self) -> Type:
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
_norm_class="FusedRMSNorm",
_norm_class="RMSNorm",
norm_eps=1e-5,
_mlp_class="LLaMAMLP",
intermediate_size=4096,
Expand All @@ -168,7 +171,7 @@ def norm_class(self) -> Type:
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
_norm_class="FusedRMSNorm",
_norm_class="RMSNorm",
norm_eps=1e-5,
full_per_layer=2,
_mlp_class="LLaMAMLP",
Expand All @@ -187,7 +190,7 @@ def norm_class(self) -> Type:
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
_norm_class="FusedRMSNorm",
_norm_class="RMSNorm",
norm_eps=1e-5,
_mlp_class="LLaMAMLP",
intermediate_size=4608,
Expand All @@ -206,7 +209,7 @@ def norm_class(self) -> Type:
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
_norm_class="FusedRMSNorm",
_norm_class="RMSNorm",
norm_eps=1e-5,
_mlp_class="LLaMAMLP",
intermediate_size=4608,
Expand All @@ -225,7 +228,7 @@ def norm_class(self) -> Type:
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
_norm_class="FusedRMSNorm",
_norm_class="RMSNorm",
norm_eps=1e-5,
_mlp_class="LLaMAMLP",
intermediate_size=4608,
Expand All @@ -244,7 +247,7 @@ def norm_class(self) -> Type:
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
_norm_class="FusedRMSNorm",
_norm_class="RMSNorm",
norm_eps=1e-5,
_mlp_class="LLaMAMLP",
intermediate_size=4608,
Expand All @@ -263,7 +266,7 @@ def norm_class(self) -> Type:
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
_norm_class="FusedRMSNorm",
_norm_class="RMSNorm",
norm_eps=1e-5,
_mlp_class="LLaMAMLP",
intermediate_size=4096,
Expand All @@ -280,7 +283,7 @@ def norm_class(self) -> Type:
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
_norm_class="FusedRMSNorm",
_norm_class="RMSNorm",
norm_eps=1e-5,
_mlp_class="LLaMAMLP",
intermediate_size=4096,
Expand All @@ -298,7 +301,7 @@ def norm_class(self) -> Type:
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
_norm_class="FusedRMSNorm",
_norm_class="RMSNorm",
norm_eps=1e-5,
_mlp_class="LLaMAMLP",
intermediate_size=4096,
Expand All @@ -316,7 +319,7 @@ def norm_class(self) -> Type:
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
_norm_class="FusedRMSNorm",
_norm_class="RMSNorm",
norm_eps=1e-5,
_mlp_class="LLaMAMLP",
intermediate_size=4096,
Expand All @@ -335,7 +338,7 @@ def norm_class(self) -> Type:
parallel_residual=True,
shared_attention_norm=True,
bias=False,
_norm_class="FusedRMSNorm",
_norm_class="RMSNorm",
norm_eps=1e-5,
_mlp_class="LLaMAMLP",
intermediate_size=4096,
Expand All @@ -354,7 +357,7 @@ def norm_class(self) -> Type:
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
_norm_class="FusedRMSNorm",
_norm_class="RMSNorm",
norm_eps=1e-5,
_mlp_class="LLaMAMLP",
intermediate_size=4096,
Expand All @@ -373,7 +376,7 @@ def norm_class(self) -> Type:
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
_norm_class="FusedRMSNorm",
_norm_class="RMSNorm",
norm_eps=1e-5,
_mlp_class="LLaMAMLP",
intermediate_size=4096,
Expand All @@ -393,7 +396,7 @@ def norm_class(self) -> Type:
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
_norm_class="FusedRMSNorm",
_norm_class="RMSNorm",
norm_eps=1e-5,
_mlp_class="LLaMAMLP",
intermediate_size=4096,
Expand All @@ -412,7 +415,7 @@ def norm_class(self) -> Type:
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
_norm_class="FusedRMSNorm",
_norm_class="RMSNorm",
norm_eps=1e-5,
_mlp_class="LLaMAMLP",
intermediate_size=4096,
Expand All @@ -431,7 +434,7 @@ def norm_class(self) -> Type:
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
_norm_class="FusedRMSNorm",
_norm_class="RMSNorm",
norm_eps=1e-5,
_mlp_class="LLaMAMLP",
intermediate_size=4096,
Expand All @@ -450,7 +453,7 @@ def norm_class(self) -> Type:
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
_norm_class="FusedRMSNorm",
_norm_class="RMSNorm",
norm_eps=1e-5,
_mlp_class="LLaMAMLP",
intermediate_size=4096,
Expand All @@ -469,7 +472,7 @@ def norm_class(self) -> Type:
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
_norm_class="FusedRMSNorm",
_norm_class="RMSNorm",
norm_eps=1e-5,
_mlp_class="LLaMAMLP",
intermediate_size=4096,
Expand All @@ -489,7 +492,7 @@ def norm_class(self) -> Type:
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
_norm_class="FusedRMSNorm",
_norm_class="RMSNorm",
norm_eps=1e-5,
_mlp_class="LLaMAMLP",
intermediate_size=4608,
Expand All @@ -510,7 +513,7 @@ def norm_class(self) -> Type:
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
_norm_class="FusedRMSNorm",
_norm_class="RMSNorm",
norm_eps=1e-5,
_mlp_class="LLaMAMLP",
intermediate_size=4608,
Expand All @@ -531,7 +534,7 @@ def norm_class(self) -> Type:
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
_norm_class="FusedRMSNorm",
_norm_class="RMSNorm",
norm_eps=1e-5,
_mlp_class="LLaMAMLP",
intermediate_size=4608,
Expand All @@ -552,7 +555,7 @@ def norm_class(self) -> Type:
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
_norm_class="FusedRMSNorm",
_norm_class="RMSNorm",
norm_eps=1e-5,
_mlp_class="LLaMAMLP",
intermediate_size=4608,
Expand All @@ -573,7 +576,7 @@ def norm_class(self) -> Type:
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
_norm_class="FusedRMSNorm",
_norm_class="RMSNorm",
norm_eps=1e-5,
_mlp_class="LLaMAMLP",
intermediate_size=4096,
Expand All @@ -592,7 +595,7 @@ def norm_class(self) -> Type:
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
_norm_class="FusedRMSNorm",
_norm_class="RMSNorm",
norm_eps=1e-5,
_mlp_class="LLaMAMLP",
intermediate_size=4096,
Expand All @@ -612,7 +615,7 @@ def norm_class(self) -> Type:
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
_norm_class="FusedRMSNorm",
_norm_class="RMSNorm",
norm_eps=1e-5,
_mlp_class="LLaMAMLP",
intermediate_size=4096,
Expand All @@ -632,7 +635,7 @@ def norm_class(self) -> Type:
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
_norm_class="FusedRMSNorm",
_norm_class="RMSNorm",
norm_eps=1e-5,
_mlp_class="LLaMAMLP",
intermediate_size=4096,
Expand All @@ -653,7 +656,7 @@ def norm_class(self) -> Type:
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
_norm_class="FusedRMSNorm",
_norm_class="RMSNorm",
norm_eps=1e-5,
_mlp_class="LLaMAMLP",
intermediate_size=6144,
Expand All @@ -673,7 +676,7 @@ def norm_class(self) -> Type:
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
_norm_class="FusedRMSNorm",
_norm_class="RMSNorm",
norm_eps=1e-5,
_mlp_class="LLaMAMLP",
intermediate_size=6144,
Expand All @@ -693,7 +696,7 @@ def norm_class(self) -> Type:
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
_norm_class="FusedRMSNorm",
_norm_class="RMSNorm",
norm_eps=1e-5,
_mlp_class="LLaMAMLP",
intermediate_size=6144,
Expand All @@ -712,7 +715,7 @@ def norm_class(self) -> Type:
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
_norm_class="FusedRMSNorm",
_norm_class="RMSNorm",
norm_eps=1e-5,
_mlp_class="LLaMAMLP",
intermediate_size=6144,
Expand All @@ -731,7 +734,7 @@ def norm_class(self) -> Type:
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
_norm_class="FusedRMSNorm",
_norm_class="RMSNorm",
norm_eps=1e-5,
_mlp_class="LLaMAMLP",
intermediate_size=6144,
Expand All @@ -750,7 +753,7 @@ def norm_class(self) -> Type:
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
_norm_class="FusedRMSNorm",
_norm_class="RMSNorm",
norm_eps=1e-5,
_mlp_class="LLaMAMLP",
intermediate_size=6144,
Expand All @@ -769,7 +772,7 @@ def norm_class(self) -> Type:
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
_norm_class="FusedRMSNorm",
_norm_class="RMSNorm",
norm_eps=1e-5,
_mlp_class="LLaMAMLP",
intermediate_size=6144,
Expand All @@ -787,7 +790,7 @@ def norm_class(self) -> Type:
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
_norm_class="FusedRMSNorm",
_norm_class="RMSNorm",
norm_eps=1e-5,
_mlp_class="LLaMAMLP",
intermediate_size=8192,
Expand Down
Loading