diff --git a/.env.example b/.env.example index 06a80f89b..3af1b716c 100644 --- a/.env.example +++ b/.env.example @@ -71,6 +71,8 @@ OPENAI_VISION_MODEL=qwen/qwen3-vl-32b-instruct # Use a non-thinking model to avoid JSON truncation issues with reasoning models # BAML_LLM_MODEL=deepseek/deepseek-v3.2 EMBEDDING_DIMENSIONS=2560 +# Override embedding dimensions for crawler search (defaults to EMBEDDING_DIMENSIONS) +CRAWLER_EMBEDDING_DIMENSIONS=1536 # ============================================================================ # OPTIONAL: Database Configuration diff --git a/compose.yml b/compose.yml index 27004d355..906146777 100644 --- a/compose.yml +++ b/compose.yml @@ -26,7 +26,7 @@ services: # ============================================================================ - # Tale DB (TimescaleDB) + # Tale DB (ParadeDB — pg_search + pgvector) # ============================================================================ db: # Image from GHCR (used when PULL_POLICY=always) @@ -148,6 +148,10 @@ services: # cpus: '1' # memory: 2G + # Dependencies + depends_on: + - db + # Volume mounts # Persist crawler data (website registry + per-site URL databases) volumes: diff --git a/services/crawler/app/config.py b/services/crawler/app/config.py index 269ce4865..dafe612d2 100644 --- a/services/crawler/app/config.py +++ b/services/crawler/app/config.py @@ -43,6 +43,13 @@ class Settings(BaseSettings): # Concurrency for Vision processing vision_max_concurrent_pages: int = 3 + # Database configuration + database_url: str | None = None + + # Embedding model configuration + openai_embedding_model: str | None = None + embedding_dimensions: int | None = None + model_config = SettingsConfigDict( env_prefix="CRAWLER_", env_file=".env", @@ -76,6 +83,26 @@ def get_fast_model(self) -> str: raise ValueError("OPENAI_FAST_MODEL must be set in environment.") return model + def get_embedding_model(self) -> str: + """Get embedding model from CRAWLER_OPENAI_EMBEDDING_MODEL or OPENAI_EMBEDDING_MODEL.""" + model = get_first_model(self.openai_embedding_model) or get_first_model( + os.environ.get("OPENAI_EMBEDDING_MODEL") + ) + if not model: + raise ValueError("OPENAI_EMBEDDING_MODEL must be set in environment.") + return model + + def get_embedding_dimensions(self) -> int: + """Get embedding dimensions from CRAWLER_EMBEDDING_DIMENSIONS or EMBEDDING_DIMENSIONS.""" + dims = self.embedding_dimensions + if dims is None: + raw = os.environ.get("EMBEDDING_DIMENSIONS") + if raw is not None: + dims = int(raw) + if dims is None: + raise ValueError("EMBEDDING_DIMENSIONS must be set in environment.") + return dims + # Global settings instance settings = Settings() diff --git a/services/crawler/app/main.py b/services/crawler/app/main.py index e1c33e0b8..221b7c40c 100644 --- a/services/crawler/app/main.py +++ b/services/crawler/app/main.py @@ -3,7 +3,7 @@ Independent web crawling service using Crawl4AI. Provides REST API for website crawling, URL discovery, document conversion, -template generation, and file parsing. +template generation, file parsing, content indexing, and hybrid search. This module follows Clean Architecture principles: - main.py: Application setup, configuration, and router registration @@ -27,16 +27,17 @@ crawler_router, docx_router, image_router, + index_router, + pages_router, pdf_router, pptx_router, + search_router, web_router, websites_router, ) from app.services.crawler_service import get_crawler_service from app.services.image_service import get_image_service from app.services.pdf_service import get_pdf_service -from app.services.scheduler import run_scheduler -from app.services.website_store import get_website_store_manager @asynccontextmanager @@ -54,11 +55,36 @@ async def lifespan(app: FastAPI) -> AsyncGenerator: logger.info("Crawler service initialized successfully") except Exception: logger.exception("Failed to initialize crawler service") - # Don't fail startup - allow lazy initialization + + # Initialize PostgreSQL connection pool + search services + from app.services.database import close_pool, init_pool + from app.services.embedding_service import get_embedding_service + from app.services.indexing_service import IndexingService + from app.services.pg_website_store import PgWebsiteStoreManager + from app.services.scheduler import run_scheduler + from app.services.search_service import SearchService + + pool = await init_pool() + pg_store_manager = PgWebsiteStoreManager(pool) + embedding_service = get_embedding_service() + indexing_service = IndexingService(pool, embedding_service) + search_service = SearchService(pool, embedding_service) + + # Wire services into routers + from app.routers.index import set_indexing_service + from app.routers.search import set_search_service + + set_search_service(search_service) + set_indexing_service(indexing_service) + + # Store references for scheduler and other routers + app.state.pg_store_manager = pg_store_manager + app.state.indexing_service = indexing_service + + logger.info("PostgreSQL pool + search services initialized") # Start background scheduler - store_manager = get_website_store_manager() - scheduler_task = asyncio.create_task(run_scheduler(store_manager, get_crawler_service())) + scheduler_task = asyncio.create_task(run_scheduler(pg_store_manager, get_crawler_service(), indexing_service)) logger.info("Background scheduler started") yield @@ -66,16 +92,14 @@ async def lifespan(app: FastAPI) -> AsyncGenerator: # Shutdown logger.info("Shutting down Tale Crawler service...") - # Stop scheduler scheduler_task.cancel() with suppress(asyncio.CancelledError): await scheduler_task logger.info("Scheduler stopped") - # Close all website stores - store_manager.close_all() + await pg_store_manager.close() + await close_pool() - # Cleanup crawler service try: crawler = get_crawler_service() if crawler.initialized: @@ -109,6 +133,9 @@ async def lifespan(app: FastAPI) -> AsyncGenerator: # Register routers app.include_router(crawler_router) app.include_router(websites_router) +app.include_router(search_router) +app.include_router(pages_router) +app.include_router(index_router) app.include_router(pdf_router) app.include_router(image_router) app.include_router(docx_router) diff --git a/services/crawler/app/models.py b/services/crawler/app/models.py index d12af2f5d..13d0b6900 100644 --- a/services/crawler/app/models.py +++ b/services/crawler/app/models.py @@ -3,8 +3,9 @@ """ from typing import Any, Literal +from urllib.parse import urlparse -from pydantic import BaseModel, Field, HttpUrl +from pydantic import BaseModel, Field, HttpUrl, field_validator # Valid Playwright wait_until values WaitUntilType = Literal["load", "domcontentloaded", "networkidle", "commit"] @@ -27,6 +28,30 @@ class RegisterWebsiteRequest(BaseModel): domain: str = Field(..., description="The domain to register (e.g., 'docs.example.com')") scan_interval: int = Field(21600, description="Scan interval in seconds (default: 6h)", ge=60) + @field_validator("domain") + @classmethod + def normalize_domain(cls, v: str) -> str: + """Strip protocol/path — store bare hostname only.""" + if "://" in v: + return urlparse(v).hostname or v + return v + + +class WebsiteInfoResponse(BaseModel): + """Full website information.""" + + domain: str + title: str | None = None + description: str | None = None + page_count: int = 0 + crawled_count: int = 0 + status: str = "idle" + scan_interval: int = 21600 + last_scanned_at: str | None = None + error: str | None = None + created_at: str | None = None + updated_at: str | None = None + class WebsiteUrl(BaseModel): """A tracked URL with content hash.""" @@ -284,3 +309,107 @@ class WebFetchExtractResponse(BaseModel): page_count: int = Field(..., description="Number of pages in PDF") vision_used: bool = Field(False, description="Whether Vision API was used for extraction") error: str | None = Field(None, description="Error message if operation failed") + + +# ==================== Search Models ==================== + + +class SearchRequest(BaseModel): + """Request for hybrid search.""" + + query: str = Field(..., description="Search query") + limit: int = Field(10, ge=1, le=100, description="Maximum results") + + +class SearchResultItem(BaseModel): + """A single search result.""" + + url: str + title: str | None = None + chunk_content: str + chunk_index: int + score: float + + +class SearchResponse(BaseModel): + """Response from search endpoint.""" + + query: str + results: list[SearchResultItem] = Field(default_factory=list) + total: int + + +# ==================== Pages List Models ==================== + + +class PageListItem(BaseModel): + """A page in the pages list.""" + + url: str + title: str | None = None + word_count: int = 0 + status: str = "discovered" + content_hash: str | None = None + last_crawled_at: str | None = None + discovered_at: str | None = None + chunks_count: int = 0 + indexed: bool = False + + +class PageListResponse(BaseModel): + """Paginated response of pages for a website.""" + + domain: str + pages: list[PageListItem] = Field(default_factory=list) + total: int = 0 + offset: int = 0 + has_more: bool = False + + +class PageChunkItem(BaseModel): + """A single chunk from a page.""" + + chunk_index: int + chunk_content: str + + +class PageChunksResponse(BaseModel): + """Response containing all chunks for a specific page.""" + + url: str + domain: str + chunks: list[PageChunkItem] = Field(default_factory=list) + total: int = 0 + + +# ==================== Indexing Models ==================== + + +class IndexPageRequest(BaseModel): + """Request to index a single page.""" + + domain: str = Field(..., description="Website domain") + url: str = Field(..., description="Page URL") + title: str | None = Field(None, description="Page title") + content: str = Field(..., description="Page content to index") + + +class IndexPageResponse(BaseModel): + """Response from indexing a single page.""" + + success: bool + url: str + chunks_indexed: int + status: str + error: str | None = None + + +class IndexWebsiteResponse(BaseModel): + """Response from indexing all pages for a website.""" + + success: bool + domain: str + pages_indexed: int + pages_skipped: int + pages_failed: int + total_chunks: int diff --git a/services/crawler/app/routers/__init__.py b/services/crawler/app/routers/__init__.py index 32a17be9a..4cbfb3ef7 100644 --- a/services/crawler/app/routers/__init__.py +++ b/services/crawler/app/routers/__init__.py @@ -4,6 +4,9 @@ This package contains modular routers following Clean Architecture principles: - crawler: Content fetching and URL check endpoints (/api/v1/urls) - websites: Website registration and URL listing (/api/v1/websites) +- search: Hybrid full-text + vector search (/api/v1/search) +- pages: List indexed pages per website (/api/v1/pages) +- index: Content indexing management (/api/v1/index) - pdf: PDF conversion and parsing (/api/v1/pdf) - image: Image conversion (/api/v1/images) - docx: DOCX document generation and parsing (/api/v1/docx) @@ -14,8 +17,11 @@ from app.routers.crawler import router as crawler_router from app.routers.docx import router as docx_router from app.routers.image import router as image_router +from app.routers.index import router as index_router +from app.routers.pages import router as pages_router from app.routers.pdf import router as pdf_router from app.routers.pptx import router as pptx_router +from app.routers.search import router as search_router from app.routers.web import router as web_router from app.routers.websites import router as websites_router @@ -23,8 +29,11 @@ "crawler_router", "docx_router", "image_router", + "index_router", + "pages_router", "pdf_router", "pptx_router", + "search_router", "web_router", "websites_router", ] diff --git a/services/crawler/app/routers/crawler.py b/services/crawler/app/routers/crawler.py index e9143bd3a..51c60933c 100644 --- a/services/crawler/app/routers/crawler.py +++ b/services/crawler/app/routers/crawler.py @@ -4,7 +4,7 @@ from typing import Annotated -from fastapi import APIRouter, HTTPException, Query +from fastapi import APIRouter, HTTPException, Query, Request from loguru import logger from pydantic import HttpUrl @@ -14,13 +14,12 @@ PageContent, ) from app.services.crawler_service import get_crawler_service -from app.services.website_store import get_website_store_manager router = APIRouter(prefix="/api/v1/urls", tags=["Crawler"]) @router.post("/fetch", response_model=FetchUrlsResponse) -async def fetch_urls(request: FetchUrlsRequest): +async def fetch_urls(request: FetchUrlsRequest, http_request: Request): """ Fetch content from a list of specific URLs. @@ -28,8 +27,8 @@ async def fetch_urls(request: FetchUrlsRequest): falling back to live crawling for cache misses. """ try: - store_manager = get_website_store_manager() - cached, urls_to_crawl = store_manager.get_cached_pages(request.urls) + store_manager = http_request.app.state.pg_store_manager + cached, urls_to_crawl = await store_manager.get_cached_pages(request.urls) # Filter cached pages by word_count_threshold threshold = request.word_count_threshold diff --git a/services/crawler/app/routers/index.py b/services/crawler/app/routers/index.py new file mode 100644 index 000000000..46d849a4e --- /dev/null +++ b/services/crawler/app/routers/index.py @@ -0,0 +1,74 @@ +""" +Index Router — Content indexing management endpoints. +""" + +from fastapi import APIRouter, HTTPException +from loguru import logger + +from app.models import IndexPageRequest, IndexPageResponse, IndexWebsiteResponse +from app.services.indexing_service import IndexingService + +router = APIRouter(prefix="/api/v1/index", tags=["Indexing"]) + +_indexing_service: IndexingService | None = None + + +def set_indexing_service(service: IndexingService) -> None: + global _indexing_service + _indexing_service = service + + +def _get_indexing_service() -> IndexingService: + if _indexing_service is None: + raise HTTPException(status_code=503, detail="Indexing service not initialized") + return _indexing_service + + +@router.post("/page", response_model=IndexPageResponse) +async def index_page(request: IndexPageRequest): + """Index a single page (chunk + embed + store).""" + try: + service = _get_indexing_service() + result = await service.index_page( + domain=request.domain, + url=request.url, + title=request.title, + content=request.content, + ) + return IndexPageResponse( + success=result["status"] in ("indexed", "skipped"), + url=result["url"], + chunks_indexed=result["chunks_indexed"], + status=result["status"], + error=result.get("error"), + ) + except HTTPException: + raise + except Exception: + logger.exception(f"Indexing failed for {request.url}") + raise HTTPException(status_code=500, detail="Indexing failed") from None + + +@router.post("/website/{domain}", response_model=IndexWebsiteResponse) +async def index_website(domain: str): + """Re-index all pages for a website. + + Website status updates are handled by the scheduler during automated scans. + This endpoint is for manual/on-demand re-indexing only. + """ + try: + service = _get_indexing_service() + result = await service.index_website(domain) + return IndexWebsiteResponse( + success=True, + domain=result["domain"], + pages_indexed=result["pages_indexed"], + pages_skipped=result["pages_skipped"], + pages_failed=result["pages_failed"], + total_chunks=result["total_chunks"], + ) + except HTTPException: + raise + except Exception: + logger.exception(f"Website indexing failed for {domain}") + raise HTTPException(status_code=500, detail="Website indexing failed") from None diff --git a/services/crawler/app/routers/pages.py b/services/crawler/app/routers/pages.py new file mode 100644 index 000000000..6367ba4d1 --- /dev/null +++ b/services/crawler/app/routers/pages.py @@ -0,0 +1,136 @@ +""" +Pages Router — List indexed pages for a website. +""" + +from fastapi import APIRouter, HTTPException, Query +from loguru import logger + +from app.models import PageChunkItem, PageChunksResponse, PageListItem, PageListResponse +from app.services.database import get_pool + +router = APIRouter(prefix="/api/v1/pages", tags=["Pages"]) + + +@router.get("/{domain}", response_model=PageListResponse) +async def list_pages( + domain: str, + offset: int = Query(0, ge=0), + limit: int = Query(100, ge=1, le=1000), + status: str | None = Query(None, description="Filter by status (discovered, active, deleted, failed)"), + sort: str = Query("last_crawled_at", description="Sort field (last_crawled_at, discovered_at, word_count)"), +): + """List all crawled pages for a website with indexing status.""" + try: + pool = get_pool() + + sort_columns = { + "last_crawled_at": "wu.last_crawled_at", + "discovered_at": "wu.discovered_at", + "word_count": "wu.word_count", + } + sort_col = sort_columns.get(sort, "wu.last_crawled_at") + + async with pool.acquire() as conn: + # Build query with optional status filter + conditions = ["wu.domain = $1", "wu.content_hash IS NOT NULL"] + params: list = [domain] + param_idx = 2 + + if status: + conditions.append(f"wu.status = ${param_idx}") + params.append(status) + param_idx += 1 + + where_clause = " AND ".join(conditions) + + # Main query with chunk count via LEFT JOIN + params.extend([limit, offset]) + rows = await conn.fetch( + f"""SELECT wu.url, wu.title, wu.word_count, wu.status, wu.content_hash, + wu.last_crawled_at, wu.discovered_at, + COALESCE(c.chunks_count, 0) AS chunks_count + FROM website_urls wu + LEFT JOIN ( + SELECT url, COUNT(*) AS chunks_count + FROM chunks + GROUP BY url + ) c ON c.url = wu.url + WHERE {where_clause} + ORDER BY {sort_col} DESC NULLS LAST + LIMIT ${param_idx} OFFSET ${param_idx + 1}""", + *params, + ) + + # Total count + total = await conn.fetchval( + f"SELECT COUNT(*) FROM website_urls wu WHERE {where_clause}", + *params[: param_idx - 1], + ) + + pages = [ + PageListItem( + url=r["url"], + title=r["title"], + word_count=r["word_count"] or 0, + status=r["status"], + content_hash=r["content_hash"], + last_crawled_at=r["last_crawled_at"].isoformat() if r["last_crawled_at"] else None, + discovered_at=r["discovered_at"].isoformat() if r["discovered_at"] else None, + chunks_count=r["chunks_count"], + indexed=r["chunks_count"] > 0, + ) + for r in rows + ] + + return PageListResponse( + domain=domain, + pages=pages, + total=total, + offset=offset, + has_more=offset + limit < total, + ) + except HTTPException: + raise + except Exception: + logger.exception(f"Error listing pages for {domain}") + raise HTTPException(status_code=500, detail="Failed to list pages") from None + + +@router.get("/{domain}/chunks", response_model=PageChunksResponse) +async def get_page_chunks( + domain: str, + url: str = Query(..., description="The page URL to get chunks for"), +): + """Get all indexed chunks for a specific page URL.""" + try: + pool = get_pool() + + async with pool.acquire() as conn: + rows = await conn.fetch( + """SELECT chunk_index, chunk_content + FROM chunks + WHERE domain = $1 AND url = $2 + ORDER BY chunk_index ASC""", + domain, + url, + ) + + chunks = [ + PageChunkItem( + chunk_index=r["chunk_index"], + chunk_content=r["chunk_content"], + ) + for r in rows + ] + + return PageChunksResponse( + url=url, + domain=domain, + chunks=chunks, + total=len(chunks), + ) + except HTTPException: + raise + except Exception: + logger.exception(f"Error getting chunks for {url} in {domain}") + raise HTTPException(status_code=500, detail="Failed to get page chunks") from None diff --git a/services/crawler/app/routers/search.py b/services/crawler/app/routers/search.py new file mode 100644 index 000000000..3de50934d --- /dev/null +++ b/services/crawler/app/routers/search.py @@ -0,0 +1,78 @@ +""" +Search Router — Hybrid full-text + vector search across indexed website content. +""" + +from fastapi import APIRouter, HTTPException +from loguru import logger + +from app.models import SearchRequest, SearchResponse, SearchResultItem +from app.services.search_service import SearchService + +router = APIRouter(prefix="/api/v1/search", tags=["Search"]) + +_search_service: SearchService | None = None + + +def set_search_service(service: SearchService) -> None: + global _search_service + _search_service = service + + +def _get_search_service() -> SearchService: + if _search_service is None: + raise HTTPException(status_code=503, detail="Search service not initialized") + return _search_service + + +@router.post("", response_model=SearchResponse) +async def search_all(request: SearchRequest): + """Search across all indexed website content.""" + try: + service = _get_search_service() + results = await service.search(query=request.query, limit=request.limit) + return SearchResponse( + query=request.query, + results=[ + SearchResultItem( + url=r.url, + title=r.title, + chunk_content=r.chunk_content, + chunk_index=r.chunk_index, + score=r.score, + ) + for r in results + ], + total=len(results), + ) + except HTTPException: + raise + except Exception: + logger.exception("Search failed") + raise HTTPException(status_code=500, detail="Search failed") from None + + +@router.post("/{domain}", response_model=SearchResponse) +async def search_domain(domain: str, request: SearchRequest): + """Search within a specific website's indexed content.""" + try: + service = _get_search_service() + results = await service.search(query=request.query, domain=domain, limit=request.limit) + return SearchResponse( + query=request.query, + results=[ + SearchResultItem( + url=r.url, + title=r.title, + chunk_content=r.chunk_content, + chunk_index=r.chunk_index, + score=r.score, + ) + for r in results + ], + total=len(results), + ) + except HTTPException: + raise + except Exception: + logger.exception(f"Search failed for domain {domain}") + raise HTTPException(status_code=500, detail="Search failed") from None diff --git a/services/crawler/app/routers/websites.py b/services/crawler/app/routers/websites.py index f31d76296..a2cea483a 100644 --- a/services/crawler/app/routers/websites.py +++ b/services/crawler/app/routers/websites.py @@ -2,36 +2,160 @@ Websites Router — Website registration and URL listing endpoints. """ -from fastapi import APIRouter, HTTPException, Query +import asyncio +import hashlib +import json +from datetime import UTC, datetime + +from fastapi import APIRouter, HTTPException, Query, Request from loguru import logger -from app.models import RegisterWebsiteRequest, WebsiteUrl, WebsiteUrlsResponse -from app.services.scheduler import trigger_scan -from app.services.website_store import get_website_store_manager +from app.models import RegisterWebsiteRequest, WebsiteInfoResponse, WebsiteUrl, WebsiteUrlsResponse +from app.services.crawler_service import get_crawler_service +from app.services.pg_website_store import PgWebsiteStoreManager +from app.services.scheduler import cancel_scan, trigger_scan +from app.utils.metadata import extract_meta_description router = APIRouter(prefix="/api/v1/websites", tags=["Websites"]) -@router.post("") -async def register_website(request: RegisterWebsiteRequest): +def _get_manager(request: Request) -> PgWebsiteStoreManager: + return request.app.state.pg_store_manager + + +def _format_timestamp(val) -> str | None: + if val is None: + return None + if isinstance(val, datetime): + return val.isoformat() + if isinstance(val, (int, float)): + return datetime.fromtimestamp(val, tz=UTC).isoformat() + return str(val) + + +async def _initialize_website(domain: str, manager: PgWebsiteStoreManager): + """Background task: crawl homepage + discover URLs concurrently.""" + crawler_service = get_crawler_service() + if not crawler_service.initialized: + await crawler_service.initialize() + + site_store = manager.get_site_store(domain) + + async def _crawl_homepage(): + homepage_url = f"https://{domain}/" + try: + results = await crawler_service.crawl_urls(urls=[homepage_url]) + if not results: + return + page = results[0] + title = page.get("title") + sd = page.get("structured_data") + if isinstance(sd, str): + sd = json.loads(sd) + description = extract_meta_description(sd) + + await site_store.save_discovered_urls([{"url": homepage_url}]) + await site_store.update_content_hashes( + [ + { + "url": homepage_url, + "content_hash": hashlib.sha256(page["content"].encode()).hexdigest(), + "status": "active", + "title": title, + "content": page["content"], + "word_count": page.get("word_count", 0), + "metadata": page.get("metadata"), + "structured_data": sd, + } + ] + ) + await manager.update_website_metadata( + domain=domain, + title=title, + description=description, + page_count=1, + ) + except Exception: + logger.exception(f"Failed to crawl homepage for {domain}") + + async def _discover_urls(): + try: + discovered = await crawler_service.discover_urls(domain=domain, max_urls=-1) + if discovered: + await site_store.save_discovered_urls(discovered) + logger.info(f"Discovered {len(discovered)} URLs for {domain}") + except Exception: + logger.exception(f"URL discovery failed for {domain}") + + await asyncio.gather(_crawl_homepage(), _discover_urls()) + + await manager.update_last_scanned(domain) + await manager.update_scan_status(domain, "active") + + +@router.post("", response_model=WebsiteInfoResponse) +async def register_website(request: RegisterWebsiteRequest, http_request: Request): try: - manager = get_website_store_manager() - result = manager.register_website( + manager = _get_manager(http_request) + await manager.register_website( domain=request.domain, scan_interval=request.scan_interval, ) + + # Fire-and-forget: crawl homepage + discover URLs concurrently in background + def _on_init_done(t: asyncio.Task) -> None: + if not t.cancelled() and (exc := t.exception()): + logger.error(f"Website initialization failed for {request.domain}: {exc}") + + task = asyncio.create_task(_initialize_website(request.domain, manager)) + task.add_done_callback(_on_init_done) trigger_scan() - return result + + return WebsiteInfoResponse( + domain=request.domain, + status="scanning", + scan_interval=request.scan_interval, + ) except Exception: logger.exception("Error registering website") raise HTTPException(status_code=500, detail="Failed to register website") from None +@router.get("/{domain}", response_model=WebsiteInfoResponse) +async def get_website_info(domain: str, http_request: Request): + try: + manager = _get_manager(http_request) + website = await manager.get_website(domain) + + if not website: + raise HTTPException(status_code=404, detail=f"Website not found: {domain}") + + return WebsiteInfoResponse( + domain=website["domain"], + title=website.get("title"), + description=website.get("description"), + page_count=website.get("total_urls", 0), + crawled_count=website.get("crawled_count", 0), + status=website.get("status", "idle"), + scan_interval=website.get("scan_interval", 21600), + last_scanned_at=_format_timestamp(website.get("last_scanned_at")), + error=website.get("error"), + created_at=_format_timestamp(website.get("created_at")), + updated_at=_format_timestamp(website.get("updated_at")), + ) + except HTTPException: + raise + except Exception: + logger.exception("Error getting website info") + raise HTTPException(status_code=500, detail="Failed to get website info") from None + + @router.delete("/{domain}") -async def deregister_website(domain: str): +async def deregister_website(domain: str, http_request: Request): try: - manager = get_website_store_manager() - deleted = manager.remove_website(domain) + cancel_scan(domain) + manager = _get_manager(http_request) + deleted = await manager.remove_website(domain) if not deleted: raise HTTPException(status_code=404, detail=f"Website not found: {domain}") return {"domain": domain, "deleted": True} @@ -45,19 +169,21 @@ async def deregister_website(domain: str): @router.get("/{domain}/urls", response_model=WebsiteUrlsResponse) async def get_website_urls( domain: str, + http_request: Request, offset: int = Query(0, ge=0), limit: int = Query(100, ge=1, le=1000), status: str | None = Query(None), ): try: - manager = get_website_store_manager() - website = manager.get_website(domain) + manager = _get_manager(http_request) + website = await manager.get_website(domain) + if not website: raise HTTPException(status_code=404, detail=f"Website not found: {domain}") site_store = manager.get_site_store(domain) - urls_data = site_store.get_urls_page(offset=offset, limit=limit, status=status) - total = site_store.get_total_count(status=status) + urls_data = await site_store.get_urls_page(offset=offset, limit=limit, status=status) + total = await site_store.get_total_count(status=status) urls = [ WebsiteUrl( diff --git a/services/crawler/app/services/chunking_service.py b/services/crawler/app/services/chunking_service.py new file mode 100644 index 000000000..682d8607d --- /dev/null +++ b/services/crawler/app/services/chunking_service.py @@ -0,0 +1,59 @@ +""" +Markdown-aware content chunking for search indexing. + +Uses semantic-text-splitter's MarkdownSplitter to split at structural +boundaries (headers, code blocks, tables, paragraphs, sentences) while +respecting a target chunk size. Page title and URL are injected as +metadata prefix into every chunk. +""" + +from dataclasses import dataclass + +from semantic_text_splitter import MarkdownSplitter + +CHUNK_SIZE = 2048 +CHUNK_OVERLAP = 200 +MIN_CHUNK_LENGTH = 50 + + +@dataclass +class ContentChunk: + content: str + index: int + + +def _build_prefix(title: str | None, url: str | None) -> str: + parts: list[str] = [] + if title and title.strip(): + parts.append(title.strip()) + if url and url.strip(): + parts.append(url.strip()) + return "\n\n".join(parts) + "\n\n" if parts else "" + + +def chunk_content( + content: str, + title: str | None = None, + url: str | None = None, + chunk_size: int = CHUNK_SIZE, + chunk_overlap: int = CHUNK_OVERLAP, + min_chunk_length: int = MIN_CHUNK_LENGTH, +) -> list[ContentChunk]: + if not content or not content.strip(): + return [] + + prefix = _build_prefix(title, url) + effective_size = max(chunk_size - len(prefix), min_chunk_length) + + splitter = MarkdownSplitter(effective_size, overlap=chunk_overlap) + raw_chunks = splitter.chunks(content.strip()) + + chunks: list[ContentChunk] = [] + idx = 0 + for raw in raw_chunks: + text = (prefix + raw).strip() + if len(text) >= min_chunk_length: + chunks.append(ContentChunk(content=text, index=idx)) + idx += 1 + + return chunks diff --git a/services/crawler/app/services/database.py b/services/crawler/app/services/database.py new file mode 100644 index 000000000..fa098d472 --- /dev/null +++ b/services/crawler/app/services/database.py @@ -0,0 +1,90 @@ +""" +Async PostgreSQL connection pool using asyncpg. + +Provides a singleton pool tied to FastAPI's lifespan for the tale_search database. +""" + +import os + +import asyncpg +from loguru import logger + +from app.config import settings + +_pool: asyncpg.Pool | None = None + + +def _get_database_url() -> str: + if settings.database_url: + return settings.database_url + if url := os.environ.get("DATABASE_URL"): + return url + password = os.environ.get("DB_PASSWORD", "tale_password_change_me") + return f"postgresql://tale:{password}@db:5432/tale_search" + + +async def init_pool() -> asyncpg.Pool: + global _pool + if _pool is not None: + return _pool + + dsn = _get_database_url() + _pool = await asyncpg.create_pool(dsn, min_size=5, max_size=25) + logger.info("PostgreSQL connection pool initialized") + + # Guard against embedding dimension mismatch: if existing data uses a + # different dimension than the current config, refuse to start. + configured_dims = settings.get_embedding_dimensions() + async with _pool.acquire() as conn: + stored_dims = await conn.fetchval( + "SELECT vector_dims(embedding) FROM chunks WHERE embedding IS NOT NULL LIMIT 1" + ) + if stored_dims is not None and stored_dims != configured_dims: + await _pool.close() + _pool = None + raise RuntimeError( + f"Embedding dimension mismatch: database has {stored_dims}d vectors " + f"but CRAWLER_EMBEDDING_DIMENSIONS={configured_dims}. " + f"Re-index existing data or update the config to match." + ) + + # Pin the embedding column to explicit dimensions so HNSW indexes work. + # The column starts as untyped `vector` because dimensions are configurable; + # once we know the configured value we can lock it in. If the column was + # previously pinned to a different dimension (e.g. config changed while the + # table was empty), re-pin it — the mismatch guard above already ensures any + # existing data is compatible. + expected_type = f"vector({int(configured_dims)})" + async with _pool.acquire() as conn: + col_type = await conn.fetchval( + "SELECT format_type(atttypid, atttypmod) " + "FROM pg_attribute " + "WHERE attrelid = 'chunks'::regclass AND attname = 'embedding'" + ) + if col_type != expected_type: + await conn.execute("DROP INDEX IF EXISTS idx_chunks_embedding_hnsw") + await conn.execute(f"ALTER TABLE chunks ALTER COLUMN embedding TYPE vector({int(configured_dims)})") + logger.info(f"Pinned embedding column to vector({configured_dims}) (was {col_type})") + + # Create HNSW index if it doesn't exist yet. + try: + async with _pool.acquire() as conn: + await conn.execute("SELECT create_chunks_hnsw_index()") + except Exception as e: + logger.warning(f"HNSW index creation deferred: {e}") + + return _pool + + +def get_pool() -> asyncpg.Pool: + if _pool is None: + raise RuntimeError("Database pool not initialized. Call init_pool() first.") + return _pool + + +async def close_pool() -> None: + global _pool + if _pool is not None: + await _pool.close() + _pool = None + logger.info("PostgreSQL connection pool closed") diff --git a/services/crawler/app/services/embedding_service.py b/services/crawler/app/services/embedding_service.py new file mode 100644 index 000000000..9d8828557 --- /dev/null +++ b/services/crawler/app/services/embedding_service.py @@ -0,0 +1,83 @@ +""" +OpenAI-compatible embedding generation service. + +Uses the async OpenAI client to generate embeddings via any OpenAI-compatible API. +""" + +import asyncio + +from loguru import logger +from openai import AsyncOpenAI + +from app.config import settings + +MAX_BATCH_SIZE = 2048 +MAX_CONCURRENT_REQUESTS = 3 +MAX_RETRIES = 3 +RETRY_BASE_DELAY = 1.0 + + +class EmbeddingService: + def __init__(self, api_key: str, base_url: str | None, model: str, dimensions: int): + self._client = AsyncOpenAI(api_key=api_key, base_url=base_url) + self._model = model + self._dimensions = dimensions + self._semaphore = asyncio.Semaphore(MAX_CONCURRENT_REQUESTS) + + @property + def dimensions(self) -> int: + return self._dimensions + + async def _embed_batch(self, batch: list[str]) -> list[list[float]]: + async with self._semaphore: + for attempt in range(MAX_RETRIES): + try: + response = await self._client.embeddings.create( + model=self._model, + input=batch, + dimensions=self._dimensions, + ) + return [item.embedding for item in response.data] + except Exception: + if attempt == MAX_RETRIES - 1: + raise + delay = RETRY_BASE_DELAY * (2**attempt) + logger.warning( + f"Embedding request failed (attempt {attempt + 1}/{MAX_RETRIES}), retrying in {delay}s" + ) + await asyncio.sleep(delay) + raise RuntimeError("unreachable") + + async def embed_texts(self, texts: list[str]) -> list[list[float]]: + if not texts: + return [] + + all_embeddings: list[list[float]] = [] + for i in range(0, len(texts), MAX_BATCH_SIZE): + batch = texts[i : i + MAX_BATCH_SIZE] + batch_embeddings = await self._embed_batch(batch) + all_embeddings.extend(batch_embeddings) + + return all_embeddings + + async def embed_query(self, query: str) -> list[float]: + result = await self.embed_texts([query]) + return result[0] + + +_embedding_service: EmbeddingService | None = None + + +def get_embedding_service() -> EmbeddingService: + global _embedding_service + if _embedding_service is None: + _embedding_service = EmbeddingService( + api_key=settings.get_openai_api_key(), + base_url=settings.get_openai_base_url(), + model=settings.get_embedding_model(), + dimensions=settings.get_embedding_dimensions(), + ) + logger.info( + f"Embedding service: model={settings.get_embedding_model()}, dims={settings.get_embedding_dimensions()}" + ) + return _embedding_service diff --git a/services/crawler/app/services/indexing_service.py b/services/crawler/app/services/indexing_service.py new file mode 100644 index 000000000..9f637b680 --- /dev/null +++ b/services/crawler/app/services/indexing_service.py @@ -0,0 +1,143 @@ +""" +Content indexing pipeline: chunk → embed → store in PostgreSQL. +""" + +import asyncio +import hashlib +import logging + +import asyncpg + +from app.services.chunking_service import chunk_content +from app.services.embedding_service import EmbeddingService + +logger = logging.getLogger(__name__) + +INDEXING_CONCURRENCY = 5 + + +def _sha256(content: str) -> str: + return hashlib.sha256(content.encode()).hexdigest() + + +class IndexingService: + def __init__(self, pool: asyncpg.Pool, embedding_service: EmbeddingService): + self._pool = pool + self._embedding = embedding_service + self._hnsw_ensured = False + + async def index_page(self, domain: str, url: str, title: str | None, content: str) -> dict: + content_hash = _sha256(content) + + # Check if already indexed with same hash + async with self._pool.acquire() as conn: + existing_hash = await conn.fetchval("SELECT content_hash FROM chunks WHERE url = $1 LIMIT 1", url) + if existing_hash == content_hash: + return {"url": url, "status": "skipped", "chunks_indexed": 0} + + # Chunk content + chunks = chunk_content(content, title=title, url=url) + if not chunks: + return {"url": url, "status": "empty", "chunks_indexed": 0} + + # Generate embeddings + texts = [c.content for c in chunks] + try: + embeddings = await self._embedding.embed_texts(texts) + except Exception: + logger.exception(f"Embedding failed for {url}") + return {"url": url, "status": "error", "chunks_indexed": 0, "error": "embedding_failed"} + + # Store in DB (ensure website_urls entry exists, delete old chunks → insert new) + async with self._pool.acquire() as conn, conn.transaction(): + await conn.execute( + """INSERT INTO website_urls (domain, url, title, content_hash, status, discovered_at, last_crawled_at) + VALUES ($1, $2, $3, $4, 'active', NOW(), NOW()) + ON CONFLICT (domain, url) DO UPDATE SET + title = COALESCE(EXCLUDED.title, website_urls.title), + content_hash = EXCLUDED.content_hash, + last_crawled_at = NOW()""", + domain, + url, + title, + content_hash, + ) + await conn.execute("DELETE FROM chunks WHERE url = $1", url) + await conn.executemany( + """INSERT INTO chunks (domain, url, title, content_hash, chunk_index, chunk_content, embedding) + VALUES ($1, $2, $3, $4, $5, $6, $7::vector)""", + [ + (domain, url, title, content_hash, chunk.index, chunk.content, str(embeddings[i])) + for i, chunk in enumerate(chunks) + ], + ) + + # Ensure HNSW index exists once embeddings are stored + if not self._hnsw_ensured: + try: + async with self._pool.acquire() as conn: + await conn.execute("SELECT create_chunks_hnsw_index()") + self._hnsw_ensured = True + except Exception as e: + logger.warning("HNSW index creation deferred: %s", e) + + logger.info(f"Indexed {len(chunks)} chunks for {url}") + return {"url": url, "status": "indexed", "chunks_indexed": len(chunks)} + + async def index_website(self, domain: str) -> dict: + indexed = 0 + skipped = 0 + failed = 0 + total_chunks = 0 + sem = asyncio.Semaphore(INDEXING_CONCURRENCY) + page_size = 100 + offset = 0 + + while True: + async with self._pool.acquire() as conn: + rows = await conn.fetch( + """SELECT url, title, content FROM website_urls + WHERE domain = $1 AND content IS NOT NULL + ORDER BY id + LIMIT $2 OFFSET $3""", + domain, + page_size, + offset, + ) + + if not rows: + break + + async def _index_one(row: asyncpg.Record) -> dict: + async with sem: + return await self.index_page(domain, row["url"], row["title"], row["content"]) + + results = await asyncio.gather(*[_index_one(row) for row in rows], return_exceptions=True) + + for result in results: + if isinstance(result, Exception): + logger.exception(f"Indexing task failed for {domain}: {result}") + failed += 1 + elif result["status"] == "indexed": + indexed += 1 + total_chunks += result["chunks_indexed"] + elif result["status"] == "skipped": + skipped += 1 + else: + failed += 1 + + offset += page_size + + return { + "domain": domain, + "pages_indexed": indexed, + "pages_skipped": skipped, + "pages_failed": failed, + "total_chunks": total_chunks, + } + + async def delete_page_chunks(self, url: str) -> int: + async with self._pool.acquire() as conn: + result = await conn.execute("DELETE FROM chunks WHERE url = $1", url) + count = int(result.split()[-1]) if result else 0 + return count diff --git a/services/crawler/app/services/pg_website_store.py b/services/crawler/app/services/pg_website_store.py new file mode 100644 index 000000000..854801a72 --- /dev/null +++ b/services/crawler/app/services/pg_website_store.py @@ -0,0 +1,350 @@ +""" +Async PostgreSQL-backed website store, replacing the SQLite multi-DB architecture. + +PgWebsiteStore: per-domain URL operations (scoped by domain column). +PgWebsiteStoreManager: website registry + factory for PgWebsiteStore instances. +""" + +import json +import logging +from datetime import UTC, datetime +from urllib.parse import urlparse + +import asyncpg + +logger = logging.getLogger(__name__) + + +class PgWebsiteStore: + """Per-domain URL operations backed by PostgreSQL.""" + + def __init__(self, pool: asyncpg.Pool, domain: str): + self._pool = pool + self._domain = domain + + async def save_discovered_urls(self, urls: list[dict]) -> int: + """Save discovered URLs. Returns number of newly inserted URLs (excludes duplicates).""" + if not urls: + return 0 + + async with self._pool.acquire() as conn: + count_before = await conn.fetchval("SELECT COUNT(*) FROM website_urls WHERE domain = $1", self._domain) + await conn.executemany( + """INSERT INTO website_urls (domain, url, discovered_at) + VALUES ($1, $2, NOW()) + ON CONFLICT (domain, url) DO NOTHING""", + [(self._domain, u["url"]) for u in urls], + ) + count_after = await conn.fetchval("SELECT COUNT(*) FROM website_urls WHERE domain = $1", self._domain) + inserted = count_after - count_before + logger.info(f"Saved discovered URLs for {self._domain}: {inserted} new, {count_after} total") + return inserted + + async def get_urls_page(self, offset: int = 0, limit: int = 100, status: str | None = None) -> list[dict]: + async with self._pool.acquire() as conn: + if status: + rows = await conn.fetch( + """SELECT url, content_hash, status, last_crawled_at + FROM website_urls + WHERE domain = $1 AND content_hash IS NOT NULL AND status = $2 + ORDER BY id LIMIT $3 OFFSET $4""", + self._domain, + status, + limit, + offset, + ) + else: + rows = await conn.fetch( + """SELECT url, content_hash, status, last_crawled_at + FROM website_urls + WHERE domain = $1 AND content_hash IS NOT NULL + ORDER BY id LIMIT $2 OFFSET $3""", + self._domain, + limit, + offset, + ) + return [ + { + "url": r["url"], + "content_hash": r["content_hash"], + "status": r["status"], + "last_crawled_at": r["last_crawled_at"].timestamp() if r["last_crawled_at"] else None, + } + for r in rows + ] + + async def get_urls_needing_recrawl(self, limit: int = 20, crawled_before: float | None = None) -> list[str]: + async with self._pool.acquire() as conn: + if crawled_before is not None: + ts = datetime.fromtimestamp(crawled_before, tz=UTC) + rows = await conn.fetch( + """SELECT url FROM website_urls + WHERE domain = $1 AND status != 'deleted' + AND (last_crawled_at IS NULL OR last_crawled_at < $2) + ORDER BY CASE WHEN content_hash IS NULL THEN 0 ELSE 1 END, + last_crawled_at ASC NULLS FIRST + LIMIT $3""", + self._domain, + ts, + limit, + ) + else: + rows = await conn.fetch( + """SELECT url FROM website_urls + WHERE domain = $1 AND status != 'deleted' + ORDER BY CASE WHEN content_hash IS NULL THEN 0 ELSE 1 END, + last_crawled_at ASC NULLS FIRST + LIMIT $2""", + self._domain, + limit, + ) + return [r["url"] for r in rows] + + async def increment_fail_count(self, urls: list[str]) -> None: + if not urls: + return + async with self._pool.acquire() as conn: + await conn.executemany( + """UPDATE website_urls + SET fail_count = fail_count + 1, last_crawled_at = NOW() + WHERE domain = $1 AND url = $2""", + [(self._domain, url) for url in urls], + ) + + async def update_content_hashes(self, updates: list[dict]) -> None: + if not updates: + return + async with self._pool.acquire() as conn: + await conn.executemany( + """UPDATE website_urls + SET content_hash = $3, status = $4, last_crawled_at = NOW(), + title = $5, content = $6, word_count = $7, + metadata = $8::jsonb, structured_data = $9::jsonb, + fail_count = 0 + WHERE domain = $1 AND url = $2""", + [ + ( + self._domain, + u["url"], + u["content_hash"], + u.get("status", "active"), + u.get("title"), + u.get("content"), + u.get("word_count"), + json.dumps(u["metadata"]) if u.get("metadata") else None, + json.dumps(u["structured_data"]) if u.get("structured_data") else None, + ) + for u in updates + ], + ) + + async def mark_urls_deleted(self, urls: list[str]) -> None: + if not urls: + return + async with self._pool.acquire() as conn: + await conn.executemany( + "UPDATE website_urls SET status = 'deleted' WHERE domain = $1 AND url = $2", + [(self._domain, url) for url in urls], + ) + + async def get_cache_headers(self, urls: list[str]) -> dict[str, dict]: + if not urls: + return {} + async with self._pool.acquire() as conn: + rows = await conn.fetch( + """SELECT url, etag, last_modified FROM website_urls + WHERE domain = $1 AND url = ANY($2) + AND (etag IS NOT NULL OR last_modified IS NOT NULL)""", + self._domain, + urls, + ) + return {r["url"]: {"etag": r["etag"], "last_modified": r["last_modified"]} for r in rows} + + async def update_cache_headers(self, updates: list[dict]) -> None: + if not updates: + return + async with self._pool.acquire() as conn: + await conn.executemany( + "UPDATE website_urls SET etag = $3, last_modified = $4 WHERE domain = $1 AND url = $2", + [(self._domain, u["url"], u.get("etag"), u.get("last_modified")) for u in updates], + ) + + async def touch_crawled_at(self, urls: list[str]) -> None: + if not urls: + return + async with self._pool.acquire() as conn: + await conn.executemany( + "UPDATE website_urls SET last_crawled_at = NOW() WHERE domain = $1 AND url = $2", + [(self._domain, url) for url in urls], + ) + + async def get_total_count(self, status: str | None = None) -> int: + async with self._pool.acquire() as conn: + if status: + return await conn.fetchval( + """SELECT COUNT(*) FROM website_urls + WHERE domain = $1 AND content_hash IS NOT NULL AND status = $2""", + self._domain, + status, + ) + return await conn.fetchval( + "SELECT COUNT(*) FROM website_urls WHERE domain = $1 AND content_hash IS NOT NULL", + self._domain, + ) + + async def get_cached_pages(self, urls: list[str]) -> list[dict]: + if not urls: + return [] + async with self._pool.acquire() as conn: + rows = await conn.fetch( + """SELECT url, title, content, word_count, metadata, structured_data + FROM website_urls + WHERE domain = $1 AND url = ANY($2) AND content IS NOT NULL""", + self._domain, + urls, + ) + return [ + { + "url": r["url"], + "title": r["title"], + "content": r["content"], + "word_count": r["word_count"] or 0, + "metadata": r["metadata"], + "structured_data": r["structured_data"], + } + for r in rows + ] + + +class PgWebsiteStoreManager: + """Website registry + factory for PgWebsiteStore instances.""" + + def __init__(self, pool: asyncpg.Pool): + self._pool = pool + self._stores: dict[str, PgWebsiteStore] = {} + + async def register_website(self, domain: str, scan_interval: int = 21600) -> dict: + async with self._pool.acquire() as conn: + await conn.execute( + """INSERT INTO websites (domain, scan_interval, created_at, updated_at) + VALUES ($1, $2, NOW(), NOW()) + ON CONFLICT(domain) DO UPDATE SET + scan_interval = EXCLUDED.scan_interval, + updated_at = NOW()""", + domain, + scan_interval, + ) + logger.info(f"Registered website: {domain} (interval={scan_interval}s)") + return {"domain": domain, "scan_interval": scan_interval, "status": "idle"} + + async def update_website_metadata( + self, + domain: str, + title: str | None = None, + description: str | None = None, + page_count: int | None = None, + ) -> None: + async with self._pool.acquire() as conn: + await conn.execute( + """UPDATE websites SET + title = COALESCE($2, title), + description = COALESCE($3, description), + page_count = COALESCE($4, page_count), + updated_at = NOW() + WHERE domain = $1""", + domain, + title, + description, + page_count, + ) + + async def remove_website(self, domain: str) -> bool: + self._stores.pop(domain, None) + async with self._pool.acquire() as conn: + # ON DELETE CASCADE on website_urls and chunks handles child row cleanup + result = await conn.execute("DELETE FROM websites WHERE domain = $1", domain) + deleted = result == "DELETE 1" + if deleted: + logger.info(f"Removed website: {domain}") + return deleted + + async def get_due_websites(self) -> list[dict]: + async with self._pool.acquire() as conn: + rows = await conn.fetch( + """SELECT domain, status, scan_interval, last_scanned_at, error + FROM websites + WHERE status != 'scanning' + AND (last_scanned_at IS NULL + OR last_scanned_at + make_interval(secs => scan_interval) < NOW())""" + ) + return [dict(r) for r in rows] + + async def update_scan_status(self, domain: str, status: str, error: str | None = None) -> None: + async with self._pool.acquire() as conn: + await conn.execute( + "UPDATE websites SET status = $2, error = $3, updated_at = NOW() WHERE domain = $1", + domain, + status, + error, + ) + + async def update_last_scanned(self, domain: str) -> None: + async with self._pool.acquire() as conn: + await conn.execute( + "UPDATE websites SET last_scanned_at = NOW(), updated_at = NOW() WHERE domain = $1", + domain, + ) + + async def get_website(self, domain: str) -> dict | None: + async with self._pool.acquire() as conn: + row = await conn.fetchrow( + """SELECT w.domain, w.title, w.description, w.page_count, w.status, + w.scan_interval, w.last_scanned_at, w.error, + w.created_at, w.updated_at, + COALESCE(u.total, 0) AS total_urls, + COALESCE(u.crawled, 0) AS crawled_count + FROM websites w + LEFT JOIN LATERAL ( + SELECT COUNT(*) AS total, + COUNT(*) FILTER (WHERE content_hash IS NOT NULL) AS crawled + FROM website_urls WHERE domain = w.domain + ) u ON true + WHERE w.domain = $1""", + domain, + ) + return dict(row) if row else None + + def get_site_store(self, domain: str) -> PgWebsiteStore: + if domain not in self._stores: + self._stores[domain] = PgWebsiteStore(self._pool, domain) + return self._stores[domain] + + async def get_cached_pages(self, urls: list[str]) -> tuple[list[dict], list[str]]: + if not urls: + return [], [] + + by_domain: dict[str, list[str]] = {} + for url in urls: + domain = urlparse(url).netloc + by_domain.setdefault(domain, []).append(url) + + cached: list[dict] = [] + to_crawl: list[str] = [] + + for domain, domain_urls in by_domain.items(): + website = await self.get_website(domain) + if not website: + to_crawl.extend(domain_urls) + continue + + site_store = self.get_site_store(domain) + hits = await site_store.get_cached_pages(domain_urls) + hit_urls = {p["url"] for p in hits} + cached.extend(hits) + to_crawl.extend(u for u in domain_urls if u not in hit_urls) + + return cached, to_crawl + + async def close(self) -> None: + self._stores.clear() + logger.info("PgWebsiteStoreManager closed") diff --git a/services/crawler/app/services/scheduler.py b/services/crawler/app/services/scheduler.py index f95b98d26..25fe051ac 100644 --- a/services/crawler/app/services/scheduler.py +++ b/services/crawler/app/services/scheduler.py @@ -2,8 +2,7 @@ Background scheduler for autonomous website scanning. Periodically checks for websites due for scanning and runs discovery + content -hashing in parallel (bounded by Semaphore). Each website writes to its own -SQLite file so there is zero lock contention between concurrent scans. +hashing in parallel (bounded by Semaphore). """ import asyncio @@ -15,7 +14,9 @@ import httpx from app.services.crawler_service import CrawlerService -from app.services.website_store import WebsiteStore, WebsiteStoreManager +from app.services.indexing_service import IndexingService +from app.services.pg_website_store import PgWebsiteStore, PgWebsiteStoreManager +from app.utils.metadata import extract_meta_description logger = logging.getLogger(__name__) @@ -24,8 +25,10 @@ POLL_INTERVAL = 60 # seconds _HEAD_TIMEOUT = 10 _HEAD_CONCURRENCY = 5 +_HEAD_BATCH_SIZE = 50 _scan_trigger: asyncio.Event | None = None +_cancelled_domains: set[str] = set() def _sha256(content: str) -> str: @@ -38,9 +41,23 @@ def trigger_scan(): _scan_trigger.set() +def cancel_scan(domain: str): + """Mark a domain for scan cancellation.""" + _cancelled_domains.add(domain) + + +def _is_cancelled(domain: str) -> bool: + return domain in _cancelled_domains + + +def _clear_cancelled(domain: str): + _cancelled_domains.discard(domain) + + async def run_scheduler( - store_manager: WebsiteStoreManager, + store_manager: PgWebsiteStoreManager, crawler_service: CrawlerService, + indexing_service: IndexingService | None = None, ): global _scan_trigger _scan_trigger = asyncio.Event() @@ -49,15 +66,22 @@ async def run_scheduler( async def bounded_scan(domain: str): async with sem: - await _scan_website(domain, store_manager, crawler_service) + await _scan_website(domain, store_manager, crawler_service, indexing_service) while True: try: - due = store_manager.get_due_websites() + due = await store_manager.get_due_websites() if due: logger.info(f"Scheduler: {len(due)} website(s) due for scanning") tasks = [asyncio.create_task(bounded_scan(w["domain"])) for w in due] - await asyncio.gather(*tasks, return_exceptions=True) + results = await asyncio.gather(*tasks, return_exceptions=True) + for website, result in zip(due, results, strict=True): + if isinstance(result, BaseException): + logger.error(f"Scheduler: scan failed for {website['domain']}: {result}") + try: + await store_manager.update_scan_status(website["domain"], "error", str(result)) + except Exception: + logger.exception(f"Scheduler: failed to update error status for {website['domain']}") except Exception: logger.exception("Scheduler loop error") @@ -71,10 +95,10 @@ async def bounded_scan(domain: str): async def _head_check( urls: list[str], - site_store: WebsiteStore, + site_store: PgWebsiteStore, ) -> tuple[list[str], list[str]]: """Split URLs into (unchanged, needs_crawl) using conditional HEAD requests.""" - stored = site_store.get_cache_headers(urls) + stored = await site_store.get_cache_headers(urls) urls_with_headers = [u for u in urls if u in stored] urls_without_headers = [u for u in urls if u not in stored] @@ -112,19 +136,19 @@ async def check_one(client: httpx.AsyncClient, url: str): async with httpx.AsyncClient( follow_redirects=True, timeout=_HEAD_TIMEOUT, - verify=False, + verify=False, # Intentional: crawling arbitrary external sites that may have invalid certs ) as client: await asyncio.gather(*[check_one(client, u) for u in urls_with_headers]) if header_updates: - site_store.update_cache_headers(header_updates) + await site_store.update_cache_headers(header_updates) return unchanged, needs_crawl async def _seed_cache_headers( urls: list[str], - site_store: WebsiteStore, + site_store: PgWebsiteStore, ) -> None: """Seed etag/last_modified via HEAD for URLs that just completed their first crawl.""" sem = asyncio.Semaphore(_HEAD_CONCURRENCY) @@ -144,59 +168,107 @@ async def seed_one(client: httpx.AsyncClient, url: str): async with httpx.AsyncClient( follow_redirects=True, timeout=_HEAD_TIMEOUT, - verify=False, + verify=False, # Intentional: crawling arbitrary external sites that may have invalid certs ) as client: await asyncio.gather(*[seed_one(client, u) for u in urls]) if header_updates: - site_store.update_cache_headers(header_updates) + await site_store.update_cache_headers(header_updates) logger.info(f"Seeded cache headers for {len(header_updates)}/{len(urls)} URLs") +def _is_homepage(url: str, domain: str) -> bool: + """Check if a URL is the homepage (root path) of the domain.""" + from urllib.parse import urlparse + + parsed = urlparse(url) + return parsed.netloc == domain and parsed.path in ("", "/") + + +async def _bulk_head_check( + all_urls: list[str], + site_store: PgWebsiteStore, +) -> tuple[list[str], list[str], set[str]]: + """HEAD check all URLs in batches, return (unchanged, needs_crawl, urls_with_prior_headers).""" + all_unchanged: list[str] = [] + all_needs_crawl: list[str] = [] + all_had_headers: set[str] = set() + + for i in range(0, len(all_urls), _HEAD_BATCH_SIZE): + batch = all_urls[i : i + _HEAD_BATCH_SIZE] + had_headers = await site_store.get_cache_headers(batch) + all_had_headers.update(had_headers) + unchanged, needs_crawl = await _head_check(batch, site_store) + all_unchanged.extend(unchanged) + all_needs_crawl.extend(needs_crawl) + + return all_unchanged, all_needs_crawl, all_had_headers + + async def _scan_website( domain: str, - store_manager: WebsiteStoreManager, + store_manager: PgWebsiteStoreManager, crawler_service: CrawlerService, + indexing_service: IndexingService | None = None, ): + _clear_cancelled(domain) site_store = store_manager.get_site_store(domain) - store_manager.update_scan_status(domain, "scanning") + await store_manager.update_scan_status(domain, "scanning") try: if not crawler_service.initialized: await crawler_service.initialize() # Phase 1: Discover new URLs + if _is_cancelled(domain): + logger.info(f"Scan [{domain}]: cancelled before discovery") + await store_manager.update_scan_status(domain, "idle") + return logger.info(f"Scan [{domain}]: Phase 1 — discovering URLs") discovered = await crawler_service.discover_urls(domain=domain, max_urls=-1) - site_store.save_discovered_urls(discovered) + await site_store.save_discovered_urls(discovered) logger.info(f"Scan [{domain}]: discovered {len(discovered)} URLs") - # Phase 2: Crawl URLs in batches and cache content + hashes + # Phase 2: Bulk HEAD check — filter unchanged URLs up front + if _is_cancelled(domain): + logger.info(f"Scan [{domain}]: cancelled before HEAD check") + await store_manager.update_scan_status(domain, "idle") + return scan_start = time.time() + all_urls = await site_store.get_urls_needing_recrawl(limit=10000, crawled_before=scan_start) + if not all_urls: + logger.info(f"Scan [{domain}]: no URLs need recrawling") + await store_manager.update_last_scanned(domain) + await store_manager.update_scan_status(domain, "active") + return + + logger.info(f"Scan [{domain}]: Phase 2 — HEAD checking {len(all_urls)} URLs in batches of {_HEAD_BATCH_SIZE}") + unchanged, needs_crawl, had_headers = await _bulk_head_check(all_urls, site_store) + + if unchanged: + await site_store.touch_crawled_at(unchanged) + logger.info( + f"Scan [{domain}]: HEAD check complete — {len(unchanged)} unchanged, {len(needs_crawl)} need crawling" + ) + + # Phase 3: Crawl changed URLs in batches crawled_total = 0 - skipped_total = 0 - while True: - batch = site_store.get_urls_needing_recrawl(limit=CRAWL_BATCH_SIZE, crawled_before=scan_start) - if not batch: - break - - # Pre-flight: skip URLs unchanged since last crawl (304) - had_headers = site_store.get_cache_headers(batch) - unchanged, to_crawl = await _head_check(batch, site_store) - if unchanged: - site_store.touch_crawled_at(unchanged) - skipped_total += len(unchanged) - - if not to_crawl: - continue - + homepage_title: str | None = None + homepage_description: str | None = None + + for i in range(0, len(needs_crawl), CRAWL_BATCH_SIZE): + if _is_cancelled(domain): + logger.info(f"Scan [{domain}]: cancelled during crawl (crawled {crawled_total} so far)") + await store_manager.update_scan_status(domain, "idle") + return + batch = needs_crawl[i : i + CRAWL_BATCH_SIZE] logger.info( - f"Scan [{domain}]: Phase 2 — crawling {len(to_crawl)} URLs " - f"(skipped {len(unchanged)}, total so far: {crawled_total})" + f"Scan [{domain}]: Phase 3 — crawling batch {i // CRAWL_BATCH_SIZE + 1} " + f"({len(batch)} URLs, total so far: {crawled_total})" ) - results = await crawler_service.crawl_urls(urls=to_crawl) + results = await crawler_service.crawl_urls(urls=batch) succeeded_urls = {p["url"] for p in results} - failed_urls = [u for u in to_crawl if u not in succeeded_urls] + failed_urls = [u for u in batch if u not in succeeded_urls] updates = [ { @@ -206,29 +278,60 @@ async def _scan_website( "title": p.get("title"), "content": p["content"], "word_count": p.get("word_count", 0), - "metadata": json.dumps(p.get("metadata")) if p.get("metadata") else None, - "structured_data": json.dumps(p.get("structured_data")) if p.get("structured_data") else None, + "metadata": p.get("metadata"), + "structured_data": p.get("structured_data"), } for p in results ] - site_store.update_content_hashes(updates) + await site_store.update_content_hashes(updates) crawled_total += len(updates) + if homepage_title is None: + for p in results: + if _is_homepage(p["url"], domain): + homepage_title = p.get("title") + sd = p.get("structured_data") + if isinstance(sd, str): + sd = json.loads(sd) + homepage_description = extract_meta_description(sd) + break + + if indexing_service: + for p in results: + if p.get("content"): + try: + await indexing_service.index_page( + domain=domain, + url=p["url"], + title=p.get("title"), + content=p["content"], + ) + except Exception: + logger.exception(f"Indexing failed for {p['url']}") + if failed_urls: logger.warning(f"Scan [{domain}]: {len(failed_urls)} URLs failed in batch") - site_store.increment_fail_count(failed_urls) + await site_store.increment_fail_count(failed_urls) - # Seed cache headers for URLs that had none before first_time = [u for u in succeeded_urls if u not in had_headers] if first_time: await _seed_cache_headers(first_time, site_store) - logger.info(f"Scan [{domain}]: crawled {crawled_total}, skipped {skipped_total} unchanged URLs") + logger.info(f"Scan [{domain}]: crawled {crawled_total}, skipped {len(unchanged)} unchanged URLs") + + # Phase 4: Update website metadata + page_count = await site_store.get_total_count() + await store_manager.update_website_metadata( + domain=domain, + title=homepage_title, + description=homepage_description, + page_count=page_count, + ) - store_manager.update_last_scanned(domain) - store_manager.update_scan_status(domain, "idle") - logger.info(f"Scan [{domain}]: complete") + await store_manager.update_last_scanned(domain) + await store_manager.update_scan_status(domain, "active") + logger.info(f"Scan [{domain}]: complete (pages={page_count})") except Exception as e: logger.exception(f"Scan failed for {domain}") - store_manager.update_scan_status(domain, "error", str(e)) + await store_manager.update_scan_status(domain, "error", str(e)) diff --git a/services/crawler/app/services/search_service.py b/services/crawler/app/services/search_service.py new file mode 100644 index 000000000..4c29f5e7f --- /dev/null +++ b/services/crawler/app/services/search_service.py @@ -0,0 +1,130 @@ +""" +Hybrid search service: BM25 full-text (pg_search) + pgvector similarity with RRF fusion. +""" + +import asyncio +import json +import logging +from dataclasses import dataclass + +import asyncpg + +from app.services.embedding_service import EmbeddingService + +logger = logging.getLogger(__name__) + +RRF_K = 60 + + +@dataclass +class SearchResult: + url: str + title: str | None + chunk_content: str + chunk_index: int + score: float + + +class SearchService: + def __init__(self, pool: asyncpg.Pool, embedding_service: EmbeddingService): + self._pool = pool + self._embedding = embedding_service + + async def search( + self, + query: str, + domain: str | None = None, + limit: int = 10, + ) -> list[SearchResult]: + # Generate query embedding and run both searches in parallel + embedding_task = asyncio.create_task(self._embedding.embed_query(query)) + fts_task = asyncio.create_task(self._fts_search(query, domain, limit * 3)) + + query_embedding = await embedding_task + fts_results = await fts_task + vector_results = await self._vector_search(query_embedding, domain, limit * 3) + + return self._merge_rrf([fts_results, vector_results], limit) + + async def _fts_search(self, query: str, domain: str | None, limit: int) -> list[dict]: + async with self._pool.acquire() as conn: + if domain: + rows = await conn.fetch( + """SELECT id, url, title, chunk_content, chunk_index, + paradedb.score(id) AS score + FROM chunks + WHERE id @@@ paradedb.match('chunk_content', $1) AND domain = $2 + ORDER BY score DESC + LIMIT $3""", + query, + domain, + limit, + ) + else: + rows = await conn.fetch( + """SELECT id, url, title, chunk_content, chunk_index, + paradedb.score(id) AS score + FROM chunks + WHERE id @@@ paradedb.match('chunk_content', $1) + ORDER BY score DESC + LIMIT $2""", + query, + limit, + ) + return [dict(r) for r in rows] + + async def _vector_search(self, embedding: list[float], domain: str | None, limit: int) -> list[dict]: + vec_str = json.dumps(embedding) + async with self._pool.acquire() as conn: + if domain: + rows = await conn.fetch( + """SELECT id, url, title, chunk_content, chunk_index, + 1 - (embedding <=> $1::vector) AS score + FROM chunks + WHERE domain = $2 AND embedding IS NOT NULL + ORDER BY embedding <=> $1::vector + LIMIT $3""", + vec_str, + domain, + limit, + ) + else: + rows = await conn.fetch( + """SELECT id, url, title, chunk_content, chunk_index, + 1 - (embedding <=> $1::vector) AS score + FROM chunks + WHERE embedding IS NOT NULL + ORDER BY embedding <=> $1::vector + LIMIT $2""", + vec_str, + limit, + ) + return [dict(r) for r in rows] + + @staticmethod + def _merge_rrf(ranked_lists: list[list[dict]], limit: int) -> list[SearchResult]: + scores: dict[int, float] = {} + items: dict[int, dict] = {} + + for ranked in ranked_lists: + for rank, item in enumerate(ranked): + item_id = item["id"] + rrf_score = 1.0 / (RRF_K + rank + 1) + scores[item_id] = scores.get(item_id, 0.0) + rrf_score + items[item_id] = item + + sorted_ids = sorted(scores, key=lambda k: scores[k], reverse=True)[:limit] + + # Normalize scores + max_score = scores[sorted_ids[0]] if sorted_ids else 1.0 + + return [ + SearchResult( + url=items[item_id]["url"], + title=items[item_id].get("title"), + chunk_content=items[item_id]["chunk_content"], + chunk_index=items[item_id]["chunk_index"], + score=scores[item_id] / max_score, + ) + for item_id in sorted_ids + ] diff --git a/services/crawler/app/services/website_store.py b/services/crawler/app/services/website_store.py deleted file mode 100644 index d090fbb7c..000000000 --- a/services/crawler/app/services/website_store.py +++ /dev/null @@ -1,451 +0,0 @@ -""" -Multi-DB SQLite store for website URL registry with content hashing. - -Architecture: -- Main DB (data/crawler.db): websites table — registry of all tracked websites -- Per-site DB (data/sites/{domain}.db): website_urls table — URLs + content_hash per site - -Benefits: -- Zero lock contention: each website has its own SQLite file, independent WAL -- Natural concurrency: different websites can be scanned in parallel -- Clean deletion: remove_website = close connection + unlink the .db file -""" - -import json -import logging -import sqlite3 -import time -from pathlib import Path -from urllib.parse import urlparse - -logger = logging.getLogger(__name__) - -_DEFAULT_DATA_DIR = Path(__file__).resolve().parent.parent.parent / "data" - - -def _sanitize_domain(domain: str) -> str: - return domain.replace(".", "_").replace("-", "_") - - -class WebsiteStore: - """Manages one per-site SQLite file with URL registry and content hashes.""" - - def __init__(self, db_path: Path): - self._db_path = db_path - self._db_path.parent.mkdir(parents=True, exist_ok=True) - self._conn: sqlite3.Connection | None = None - self._get_conn() - - def _get_conn(self) -> sqlite3.Connection: - if self._conn is None: - self._conn = sqlite3.connect(str(self._db_path), timeout=30) - self._conn.execute("PRAGMA journal_mode=WAL") - self._conn.execute("PRAGMA busy_timeout=5000") - self._conn.row_factory = sqlite3.Row - self._ensure_schema(self._conn) - return self._conn - - @staticmethod - def _ensure_schema(conn: sqlite3.Connection): - conn.executescript(""" - CREATE TABLE IF NOT EXISTS website_urls ( - url TEXT PRIMARY KEY, - content_hash TEXT, - status TEXT NOT NULL DEFAULT 'discovered', - last_crawled_at REAL, - discovered_at REAL NOT NULL, - title TEXT, - content TEXT, - word_count INTEGER, - metadata TEXT, - structured_data TEXT, - fail_count INTEGER NOT NULL DEFAULT 0 - ); - - CREATE INDEX IF NOT EXISTS idx_crawl_order - ON website_urls(last_crawled_at); - """) - - # Migrate existing databases: add content cache columns if missing - existing = {row[1] for row in conn.execute("PRAGMA table_info(website_urls)").fetchall()} - for col, col_type in [ - ("title", "TEXT"), - ("content", "TEXT"), - ("word_count", "INTEGER"), - ("metadata", "TEXT"), - ("structured_data", "TEXT"), - ("fail_count", "INTEGER NOT NULL DEFAULT 0"), - ("etag", "TEXT"), - ("last_modified", "TEXT"), - ]: - if col not in existing: - conn.execute(f"ALTER TABLE website_urls ADD COLUMN {col} {col_type}") - - conn.commit() - - def save_discovered_urls(self, urls: list[dict]) -> int: - if not urls: - return 0 - - now = time.time() - rows = [(u["url"], now) for u in urls] - conn = self._get_conn() - conn.executemany( - "INSERT OR IGNORE INTO website_urls (url, discovered_at) VALUES (?, ?)", - rows, - ) - inserted = conn.total_changes - conn.commit() - return inserted - - def get_urls_page(self, offset: int = 0, limit: int = 100, status: str | None = None) -> list[dict]: - conn = self._get_conn() - if status: - rows = conn.execute( - "SELECT url, content_hash, status, last_crawled_at " - "FROM website_urls WHERE content_hash IS NOT NULL AND status = ? " - "ORDER BY rowid LIMIT ? OFFSET ?", - (status, limit, offset), - ).fetchall() - else: - rows = conn.execute( - "SELECT url, content_hash, status, last_crawled_at " - "FROM website_urls WHERE content_hash IS NOT NULL " - "ORDER BY rowid LIMIT ? OFFSET ?", - (limit, offset), - ).fetchall() - - return [ - { - "url": r["url"], - "content_hash": r["content_hash"], - "status": r["status"], - "last_crawled_at": r["last_crawled_at"], - } - for r in rows - ] - - def get_urls_needing_recrawl(self, limit: int = 20, crawled_before: float | None = None) -> list[str]: - conn = self._get_conn() - if crawled_before is not None: - rows = conn.execute( - "SELECT url FROM website_urls " - "WHERE status != 'deleted' " - "AND (last_crawled_at IS NULL OR last_crawled_at < ?) " - "ORDER BY CASE WHEN content_hash IS NULL THEN 0 ELSE 1 END, " - "last_crawled_at ASC NULLS FIRST " - "LIMIT ?", - (crawled_before, limit), - ).fetchall() - else: - rows = conn.execute( - "SELECT url FROM website_urls " - "WHERE status != 'deleted' " - "ORDER BY CASE WHEN content_hash IS NULL THEN 0 ELSE 1 END, " - "last_crawled_at ASC NULLS FIRST " - "LIMIT ?", - (limit,), - ).fetchall() - return [r["url"] for r in rows] - - def increment_fail_count(self, urls: list[str]): - if not urls: - return - - now = time.time() - conn = self._get_conn() - conn.executemany( - "UPDATE website_urls SET fail_count = fail_count + 1, last_crawled_at = ? WHERE url = ?", - [(now, url) for url in urls], - ) - conn.commit() - - def update_content_hashes(self, updates: list[dict]): - if not updates: - return - - now = time.time() - conn = self._get_conn() - conn.executemany( - "UPDATE website_urls " - "SET content_hash = ?, status = ?, last_crawled_at = ?, " - " title = ?, content = ?, word_count = ?, metadata = ?, structured_data = ?, " - " fail_count = 0 " - "WHERE url = ?", - [ - ( - u["content_hash"], - u.get("status", "active"), - now, - u.get("title"), - u.get("content"), - u.get("word_count"), - u.get("metadata"), - u.get("structured_data"), - u["url"], - ) - for u in updates - ], - ) - conn.commit() - - def mark_urls_deleted(self, urls: list[str]): - if not urls: - return - - conn = self._get_conn() - conn.executemany( - "UPDATE website_urls SET status = 'deleted' WHERE url = ?", - [(url,) for url in urls], - ) - conn.commit() - - def get_cache_headers(self, urls: list[str]) -> dict[str, dict]: - """Load stored etag/last_modified for URLs that have at least one header.""" - if not urls: - return {} - - conn = self._get_conn() - placeholders = ",".join("?" * len(urls)) - rows = conn.execute( - "SELECT url, etag, last_modified FROM website_urls " - f"WHERE url IN ({placeholders}) AND (etag IS NOT NULL OR last_modified IS NOT NULL)", - urls, - ).fetchall() - - return {r["url"]: {"etag": r["etag"], "last_modified": r["last_modified"]} for r in rows} - - def update_cache_headers(self, updates: list[dict]): - """Batch store etag/last_modified from HEAD responses.""" - if not updates: - return - - conn = self._get_conn() - conn.executemany( - "UPDATE website_urls SET etag = ?, last_modified = ? WHERE url = ?", - [(u.get("etag"), u.get("last_modified"), u["url"]) for u in updates], - ) - conn.commit() - - def touch_crawled_at(self, urls: list[str]): - """Update only last_crawled_at for unchanged URLs (skipped by 304).""" - if not urls: - return - - now = time.time() - conn = self._get_conn() - conn.executemany( - "UPDATE website_urls SET last_crawled_at = ? WHERE url = ?", - [(now, url) for url in urls], - ) - conn.commit() - - def get_total_count(self, status: str | None = None) -> int: - conn = self._get_conn() - if status: - row = conn.execute( - "SELECT COUNT(*) as cnt FROM website_urls WHERE content_hash IS NOT NULL AND status = ?", - (status,), - ).fetchone() - else: - row = conn.execute("SELECT COUNT(*) as cnt FROM website_urls WHERE content_hash IS NOT NULL").fetchone() - return row["cnt"] if row else 0 - - def get_cached_pages(self, urls: list[str]) -> list[dict]: - if not urls: - return [] - - conn = self._get_conn() - placeholders = ",".join("?" * len(urls)) - rows = conn.execute( - "SELECT url, title, content, word_count, metadata, structured_data " - f"FROM website_urls WHERE url IN ({placeholders}) AND content IS NOT NULL", - urls, - ).fetchall() - - return [ - { - "url": r["url"], - "title": r["title"], - "content": r["content"], - "word_count": r["word_count"] or 0, - "metadata": json.loads(r["metadata"]) if r["metadata"] else None, - "structured_data": json.loads(r["structured_data"]) if r["structured_data"] else None, - } - for r in rows - ] - - def close(self): - if self._conn: - self._conn.close() - self._conn = None - - -class WebsiteStoreManager: - """Manages main DB (website registry) + per-site WebsiteStore instances.""" - - def __init__(self, data_dir: Path | None = None): - self._data_dir = data_dir or _DEFAULT_DATA_DIR - self._data_dir.mkdir(parents=True, exist_ok=True) - self._main_db_path = self._data_dir / "crawler.db" - self._sites_dir = self._data_dir / "sites" - self._sites_dir.mkdir(parents=True, exist_ok=True) - self._stores: dict[str, WebsiteStore] = {} - self._main_conn: sqlite3.Connection | None = None - self._init_main_db() - - def _get_main_conn(self) -> sqlite3.Connection: - if self._main_conn is None: - self._main_conn = sqlite3.connect(str(self._main_db_path), timeout=30) - self._main_conn.execute("PRAGMA journal_mode=WAL") - self._main_conn.execute("PRAGMA busy_timeout=5000") - self._main_conn.row_factory = sqlite3.Row - return self._main_conn - - def _init_main_db(self): - conn = self._get_main_conn() - conn.executescript(""" - CREATE TABLE IF NOT EXISTS websites ( - domain TEXT PRIMARY KEY, - status TEXT NOT NULL DEFAULT 'idle', - scan_interval INTEGER NOT NULL DEFAULT 21600, - last_scanned_at REAL, - error TEXT, - created_at REAL NOT NULL, - updated_at REAL NOT NULL - ); - """) - conn.commit() - logger.info(f"Website store manager initialized at {self._data_dir}") - - def register_website(self, domain: str, scan_interval: int = 21600) -> dict: - now = time.time() - conn = self._get_main_conn() - conn.execute( - """INSERT INTO websites (domain, scan_interval, created_at, updated_at) - VALUES (?, ?, ?, ?) - ON CONFLICT(domain) DO UPDATE SET - scan_interval = excluded.scan_interval, - updated_at = excluded.updated_at""", - (domain, scan_interval, now, now), - ) - conn.commit() - logger.info(f"Registered website: {domain} (interval={scan_interval}s)") - return {"domain": domain, "scan_interval": scan_interval, "status": "idle"} - - def remove_website(self, domain: str) -> bool: - if domain in self._stores: - self._stores[domain].close() - del self._stores[domain] - - db_file = self._sites_dir / f"{_sanitize_domain(domain)}.db" - if db_file.exists(): - db_file.unlink() - wal = db_file.with_suffix(".db-wal") - shm = db_file.with_suffix(".db-shm") - if wal.exists(): - wal.unlink() - if shm.exists(): - shm.unlink() - - conn = self._get_main_conn() - cursor = conn.execute("DELETE FROM websites WHERE domain = ?", (domain,)) - conn.commit() - deleted = cursor.rowcount > 0 - if deleted: - logger.info(f"Removed website: {domain}") - return deleted - - def get_due_websites(self) -> list[dict]: - now = time.time() - conn = self._get_main_conn() - rows = conn.execute( - """SELECT domain, status, scan_interval, last_scanned_at, error - FROM websites - WHERE status != 'scanning' - AND (last_scanned_at IS NULL - OR last_scanned_at + scan_interval < ?)""", - (now,), - ).fetchall() - return [dict(r) for r in rows] - - def update_scan_status(self, domain: str, status: str, error: str | None = None): - now = time.time() - conn = self._get_main_conn() - conn.execute( - "UPDATE websites SET status = ?, error = ?, updated_at = ? WHERE domain = ?", - (status, error, now, domain), - ) - conn.commit() - - def update_last_scanned(self, domain: str): - now = time.time() - conn = self._get_main_conn() - conn.execute( - "UPDATE websites SET last_scanned_at = ?, updated_at = ? WHERE domain = ?", - (now, now, domain), - ) - conn.commit() - - def get_website(self, domain: str) -> dict | None: - conn = self._get_main_conn() - row = conn.execute( - "SELECT domain, status, scan_interval, last_scanned_at, error, created_at, updated_at " - "FROM websites WHERE domain = ?", - (domain,), - ).fetchone() - return dict(row) if row else None - - def get_site_store(self, domain: str) -> WebsiteStore: - if domain not in self._stores: - db_path = self._sites_dir / f"{_sanitize_domain(domain)}.db" - self._stores[domain] = WebsiteStore(db_path) - return self._stores[domain] - - def get_cached_pages(self, urls: list[str]) -> tuple[list[dict], list[str]]: - """Return cached page content for URLs with registered websites. - - Returns (cached_pages, urls_needing_crawl). - """ - if not urls: - return [], [] - - by_domain: dict[str, list[str]] = {} - for url in urls: - domain = urlparse(url).netloc - by_domain.setdefault(domain, []).append(url) - - cached: list[dict] = [] - to_crawl: list[str] = [] - - for domain, domain_urls in by_domain.items(): - if not self.get_website(domain): - to_crawl.extend(domain_urls) - continue - - site_store = self.get_site_store(domain) - hits = site_store.get_cached_pages(domain_urls) - hit_urls = {p["url"] for p in hits} - cached.extend(hits) - to_crawl.extend(u for u in domain_urls if u not in hit_urls) - - return cached, to_crawl - - def close_all(self): - for store in self._stores.values(): - store.close() - self._stores.clear() - if self._main_conn: - self._main_conn.close() - self._main_conn = None - logger.info("All website stores closed") - - -_store_manager: WebsiteStoreManager | None = None - - -def get_website_store_manager() -> WebsiteStoreManager: - global _store_manager - if _store_manager is None: - _store_manager = WebsiteStoreManager() - return _store_manager diff --git a/services/crawler/app/utils/metadata.py b/services/crawler/app/utils/metadata.py new file mode 100644 index 000000000..389e8fb21 --- /dev/null +++ b/services/crawler/app/utils/metadata.py @@ -0,0 +1,14 @@ +"""Shared metadata extraction utilities.""" + + +def extract_meta_description(structured_data: dict | None) -> str | None: + """Extract meta description from structured data (meta tags or OpenGraph).""" + if not structured_data: + return None + meta = structured_data.get("meta", {}) + if desc := meta.get("description"): + return desc + og = structured_data.get("opengraph", {}) + if desc := og.get("og:description"): + return desc + return None diff --git a/services/crawler/pyproject.toml b/services/crawler/pyproject.toml index 182ee6cbd..c4f7db0c8 100644 --- a/services/crawler/pyproject.toml +++ b/services/crawler/pyproject.toml @@ -18,6 +18,9 @@ dependencies = [ "python-docx==1.2.0", "pymupdf==1.27.1", "openai>=1.0.0", + "asyncpg>=0.30.0", + "tiktoken>=0.9.0", + "semantic-text-splitter>=0.18.0", ] [project.optional-dependencies] diff --git a/services/crawler/tests/test_chunking_service.py b/services/crawler/tests/test_chunking_service.py new file mode 100644 index 000000000..a9e88ac99 --- /dev/null +++ b/services/crawler/tests/test_chunking_service.py @@ -0,0 +1,335 @@ +from app.services.chunking_service import ( + CHUNK_OVERLAP, + CHUNK_SIZE, + MIN_CHUNK_LENGTH, + ContentChunk, + chunk_content, +) + + +class TestChunkContentEmptyInput: + def test_empty_string(self): + assert chunk_content("") == [] + + def test_none_like_empty(self): + assert chunk_content("") == [] + + def test_whitespace_only(self): + assert chunk_content(" \n\n \t ") == [] + + def test_newlines_only(self): + assert chunk_content("\n\n\n") == [] + + +class TestChunkContentSingleChunk: + def test_short_content_returns_one_chunk(self): + text = "Hello world, this is a test of the chunking service module." + result = chunk_content(text) + assert len(result) == 1 + assert result[0].content == text + assert result[0].index == 0 + + def test_content_is_stripped(self): + text = "Hello world, this is a test of the chunking service module." + result = chunk_content(f" {text} \n\n") + assert result[0].content == text + + def test_returns_content_chunk_dataclass(self): + text = "Hello world, this is a test of the chunking service module." + result = chunk_content(text) + assert isinstance(result[0], ContentChunk) + + +class TestChunkContentWithTitle: + BODY = "Some body text here that is long enough to pass the minimum chunk length filter." + + def test_title_prepended_to_single_chunk(self): + result = chunk_content(self.BODY, title="My Title") + assert result[0].content.startswith("My Title\n\n") + assert self.BODY in result[0].content + + def test_none_title_ignored(self): + result = chunk_content(self.BODY, title=None) + assert result[0].content == self.BODY + + def test_empty_title_ignored(self): + result = chunk_content(self.BODY, title="") + assert result[0].content == self.BODY + + def test_whitespace_title_ignored(self): + result = chunk_content(self.BODY, title=" ") + assert result[0].content == self.BODY + + def test_title_is_stripped(self): + result = chunk_content(self.BODY, title=" My Title ") + assert result[0].content.startswith("My Title\n\n") + + def test_title_prepended_to_every_chunk(self): + para = "A" * 100 + content = f"{para}\n\n{para}\n\n{para}" + result = chunk_content(content, title="Title", chunk_size=150, chunk_overlap=20) + for chunk in result: + assert chunk.content.startswith("Title") + + +class TestChunkContentWithUrl: + BODY = "Some body text here that is long enough to pass the minimum chunk length filter." + + def test_url_prepended_to_single_chunk(self): + result = chunk_content(self.BODY, url="https://example.com/page") + assert result[0].content.startswith("https://example.com/page\n\n") + assert self.BODY in result[0].content + + def test_none_url_ignored(self): + result = chunk_content(self.BODY, url=None) + assert result[0].content == self.BODY + + def test_empty_url_ignored(self): + result = chunk_content(self.BODY, url="") + assert result[0].content == self.BODY + + def test_whitespace_url_ignored(self): + result = chunk_content(self.BODY, url=" ") + assert result[0].content == self.BODY + + def test_title_and_url_both_in_prefix(self): + result = chunk_content(self.BODY, title="My Title", url="https://example.com/page") + assert result[0].content.startswith("My Title\n\nhttps://example.com/page\n\n") + assert self.BODY in result[0].content + + def test_url_prepended_to_every_chunk(self): + para = "A" * 100 + content = f"{para}\n\n{para}\n\n{para}" + result = chunk_content(content, url="https://example.com", chunk_size=200, chunk_overlap=20) + for chunk in result: + assert "https://example.com" in chunk.content + + +class TestChunkContentMultipleParagraphs: + def test_two_paragraphs_within_limit_stay_in_one_chunk(self): + p1 = "First paragraph with enough content to be meaningful here." + p2 = "Second paragraph also with enough content to pass filters." + content = f"{p1}\n\n{p2}" + result = chunk_content(content, chunk_size=500) + assert len(result) == 1 + assert p1 in result[0].content + assert p2 in result[0].content + + def test_paragraphs_exceeding_limit_split_into_multiple_chunks(self): + p1 = "A" * 100 + p2 = "B" * 100 + p3 = "C" * 100 + content = f"{p1}\n\n{p2}\n\n{p3}" + result = chunk_content(content, chunk_size=150, chunk_overlap=20) + assert len(result) > 1 + + def test_paragraph_boundaries_preserved(self): + p1 = "First paragraph with enough content to pass the minimum length." + p2 = "Second paragraph also with enough content to pass the filter." + content = f"{p1}\n\n{p2}" + result = chunk_content(content, chunk_size=500) + assert "\n\n" in result[0].content + + +class TestChunkContentOverlap: + def test_multiple_chunks_produced_with_overlap(self): + p1 = "A" * 150 + p2 = "B" * 150 + result = chunk_content(f"{p1}\n\n{p2}", chunk_size=200, chunk_overlap=50) + assert len(result) >= 2 + + def test_overlap_produces_shared_content(self): + sentences = [f"Sentence number {i} with some extra words to fill." for i in range(20)] + content = " ".join(sentences) + result = chunk_content(content, chunk_size=200, chunk_overlap=50) + assert len(result) >= 2 + # With overlap, adjacent chunks should share some content + for i in range(len(result) - 1): + combined_adjacent = result[i].content + result[i + 1].content + assert len(combined_adjacent) > len(result[i].content) + + +class TestChunkContentLargeParagraphSentenceSplitting: + def test_large_paragraph_splits_into_multiple(self): + sentences = [f"This is sentence number {i}." for i in range(50)] + large_para = " ".join(sentences) + result = chunk_content(large_para, chunk_size=200, chunk_overlap=30) + assert len(result) > 1 + + def test_sentences_distributed_across_chunks(self): + sentences = [f"This is a fairly long sentence number {i} here." for i in range(30)] + large_para = " ".join(sentences) + result = chunk_content(large_para, chunk_size=200, chunk_overlap=20) + combined = " ".join(c.content for c in result) + for s in sentences: + assert s in combined + + +class TestChunkContentHardSplit: + def test_very_long_text_gets_split(self): + long_text = "A" * 5000 + result = chunk_content(long_text, chunk_size=500, chunk_overlap=50) + assert len(result) > 1 + + def test_hard_split_pieces_cover_original(self): + long_text = "B" * 3000 + result = chunk_content(long_text, chunk_size=500, chunk_overlap=50) + combined = "".join(c.content for c in result) + assert "B" * 500 in combined + + +class TestChunkContentMinChunkLength: + def test_short_content_below_min_filtered_out(self): + result = chunk_content("Hi.", min_chunk_length=100) + assert result == [] + + def test_content_at_min_length_kept(self): + text = "A" * MIN_CHUNK_LENGTH + result = chunk_content(text) + assert len(result) == 1 + + def test_content_just_below_min_length_filtered(self): + text = "A" * (MIN_CHUNK_LENGTH - 1) + result = chunk_content(text) + assert result == [] + + def test_custom_min_chunk_length(self): + result = chunk_content("Short text.", min_chunk_length=5) + assert len(result) == 1 + + def test_custom_high_min_chunk_length_filters(self): + result = chunk_content("Short text.", min_chunk_length=500) + assert result == [] + + +class TestChunkContentCustomParams: + def test_custom_chunk_size(self): + content = "Word " * 200 + result_small = chunk_content(content, chunk_size=100, chunk_overlap=10) + result_large = chunk_content(content, chunk_size=2000, chunk_overlap=10) + assert len(result_small) > len(result_large) + + def test_defaults_match_constants(self): + assert CHUNK_SIZE == 2048 + assert CHUNK_OVERLAP == 200 + assert MIN_CHUNK_LENGTH == 50 + + +class TestChunkContentIndexNumbering: + def test_single_chunk_has_index_zero(self): + text = "Hello world, this is a test of the chunking service module." + result = chunk_content(text) + assert result[0].index == 0 + + def test_multiple_chunks_have_sequential_indexes(self): + paragraphs = [("P" * 100) for _ in range(10)] + content = "\n\n".join(paragraphs) + result = chunk_content(content, chunk_size=150, chunk_overlap=20) + assert len(result) > 1 + for i, chunk in enumerate(result): + assert chunk.index == i + + def test_indexes_are_contiguous(self): + long_text = "X" * 3000 + result = chunk_content(long_text, chunk_size=300, chunk_overlap=30) + indexes = [c.index for c in result] + assert indexes == list(range(len(result))) + + +class TestMarkdownAwareChunking: + def test_splits_at_header_boundaries(self): + content = "## Section One\n\n" + "A" * 300 + "\n\n## Section Two\n\n" + "B" * 300 + result = chunk_content(content, chunk_size=400, chunk_overlap=0) + assert len(result) >= 2 + assert "Section One" in result[0].content + assert "Section Two" in result[-1].content + + def test_header_content_preserved_in_chunks(self): + content = "# Main Title\n\nSome introduction text that is long enough.\n\n## Details\n\nMore details here that pass the filter." + result = chunk_content(content, chunk_size=2048) + combined = "\n".join(c.content for c in result) + assert "Main Title" in combined + assert "Details" in combined + + def test_code_block_kept_intact(self): + code = "```python\ndef hello():\n print('world')\n return 42\n```" + content = f"Some intro text here.\n\n{code}\n\nSome outro text here." + result = chunk_content(content, chunk_size=2048) + combined = "\n".join(c.content for c in result) + assert "def hello():" in combined + assert "return 42" in combined + + def test_table_kept_intact(self): + table = "| Col A | Col B |\n| --- | --- |\n| val1 | val2 |\n| val3 | val4 |" + content = f"Some intro text here.\n\n{table}\n\nSome outro text here." + result = chunk_content(content, chunk_size=2048) + combined = "\n".join(c.content for c in result) + assert "val1" in combined + assert "val4" in combined + + def test_nested_headers_produce_chunks(self): + content = ( + "# Top Level\n\nIntro text here.\n\n## Sub Section\n\nSub text here.\n\n### Deep Section\n\nDeep text here." + ) + result = chunk_content(content, chunk_size=2048) + combined = "\n".join(c.content for c in result) + assert "Top Level" in combined + assert "Sub Section" in combined + assert "Deep Section" in combined + + def test_long_section_gets_sub_split(self): + long_body = " ".join([f"This is sentence number {i} in a very long section." for i in range(50)]) + content = f"## Long Section\n\n{long_body}" + result = chunk_content(content, chunk_size=300, chunk_overlap=30) + assert len(result) > 1 + + def test_short_sections_merged_into_one_chunk(self): + content = "## A\n\nShort text.\n\n## B\n\nAnother short text.\n\n## C\n\nYet another." + result = chunk_content(content, chunk_size=2048) + assert len(result) == 1 + + def test_title_and_url_in_every_chunk_with_headers(self): + content = "## Section One\n\n" + "A" * 300 + "\n\n## Section Two\n\n" + "B" * 300 + result = chunk_content( + content, + title="My Page", + url="https://example.com/page", + chunk_size=400, + chunk_overlap=0, + ) + assert len(result) >= 2 + for chunk in result: + assert "My Page" in chunk.content + assert "https://example.com/page" in chunk.content + + def test_realistic_page(self): + content = ( + "# WiseKey Security Solutions\n\n" + "WiseKey provides cybersecurity solutions for IoT and digital identity.\n\n" + "## Products\n\n" + "Our product line includes secure semiconductors and PKI services.\n\n" + "### WiseKey IoT\n\n" + "The IoT platform secures connected devices with certificate-based auth.\n\n" + "### WiseKey PKI\n\n" + "Public Key Infrastructure for enterprise identity management.\n\n" + "## Partners\n\n" + "We work with leading technology companies worldwide.\n\n" + "## Contact\n\n" + "Visit us at wisekey.com for more information." + ) + result = chunk_content(content, title="WiseKey", url="https://wisekey.com") + assert len(result) >= 1 + combined = "\n".join(c.content for c in result) + assert "WiseKey" in combined + assert "IoT" in combined + assert "PKI" in combined + assert "Partners" in combined + + def test_all_content_preserved(self): + sections = [f"## Section {i}\n\nContent for section {i} with enough words." for i in range(5)] + content = "\n\n".join(sections) + result = chunk_content(content, chunk_size=2048) + combined = " ".join(c.content for c in result) + for i in range(5): + assert f"Section {i}" in combined + assert f"Content for section {i}" in combined diff --git a/services/crawler/tests/test_config.py b/services/crawler/tests/test_config.py index 84218c61b..6848a79cf 100644 --- a/services/crawler/tests/test_config.py +++ b/services/crawler/tests/test_config.py @@ -83,3 +83,27 @@ def test_missing_vision_model_raises(self): s = Settings() with pytest.raises(ValueError, match="OPENAI_VISION_MODEL"): s.get_vision_model() + + +class TestGetEmbeddingDimensions: + def test_crawler_prefixed_takes_priority(self): + env = _base_env() + env["CRAWLER_EMBEDDING_DIMENSIONS"] = "768" + env["EMBEDDING_DIMENSIONS"] = "1536" + with patch.dict(os.environ, env, clear=True): + s = Settings() + assert s.get_embedding_dimensions() == 768 + + def test_falls_back_to_generic_env(self): + env = _base_env() + env["EMBEDDING_DIMENSIONS"] = "1536" + with patch.dict(os.environ, env, clear=True): + s = Settings() + assert s.get_embedding_dimensions() == 1536 + + def test_missing_dimensions_raises(self): + env = _base_env() + with patch.dict(os.environ, env, clear=True): + s = Settings() + with pytest.raises(ValueError, match="EMBEDDING_DIMENSIONS"): + s.get_embedding_dimensions() diff --git a/services/crawler/tests/test_database.py b/services/crawler/tests/test_database.py new file mode 100644 index 000000000..cdcbc7ef1 --- /dev/null +++ b/services/crawler/tests/test_database.py @@ -0,0 +1,140 @@ +"""Tests for database pool initialization, including dimension mismatch guard.""" + +from unittest.mock import AsyncMock, MagicMock, call, patch + +import pytest + +import app.services.database as db_mod + + +@pytest.fixture(autouse=True) +def _reset_pool(): + """Ensure module-level _pool is reset before and after each test.""" + db_mod._pool = None + yield + db_mod._pool = None + + +def _fake_pool(stored_dims: int | None, col_type: str = "vector(1536)"): + """Build a mock asyncpg pool. + + *stored_dims* is returned for the first fetchval (dimension check). + *col_type* is returned for the second fetchval (column type check). + """ + conn = AsyncMock() + conn.fetchval = AsyncMock(side_effect=[stored_dims, col_type]) + conn.execute = AsyncMock() + + ctx = AsyncMock() + ctx.__aenter__ = AsyncMock(return_value=conn) + ctx.__aexit__ = AsyncMock(return_value=False) + + pool = AsyncMock() + pool.acquire = MagicMock(return_value=ctx) + pool.close = AsyncMock() + return pool + + +class TestDimensionMismatchGuard: + @pytest.mark.asyncio + async def test_raises_on_dimension_mismatch(self): + fake_pool = _fake_pool(stored_dims=3072) + + with ( + patch("app.services.database.asyncpg.create_pool", AsyncMock(return_value=fake_pool)), + patch("app.services.database.settings") as mock_settings, + ): + mock_settings.get_embedding_dimensions.return_value = 1536 + mock_settings.database_url = "postgresql://test:test@localhost/test" + + with pytest.raises(RuntimeError, match="dimension mismatch"): + await db_mod.init_pool() + + assert db_mod._pool is None + + @pytest.mark.asyncio + async def test_passes_when_dimensions_match(self): + fake_pool = _fake_pool(stored_dims=1536, col_type="vector(1536)") + + with ( + patch("app.services.database.asyncpg.create_pool", AsyncMock(return_value=fake_pool)), + patch("app.services.database.settings") as mock_settings, + ): + mock_settings.get_embedding_dimensions.return_value = 1536 + mock_settings.database_url = "postgresql://test:test@localhost/test" + + pool = await db_mod.init_pool() + + assert pool is fake_pool + + @pytest.mark.asyncio + async def test_passes_when_no_existing_data(self): + fake_pool = _fake_pool(stored_dims=None, col_type="vector") + + with ( + patch("app.services.database.asyncpg.create_pool", AsyncMock(return_value=fake_pool)), + patch("app.services.database.settings") as mock_settings, + ): + mock_settings.get_embedding_dimensions.return_value = 1536 + mock_settings.database_url = "postgresql://test:test@localhost/test" + + pool = await db_mod.init_pool() + + assert pool is fake_pool + + +class TestEmbeddingColumnPinning: + @pytest.mark.asyncio + async def test_alters_untyped_vector_column(self): + """When column is bare `vector`, init_pool pins it to vector(N).""" + fake_pool = _fake_pool(stored_dims=None, col_type="vector") + + with ( + patch("app.services.database.asyncpg.create_pool", AsyncMock(return_value=fake_pool)), + patch("app.services.database.settings") as mock_settings, + ): + mock_settings.get_embedding_dimensions.return_value = 768 + mock_settings.database_url = "postgresql://test:test@localhost/test" + + await db_mod.init_pool() + + conn = fake_pool.acquire().__aenter__.return_value + execute_calls = [str(c) for c in conn.execute.call_args_list] + assert any("ALTER TABLE" in c and "vector(768)" in c for c in execute_calls) + + @pytest.mark.asyncio + async def test_skips_alter_when_already_typed(self): + """When column already has dimensions, no ALTER is issued.""" + fake_pool = _fake_pool(stored_dims=1536, col_type="vector(1536)") + + with ( + patch("app.services.database.asyncpg.create_pool", AsyncMock(return_value=fake_pool)), + patch("app.services.database.settings") as mock_settings, + ): + mock_settings.get_embedding_dimensions.return_value = 1536 + mock_settings.database_url = "postgresql://test:test@localhost/test" + + await db_mod.init_pool() + + conn = fake_pool.acquire().__aenter__.return_value + execute_calls = [str(c) for c in conn.execute.call_args_list] + assert not any("ALTER TABLE" in c for c in execute_calls) + + @pytest.mark.asyncio + async def test_repins_column_when_dimension_changed(self): + """When column is pinned to a different dimension and table is empty, re-pin.""" + fake_pool = _fake_pool(stored_dims=None, col_type="vector(2560)") + + with ( + patch("app.services.database.asyncpg.create_pool", AsyncMock(return_value=fake_pool)), + patch("app.services.database.settings") as mock_settings, + ): + mock_settings.get_embedding_dimensions.return_value = 1536 + mock_settings.database_url = "postgresql://test:test@localhost/test" + + await db_mod.init_pool() + + conn = fake_pool.acquire().__aenter__.return_value + execute_calls = [str(c) for c in conn.execute.call_args_list] + assert any("DROP INDEX" in c and "idx_chunks_embedding_hnsw" in c for c in execute_calls) + assert any("ALTER TABLE" in c and "vector(1536)" in c for c in execute_calls) diff --git a/services/crawler/tests/test_embedding_service.py b/services/crawler/tests/test_embedding_service.py new file mode 100644 index 000000000..634ec1be0 --- /dev/null +++ b/services/crawler/tests/test_embedding_service.py @@ -0,0 +1,164 @@ +from types import SimpleNamespace +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from app.services.embedding_service import EmbeddingService + + +def make_embedding_response(embeddings: list[list[float]]): + return SimpleNamespace(data=[SimpleNamespace(embedding=e) for e in embeddings]) + + +def create_service(dimensions: int = 1536) -> EmbeddingService: + service = EmbeddingService( + api_key="test-key", + base_url=None, + model="test-model", + dimensions=dimensions, + ) + mock_client = MagicMock() + mock_client.embeddings.create = AsyncMock() + service._client = mock_client + return service + + +class TestDimensionsProperty: + def test_returns_configured_dimensions(self): + service = create_service(dimensions=768) + assert service.dimensions == 768 + + def test_returns_default_dimensions(self): + service = create_service() + assert service.dimensions == 1536 + + +class TestEmbedTexts: + async def test_empty_texts_returns_empty_list(self): + service = create_service() + result = await service.embed_texts([]) + assert result == [] + service._client.embeddings.create.assert_not_called() + + async def test_single_text(self): + service = create_service(dimensions=3) + expected = [0.1, 0.2, 0.3] + service._client.embeddings.create.return_value = make_embedding_response([expected]) + + result = await service.embed_texts(["hello"]) + + assert result == [expected] + service._client.embeddings.create.assert_called_once_with( + model="test-model", + input=["hello"], + dimensions=3, + ) + + async def test_multiple_texts_single_batch(self): + service = create_service(dimensions=2) + embeddings = [[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]] + service._client.embeddings.create.return_value = make_embedding_response(embeddings) + + result = await service.embed_texts(["a", "b", "c"]) + + assert result == embeddings + service._client.embeddings.create.assert_called_once_with( + model="test-model", + input=["a", "b", "c"], + dimensions=2, + ) + + async def test_batching_splits_large_input(self, monkeypatch): + import app.services.embedding_service as module + + monkeypatch.setattr(module, "MAX_BATCH_SIZE", 2) + + service = create_service(dimensions=2) + batch1_embeddings = [[0.1, 0.2], [0.3, 0.4]] + batch2_embeddings = [[0.5, 0.6]] + service._client.embeddings.create.side_effect = [ + make_embedding_response(batch1_embeddings), + make_embedding_response(batch2_embeddings), + ] + + result = await service.embed_texts(["a", "b", "c"]) + + assert result == [[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]] + assert service._client.embeddings.create.call_count == 2 + calls = service._client.embeddings.create.call_args_list + assert calls[0].kwargs == {"model": "test-model", "input": ["a", "b"], "dimensions": 2} + assert calls[1].kwargs == {"model": "test-model", "input": ["c"], "dimensions": 2} + + async def test_batching_exact_multiple(self, monkeypatch): + import app.services.embedding_service as module + + monkeypatch.setattr(module, "MAX_BATCH_SIZE", 2) + + service = create_service(dimensions=1) + service._client.embeddings.create.side_effect = [ + make_embedding_response([[1.0], [2.0]]), + make_embedding_response([[3.0], [4.0]]), + ] + + result = await service.embed_texts(["a", "b", "c", "d"]) + + assert result == [[1.0], [2.0], [3.0], [4.0]] + assert service._client.embeddings.create.call_count == 2 + + +class TestEmbedQuery: + async def test_returns_single_vector(self): + service = create_service(dimensions=3) + expected = [0.1, 0.2, 0.3] + service._client.embeddings.create.return_value = make_embedding_response([expected]) + + result = await service.embed_query("search term") + + assert result == expected + service._client.embeddings.create.assert_called_once_with( + model="test-model", + input=["search term"], + dimensions=3, + ) + + +class TestRetryBehavior: + @patch("app.services.embedding_service.asyncio.sleep", new_callable=AsyncMock) + async def test_retries_on_first_failure(self, mock_sleep): + service = create_service(dimensions=2) + expected = [[0.1, 0.2]] + service._client.embeddings.create.side_effect = [ + RuntimeError("API error"), + make_embedding_response(expected), + ] + + result = await service.embed_texts(["hello"]) + + assert result == expected + assert service._client.embeddings.create.call_count == 2 + mock_sleep.assert_awaited_once_with(1.0) + + @patch("app.services.embedding_service.asyncio.sleep", new_callable=AsyncMock) + async def test_raises_after_all_retries_exhausted(self, mock_sleep): + service = create_service(dimensions=2) + service._client.embeddings.create.side_effect = [ + RuntimeError("API error 1"), + RuntimeError("API error 2"), + RuntimeError("API error 3"), + ] + + with pytest.raises(RuntimeError, match="API error 3"): + await service.embed_texts(["hello"]) + + assert service._client.embeddings.create.call_count == 3 + assert mock_sleep.await_count == 2 + + async def test_no_retry_on_success(self): + service = create_service(dimensions=2) + expected = [[0.1, 0.2]] + service._client.embeddings.create.return_value = make_embedding_response(expected) + + result = await service.embed_texts(["hello"]) + + assert result == expected + assert service._client.embeddings.create.call_count == 1 diff --git a/services/crawler/tests/test_index_router.py b/services/crawler/tests/test_index_router.py new file mode 100644 index 000000000..b21cbcd58 --- /dev/null +++ b/services/crawler/tests/test_index_router.py @@ -0,0 +1,134 @@ +from unittest.mock import AsyncMock + +import pytest +from fastapi import FastAPI +from httpx import ASGITransport, AsyncClient + +from app.routers.index import router, set_indexing_service + +app = FastAPI() +app.include_router(router) + + +@pytest.fixture +def mock_indexing_service(): + service = AsyncMock() + set_indexing_service(service) + yield service + set_indexing_service(None) + + +class TestIndexPage: + async def test_success(self, mock_indexing_service): + mock_indexing_service.index_page.return_value = { + "url": "https://example.com/page", + "status": "indexed", + "chunks_indexed": 5, + } + + async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client: + response = await client.post( + "/api/v1/index/page", + json={ + "domain": "example.com", + "url": "https://example.com/page", + "title": "Test Page", + "content": "Some content to index", + }, + ) + + assert response.status_code == 200 + data = response.json() + assert data["success"] is True + assert data["url"] == "https://example.com/page" + assert data["chunks_indexed"] == 5 + assert data["status"] == "indexed" + mock_indexing_service.index_page.assert_awaited_once_with( + domain="example.com", + url="https://example.com/page", + title="Test Page", + content="Some content to index", + ) + + async def test_skipped_page(self, mock_indexing_service): + mock_indexing_service.index_page.return_value = { + "url": "https://example.com/page", + "status": "skipped", + "chunks_indexed": 0, + } + + async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client: + response = await client.post( + "/api/v1/index/page", + json={"domain": "example.com", "url": "https://example.com/page", "content": "Same content"}, + ) + + assert response.status_code == 200 + data = response.json() + assert data["success"] is True + assert data["status"] == "skipped" + + async def test_503_when_service_not_initialized(self): + set_indexing_service(None) + + async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client: + response = await client.post( + "/api/v1/index/page", + json={"domain": "example.com", "url": "https://example.com/page", "content": "content"}, + ) + + assert response.status_code == 503 + assert response.json()["detail"] == "Indexing service not initialized" + + async def test_500_on_unexpected_error(self, mock_indexing_service): + mock_indexing_service.index_page.side_effect = RuntimeError("db error") + + async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client: + response = await client.post( + "/api/v1/index/page", + json={"domain": "example.com", "url": "https://example.com/page", "content": "content"}, + ) + + assert response.status_code == 500 + assert response.json()["detail"] == "Indexing failed" + + +class TestIndexWebsite: + async def test_success(self, mock_indexing_service): + mock_indexing_service.index_website.return_value = { + "domain": "example.com", + "pages_indexed": 10, + "pages_skipped": 2, + "pages_failed": 1, + "total_chunks": 50, + } + + async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client: + response = await client.post("/api/v1/index/website/example.com") + + assert response.status_code == 200 + data = response.json() + assert data["success"] is True + assert data["domain"] == "example.com" + assert data["pages_indexed"] == 10 + assert data["pages_skipped"] == 2 + assert data["pages_failed"] == 1 + assert data["total_chunks"] == 50 + mock_indexing_service.index_website.assert_awaited_once_with("example.com") + + async def test_503_when_service_not_initialized(self): + set_indexing_service(None) + + async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client: + response = await client.post("/api/v1/index/website/example.com") + + assert response.status_code == 503 + + async def test_500_on_unexpected_error(self, mock_indexing_service): + mock_indexing_service.index_website.side_effect = RuntimeError("boom") + + async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client: + response = await client.post("/api/v1/index/website/example.com") + + assert response.status_code == 500 + assert response.json()["detail"] == "Website indexing failed" diff --git a/services/crawler/tests/test_indexing_service.py b/services/crawler/tests/test_indexing_service.py new file mode 100644 index 000000000..04e8f1713 --- /dev/null +++ b/services/crawler/tests/test_indexing_service.py @@ -0,0 +1,169 @@ +import hashlib +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from app.services.chunking_service import ContentChunk +from app.services.indexing_service import IndexingService + + +def _sha256(content: str) -> str: + return hashlib.sha256(content.encode()).hexdigest() + + +@pytest.fixture +def mock_conn(): + conn = AsyncMock() + conn.fetchval = AsyncMock(return_value=None) + conn.fetch = AsyncMock(return_value=[]) + conn.execute = AsyncMock(return_value="DELETE 0") + conn.executemany = AsyncMock() + conn.transaction = MagicMock(return_value=AsyncMock(__aenter__=AsyncMock(), __aexit__=AsyncMock())) + return conn + + +@pytest.fixture +def mock_pool(mock_conn): + pool = AsyncMock() + ctx = AsyncMock() + ctx.__aenter__ = AsyncMock(return_value=mock_conn) + ctx.__aexit__ = AsyncMock(return_value=None) + pool.acquire = MagicMock(return_value=ctx) + return pool + + +@pytest.fixture +def mock_embedding(): + service = AsyncMock() + service.embed_texts = AsyncMock(return_value=[[0.1, 0.2], [0.3, 0.4]]) + return service + + +@pytest.fixture +def indexing_service(mock_pool, mock_embedding): + return IndexingService(mock_pool, mock_embedding) + + +class TestIndexPage: + async def test_skips_when_content_hash_matches(self, mock_conn, indexing_service): + content = "some page content" + mock_conn.fetchval = AsyncMock(return_value=_sha256(content)) + + result = await indexing_service.index_page("example.com", "https://example.com/page", "Title", content) + + assert result["status"] == "skipped" + assert result["chunks_indexed"] == 0 + + @patch("app.services.indexing_service.chunk_content", return_value=[]) + async def test_returns_empty_when_no_chunks(self, mock_chunk, indexing_service): + result = await indexing_service.index_page("example.com", "https://example.com/page", "Title", "content") + + assert result["status"] == "empty" + assert result["chunks_indexed"] == 0 + + @patch("app.services.indexing_service.chunk_content") + async def test_returns_error_when_embedding_fails(self, mock_chunk, indexing_service, mock_embedding): + mock_chunk.return_value = [ContentChunk(content="chunk text", index=0)] + mock_embedding.embed_texts = AsyncMock(side_effect=RuntimeError("API down")) + + result = await indexing_service.index_page("example.com", "https://example.com/page", "Title", "content") + + assert result["status"] == "error" + assert result["error"] == "embedding_failed" + assert result["chunks_indexed"] == 0 + + @patch("app.services.indexing_service.chunk_content") + async def test_indexes_successfully(self, mock_chunk, indexing_service, mock_embedding): + chunks = [ContentChunk(content="chunk one", index=0), ContentChunk(content="chunk two", index=1)] + mock_chunk.return_value = chunks + mock_embedding.embed_texts = AsyncMock(return_value=[[0.1, 0.2], [0.3, 0.4]]) + + result = await indexing_service.index_page("example.com", "https://example.com/page", "Title", "content") + + assert result["status"] == "indexed" + assert result["chunks_indexed"] == 2 + assert result["url"] == "https://example.com/page" + + @patch("app.services.indexing_service.chunk_content") + async def test_deletes_old_chunks_before_inserting(self, mock_chunk, indexing_service, mock_conn, mock_embedding): + chunks = [ContentChunk(content="chunk", index=0)] + mock_chunk.return_value = chunks + mock_embedding.embed_texts = AsyncMock(return_value=[[0.1, 0.2]]) + + await indexing_service.index_page("example.com", "https://example.com/page", "Title", "content") + + calls = [str(c) for c in mock_conn.execute.call_args_list] + delete_call = next(c for c in calls if "DELETE" in c) + assert "https://example.com/page" in delete_call + mock_conn.executemany.assert_called_once() + + +class TestDeletePageChunks: + async def test_returns_deleted_count(self, indexing_service, mock_conn): + mock_conn.execute = AsyncMock(return_value="DELETE 5") + + count = await indexing_service.delete_page_chunks("https://example.com/page") + + assert count == 5 + + async def test_returns_zero_when_no_rows_deleted(self, indexing_service, mock_conn): + mock_conn.execute = AsyncMock(return_value="DELETE 0") + + count = await indexing_service.delete_page_chunks("https://example.com/page") + + assert count == 0 + + async def test_returns_zero_when_result_is_empty(self, indexing_service, mock_conn): + mock_conn.execute = AsyncMock(return_value="") + + count = await indexing_service.delete_page_chunks("https://example.com/page") + + assert count == 0 + + +class TestIndexWebsite: + async def test_aggregates_results_correctly(self, indexing_service, mock_conn): + mock_conn.fetch = AsyncMock( + side_effect=[ + [ + {"url": "https://example.com/a", "title": "Page A", "content": "aaa"}, + {"url": "https://example.com/b", "title": "Page B", "content": "bbb"}, + {"url": "https://example.com/c", "title": "Page C", "content": "ccc"}, + ], + [], + ] + ) + + call_count = 0 + results = [ + {"url": "https://example.com/a", "status": "indexed", "chunks_indexed": 3}, + {"url": "https://example.com/b", "status": "skipped", "chunks_indexed": 0}, + {"url": "https://example.com/c", "status": "error", "chunks_indexed": 0, "error": "embedding_failed"}, + ] + + async def fake_index_page(domain, url, title, content): + nonlocal call_count + result = results[call_count] + call_count += 1 + return result + + indexing_service.index_page = fake_index_page + + result = await indexing_service.index_website("example.com") + + assert result["domain"] == "example.com" + assert result["pages_indexed"] == 1 + assert result["pages_skipped"] == 1 + assert result["pages_failed"] == 1 + assert result["total_chunks"] == 3 + + async def test_returns_zeros_when_no_pages(self, indexing_service, mock_conn): + mock_conn.fetch = AsyncMock(return_value=[]) + + result = await indexing_service.index_website("empty.com") + + assert result["domain"] == "empty.com" + assert result["pages_indexed"] == 0 + assert result["pages_skipped"] == 0 + assert result["pages_failed"] == 0 + assert result["total_chunks"] == 0 diff --git a/services/crawler/tests/test_pages_router.py b/services/crawler/tests/test_pages_router.py new file mode 100644 index 000000000..9d4c58464 --- /dev/null +++ b/services/crawler/tests/test_pages_router.py @@ -0,0 +1,264 @@ +from datetime import datetime, timezone +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from fastapi import FastAPI +from httpx import ASGITransport, AsyncClient + +from app.routers.pages import router + +app = FastAPI() +app.include_router(router) + +_DEFAULT_CRAWLED = datetime(2025, 6, 1, 12, 0, 0, tzinfo=timezone.utc) +_DEFAULT_DISCOVERED = datetime(2025, 5, 15, 8, 0, 0, tzinfo=timezone.utc) + + +class FakeRecord(dict): + """Dict subclass mimicking asyncpg Record with r["field"] access.""" + + +def _make_row(**overrides): + defaults = { + "url": "https://example.com/page1", + "title": "Page 1", + "word_count": 500, + "status": "active", + "content_hash": "abc123", + "last_crawled_at": _DEFAULT_CRAWLED, + "discovered_at": _DEFAULT_DISCOVERED, + "chunks_count": 3, + } + defaults.update(overrides) + return FakeRecord(defaults) + + +@pytest.fixture +def mock_pool(): + conn = AsyncMock() + pool = MagicMock() + ctx = AsyncMock() + ctx.__aenter__ = AsyncMock(return_value=conn) + ctx.__aexit__ = AsyncMock(return_value=False) + pool.acquire.return_value = ctx + with patch("app.routers.pages.get_pool", return_value=pool): + yield conn + + +class TestListPages: + async def test_success(self, mock_pool): + rows = [ + _make_row(url="https://example.com/a", title="Page A", word_count=100, chunks_count=2), + _make_row(url="https://example.com/b", title="Page B", word_count=200, chunks_count=0), + ] + mock_pool.fetch.return_value = rows + mock_pool.fetchval.return_value = 2 + + async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client: + response = await client.get("/api/v1/pages/example.com") + + assert response.status_code == 200 + data = response.json() + assert data["domain"] == "example.com" + assert data["total"] == 2 + assert data["offset"] == 0 + assert data["has_more"] is False + assert len(data["pages"]) == 2 + + page_a = data["pages"][0] + assert page_a["url"] == "https://example.com/a" + assert page_a["title"] == "Page A" + assert page_a["word_count"] == 100 + assert page_a["chunks_count"] == 2 + assert page_a["indexed"] is True + + page_b = data["pages"][1] + assert page_b["url"] == "https://example.com/b" + assert page_b["chunks_count"] == 0 + assert page_b["indexed"] is False + + async def test_empty_result(self, mock_pool): + mock_pool.fetch.return_value = [] + mock_pool.fetchval.return_value = 0 + + async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client: + response = await client.get("/api/v1/pages/unknown.com") + + assert response.status_code == 200 + data = response.json() + assert data["domain"] == "unknown.com" + assert data["pages"] == [] + assert data["total"] == 0 + assert data["has_more"] is False + + async def test_status_filter(self, mock_pool): + mock_pool.fetch.return_value = [] + mock_pool.fetchval.return_value = 0 + + async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client: + response = await client.get("/api/v1/pages/example.com?status=active") + + assert response.status_code == 200 + + fetch_call = mock_pool.fetch.call_args + query = fetch_call[0][0] + assert "wu.status = $2" in query + + params = fetch_call[0][1:] + assert params[0] == "example.com" + assert params[1] == "active" + + async def test_has_more_true(self, mock_pool): + mock_pool.fetch.return_value = [_make_row(url="https://example.com/p1")] + mock_pool.fetchval.return_value = 50 + + async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client: + response = await client.get("/api/v1/pages/example.com?offset=0&limit=10") + + assert response.status_code == 200 + data = response.json() + assert data["has_more"] is True + assert data["total"] == 50 + assert data["offset"] == 0 + + async def test_has_more_false_at_end(self, mock_pool): + mock_pool.fetch.return_value = [_make_row(url="https://example.com/p1")] + mock_pool.fetchval.return_value = 50 + + async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client: + response = await client.get("/api/v1/pages/example.com?offset=40&limit=10") + + assert response.status_code == 200 + data = response.json() + assert data["has_more"] is False + + async def test_sort_param(self, mock_pool): + mock_pool.fetch.return_value = [] + mock_pool.fetchval.return_value = 0 + + async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client: + response = await client.get("/api/v1/pages/example.com?sort=word_count") + + assert response.status_code == 200 + + query = mock_pool.fetch.call_args[0][0] + assert "ORDER BY wu.word_count DESC" in query + + async def test_invalid_sort_falls_back(self, mock_pool): + mock_pool.fetch.return_value = [] + mock_pool.fetchval.return_value = 0 + + async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client: + response = await client.get("/api/v1/pages/example.com?sort=invalid_field") + + assert response.status_code == 200 + + query = mock_pool.fetch.call_args[0][0] + assert "ORDER BY wu.last_crawled_at DESC" in query + + async def test_pagination_params_passed(self, mock_pool): + mock_pool.fetch.return_value = [] + mock_pool.fetchval.return_value = 0 + + async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client: + response = await client.get("/api/v1/pages/example.com?offset=20&limit=50") + + assert response.status_code == 200 + + fetch_call = mock_pool.fetch.call_args + params = fetch_call[0][1:] + assert 50 in params + assert 20 in params + + async def test_null_timestamps(self, mock_pool): + mock_pool.fetch.return_value = [ + _make_row(url="https://example.com/new", last_crawled_at=None, discovered_at=None), + ] + mock_pool.fetchval.return_value = 1 + + async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client: + response = await client.get("/api/v1/pages/example.com") + + assert response.status_code == 200 + page = response.json()["pages"][0] + assert page["last_crawled_at"] is None + assert page["discovered_at"] is None + + async def test_null_word_count_defaults_to_zero(self, mock_pool): + mock_pool.fetch.return_value = [_make_row(url="https://example.com/empty", word_count=None)] + mock_pool.fetchval.return_value = 1 + + async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client: + response = await client.get("/api/v1/pages/example.com") + + assert response.status_code == 200 + assert response.json()["pages"][0]["word_count"] == 0 + + async def test_500_on_database_error(self, mock_pool): + mock_pool.fetch.side_effect = RuntimeError("connection lost") + + async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client: + response = await client.get("/api/v1/pages/example.com") + + assert response.status_code == 500 + assert response.json()["detail"] == "Failed to list pages" + + +def _make_chunk_row(**overrides): + defaults = { + "chunk_index": 0, + "chunk_content": "This is chunk content.", + } + defaults.update(overrides) + return FakeRecord(defaults) + + +class TestGetPageChunks: + async def test_success(self, mock_pool): + rows = [ + _make_chunk_row(chunk_index=0, chunk_content="First chunk"), + _make_chunk_row(chunk_index=1, chunk_content="Second chunk"), + ] + mock_pool.fetch.return_value = rows + + async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client: + response = await client.get( + "/api/v1/pages/example.com/chunks", + params={"url": "https://example.com/page1"}, + ) + + assert response.status_code == 200 + data = response.json() + assert data["url"] == "https://example.com/page1" + assert data["domain"] == "example.com" + assert data["total"] == 2 + assert len(data["chunks"]) == 2 + assert data["chunks"][0]["chunk_index"] == 0 + assert data["chunks"][0]["chunk_content"] == "First chunk" + assert data["chunks"][1]["chunk_index"] == 1 + + async def test_empty_chunks(self, mock_pool): + mock_pool.fetch.return_value = [] + + async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client: + response = await client.get( + "/api/v1/pages/example.com/chunks", + params={"url": "https://example.com/no-chunks"}, + ) + + assert response.status_code == 200 + data = response.json() + assert data["chunks"] == [] + assert data["total"] == 0 + + async def test_500_on_database_error(self, mock_pool): + mock_pool.fetch.side_effect = RuntimeError("connection lost") + + async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client: + response = await client.get( + "/api/v1/pages/example.com/chunks", + params={"url": "https://example.com/page1"}, + ) + + assert response.status_code == 500 + assert response.json()["detail"] == "Failed to get page chunks" diff --git a/services/crawler/tests/test_search_router.py b/services/crawler/tests/test_search_router.py new file mode 100644 index 000000000..cc812bf97 --- /dev/null +++ b/services/crawler/tests/test_search_router.py @@ -0,0 +1,129 @@ +from unittest.mock import AsyncMock + +import pytest +from fastapi import FastAPI +from httpx import ASGITransport, AsyncClient + +from app.routers.search import router, set_search_service +from app.services.search_service import SearchResult + +app = FastAPI() +app.include_router(router) + + +@pytest.fixture +def mock_search_service(): + service = AsyncMock() + set_search_service(service) + yield service + set_search_service(None) + + +class TestSearchAll: + async def test_returns_results(self, mock_search_service): + mock_search_service.search.return_value = [ + SearchResult( + url="https://example.com/page1", title="Page 1", chunk_content="Hello world", chunk_index=0, score=0.95 + ), + SearchResult( + url="https://example.com/page2", + title="Page 2", + chunk_content="Goodbye world", + chunk_index=1, + score=0.80, + ), + ] + + async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client: + response = await client.post("/api/v1/search", json={"query": "hello", "limit": 10}) + + assert response.status_code == 200 + data = response.json() + assert data["query"] == "hello" + assert data["total"] == 2 + assert len(data["results"]) == 2 + assert data["results"][0]["url"] == "https://example.com/page1" + assert data["results"][0]["score"] == 0.95 + mock_search_service.search.assert_awaited_once_with(query="hello", limit=10) + + async def test_returns_empty_results(self, mock_search_service): + mock_search_service.search.return_value = [] + + async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client: + response = await client.post("/api/v1/search", json={"query": "nonexistent"}) + + assert response.status_code == 200 + data = response.json() + assert data["query"] == "nonexistent" + assert data["total"] == 0 + assert data["results"] == [] + + async def test_uses_default_limit(self, mock_search_service): + mock_search_service.search.return_value = [] + + async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client: + await client.post("/api/v1/search", json={"query": "test"}) + + mock_search_service.search.assert_awaited_once_with(query="test", limit=10) + + async def test_503_when_service_not_initialized(self): + set_search_service(None) + + async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client: + response = await client.post("/api/v1/search", json={"query": "test"}) + + assert response.status_code == 503 + assert response.json()["detail"] == "Search service not initialized" + + async def test_500_on_unexpected_error(self, mock_search_service): + mock_search_service.search.side_effect = RuntimeError("db gone") + + async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client: + response = await client.post("/api/v1/search", json={"query": "boom"}) + + assert response.status_code == 500 + assert response.json()["detail"] == "Search failed" + + +class TestSearchDomain: + async def test_routes_domain_correctly(self, mock_search_service): + mock_search_service.search.return_value = [ + SearchResult( + url="https://docs.example.com/intro", title="Intro", chunk_content="Welcome", chunk_index=0, score=1.0 + ), + ] + + async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client: + response = await client.post("/api/v1/search/docs.example.com", json={"query": "welcome", "limit": 5}) + + assert response.status_code == 200 + data = response.json() + assert data["total"] == 1 + assert data["results"][0]["url"] == "https://docs.example.com/intro" + mock_search_service.search.assert_awaited_once_with(query="welcome", domain="docs.example.com", limit=5) + + async def test_empty_domain_results(self, mock_search_service): + mock_search_service.search.return_value = [] + + async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client: + response = await client.post("/api/v1/search/unknown.com", json={"query": "anything"}) + + assert response.status_code == 200 + assert response.json()["total"] == 0 + + async def test_503_when_service_not_initialized(self): + set_search_service(None) + + async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client: + response = await client.post("/api/v1/search/example.com", json={"query": "test"}) + + assert response.status_code == 503 + + async def test_500_on_unexpected_error(self, mock_search_service): + mock_search_service.search.side_effect = RuntimeError("oops") + + async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client: + response = await client.post("/api/v1/search/example.com", json={"query": "fail"}) + + assert response.status_code == 500 + assert response.json()["detail"] == "Search failed" diff --git a/services/crawler/tests/test_search_service.py b/services/crawler/tests/test_search_service.py new file mode 100644 index 000000000..237899363 --- /dev/null +++ b/services/crawler/tests/test_search_service.py @@ -0,0 +1,176 @@ +"""Tests for SearchService RRF merge logic.""" + +import pytest + +from app.services.search_service import RRF_K, SearchResult, SearchService + + +def _item(id, url="https://example.com", title="Title", chunk_content="content", chunk_index=0): + return {"id": id, "url": url, "title": title, "chunk_content": chunk_content, "chunk_index": chunk_index} + + +class TestMergeRrfEmptyInput: + def test_no_ranked_lists(self): + assert SearchService._merge_rrf([], limit=10) == [] + + def test_single_empty_list(self): + assert SearchService._merge_rrf([[]], limit=10) == [] + + def test_multiple_empty_lists(self): + assert SearchService._merge_rrf([[], []], limit=10) == [] + + +class TestMergeRrfSingleList: + def test_single_list_returns_all_items(self): + items = [_item(1, url="https://a.com"), _item(2, url="https://b.com")] + results = SearchService._merge_rrf([items], limit=10) + assert len(results) == 2 + assert results[0].url == "https://a.com" + assert results[1].url == "https://b.com" + + def test_single_list_preserves_rank_order(self): + items = [_item(10), _item(20), _item(30)] + results = SearchService._merge_rrf([items], limit=10) + assert [r.score for r in results] == sorted([r.score for r in results], reverse=True) + + def test_single_item(self): + results = SearchService._merge_rrf([[_item(1, chunk_content="hello")]], limit=10) + assert len(results) == 1 + assert results[0].chunk_content == "hello" + + +class TestMergeRrfOverlappingItems: + def test_overlapping_item_boosted_above_disjoint(self): + list_a = [_item(1), _item(2)] + list_b = [_item(1), _item(3)] + results = SearchService._merge_rrf([list_a, list_b], limit=10) + assert results[0].url == "https://example.com" + ids_by_score = [r for r in results] + assert ids_by_score[0].score > ids_by_score[1].score + + def test_overlapping_item_score_equals_sum_of_rrf(self): + list_a = [_item(1)] + list_b = [_item(1)] + results = SearchService._merge_rrf([list_a, list_b], limit=10) + expected_raw = 2 * (1.0 / (RRF_K + 0 + 1)) + assert len(results) == 1 + # Single item → normalized score is always 1.0 (raw / max where max == raw) + assert results[0].score == pytest.approx(1.0) + # Verify raw RRF formula: rank-0 item in 2 lists → 2 * 1/(K+1) + assert expected_raw == pytest.approx(2.0 / (RRF_K + 1)) + + def test_overlapping_at_different_ranks(self): + list_a = [_item(1), _item(2)] + list_b = [_item(2), _item(1)] + results = SearchService._merge_rrf([list_a, list_b], limit=10) + assert len(results) == 2 + assert results[0].score == results[1].score + + +class TestMergeRrfDisjointItems: + def test_disjoint_lists_merged(self): + list_a = [_item(1, url="https://a.com")] + list_b = [_item(2, url="https://b.com")] + results = SearchService._merge_rrf([list_a, list_b], limit=10) + assert len(results) == 2 + urls = {r.url for r in results} + assert urls == {"https://a.com", "https://b.com"} + + def test_disjoint_same_rank_have_equal_scores(self): + list_a = [_item(1)] + list_b = [_item(2)] + results = SearchService._merge_rrf([list_a, list_b], limit=10) + assert results[0].score == results[1].score + + def test_disjoint_different_ranks(self): + list_a = [_item(1), _item(2)] + list_b = [_item(3)] + results = SearchService._merge_rrf([list_a, list_b], limit=10) + rank0_score = 1.0 / (RRF_K + 0 + 1) + rank1_score = 1.0 / (RRF_K + 1 + 1) + top = [r for r in results if r.score == pytest.approx(1.0)] + assert len(top) == 2 + bottom = [r for r in results if r.score < 1.0] + assert len(bottom) == 1 + assert bottom[0].score == pytest.approx(rank1_score / rank0_score) + + +class TestMergeRrfLimitTruncation: + def test_limit_truncates_results(self): + items = [_item(i) for i in range(20)] + results = SearchService._merge_rrf([items], limit=5) + assert len(results) == 5 + + def test_limit_larger_than_items_returns_all(self): + items = [_item(1), _item(2)] + results = SearchService._merge_rrf([items], limit=100) + assert len(results) == 2 + + def test_limit_zero_returns_empty(self): + items = [_item(1), _item(2)] + results = SearchService._merge_rrf([items], limit=0) + assert results == [] + + def test_limit_one_returns_top_result(self): + list_a = [_item(1), _item(2)] + list_b = [_item(1), _item(3)] + results = SearchService._merge_rrf([list_a, list_b], limit=1) + assert len(results) == 1 + + +class TestMergeRrfScoreNormalization: + def test_top_result_always_has_score_one(self): + items = [_item(i) for i in range(5)] + results = SearchService._merge_rrf([items], limit=10) + assert results[0].score == pytest.approx(1.0) + + def test_top_result_score_one_with_multiple_lists(self): + list_a = [_item(1), _item(2), _item(3)] + list_b = [_item(4), _item(1), _item(5)] + results = SearchService._merge_rrf([list_a, list_b], limit=10) + assert results[0].score == pytest.approx(1.0) + + def test_scores_are_between_zero_and_one(self): + list_a = [_item(i) for i in range(10)] + list_b = [_item(i + 5) for i in range(10)] + results = SearchService._merge_rrf([list_a, list_b], limit=20) + for r in results: + assert 0.0 < r.score <= 1.0 + + def test_normalized_scores_preserve_relative_order(self): + list_a = [_item(1), _item(2), _item(3)] + results = SearchService._merge_rrf([list_a], limit=10) + scores = [r.score for r in results] + assert scores == sorted(scores, reverse=True) + + +class TestMergeRrfFieldMapping: + def test_fields_mapped_correctly(self): + item = _item(42, url="https://test.dev", title="Test Page", chunk_content="some text", chunk_index=3) + results = SearchService._merge_rrf([[item]], limit=10) + assert len(results) == 1 + r = results[0] + assert r.url == "https://test.dev" + assert r.title == "Test Page" + assert r.chunk_content == "some text" + assert r.chunk_index == 3 + + def test_title_can_be_none(self): + item = {"id": 1, "url": "https://x.com", "title": None, "chunk_content": "c", "chunk_index": 0} + results = SearchService._merge_rrf([[item]], limit=10) + assert results[0].title is None + + def test_missing_title_key_defaults_to_none(self): + item = {"id": 1, "url": "https://x.com", "chunk_content": "c", "chunk_index": 0} + results = SearchService._merge_rrf([[item]], limit=10) + assert results[0].title is None + + def test_returns_search_result_instances(self): + results = SearchService._merge_rrf([[_item(1)]], limit=10) + assert isinstance(results[0], SearchResult) + + def test_later_list_overwrites_item_metadata(self): + item_v1 = _item(1, title="Old Title") + item_v2 = _item(1, title="New Title") + results = SearchService._merge_rrf([[item_v1], [item_v2]], limit=10) + assert results[0].title == "New Title" diff --git a/services/crawler/tests/test_website_store.py b/services/crawler/tests/test_website_store.py deleted file mode 100644 index a3849aa8d..000000000 --- a/services/crawler/tests/test_website_store.py +++ /dev/null @@ -1,701 +0,0 @@ -""" -Tests for WebsiteStore and WebsiteStoreManager. - -Uses importlib to load website_store directly, bypassing the app.services -barrel __init__.py which pulls in heavy dependencies (playwright, crawl4ai). -""" - -import importlib.util -import sys -from pathlib import Path - -import pytest - -# Load website_store module directly to avoid app.services.__init__ barrel import -_module_path = Path(__file__).resolve().parent.parent / "app" / "services" / "website_store.py" -_spec = importlib.util.spec_from_file_location("website_store", _module_path) -_mod = importlib.util.module_from_spec(_spec) -sys.modules["website_store"] = _mod -_spec.loader.exec_module(_mod) - -WebsiteStore = _mod.WebsiteStore -WebsiteStoreManager = _mod.WebsiteStoreManager -_sanitize_domain = _mod._sanitize_domain - - -@pytest.fixture -def tmp_data_dir(tmp_path): - return tmp_path / "data" - - -@pytest.fixture -def site_store(tmp_path): - db_path = tmp_path / "sites" / "example_com.db" - store = WebsiteStore(db_path) - yield store - store.close() - - -@pytest.fixture -def manager(tmp_data_dir): - mgr = WebsiteStoreManager(data_dir=tmp_data_dir) - yield mgr - mgr.close_all() - - -class TestSanitizeDomain: - def test_replaces_dots_and_hyphens(self): - assert _sanitize_domain("my-site.example.com") == "my_site_example_com" - - def test_no_special_chars(self): - assert _sanitize_domain("localhost") == "localhost" - - -class TestWebsiteStore: - def test_creates_db_file(self, tmp_path): - db_path = tmp_path / "nested" / "dir" / "test.db" - store = WebsiteStore(db_path) - assert db_path.exists() - store.close() - - def test_save_discovered_urls(self, site_store): - urls = [{"url": "https://example.com/a"}, {"url": "https://example.com/b"}] - inserted = site_store.save_discovered_urls(urls) - assert inserted >= 2 - # Discovered URLs have no content_hash yet, so not counted - assert site_store.get_total_count() == 0 - - def test_save_discovered_urls_ignores_duplicates(self, site_store): - urls = [{"url": "https://example.com/a"}] - site_store.save_discovered_urls(urls) - site_store.save_discovered_urls(urls) - assert site_store.get_total_count() == 0 - - def test_save_discovered_urls_empty(self, site_store): - assert site_store.save_discovered_urls([]) == 0 - - def test_get_urls_page_excludes_null_hash(self, site_store): - site_store.save_discovered_urls([{"url": "https://example.com/a"}]) - assert site_store.get_urls_page() == [] - - def test_get_urls_page_basic(self, site_store): - urls = [{"url": f"https://example.com/{i}"} for i in range(5)] - site_store.save_discovered_urls(urls) - site_store.update_content_hashes( - [{"url": f"https://example.com/{i}", "content_hash": f"h{i}"} for i in range(5)] - ) - - page = site_store.get_urls_page(offset=0, limit=3) - assert len(page) == 3 - assert all("url" in p and "content_hash" in p and "status" in p for p in page) - - def test_get_urls_page_offset(self, site_store): - urls = [{"url": f"https://example.com/{i}"} for i in range(5)] - site_store.save_discovered_urls(urls) - site_store.update_content_hashes( - [{"url": f"https://example.com/{i}", "content_hash": f"h{i}"} for i in range(5)] - ) - - page = site_store.get_urls_page(offset=3, limit=10) - assert len(page) == 2 - - def test_get_urls_page_with_status_filter(self, site_store): - urls = [{"url": "https://example.com/a"}, {"url": "https://example.com/b"}] - site_store.save_discovered_urls(urls) - - site_store.update_content_hashes( - [ - {"url": "https://example.com/a", "content_hash": "abc", "status": "active"}, - ] - ) - - active = site_store.get_urls_page(status="active") - assert len(active) == 1 - assert active[0]["url"] == "https://example.com/a" - - # discovered URL has no hash, so excluded - discovered = site_store.get_urls_page(status="discovered") - assert len(discovered) == 0 - - def test_update_content_hashes(self, site_store): - site_store.save_discovered_urls([{"url": "https://example.com/page"}]) - - site_store.update_content_hashes( - [ - {"url": "https://example.com/page", "content_hash": "sha256abc"}, - ] - ) - - pages = site_store.get_urls_page() - assert len(pages) == 1 - assert pages[0]["content_hash"] == "sha256abc" - assert pages[0]["status"] == "active" - assert pages[0]["last_crawled_at"] is not None - - def test_update_content_hashes_empty(self, site_store): - # Should not raise - site_store.update_content_hashes([]) - - def test_mark_urls_deleted(self, site_store): - site_store.save_discovered_urls( - [ - {"url": "https://example.com/a"}, - {"url": "https://example.com/b"}, - ] - ) - site_store.update_content_hashes( - [ - {"url": "https://example.com/a", "content_hash": "ha", "status": "active"}, - {"url": "https://example.com/b", "content_hash": "hb", "status": "active"}, - ] - ) - - site_store.mark_urls_deleted(["https://example.com/a"]) - - deleted = site_store.get_urls_page(status="deleted") - assert len(deleted) == 1 - assert deleted[0]["url"] == "https://example.com/a" - - active = site_store.get_urls_page(status="active") - assert len(active) == 1 - - def test_mark_urls_deleted_empty(self, site_store): - site_store.mark_urls_deleted([]) - - def test_get_urls_needing_recrawl_prefers_no_hash(self, site_store): - site_store.save_discovered_urls( - [ - {"url": "https://example.com/new"}, - {"url": "https://example.com/old"}, - ] - ) - site_store.update_content_hashes( - [ - {"url": "https://example.com/old", "content_hash": "h1"}, - ] - ) - - needing = site_store.get_urls_needing_recrawl(limit=10) - assert len(needing) == 2 - # URL without hash should come first - assert needing[0] == "https://example.com/new" - - def test_get_urls_needing_recrawl_excludes_deleted(self, site_store): - site_store.save_discovered_urls( - [ - {"url": "https://example.com/a"}, - {"url": "https://example.com/b"}, - ] - ) - site_store.mark_urls_deleted(["https://example.com/a"]) - - needing = site_store.get_urls_needing_recrawl(limit=10) - assert len(needing) == 1 - assert needing[0] == "https://example.com/b" - - def test_get_urls_needing_recrawl_respects_limit(self, site_store): - urls = [{"url": f"https://example.com/{i}"} for i in range(10)] - site_store.save_discovered_urls(urls) - - needing = site_store.get_urls_needing_recrawl(limit=3) - assert len(needing) == 3 - - def test_get_urls_needing_recrawl_crawled_before_excludes_recent(self, site_store): - import time - - site_store.save_discovered_urls( - [ - {"url": "https://example.com/a"}, - {"url": "https://example.com/b"}, - ] - ) - cutoff = time.time() - # Crawl one URL after the cutoff - site_store.update_content_hashes([{"url": "https://example.com/a", "content_hash": "h1"}]) - - needing = site_store.get_urls_needing_recrawl(limit=10, crawled_before=cutoff) - assert needing == ["https://example.com/b"] - - def test_get_urls_needing_recrawl_crawled_before_includes_stale(self, site_store): - import time - - site_store.save_discovered_urls( - [ - {"url": "https://example.com/a"}, - {"url": "https://example.com/b"}, - ] - ) - # Crawl both URLs - site_store.update_content_hashes( - [ - {"url": "https://example.com/a", "content_hash": "h1"}, - {"url": "https://example.com/b", "content_hash": "h2"}, - ] - ) - cutoff = time.time() - - # Both were crawled before cutoff, so both should be returned - needing = site_store.get_urls_needing_recrawl(limit=10, crawled_before=cutoff) - assert len(needing) == 2 - - def test_increment_fail_count(self, site_store): - site_store.save_discovered_urls([{"url": "https://example.com/flaky"}]) - - site_store.increment_fail_count(["https://example.com/flaky"]) - conn = site_store._get_conn() - row = conn.execute( - "SELECT fail_count, last_crawled_at FROM website_urls WHERE url = ?", - ("https://example.com/flaky",), - ).fetchone() - assert row["fail_count"] == 1 - assert row["last_crawled_at"] is not None - - def test_increment_fail_count_accumulates(self, site_store): - site_store.save_discovered_urls([{"url": "https://example.com/flaky"}]) - - site_store.increment_fail_count(["https://example.com/flaky"]) - site_store.increment_fail_count(["https://example.com/flaky"]) - site_store.increment_fail_count(["https://example.com/flaky"]) - - conn = site_store._get_conn() - row = conn.execute( - "SELECT fail_count FROM website_urls WHERE url = ?", - ("https://example.com/flaky",), - ).fetchone() - assert row["fail_count"] == 3 - - def test_increment_fail_count_empty(self, site_store): - site_store.increment_fail_count([]) - - def test_successful_crawl_resets_fail_count(self, site_store): - site_store.save_discovered_urls([{"url": "https://example.com/flaky"}]) - site_store.increment_fail_count(["https://example.com/flaky"]) - site_store.increment_fail_count(["https://example.com/flaky"]) - - site_store.update_content_hashes([{"url": "https://example.com/flaky", "content_hash": "h1"}]) - - conn = site_store._get_conn() - row = conn.execute( - "SELECT fail_count FROM website_urls WHERE url = ?", - ("https://example.com/flaky",), - ).fetchone() - assert row["fail_count"] == 0 - - def test_increment_fail_count_sets_last_crawled_at_for_session_scoping(self, site_store): - import time - - site_store.save_discovered_urls([{"url": "https://example.com/fail"}]) - cutoff = time.time() - - # Increment fail count (sets last_crawled_at to now, which is after cutoff) - site_store.increment_fail_count(["https://example.com/fail"]) - - # URL should no longer appear in this scan session - needing = site_store.get_urls_needing_recrawl(limit=10, crawled_before=cutoff) - assert needing == [] - - def test_get_total_count(self, site_store): - assert site_store.get_total_count() == 0 - - site_store.save_discovered_urls([{"url": "https://example.com/a"}]) - # No hash yet, so count is still 0 - assert site_store.get_total_count() == 0 - - site_store.update_content_hashes([{"url": "https://example.com/a", "content_hash": "h1"}]) - assert site_store.get_total_count() == 1 - - def test_get_total_count_with_status(self, site_store): - site_store.save_discovered_urls( - [ - {"url": "https://example.com/a"}, - {"url": "https://example.com/b"}, - ] - ) - site_store.update_content_hashes( - [ - {"url": "https://example.com/a", "content_hash": "ha", "status": "active"}, - {"url": "https://example.com/b", "content_hash": "hb", "status": "active"}, - ] - ) - site_store.mark_urls_deleted(["https://example.com/a"]) - - assert site_store.get_total_count(status="deleted") == 1 - assert site_store.get_total_count(status="active") == 1 - - def test_update_content_hashes_with_page_data(self, site_store): - site_store.save_discovered_urls([{"url": "https://example.com/page"}]) - - site_store.update_content_hashes( - [ - { - "url": "https://example.com/page", - "content_hash": "sha256abc", - "status": "active", - "title": "Test Page", - "content": "# Hello World\n\nSome content here.", - "word_count": 5, - "metadata": '{"author": "test"}', - "structured_data": '{"og:title": "Test"}', - }, - ] - ) - - cached = site_store.get_cached_pages(["https://example.com/page"]) - assert len(cached) == 1 - assert cached[0]["url"] == "https://example.com/page" - assert cached[0]["title"] == "Test Page" - assert cached[0]["content"] == "# Hello World\n\nSome content here." - assert cached[0]["word_count"] == 5 - assert cached[0]["metadata"] == {"author": "test"} - assert cached[0]["structured_data"] == {"og:title": "Test"} - - def test_update_content_hashes_without_page_data(self, site_store): - site_store.save_discovered_urls([{"url": "https://example.com/page"}]) - - site_store.update_content_hashes([{"url": "https://example.com/page", "content_hash": "sha256abc"}]) - - cached = site_store.get_cached_pages(["https://example.com/page"]) - assert len(cached) == 0 - - def test_get_cached_pages_hit(self, site_store): - site_store.save_discovered_urls([{"url": "https://example.com/a"}]) - site_store.update_content_hashes( - [ - { - "url": "https://example.com/a", - "content_hash": "h1", - "content": "Page A content", - "title": "Page A", - "word_count": 3, - }, - ] - ) - - cached = site_store.get_cached_pages(["https://example.com/a"]) - assert len(cached) == 1 - assert cached[0]["content"] == "Page A content" - assert cached[0]["title"] == "Page A" - - def test_get_cached_pages_miss(self, site_store): - cached = site_store.get_cached_pages(["https://example.com/nonexistent"]) - assert len(cached) == 0 - - def test_get_cached_pages_mixed(self, site_store): - site_store.save_discovered_urls([{"url": "https://example.com/a"}, {"url": "https://example.com/b"}]) - site_store.update_content_hashes( - [ - { - "url": "https://example.com/a", - "content_hash": "h1", - "content": "Page A", - "word_count": 2, - }, - {"url": "https://example.com/b", "content_hash": "h2"}, - ] - ) - - cached = site_store.get_cached_pages( - ["https://example.com/a", "https://example.com/b", "https://example.com/c"] - ) - assert len(cached) == 1 - assert cached[0]["url"] == "https://example.com/a" - - def test_get_cached_pages_empty_input(self, site_store): - assert site_store.get_cached_pages([]) == [] - - def test_get_cached_pages_null_metadata(self, site_store): - site_store.save_discovered_urls([{"url": "https://example.com/a"}]) - site_store.update_content_hashes( - [ - { - "url": "https://example.com/a", - "content_hash": "h1", - "content": "Some content", - "word_count": 2, - }, - ] - ) - - cached = site_store.get_cached_pages(["https://example.com/a"]) - assert cached[0]["metadata"] is None - assert cached[0]["structured_data"] is None - - def test_close_and_reopen(self, tmp_path): - db_path = tmp_path / "test.db" - store = WebsiteStore(db_path) - store.save_discovered_urls([{"url": "https://example.com/persist"}]) - store.update_content_hashes([{"url": "https://example.com/persist", "content_hash": "h1"}]) - store.close() - - store2 = WebsiteStore(db_path) - assert store2.get_total_count() == 1 - pages = store2.get_urls_page() - assert pages[0]["url"] == "https://example.com/persist" - store2.close() - - def test_schema_migration_adds_fail_count(self, tmp_path): - db_path = tmp_path / "legacy_fc.db" - import sqlite3 - - conn = sqlite3.connect(str(db_path)) - conn.executescript(""" - CREATE TABLE website_urls ( - url TEXT PRIMARY KEY, - content_hash TEXT, - status TEXT NOT NULL DEFAULT 'discovered', - last_crawled_at REAL, - discovered_at REAL NOT NULL, - title TEXT, - content TEXT, - word_count INTEGER, - metadata TEXT, - structured_data TEXT - ); - """) - conn.execute( - "INSERT INTO website_urls (url, discovered_at) VALUES (?, ?)", - ("https://example.com/old", 1000.0), - ) - conn.commit() - conn.close() - - store = WebsiteStore(db_path) - # fail_count column should exist after migration - c = store._get_conn() - row = c.execute( - "SELECT fail_count FROM website_urls WHERE url = ?", - ("https://example.com/old",), - ).fetchone() - assert row["fail_count"] == 0 - store.close() - - def test_schema_migration_adds_columns(self, tmp_path): - db_path = tmp_path / "legacy.db" - import sqlite3 - - conn = sqlite3.connect(str(db_path)) - conn.executescript(""" - CREATE TABLE website_urls ( - url TEXT PRIMARY KEY, - content_hash TEXT, - status TEXT NOT NULL DEFAULT 'discovered', - last_crawled_at REAL, - discovered_at REAL NOT NULL - ); - """) - conn.execute( - "INSERT INTO website_urls (url, discovered_at) VALUES (?, ?)", - ("https://example.com/old", 1000.0), - ) - conn.commit() - conn.close() - - store = WebsiteStore(db_path) - store.update_content_hashes( - [ - { - "url": "https://example.com/old", - "content_hash": "h1", - "content": "Migrated content", - "title": "Old Page", - "word_count": 2, - }, - ] - ) - cached = store.get_cached_pages(["https://example.com/old"]) - assert len(cached) == 1 - assert cached[0]["content"] == "Migrated content" - assert cached[0]["title"] == "Old Page" - store.close() - - -class TestWebsiteStoreManager: - def test_register_website(self, manager): - result = manager.register_website("example.com", scan_interval=3600) - assert result["domain"] == "example.com" - assert result["scan_interval"] == 3600 - - website = manager.get_website("example.com") - assert website is not None - assert website["domain"] == "example.com" - assert website["scan_interval"] == 3600 - assert website["status"] == "idle" - - def test_register_website_upsert(self, manager): - manager.register_website("example.com", scan_interval=3600) - manager.register_website("example.com", scan_interval=7200) - - website = manager.get_website("example.com") - assert website["scan_interval"] == 7200 - - def test_remove_website(self, manager, tmp_data_dir): - manager.register_website("example.com") - site_store = manager.get_site_store("example.com") - site_store.save_discovered_urls([{"url": "https://example.com/a"}]) - - db_file = tmp_data_dir / "sites" / "example_com.db" - assert db_file.exists() - - removed = manager.remove_website("example.com") - assert removed is True - assert not db_file.exists() - assert manager.get_website("example.com") is None - - def test_remove_website_not_found(self, manager): - removed = manager.remove_website("nonexistent.com") - assert removed is False - - def test_get_due_websites_none_scanned(self, manager): - manager.register_website("a.com") - manager.register_website("b.com") - - due = manager.get_due_websites() - domains = [w["domain"] for w in due] - assert "a.com" in domains - assert "b.com" in domains - - def test_get_due_websites_excludes_scanning(self, manager): - manager.register_website("a.com") - manager.update_scan_status("a.com", "scanning") - - due = manager.get_due_websites() - assert len(due) == 0 - - def test_get_due_websites_excludes_recently_scanned(self, manager): - manager.register_website("a.com", scan_interval=3600) - manager.update_last_scanned("a.com") - - due = manager.get_due_websites() - assert len(due) == 0 - - def test_update_scan_status(self, manager): - manager.register_website("a.com") - manager.update_scan_status("a.com", "error", error="timeout") - - website = manager.get_website("a.com") - assert website["status"] == "error" - assert website["error"] == "timeout" - - def test_update_scan_status_clears_error(self, manager): - manager.register_website("a.com") - manager.update_scan_status("a.com", "error", error="timeout") - manager.update_scan_status("a.com", "idle") - - website = manager.get_website("a.com") - assert website["status"] == "idle" - assert website["error"] is None - - def test_get_site_store_creates_and_caches(self, manager): - store1 = manager.get_site_store("example.com") - store2 = manager.get_site_store("example.com") - assert store1 is store2 - - def test_get_site_store_different_domains(self, manager): - store_a = manager.get_site_store("a.com") - store_b = manager.get_site_store("b.com") - assert store_a is not store_b - - def test_site_store_isolation(self, manager): - store_a = manager.get_site_store("a.com") - store_b = manager.get_site_store("b.com") - - store_a.save_discovered_urls([{"url": "https://a.com/page"}]) - store_a.update_content_hashes([{"url": "https://a.com/page", "content_hash": "ha"}]) - - store_b.save_discovered_urls([{"url": "https://b.com/page1"}, {"url": "https://b.com/page2"}]) - store_b.update_content_hashes( - [ - {"url": "https://b.com/page1", "content_hash": "hb1"}, - {"url": "https://b.com/page2", "content_hash": "hb2"}, - ] - ) - - assert store_a.get_total_count() == 1 - assert store_b.get_total_count() == 2 - - def test_get_website_not_found(self, manager): - assert manager.get_website("nonexistent.com") is None - - def test_close_all(self, manager): - manager.register_website("a.com") - manager.get_site_store("a.com") - - manager.close_all() - # After close_all, internal state should be cleared - assert len(manager._stores) == 0 - assert manager._main_conn is None - - def test_removes_wal_and_shm_files(self, manager, tmp_data_dir): - manager.register_website("example.com") - site_store = manager.get_site_store("example.com") - site_store.save_discovered_urls([{"url": "https://example.com/a"}]) - - db_file = tmp_data_dir / "sites" / "example_com.db" - # WAL files may or may not exist depending on SQLite behavior, - # but remove_website should handle them gracefully - manager.remove_website("example.com") - assert not db_file.exists() - assert not db_file.with_suffix(".db-wal").exists() - assert not db_file.with_suffix(".db-shm").exists() - - def test_get_cached_pages_registered_domain(self, manager): - manager.register_website("example.com") - store = manager.get_site_store("example.com") - store.save_discovered_urls([{"url": "https://example.com/page"}]) - store.update_content_hashes( - [ - { - "url": "https://example.com/page", - "content_hash": "h1", - "content": "Cached content", - "title": "Cached", - "word_count": 2, - }, - ] - ) - - cached, to_crawl = manager.get_cached_pages(["https://example.com/page"]) - assert len(cached) == 1 - assert cached[0]["content"] == "Cached content" - assert len(to_crawl) == 0 - - def test_get_cached_pages_unregistered_domain(self, manager): - cached, to_crawl = manager.get_cached_pages(["https://unknown.com/page"]) - assert len(cached) == 0 - assert to_crawl == ["https://unknown.com/page"] - - def test_get_cached_pages_mixed_domains(self, manager): - manager.register_website("a.com") - store = manager.get_site_store("a.com") - store.save_discovered_urls([{"url": "https://a.com/page"}]) - store.update_content_hashes( - [ - { - "url": "https://a.com/page", - "content_hash": "h1", - "content": "Page A", - "word_count": 2, - }, - ] - ) - - cached, to_crawl = manager.get_cached_pages(["https://a.com/page", "https://b.com/other"]) - assert len(cached) == 1 - assert cached[0]["url"] == "https://a.com/page" - assert to_crawl == ["https://b.com/other"] - - def test_get_cached_pages_cache_miss_on_registered_domain(self, manager): - manager.register_website("example.com") - store = manager.get_site_store("example.com") - store.save_discovered_urls([{"url": "https://example.com/a"}]) - # Hash only, no content - store.update_content_hashes([{"url": "https://example.com/a", "content_hash": "h1"}]) - - cached, to_crawl = manager.get_cached_pages(["https://example.com/a"]) - assert len(cached) == 0 - assert to_crawl == ["https://example.com/a"] - - def test_get_cached_pages_empty_input(self, manager): - cached, to_crawl = manager.get_cached_pages([]) - assert cached == [] - assert to_crawl == [] diff --git a/services/crawler/tests/test_websites_router.py b/services/crawler/tests/test_websites_router.py new file mode 100644 index 000000000..2e5919791 --- /dev/null +++ b/services/crawler/tests/test_websites_router.py @@ -0,0 +1,333 @@ +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from fastapi import FastAPI +from httpx import ASGITransport, AsyncClient + +from app.routers.websites import router + +app = FastAPI() +app.include_router(router) + + +@pytest.fixture +def mock_manager(): + manager = AsyncMock() + manager.get_site_store = MagicMock() + app.state.pg_store_manager = manager + yield manager + del app.state.pg_store_manager + + +def _website_row(domain="example.com", scan_interval=21600, **overrides): + return { + "domain": domain, + "title": None, + "description": None, + "page_count": 0, + "total_urls": 0, + "crawled_count": 0, + "status": "idle", + "scan_interval": scan_interval, + "last_scanned_at": None, + "error": None, + "created_at": None, + "updated_at": None, + **overrides, + } + + +class TestRegisterWebsite: + async def test_success(self, mock_manager): + mock_manager.register_website.return_value = { + "domain": "example.com", + "status": "idle", + "scan_interval": 21600, + } + + with ( + patch("app.routers.websites.trigger_scan") as mock_trigger, + patch("app.routers.websites._initialize_website"), + ): + async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client: + response = await client.post( + "/api/v1/websites", + json={"domain": "example.com", "scan_interval": 21600}, + ) + + assert response.status_code == 200 + data = response.json() + assert data["domain"] == "example.com" + assert data["status"] == "scanning" + assert data["scan_interval"] == 21600 + mock_manager.register_website.assert_awaited_once_with( + domain="example.com", + scan_interval=21600, + ) + mock_trigger.assert_called_once() + + async def test_normalizes_full_url_to_domain(self, mock_manager): + mock_manager.register_website.return_value = { + "domain": "www.wisekey.com", + "status": "idle", + "scan_interval": 21600, + } + mock_manager.get_website.return_value = _website_row(domain="www.wisekey.com") + + with ( + patch("app.routers.websites.trigger_scan"), + patch("app.routers.websites._initialize_website"), + ): + async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client: + response = await client.post( + "/api/v1/websites", + json={"domain": "https://www.wisekey.com", "scan_interval": 21600}, + ) + + assert response.status_code == 200 + mock_manager.register_website.assert_awaited_once_with( + domain="www.wisekey.com", + scan_interval=21600, + ) + + async def test_uses_default_scan_interval(self, mock_manager): + mock_manager.register_website.return_value = { + "domain": "example.com", + "status": "idle", + "scan_interval": 21600, + } + mock_manager.get_website.return_value = _website_row() + + with ( + patch("app.routers.websites.trigger_scan"), + patch("app.routers.websites._initialize_website"), + ): + async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client: + response = await client.post( + "/api/v1/websites", + json={"domain": "example.com"}, + ) + + assert response.status_code == 200 + mock_manager.register_website.assert_awaited_once_with( + domain="example.com", + scan_interval=21600, + ) + + async def test_returns_scanning_status_immediately(self, mock_manager): + mock_manager.register_website.return_value = { + "domain": "example.com", + "status": "idle", + "scan_interval": 21600, + } + + with ( + patch("app.routers.websites.trigger_scan"), + patch("app.routers.websites._initialize_website"), + ): + async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client: + response = await client.post( + "/api/v1/websites", + json={"domain": "example.com"}, + ) + + assert response.status_code == 200 + data = response.json() + assert data["title"] is None + assert data["page_count"] == 0 + assert data["crawled_count"] == 0 + assert data["status"] == "scanning" + + async def test_500_on_error(self, mock_manager): + mock_manager.register_website.side_effect = RuntimeError("db error") + + with ( + patch("app.routers.websites.trigger_scan"), + patch("app.routers.websites._initialize_website"), + ): + async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client: + response = await client.post( + "/api/v1/websites", + json={"domain": "example.com"}, + ) + + assert response.status_code == 500 + assert response.json()["detail"] == "Failed to register website" + + +class TestGetWebsiteInfo: + async def test_success(self, mock_manager): + mock_manager.get_website.return_value = { + "domain": "example.com", + "title": "Example", + "description": "An example site", + "page_count": 42, + "total_urls": 50, + "crawled_count": 42, + "status": "active", + "scan_interval": 3600, + "last_scanned_at": 1700000000.0, + "error": None, + "created_at": 1699000000.0, + "updated_at": 1700000000.0, + } + + async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client: + response = await client.get("/api/v1/websites/example.com") + + assert response.status_code == 200 + data = response.json() + assert data["domain"] == "example.com" + assert data["title"] == "Example" + assert data["description"] == "An example site" + assert data["page_count"] == 50 + assert data["crawled_count"] == 42 + assert data["status"] == "active" + assert data["scan_interval"] == 3600 + assert data["last_scanned_at"] is not None + assert data["error"] is None + assert data["created_at"] is not None + assert data["updated_at"] is not None + mock_manager.get_website.assert_awaited_once_with("example.com") + + async def test_404_when_not_found(self, mock_manager): + mock_manager.get_website.return_value = None + + async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client: + response = await client.get("/api/v1/websites/unknown.com") + + assert response.status_code == 404 + assert response.json()["detail"] == "Website not found: unknown.com" + + async def test_500_on_error(self, mock_manager): + mock_manager.get_website.side_effect = RuntimeError("db error") + + async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client: + response = await client.get("/api/v1/websites/example.com") + + assert response.status_code == 500 + assert response.json()["detail"] == "Failed to get website info" + + +class TestDeregisterWebsite: + async def test_success(self, mock_manager): + mock_manager.remove_website.return_value = True + + async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client: + response = await client.delete("/api/v1/websites/example.com") + + assert response.status_code == 200 + data = response.json() + assert data["domain"] == "example.com" + assert data["deleted"] is True + mock_manager.remove_website.assert_awaited_once_with("example.com") + + async def test_404_when_not_found(self, mock_manager): + mock_manager.remove_website.return_value = False + + async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client: + response = await client.delete("/api/v1/websites/unknown.com") + + assert response.status_code == 404 + assert response.json()["detail"] == "Website not found: unknown.com" + + async def test_500_on_error(self, mock_manager): + mock_manager.remove_website.side_effect = RuntimeError("db error") + + async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client: + response = await client.delete("/api/v1/websites/example.com") + + assert response.status_code == 500 + assert response.json()["detail"] == "Failed to deregister website" + + +class TestGetWebsiteUrls: + async def test_success_with_pagination(self, mock_manager): + mock_manager.get_website.return_value = {"domain": "example.com"} + mock_site_store = AsyncMock() + mock_manager.get_site_store.return_value = mock_site_store + mock_site_store.get_urls_page.return_value = [ + { + "url": "https://example.com/page1", + "content_hash": "abc123", + "status": "active", + "last_crawled_at": 1700000000.0, + }, + { + "url": "https://example.com/page2", + "content_hash": "def456", + "status": "active", + "last_crawled_at": 1700001000.0, + }, + ] + mock_site_store.get_total_count.return_value = 50 + + async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client: + response = await client.get("/api/v1/websites/example.com/urls?offset=0&limit=2") + + assert response.status_code == 200 + data = response.json() + assert data["domain"] == "example.com" + assert len(data["urls"]) == 2 + assert data["urls"][0]["url"] == "https://example.com/page1" + assert data["urls"][0]["content_hash"] == "abc123" + assert data["urls"][1]["url"] == "https://example.com/page2" + assert data["total"] == 50 + assert data["offset"] == 0 + assert data["has_more"] is True + mock_site_store.get_urls_page.assert_awaited_once_with(offset=0, limit=2, status=None) + mock_site_store.get_total_count.assert_awaited_once_with(status=None) + + async def test_has_more_false_when_at_end(self, mock_manager): + mock_manager.get_website.return_value = {"domain": "example.com"} + mock_site_store = AsyncMock() + mock_manager.get_site_store.return_value = mock_site_store + mock_site_store.get_urls_page.return_value = [ + { + "url": "https://example.com/last", + "content_hash": "xyz", + "status": "active", + "last_crawled_at": None, + }, + ] + mock_site_store.get_total_count.return_value = 1 + + async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client: + response = await client.get("/api/v1/websites/example.com/urls?offset=0&limit=100") + + assert response.status_code == 200 + data = response.json() + assert data["has_more"] is False + assert data["total"] == 1 + + async def test_status_filter(self, mock_manager): + mock_manager.get_website.return_value = {"domain": "example.com"} + mock_site_store = AsyncMock() + mock_manager.get_site_store.return_value = mock_site_store + mock_site_store.get_urls_page.return_value = [] + mock_site_store.get_total_count.return_value = 0 + + async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client: + response = await client.get("/api/v1/websites/example.com/urls?status=active") + + assert response.status_code == 200 + mock_site_store.get_urls_page.assert_awaited_once_with(offset=0, limit=100, status="active") + mock_site_store.get_total_count.assert_awaited_once_with(status="active") + + async def test_404_when_website_not_found(self, mock_manager): + mock_manager.get_website.return_value = None + + async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client: + response = await client.get("/api/v1/websites/unknown.com/urls") + + assert response.status_code == 404 + assert response.json()["detail"] == "Website not found: unknown.com" + + async def test_500_on_error(self, mock_manager): + mock_manager.get_website.side_effect = RuntimeError("db error") + + async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client: + response = await client.get("/api/v1/websites/example.com/urls") + + assert response.status_code == 500 + assert response.json()["detail"] == "Failed to get website URLs" diff --git a/services/crawler/uv.lock b/services/crawler/uv.lock index 322d8992e..cf243052f 100644 --- a/services/crawler/uv.lock +++ b/services/crawler/uv.lock @@ -198,6 +198,54 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/38/0e/27be9fdef66e72d64c0cdc3cc2823101b80585f8119b5c112c2e8f5f7dab/anyio-4.12.1-py3-none-any.whl", hash = "sha256:d405828884fc140aa80a3c667b8beed277f1dfedec42ba031bd6ac3db606ab6c", size = 113592, upload-time = "2026-01-06T11:45:19.497Z" }, ] +[[package]] +name = "asyncpg" +version = "0.31.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/fe/cc/d18065ce2380d80b1bcce927c24a2642efd38918e33fd724bc4bca904877/asyncpg-0.31.0.tar.gz", hash = "sha256:c989386c83940bfbd787180f2b1519415e2d3d6277a70d9d0f0145ac73500735", size = 993667, upload-time = "2025-11-24T23:27:00.812Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/08/17/cc02bc49bc350623d050fa139e34ea512cd6e020562f2a7312a7bcae4bc9/asyncpg-0.31.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:eee690960e8ab85063ba93af2ce128c0f52fd655fdff9fdb1a28df01329f031d", size = 643159, upload-time = "2025-11-24T23:25:36.443Z" }, + { url = "https://files.pythonhosted.org/packages/a4/62/4ded7d400a7b651adf06f49ea8f73100cca07c6df012119594d1e3447aa6/asyncpg-0.31.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:2657204552b75f8288de08ca60faf4a99a65deef3a71d1467454123205a88fab", size = 638157, upload-time = "2025-11-24T23:25:37.89Z" }, + { url = "https://files.pythonhosted.org/packages/d6/5b/4179538a9a72166a0bf60ad783b1ef16efb7960e4d7b9afe9f77a5551680/asyncpg-0.31.0-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:a429e842a3a4b4ea240ea52d7fe3f82d5149853249306f7ff166cb9948faa46c", size = 2918051, upload-time = "2025-11-24T23:25:39.461Z" }, + { url = "https://files.pythonhosted.org/packages/e6/35/c27719ae0536c5b6e61e4701391ffe435ef59539e9360959240d6e47c8c8/asyncpg-0.31.0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c0807be46c32c963ae40d329b3a686356e417f674c976c07fa49f1b30303f109", size = 2972640, upload-time = "2025-11-24T23:25:41.512Z" }, + { url = "https://files.pythonhosted.org/packages/43/f4/01ebb9207f29e645a64699b9ce0eefeff8e7a33494e1d29bb53736f7766b/asyncpg-0.31.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:e5d5098f63beeae93512ee513d4c0c53dc12e9aa2b7a1af5a81cddf93fe4e4da", size = 2851050, upload-time = "2025-11-24T23:25:43.153Z" }, + { url = "https://files.pythonhosted.org/packages/3e/f4/03ff1426acc87be0f4e8d40fa2bff5c3952bef0080062af9efc2212e3be8/asyncpg-0.31.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:37fc6c00a814e18eef51833545d1891cac9aa69140598bb076b4cd29b3e010b9", size = 2962574, upload-time = "2025-11-24T23:25:44.942Z" }, + { url = "https://files.pythonhosted.org/packages/c7/39/cc788dfca3d4060f9d93e67be396ceec458dfc429e26139059e58c2c244d/asyncpg-0.31.0-cp311-cp311-win32.whl", hash = "sha256:5a4af56edf82a701aece93190cc4e094d2df7d33f6e915c222fb09efbb5afc24", size = 521076, upload-time = "2025-11-24T23:25:46.486Z" }, + { url = "https://files.pythonhosted.org/packages/28/fc/735af5384c029eb7f1ca60ccb8fa95521dbdaeef788edf4cecfc604c3cab/asyncpg-0.31.0-cp311-cp311-win_amd64.whl", hash = "sha256:480c4befbdf079c14c9ca43c8c5e1fe8b6296c96f1f927158d4f1e750aacc047", size = 584980, upload-time = "2025-11-24T23:25:47.938Z" }, + { url = "https://files.pythonhosted.org/packages/2a/a6/59d0a146e61d20e18db7396583242e32e0f120693b67a8de43f1557033e2/asyncpg-0.31.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:b44c31e1efc1c15188ef183f287c728e2046abb1d26af4d20858215d50d91fad", size = 662042, upload-time = "2025-11-24T23:25:49.578Z" }, + { url = "https://files.pythonhosted.org/packages/36/01/ffaa189dcb63a2471720615e60185c3f6327716fdc0fc04334436fbb7c65/asyncpg-0.31.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:0c89ccf741c067614c9b5fc7f1fc6f3b61ab05ae4aaa966e6fd6b93097c7d20d", size = 638504, upload-time = "2025-11-24T23:25:51.501Z" }, + { url = "https://files.pythonhosted.org/packages/9f/62/3f699ba45d8bd24c5d65392190d19656d74ff0185f42e19d0bbd973bb371/asyncpg-0.31.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:12b3b2e39dc5470abd5e98c8d3373e4b1d1234d9fbdedf538798b2c13c64460a", size = 3426241, upload-time = "2025-11-24T23:25:53.278Z" }, + { url = "https://files.pythonhosted.org/packages/8c/d1/a867c2150f9c6e7af6462637f613ba67f78a314b00db220cd26ff559d532/asyncpg-0.31.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:aad7a33913fb8bcb5454313377cc330fbb19a0cd5faa7272407d8a0c4257b671", size = 3520321, upload-time = "2025-11-24T23:25:54.982Z" }, + { url = "https://files.pythonhosted.org/packages/7a/1a/cce4c3f246805ecd285a3591222a2611141f1669d002163abef999b60f98/asyncpg-0.31.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:3df118d94f46d85b2e434fd62c84cb66d5834d5a890725fe625f498e72e4d5ec", size = 3316685, upload-time = "2025-11-24T23:25:57.43Z" }, + { url = "https://files.pythonhosted.org/packages/40/ae/0fc961179e78cc579e138fad6eb580448ecae64908f95b8cb8ee2f241f67/asyncpg-0.31.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:bd5b6efff3c17c3202d4b37189969acf8927438a238c6257f66be3c426beba20", size = 3471858, upload-time = "2025-11-24T23:25:59.636Z" }, + { url = "https://files.pythonhosted.org/packages/52/b2/b20e09670be031afa4cbfabd645caece7f85ec62d69c312239de568e058e/asyncpg-0.31.0-cp312-cp312-win32.whl", hash = "sha256:027eaa61361ec735926566f995d959ade4796f6a49d3bde17e5134b9964f9ba8", size = 527852, upload-time = "2025-11-24T23:26:01.084Z" }, + { url = "https://files.pythonhosted.org/packages/b5/f0/f2ed1de154e15b107dc692262395b3c17fc34eafe2a78fc2115931561730/asyncpg-0.31.0-cp312-cp312-win_amd64.whl", hash = "sha256:72d6bdcbc93d608a1158f17932de2321f68b1a967a13e014998db87a72ed3186", size = 597175, upload-time = "2025-11-24T23:26:02.564Z" }, + { url = "https://files.pythonhosted.org/packages/95/11/97b5c2af72a5d0b9bc3fa30cd4b9ce22284a9a943a150fdc768763caf035/asyncpg-0.31.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:c204fab1b91e08b0f47e90a75d1b3c62174dab21f670ad6c5d0f243a228f015b", size = 661111, upload-time = "2025-11-24T23:26:04.467Z" }, + { url = "https://files.pythonhosted.org/packages/1b/71/157d611c791a5e2d0423f09f027bd499935f0906e0c2a416ce712ba51ef3/asyncpg-0.31.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:54a64f91839ba59008eccf7aad2e93d6e3de688d796f35803235ea1c4898ae1e", size = 636928, upload-time = "2025-11-24T23:26:05.944Z" }, + { url = "https://files.pythonhosted.org/packages/2e/fc/9e3486fb2bbe69d4a867c0b76d68542650a7ff1574ca40e84c3111bb0c6e/asyncpg-0.31.0-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:c0e0822b1038dc7253b337b0f3f676cadc4ac31b126c5d42691c39691962e403", size = 3424067, upload-time = "2025-11-24T23:26:07.957Z" }, + { url = "https://files.pythonhosted.org/packages/12/c6/8c9d076f73f07f995013c791e018a1cd5f31823c2a3187fc8581706aa00f/asyncpg-0.31.0-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:bef056aa502ee34204c161c72ca1f3c274917596877f825968368b2c33f585f4", size = 3518156, upload-time = "2025-11-24T23:26:09.591Z" }, + { url = "https://files.pythonhosted.org/packages/ae/3b/60683a0baf50fbc546499cfb53132cb6835b92b529a05f6a81471ab60d0c/asyncpg-0.31.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:0bfbcc5b7ffcd9b75ab1558f00db2ae07db9c80637ad1b2469c43df79d7a5ae2", size = 3319636, upload-time = "2025-11-24T23:26:11.168Z" }, + { url = "https://files.pythonhosted.org/packages/50/dc/8487df0f69bd398a61e1792b3cba0e47477f214eff085ba0efa7eac9ce87/asyncpg-0.31.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:22bc525ebbdc24d1261ecbf6f504998244d4e3be1721784b5f64664d61fbe602", size = 3472079, upload-time = "2025-11-24T23:26:13.164Z" }, + { url = "https://files.pythonhosted.org/packages/13/a1/c5bbeeb8531c05c89135cb8b28575ac2fac618bcb60119ee9696c3faf71c/asyncpg-0.31.0-cp313-cp313-win32.whl", hash = "sha256:f890de5e1e4f7e14023619399a471ce4b71f5418cd67a51853b9910fdfa73696", size = 527606, upload-time = "2025-11-24T23:26:14.78Z" }, + { url = "https://files.pythonhosted.org/packages/91/66/b25ccb84a246b470eb943b0107c07edcae51804912b824054b3413995a10/asyncpg-0.31.0-cp313-cp313-win_amd64.whl", hash = "sha256:dc5f2fa9916f292e5c5c8b2ac2813763bcd7f58e130055b4ad8a0531314201ab", size = 596569, upload-time = "2025-11-24T23:26:16.189Z" }, + { url = "https://files.pythonhosted.org/packages/3c/36/e9450d62e84a13aea6580c83a47a437f26c7ca6fa0f0fd40b6670793ea30/asyncpg-0.31.0-cp314-cp314-macosx_10_15_x86_64.whl", hash = "sha256:f6b56b91bb0ffc328c4e3ed113136cddd9deefdf5f79ab448598b9772831df44", size = 660867, upload-time = "2025-11-24T23:26:17.631Z" }, + { url = "https://files.pythonhosted.org/packages/82/4b/1d0a2b33b3102d210439338e1beea616a6122267c0df459ff0265cd5807a/asyncpg-0.31.0-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:334dec28cf20d7f5bb9e45b39546ddf247f8042a690bff9b9573d00086e69cb5", size = 638349, upload-time = "2025-11-24T23:26:19.689Z" }, + { url = "https://files.pythonhosted.org/packages/41/aa/e7f7ac9a7974f08eff9183e392b2d62516f90412686532d27e196c0f0eeb/asyncpg-0.31.0-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:98cc158c53f46de7bb677fd20c417e264fc02b36d901cc2a43bd6cb0dc6dbfd2", size = 3410428, upload-time = "2025-11-24T23:26:21.275Z" }, + { url = "https://files.pythonhosted.org/packages/6f/de/bf1b60de3dede5c2731e6788617a512bc0ebd9693eac297ee74086f101d7/asyncpg-0.31.0-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:9322b563e2661a52e3cdbc93eed3be7748b289f792e0011cb2720d278b366ce2", size = 3471678, upload-time = "2025-11-24T23:26:23.627Z" }, + { url = "https://files.pythonhosted.org/packages/46/78/fc3ade003e22d8bd53aaf8f75f4be48f0b460fa73738f0391b9c856a9147/asyncpg-0.31.0-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:19857a358fc811d82227449b7ca40afb46e75b33eb8897240c3839dd8b744218", size = 3313505, upload-time = "2025-11-24T23:26:25.235Z" }, + { url = "https://files.pythonhosted.org/packages/bf/e9/73eb8a6789e927816f4705291be21f2225687bfa97321e40cd23055e903a/asyncpg-0.31.0-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:ba5f8886e850882ff2c2ace5732300e99193823e8107e2c53ef01c1ebfa1e85d", size = 3434744, upload-time = "2025-11-24T23:26:26.944Z" }, + { url = "https://files.pythonhosted.org/packages/08/4b/f10b880534413c65c5b5862f79b8e81553a8f364e5238832ad4c0af71b7f/asyncpg-0.31.0-cp314-cp314-win32.whl", hash = "sha256:cea3a0b2a14f95834cee29432e4ddc399b95700eb1d51bbc5bfee8f31fa07b2b", size = 532251, upload-time = "2025-11-24T23:26:28.404Z" }, + { url = "https://files.pythonhosted.org/packages/d3/2d/7aa40750b7a19efa5d66e67fc06008ca0f27ba1bd082e457ad82f59aba49/asyncpg-0.31.0-cp314-cp314-win_amd64.whl", hash = "sha256:04d19392716af6b029411a0264d92093b6e5e8285ae97a39957b9a9c14ea72be", size = 604901, upload-time = "2025-11-24T23:26:30.34Z" }, + { url = "https://files.pythonhosted.org/packages/ce/fe/b9dfe349b83b9dee28cc42360d2c86b2cdce4cb551a2c2d27e156bcac84d/asyncpg-0.31.0-cp314-cp314t-macosx_10_15_x86_64.whl", hash = "sha256:bdb957706da132e982cc6856bb2f7b740603472b54c3ebc77fe60ea3e57e1bd2", size = 702280, upload-time = "2025-11-24T23:26:32Z" }, + { url = "https://files.pythonhosted.org/packages/6a/81/e6be6e37e560bd91e6c23ea8a6138a04fd057b08cf63d3c5055c98e81c1d/asyncpg-0.31.0-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:6d11b198111a72f47154fa03b85799f9be63701e068b43f84ac25da0bda9cb31", size = 682931, upload-time = "2025-11-24T23:26:33.572Z" }, + { url = "https://files.pythonhosted.org/packages/a6/45/6009040da85a1648dd5bc75b3b0a062081c483e75a1a29041ae63a0bf0dc/asyncpg-0.31.0-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:18c83b03bc0d1b23e6230f5bf8d4f217dc9bc08644ce0502a9d91dc9e634a9c7", size = 3581608, upload-time = "2025-11-24T23:26:35.638Z" }, + { url = "https://files.pythonhosted.org/packages/7e/06/2e3d4d7608b0b2b3adbee0d0bd6a2d29ca0fc4d8a78f8277df04e2d1fd7b/asyncpg-0.31.0-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:e009abc333464ff18b8f6fd146addffd9aaf63e79aa3bb40ab7a4c332d0c5e9e", size = 3498738, upload-time = "2025-11-24T23:26:37.275Z" }, + { url = "https://files.pythonhosted.org/packages/7d/aa/7d75ede780033141c51d83577ea23236ba7d3a23593929b32b49db8ed36e/asyncpg-0.31.0-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:3b1fbcb0e396a5ca435a8826a87e5c2c2cc0c8c68eb6fadf82168056b0e53a8c", size = 3401026, upload-time = "2025-11-24T23:26:39.423Z" }, + { url = "https://files.pythonhosted.org/packages/ba/7a/15e37d45e7f7c94facc1e9148c0e455e8f33c08f0b8a0b1deb2c5171771b/asyncpg-0.31.0-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:8df714dba348efcc162d2adf02d213e5fab1bd9f557e1305633e851a61814a7a", size = 3429426, upload-time = "2025-11-24T23:26:41.032Z" }, + { url = "https://files.pythonhosted.org/packages/13/d5/71437c5f6ae5f307828710efbe62163974e71237d5d46ebd2869ea052d10/asyncpg-0.31.0-cp314-cp314t-win32.whl", hash = "sha256:1b41f1afb1033f2b44f3234993b15096ddc9cd71b21a42dbd87fc6a57b43d65d", size = 614495, upload-time = "2025-11-24T23:26:42.659Z" }, + { url = "https://files.pythonhosted.org/packages/3c/d7/8fb3044eaef08a310acfe23dae9a8e2e07d305edc29a53497e52bc76eca7/asyncpg-0.31.0-cp314-cp314t-win_amd64.whl", hash = "sha256:bd4107bb7cdd0e9e65fae66a62afd3a249663b844fa34d479f6d5b3bef9c04c3", size = 706062, upload-time = "2025-11-24T23:26:44.086Z" }, +] + [[package]] name = "attrs" version = "25.4.0" @@ -2745,6 +2793,23 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/56/a5/df8f46ef7da168f1bc52cd86e09a9de5c6f19cc1da04454d51b7d4f43408/scipy-1.17.0-cp314-cp314t-win_arm64.whl", hash = "sha256:031121914e295d9791319a1875444d55079885bbae5bdc9c5e0f2ee5f09d34ff", size = 25246266, upload-time = "2026-01-10T21:30:45.923Z" }, ] +[[package]] +name = "semantic-text-splitter" +version = "0.29.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/ae/e9/cc31904f918a0f9f20039a969e191a9b88ecfc7536d5ed822a72640d4022/semantic_text_splitter-0.29.0.tar.gz", hash = "sha256:80a57c689f3521730670a881eccf95b996cb6115ee5c916778a3996971899121", size = 283431, upload-time = "2025-12-30T22:31:58.534Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/11/53/ba753d570945b076b7a7e80bc49dc0bc43d60ac294b7b77e7769c52e77ac/semantic_text_splitter-0.29.0-cp310-abi3-macosx_10_12_x86_64.whl", hash = "sha256:a5197301c6ddf57ae421af0e723d97774516b8b93f6439d1ad16760c5290e0d5", size = 8261920, upload-time = "2025-12-30T22:31:42.696Z" }, + { url = "https://files.pythonhosted.org/packages/76/b2/64b4396ddefcffb550461ea5b335b6e9e74356d744ad5aa93640ab4f9306/semantic_text_splitter-0.29.0-cp310-abi3-macosx_11_0_arm64.whl", hash = "sha256:0cd62a7710c9049b9cd6484531379ba1d8c8594a9d63ffef38b65a084aa3c8ac", size = 8268950, upload-time = "2025-12-30T22:31:44.697Z" }, + { url = "https://files.pythonhosted.org/packages/d4/17/d3c8736590b9cd1c83f12e7fb67b63efd8fbf32da9689d1bfd8cee2e6f14/semantic_text_splitter-0.29.0-cp310-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:885a9c18e6f87e9280ac8894ed4760b26f718eb2d81656263278640e4abc35b1", size = 8485522, upload-time = "2025-12-30T22:31:46.582Z" }, + { url = "https://files.pythonhosted.org/packages/fd/8b/69772c4f884d99f62f7426bab3a8e75c3e0bcdcc2fd975e926057a813a47/semantic_text_splitter-0.29.0-cp310-abi3-manylinux_2_28_armv7l.whl", hash = "sha256:417db93494f97e83b8dab12c3e565df3357c4c75d6691a3314f4465a6eb57696", size = 8347931, upload-time = "2025-12-30T22:31:48.591Z" }, + { url = "https://files.pythonhosted.org/packages/e3/d6/acaed5cbe7182850fa7daea0fc9a27ce4c43d5ad0999aa3670013ce033f0/semantic_text_splitter-0.29.0-cp310-abi3-manylinux_2_28_ppc64le.whl", hash = "sha256:654bb33a9466b77177ea0ae2b382c3564977e49cc92c8a0fb4b320876bdedd77", size = 8817946, upload-time = "2025-12-30T22:31:50.028Z" }, + { url = "https://files.pythonhosted.org/packages/c5/90/a2482e3b1fa846b620e752aedc54629fd63034017542b400f1539277e522/semantic_text_splitter-0.29.0-cp310-abi3-manylinux_2_28_s390x.whl", hash = "sha256:1e011bb178801c0e8780e124edb0bd8e396e7846c60ab60aa5b126c57540556d", size = 8659506, upload-time = "2025-12-30T22:31:52.09Z" }, + { url = "https://files.pythonhosted.org/packages/51/86/2947d6fa18bf8f89aa7096993ece71c927647846569a784a3441987d927b/semantic_text_splitter-0.29.0-cp310-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:8d4efcc5f42b9196b1aad165871ad212d7f102847d522f45e73dba712d61d7cb", size = 8492252, upload-time = "2025-12-30T22:31:53.724Z" }, + { url = "https://files.pythonhosted.org/packages/64/61/f00fc614fa38d663d146df29342db6a057e3c086d2b7a84b468b5647b023/semantic_text_splitter-0.29.0-cp310-abi3-win32.whl", hash = "sha256:71adb6d2c9ea14ef56065688634a7fa98ea2f0f8768d6099cdb6fbb82d7745fd", size = 7784072, upload-time = "2025-12-30T22:31:55.753Z" }, + { url = "https://files.pythonhosted.org/packages/75/3c/853c39e0adc769cf51dc87335fc929e1fe66036f3f6ece8c8556d5c1ea87/semantic_text_splitter-0.29.0-cp310-abi3-win_amd64.whl", hash = "sha256:ccbf00e2a54c790f117a4890c45f0045862a9c76bcbc23382e5eaf84fb4c8334", size = 8000626, upload-time = "2025-12-30T22:31:57.173Z" }, +] + [[package]] name = "shapely" version = "2.1.2" @@ -2858,6 +2923,7 @@ name = "tale-crawler" version = "0.1.0" source = { virtual = "." } dependencies = [ + { name = "asyncpg" }, { name = "beautifulsoup4" }, { name = "crawl4ai" }, { name = "fastapi" }, @@ -2872,6 +2938,8 @@ dependencies = [ { name = "python-dotenv" }, { name = "python-multipart" }, { name = "python-pptx" }, + { name = "semantic-text-splitter" }, + { name = "tiktoken" }, { name = "uvicorn", extra = ["standard"] }, ] @@ -2886,6 +2954,7 @@ dev = [ [package.metadata] requires-dist = [ + { name = "asyncpg", specifier = ">=0.30.0" }, { name = "beautifulsoup4", specifier = "==4.14.3" }, { name = "crawl4ai", specifier = "==0.8.0" }, { name = "fastapi", specifier = "==0.133.1" }, @@ -2905,6 +2974,8 @@ requires-dist = [ { name = "python-multipart", specifier = "==0.0.20" }, { name = "python-pptx", specifier = "==1.0.2" }, { name = "ruff", marker = "extra == 'dev'", specifier = "==0.15.4" }, + { name = "semantic-text-splitter", specifier = ">=0.18.0" }, + { name = "tiktoken", specifier = ">=0.9.0" }, { name = "uvicorn", extras = ["standard"], specifier = "==0.41.0" }, ] provides-extras = ["dev"] diff --git a/services/db/Dockerfile b/services/db/Dockerfile index 0b4dc40a4..addbbb75a 100644 --- a/services/db/Dockerfile +++ b/services/db/Dockerfile @@ -1,12 +1,13 @@ -# Dockerfile for Tale DB (TimescaleDB) +# Dockerfile for Tale DB (ParadeDB) # Supports AMD64 and ARM64 architectures +# +# Base: ParadeDB (pg_search BM25 + pgvector) # Version argument - injected by CI from git tag, defaults to 'dev' for local builds ARG VERSION=dev -# Use official TimescaleDB image as base -# TimescaleDB is PostgreSQL with time-series extensions -FROM timescale/timescaledb:2.25.1-pg16 +# ParadeDB includes pg_search (BM25), pgvector, and PostgreSQL 16 +FROM paradedb/paradedb:v0.21.9-pg16 # Re-declare VERSION arg (ARGs don't persist after FROM) ARG VERSION=dev @@ -14,17 +15,22 @@ ARG VERSION=dev # Switch to root for all setup operations USER root -# Install additional tools, create directories, and set up configuration in one layer -# Note: TimescaleDB image is based on Alpine Linux, so we use apk instead of apt-get -RUN apk add --no-cache \ - curl \ - ca-certificates \ +# Install additional tools and create required directories +RUN apt-get update \ + && apt-get install -y --no-install-recommends \ + curl \ + ca-certificates \ + && apt-get clean \ + && rm -rf /var/lib/apt/lists/* \ && mkdir -p /docker-entrypoint-initdb.d \ /etc/postgresql/conf.d \ /var/lib/postgresql/backup # Copy initialization scripts +# - /docker-entrypoint-initdb.d/ : PostgreSQL runs these on first init only +# - /etc/postgresql/init-scripts/ : Entrypoint wrapper runs these on every startup COPY services/db/init-scripts/ /docker-entrypoint-initdb.d/ +COPY services/db/init-scripts/ /etc/postgresql/init-scripts/ # Copy custom PostgreSQL configuration to the expected location COPY services/db/postgresql.conf /etc/postgresql/postgresql.conf @@ -36,8 +42,9 @@ RUN chmod +x /usr/local/bin/docker-entrypoint-wrapper.sh \ /etc/postgresql \ /var/lib/postgresql/backup -# Switch to postgres user for runtime security -USER postgres +# Run as root — the standard docker-entrypoint.sh detects root, +# fixes data directory ownership, then re-execs as postgres via gosu. +# This handles UID mismatches when switching base images (e.g., Alpine→Debian). # Set environment variables with DB_ prefix defaults # These can be overridden at runtime @@ -53,11 +60,8 @@ ENV TALE_VERSION=${VERSION} \ DB_MAX_CONNECTIONS=100 \ DB_SHARED_BUFFERS=256MB \ DB_EFFECTIVE_CACHE_SIZE=1GB \ - DB_MAINTENANCE_WORK_MEM=64MB \ - DB_WORK_MEM=4MB \ - # TimescaleDB specific - DB_TIMESCALEDB_TELEMETRY=off \ - TIMESCALEDB_TELEMETRY=off \ + DB_MAINTENANCE_WORK_MEM=128MB \ + DB_WORK_MEM=32MB \ # Logging DB_LOG_STATEMENT=none \ DB_LOG_MIN_DURATION_STATEMENT=-1 diff --git a/services/db/docker-entrypoint-wrapper.sh b/services/db/docker-entrypoint-wrapper.sh index 9b636ba90..806f6a510 100644 --- a/services/db/docker-entrypoint-wrapper.sh +++ b/services/db/docker-entrypoint-wrapper.sh @@ -3,7 +3,7 @@ set -e # Tale DB Entrypoint Wrapper # This script maps DB_ prefixed environment variables to PostgreSQL configuration -# and then calls the original TimescaleDB entrypoint +# and then calls the original PostgreSQL entrypoint # ============================================================================ # Map DB_ environment variables to PostgreSQL standard variables @@ -42,12 +42,6 @@ if [ -n "$DB_WORK_MEM" ]; then POSTGRES_ARGS+=("-c" "work_mem=${DB_WORK_MEM}") fi -# TimescaleDB settings -if [ -n "$DB_TIMESCALEDB_TELEMETRY" ]; then - POSTGRES_ARGS+=("-c" "timescaledb.telemetry_level=${DB_TIMESCALEDB_TELEMETRY}") - export TIMESCALEDB_TELEMETRY="${DB_TIMESCALEDB_TELEMETRY}" -fi - # Logging settings if [ -n "$DB_LOG_STATEMENT" ]; then POSTGRES_ARGS+=("-c" "log_statement=${DB_LOG_STATEMENT}") @@ -73,11 +67,37 @@ echo "User: ${POSTGRES_USER}" echo "Max Connections: ${DB_MAX_CONNECTIONS:-100}" echo "Shared Buffers: ${DB_SHARED_BUFFERS:-256MB}" echo "Effective Cache Size: ${DB_EFFECTIVE_CACHE_SIZE:-1GB}" -echo "TimescaleDB Telemetry: ${DB_TIMESCALEDB_TELEMETRY:-off}" echo "==================================================" # ============================================================================ -# Call the original TimescaleDB/PostgreSQL entrypoint +# Post-start init scripts (idempotent, run on every startup) +# ============================================================================ +# All init scripts use IF NOT EXISTS / CREATE OR REPLACE / DROP IF EXISTS +# so they are safe to re-run. This ensures schema, extensions, and indexes +# converge to the desired state on every container start — not just first init. + +INIT_SCRIPTS_DIR="/etc/postgresql/init-scripts" + +run_init_scripts() { + echo "Running init scripts..." + for script in "$INIT_SCRIPTS_DIR"/*.sql; do + [ -f "$script" ] || continue + echo " $(basename "$script")" + psql -U "$POSTGRES_USER" -d "$POSTGRES_DB" -f "$script" 2>&1 | grep -E "^(ERROR|NOTICE)" || true + done + echo "Init scripts complete." +} + +# Run init scripts in the background after PostgreSQL starts +( + until pg_isready -U "$POSTGRES_USER" -q 2>/dev/null; do + sleep 1 + done + run_init_scripts +) & + +# ============================================================================ +# Call the original PostgreSQL entrypoint # ============================================================================ exec docker-entrypoint.sh "$@" "${POSTGRES_ARGS[@]}" diff --git a/services/db/init-scripts/01-init-extensions.sql b/services/db/init-scripts/01-init-extensions.sql new file mode 100644 index 000000000..064e9eba1 --- /dev/null +++ b/services/db/init-scripts/01-init-extensions.sql @@ -0,0 +1,24 @@ +-- Tale DB: Core extensions and schema setup +-- Idempotent: safe to run on every startup + +-- Remove legacy TimescaleDB extension if present +DROP EXTENSION IF EXISTS timescaledb CASCADE; + +-- Enable core extensions +CREATE EXTENSION IF NOT EXISTS "uuid-ossp"; +CREATE EXTENSION IF NOT EXISTS "pg_stat_statements"; +CREATE EXTENSION IF NOT EXISTS "pgcrypto"; + +-- Create tale schema +CREATE SCHEMA IF NOT EXISTS tale; + +-- Set search path +DO $$ +BEGIN + EXECUTE format('ALTER DATABASE %I SET search_path TO tale, public', current_database()); +END $$; + +-- Grant permissions +GRANT ALL PRIVILEGES ON SCHEMA tale TO CURRENT_USER; +GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA tale TO CURRENT_USER; +GRANT ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA tale TO CURRENT_USER; diff --git a/services/db/init-scripts/01-init-timescaledb.sql b/services/db/init-scripts/01-init-timescaledb.sql deleted file mode 100644 index 42ac7404c..000000000 --- a/services/db/init-scripts/01-init-timescaledb.sql +++ /dev/null @@ -1,85 +0,0 @@ --- Tale DB Initialization Script --- This script sets up the TimescaleDB extension and creates initial schema - --- Enable TimescaleDB extension -CREATE EXTENSION IF NOT EXISTS timescaledb; - --- Enable additional useful extensions -CREATE EXTENSION IF NOT EXISTS "uuid-ossp"; -- UUID generation -CREATE EXTENSION IF NOT EXISTS "pg_stat_statements"; -- Query performance monitoring -CREATE EXTENSION IF NOT EXISTS "pgcrypto"; -- Cryptographic functions - --- Create schema for Tale application -CREATE SCHEMA IF NOT EXISTS tale; - --- Set search path to include tale schema for the current database -DO $$ -BEGIN - EXECUTE format('ALTER DATABASE %I SET search_path TO tale, public', current_database()); -END $$; - --- Grant permissions to the current user (entrypoint runs scripts as POSTGRES_USER) -GRANT ALL PRIVILEGES ON SCHEMA tale TO CURRENT_USER; -GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA tale TO CURRENT_USER; -GRANT ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA tale TO CURRENT_USER; - --- Create a sample time-series table (can be customized based on needs) -CREATE TABLE IF NOT EXISTS tale.metrics ( - time TIMESTAMPTZ NOT NULL, - metric_name TEXT NOT NULL, - value DOUBLE PRECISION, - tags JSONB, - metadata JSONB -); - --- Convert to hypertable for time-series optimization -SELECT create_hypertable('tale.metrics', 'time', if_not_exists => TRUE); - --- Create indexes for common queries -CREATE INDEX IF NOT EXISTS idx_metrics_name_time ON tale.metrics (metric_name, time DESC); -CREATE INDEX IF NOT EXISTS idx_metrics_tags ON tale.metrics USING GIN (tags); - --- Create a sample events table -CREATE TABLE IF NOT EXISTS tale.events ( - time TIMESTAMPTZ NOT NULL, - event_type TEXT NOT NULL, - user_id UUID, - session_id UUID, - properties JSONB, - metadata JSONB -); - --- Convert to hypertable -SELECT create_hypertable('tale.events', 'time', if_not_exists => TRUE); - --- Create indexes -CREATE INDEX IF NOT EXISTS idx_events_type_time ON tale.events (event_type, time DESC); -CREATE INDEX IF NOT EXISTS idx_events_user ON tale.events (user_id, time DESC); -CREATE INDEX IF NOT EXISTS idx_events_session ON tale.events (session_id, time DESC); -CREATE INDEX IF NOT EXISTS idx_events_properties ON tale.events USING GIN (properties); - --- Create retention policies (optional - adjust based on needs) --- Automatically drop data older than 90 days --- SELECT add_retention_policy('tale.metrics', INTERVAL '90 days', if_not_exists => TRUE); --- SELECT add_retention_policy('tale.events', INTERVAL '90 days', if_not_exists => TRUE); - --- Create continuous aggregates for common queries (optional) --- Example: hourly metrics rollup --- CREATE MATERIALIZED VIEW tale.metrics_hourly --- WITH (timescaledb.continuous) AS --- SELECT --- time_bucket('1 hour', time) AS bucket, --- metric_name, --- AVG(value) as avg_value, --- MAX(value) as max_value, --- MIN(value) as min_value, --- COUNT(*) as count --- FROM tale.metrics --- GROUP BY bucket, metric_name; - --- Log successful initialization -DO $$ -BEGIN - RAISE NOTICE 'Tale DB initialized successfully with TimescaleDB'; -END $$; - diff --git a/services/db/init-scripts/02-create-convex-database.sql b/services/db/init-scripts/02-create-convex-database.sql index 3ffe87195..2fcb8d064 100644 --- a/services/db/init-scripts/02-create-convex-database.sql +++ b/services/db/init-scripts/02-create-convex-database.sql @@ -1,23 +1,14 @@ --- ============================================================================ --- Create Convex Self-Hosted Database --- ============================================================================ --- This script creates the database required by Convex self-hosted backend. --- The database name is HARDCODED to tale_platform for safety and consistency. --- ============================================================================ +-- Tale DB: Convex self-hosted database +-- Idempotent: safe to run on every startup --- Create the Convex database (hardcoded name) -CREATE DATABASE tale_platform; +SELECT 'CREATE DATABASE tale_platform' +WHERE NOT EXISTS (SELECT FROM pg_database WHERE datname = 'tale_platform') +\gexec --- Grant privileges to the tale user GRANT ALL PRIVILEGES ON DATABASE tale_platform TO tale; --- Connect to the new database \c tale_platform --- Enable required extensions +DROP EXTENSION IF EXISTS timescaledb CASCADE; CREATE EXTENSION IF NOT EXISTS "uuid-ossp"; CREATE EXTENSION IF NOT EXISTS "pg_trgm"; - --- Log completion -\echo 'Convex database created successfully: tale_platform' - diff --git a/services/db/init-scripts/03-create-rag-database.sql b/services/db/init-scripts/03-create-rag-database.sql index b5bc4a1e4..b8664e84f 100644 --- a/services/db/init-scripts/03-create-rag-database.sql +++ b/services/db/init-scripts/03-create-rag-database.sql @@ -1,34 +1,19 @@ --- ============================================================================ --- Create RAG (Cognee) Database --- ============================================================================ --- This script creates the database required by the RAG service (Cognee). --- The database is dedicated to RAG to allow safe full-database resets without --- affecting other services (e.g., Convex uses tale_platform). --- ============================================================================ +-- Tale DB: RAG (Cognee) database +-- Idempotent: safe to run on every startup --- Create the RAG database (hardcoded name for safety) -CREATE DATABASE tale_rag; +SELECT 'CREATE DATABASE tale_rag' +WHERE NOT EXISTS (SELECT FROM pg_database WHERE datname = 'tale_rag') +\gexec --- Grant privileges to the tale user GRANT ALL PRIVILEGES ON DATABASE tale_rag TO tale; --- Connect to the new database \c tale_rag --- Enable required extensions for Cognee/PGVector +DROP EXTENSION IF EXISTS timescaledb CASCADE; CREATE EXTENSION IF NOT EXISTS "uuid-ossp"; CREATE EXTENSION IF NOT EXISTS "vector"; --- ============================================================================ --- HNSW Index Management for PGVector --- ============================================================================ --- Cognee creates vector tables dynamically (one per collection/dataset). --- This function creates HNSW indexes on vector columns for fast similarity search. --- Without indexes, queries on 200k+ vectors can take 5-15 seconds. --- With HNSW indexes, queries complete in <500ms. --- ============================================================================ - --- Function to create HNSW indexes on all vector columns that don't have one +-- Dynamic HNSW index creation for Cognee's vector tables CREATE OR REPLACE FUNCTION create_vector_hnsw_indexes() RETURNS void AS $$ DECLARE @@ -36,50 +21,29 @@ DECLARE index_name TEXT; index_exists BOOLEAN; BEGIN - -- Find all columns with vector type in public schema FOR rec IN - SELECT - c.table_name, - c.column_name, - c.udt_name + SELECT c.table_name, c.column_name FROM information_schema.columns c JOIN information_schema.tables t - ON c.table_name = t.table_name - AND c.table_schema = t.table_schema + ON c.table_name = t.table_name AND c.table_schema = t.table_schema WHERE c.table_schema = 'public' AND c.udt_name = 'vector' AND t.table_type = 'BASE TABLE' LOOP - -- Generate index name index_name := rec.table_name || '_' || rec.column_name || '_hnsw_idx'; - -- Check if index already exists SELECT EXISTS ( SELECT 1 FROM pg_indexes - WHERE schemaname = 'public' - AND indexname = index_name + WHERE schemaname = 'public' AND indexname = index_name ) INTO index_exists; - -- Create index if it doesn't exist IF NOT index_exists THEN - RAISE NOTICE 'Creating HNSW index: % on %.%', - index_name, rec.table_name, rec.column_name; - - -- Use cosine distance operator (most common for embeddings) - -- m=16, ef_construction=64 are good defaults for quality/speed balance EXECUTE format( 'CREATE INDEX %I ON %I USING hnsw (%I vector_cosine_ops) WITH (m = 16, ef_construction = 64)', - index_name, - rec.table_name, - rec.column_name + index_name, rec.table_name, rec.column_name ); - RAISE NOTICE 'Created HNSW index: %', index_name; END IF; END LOOP; END; $$ LANGUAGE plpgsql; - --- Log completion -\echo 'RAG database created successfully: tale_rag' -\echo 'HNSW index function created: SELECT create_vector_hnsw_indexes();' diff --git a/services/db/init-scripts/04-create-search-database.sql b/services/db/init-scripts/04-create-search-database.sql new file mode 100644 index 000000000..93b8666ae --- /dev/null +++ b/services/db/init-scripts/04-create-search-database.sql @@ -0,0 +1,124 @@ +-- Tale DB: Crawler search database (pgvector + pg_search BM25) +-- Idempotent: safe to run on every startup + +SELECT 'CREATE DATABASE tale_search' +WHERE NOT EXISTS (SELECT FROM pg_database WHERE datname = 'tale_search') +\gexec + +GRANT ALL PRIVILEGES ON DATABASE tale_search TO tale; + +\c tale_search + +DROP EXTENSION IF EXISTS timescaledb CASCADE; +CREATE EXTENSION IF NOT EXISTS "uuid-ossp"; +CREATE EXTENSION IF NOT EXISTS "vector"; +CREATE EXTENSION IF NOT EXISTS "pg_search"; + +-- ============================================================================ +-- Websites +-- ============================================================================ + +CREATE TABLE IF NOT EXISTS websites ( + domain TEXT PRIMARY KEY, + title TEXT, + description TEXT, + page_count INTEGER NOT NULL DEFAULT 0, + status TEXT NOT NULL DEFAULT 'idle', + scan_interval INTEGER NOT NULL DEFAULT 21600, + last_scanned_at TIMESTAMPTZ, + error TEXT, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW() +); + +CREATE INDEX IF NOT EXISTS idx_websites_status ON websites(status); +CREATE INDEX IF NOT EXISTS idx_websites_due ON websites(status, last_scanned_at); + +-- ============================================================================ +-- Website URLs +-- ============================================================================ + +CREATE TABLE IF NOT EXISTS website_urls ( + id BIGINT GENERATED ALWAYS AS IDENTITY PRIMARY KEY, + domain TEXT NOT NULL REFERENCES websites(domain) ON DELETE CASCADE, + url TEXT NOT NULL, + content_hash TEXT, + status TEXT NOT NULL DEFAULT 'discovered', + last_crawled_at TIMESTAMPTZ, + discovered_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + title TEXT, + content TEXT, + word_count INTEGER, + metadata JSONB, + structured_data JSONB, + fail_count INTEGER NOT NULL DEFAULT 0, + etag TEXT, + last_modified TEXT, + UNIQUE(domain, url) +); + +CREATE INDEX IF NOT EXISTS idx_website_urls_domain ON website_urls(domain); +CREATE INDEX IF NOT EXISTS idx_website_urls_domain_status ON website_urls(domain, status); +CREATE INDEX IF NOT EXISTS idx_website_urls_crawl_order ON website_urls(domain, last_crawled_at NULLS FIRST); + +-- ============================================================================ +-- Chunks (search index) +-- ============================================================================ + +CREATE TABLE IF NOT EXISTS chunks ( + id BIGINT GENERATED ALWAYS AS IDENTITY PRIMARY KEY, + domain TEXT NOT NULL, + url TEXT NOT NULL, + title TEXT, + content_hash TEXT NOT NULL, + chunk_index INTEGER NOT NULL, + chunk_content TEXT NOT NULL, + embedding vector, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + UNIQUE(url, chunk_index), + FOREIGN KEY (domain, url) REFERENCES website_urls(domain, url) ON DELETE CASCADE +); + +CREATE INDEX IF NOT EXISTS idx_chunks_domain ON chunks(domain); +CREATE INDEX IF NOT EXISTS idx_chunks_url ON chunks(url); +CREATE INDEX IF NOT EXISTS idx_chunks_url_content_hash ON chunks(url, content_hash); + +-- BM25 full-text index (ParadeDB pg_search) +DO $$ +BEGIN + IF NOT EXISTS (SELECT 1 FROM pg_indexes WHERE indexname = 'idx_chunks_bm25') THEN + CREATE INDEX idx_chunks_bm25 ON chunks + USING bm25 (id, chunk_content) + WITH (key_field='id'); + END IF; +EXCEPTION WHEN OTHERS THEN + RAISE NOTICE 'BM25 index deferred: %', SQLERRM; +END; +$$; + +-- Dynamic HNSW index (vector dimensions are configurable). +-- The embedding column starts as untyped `vector`; the application pins it to +-- an explicit dimension at startup via ALTER TABLE before calling this function. +CREATE OR REPLACE FUNCTION create_chunks_hnsw_index() +RETURNS void AS $$ +DECLARE + col_type text; +BEGIN + IF NOT EXISTS ( + SELECT 1 FROM pg_indexes + WHERE tablename = 'chunks' AND indexname = 'idx_chunks_embedding_hnsw' + ) THEN + -- Verify the column has explicit dimensions (e.g. vector(1536), not bare vector) + SELECT format_type(atttypid, atttypmod) INTO col_type + FROM pg_attribute + WHERE attrelid = 'chunks'::regclass AND attname = 'embedding'; + + IF col_type = 'vector' THEN + RAISE EXCEPTION 'embedding column has no dimensions – pin it with ALTER TABLE first'; + END IF; + + EXECUTE 'CREATE INDEX idx_chunks_embedding_hnsw ON chunks USING hnsw (embedding vector_cosine_ops) WITH (m = 16, ef_construction = 64)'; + RAISE NOTICE 'Created HNSW index on chunks.embedding'; + END IF; +END; +$$ LANGUAGE plpgsql; diff --git a/services/db/postgresql.conf b/services/db/postgresql.conf index 1a9e36934..fdff07ab3 100644 --- a/services/db/postgresql.conf +++ b/services/db/postgresql.conf @@ -1,5 +1,5 @@ # Tale DB PostgreSQL Configuration -# Custom configuration for TimescaleDB optimized for Tale platform +# Custom configuration for ParadeDB (pg_search + pgvector) optimized for Tale platform # This file is loaded in addition to the default PostgreSQL configuration # ============================================================================ @@ -18,20 +18,20 @@ listen_addresses = '*' # Shared memory for caching data # Controlled by DB_SHARED_BUFFERS environment variable # Recommended: 25% of system RAM for dedicated DB server -# shared_buffers = 256MB +shared_buffers = 256MB # Memory for maintenance operations (VACUUM, CREATE INDEX, etc.) # Controlled by DB_MAINTENANCE_WORK_MEM environment variable -# maintenance_work_mem = 64MB +maintenance_work_mem = 128MB # Memory for query operations (sorts, hash tables) # Controlled by DB_WORK_MEM environment variable -# work_mem = 4MB +work_mem = 32MB # Estimate of memory available for disk caching # Controlled by DB_EFFECTIVE_CACHE_SIZE environment variable # Recommended: 50-75% of system RAM -# effective_cache_size = 1GB +effective_cache_size = 1GB # ============================================================================ # Write-Ahead Log (WAL) Settings @@ -82,21 +82,11 @@ log_line_prefix = '%t [%p]: [%l-1] user=%u,db=%d,app=%a,client=%h ' log_connections = on log_disconnections = on -# ============================================================================ -# TimescaleDB Settings -# ============================================================================ -# Disable telemetry -# Controlled by DB_TIMESCALEDB_TELEMETRY environment variable -# timescaledb.telemetry_level = off - -# TimescaleDB background workers -timescaledb.max_background_workers = 8 - # ============================================================================ # Statistics # ============================================================================ # Enable query statistics collection -shared_preload_libraries = 'timescaledb,pg_stat_statements' +shared_preload_libraries = 'pg_search,pg_stat_statements' # Track query statistics pg_stat_statements.track = all diff --git a/services/platform/app/features/automations/utils/step-icons.tsx b/services/platform/app/features/automations/utils/step-icons.tsx index 0f21e28e1..69cda0215 100644 --- a/services/platform/app/features/automations/utils/step-icons.tsx +++ b/services/platform/app/features/automations/utils/step-icons.tsx @@ -21,7 +21,6 @@ import { CheckCircle, Cloud, Globe, - Layout, GitBranch, Settings, } from 'lucide-react'; @@ -42,7 +41,6 @@ const ACTION_ICON_MAP: Record = { onedrive: Cloud, crawler: Globe, website: Globe, - websitePages: Layout, workflow: GitBranch, }; diff --git a/services/platform/app/features/websites/components/website-edit-dialog.tsx b/services/platform/app/features/websites/components/website-edit-dialog.tsx index 4eb1a638e..7a7688312 100644 --- a/services/platform/app/features/websites/components/website-edit-dialog.tsx +++ b/services/platform/app/features/websites/components/website-edit-dialog.tsx @@ -15,7 +15,6 @@ import { useT } from '@/lib/i18n/client'; import { useUpdateWebsite } from '../hooks/mutations'; type FormData = { - domain: string; scanInterval: string; }; @@ -36,10 +35,6 @@ export function EditWebsiteDialog({ const formSchema = useMemo( () => z.object({ - domain: z - .string() - .min(1, tWebsites('validation.domainRequired')) - .url(tWebsites('validation.validUrl')), scanInterval: z .string() .min(1, tWebsites('validation.scanIntervalRequired')), @@ -58,7 +53,6 @@ export function EditWebsiteDialog({ ]; const { - register, handleSubmit, formState: { errors }, reset, @@ -67,7 +61,6 @@ export function EditWebsiteDialog({ } = useForm({ resolver: zodResolver(formSchema), defaultValues: { - domain: website.domain, scanInterval: website.scanInterval, }, }); @@ -77,7 +70,6 @@ export function EditWebsiteDialog({ useEffect(() => { if (website) { reset({ - domain: website.domain, scanInterval: website.scanInterval, }); } @@ -87,7 +79,6 @@ export function EditWebsiteDialog({ updateWebsite( { websiteId: website._id, - domain: data.domain, scanInterval: data.scanInterval, }, { @@ -119,12 +110,9 @@ export function EditWebsiteDialog({ >