-
Notifications
You must be signed in to change notification settings - Fork 4.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
21 changed files
with
1,843 additions
and
18 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
153 changes: 153 additions & 0 deletions
153
llama-index-integrations/llms/llama-index-llms-nvidia/.gitignore
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,153 @@ | ||
llama_index/_static | ||
.DS_Store | ||
# Byte-compiled / optimized / DLL files | ||
__pycache__/ | ||
*.py[cod] | ||
*$py.class | ||
|
||
# C extensions | ||
*.so | ||
|
||
# Distribution / packaging | ||
.Python | ||
bin/ | ||
build/ | ||
develop-eggs/ | ||
dist/ | ||
downloads/ | ||
eggs/ | ||
.eggs/ | ||
etc/ | ||
include/ | ||
lib/ | ||
lib64/ | ||
parts/ | ||
sdist/ | ||
share/ | ||
var/ | ||
wheels/ | ||
pip-wheel-metadata/ | ||
share/python-wheels/ | ||
*.egg-info/ | ||
.installed.cfg | ||
*.egg | ||
MANIFEST | ||
|
||
# PyInstaller | ||
# Usually these files are written by a python script from a template | ||
# before PyInstaller builds the exe, so as to inject date/other infos into it. | ||
*.manifest | ||
*.spec | ||
|
||
# Installer logs | ||
pip-log.txt | ||
pip-delete-this-directory.txt | ||
|
||
# Unit test / coverage reports | ||
htmlcov/ | ||
.tox/ | ||
.nox/ | ||
.coverage | ||
.coverage.* | ||
.cache | ||
nosetests.xml | ||
coverage.xml | ||
*.cover | ||
*.py,cover | ||
.hypothesis/ | ||
.pytest_cache/ | ||
.ruff_cache | ||
|
||
# Translations | ||
*.mo | ||
*.pot | ||
|
||
# Django stuff: | ||
*.log | ||
local_settings.py | ||
db.sqlite3 | ||
db.sqlite3-journal | ||
|
||
# Flask stuff: | ||
instance/ | ||
.webassets-cache | ||
|
||
# Scrapy stuff: | ||
.scrapy | ||
|
||
# Sphinx documentation | ||
docs/_build/ | ||
|
||
# PyBuilder | ||
target/ | ||
|
||
# Jupyter Notebook | ||
.ipynb_checkpoints | ||
notebooks/ | ||
|
||
# IPython | ||
profile_default/ | ||
ipython_config.py | ||
|
||
# pyenv | ||
.python-version | ||
|
||
# pipenv | ||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. | ||
# However, in case of collaboration, if having platform-specific dependencies or dependencies | ||
# having no cross-platform support, pipenv may install dependencies that don't work, or not | ||
# install all needed dependencies. | ||
#Pipfile.lock | ||
|
||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow | ||
__pypackages__/ | ||
|
||
# Celery stuff | ||
celerybeat-schedule | ||
celerybeat.pid | ||
|
||
# SageMath parsed files | ||
*.sage.py | ||
|
||
# Environments | ||
.env | ||
.venv | ||
env/ | ||
venv/ | ||
ENV/ | ||
env.bak/ | ||
venv.bak/ | ||
pyvenv.cfg | ||
|
||
# Spyder project settings | ||
.spyderproject | ||
.spyproject | ||
|
||
# Rope project settings | ||
.ropeproject | ||
|
||
# mkdocs documentation | ||
/site | ||
|
||
# mypy | ||
.mypy_cache/ | ||
.dmypy.json | ||
dmypy.json | ||
|
||
# Pyre type checker | ||
.pyre/ | ||
|
||
# Jetbrains | ||
.idea | ||
modules/ | ||
*.swp | ||
|
||
# VsCode | ||
.vscode | ||
|
||
# pipenv | ||
Pipfile | ||
Pipfile.lock | ||
|
||
# pyright | ||
pyrightconfig.json |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
poetry_requirements( | ||
name="poetry", | ||
) |
17 changes: 17 additions & 0 deletions
17
llama-index-integrations/llms/llama-index-llms-nvidia/Makefile
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
GIT_ROOT ?= $(shell git rev-parse --show-toplevel) | ||
|
||
help: ## Show all Makefile targets. | ||
@grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[33m%-30s\033[0m %s\n", $$1, $$2}' | ||
|
||
format: ## Run code autoformatters (black). | ||
pre-commit install | ||
git ls-files | xargs pre-commit run black --files | ||
|
||
lint: ## Run linters: pre-commit (black, ruff, codespell) and mypy | ||
pre-commit install && git ls-files | xargs pre-commit run --show-diff-on-failure --files | ||
|
||
test: ## Run tests via pytest. | ||
pytest tests | ||
|
||
watch-docs: ## Build and watch documentation. | ||
sphinx-autobuild docs/ docs/_build/html --open-browser --watch $(GIT_ROOT)/llama_index/ |
37 changes: 37 additions & 0 deletions
37
llama-index-integrations/llms/llama-index-llms-nvidia/README.md
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
# NVIDIA's LLM connector | ||
|
||
Install the connector, | ||
|
||
```shell | ||
pip install llama-index-llms-nvidia | ||
``` | ||
|
||
With this connector, you'll be able to connect to and generate from compatible models available as hosted [NVIDIA NIMs](https://ai.nvidia.com), such as: | ||
|
||
- Google's [gemma-7b](https://build.nvidia.com/google/gemma-7b) | ||
- Mistal AI's [mistral-7b-instruct-v0.2](https://build.nvidia.com/mistralai/mistral-7b-instruct-v2) | ||
- And more! | ||
|
||
_First_, get a free API key. Go to https://build.nvidia.com, select a model, click "Get API Key". | ||
Store this key in your environment as `NVIDIA_API_KEY`. | ||
|
||
_Then_, try it out. | ||
|
||
```python | ||
from llama_index.llms.nvidia import NVIDIA | ||
from llama_index.core.llms import ChatMessage, MessageRole | ||
|
||
llm = NVIDIA() | ||
|
||
messages = [ | ||
ChatMessage( | ||
role=MessageRole.SYSTEM, content=("You are a helpful assistant.") | ||
), | ||
ChatMessage( | ||
role=MessageRole.USER, | ||
content=("What are the most popular house pets in North America?"), | ||
), | ||
] | ||
|
||
llm.chat(messages) | ||
``` |
1 change: 1 addition & 0 deletions
1
llama-index-integrations/llms/llama-index-llms-nvidia/llama_index/llms/nvidia/BUILD
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
python_sources() |
3 changes: 3 additions & 0 deletions
3
llama-index-integrations/llms/llama-index-llms-nvidia/llama_index/llms/nvidia/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from llama_index.llms.nvidia.base import NVIDIA | ||
|
||
__all__ = ["NVIDIA"] |
107 changes: 107 additions & 0 deletions
107
llama-index-integrations/llms/llama-index-llms-nvidia/llama_index/llms/nvidia/base.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
from typing import ( | ||
Any, | ||
Optional, | ||
List, | ||
Literal, | ||
) | ||
|
||
from llama_index.core.bridge.pydantic import PrivateAttr, BaseModel | ||
from llama_index.core.base.llms.generic_utils import ( | ||
get_from_param_or_env, | ||
) | ||
|
||
from llama_index.llms.nvidia.utils import API_CATALOG_MODELS | ||
|
||
from llama_index.llms.openai_like import OpenAILike | ||
|
||
DEFAULT_MODEL = "meta/llama3-8b-instruct" | ||
BASE_URL = "https://integrate.api.nvidia.com/v1/" | ||
|
||
|
||
class Model(BaseModel): | ||
id: str | ||
|
||
|
||
class NVIDIA(OpenAILike): | ||
"""NVIDIA's API Catalog Connector.""" | ||
|
||
_mode: str = PrivateAttr("nvidia") | ||
|
||
def __init__( | ||
self, | ||
model: str = DEFAULT_MODEL, | ||
nvidia_api_key: Optional[str] = None, | ||
api_key: Optional[str] = None, | ||
**kwargs: Any, | ||
) -> None: | ||
api_key = get_from_param_or_env( | ||
"api_key", | ||
nvidia_api_key or api_key, | ||
"NVIDIA_API_KEY", | ||
"NO_API_KEY_PROVIDED", | ||
) | ||
|
||
super().__init__( | ||
model=model, | ||
api_key=api_key, | ||
api_base=BASE_URL, | ||
is_chat_model=True, | ||
default_headers={"User-Agent": "llama-index-llms-nvidia"}, | ||
**kwargs, | ||
) | ||
|
||
@property | ||
def available_models(self) -> List[Model]: | ||
ids = API_CATALOG_MODELS.keys() | ||
if self._mode == "nim": | ||
ids = [model.id for model in self._get_client().models.list()] | ||
return [Model(id=name) for name in ids] | ||
|
||
@classmethod | ||
def class_name(cls) -> str: | ||
return "NVIDIA" | ||
|
||
def mode( | ||
self, | ||
mode: Optional[Literal["nvidia", "nim"]] = "nvidia", | ||
*, | ||
base_url: Optional[str] = None, | ||
model: Optional[str] = None, | ||
api_key: Optional[str] = None, | ||
) -> "NVIDIA": | ||
""" | ||
Change the mode. | ||
There are two modes, "nvidia" and "nim". The "nvidia" mode is the default | ||
mode and is used to interact with hosted NIMs. The "nim" mode is used to | ||
interact with NVIDIA NIM endpoints, which are typically hosted on-premises. | ||
For the "nvidia" mode, the "api_key" parameter is available to specify | ||
your API key. If not specified, the NVIDIA_API_KEY environment variable | ||
will be used. | ||
For the "nim" mode, the "base_url" parameter is required and the "model" | ||
parameter may be necessary. Set base_url to the url of your local NIM | ||
endpoint. For instance, "https://localhost:9999/v1". Additionally, the | ||
"model" parameter must be set to the name of the model inside the NIM. | ||
""" | ||
if mode == "nim": | ||
if not base_url: | ||
raise ValueError("base_url is required for nim mode") | ||
if mode == "nvidia": | ||
api_key = get_from_param_or_env( | ||
"api_key", | ||
api_key, | ||
"NVIDIA_API_KEY", | ||
) | ||
base_url = base_url or BASE_URL | ||
|
||
self._mode = mode | ||
if base_url: | ||
self.api_base = base_url | ||
if model: | ||
self.model = model | ||
if api_key: | ||
self.api_key = api_key | ||
|
||
return self |
19 changes: 19 additions & 0 deletions
19
llama-index-integrations/llms/llama-index-llms-nvidia/llama_index/llms/nvidia/utils.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
from typing import Dict, Optional | ||
|
||
API_CATALOG_MODELS: Dict[str, int] = { | ||
"mistralai/mistral-7b-instruct-v0.2": 16384, | ||
"mistralai/mixtral-8x7b-instruct-v0.1": 16384, | ||
"mistralai/mixtral-8x22b-instruct-v0.1": 32768, | ||
"mistralai/mistral-large": 16384, | ||
"google/gemma-7b": 4096, | ||
"google/gemma-2b": 4096, | ||
"google/codegemma-7b": 4096, | ||
"meta/llama2-70b": 1024, | ||
"meta/codellama-70b": 1024, | ||
"meta/llama3-8b-instruct": 6000, | ||
"meta/llama3-70b-instruct": 6000, | ||
} | ||
|
||
|
||
def catalog_modelname_to_contextsize(modelname: str) -> Optional[int]: | ||
return API_CATALOG_MODELS.get(modelname, None) |
Oops, something went wrong.