Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions .github/workflows/lint.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
name: Python lint

on:
push:
branches:
- main
pull_request:
branches:
- main

jobs:
lint_langchain:
name: Lint LangChain - Python ${{ matrix.python-version }}
runs-on: ubuntu-latest
defaults:
run:
working-directory: ./langchain
strategy:
matrix:
python-version: ['3.9', '3.13']
steps:
- uses: actions/checkout@v4
with:
ref: ${{ github.ref }}
- name: Install uv
uses: astral-sh/setup-uv@v3
with:
enable-cache: true
cache-dependency-glob: "langchain/uv.lock"
- name: "Set up Python"
uses: actions/setup-python@v5
with:
python-version-file: "langchain/pyproject.toml"
- name: Restore uv cache
uses: actions/cache@v4
with:
path: /tmp/.uv-cache
key: uv-langchain-${{ hashFiles('langchain/uv.lock') }}
restore-keys: |
uv-langchain-${{ hashFiles('langchain/uv.lock') }}
uv-${{ runner.os }}
- name: Install the project
run: uv sync --dev
- name: Run ruff format check
run: uv run ruff format --check
- name: Run ruff check
run: uv run ruff check
# - name: Run mypy
# run: uv run mypy .
- name: Minimize uv cache
run: uv cache prune --ci
50 changes: 50 additions & 0 deletions .github/workflows/python_test.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
name: Python tests

on:
push:
branches:
- main
pull_request:
branches:
- main

jobs:
build-langchain:
name: LangChain Unit Tests - Python ${{ matrix.python-version }}
runs-on: ubuntu-latest
defaults:
run:
working-directory: ./langchain
strategy:
matrix:
python-version: ['3.9', '3.10', '3.11', '3.12', '3.13']
steps:
- uses: actions/checkout@v4
with:
ref: ${{ github.ref }}
- name: Install uv
uses: astral-sh/setup-uv@v3
with:
enable-cache: true
cache-dependency-glob: "langchain/uv.lock"
- name: "Set up Python"
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Restore uv cache
uses: actions/cache@v4
with:
path: /tmp/.uv-cache
key: uv-langchain-${{ hashFiles('langchain/uv.lock') }}
restore-keys: |
uv-langchain-${{ hashFiles('langchain/uv.lock') }}
- name: Install the project
run: uv sync --dev
- name: Run unit tests
env:
VECTORIZE_TOKEN: ${{ secrets.VECTORIZE_TOKEN }}
VECTORIZE_ORG: ${{ secrets.VECTORIZE_ORG }}
VECTORIZE_ENV: dev
run: uv run pytest tests -vv
- name: Minimize uv cache
run: uv cache prune --ci
22 changes: 20 additions & 2 deletions langchain/langchain_vectorize/retrievers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from __future__ import annotations

from typing import TYPE_CHECKING, Any, Optional
from typing import TYPE_CHECKING, Any, Literal, Optional

import vectorize_client
from langchain_core.documents import Document
Expand Down Expand Up @@ -40,6 +40,8 @@ class VectorizeRetriever(BaseRetriever):

api_token: str
"""The Vectorize API token."""
environment: Literal["prod", "dev", "local", "staging"] = "prod"
"""The Vectorize API environment."""
organization: Optional[str] = None # noqa: UP007
"""The Vectorize organization ID."""
pipeline_id: Optional[str] = None # noqa: UP007
Expand All @@ -55,7 +57,23 @@ class VectorizeRetriever(BaseRetriever):

@override
def model_post_init(self, /, context: Any) -> None:
api = ApiClient(Configuration(access_token=self.api_token))
header_name = None
header_value = None
if self.environment == "prod":
host = "https://api.vectorize.io/v1"
elif self.environment == "dev":
host = "https://api-dev.vectorize.io/v1"
elif self.environment == "local":
host = "http://localhost:3000/api"
header_name = "x-lambda-api-key"
header_value = self.api_token
else:
host = "https://api-staging.vectorize.io/v1"
api = ApiClient(
Configuration(host=host, access_token=self.api_token, debug=True),
header_name,
header_value,
)
self._pipelines = PipelinesApi(api)

@staticmethod
Expand Down
90 changes: 61 additions & 29 deletions langchain/tests/test_retrievers.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,24 @@
import json
import logging
import os
import time
from collections.abc import Iterator
from dataclasses import dataclass
from pathlib import Path
from typing import Literal

import pytest
import urllib3
import vectorize_client as v
from vectorize_client import ApiClient

from langchain_vectorize.retrievers import VectorizeRetriever


@dataclass
class TestContext:
api_client: v.ApiClient
api_token: str
org_id: str


@pytest.fixture(scope="session")
def api_token() -> str:
token = os.getenv("VECTORIZE_TOKEN")
if not token:
msg = "Please set VECTORIZE_TOKEN environment variable"
msg = "Please set the VECTORIZE_TOKEN environment variable"
raise ValueError(msg)
return token

Expand All @@ -31,21 +27,29 @@ def api_token() -> str:
def org_id() -> str:
org = os.getenv("VECTORIZE_ORG")
if not org:
msg = "Please set VECTORIZE_ORG environment variable"
msg = "Please set the VECTORIZE_ORG environment variable"
raise ValueError(msg)
return org


@pytest.fixture(scope="session")
def api_client(api_token: str) -> Iterator[TestContext]:
def environment() -> Literal["prod", "dev", "local", "staging"]:
env = os.getenv("VECTORIZE_ENV", "prod")
if env not in ["prod", "dev", "local", "staging"]:
msg = "Invalid VECTORIZE_ENV environment variable."
raise ValueError(msg)
return env


@pytest.fixture(scope="session")
def api_client(api_token: str, environment: str) -> Iterator[ApiClient]:
header_name = None
header_value = None
if env == "prod":
if environment == "prod":
host = "https://api.vectorize.io/v1"
elif env == "dev":
elif environment == "dev":
host = "https://api-dev.vectorize.io/v1"
elif env == "local":
elif environment == "local":
host = "http://localhost:3000/api"
header_name = "x-lambda-api-key"
header_value = api_token
Expand Down Expand Up @@ -87,8 +91,6 @@ def pipeline_id(api_client: v.ApiClient, org_id: str) -> Iterator[str]:
),
)

import urllib3

http = urllib3.PoolManager()
this_dir = Path(__file__).parent
file_path = this_dir / "research.pdf"
Expand Down Expand Up @@ -137,7 +139,9 @@ def pipeline_id(api_client: v.ApiClient, org_id: str) -> Iterator[str]:
config={},
),
ai_platform=v.AIPlatformSchema(
id=builtin_ai_platform, type=v.AIPlatformType.VECTORIZE, config=v.AIPlatformConfigSchema()
id=builtin_ai_platform,
type=v.AIPlatformType.VECTORIZE,
config=v.AIPlatformConfigSchema(),
),
pipeline_name="Test pipeline",
schedule=v.ScheduleSchema(type=v.ScheduleSchemaType.MANUAL),
Expand All @@ -154,20 +158,48 @@ def pipeline_id(api_client: v.ApiClient, org_id: str) -> Iterator[str]:
logging.exception("Failed to delete pipeline %s", pipeline_id)


def test_retrieve_init_args(api_token: str, org_id: str, pipeline_id: str) -> None:
def test_retrieve_init_args(
environment: Literal["prod", "dev", "local", "staging"],
api_token: str,
org_id: str,
pipeline_id: str,
) -> None:
retriever = VectorizeRetriever(
api_token=api_token, organization=org_id, pipeline_id=pipeline_id, num_results=2
)
docs = retriever.invoke(input="What are you?")
assert len(docs) == 2


def test_retrieve_invoke_args(api_token: str, org_id: str, pipeline_id: str) -> None:
retriever = VectorizeRetriever(api_token=api_token)
docs = retriever.invoke(
input="What are you?",
environment=environment,
api_token=api_token,
organization=org_id,
pipeline_id=pipeline_id,
num_results=2,
)
assert len(docs) == 2
start = time.time()
while True:
docs = retriever.invoke(input="What are you?")
if len(docs) == 2:
break
if time.time() - start > 180:
msg = "Docs not retrieved in time"
raise RuntimeError(msg)
time.sleep(1)


def test_retrieve_invoke_args(
environment: Literal["prod", "dev", "local", "staging"],
api_token: str,
org_id: str,
pipeline_id: str,
) -> None:
retriever = VectorizeRetriever(environment=environment, api_token=api_token)
start = time.time()
while True:
docs = retriever.invoke(
input="What are you?",
organization=org_id,
pipeline_id=pipeline_id,
num_results=2,
)
if len(docs) == 2:
break
if time.time() - start > 180:
msg = "Docs not retrieved in time"
raise RuntimeError(msg)
time.sleep(1)
14 changes: 0 additions & 14 deletions langchain/uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.