-
Notifications
You must be signed in to change notification settings - Fork 187
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
9e6fd30
commit b42939e
Showing
3 changed files
with
200 additions
and
124 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,11 +1,12 @@ | ||
| import os | ||
| os.environ["TOKENIZERS_PARALLELISM"] = "false" # 或者 "true" | ||
| os.environ["TOKENIZERS_PARALLELISM"] = "false" | ||
| import mmfreelm | ||
| from transformers import AutoModelForCausalLM, AutoTokenizer | ||
| name = '/vol2/matmulfreellm/hgrn_bit_1.3B_100B_realbit' | ||
| #Change here to our open-sourced model | ||
| name = '' | ||
| tokenizer = AutoTokenizer.from_pretrained(name) | ||
| model = AutoModelForCausalLM.from_pretrained(name).cuda().half() | ||
| input_prompt = "I am " | ||
| input_prompt = "In a shocking finding, scientist discovered a herd of unicorns living in a remote, " | ||
| input_ids = tokenizer(input_prompt, return_tensors="pt").input_ids.cuda() | ||
| outputs = model.generate(input_ids, max_length=32, do_sample=True, top_p=0.4, temperature=0.6) | ||
| print(tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,134 @@ | ||
| import ast | ||
| import os | ||
| import re | ||
| import subprocess | ||
| import warnings | ||
| from pathlib import Path | ||
|
|
||
| import torch | ||
| from packaging.version import Version, parse | ||
| from setuptools import find_packages, setup | ||
| from torch.utils.cpp_extension import CUDA_HOME | ||
|
|
||
| with open('README.md') as f: | ||
| long_description = f.read() | ||
|
|
||
| # ninja build does not work unless include_dirs are abs path | ||
| this_dir = os.path.dirname(os.path.abspath(__file__)) | ||
|
|
||
| PACKAGE_NAME = 'mmfreelm' | ||
|
|
||
| # FORCE_BUILD: force a fresh build locally, instead of attempting to find prebuilt wheels | ||
| FORCE_BUILD = os.getenv('FLA_FORCE_BUILD', "FALSE") == 'TRUE' | ||
| # SKIP_CUDA_BUILD: allow CI to use a simple `python setup.py sdist` run to copy over raw files, without any cuda compilation | ||
| SKIP_CUDA_BUILD = os.getenv('FLA_SKIP_CUDA_BUILD', "TRUE") == 'TRUE' | ||
| # For CI, we want the option to build with C++11 ABI since the nvcr images use C++11 ABI | ||
| FORCE_CXX11_ABI = os.getenv('FLA_FORCE_CXX11_ABI', "FALSE") == 'TRUE' | ||
|
|
||
|
|
||
| def get_cuda_bare_metal_version(cuda_dir): | ||
| raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True) | ||
| output = raw_output.split() | ||
| release_idx = output.index("release") + 1 | ||
| bare_metal_version = parse(output[release_idx].split(",")[0]) | ||
|
|
||
| return raw_output, bare_metal_version | ||
|
|
||
|
|
||
| def check_if_cuda_home_none(global_option: str) -> None: | ||
| if CUDA_HOME is not None: | ||
| return | ||
| # warn instead of error because user could be downloading prebuilt wheels, so nvcc won't be necessary | ||
| # in that case. | ||
| warnings.warn( | ||
| f"{global_option} was requested, but nvcc was not found. Are you sure your environment has nvcc available? " | ||
| "If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, " | ||
| "only images whose names contain 'devel' will provide nvcc." | ||
| ) | ||
|
|
||
|
|
||
| def append_nvcc_threads(nvcc_extra_args): | ||
| return nvcc_extra_args + ["--threads", "4"] | ||
|
|
||
|
|
||
| ext_modules = [] | ||
| if not SKIP_CUDA_BUILD: | ||
| print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__)) | ||
| TORCH_MAJOR = int(torch.__version__.split(".")[0]) | ||
| TORCH_MINOR = int(torch.__version__.split(".")[1]) | ||
|
|
||
| # Check, if ATen/CUDAGeneratorImpl.h is found, otherwise use ATen/cuda/CUDAGeneratorImpl.h | ||
| # See https://github.com/pytorch/pytorch/pull/70650 | ||
| generator_flag = [] | ||
| torch_dir = torch.__path__[0] | ||
| if os.path.exists(os.path.join(torch_dir, "include", "ATen", "CUDAGeneratorImpl.h")): | ||
| generator_flag = ["-DOLD_GENERATOR_PATH"] | ||
|
|
||
| check_if_cuda_home_none('mmfreelm') | ||
| # Check, if CUDA11 is installed for compute capability 8.0 | ||
| cc_flag = [] | ||
| if CUDA_HOME is not None: | ||
| _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) | ||
| if bare_metal_version < Version("11.6"): | ||
| raise RuntimeError( | ||
| "FLA is only supported on CUDA 11.6 and above. " | ||
| "Note: make sure nvcc has a supported version by running nvcc -V." | ||
| ) | ||
| cc_flag.append("-gencode") | ||
| cc_flag.append("arch=compute_80,code=sm_80") | ||
| if CUDA_HOME is not None: | ||
| if bare_metal_version >= Version("11.8"): | ||
| cc_flag.append("-gencode") | ||
| cc_flag.append("arch=compute_90,code=sm_90") | ||
|
|
||
| extra_compile_args = { | ||
| "cxx": ["-O3", "-std=c++17"] + generator_flag, | ||
| "nvcc": append_nvcc_threads( | ||
| [ | ||
| "-O3", | ||
| "-std=c++17", | ||
| "-U__CUDA_NO_HALF_OPERATORS__", | ||
| "-U__CUDA_NO_HALF_CONVERSIONS__", | ||
| "-U__CUDA_NO_HALF2_OPERATORS__", | ||
| "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", | ||
| "--expt-relaxed-constexpr", | ||
| "--expt-extended-lambda", | ||
| "--use_fast_math", | ||
| ] | ||
| + generator_flag | ||
| + cc_flag | ||
| ), | ||
| } | ||
|
|
||
|
|
||
| def get_package_version(): | ||
| with open(Path(this_dir) / 'mmfreelm' / '__init__.py') as f: | ||
| version_match = re.search(r"^__version__\s*=\s*(.*)$", f.read(), re.MULTILINE) | ||
| return ast.literal_eval(version_match.group(1)) | ||
|
|
||
|
|
||
| setup( | ||
| name=PACKAGE_NAME, | ||
| version=get_package_version(), | ||
| description='Implementation for Matmul-free LM', | ||
| long_description=long_description, | ||
| long_description_content_type='text/markdown', | ||
| author='', | ||
| author_email='', | ||
| url='', | ||
| packages=find_packages(), | ||
| license='MIT', | ||
| classifiers=[ | ||
| 'Programming Language :: Python :: 3', | ||
| 'License :: OSI Approved :: MIT License', | ||
| 'Operating System :: OS Independent', | ||
| 'Topic :: Scientific/Engineering :: Artificial Intelligence' | ||
| ], | ||
| python_requires='>=3.7', | ||
| install_requires=[ | ||
| 'triton', | ||
| 'transformers', | ||
| 'einops', | ||
| 'ninja' | ||
| ] | ||
| ) |