Skip to content

Commit

Permalink
add nvidia nim llm support (#13176)
Browse files Browse the repository at this point in the history
  • Loading branch information
mattf committed May 7, 2024
1 parent 2d3b718 commit 4431e99
Show file tree
Hide file tree
Showing 21 changed files with 1,843 additions and 18 deletions.
862 changes: 862 additions & 0 deletions docs/docs/examples/llm/nvidia.ipynb

Large diffs are not rendered by default.

153 changes: 153 additions & 0 deletions llama-index-integrations/llms/llama-index-llms-nvidia/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
llama_index/_static
.DS_Store
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
bin/
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
etc/
include/
lib/
lib64/
parts/
sdist/
share/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
.ruff_cache

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
target/

# Jupyter Notebook
.ipynb_checkpoints
notebooks/

# IPython
profile_default/
ipython_config.py

# pyenv
.python-version

# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock

# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/

# Celery stuff
celerybeat-schedule
celerybeat.pid

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
pyvenv.cfg

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/
.dmypy.json
dmypy.json

# Pyre type checker
.pyre/

# Jetbrains
.idea
modules/
*.swp

# VsCode
.vscode

# pipenv
Pipfile
Pipfile.lock

# pyright
pyrightconfig.json
3 changes: 3 additions & 0 deletions llama-index-integrations/llms/llama-index-llms-nvidia/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
poetry_requirements(
name="poetry",
)
17 changes: 17 additions & 0 deletions llama-index-integrations/llms/llama-index-llms-nvidia/Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
GIT_ROOT ?= $(shell git rev-parse --show-toplevel)

help: ## Show all Makefile targets.
@grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[33m%-30s\033[0m %s\n", $$1, $$2}'

format: ## Run code autoformatters (black).
pre-commit install
git ls-files | xargs pre-commit run black --files

lint: ## Run linters: pre-commit (black, ruff, codespell) and mypy
pre-commit install && git ls-files | xargs pre-commit run --show-diff-on-failure --files

test: ## Run tests via pytest.
pytest tests

watch-docs: ## Build and watch documentation.
sphinx-autobuild docs/ docs/_build/html --open-browser --watch $(GIT_ROOT)/llama_index/
37 changes: 37 additions & 0 deletions llama-index-integrations/llms/llama-index-llms-nvidia/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# NVIDIA's LLM connector

Install the connector,

```shell
pip install llama-index-llms-nvidia
```

With this connector, you'll be able to connect to and generate from compatible models available as hosted [NVIDIA NIMs](https://ai.nvidia.com), such as:

- Google's [gemma-7b](https://build.nvidia.com/google/gemma-7b)
- Mistal AI's [mistral-7b-instruct-v0.2](https://build.nvidia.com/mistralai/mistral-7b-instruct-v2)
- And more!

_First_, get a free API key. Go to https://build.nvidia.com, select a model, click "Get API Key".
Store this key in your environment as `NVIDIA_API_KEY`.

_Then_, try it out.

```python
from llama_index.llms.nvidia import NVIDIA
from llama_index.core.llms import ChatMessage, MessageRole

llm = NVIDIA()

messages = [
ChatMessage(
role=MessageRole.SYSTEM, content=("You are a helpful assistant.")
),
ChatMessage(
role=MessageRole.USER,
content=("What are the most popular house pets in North America?"),
),
]

llm.chat(messages)
```
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
python_sources()
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from llama_index.llms.nvidia.base import NVIDIA

__all__ = ["NVIDIA"]
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
from typing import (
Any,
Optional,
List,
Literal,
)

from llama_index.core.bridge.pydantic import PrivateAttr, BaseModel
from llama_index.core.base.llms.generic_utils import (
get_from_param_or_env,
)

from llama_index.llms.nvidia.utils import API_CATALOG_MODELS

from llama_index.llms.openai_like import OpenAILike

DEFAULT_MODEL = "meta/llama3-8b-instruct"
BASE_URL = "https://integrate.api.nvidia.com/v1/"


class Model(BaseModel):
id: str


class NVIDIA(OpenAILike):
"""NVIDIA's API Catalog Connector."""

_mode: str = PrivateAttr("nvidia")

def __init__(
self,
model: str = DEFAULT_MODEL,
nvidia_api_key: Optional[str] = None,
api_key: Optional[str] = None,
**kwargs: Any,
) -> None:
api_key = get_from_param_or_env(
"api_key",
nvidia_api_key or api_key,
"NVIDIA_API_KEY",
"NO_API_KEY_PROVIDED",
)

super().__init__(
model=model,
api_key=api_key,
api_base=BASE_URL,
is_chat_model=True,
default_headers={"User-Agent": "llama-index-llms-nvidia"},
**kwargs,
)

@property
def available_models(self) -> List[Model]:
ids = API_CATALOG_MODELS.keys()
if self._mode == "nim":
ids = [model.id for model in self._get_client().models.list()]
return [Model(id=name) for name in ids]

@classmethod
def class_name(cls) -> str:
return "NVIDIA"

def mode(
self,
mode: Optional[Literal["nvidia", "nim"]] = "nvidia",
*,
base_url: Optional[str] = None,
model: Optional[str] = None,
api_key: Optional[str] = None,
) -> "NVIDIA":
"""
Change the mode.
There are two modes, "nvidia" and "nim". The "nvidia" mode is the default
mode and is used to interact with hosted NIMs. The "nim" mode is used to
interact with NVIDIA NIM endpoints, which are typically hosted on-premises.
For the "nvidia" mode, the "api_key" parameter is available to specify
your API key. If not specified, the NVIDIA_API_KEY environment variable
will be used.
For the "nim" mode, the "base_url" parameter is required and the "model"
parameter may be necessary. Set base_url to the url of your local NIM
endpoint. For instance, "https://localhost:9999/v1". Additionally, the
"model" parameter must be set to the name of the model inside the NIM.
"""
if mode == "nim":
if not base_url:
raise ValueError("base_url is required for nim mode")
if mode == "nvidia":
api_key = get_from_param_or_env(
"api_key",
api_key,
"NVIDIA_API_KEY",
)
base_url = base_url or BASE_URL

self._mode = mode
if base_url:
self.api_base = base_url
if model:
self.model = model
if api_key:
self.api_key = api_key

return self
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from typing import Dict, Optional

API_CATALOG_MODELS: Dict[str, int] = {
"mistralai/mistral-7b-instruct-v0.2": 16384,
"mistralai/mixtral-8x7b-instruct-v0.1": 16384,
"mistralai/mixtral-8x22b-instruct-v0.1": 32768,
"mistralai/mistral-large": 16384,
"google/gemma-7b": 4096,
"google/gemma-2b": 4096,
"google/codegemma-7b": 4096,
"meta/llama2-70b": 1024,
"meta/codellama-70b": 1024,
"meta/llama3-8b-instruct": 6000,
"meta/llama3-70b-instruct": 6000,
}


def catalog_modelname_to_contextsize(modelname: str) -> Optional[int]:
return API_CATALOG_MODELS.get(modelname, None)

0 comments on commit 4431e99

Please sign in to comment.