Skip to content

Commit

Permalink
Update with Matmulfree LM
Browse files Browse the repository at this point in the history
  • Loading branch information
ruijie-zhu committed May 21, 2024
1 parent 9e6fd30 commit b42939e
Show file tree
Hide file tree
Showing 3 changed files with 200 additions and 124 deletions.
183 changes: 62 additions & 121 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
</div>
<h2 align="center">MatMul-Free LLM</h2>
<h5 align="center"> If you like our project, please give us a star ⭐ on GitHub for the latest update. </h2>
<h5 align="center"> This repo is adapted from <a href="https://github.com/sustcsonglin/flash-linear-attention">flash-linear-attention</a>。 </h2>

<h5 align="center">

Expand All @@ -13,9 +14,7 @@ The following requirements should be satisfied
- [Triton](https://github.com/openai/triton) >=2.2
- [einops](https://einops.rocks/)

As `fla` is actively developed now, no released packages are provided at this time.
If you do need to use `fla` ops/modules and contemplate further explorations, an alternative way is to install the package from source
```sh
\```sh
pip install -U git+https://github.com/sustcsonglin/flash-linear-attention
```
or manage `fla` with submodules
Expand All @@ -24,140 +23,82 @@ git submodule add https://github.com/sustcsonglin/flash-linear-attention.git 3rd
ln -s 3rdparty/flash-linear-attention/mmfreelm mmfreelm
```

> [!CAUTION]
> If you're not working with Triton v2.2 or its nightly release, it's important to be aware of potential issues with the `FusedChunk` implementation, detailed in this [issue](https://github.com/openai/triton/issues/2852).
You can run the test `python tests/test_fused_chunk.py` to check if your version is affected by similar compiler problems.
While we offer some fixes for Triton<=2.1, be aware that these may result in reduced performance.
>
> For both Triton 2.2 and earlier versions (up to 2.1), you can reliably use the `Chunk` version (with hidden states materialized into HBMs).
> After careful optimization, this version generally delivers high performance in most scenarios.
# Usage

## Token Mixing

We provide "token mixing" linear attention layers in `fla.layers` for you to use.
You can replace the standard multihead attention layer in your model with other linear attention layers.
Example usage is as follows:

```py
>> > import torch
>> > from mmfreelm.layers import MultiScaleRetention
>> > batch_size, num_heads, seq_len, hidden_size, = 32, 4, 2048, 1024
>> > device, dtype = 'cuda:0', torch.bfloat16
>> > retnet = MultiScaleRetention(hidden_size=hidden_size, num_heads=num_heads).to(device=device, dtype=dtype)
>> > x = torch.randn(batch_size, seq_len, hidden_size).to(device=device, dtype=dtype)
>> > y, *_ = retnet(x)
>> > y.shape
torch.Size([32, 2048, 1024])
```
## Model

We provide the implementations of models that are compatible with 🤗 Transformers library.
Here's an example of how to initialize a GLA model from the default configs in `fla`:
This is a huggingface-compatible libary that you can use such command to initize the model with huggingface `AutoModel`:

```py
>> > from mmfreelm.models import GLAConfig
>> > from transformers import AutoModel
>> > config = GLAConfig()
>> > config
GLAConfig
{
"attn_mode": "fused_chunk",
"bos_token_id": 1,
"clamp_min": null,
"conv_size": 4,
"eos_token_id": 2,
"expand_k": 0.5,
"expand_v": 1,
"fuse_cross_entropy": true,
"fuse_norm": true,
"hidden_act": "swish",
"hidden_ratio": 4,
"hidden_size": 2048,
"initializer_range": 0.02,
"intermediate_size": null,
"max_position_embeddings": 2048,
"model_type": "gla",
"num_heads": 4,
"num_hidden_layers": 24,
"rms_norm_eps": 1e-06,
"share_conv_kernel": true,
"tie_word_embeddings": false,
"transformers_version": "4.39.1",
"use_cache": true,
"use_gk": true,
"use_gv": false,
"use_short_conv": false,
"vocab_size": 32000
}

>> > AutoModel.from_config(config)
GLAModel(
(embed_tokens): Embedding(32000, 2048)
(layers): ModuleList(
(0 - 23): 24
x
GLABlock(
(attn_norm): RMSNorm()
(attn): GatedLinearAttention(
(gate_fn): SiLU()
(q_proj): Linear(in_features=2048, out_features=1024, bias=False)
(k_proj): Linear(in_features=2048, out_features=1024, bias=False)
(v_proj): Linear(in_features=2048, out_features=2048, bias=False)
(g_proj): Linear(in_features=2048, out_features=2048, bias=False)
(gk_proj): Sequential(
(0): Linear(in_features=2048, out_features=16, bias=False)
(1): Linear(in_features=16, out_features=1024, bias=True)
)
(o_proj): Linear(in_features=2048, out_features=2048, bias=False)
(g_norm_swish_gate): FusedRMSNormSwishGate()
)
(mlp_norm): RMSNorm()
(mlp): GLAMLP(
(gate_proj): Linear(in_features=2048, out_features=11264, bias=False)
(down_proj): Linear(in_features=5632, out_features=2048, bias=False)
(act_fn): SiLU()
)
)
)
(norm): RMSNorm()
```py
>>> from mmfreelm.models import HGRNBitConfig
>>>
>>> from transformers import AutoModel
>>> config = HGRNBitConfig()
>>> AutoModel.from_config(config)
HGRNBitModel(
(embeddings): Embedding(32000, 2048)
(layers): ModuleList(
(0): HGRNBitBlock(
(attn_norm): RMSNorm(2048, eps=1e-06)
(attn): HGRNBitAttention(
(i_proj): FusedBitLinear(
in_features=2048, out_features=2048, bias=False
(norm): RMSNorm(2048, eps=1e-08)
)
(f_proj): FusedBitLinear(
in_features=2048, out_features=2048, bias=False
(norm): RMSNorm(2048, eps=1e-08)
)
(g_proj): FusedBitLinear(
in_features=2048, out_features=2048, bias=False
(norm): RMSNorm(2048, eps=1e-08)
)
(g_norm): FusedRMSNormSwishGate()
(o_proj): FusedBitLinear(
in_features=2048, out_features=2048, bias=False
(norm): RMSNorm(2048, eps=1e-08)
)
)
(mlp_norm): RMSNorm(2048, eps=1e-06)
(mlp): HGRNBitMLP(
(gate_proj): FusedBitLinear(
in_features=2048, out_features=11264, bias=False
(norm): RMSNorm(2048, eps=1e-08)
)
(down_proj): FusedBitLinear(
in_features=5632, out_features=2048, bias=False
(norm): RMSNorm(5632, eps=1e-08)
)
(act_fn): SiLU()
)
)

)
>>>

```

## Generation

Upon successfully pretraining a model, it becomes accessible for generating text using the 🤗 text generation APIs.
In the following, we give a generation example:
In the following, we give a generation example in `generate.py`:

```py
>> > import mmfreelm
>> > from transformers import AutoModelForCausalLM, AutoTokenizer
>> > name = 'mmfreelm-hub/gla-340M-15B'
>> > tokenizer = AutoTokenizer.from_pretrained(name)
>> > model = AutoModelForCausalLM.from_pretrained(name).cuda()
>> > input_prompt = "Power goes with permanence. Impermanence is impotence. And rotation is castration."
>> > input_ids = tokenizer(input_prompt, return_tensors="pt").input_ids.cuda()
>> > outputs = model.generate(input_ids, max_length=64)
>> > tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
```

We also provide a simple script [here](benchmarks/benchmark_generation.py) for benchmarking the generation speed.
Simply run it by:
```sh
$ python -m benchmarks.benchmark_generation \
--path 'mmfreelm-hub/gla-340M-15B' \
--repetition_penalty 2. \
--prompt="Hello everyone, I'm Songlin Yang"

Prompt:
Hello everyone, I'm Songlin Yang
Generated:
Hello everyone, I'm Songlin Yang.
I am a 20 year old girl from China who is currently studying in the United States of America for my Master degree and also working as an English teacher at school here on campus since last summer (1st semester). My main goal to be able do well with this course so that we can have

Prompt length: 10, generation length: 64
Total prompt processing + decoding time: 4593ms
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
import mmfreelm
from transformers import AutoModelForCausalLM, AutoTokenizer
#Change here to our open-sourced model
name = ''
tokenizer = AutoTokenizer.from_pretrained(name)
model = AutoModelForCausalLM.from_pretrained(name).cuda().half()
input_prompt = "In a shocking finding, scientist discovered a herd of unicorns living in a remote, "
input_ids = tokenizer(input_prompt, return_tensors="pt").input_ids.cuda()
outputs = model.generate(input_ids, max_length=32, do_sample=True, top_p=0.4, temperature=0.6)
print(tokenizer.batch_decode(outputs, skip_special_tokens=True)[0])
```

All of the pretrained models currently available can be found in [`fla-hub`](https://huggingface.co/fla-hub).
Expand Down
7 changes: 4 additions & 3 deletions generate.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false" # 或者 "true"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
import mmfreelm
from transformers import AutoModelForCausalLM, AutoTokenizer
name = '/vol2/matmulfreellm/hgrn_bit_1.3B_100B_realbit'
#Change here to our open-sourced model
name = ''
tokenizer = AutoTokenizer.from_pretrained(name)
model = AutoModelForCausalLM.from_pretrained(name).cuda().half()
input_prompt = "I am "
input_prompt = "In a shocking finding, scientist discovered a herd of unicorns living in a remote, "
input_ids = tokenizer(input_prompt, return_tensors="pt").input_ids.cuda()
outputs = model.generate(input_ids, max_length=32, do_sample=True, top_p=0.4, temperature=0.6)
print(tokenizer.batch_decode(outputs, skip_special_tokens=True)[0])
134 changes: 134 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
import ast
import os
import re
import subprocess
import warnings
from pathlib import Path

import torch
from packaging.version import Version, parse
from setuptools import find_packages, setup
from torch.utils.cpp_extension import CUDA_HOME

with open('README.md') as f:
long_description = f.read()

# ninja build does not work unless include_dirs are abs path
this_dir = os.path.dirname(os.path.abspath(__file__))

PACKAGE_NAME = 'mmfreelm'

# FORCE_BUILD: force a fresh build locally, instead of attempting to find prebuilt wheels
FORCE_BUILD = os.getenv('FLA_FORCE_BUILD', "FALSE") == 'TRUE'
# SKIP_CUDA_BUILD: allow CI to use a simple `python setup.py sdist` run to copy over raw files, without any cuda compilation
SKIP_CUDA_BUILD = os.getenv('FLA_SKIP_CUDA_BUILD', "TRUE") == 'TRUE'
# For CI, we want the option to build with C++11 ABI since the nvcr images use C++11 ABI
FORCE_CXX11_ABI = os.getenv('FLA_FORCE_CXX11_ABI', "FALSE") == 'TRUE'


def get_cuda_bare_metal_version(cuda_dir):
raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True)
output = raw_output.split()
release_idx = output.index("release") + 1
bare_metal_version = parse(output[release_idx].split(",")[0])

return raw_output, bare_metal_version


def check_if_cuda_home_none(global_option: str) -> None:
if CUDA_HOME is not None:
return
# warn instead of error because user could be downloading prebuilt wheels, so nvcc won't be necessary
# in that case.
warnings.warn(
f"{global_option} was requested, but nvcc was not found. Are you sure your environment has nvcc available? "
"If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, "
"only images whose names contain 'devel' will provide nvcc."
)


def append_nvcc_threads(nvcc_extra_args):
return nvcc_extra_args + ["--threads", "4"]


ext_modules = []
if not SKIP_CUDA_BUILD:
print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__))
TORCH_MAJOR = int(torch.__version__.split(".")[0])
TORCH_MINOR = int(torch.__version__.split(".")[1])

# Check, if ATen/CUDAGeneratorImpl.h is found, otherwise use ATen/cuda/CUDAGeneratorImpl.h
# See https://github.com/pytorch/pytorch/pull/70650
generator_flag = []
torch_dir = torch.__path__[0]
if os.path.exists(os.path.join(torch_dir, "include", "ATen", "CUDAGeneratorImpl.h")):
generator_flag = ["-DOLD_GENERATOR_PATH"]

check_if_cuda_home_none('mmfreelm')
# Check, if CUDA11 is installed for compute capability 8.0
cc_flag = []
if CUDA_HOME is not None:
_, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME)
if bare_metal_version < Version("11.6"):
raise RuntimeError(
"FLA is only supported on CUDA 11.6 and above. "
"Note: make sure nvcc has a supported version by running nvcc -V."
)
cc_flag.append("-gencode")
cc_flag.append("arch=compute_80,code=sm_80")
if CUDA_HOME is not None:
if bare_metal_version >= Version("11.8"):
cc_flag.append("-gencode")
cc_flag.append("arch=compute_90,code=sm_90")

extra_compile_args = {
"cxx": ["-O3", "-std=c++17"] + generator_flag,
"nvcc": append_nvcc_threads(
[
"-O3",
"-std=c++17",
"-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF_CONVERSIONS__",
"-U__CUDA_NO_HALF2_OPERATORS__",
"-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
"--expt-relaxed-constexpr",
"--expt-extended-lambda",
"--use_fast_math",
]
+ generator_flag
+ cc_flag
),
}


def get_package_version():
with open(Path(this_dir) / 'mmfreelm' / '__init__.py') as f:
version_match = re.search(r"^__version__\s*=\s*(.*)$", f.read(), re.MULTILINE)
return ast.literal_eval(version_match.group(1))


setup(
name=PACKAGE_NAME,
version=get_package_version(),
description='Implementation for Matmul-free LM',
long_description=long_description,
long_description_content_type='text/markdown',
author='',
author_email='',
url='',
packages=find_packages(),
license='MIT',
classifiers=[
'Programming Language :: Python :: 3',
'License :: OSI Approved :: MIT License',
'Operating System :: OS Independent',
'Topic :: Scientific/Engineering :: Artificial Intelligence'
],
python_requires='>=3.7',
install_requires=[
'triton',
'transformers',
'einops',
'ninja'
]
)

0 comments on commit b42939e

Please sign in to comment.