diff --git a/.gitignore b/.gitignore index 61692d1..4e187c7 100644 --- a/.gitignore +++ b/.gitignore @@ -135,5 +135,4 @@ dmypy.json data/*.json data/*.jsonl dbs/meilisearch/meili_data -*/*/onnx_model/onnx -*/*/lancedb/*.lance \ No newline at end of file +dbs/lancedb/winemag \ No newline at end of file diff --git a/dbs/lancedb/.env.example b/dbs/lancedb/.env.example index 7f99fc7..19d865c 100644 --- a/dbs/lancedb/.env.example +++ b/dbs/lancedb/.env.example @@ -1,4 +1,4 @@ -LANCEDB_DIR = "lancedb" +LANCEDB_DIR = "winemag" API_PORT = 8006 EMBEDDING_MODEL_CHECKPOINT = "sentence-transformers/multi-qa-MiniLM-L6-cos-v1" diff --git a/dbs/lancedb/Dockerfile b/dbs/lancedb/Dockerfile index 2a47f56..2182431 100644 --- a/dbs/lancedb/Dockerfile +++ b/dbs/lancedb/Dockerfile @@ -1,4 +1,4 @@ -FROM python:3.10-slim-bullseye +FROM python:3.11-slim-bullseye WORKDIR /wine @@ -7,6 +7,7 @@ COPY ./requirements.txt /wine/requirements.txt RUN pip install --no-cache-dir -U pip wheel setuptools RUN pip install --no-cache-dir -r /wine/requirements.txt +COPY ./winemag /wine/winemag COPY ./api /wine/api COPY ./schemas /wine/schemas diff --git a/dbs/lancedb/api/main.py b/dbs/lancedb/api/main.py index 59bcac1..77250d8 100644 --- a/dbs/lancedb/api/main.py +++ b/dbs/lancedb/api/main.py @@ -4,13 +4,12 @@ import lancedb from fastapi import FastAPI +from fastapi.middleware.cors import CORSMiddleware from sentence_transformers import SentenceTransformer from api.config import Settings from api.routers.rest import router -model_type = "sbert" - @lru_cache() def get_settings(): @@ -24,9 +23,8 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: settings = get_settings() model_checkpoint = settings.embedding_model_checkpoint app.model = SentenceTransformer(model_checkpoint) - app.model_type = "sbert" # Define LanceDB client - db = lancedb.connect("./lancedb") + db = lancedb.connect("./winemag") app.table = db.open_table("wines") print("Successfully connected to LanceDB") yield @@ -52,3 +50,11 @@ async def root(): # Attach routes app.include_router(router, prefix="/wine", tags=["wine"]) + +# Add CORS middleware +app.add_middleware( + CORSMiddleware, + allow_origins=["http://localhost:8000"], + allow_methods=["GET"], + allow_headers=["*"], +) diff --git a/dbs/lancedb/docker-compose.yml b/dbs/lancedb/docker-compose.yml index bde7b26..97b743b 100644 --- a/dbs/lancedb/docker-compose.yml +++ b/dbs/lancedb/docker-compose.yml @@ -2,7 +2,6 @@ version: "3.9" services: fastapi: - platform: linux/x86_64 image: lancedb_wine_fastapi:${TAG} build: context: . @@ -12,6 +11,6 @@ services: - .env ports: - ${API_PORT}:8000 - volumes: - - ./:/winex - command: uvicorn api.main:app --host 0.0.0.0 --port 8000 --reload \ No newline at end of file + environment: + - LANCEDB_CONFIG_DIR=/wine + command: uvicorn api.main:app --host 0.0.0.0 --port 8000 --reload diff --git a/dbs/lancedb/requirements.txt b/dbs/lancedb/requirements.txt index ba8a6ce..1745552 100644 --- a/dbs/lancedb/requirements.txt +++ b/dbs/lancedb/requirements.txt @@ -1,10 +1,10 @@ -lancedb~=0.2.0 +lancedb~=0.3.0 transformers~=4.28.0 sentence-transformers~=2.2.0 pydantic~=2.3.0 pydantic-settings~=2.0.0 python-dotenv>=1.0.0 -fastapi~=0.100.0 +fastapi~=0.104.0 httpx>=0.24.0 aiohttp>=3.8.4 uvicorn>=0.21.0, <1.0.0 diff --git a/dbs/lancedb/scripts/bulk_index_sbert.py b/dbs/lancedb/scripts/bulk_index_sbert.py index 29b862f..5914745 100644 --- a/dbs/lancedb/scripts/bulk_index_sbert.py +++ b/dbs/lancedb/scripts/bulk_index_sbert.py @@ -12,12 +12,13 @@ from codetiming import Timer from dotenv import load_dotenv from lancedb.pydantic import pydantic_to_schema +from sentence_transformers import SentenceTransformer from tqdm import tqdm sys.path.insert(1, os.path.realpath(Path(__file__).resolve().parents[1])) from api.config import Settings from schemas.wine import LanceModelWine, Wine -from sentence_transformers import SentenceTransformer + load_dotenv() # Custom types @@ -97,7 +98,7 @@ def embed_batches(tbl: str, validated_data: list[JsonBlob]) -> pd.DataFrame: def main(data: list[JsonBlob]) -> None: - DB_NAME = "../lancedb" + DB_NAME = f"../{get_settings().lancedb_dir}" TABLE = "wines" db = lancedb.connect(DB_NAME) @@ -115,7 +116,7 @@ def main(data: list[JsonBlob]) -> None: with Timer(name="Create index", text="Created IVF-PQ index in {:.4f} sec"): # Creating index (choose num partitions as a power of 2 that's closest to len(dataset) // 5000) # In this case, we have 130k datapoints, so the nearest power of 2 is 130000//5000 ~ 32) - tbl.create_index(metric="cosine", num_partitions=32, num_sub_vectors=96) + tbl.create_index(metric="cosine", num_partitions=4, num_sub_vectors=32) if __name__ == "__main__": diff --git a/dbs/qdrant/.env.example b/dbs/qdrant/.env.example index 300cf76..f788749 100644 --- a/dbs/qdrant/.env.example +++ b/dbs/qdrant/.env.example @@ -1,8 +1,8 @@ -QDRANT_VERSION = "v1.3.0" +QDRANT_VERSION = "v1.6.1" QDRANT_PORT = 6333 QDRANT_HOST = "localhost" QDRANT_SERVICE = "qdrant" -API_PORT = 8005 +API_PORT = 8000 EMBEDDING_MODEL_CHECKPOINT = "sentence-transformers/multi-qa-MiniLM-L6-cos-v1" ONNX_MODEL_FILENAME = "model_optimized_quantized.onnx" diff --git a/dbs/qdrant/api/main.py b/dbs/qdrant/api/main.py index 17463ad..56fcd27 100644 --- a/dbs/qdrant/api/main.py +++ b/dbs/qdrant/api/main.py @@ -3,6 +3,7 @@ from functools import lru_cache from fastapi import FastAPI +from fastapi.middleware.cors import CORSMiddleware from qdrant_client import QdrantClient from sentence_transformers import SentenceTransformer @@ -26,7 +27,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: app.model = SentenceTransformer(model_checkpoint) app.model_type = "sbert" # Define Qdrant client - app.client = QdrantClient(host=settings.qdrant_service, port=settings.qdrant_port) + app.client = QdrantClient(host=settings.qdrant_service, port=settings.qdrant_port, timeout=None) print("Successfully connected to Qdrant") yield print("Successfully closed Qdrant connection and released resources") @@ -49,5 +50,14 @@ async def root(): } +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + expose_headers=["*"], +) + # Attach routes app.include_router(rest.router, prefix="/wine", tags=["wine"]) diff --git a/dbs/qdrant/requirements.txt b/dbs/qdrant/requirements.txt index 399eda5..83dc155 100644 --- a/dbs/qdrant/requirements.txt +++ b/dbs/qdrant/requirements.txt @@ -1,10 +1,10 @@ -qdrant-client~=1.3.0 -transformers~=4.28.0 +qdrant-client~=1.6.0 +transformers~=4.33.0 sentence-transformers~=2.2.0 -pydantic~=2.0.0 +pydantic~=2.4.0 pydantic-settings>=2.0.0 python-dotenv>=1.0.0 -fastapi~=0.100.0 +fastapi~=0.104.0 httpx>=0.24.0 aiohttp>=3.8.4 uvloop>=0.17.0