diff --git a/README.md b/README.md index e1d501c..70c62e8 100644 --- a/README.md +++ b/README.md @@ -8,6 +8,7 @@ [![Python Version](https://img.shields.io/badge/python-3.9+-blue.svg)](https://www.python.org/downloads/) [![MongoDB](https://img.shields.io/badge/MongoDB-7.0+-green.svg)](https://www.mongodb.com/) [![SQLAlchemy](https://img.shields.io/badge/SQLAlchemy-1.4+_2.0+-darkgreen.svg)](https://www.sqlalchemy.org/) +[![Superset](https://img.shields.io/badge/Apache_Superset-1.0+-blue.svg)](https://superset.apache.org/docs/6.0.0/configuration/databases) PyMongoSQL is a Python [DB API 2.0 (PEP 249)](https://www.python.org/dev/peps/pep-0249/) client for [MongoDB](https://www.mongodb.com/). It provides a familiar SQL interface to MongoDB, allowing developers to use SQL to interact with MongoDB collections. @@ -40,6 +41,9 @@ PyMongoSQL implements the DB API 2.0 interfaces to provide SQL-like access to Mo - **ANTLR4** (SQL Parser Runtime) - antlr4-python3-runtime >= 4.13.0 +- **JMESPath** (JSON/Dict Path Query) + - jmespath >= 1.0.0 + ### Optional Dependencies - **SQLAlchemy** (for ORM/Core support) @@ -136,37 +140,41 @@ while users: ### SELECT Statements - Field selection: `SELECT name, age FROM users` - Wildcards: `SELECT * FROM products` +- **Nested fields**: `SELECT profile.name, profile.age FROM users` +- **Array access**: `SELECT items[0], items[1].name FROM orders` ### WHERE Clauses - Equality: `WHERE name = 'John'` - Comparisons: `WHERE age > 25`, `WHERE price <= 100.0` - Logical operators: `WHERE age > 18 AND status = 'active'` +- **Nested field filtering**: `WHERE profile.status = 'active'` +- **Array filtering**: `WHERE items[0].price > 100` + +### Nested Field Support +- **Single-level**: `profile.name`, `settings.theme` +- **Multi-level**: `account.profile.name`, `config.database.host` +- **Array access**: `items[0].name`, `orders[1].total` +- **Complex queries**: `WHERE customer.profile.age > 18 AND orders[0].status = 'paid'` + +> **Note**: Avoid SQL reserved words (`user`, `data`, `value`, `count`, etc.) as unquoted field names. Use alternatives or bracket notation for arrays. ### Sorting and Limiting - ORDER BY: `ORDER BY name ASC, age DESC` - LIMIT: `LIMIT 10` - Combined: `ORDER BY created_at DESC LIMIT 5` -## Connection Options +## Limitations & Roadmap -```python -from pymongosql.connection import Connection +**Note**: Currently PyMongoSQL focuses on Data Query Language (DQL) operations. The following SQL features are **not yet supported** but are planned for future releases: -# Basic connection -conn = Connection(host="localhost", port=27017, database="mydb") +- **DML Operations** (Data Manipulation Language) + - `INSERT`, `UPDATE`, `DELETE` +- **DDL Operations** (Data Definition Language) + - `CREATE TABLE/COLLECTION`, `DROP TABLE/COLLECTION` + - `CREATE INDEX`, `DROP INDEX` + - `LIST TABLES/COLLECTIONS` -# With authentication -conn = Connection( - host="mongodb://user:pass@host:port/db?authSource=admin", - database="mydb" -) - -# Connection properties -print(conn.host) # MongoDB connection URL -print(conn.port) # Port number -print(conn.database_name) # Database name -print(conn.is_connected) # Connection status -``` +These features are on our development roadmap and contributions are welcome! ## Contributing diff --git a/pymongosql/__init__.py b/pymongosql/__init__.py index 3014e99..09bcd31 100644 --- a/pymongosql/__init__.py +++ b/pymongosql/__init__.py @@ -6,7 +6,7 @@ if TYPE_CHECKING: from .connection import Connection -__version__: str = "0.2.1" +__version__: str = "0.2.2" # Globals https://www.python.org/dev/peps/pep-0249/#globals apilevel: str = "2.0" @@ -42,6 +42,26 @@ def connect(*args, **kwargs) -> "Connection": return Connection(*args, **kwargs) +# Register superset execution strategy for mongodb+superset:// connections +def _register_superset_executor() -> None: + """Register SupersetExecution strategy for superset mode. + + This allows the executor and cursor to be unaware of superset - + the execution strategy is automatically selected based on the connection mode. + """ + try: + from .executor import ExecutionPlanFactory + from .superset_mongodb.executor import SupersetExecution + + ExecutionPlanFactory.register_strategy(SupersetExecution()) + except ImportError: + # Superset module not available - skip registration + pass + + +# Auto-register superset executor on module import +_register_superset_executor() + # SQLAlchemy integration (optional) # For SQLAlchemy functionality, import from pymongosql.sqlalchemy_mongodb: # from pymongosql.sqlalchemy_mongodb import create_engine_url, create_engine_from_mongodb_uri diff --git a/pymongosql/common.py b/pymongosql/common.py index 0c6fd64..530f396 100644 --- a/pymongosql/common.py +++ b/pymongosql/common.py @@ -17,10 +17,12 @@ class BaseCursor(metaclass=ABCMeta): def __init__( self, connection: "Connection", + mode: str = "standard", **kwargs, ) -> None: super().__init__() self._connection = connection + self.mode = mode @property def connection(self) -> "Connection": diff --git a/pymongosql/connection.py b/pymongosql/connection.py index ec38002..d31f34d 100644 --- a/pymongosql/connection.py +++ b/pymongosql/connection.py @@ -12,6 +12,7 @@ from .common import BaseCursor from .cursor import Cursor from .error import DatabaseError, NotSupportedError, OperationalError +from .helper import ConnectionHelper _logger = logging.getLogger(__name__) @@ -35,9 +36,17 @@ def __init__( to ensure full compatibility. All parameters are passed through directly to the underlying MongoClient. + Supports connection string patterns: + - mongodb://host:port/database - Core driver (no subquery support) + - mongodb+superset://host:port/database - Superset driver with subquery support + See PyMongo MongoClient documentation for full parameter details. https://www.mongodb.com/docs/languages/python/pymongo-driver/current/connect/mongoclient/ """ + # Check if connection string specifies mode + connection_string = host if isinstance(host, str) else None + self._mode, host = ConnectionHelper.parse_connection_string(connection_string) + # Extract commonly used parameters for backward compatibility self._host = host or "localhost" self._port = port or 27017 @@ -154,6 +163,11 @@ def database(self) -> Database: raise OperationalError("No database selected") return self._database + @property + def mode(self) -> str: + """Get the specified mode""" + return self._mode + def use_database(self, database_name: str) -> None: """Switch to a different database""" if self._client is None: @@ -267,6 +281,7 @@ def cursor(self, cursor: Optional[Type[BaseCursor]] = None, **kwargs) -> BaseCur new_cursor = cursor( connection=self, + mode=self._mode, **kwargs, ) self.cursor_pool.append(new_cursor) diff --git a/pymongosql/cursor.py b/pymongosql/cursor.py index bf283a8..0fecac6 100644 --- a/pymongosql/cursor.py +++ b/pymongosql/cursor.py @@ -2,14 +2,11 @@ import logging from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple, TypeVar -from pymongo.cursor import Cursor as MongoCursor -from pymongo.errors import PyMongoError - from .common import BaseCursor, CursorIterator from .error import DatabaseError, OperationalError, ProgrammingError, SqlSyntaxError -from .result_set import ResultSet +from .executor import ExecutionContext, ExecutionPlanFactory +from .result_set import DictResultSet, ResultSet from .sql.builder import ExecutionPlan -from .sql.parser import SQLParser if TYPE_CHECKING: from .connection import Connection @@ -23,16 +20,16 @@ class Cursor(BaseCursor, CursorIterator): NO_RESULT_SET = "No result set." - def __init__(self, connection: "Connection", **kwargs) -> None: + def __init__(self, connection: "Connection", mode: str = "standard", **kwargs) -> None: super().__init__( connection=connection, + mode=mode, **kwargs, ) self._kwargs = kwargs self._result_set: Optional[ResultSet] = None self._result_set_class = ResultSet self._current_execution_plan: Optional[ExecutionPlan] = None - self._mongo_cursor: Optional[MongoCursor] = None self._is_closed = False @property @@ -40,8 +37,8 @@ def result_set(self) -> Optional[ResultSet]: return self._result_set @result_set.setter - def result_set(self, val: ResultSet) -> None: - self._result_set = val + def result_set(self, rs: ResultSet) -> None: + self._result_set = rs @property def has_result_set(self) -> bool: @@ -52,8 +49,8 @@ def result_set_class(self) -> Optional[type]: return self._result_set_class @result_set_class.setter - def result_set_class(self, val: type) -> None: - self._result_set_class = val + def result_set_class(self, rs_cls: type) -> None: + self._result_set_class = rs_cls @property def rowcount(self) -> int: @@ -78,74 +75,6 @@ def _check_closed(self) -> None: if self._is_closed: raise ProgrammingError("Cursor is closed") - def _parse_sql(self, sql: str) -> ExecutionPlan: - """Parse SQL statement and return ExecutionPlan""" - try: - parser = SQLParser(sql) - execution_plan = parser.get_execution_plan() - - if not execution_plan.validate(): - raise SqlSyntaxError("Generated query plan is invalid") - - return execution_plan - - except SqlSyntaxError: - raise - except Exception as e: - _logger.error(f"SQL parsing failed: {e}") - raise SqlSyntaxError(f"Failed to parse SQL: {e}") - - def _execute_execution_plan(self, execution_plan: ExecutionPlan) -> None: - """Execute an ExecutionPlan against MongoDB using db.command""" - try: - # Get database - if not execution_plan.collection: - raise ProgrammingError("No collection specified in query") - - db = self.connection.database - - # Build MongoDB find command - find_command = {"find": execution_plan.collection, "filter": execution_plan.filter_stage or {}} - - # Apply projection if specified (already in MongoDB format) - if execution_plan.projection_stage: - find_command["projection"] = execution_plan.projection_stage - - # Apply sort if specified - if execution_plan.sort_stage: - sort_spec = {} - for sort_dict in execution_plan.sort_stage: - for field, direction in sort_dict.items(): - sort_spec[field] = direction - find_command["sort"] = sort_spec - - # Apply skip if specified - if execution_plan.skip_stage: - find_command["skip"] = execution_plan.skip_stage - - # Apply limit if specified - if execution_plan.limit_stage: - find_command["limit"] = execution_plan.limit_stage - - _logger.debug(f"Executing MongoDB command: {find_command}") - - # Execute find command directly - result = db.command(find_command) - - # Create result set from command result - self._result_set = self._result_set_class( - command_result=result, execution_plan=execution_plan, **self._kwargs - ) - - _logger.info(f"Query executed successfully on collection '{execution_plan.collection}'") - - except PyMongoError as e: - _logger.error(f"MongoDB command execution failed: {e}") - raise DatabaseError(f"Command execution failed: {e}") - except Exception as e: - _logger.error(f"Unexpected error during command execution: {e}") - raise OperationalError(f"Command execution error: {e}") - def execute(self: _T, operation: str, parameters: Optional[Dict[str, Any]] = None) -> _T: """Execute a SQL statement @@ -162,11 +91,25 @@ def execute(self: _T, operation: str, parameters: Optional[Dict[str, Any]] = Non _logger.warning("Parameter substitution not yet implemented, ignoring parameters") try: - # Parse SQL to ExecutionPlan - self._current_execution_plan = self._parse_sql(operation) + # Create execution context + context = ExecutionContext(operation, self.mode) + + # Get appropriate execution strategy + strategy = ExecutionPlanFactory.get_strategy(context) + + # Execute using selected strategy (Standard or Subquery) + result = strategy.execute(context, self.connection) - # Execute the execution plan - self._execute_execution_plan(self._current_execution_plan) + # Store execution plan for reference + self._current_execution_plan = strategy.execution_plan + + # Create result set from command result + self._result_set = self._result_set_class( + command_result=result, + execution_plan=self._current_execution_plan, + database=self.connection.database, + **self._kwargs, + ) return self @@ -236,15 +179,6 @@ def fetchall(self) -> List[Sequence[Any]]: def close(self) -> None: """Close the cursor and free resources""" try: - if self._mongo_cursor: - # Close MongoDB cursor - try: - self._mongo_cursor.close() - except Exception as e: - _logger.warning(f"Error closing MongoDB cursor: {e}") - finally: - self._mongo_cursor = None - if self._result_set: # Close result set try: @@ -274,3 +208,12 @@ def __del__(self): self.close() except Exception: pass # Ignore errors during cleanup + + +class DictCursor(Cursor): + """Cursor that returns results as dictionaries instead of tuples/sequences""" + + def __init__(self, connection: "Connection", **kwargs) -> None: + super().__init__(connection=connection, **kwargs) + # Override result set class to use DictResultSet + self._result_set_class = DictResultSet diff --git a/pymongosql/executor.py b/pymongosql/executor.py new file mode 100644 index 0000000..ad36af3 --- /dev/null +++ b/pymongosql/executor.py @@ -0,0 +1,183 @@ +# -*- coding: utf-8 -*- +""" +Query execution strategies for handling both simple and subquery-based SQL operations. + +This module provides different execution strategies: +- StandardExecution: Direct MongoDB query for simple SELECT statements + +The intermediate database is configurable - any backend implementing QueryDatabase +interface can be used (SQLite3, PostgreSQL, MySQL, etc.). +""" + +import logging +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Any, Dict, Optional + +from pymongo.errors import PyMongoError + +from .error import DatabaseError, OperationalError, ProgrammingError, SqlSyntaxError +from .sql.builder import ExecutionPlan +from .sql.parser import SQLParser + +_logger = logging.getLogger(__name__) + + +@dataclass +class ExecutionContext: + """Manages execution context for a single query""" + + query: str + execution_mode: str = "standard" + + def __repr__(self) -> str: + return f"ExecutionContext(mode={self.execution_mode}, " f"query={self.query})" + + +class ExecutionStrategy(ABC): + """Abstract base class for query execution strategies""" + + @property + @abstractmethod + def execution_plan(self) -> ExecutionPlan: + """Name of the execution plan""" + pass + + @abstractmethod + def execute( + self, + context: ExecutionContext, + connection: Any, + ) -> Optional[Dict[str, Any]]: + """ + Execute query and return result set. + + Args: + context: ExecutionContext with query and subquery info + connection: MongoDB connection + + Returns: + command_result with query results + """ + pass + + @abstractmethod + def supports(self, context: ExecutionContext) -> bool: + """Check if this strategy supports the given context""" + pass + + +class StandardExecution(ExecutionStrategy): + """Standard execution strategy for simple SELECT queries without subqueries""" + + @property + def execution_plan(self) -> ExecutionPlan: + """Return standard execution plan""" + return self._execution_plan + + def supports(self, context: ExecutionContext) -> bool: + """Support simple queries without subqueries""" + return "standard" in context.execution_mode.lower() + + def _parse_sql(self, sql: str) -> ExecutionPlan: + """Parse SQL statement and return ExecutionPlan""" + try: + parser = SQLParser(sql) + execution_plan = parser.get_execution_plan() + + if not execution_plan.validate(): + raise SqlSyntaxError("Generated query plan is invalid") + + return execution_plan + + except SqlSyntaxError: + raise + except Exception as e: + _logger.error(f"SQL parsing failed: {e}") + raise SqlSyntaxError(f"Failed to parse SQL: {e}") + + def _execute_execution_plan(self, execution_plan: ExecutionPlan, db: Any) -> Optional[Dict[str, Any]]: + """Execute an ExecutionPlan against MongoDB using db.command""" + try: + # Get database + if not execution_plan.collection: + raise ProgrammingError("No collection specified in query") + + # Build MongoDB find command + find_command = {"find": execution_plan.collection, "filter": execution_plan.filter_stage or {}} + + # Apply projection if specified + if execution_plan.projection_stage: + find_command["projection"] = execution_plan.projection_stage + + # Apply sort if specified + if execution_plan.sort_stage: + sort_spec = {} + for sort_dict in execution_plan.sort_stage: + for field_name, direction in sort_dict.items(): + sort_spec[field_name] = direction + find_command["sort"] = sort_spec + + # Apply skip if specified + if execution_plan.skip_stage: + find_command["skip"] = execution_plan.skip_stage + + # Apply limit if specified + if execution_plan.limit_stage: + find_command["limit"] = execution_plan.limit_stage + + _logger.debug(f"Executing MongoDB command: {find_command}") + + # Execute find command directly + result = db.command(find_command) + + # Create command result + return result + + except PyMongoError as e: + _logger.error(f"MongoDB command execution failed: {e}") + raise DatabaseError(f"Command execution failed: {e}") + except Exception as e: + _logger.error(f"Unexpected error during command execution: {e}") + raise OperationalError(f"Command execution error: {e}") + + def execute( + self, + context: ExecutionContext, + connection: Any, + ) -> Optional[Dict[str, Any]]: + """Execute standard query directly against MongoDB""" + _logger.debug(f"Using standard execution for query: {context.query[:100]}") + + # Parse the query + self._execution_plan = self._parse_sql(context.query) + + return self._execute_execution_plan(self._execution_plan, connection.database) + + +class ExecutionPlanFactory: + """Factory for creating appropriate execution strategy based on query context""" + + _strategies = [StandardExecution()] + + @classmethod + def get_strategy(cls, context: ExecutionContext) -> ExecutionStrategy: + """Get appropriate execution strategy for context""" + for strategy in cls._strategies: + if strategy.supports(context): + _logger.debug(f"Selected strategy: {strategy.__class__.__name__}") + return strategy + + # Fallback to standard execution + return StandardExecution() + + @classmethod + def register_strategy(cls, strategy: ExecutionStrategy) -> None: + """ + Register a custom execution strategy. + + Args: + strategy: ExecutionStrategy instance + """ + cls._strategies.append(strategy) + _logger.debug(f"Registered strategy: {strategy.__class__.__name__}") diff --git a/pymongosql/helper.py b/pymongosql/helper.py new file mode 100644 index 0000000..68344a6 --- /dev/null +++ b/pymongosql/helper.py @@ -0,0 +1,69 @@ +# -*- coding: utf-8 -*- +""" +Connection helper utilities for PyMongoSQL. + +Handles connection string parsing and mode detection. +""" + +import logging +from typing import Optional, Tuple +from urllib.parse import urlparse + +_logger = logging.getLogger(__name__) + + +class ConnectionHelper: + """Helper class for connection string parsing and mode detection. + + Supports connection string patterns: + - mongodb://host:port/database - Core driver (no subquery support) + - mongodb+superset://host:port/database - Superset driver with subquery support + """ + + @staticmethod + def parse_connection_string(connection_string: str) -> Tuple[str, str, Optional[str], int, Optional[str]]: + """ + Parse PyMongoSQL connection string and determine driver mode. + """ + try: + if not connection_string: + return "standard", None + + parsed = urlparse(connection_string) + scheme = parsed.scheme + + if not parsed.scheme: + return "standard", connection_string + + base_scheme = "mongodb" + mode = "standard" + + # Determine mode from scheme + if "+" in scheme: + base_scheme = scheme.split("+")[0].lower() + mode = scheme.split("+")[-1].lower() + + host = parsed.hostname or "localhost" + port = parsed.port or 27017 + database = parsed.path.lstrip("/") if parsed.path else None + + # Build normalized connection string with mongodb scheme (removing any +mode) + # Reconstruct netloc with credentials if present + netloc = host + if parsed.username: + creds = parsed.username + if parsed.password: + creds += f":{parsed.password}" + netloc = f"{creds}@{host}" + netloc += f":{port}" + + query_part = f"?{parsed.query}" if parsed.query else "" + normalized_connection_string = f"{base_scheme}://{netloc}/{database or ''}{query_part}" + + _logger.debug(f"Parsed connection string - Mode: {mode}, Host: {host}, Port: {port}, Database: {database}") + + return mode, normalized_connection_string + + except Exception as e: + _logger.error(f"Failed to parse connection string: {e}") + raise ValueError(f"Invalid connection string format: {e}") diff --git a/pymongosql/result_set.py b/pymongosql/result_set.py index c0c7848..1a78db2 100644 --- a/pymongosql/result_set.py +++ b/pymongosql/result_set.py @@ -1,8 +1,9 @@ # -*- coding: utf-8 -*- import logging +import re from typing import Any, Dict, List, Optional, Sequence, Tuple -from pymongo.cursor import Cursor as MongoCursor +import jmespath from pymongo.errors import PyMongoError from .common import CursorIterator @@ -18,28 +19,24 @@ class ResultSet(CursorIterator): def __init__( self, command_result: Optional[Dict[str, Any]] = None, - mongo_cursor: Optional[MongoCursor] = None, execution_plan: ExecutionPlan = None, arraysize: int = None, + database: Optional[Any] = None, **kwargs, ) -> None: super().__init__(arraysize=arraysize or self.DEFAULT_FETCH_SIZE, **kwargs) - # Handle both command results and legacy mongo cursor for backward compatibility + # Handle command results from db.command if command_result is not None: self._command_result = command_result - self._mongo_cursor = None + self._database = database # Extract cursor info from command result self._result_cursor = command_result.get("cursor", {}) + self._cursor_id = self._result_cursor.get("id", 0) # 0 means no more results self._raw_results = self._result_cursor.get("firstBatch", []) self._cached_results: List[Sequence[Any]] = [] - elif mongo_cursor is not None: - self._mongo_cursor = mongo_cursor - self._command_result = None - self._raw_results = [] - self._cached_results: List[Sequence[Any]] = [] else: - raise ProgrammingError("Either command_result or mongo_cursor must be provided") + raise ProgrammingError("command_result must be provided") self._execution_plan = execution_plan self._is_closed = False @@ -51,14 +48,22 @@ def __init__( # Process firstBatch immediately if available (after all attributes are set) if command_result is not None and self._raw_results: - processed_batch = [self._process_document(doc) for doc in self._raw_results] - # Convert dictionaries to sequences for DB API 2.0 compliance - sequence_batch = [self._dict_to_sequence(doc) for doc in processed_batch] - self._cached_results.extend(sequence_batch) + self._process_and_cache_batch(self._raw_results) # Build description from projection self._build_description() + def _process_and_cache_batch(self, batch: List[Dict[str, Any]]) -> None: + """Process and cache a batch of documents""" + if not batch: + return + # Process results through projection mapping + processed_batch = [self._process_document(doc) for doc in batch] + # Convert dictionaries to output format (sequence or dict) + formatted_batch = [self._format_result(doc) for doc in processed_batch] + self._cached_results.extend(formatted_batch) + self._total_fetched += len(batch) + def _build_description(self) -> None: """Build column description from execution plan projection""" if not self._execution_plan.projection_stage: @@ -83,37 +88,37 @@ def _ensure_results_available(self, count: int = 1) -> None: if self._cache_exhausted: return - if self._command_result is not None: - # For command results, we already have all data in firstBatch - # No additional fetching needed - self._cache_exhausted = True - return + # Fetch more results if needed and cursor has more data + while len(self._cached_results) < count and self._cursor_id != 0: + try: + # Use getMore to fetch next batch + if self._database is not None and self._execution_plan.collection: + getmore_cmd = { + "getMore": self._cursor_id, + "collection": self._execution_plan.collection, + } + result = self._database.command(getmore_cmd) + + # Extract and process next batch + cursor_info = result.get("cursor", {}) + next_batch = cursor_info.get("nextBatch", []) + self._process_and_cache_batch(next_batch) + + # Update cursor ID for next iteration + self._cursor_id = cursor_info.get("id", 0) + else: + # No database access, mark as exhausted + self._cache_exhausted = True + break + + except PyMongoError as e: + self._errors.append({"error": str(e), "type": type(e).__name__}) + self._cache_exhausted = True + raise DatabaseError(f"Error fetching more results: {e}") - elif self._mongo_cursor is not None: - # Fetch more results if needed (legacy mongo cursor support) - while len(self._cached_results) < count and not self._cache_exhausted: - try: - # Iterate through cursor without calling limit() again - batch = [] - for i, doc in enumerate(self._mongo_cursor): - if i >= self.arraysize: - break - batch.append(doc) - - if not batch: - self._cache_exhausted = True - break - - # Process results through projection mapping - processed_batch = [self._process_document(doc) for doc in batch] - # Convert dictionaries to sequences for DB API 2.0 compliance - sequence_batch = [self._dict_to_sequence(doc) for doc in processed_batch] - self._cached_results.extend(sequence_batch) - self._total_fetched += len(batch) - - except PyMongoError as e: - self._errors.append({"error": str(e), "type": type(e).__name__}) - raise DatabaseError(f"Error fetching results: {e}") + # Mark as exhausted if no more results available + if self._cursor_id == 0: + self._cache_exhausted = True def _process_document(self, doc: Dict[str, Any]) -> Dict[str, Any]: """Process a MongoDB document according to projection mapping""" @@ -125,16 +130,53 @@ def _process_document(self, doc: Dict[str, Any]) -> Dict[str, Any]: processed = {} for field_name, include_flag in self._execution_plan.projection_stage.items(): if include_flag == 1: # Field is included in projection - if field_name in doc: - processed[field_name] = doc[field_name] - elif field_name != "_id": # _id might be excluded by MongoDB - # Field not found, set to None - processed[field_name] = None + # Extract value using jmespath-compatible field path (convert numeric dot indexes to bracket form) + value = self._get_nested_value(doc, field_name) + # Convert the projection key back to bracket notation for client-facing results + display_key = self._mongo_to_bracket_key(field_name) + processed[display_key] = value return processed - def _dict_to_sequence(self, doc: Dict[str, Any]) -> Tuple[Any, ...]: - """Convert document dictionary to sequence according to column order""" + def _mongo_to_bracket_key(self, field_path: str) -> str: + """Convert Mongo dot-index notation to bracket notation. + + Transforms numeric dot segments into bracket indices for both display keys + and JMESPath-compatible field paths. + + Examples: + items.0 -> items[0] + items.1.name -> items[1].name + """ + if not isinstance(field_path, str): + return field_path + # Replace . with [] + return re.sub(r"\.(\d+)", r"[\1]", field_path) + + def _get_nested_value(self, doc: Dict[str, Any], field_path: str) -> Any: + """Extract nested field value from document using JMESPath + + Supports: + - Simple fields: "name" -> doc["name"] + - Nested fields: "profile.bio" -> doc["profile"]["bio"] + - Array indexing: "address.coordinates[1]" -> doc["address"]["coordinates"][1] + - Wildcards: "items[*].name" -> [item["name"] for item in items] + """ + try: + # Optimization: for simple field names without dots/brackets, use direct access + if "." not in field_path and "[" not in field_path: + return doc.get(field_path) + + # Convert normalized Mongo-style numeric segments to bracket notation + normalized_field = self._mongo_to_bracket_key(field_path) + # Use jmespath for complex paths + return jmespath.search(normalized_field, doc) + except Exception as e: + _logger.debug(f"Error extracting field '{field_path}': {e}") + return None + + def _format_result(self, doc: Dict[str, Any]) -> Tuple[Any, ...]: + """Format processed document to output format (tuple for DB API 2.0 compliance)""" if self._column_names is None: # First time - establish column order self._column_names = list(doc.keys()) @@ -214,33 +256,16 @@ def fetchall(self) -> List[Sequence[Any]]: all_results = [] try: - if self._command_result is not None: - # Handle command result (db.command) - if not self._cache_exhausted: - # Results are already processed in constructor, just extend - all_results.extend(self._cached_results) - self._total_fetched += len(self._cached_results) - self._cache_exhausted = True - - elif self._mongo_cursor is not None: - # Handle legacy mongo cursor (for backward compatibility) - # Add cached results - all_results.extend(self._cached_results) - self._cached_results.clear() - - # Fetch remaining from cursor - if not self._cache_exhausted: - # Iterate through all remaining documents in the cursor - remaining_docs = list(self._mongo_cursor) - if remaining_docs: - # Process results through projection mapping - processed_docs = [self._process_document(doc) for doc in remaining_docs] - # Convert dictionaries to sequences for DB API 2.0 compliance - sequence_docs = [self._dict_to_sequence(doc) for doc in processed_docs] - all_results.extend(sequence_docs) - self._total_fetched += len(remaining_docs) - - self._cache_exhausted = True + # Ensure all results are available in cache by requesting a very large number + # This will trigger getMore calls until all data is exhausted + if not self._cache_exhausted and self._cursor_id != 0: + self._ensure_results_available(float("inf")) + + # Now get everything from cache + all_results.extend(self._cached_results) + self._total_fetched += len(self._cached_results) + self._cached_results.clear() + self._cache_exhausted = True except PyMongoError as e: self._errors.append({"error": str(e), "type": type(e).__name__}) @@ -258,17 +283,10 @@ def is_closed(self) -> bool: def close(self) -> None: """Close the result set and free resources""" if not self._is_closed: - try: - if self._mongo_cursor: - self._mongo_cursor.close() - # No special cleanup needed for command results - except Exception as e: - _logger.warning(f"Error closing MongoDB cursor: {e}") - finally: - self._is_closed = True - self._mongo_cursor = None - self._command_result = None - self._cached_results.clear() + self._is_closed = True + self._command_result = None + self._database = None + self._cached_results.clear() def __enter__(self): return self @@ -277,5 +295,13 @@ def __exit__(self, exc_type, exc_val, exc_tb): self.close() +class DictResultSet(ResultSet): + """Result set that returns dictionaries instead of sequences""" + + def _format_result(self, doc: Dict[str, Any]) -> Dict[str, Any]: + """Override to return dictionary directly instead of converting to sequence""" + return doc + + # For backward compatibility MongoResultSet = ResultSet diff --git a/pymongosql/sql/ast.py b/pymongosql/sql/ast.py index ec7b978..9c20a3f 100644 --- a/pymongosql/sql/ast.py +++ b/pymongosql/sql/ast.py @@ -3,7 +3,7 @@ from typing import Any, Dict from ..error import SqlSyntaxError -from .builder import ExecutionPlan +from .builder import BuilderFactory, ExecutionPlan from .handler import BaseHandler, HandlerFactory, ParseResult from .partiql.PartiQLLexer import PartiQLLexer from .partiql.PartiQLParser import PartiQLParser @@ -47,15 +47,14 @@ def parse_result(self) -> ParseResult: return self._parse_result def parse_to_execution_plan(self) -> ExecutionPlan: - """Convert the parse result to an ExecutionPlan""" - return ExecutionPlan( - collection=self._parse_result.collection, - filter_stage=self._parse_result.filter_conditions, - projection_stage=self._parse_result.projection, - sort_stage=self._parse_result.sort_fields, - limit_stage=self._parse_result.limit_value, - skip_stage=self._parse_result.offset_value, - ) + """Convert the parse result to an ExecutionPlan using BuilderFactory""" + builder = BuilderFactory.create_query_builder().collection(self._parse_result.collection) + + builder.filter(self._parse_result.filter_conditions).project(self._parse_result.projection).sort( + self._parse_result.sort_fields + ).limit(self._parse_result.limit_value).skip(self._parse_result.offset_value) + + return builder.build() def visitRoot(self, ctx: PartiQLParser.RootContext) -> Any: """Visit root node and process child nodes""" @@ -154,3 +153,20 @@ def visitLimitClause(self, ctx: PartiQLParser.LimitClauseContext) -> Any: except Exception as e: _logger.warning(f"Error processing LIMIT clause: {e}") return self.visitChildren(ctx) + + def visitOffsetByClause(self, ctx: PartiQLParser.OffsetByClauseContext) -> Any: + """Handle OFFSET clause for result skipping""" + _logger.debug("Processing OFFSET clause") + try: + if hasattr(ctx, "exprSelect") and ctx.exprSelect(): + offset_text = ctx.exprSelect().getText() + try: + offset_value = int(offset_text) + self._parse_result.offset_value = offset_value + _logger.debug(f"Extracted offset value: {offset_value}") + except ValueError as e: + _logger.warning(f"Invalid OFFSET value '{offset_text}': {e}") + return self.visitChildren(ctx) + except Exception as e: + _logger.warning(f"Error processing OFFSET clause: {e}") + return self.visitChildren(ctx) diff --git a/pymongosql/sql/builder.py b/pymongosql/sql/builder.py index 65e950d..66c4f60 100644 --- a/pymongosql/sql/builder.py +++ b/pymongosql/sql/builder.py @@ -104,18 +104,36 @@ def project(self, fields: Union[Dict[str, int], List[str]]) -> "MongoQueryBuilde _logger.debug(f"Set projection: {projection}") return self - def sort(self, field: str, direction: int = 1) -> "MongoQueryBuilder": - """Add sort criteria""" - if not field or not isinstance(field, str): - self._add_error("Sort field must be a non-empty string") - return self + def sort(self, specs: List[Dict[str, int]]) -> "MongoQueryBuilder": + """Add sort criteria. + + Only accepts a list of single-key dicts in the form: + [{"field": 1}, {"other": -1}] - if direction not in [-1, 1]: - self._add_error("Sort direction must be 1 (ascending) or -1 (descending)") + This matches the output produced by the SQL parser (`sort_fields`). + """ + if not isinstance(specs, list): + self._add_error("Sort specifications must be a list of single-key dicts") return self - self._execution_plan.sort_stage.append({field: direction}) - _logger.debug(f"Added sort: {field} -> {direction}") + for spec in specs: + if not isinstance(spec, dict) or len(spec) != 1: + self._add_error("Each sort specification must be a single-key dict, e.g. {'name': 1}") + continue + + field, direction = next(iter(spec.items())) + + if not isinstance(field, str) or not field: + self._add_error("Sort field must be a non-empty string") + continue + + if direction not in [-1, 1]: + self._add_error(f"Sort direction for field '{field}' must be 1 or -1") + continue + + self._execution_plan.sort_stage.append({field: direction}) + _logger.debug(f"Added sort: {field} -> {direction}") + return self def limit(self, count: int) -> "MongoQueryBuilder": diff --git a/pymongosql/sql/handler.py b/pymongosql/sql/handler.py index 49d5126..67bfdc0 100644 --- a/pymongosql/sql/handler.py +++ b/pymongosql/sql/handler.py @@ -3,7 +3,7 @@ Expression handlers for converting SQL expressions to MongoDB query format """ import logging -import time +import re from abc import ABC, abstractmethod from dataclasses import dataclass, field from typing import Any, Dict, List, Optional, Tuple @@ -46,6 +46,10 @@ class ParseResult: limit_value: Optional[int] = None offset_value: Optional[int] = None + # Subquery info (for wrapped subqueries, e.g., Superset outering) + subquery_plan: Optional[Any] = None + subquery_alias: Optional[str] = None + # Factory methods for different use cases @classmethod def for_visitor(cls) -> "ParseResult": @@ -98,6 +102,27 @@ def has_children(ctx: Any) -> bool: """Check if context has children""" return hasattr(ctx, "children") and bool(ctx.children) + @staticmethod + def normalize_field_path(path: str) -> str: + """Normalize jmspath/bracket notation to MongoDB dot notation. + + Examples: + items[0] -> items.0 + items[1].name -> items.1.name + arr['key'] or arr["key"] -> arr.key + """ + if not isinstance(path, str): + return path + + s = path.strip() + # Convert quoted bracket identifiers ["name"] or ['name'] -> .name + s = re.sub(r"\[\s*['\"]([^'\"]+)['\"]\s*\]", r".\1", s) + # Convert numeric bracket indexes [0] -> .0 + s = re.sub(r"\[\s*(\d+)\s*\]", r".\1", s) + # Collapse multiple dots and strip leading/trailing dots + s = re.sub(r"\.{2,}", ".", s).strip(".") + return s + class LoggingMixin: """Mixin providing structured logging functionality""" @@ -114,11 +139,10 @@ def _log_operation_start(self, operation: str, ctx: Any, operation_id: int): }, ) - def _log_operation_success(self, operation: str, operation_id: int, processing_time: float, **extra_data): + def _log_operation_success(self, operation: str, operation_id: int, **extra_data): """Log successful operation completion""" log_data = { "operation": operation, - "processing_time_ms": processing_time, "operation_id": operation_id, } log_data.update(extra_data) @@ -129,7 +153,6 @@ def _log_operation_error( operation: str, ctx: Any, operation_id: int, - processing_time: float, error: Exception, ): """Log operation error with context""" @@ -141,7 +164,6 @@ def _log_operation_error( "context_text": ContextUtilsMixin.get_context_text(ctx), "context_type": ContextUtilsMixin.get_context_type_name(ctx), "operation": operation, - "processing_time_ms": processing_time, "operation_id": operation_id, }, exc_info=True, @@ -244,7 +266,6 @@ def can_handle(self, ctx: Any) -> bool: def handle_expression(self, ctx: Any) -> ParseResult: """Convert comparison expression to MongoDB filter""" - start_time = time.time() operation_id = id(ctx) self._log_operation_start("comparison_parsing", ctx, operation_id) @@ -255,11 +276,9 @@ def handle_expression(self, ctx: Any) -> ParseResult: mongo_filter = self._build_mongo_filter(field_name, operator, value) - processing_time = (time.time() - start_time) * 1000 self._log_operation_success( "comparison_parsing", operation_id, - processing_time, field_name=field_name, operator=operator, ) @@ -267,8 +286,7 @@ def handle_expression(self, ctx: Any) -> ParseResult: return ParseResult(filter_conditions=mongo_filter) except Exception as e: - processing_time = (time.time() - start_time) * 1000 - self._log_operation_error("comparison_parsing", ctx, operation_id, processing_time, e) + self._log_operation_error("comparison_parsing", ctx, operation_id, e) return ParseResult(has_errors=True, error_message=str(e)) def _build_mongo_filter(self, field_name: str, operator: str, value: Any) -> Dict[str, Any]: @@ -358,21 +376,23 @@ def _extract_field_name(self, ctx: Any) -> str: sql_keywords = ["IN(", "LIKE", "BETWEEN", "ISNULL", "ISNOTNULL"] for keyword in sql_keywords: if keyword in text: - return text.split(keyword, 1)[0].strip() + candidate = text.split(keyword, 1)[0].strip() + return self.normalize_field_path(candidate) # Try operator-based splitting operator = self._find_operator_in_text(text, COMPARISON_OPERATORS) if operator: parts = self._split_by_operator(text, operator) if parts: - return parts[0].strip("'\"()") + candidate = parts[0].strip("'\"()") + return self.normalize_field_path(candidate) # Fallback to children parsing if self.has_children(ctx): for child in ctx.children: child_text = self.get_context_text(child) if child_text not in COMPARISON_OPERATORS and not child_text.startswith(("'", '"')): - return child_text + return self.normalize_field_path(child_text) return "unknown_field" except Exception as e: @@ -544,7 +564,6 @@ def _is_logical_context(self, ctx: Any) -> bool: def handle_expression(self, ctx: Any) -> ParseResult: """Convert logical expression to MongoDB filter""" - start_time = time.time() operation_id = id(ctx) self._log_operation_start("logical_parsing", ctx, operation_id) @@ -561,11 +580,9 @@ def handle_expression(self, ctx: Any) -> ParseResult: # Combine operands based on logical operator mongo_filter = self._combine_operands(operator, processed_operands) - processing_time = (time.time() - start_time) * 1000 self._log_operation_success( "logical_parsing", operation_id, - processing_time, operator=operator, processed_count=len(processed_operands), ) @@ -573,8 +590,7 @@ def handle_expression(self, ctx: Any) -> ParseResult: return ParseResult(filter_conditions=mongo_filter) except Exception as e: - processing_time = (time.time() - start_time) * 1000 - self._log_operation_error("logical_parsing", ctx, operation_id, processing_time, e) + self._log_operation_error("logical_parsing", ctx, operation_id, e) return ParseResult(has_errors=True, error_message=str(e)) def _process_operands(self, operands: List[Any]) -> List[Dict[str, Any]]: @@ -731,7 +747,6 @@ def can_handle(self, ctx: Any) -> bool: def handle_expression(self, ctx: Any) -> ParseResult: """Handle function expressions""" - start_time = time.time() operation_id = id(ctx) self._log_operation_start("function_parsing", ctx, operation_id) @@ -742,19 +757,16 @@ def handle_expression(self, ctx: Any) -> ParseResult: # For now, just return a placeholder - this would need full implementation mongo_filter = {"$expr": {self.FUNCTION_MAP.get(function_name.upper(), "$sum"): arguments}} - processing_time = (time.time() - start_time) * 1000 self._log_operation_success( "function_parsing", operation_id, - processing_time, function_name=function_name, ) return ParseResult(filter_conditions=mongo_filter) except Exception as e: - processing_time = (time.time() - start_time) * 1000 - self._log_operation_error("function_parsing", ctx, operation_id, processing_time, e) + self._log_operation_error("function_parsing", ctx, operation_id, e) return ParseResult(has_errors=True, error_message=str(e)) def _is_function_context(self, ctx: Any) -> bool: @@ -873,7 +885,7 @@ def handle(self, ctx: PartiQLParser.WhereClauseSelectContext) -> Dict[str, Any]: # Visitor Handler Classes for AST Processing -class SelectHandler(BaseHandler): +class SelectHandler(BaseHandler, ContextUtilsMixin): """Handles SELECT statement parsing""" def can_handle(self, ctx: Any) -> bool: @@ -893,7 +905,7 @@ def handle_visitor(self, ctx: PartiQLParser.SelectItemsContext, parse_result: "P return projection def _extract_field_and_alias(self, item) -> Tuple[str, Optional[str]]: - """Extract field name and alias from projection item context""" + """Extract field name and alias from projection item context with nested field support""" if not hasattr(item, "children") or not item.children: return str(item), None @@ -903,6 +915,9 @@ def _extract_field_and_alias(self, item) -> Tuple[str, Optional[str]]: # OR children[1] might be just symbolPrimitive (without AS) field_name = item.children[0].getText() + # Normalize bracket notation (jmspath) to Mongo dot notation + field_name = self.normalize_field_path(field_name) + alias = None if len(item.children) >= 2: @@ -927,7 +942,8 @@ def can_handle(self, ctx: Any) -> bool: def handle_visitor(self, ctx: PartiQLParser.FromClauseContext, parse_result: "ParseResult") -> Any: if hasattr(ctx, "tableReference") and ctx.tableReference(): - collection_name = ctx.tableReference().getText() + table_text = ctx.tableReference().getText() + collection_name = table_text parse_result.collection = collection_name return collection_name return None diff --git a/pymongosql/superset_mongodb/__init__.py b/pymongosql/superset_mongodb/__init__.py new file mode 100644 index 0000000..40a96af --- /dev/null +++ b/pymongosql/superset_mongodb/__init__.py @@ -0,0 +1 @@ +# -*- coding: utf-8 -*- diff --git a/pymongosql/superset_mongodb/detector.py b/pymongosql/superset_mongodb/detector.py new file mode 100644 index 0000000..01de7f0 --- /dev/null +++ b/pymongosql/superset_mongodb/detector.py @@ -0,0 +1,111 @@ +# -*- coding: utf-8 -*- +""" +Subquery detection and execution context management for handling Superset-style queries. + +This module provides utilities to detect and manage the execution context for SQL queries +that contain subqueries, enabling the use of SQLite3 as an intermediate database for +complex query operations that MongoDB cannot handle natively. +""" + +import re +from dataclasses import dataclass +from typing import Optional, Tuple + + +@dataclass +class QueryInfo: + """Information about a detected subquery""" + + has_subquery: bool = False + is_wrapped: bool = False # True if query is wrapped like SELECT * FROM (...) AS alias + subquery_text: Optional[str] = None + outer_query_text: Optional[str] = None + subquery_alias: Optional[str] = None + query_depth: int = 0 # Nesting depth + + +class SubqueryDetector: + """Detects and analyzes SQL subqueries in query strings""" + + # Pattern to detect wrapped subqueries: SELECT ... FROM (SELECT ...) AS alias + WRAPPED_SUBQUERY_PATTERN = re.compile( + r"SELECT\s+.*?\s+FROM\s*\(\s*(SELECT\s+.*?)\s*\)\s+(?:AS\s+)?(\w+)", + re.IGNORECASE | re.DOTALL, + ) + + # Pattern to detect simple SELECT start + SELECT_PATTERN = re.compile(r"^\s*SELECT\s+", re.IGNORECASE) + + @classmethod + def detect(cls, query: str) -> QueryInfo: + """ + Detect if a query contains subqueries. + + Args: + query: SQL query string + + Returns: + QueryInfo with detection results + """ + query = query.strip() + + # Check for wrapped subquery pattern (most common Superset case) + match = cls.WRAPPED_SUBQUERY_PATTERN.search(query) + if match: + subquery_text = match.group(1) + subquery_alias = match.group(2) + + if subquery_alias is None or subquery_alias == "": + subquery_alias = "subquery_result" + + return QueryInfo( + has_subquery=True, + is_wrapped=True, + subquery_text=subquery_text, + outer_query_text=query, + subquery_alias=subquery_alias, + query_depth=2, + ) + + # Check if query itself is a SELECT (no subquery) + if cls.SELECT_PATTERN.match(query): + return QueryInfo( + has_subquery=False, + is_wrapped=False, + query_depth=1, + ) + + # Unknown pattern + return QueryInfo(has_subquery=False) + + @classmethod + def extract_subquery(cls, query: str) -> Optional[str]: + """Extract the subquery text from a wrapped query""" + info = cls.detect(query) + return info.subquery_text if info.is_wrapped else None + + @classmethod + def extract_outer_query(cls, query: str) -> Optional[Tuple[str, str]]: + """ + Extract outer query with subquery placeholder. + + Returns: + Tuple of (outer_query, subquery_alias) or None + """ + info = cls.detect(query) + if not info.is_wrapped: + return None + + # Replace subquery with temporary table reference + outer = cls.WRAPPED_SUBQUERY_PATTERN.sub( + f"SELECT * FROM {info.subquery_alias}", + query, + ) + + return outer, info.subquery_alias + + @classmethod + def is_simple_select(cls, query: str) -> bool: + """Check if query is a simple SELECT without subqueries""" + info = cls.detect(query) + return not info.has_subquery and cls.SELECT_PATTERN.match(query) diff --git a/pymongosql/superset_mongodb/executor.py b/pymongosql/superset_mongodb/executor.py new file mode 100644 index 0000000..9cecd47 --- /dev/null +++ b/pymongosql/superset_mongodb/executor.py @@ -0,0 +1,145 @@ +# -*- coding: utf-8 -*- +""" +Query execution strategies for handling both simple and subquery-based SQL operations. + +This module provides different execution strategies: +- StandardExecution: Direct MongoDB query for simple SELECT statements +- SubqueryExecution: Two-stage execution using intermediate RDBMS (SQLite3 by default) + +The intermediate database is configurable - any backend implementing QueryDatabase +interface can be used (SQLite3, PostgreSQL, MySQL, etc.). +""" + +import logging +from typing import Any, Dict, List, Optional + +from ..executor import ExecutionContext, StandardExecution +from ..result_set import ResultSet +from ..sql.builder import ExecutionPlan +from .detector import SubqueryDetector +from .query_db_sqlite import QueryDBSQLite + +_logger = logging.getLogger(__name__) + + +class SupersetExecution(StandardExecution): + """Two-stage execution strategy for subquery-based queries using intermediate RDBMS. + + Uses a QueryDatabase backend (SQLite3 by default) to handle complex + SQL operations that MongoDB cannot perform natively. + + Attributes: + _query_db_factory: Callable that creates QueryDatabase instances + """ + + def __init__(self, query_db_factory: Optional[Any] = None) -> None: + """ + Initialize SupersetExecution with optional custom database backend. + + Args: + query_db_factory: Callable that returns QueryDatabase instance. + Defaults to SQLiteBridge if not provided. + """ + self._query_db_factory = query_db_factory or QueryDBSQLite + self._execution_plan: Optional[ExecutionPlan] = None + + @property + def execution_plan(self) -> ExecutionPlan: + return self._execution_plan + + def supports(self, context: ExecutionContext) -> bool: + """Support queries with subqueries""" + return context.execution_mode == "superset" + + def execute( + self, + context: ExecutionContext, + connection: Any, + ) -> Optional[Dict[str, Any]]: + """Execute query in two stages: MongoDB for subquery, intermediate DB for outer query""" + _logger.debug(f"Using subquery execution for query: {context.query[:100]}") + + # Detect if query is a subquery or simple SELECT + query_info = SubqueryDetector.detect(context.query) + + # If no subquery detected, fall back to standard execution + if not query_info.has_subquery: + _logger.debug("No subquery detected, falling back to standard execution") + return super().execute(context, connection) + + # Stage 1: Execute MongoDB subquery + mongo_query = query_info.subquery_text + _logger.debug(f"Stage 1: Executing MongoDB subquery: {mongo_query}") + + mongo_execution_plan = self._parse_sql(mongo_query) + mongo_result = self._execute_execution_plan(mongo_execution_plan, connection.database) + + # Extract result set from MongoDB + mongo_result_set = ResultSet( + command_result=mongo_result, + execution_plan=mongo_execution_plan, + database=connection.database, + ) + + # Fetch all MongoDB results and convert to list of dicts + mongo_rows = mongo_result_set.fetchall() + _logger.debug(f"Stage 1 complete: Got {len(mongo_rows)} rows from MongoDB") + + # Convert tuple rows to dictionaries using column names + column_names = [desc[0] for desc in mongo_result_set.description] if mongo_result_set.description else [] + mongo_dicts = [] + + for row in mongo_rows: + if column_names: + mongo_dicts.append(dict(zip(column_names, row))) + else: + # Fallback if no description available + mongo_dicts.append({"result": row}) + + # Stage 2: Load results into intermediate DB and execute outer query + db_name = self._query_db_factory.__name__ if hasattr(self._query_db_factory, "__name__") else "QueryDB" + _logger.debug(f"Stage 2: Loading {len(mongo_dicts)} rows into {db_name}") + + query_db = self._query_db_factory() + + try: + # Create temporary table with MongoDB results + querydb_query, table_name = SubqueryDetector.extract_outer_query(context.query) + query_db.insert_records(table_name, mongo_dicts) + + # Execute outer query against intermediate DB + _logger.debug(f"Stage 2: Executing {db_name} query: {querydb_query}") + + querydb_rows = query_db.execute_query(querydb_query) + _logger.debug(f"Stage 2 complete: Got {len(querydb_rows)} rows from {db_name}") + + # Create a ResultSet-like object from intermediate DB results + result_set = self._create_result_set_from_db(querydb_rows, querydb_query) + + self._execution_plan = ExecutionPlan(collection="query_db_result", projection_stage={}) + + return result_set + + finally: + query_db.close() + + def _create_result_set_from_db(self, rows: List[Dict[str, Any]], query: str) -> ResultSet: + """ + Create a ResultSet from query database results. + + Args: + rows: List of dictionaries from query database + query: Original SQL query + + Returns: + ResultSet with query database results + """ + # Create a mock command result structure compatible with ResultSet + command_result = { + "cursor": { + "id": 0, # No pagination for query DB results + "firstBatch": rows, + } + } + + return command_result diff --git a/pymongosql/superset_mongodb/query_db.py b/pymongosql/superset_mongodb/query_db.py new file mode 100644 index 0000000..abfe8a6 --- /dev/null +++ b/pymongosql/superset_mongodb/query_db.py @@ -0,0 +1,117 @@ +# -*- coding: utf-8 -*- +""" +Query database abstraction layer. + +This module provides an abstract interface for databases used +during subquery execution. Allows plugging in different RDBMS backends +(SQLite3, PostgreSQL, MySQL, etc.) while maintaining a unified interface. + +Default implementation uses SQLite3 for in-memory processing. +""" + +import logging +from abc import ABC, abstractmethod +from typing import Any, Dict, List + +_logger = logging.getLogger(__name__) + + +class QueryDatabase(ABC): + """Abstract base class for query database backends""" + + @abstractmethod + def create_table(self, table_name: str, schema: Dict[str, str]) -> None: + """ + Create a table with the specified schema. + + Args: + table_name: Name of the table to create + schema: Dictionary mapping column names to SQL types + """ + pass + + @abstractmethod + def insert_records(self, table_name: str, records: List[Dict[str, Any]], infer_schema: bool = True) -> None: + """ + Insert records into a table. + + Args: + table_name: Name of the table + records: List of dictionaries with data + infer_schema: If True and table doesn't exist, infer schema from records + """ + pass + + @abstractmethod + def execute_query(self, query: str) -> List[Dict[str, Any]]: + """ + Execute a query and return results as list of dictionaries. + + Args: + query: SQL query string + + Returns: + List of dictionaries with query results + """ + pass + + @abstractmethod + def execute_query_cursor(self, query: str) -> Any: + """ + Execute a query and return a cursor-like object. + + Args: + query: SQL query string + + Returns: + Cursor object for row-by-row iteration + """ + pass + + @abstractmethod + def drop_table(self, table_name: str) -> None: + """ + Drop a table. + + Args: + table_name: Name of the table to drop + """ + pass + + @abstractmethod + def table_exists(self, table_name: str) -> bool: + """ + Check if a table exists. + + Args: + table_name: Name of the table + + Returns: + True if table exists, False otherwise + """ + pass + + @abstractmethod + def list_tables(self) -> List[str]: + """ + List all tables in the database. + + Returns: + List of table names + """ + pass + + @abstractmethod + def close(self) -> None: + """Close database connection and cleanup resources.""" + pass + + @abstractmethod + def __enter__(self) -> "QueryDatabase": + """Context manager entry""" + pass + + @abstractmethod + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + """Context manager exit""" + pass diff --git a/pymongosql/superset_mongodb/query_db_sqlite.py b/pymongosql/superset_mongodb/query_db_sqlite.py new file mode 100644 index 0000000..a05e291 --- /dev/null +++ b/pymongosql/superset_mongodb/query_db_sqlite.py @@ -0,0 +1,283 @@ +# -*- coding: utf-8 -*- +""" +SQLite3 bridge for handling query database operations. + +This module manages the creation, population, and querying of in-memory SQLite3 +databases that serve as an intermediate layer between MongoDB and Superset, +enabling support for complex SQL operations that MongoDB cannot handle natively. + +SQLiteBridge is the default implementation of the QueryDatabase interface. +""" + +import logging +import sqlite3 +from typing import Any, Dict, List, Optional + +from .query_db import QueryDatabase + +_logger = logging.getLogger(__name__) + + +class SQLiteTypeMapper: + """Maps Python/MongoDB data types to SQLite3 types""" + + # Type mapping from Python types to SQLite3 types + TYPE_MAP = { + str: "TEXT", + int: "INTEGER", + float: "REAL", + bool: "INTEGER", # SQLite3 uses 0/1 for boolean + bytes: "BLOB", + type(None): "NULL", + dict: "TEXT", # Store as JSON string + list: "TEXT", # Store as JSON string + } + + @classmethod + def get_sqlite_type(cls, value: Any) -> str: + """Get SQLite type for a Python value""" + if value is None: + return "NULL" + + value_type = type(value) + if value_type in cls.TYPE_MAP: + return cls.TYPE_MAP[value_type] + + # Default to TEXT for unknown types + return "TEXT" + + @classmethod + def infer_schema(cls, records: List[Dict[str, Any]]) -> Dict[str, str]: + """ + Infer SQLite schema from a list of records. + + Args: + records: List of dictionaries with data + + Returns: + Dictionary mapping column names to SQLite types + """ + schema = {} + + for record in records: + for col_name, value in record.items(): + if col_name not in schema: + # First occurrence, determine type + schema[col_name] = cls.get_sqlite_type(value) + elif schema[col_name] != "TEXT": + # If we've already determined type, check compatibility + new_type = cls.get_sqlite_type(value) + # Upgrade to TEXT if types differ (safest option) + if new_type != schema[col_name]: + schema[col_name] = "TEXT" + + return schema + + @classmethod + def convert_value(cls, value: Any, target_type: str) -> Any: + """Convert value to appropriate SQLite type""" + if value is None: + return None + + if target_type == "INTEGER": + return int(value) if value is not None else None + elif target_type == "REAL": + return float(value) if value is not None else None + elif target_type == "TEXT": + if isinstance(value, (dict, list)): + import json + + return json.dumps(value) + return str(value) + elif target_type == "BLOB": + if isinstance(value, bytes): + return value + return str(value).encode() + + return value + + +class QueryDBSQLite(QueryDatabase): + """Manages SQLite3 in-memory database for query database operations. + + This is the default implementation of QueryDatabase using SQLite3. + Other RDBMS backends can be created by implementing the QueryDatabase interface. + """ + + def __init__(self) -> None: + """Initialize SQLite3 bridge with in-memory database""" + self._connection: Optional[sqlite3.Connection] = None + self._tables: Dict[str, Dict[str, str]] = {} # table_name -> schema + self._is_closed = False + + def _ensure_connection(self) -> sqlite3.Connection: + """Ensure SQLite3 connection is available""" + if self._is_closed: + raise RuntimeError("SQLiteBridge is closed") + + if self._connection is None: + # Create in-memory database + self._connection = sqlite3.connect(":memory:") + # Enable row factory to get dict-like rows + self._connection.row_factory = sqlite3.Row + _logger.debug("Created in-memory SQLite3 database") + + return self._connection + + def create_table(self, table_name: str, schema: Dict[str, str]) -> None: + """ + Create a table in SQLite3. + + Args: + table_name: Name of the table + schema: Dictionary mapping column names to SQLite types + """ + conn = self._ensure_connection() + + # Build CREATE TABLE statement + columns = ", ".join([f'"{col}" {dtype}' for col, dtype in schema.items()]) + create_sql = f"CREATE TABLE {table_name} ({columns})" + + try: + conn.execute(create_sql) + conn.commit() + self._tables[table_name] = schema + _logger.debug(f"Created SQLite3 table: {table_name}") + except sqlite3.Error as e: + _logger.error(f"Error creating table {table_name}: {e}") + raise + + def insert_records( + self, table_name: str, records: List[Dict[str, Any]], schema: Optional[Dict[str, str]] = None + ) -> int: + """ + Insert records into a SQLite3 table. + + Args: + table_name: Name of the table + records: List of dictionaries to insert + schema: Optional schema (will be inferred if not provided) + + Returns: + Number of records inserted + """ + if not records: + return 0 + + conn = self._ensure_connection() + + # Create table if not exists + if table_name not in self._tables: + if schema is None: + schema = SQLiteTypeMapper.infer_schema(records) + self.create_table(table_name, schema) + + # Build INSERT statement + columns = list(records[0].keys()) + placeholders = ", ".join(["?" for _ in columns]) + insert_sql = f"INSERT INTO {table_name} ({', '.join(columns)}) VALUES ({placeholders})" + + # Convert values to appropriate types + schema = self._tables[table_name] + converted_records = [] + + for record in records: + converted_row = tuple( + SQLiteTypeMapper.convert_value(record.get(col), schema.get(col, "TEXT")) for col in columns + ) + converted_records.append(converted_row) + + try: + conn.executemany(insert_sql, converted_records) + conn.commit() + _logger.debug(f"Inserted {len(records)} records into {table_name}") + return len(records) + except sqlite3.Error as e: + _logger.error(f"Error inserting records into {table_name}: {e}") + raise + + def execute_query(self, query: str) -> List[Dict[str, Any]]: + """ + Execute a query against the SQLite3 database. + + Args: + query: SQL query string + + Returns: + List of dictionaries with query results + """ + conn = self._ensure_connection() + + try: + cursor = conn.execute(query) + # Fetch all rows and convert from sqlite3.Row to dict + rows = cursor.fetchall() + + column_names = [desc[0] for desc in cursor.description] if cursor.description else [] + return [dict(zip(column_names, row)) for row in rows] + except sqlite3.Error as e: + _logger.error(f"Error executing query: {e}") + raise + + def execute_query_cursor(self, query: str) -> sqlite3.Cursor: + """ + Execute a query and return cursor for manual iteration. + + Args: + query: SQL query string + + Returns: + SQLite3 cursor for iteration + """ + conn = self._ensure_connection() + return conn.execute(query) + + def table_exists(self, table_name: str) -> bool: + """Check if a table exists in the database""" + return table_name in self._tables + + def get_table_schema(self, table_name: str) -> Optional[Dict[str, str]]: + """Get the schema of a table""" + return self._tables.get(table_name) + + def list_tables(self) -> List[str]: + """List all tables in the database""" + return list(self._tables.keys()) + + def drop_table(self, table_name: str) -> None: + """Drop a table from the database""" + if table_name not in self._tables: + return + + conn = self._ensure_connection() + try: + conn.execute(f"DROP TABLE {table_name}") + conn.commit() + del self._tables[table_name] + _logger.debug(f"Dropped table: {table_name}") + except sqlite3.Error as e: + _logger.error(f"Error dropping table {table_name}: {e}") + raise + + def close(self) -> None: + """Close the SQLite3 connection""" + if self._connection is not None: + try: + self._connection.close() + _logger.debug("Closed SQLite3 database connection") + except sqlite3.Error as e: + _logger.error(f"Error closing SQLite3 connection: {e}") + finally: + self._connection = None + self._is_closed = True + + def __enter__(self) -> "QueryDBSQLite": + """Context manager entry""" + return self + + def __exit__(self, exc_type, exc_val, exc_tb) -> None: + """Context manager exit""" + self.close() + + def __repr__(self) -> str: + return f"QueryDBSQLite(tables={list(self._tables.keys())}, closed={self._is_closed})" diff --git a/pyproject.toml b/pyproject.toml index d98a795..e42724a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,6 +34,7 @@ requires-python = ">=3.9" dependencies = [ "pymongo>=4.15.0", "antlr4-python3-runtime>=4.13.0", + "jmespath>=1.0.0", ] [project.optional-dependencies] diff --git a/requirements.txt b/requirements.txt index 7ebe8cd..2b5da98 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,3 @@ antlr4-python3-runtime>=4.13.0 pymongo>=4.15.0 +jmespath>=1.0.0 \ No newline at end of file diff --git a/tests/conftest.py b/tests/conftest.py index 0da78f2..2c033bb 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -25,6 +25,25 @@ def make_conn(**kwargs): return Connection(**kwargs) +def make_superset_conn(**kwargs): + """Create a superset-mode Connection using TEST_URI if provided, otherwise use a local default.""" + if TEST_URI: + # Convert test URI to superset mode by replacing mongodb:// with mongodb+superset:// + superset_uri = TEST_URI.replace("mongodb://", "mongodb+superset://", 1) + if "database" not in kwargs: + kwargs["database"] = TEST_DB + return Connection(host=superset_uri, **kwargs) + + # Default local connection parameters with superset mode + defaults = { + "host": "mongodb+superset://testuser:testpass@localhost:27017/test_db?authSource=test_db", + "database": "test_db", + } + for k, v in defaults.items(): + kwargs.setdefault(k, v) + return Connection(**kwargs) + + @pytest.fixture def conn(): """Yield a Connection instance configured via environment variables and tear it down after use.""" @@ -38,6 +57,19 @@ def conn(): pass +@pytest.fixture +def superset_conn(): + """Yield a superset-mode Connection instance and tear it down after use.""" + connection = make_superset_conn() + try: + yield connection + finally: + try: + connection.close() + except Exception: + pass + + @pytest.fixture def make_connection(): """Provide the helper make_conn function to tests that need to create connections with custom args.""" diff --git a/tests/run_test_server.py b/tests/run_test_server.py index 7cdd7cc..b5a917f 100644 --- a/tests/run_test_server.py +++ b/tests/run_test_server.py @@ -126,7 +126,7 @@ def start_mongodb_docker(version="8.0"): capture_output=True, text=True, check=True, - timeout=30, + timeout=120, ) print(f"Container started: {result.stdout.strip()}") @@ -151,7 +151,7 @@ def stop_mongodb_docker(): return False -def wait_for_mongodb(host=MONGODB_HOST, port=MONGODB_PORT, timeout=30): +def wait_for_mongodb(host=MONGODB_HOST, port=MONGODB_PORT, timeout=90): """Wait for MongoDB to be ready""" print(f"Waiting for MongoDB at {host}:{port}... (timeout: {timeout}s)") diff --git a/tests/test_cursor.py b/tests/test_cursor.py index f84aff9..d00b5c0 100644 --- a/tests/test_cursor.py +++ b/tests/test_cursor.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- import pytest -from pymongosql.error import ProgrammingError +from pymongosql.error import DatabaseError, ProgrammingError, SqlSyntaxError from pymongosql.result_set import ResultSet @@ -64,9 +64,7 @@ def test_execute_with_limit(self, conn): assert isinstance(cursor.result_set, ResultSet) rows = cursor.result_set.fetchall() - # Should return results from 22 users in dataset (LIMIT parsing may not be implemented yet) - # TODO: Fix LIMIT parsing in SQL grammar - assert len(rows) >= 1 # At least we get some results + assert len(rows) == 2 # At least we get some results # Check that names are present using DB API 2.0 if len(rows) > 0: @@ -85,7 +83,7 @@ def test_execute_with_skip(self, conn): rows = cursor.result_set.fetchall() # Should return users after skipping 1 (from 22 users in dataset) - assert len(rows) >= 0 # Could be 0-21 depending on implementation + assert len(rows) == 21 # 22 - 1 = 21 users after skipping the first one # Check that results have name field if any results using DB API 2.0 if len(rows) > 0: @@ -111,11 +109,8 @@ def test_execute_with_sort(self, conn): assert "name" in col_names assert all(len(row) >= 1 for row in rows) # All rows should have data - # Verify that we have actual user names from the dataset using DB API 2.0 - if "name" in col_names: - name_idx = col_names.index("name") - names = [row[name_idx] for row in rows] - assert "John Doe" in names # First user from dataset + # Verify that the first name in the result + assert "Patricia Johnson" == rows[0][0] def test_execute_complex_query(self, conn): """Test executing complex query with multiple clauses""" @@ -138,13 +133,38 @@ def test_execute_complex_query(self, conn): for row in rows: assert len(row) >= 2 # Should have at least name and email + def test_execute_nested_fields_query(self, conn): + """Test executing query with nested field access""" + sql = "SELECT name, profile.bio, address.city FROM users WHERE salary >= 100000 ORDER BY salary DESC" + + cursor = conn.cursor() + result = cursor.execute(sql) + assert result == cursor + assert isinstance(cursor.result_set, ResultSet) + + # Get results - test nested field functionality + rows = cursor.result_set.fetchall() + assert isinstance(rows, list) + assert len(rows) == 4 + + # Verify that nested fields are properly projected + if cursor.result_set.description: + col_names = [desc[0] for desc in cursor.result_set.description] + # Should include nested field names in projection + assert "name" in col_names + assert "profile.bio" in col_names + assert "address.city" in col_names + + # Verify the first record matched the highest salary + assert "Patricia Johnson" == rows[0][0] + def test_execute_parser_error(self, conn): """Test executing query with parser errors""" sql = "INVALID SQL SYNTAX" # This should raise an exception due to invalid SQL cursor = conn.cursor() - with pytest.raises(Exception): # Could be SqlSyntaxError or other parsing error + with pytest.raises(SqlSyntaxError): # Could be SqlSyntaxError or other parsing error cursor.execute(sql) def test_execute_database_error(self, conn, make_connection): @@ -156,7 +176,7 @@ def test_execute_database_error(self, conn, make_connection): # This should raise an exception due to closed connection cursor = conn.cursor() - with pytest.raises(Exception): # Could be DatabaseError or OperationalError + with pytest.raises(DatabaseError): cursor.execute(sql) # Reconnect for other tests @@ -166,31 +186,6 @@ def test_execute_database_error(self, conn, make_connection): finally: new_conn.close() - def test_execute_with_aliases(self, conn): - """Test executing query with field aliases""" - sql = "SELECT name AS full_name, email AS user_email FROM users" - cursor = conn.cursor() - result = cursor.execute(sql) - - assert result == cursor # execute returns self - assert isinstance(cursor.result_set, ResultSet) - rows = cursor.result_set.fetchall() - - # Should return users with aliased field names - assert len(rows) == 22 - - # Check that alias fields are present if aliasing works using DB API 2.0 - col_names = [desc[0] for desc in cursor.result_set.description] - # Aliases might not work yet, so check for either original or alias names - assert "name" in col_names or "full_name" in col_names - # Check for email columns in description - has_email = "email" in col_names or "user_email" in col_names - for row in rows: - assert len(row) >= 2 # Should have at least 2 columns - # Verify we have email data if expected - if has_email: - assert True # Email column exists in description - def test_fetchone_without_execute(self, conn): """Test fetchone without previous execute""" fresh_cursor = conn.cursor() @@ -216,11 +211,10 @@ def test_fetchone_with_result(self, conn): # Execute query first cursor = conn.cursor() _ = cursor.execute(sql) - - # Test fetchone - DB API 2.0 returns sequences, not dicts row = cursor.fetchone() + assert row is not None - assert isinstance(row, (tuple, list)) # Should be sequence, not dict + assert isinstance(row, (tuple, list)) # Verify we have data using DB API 2.0 approach col_names = [desc[0] for desc in cursor.result_set.description] if cursor.result_set.description else [] if "name" in col_names: diff --git a/tests/test_dict_cursor.py b/tests/test_dict_cursor.py new file mode 100644 index 0000000..ef41f36 --- /dev/null +++ b/tests/test_dict_cursor.py @@ -0,0 +1,166 @@ +# -*- coding: utf-8 -*- +from pymongosql.cursor import DictCursor +from pymongosql.result_set import DictResultSet + + +class TestDictCursor: + """Test suite for DictCursor class - returns results as dictionaries""" + + def test_dict_cursor_init(self, conn): + """Test DictCursor initialization""" + dict_cursor = conn.cursor(DictCursor) + assert dict_cursor._connection == conn + assert dict_cursor._result_set is None + assert dict_cursor._result_set_class == DictResultSet + + def test_dict_cursor_simple_select(self, conn): + """Test DictCursor returning results as dictionaries""" + sql = "SELECT name, email FROM users WHERE age > 25" + dict_cursor = conn.cursor(DictCursor) + result = dict_cursor.execute(sql) + + assert result == dict_cursor # execute returns self + assert isinstance(dict_cursor.result_set, DictResultSet) + rows = dict_cursor.result_set.fetchall() + + # Should return 19 users with age > 25 + assert len(rows) == 19 + + # Results should be dictionaries with field names as keys + if len(rows) > 0: + first_row = rows[0] + assert isinstance(first_row, dict) + assert "name" in first_row + assert "email" in first_row + # Should have exactly 2 keys + assert len(first_row) == 2 + # Verify actual data values + assert first_row["name"] == "John Doe" + assert first_row["email"] == "john@example.com" + + def test_dict_cursor_select_all(self, conn): + """Test DictCursor with SELECT *""" + sql = "SELECT * FROM products LIMIT 3" + dict_cursor = conn.cursor(DictCursor) + dict_cursor.execute(sql) + rows = dict_cursor.result_set.fetchall() + + assert len(rows) <= 3 + + # All results should be dictionaries + for row in rows: + assert isinstance(row, dict) + assert len(row) > 0 # Should have fields + + def test_dict_cursor_fetchone(self, conn): + """Test DictCursor fetchone returns dictionary""" + sql = "SELECT name, age FROM users" + dict_cursor = conn.cursor(DictCursor) + dict_cursor.execute(sql) + + row = dict_cursor.fetchone() + + assert row is not None + assert isinstance(row, dict) + assert "name" in row + assert "age" in row + # Verify actual data values + assert row["name"] == "John Doe" + assert row["age"] == 30 + + def test_dict_cursor_fetchmany(self, conn): + """Test DictCursor fetchmany returns list of dictionaries""" + sql = "SELECT name, email FROM users ORDER BY name" + dict_cursor = conn.cursor(DictCursor) + dict_cursor.execute(sql) + + rows = dict_cursor.fetchmany(3) + + assert len(rows) == 3 + # All rows should be dictionaries + for row in rows: + assert isinstance(row, dict) + assert "name" in row + assert "email" in row + # Verify actual data values + assert rows[0]["name"] == "Alice Williams" + assert rows[0]["email"] == "alice@example.com" + assert rows[1]["name"] == "Bob Johnson" + assert rows[2]["name"] == "Broken Reference User" + + def test_dict_cursor_with_where_clause(self, conn): + """Test DictCursor with WHERE clause""" + sql = "SELECT name, status FROM users WHERE age > 30 ORDER BY name" + dict_cursor = conn.cursor(DictCursor) + dict_cursor.execute(sql) + + rows = dict_cursor.fetchall() + + # Results should be dictionaries and all have age > 30 + assert len(rows) == 11 + if len(rows) > 0: + for row in rows: + assert isinstance(row, dict) + assert "name" in row + assert "status" in row + # Verify actual data values + assert rows[0]["name"] == "Bob Johnson" + assert rows[0]["status"] is None + + def test_dict_cursor_with_order_by(self, conn): + """Test DictCursor with ORDER BY""" + sql = "SELECT name FROM users ORDER BY age DESC LIMIT 1" + dict_cursor = conn.cursor(DictCursor) + dict_cursor.execute(sql) + + rows = dict_cursor.fetchall() + + assert len(rows) == 1 + assert isinstance(rows[0], dict) + assert "name" in rows[0] + assert rows[0]["name"] == "Patricia Johnson" # Highest age + + def test_dict_cursor_vs_tuple_cursor(self, conn): + """Test that DictCursor returns dicts while regular Cursor returns tuples""" + sql = "SELECT name, email FROM users LIMIT 1" + + # Get result from regular cursor (tuple) + cursor = conn.cursor() + cursor.execute(sql) + tuple_row = cursor.fetchone() + + # Get result from dict cursor (dict) + dict_cursor = conn.cursor(DictCursor) + dict_cursor.execute(sql) + dict_row = dict_cursor.fetchone() + + # Regular cursor returns tuple/sequence + assert isinstance(tuple_row, (tuple, list)) + + # DictCursor returns dictionary + assert isinstance(dict_row, dict) + + # Both should have same data (just different formats) + col_names = [desc[0] for desc in cursor.result_set.description] + assert "name" in col_names + assert "email" in col_names + assert "name" in dict_row + assert "email" in dict_row + # Verify actual data matches between both cursor types + assert dict_row["name"] == tuple_row[col_names.index("name")] + assert dict_row["email"] == tuple_row[col_names.index("email")] + assert dict_row["name"] == "John Doe" + assert dict_row["email"] == "john@example.com" + + def test_dict_cursor_close(self, conn): + """Test DictCursor close""" + dict_cursor = conn.cursor(DictCursor) + dict_cursor.execute("SELECT * FROM users LIMIT 1") + dict_cursor.close() + assert dict_cursor._result_set is None + + def test_dict_cursor_context_manager(self, conn): + """Test DictCursor as context manager""" + dict_cursor = conn.cursor(DictCursor) + with dict_cursor as ctx: + assert ctx == dict_cursor diff --git a/tests/test_result_set.py b/tests/test_result_set.py index ed81a29..3658391 100644 --- a/tests/test_result_set.py +++ b/tests/test_result_set.py @@ -3,7 +3,7 @@ from pymongosql.error import ProgrammingError from pymongosql.result_set import ResultSet -from pymongosql.sql.builder import ExecutionPlan +from pymongosql.sql.builder import BuilderFactory class TestResultSet: @@ -19,7 +19,9 @@ def test_result_set_init(self, conn): # Execute a real command to get results command_result = db.command({"find": "users", "filter": {"age": {"$gt": 25}}, "limit": 1}) - execution_plan = ExecutionPlan(collection="users", projection_stage=self.PROJECTION_WITH_FIELDS) + execution_plan = ( + BuilderFactory.create_query_builder().collection("users").project(self.PROJECTION_WITH_FIELDS).build() + ) result_set = ResultSet(command_result=command_result, execution_plan=execution_plan) assert result_set._command_result == command_result assert result_set._execution_plan == execution_plan @@ -30,7 +32,9 @@ def test_result_set_init_empty_projection(self, conn): db = conn.database command_result = db.command({"find": "users", "limit": 1}) - execution_plan = ExecutionPlan(collection="users", projection_stage=self.PROJECTION_EMPTY) + execution_plan = ( + BuilderFactory.create_query_builder().collection("users").project(self.PROJECTION_EMPTY).build() + ) result_set = ResultSet(command_result=command_result, execution_plan=execution_plan) assert result_set._execution_plan.projection_stage == {} @@ -40,7 +44,9 @@ def test_fetchone_with_data(self, conn): # Get real user data with projection mapping command_result = db.command({"find": "users", "projection": {"name": 1, "email": 1}, "limit": 1}) - execution_plan = ExecutionPlan(collection="users", projection_stage=self.PROJECTION_WITH_FIELDS) + execution_plan = ( + BuilderFactory.create_query_builder().collection("users").project(self.PROJECTION_WITH_FIELDS).build() + ) result_set = ResultSet(command_result=command_result, execution_plan=execution_plan) row = result_set.fetchone() @@ -66,7 +72,9 @@ def test_fetchone_no_data(self, conn): {"find": "users", "filter": {"age": {"$gt": 999}}, "limit": 1} # No users over 999 years old ) - execution_plan = ExecutionPlan(collection="users", projection_stage=self.PROJECTION_WITH_FIELDS) + execution_plan = ( + BuilderFactory.create_query_builder().collection("users").project(self.PROJECTION_WITH_FIELDS).build() + ) result_set = ResultSet(command_result=command_result, execution_plan=execution_plan) row = result_set.fetchone() @@ -77,7 +85,9 @@ def test_fetchone_empty_projection(self, conn): db = conn.database command_result = db.command({"find": "users", "limit": 1, "sort": {"_id": 1}}) - execution_plan = ExecutionPlan(collection="users", projection_stage=self.PROJECTION_EMPTY) + execution_plan = ( + BuilderFactory.create_query_builder().collection("users").project(self.PROJECTION_EMPTY).build() + ) result_set = ResultSet(command_result=command_result, execution_plan=execution_plan) row = result_set.fetchone() @@ -102,7 +112,9 @@ def test_fetchone_closed_cursor(self, conn): db = conn.database command_result = db.command({"find": "users", "limit": 1}) - execution_plan = ExecutionPlan(collection="users", projection_stage=self.PROJECTION_WITH_FIELDS) + execution_plan = ( + BuilderFactory.create_query_builder().collection("users").project(self.PROJECTION_WITH_FIELDS).build() + ) result_set = ResultSet(command_result=command_result, execution_plan=execution_plan) result_set.close() @@ -115,7 +127,9 @@ def test_fetchmany_with_data(self, conn): # Get multiple users with projection command_result = db.command({"find": "users", "projection": {"name": 1, "email": 1}, "limit": 5}) - execution_plan = ExecutionPlan(collection="users", projection_stage=self.PROJECTION_WITH_FIELDS) + execution_plan = ( + BuilderFactory.create_query_builder().collection("users").project(self.PROJECTION_WITH_FIELDS).build() + ) result_set = ResultSet(command_result=command_result, execution_plan=execution_plan) rows = result_set.fetchmany(2) @@ -141,7 +155,9 @@ def test_fetchmany_default_size(self, conn): # Get all users (22 total in test dataset) command_result = db.command({"find": "users"}) - execution_plan = ExecutionPlan(collection="users", projection_stage=self.PROJECTION_EMPTY) + execution_plan = ( + BuilderFactory.create_query_builder().collection("users").project(self.PROJECTION_EMPTY).build() + ) result_set = ResultSet(command_result=command_result, execution_plan=execution_plan) rows = result_set.fetchmany() # Should use default arraysize (1000) @@ -153,7 +169,9 @@ def test_fetchmany_less_data_available(self, conn): # Get only 2 users but request 5 command_result = db.command({"find": "users", "limit": 2}) - execution_plan = ExecutionPlan(collection="users", projection_stage=self.PROJECTION_EMPTY) + execution_plan = ( + BuilderFactory.create_query_builder().collection("users").project(self.PROJECTION_EMPTY).build() + ) result_set = ResultSet(command_result=command_result, execution_plan=execution_plan) rows = result_set.fetchmany(5) # Request 5 but only 2 available @@ -165,7 +183,9 @@ def test_fetchmany_no_data(self, conn): # Query for non-existent data command_result = db.command({"find": "users", "filter": {"age": {"$gt": 999}}}) # No users over 999 years old - execution_plan = ExecutionPlan(collection="users", projection_stage=self.PROJECTION_EMPTY) + execution_plan = ( + BuilderFactory.create_query_builder().collection("users").project(self.PROJECTION_EMPTY).build() + ) result_set = ResultSet(command_result=command_result, execution_plan=execution_plan) rows = result_set.fetchmany(3) @@ -179,7 +199,9 @@ def test_fetchall_with_data(self, conn): {"find": "users", "filter": {"age": {"$gt": 25}}, "projection": {"name": 1, "email": 1}} ) - execution_plan = ExecutionPlan(collection="users", projection_stage=self.PROJECTION_WITH_FIELDS) + execution_plan = ( + BuilderFactory.create_query_builder().collection("users").project(self.PROJECTION_WITH_FIELDS).build() + ) result_set = ResultSet(command_result=command_result, execution_plan=execution_plan) rows = result_set.fetchall() @@ -201,7 +223,9 @@ def test_fetchall_no_data(self, conn): db = conn.database command_result = db.command({"find": "users", "filter": {"age": {"$gt": 999}}}) # No users over 999 years old - execution_plan = ExecutionPlan(collection="users", projection_stage=self.PROJECTION_EMPTY) + execution_plan = ( + BuilderFactory.create_query_builder().collection("users").project(self.PROJECTION_EMPTY).build() + ) result_set = ResultSet(command_result=command_result, execution_plan=execution_plan) rows = result_set.fetchall() @@ -212,7 +236,9 @@ def test_fetchall_closed_cursor(self, conn): db = conn.database command_result = db.command({"find": "users", "limit": 1}) - execution_plan = ExecutionPlan(collection="users", projection_stage=self.PROJECTION_EMPTY) + execution_plan = ( + BuilderFactory.create_query_builder().collection("users").project(self.PROJECTION_EMPTY).build() + ) result_set = ResultSet(command_result=command_result, execution_plan=execution_plan) result_set.close() @@ -222,7 +248,7 @@ def test_fetchall_closed_cursor(self, conn): def test_apply_projection_mapping(self): """Test _process_document method""" projection = {"name": 1, "age": 1, "email": 1} - execution_plan = ExecutionPlan(collection="users", projection_stage=projection) + execution_plan = BuilderFactory.create_query_builder().collection("users").project(projection).build() # Create empty command result for testing _process_document method command_result = {"cursor": {"firstBatch": []}} @@ -248,7 +274,7 @@ def test_apply_projection_mapping_missing_fields(self): "age": 1, "missing": 1, } - execution_plan = ExecutionPlan(collection="users", projection_stage=projection) + execution_plan = BuilderFactory.create_query_builder().collection("users").project(projection).build() command_result = {"cursor": {"firstBatch": []}} result_set = ResultSet(command_result=command_result, execution_plan=execution_plan) @@ -264,7 +290,7 @@ def test_apply_projection_mapping_missing_fields(self): def test_apply_projection_mapping_identity_mapping(self): """Test projection with MongoDB standard format""" projection = {"name": 1, "age": 1} - execution_plan = ExecutionPlan(collection="users", projection_stage=projection) + execution_plan = BuilderFactory.create_query_builder().collection("users").project(projection).build() command_result = {"cursor": {"firstBatch": []}} result_set = ResultSet(command_result=command_result, execution_plan=execution_plan) @@ -276,10 +302,27 @@ def test_apply_projection_mapping_identity_mapping(self): expected = {"name": "John", "age": 30} assert mapped_doc == expected + def test_array_projection_mapping(self): + """Test projection mapping with array bracket/number conversion""" + projection = {"items.0": 1, "items.1.name": 1} + execution_plan = BuilderFactory.create_query_builder().collection("orders").project(projection).build() + + command_result = {"cursor": {"firstBatch": []}} + result_set = ResultSet(command_result=command_result, execution_plan=execution_plan) + + doc = {"items": [{"price": 50, "name": "a"}, {"price": 200, "name": "b"}]} + + mapped_doc = result_set._process_document(doc) + + expected = {"items[0]": {"price": 50, "name": "a"}, "items[1].name": "b"} + assert mapped_doc == expected + def test_close(self): """Test close method""" command_result = {"cursor": {"firstBatch": []}} - execution_plan = ExecutionPlan(collection="users", projection_stage=self.PROJECTION_EMPTY) + execution_plan = ( + BuilderFactory.create_query_builder().collection("users").project(self.PROJECTION_EMPTY).build() + ) result_set = ResultSet(command_result=command_result, execution_plan=execution_plan) # Should not be closed initially @@ -293,7 +336,9 @@ def test_close(self): def test_context_manager(self): """Test ResultSet as context manager""" command_result = {"cursor": {"firstBatch": []}} - execution_plan = ExecutionPlan(collection="users", projection_stage=self.PROJECTION_EMPTY) + execution_plan = ( + BuilderFactory.create_query_builder().collection("users").project(self.PROJECTION_EMPTY).build() + ) result_set = ResultSet(command_result=command_result, execution_plan=execution_plan) with result_set as rs: @@ -306,7 +351,9 @@ def test_context_manager(self): def test_context_manager_with_exception(self): """Test context manager with exception""" command_result = {"cursor": {"firstBatch": []}} - execution_plan = ExecutionPlan(collection="users", projection_stage=self.PROJECTION_EMPTY) + execution_plan = ( + BuilderFactory.create_query_builder().collection("users").project(self.PROJECTION_EMPTY).build() + ) result_set = ResultSet(command_result=command_result, execution_plan=execution_plan) try: @@ -325,7 +372,9 @@ def test_iterator_protocol(self, conn): # Get 2 users from database command_result = db.command({"find": "users", "limit": 2}) - execution_plan = ExecutionPlan(collection="users", projection_stage=self.PROJECTION_EMPTY) + execution_plan = ( + BuilderFactory.create_query_builder().collection("users").project(self.PROJECTION_EMPTY).build() + ) result_set = ResultSet(command_result=command_result, execution_plan=execution_plan) # Test iterator protocol @@ -348,7 +397,9 @@ def test_iterator_with_projection(self, conn): db = conn.database command_result = db.command({"find": "users", "projection": {"name": 1, "email": 1}, "limit": 2}) - execution_plan = ExecutionPlan(collection="users", projection_stage=self.PROJECTION_WITH_FIELDS) + execution_plan = ( + BuilderFactory.create_query_builder().collection("users").project(self.PROJECTION_WITH_FIELDS).build() + ) result_set = ResultSet(command_result=command_result, execution_plan=execution_plan) rows = list(result_set) @@ -363,7 +414,9 @@ def test_iterator_with_projection(self, conn): def test_iterator_closed_cursor(self): """Test iteration on closed cursor""" command_result = {"cursor": {"firstBatch": []}} - execution_plan = ExecutionPlan(collection="users", projection_stage=self.PROJECTION_EMPTY) + execution_plan = ( + BuilderFactory.create_query_builder().collection("users").project(self.PROJECTION_EMPTY).build() + ) result_set = ResultSet(command_result=command_result, execution_plan=execution_plan) result_set.close() @@ -373,7 +426,9 @@ def test_iterator_closed_cursor(self): def test_arraysize_property(self): """Test arraysize property""" command_result = {"cursor": {"firstBatch": []}} - execution_plan = ExecutionPlan(collection="users", projection_stage=self.PROJECTION_EMPTY) + execution_plan = ( + BuilderFactory.create_query_builder().collection("users").project(self.PROJECTION_EMPTY).build() + ) result_set = ResultSet(command_result=command_result, execution_plan=execution_plan) # Default arraysize should be 1000 @@ -386,7 +441,9 @@ def test_arraysize_property(self): def test_arraysize_validation(self): """Test arraysize validation""" command_result = {"cursor": {"firstBatch": []}} - execution_plan = ExecutionPlan(collection="users", projection_stage=self.PROJECTION_EMPTY) + execution_plan = ( + BuilderFactory.create_query_builder().collection("users").project(self.PROJECTION_EMPTY).build() + ) result_set = ResultSet(command_result=command_result, execution_plan=execution_plan) # Should reject invalid values diff --git a/tests/test_result_set_pagination.py b/tests/test_result_set_pagination.py new file mode 100644 index 0000000..253075c --- /dev/null +++ b/tests/test_result_set_pagination.py @@ -0,0 +1,198 @@ +# -*- coding: utf-8 -*- +from pymongosql.result_set import ResultSet +from pymongosql.sql.builder import BuilderFactory + + +class TestResultSetPagination: + """Test suite for ResultSet pagination with getMore""" + + # Shared projections used by tests + PROJECTION_WITH_FIELDS = {"name": 1, "email": 1} + PROJECTION_EMPTY = {} + + def test_pagination_cursor_id_zero(self, conn): + """Test pagination when cursor_id is 0 (all results in firstBatch)""" + db = conn.database + # Query with small limit - all results fit in firstBatch + command_result = db.command({"find": "users", "limit": 5}) + + execution_plan = ( + BuilderFactory.create_query_builder().collection("users").project(self.PROJECTION_EMPTY).build() + ) + result_set = ResultSet(command_result=command_result, execution_plan=execution_plan, database=db) + + # Check cursor_id - should be 0 when all results fit in firstBatch + assert result_set._cursor_id == 0 + assert result_set._cache_exhausted is False # Not exhausted yet, but no getMore needed + + # Fetch all results + rows = result_set.fetchall() + assert len(rows) == 5 + + # After fetching all, cache should be exhausted + assert result_set._cache_exhausted is True + + def test_pagination_multiple_batches(self, conn): + """Test pagination across multiple batches with getMore""" + db = conn.database + # Use a small batch size (batchSize) to force pagination + command_result = db.command({"find": "users", "batchSize": 5}) # Only 5 results per batch + + execution_plan = ( + BuilderFactory.create_query_builder().collection("users").project(self.PROJECTION_EMPTY).build() + ) + result_set = ResultSet(command_result=command_result, execution_plan=execution_plan, database=db) + + # Initial results should have cursor_id > 0 since we have 22 total users and batchSize is 5 + initial_cached = len(result_set._cached_results) + assert initial_cached <= 5 # Should have at most 5 in cache from firstBatch + + # Fetch multiple results (should trigger getMore) + rows = result_set.fetchmany(10) + assert len(rows) == 10 + + # After fetching, we should have processed multiple batches + assert result_set._total_fetched >= 10 + + def test_pagination_ensure_results_available(self, conn): + """Test _ensure_results_available with pagination""" + db = conn.database + # Request results with small batch size + command_result = db.command({"find": "users", "batchSize": 3}) # Small batch to test pagination + + execution_plan = ( + BuilderFactory.create_query_builder().collection("users").project(self.PROJECTION_EMPTY).build() + ) + result_set = ResultSet(command_result=command_result, execution_plan=execution_plan, database=db) + + # Initially, cache might have 3 results + initial_cache_size = len(result_set._cached_results) + assert initial_cache_size <= 3 + + # Ensure we have 8 results available - should trigger getMore + result_set._ensure_results_available(8) + assert len(result_set._cached_results) >= 8 + + # Check that cursor_id was updated + assert result_set._cursor_id >= 0 + + def test_pagination_fetchone_triggers_getmore(self, conn): + """Test that fetchone triggers getMore when needed""" + db = conn.database + # Create result set with small batch size + command_result = db.command({"find": "users", "batchSize": 2}) # Very small batch + + execution_plan = ( + BuilderFactory.create_query_builder().collection("users").project(self.PROJECTION_WITH_FIELDS).build() + ) + result_set = ResultSet(command_result=command_result, execution_plan=execution_plan, database=db) + + _ = result_set._cursor_id + rows_fetched = [] + + # Fetch many single rows - should trigger getMore multiple times + for _ in range(10): + row = result_set.fetchone() + if row: + rows_fetched.append(row) + + assert len(rows_fetched) == 10 + # rowcount should reflect total fetched + assert result_set.rowcount >= 10 + + def test_pagination_cache_exhausted_flag(self, conn): + """Test cache exhausted flag is set correctly""" + db = conn.database + command_result = db.command({"find": "users", "limit": 3}) + + execution_plan = ( + BuilderFactory.create_query_builder().collection("users").project(self.PROJECTION_EMPTY).build() + ) + result_set = ResultSet(command_result=command_result, execution_plan=execution_plan, database=db) + + assert result_set._cache_exhausted is False + + # Fetch all results + rows = result_set.fetchall() + assert len(rows) == 3 + + # After exhausting results, flag should be set + assert result_set._cache_exhausted is True + + # Subsequent fetches should return empty + more_rows = result_set.fetchall() + assert more_rows == [] + + def test_pagination_rowcount_tracking(self, conn): + """Test rowcount is accurately tracked during pagination""" + db = conn.database + command_result = db.command({"find": "users", "batchSize": 4}) + + execution_plan = ( + BuilderFactory.create_query_builder().collection("users").project(self.PROJECTION_EMPTY).build() + ) + result_set = ResultSet(command_result=command_result, execution_plan=execution_plan, database=db) + + initial_rowcount = result_set.rowcount + assert initial_rowcount <= 4 # Initial batch size + + # Fetch multiple batches + batch1 = result_set.fetchmany(8) + assert result_set.rowcount >= 8 + + batch2 = result_set.fetchmany(5) + assert result_set.rowcount >= 13 + + # Fetch all remaining + all_remaining = result_set.fetchall() + _ = result_set.rowcount + + # All 22 users should be fetched eventually + total_fetched = len(batch1) + len(batch2) + len(all_remaining) + assert total_fetched == 22 + + def test_pagination_with_projection(self, conn): + """Test pagination with field projection applied""" + db = conn.database + command_result = db.command({"find": "users", "projection": {"name": 1, "email": 1}, "batchSize": 3}) + + execution_plan = ( + BuilderFactory.create_query_builder().collection("users").project(self.PROJECTION_WITH_FIELDS).build() + ) + result_set = ResultSet(command_result=command_result, execution_plan=execution_plan, database=db) + + # Fetch across multiple batches + rows = result_set.fetchall() + + # Should have all 22 users + assert len(rows) == 22 + + # Each row should have exactly 2 projected fields + col_names = [desc[0] for desc in result_set.description] + for row in rows: + assert len(row) == 2 + assert isinstance(row[col_names.index("name")], (str, type(None))) + assert isinstance(row[col_names.index("email")], (str, type(None))) + + def test_pagination_fetchmany_across_batches(self, conn): + """Test fetchmany that spans multiple getMore calls""" + db = conn.database + command_result = db.command({"find": "users", "batchSize": 3}) + + execution_plan = ( + BuilderFactory.create_query_builder().collection("users").project(self.PROJECTION_EMPTY).build() + ) + result_set = ResultSet(command_result=command_result, execution_plan=execution_plan, database=db) + + # Fetch 10 rows - should span multiple batches + batch1 = result_set.fetchmany(10) + assert len(batch1) == 10 + + # Fetch next 10 - should get more users + batch2 = result_set.fetchmany(10) + assert len(batch2) == 10 + + # Fetch remaining results + batch3 = result_set.fetchmany(5) + # Should get remaining users (total > 20, depends on actual data size) + assert len(batch3) > 0 diff --git a/tests/test_sql_parser.py b/tests/test_sql_parser_general.py similarity index 76% rename from tests/test_sql_parser.py rename to tests/test_sql_parser_general.py index fe4cbe2..4b7e223 100644 --- a/tests/test_sql_parser.py +++ b/tests/test_sql_parser_general.py @@ -5,7 +5,7 @@ from pymongosql.sql.parser import SQLParser -class TestSQLParser: +class TestSQLParserGeneral: """Comprehensive test suite for SQL parser from simple to complex queries""" def test_simple_select_all(self): @@ -239,6 +239,29 @@ def test_select_with_limit(self): assert execution_plan.limit_stage == 10 assert execution_plan.projection_stage == {"name": 1} + def test_select_with_offset(self): + """Test SELECT with OFFSET clause""" + sql = "SELECT name FROM users OFFSET 5" + parser = SQLParser(sql) + + assert not parser.has_errors, f"Parser errors: {parser.errors}" + execution_plan = parser.get_execution_plan() + assert execution_plan.collection == "users" + assert execution_plan.skip_stage == 5 + assert execution_plan.projection_stage == {"name": 1} + + def test_select_with_limit_and_offset(self): + """Test SELECT with both LIMIT and OFFSET clauses""" + sql = "SELECT name, email FROM users LIMIT 10 OFFSET 5" + parser = SQLParser(sql) + + assert not parser.has_errors, f"Parser errors: {parser.errors}" + execution_plan = parser.get_execution_plan() + assert execution_plan.collection == "users" + assert execution_plan.limit_stage == 10 + assert execution_plan.skip_stage == 5 + assert execution_plan.projection_stage == {"name": 1, "email": 1} + def test_complex_query_combination(self): """Test complex query with multiple clauses""" sql = """ @@ -277,126 +300,6 @@ def test_parser_error_handling(self): parser = SQLParser("INVALID SQL SYNTAX") parser.get_execution_plan() - def test_select_with_as_aliases(self): - """Test SELECT with AS aliases""" - sql = "SELECT name AS username, email AS user_email FROM customers" - parser = SQLParser(sql) - - assert not parser.has_errors, f"Parser errors: {parser.errors}" - - execution_plan = parser.get_execution_plan() - assert execution_plan.collection == "customers" - assert execution_plan.filter_stage == {} - assert execution_plan.projection_stage == { - "name": 1, - "email": 1, - } - - def test_select_with_mixed_aliases(self): - """Test SELECT with mixed alias formats""" - sql = "SELECT name AS username, age user_age, status FROM users" - parser = SQLParser(sql) - - assert not parser.has_errors, f"Parser errors: {parser.errors}" - - execution_plan = parser.get_execution_plan() - assert execution_plan.collection == "users" - assert execution_plan.filter_stage == {} - assert execution_plan.projection_stage == { - "name": 1, # AS alias - "age": 1, # Space-separated alias - "status": 1, # No alias (field included) - } - - def test_select_with_space_separated_aliases(self): - """Test SELECT with space-separated aliases""" - sql = "SELECT first_name fname, last_name lname, created_at creation_date FROM users" - parser = SQLParser(sql) - - assert not parser.has_errors, f"Parser errors: {parser.errors}" - - execution_plan = parser.get_execution_plan() - assert execution_plan.collection == "users" - assert execution_plan.filter_stage == {} - assert execution_plan.projection_stage == { - "first_name": 1, - "last_name": 1, - "created_at": 1, - } - - def test_select_with_complex_field_names_and_aliases(self): - """Test SELECT with complex field names and aliases""" - sql = "SELECT user_profile.name AS display_name, account_settings.theme user_theme FROM users" - parser = SQLParser(sql) - - assert not parser.has_errors, f"Parser errors: {parser.errors}" - - execution_plan = parser.get_execution_plan() - assert execution_plan.collection == "users" - assert execution_plan.filter_stage == {} - assert execution_plan.projection_stage == { - "user_profile.name": 1, - "account_settings.theme": 1, - } - - def test_select_function_with_aliases(self): - """Test SELECT with functions and aliases""" - sql = "SELECT COUNT(*) AS total_count, MAX(age) max_age FROM users" - parser = SQLParser(sql) - - assert not parser.has_errors, f"Parser errors: {parser.errors}" - - execution_plan = parser.get_execution_plan() - assert execution_plan.collection == "users" - assert execution_plan.filter_stage == {} - assert execution_plan.projection_stage == { - "COUNT(*)": 1, - "MAX(age)": 1, - } - - def test_select_single_field_with_alias(self): - """Test SELECT with single field and alias""" - sql = "SELECT email AS contact_email FROM customers" - parser = SQLParser(sql) - - assert not parser.has_errors, f"Parser errors: {parser.errors}" - - execution_plan = parser.get_execution_plan() - assert execution_plan.collection == "customers" - assert execution_plan.filter_stage == {} - assert execution_plan.projection_stage == {"email": 1} - - def test_select_aliases_with_where_clause(self): - """Test SELECT with aliases and WHERE clause""" - sql = "SELECT name AS username, status AS account_status FROM users WHERE age > 18" - parser = SQLParser(sql) - - assert not parser.has_errors, f"Parser errors: {parser.errors}" - - execution_plan = parser.get_execution_plan() - assert execution_plan.collection == "users" - assert execution_plan.filter_stage == {"age": {"$gt": 18}} - assert execution_plan.projection_stage == { - "name": 1, - "status": 1, - } - - def test_select_case_insensitive_as_alias(self): - """Test SELECT with case insensitive AS keyword""" - sql = "SELECT name as username, email As user_email, status AS account_status FROM users" - parser = SQLParser(sql) - - assert not parser.has_errors, f"Parser errors: {parser.errors}" - - execution_plan = parser.get_execution_plan() - assert execution_plan.collection == "users" - assert execution_plan.filter_stage == {} - assert execution_plan.projection_stage == { - "name": 1, - "email": 1, - "status": 1, - } - def test_different_collection_names(self): """Test parsing with different collection names""" test_cases = [ diff --git a/tests/test_sql_parser_nested_fields.py b/tests/test_sql_parser_nested_fields.py new file mode 100644 index 0000000..3d223ac --- /dev/null +++ b/tests/test_sql_parser_nested_fields.py @@ -0,0 +1,193 @@ +# -*- coding: utf-8 -*- +""" +Comprehensive tests for nested field support in PyMongoSQL +""" +import pytest + +from pymongosql.error import SqlSyntaxError +from pymongosql.sql.parser import SQLParser + + +class TestSQLParserNestedFields: + """Test suite for nested field querying functionality""" + + def test_basic_single_level_nesting_select(self): + """Test basic single-level nested fields in SELECT""" + sql = "SELECT c.a, c.b FROM collection" + parser = SQLParser(sql) + + assert not parser.has_errors, f"Parser errors: {parser.errors}" + + execution_plan = parser.get_execution_plan() + assert execution_plan.collection == "collection" + assert execution_plan.projection_stage == {"c.a": 1, "c.b": 1} + assert execution_plan.filter_stage == {} + + def test_basic_single_level_nesting_where(self): + """Test basic single-level nested fields in WHERE clause""" + sql = "SELECT * FROM users WHERE profile.status = 'active'" + parser = SQLParser(sql) + + assert not parser.has_errors, f"Parser errors: {parser.errors}" + + execution_plan = parser.get_execution_plan() + assert execution_plan.collection == "users" + assert execution_plan.filter_stage == {"profile.status": "active"} + + def test_multi_level_nesting_non_reserved_words(self): + """Test multi-level nested fields with non-reserved words""" + sql = "SELECT account.profile.name FROM users WHERE account.settings.theme = 'dark'" + parser = SQLParser(sql) + + assert not parser.has_errors, f"Parser errors: {parser.errors}" + + execution_plan = parser.get_execution_plan() + assert execution_plan.collection == "users" + assert execution_plan.projection_stage == {"account.profile.name": 1} + assert execution_plan.filter_stage == {"account.settings.theme": "dark"} + + def test_array_bracket_notation_select(self): + """Test array access using bracket notation in SELECT""" + sql = "SELECT items[0], items[1].name FROM orders" + parser = SQLParser(sql) + + assert not parser.has_errors, f"Parser errors: {parser.errors}" + + execution_plan = parser.get_execution_plan() + assert execution_plan.collection == "orders" + assert execution_plan.projection_stage == {"items.0": 1, "items.1.name": 1} + + def test_array_bracket_notation_where(self): + """Test array access using bracket notation in WHERE""" + sql = "SELECT * FROM orders WHERE items[0].price > 100" + parser = SQLParser(sql) + + assert not parser.has_errors, f"Parser errors: {parser.errors}" + + execution_plan = parser.get_execution_plan() + assert execution_plan.collection == "orders" + assert execution_plan.filter_stage == {"items.0.price": {"$gt": 100}} + + def test_quoted_reserved_words(self): + """Test using quoted reserved words as field names - currently limited support""" + # Note: This test documents current limitations with quoted identifiers in complex paths + sql = 'SELECT "user" FROM collection' # Simplified test that works + parser = SQLParser(sql) + + assert not parser.has_errors, f"Parser errors: {parser.errors}" + + execution_plan = parser.get_execution_plan() + assert execution_plan.collection == "collection" + assert execution_plan.projection_stage == {'"user"': 1} + + def test_complex_nested_query(self): + """Test complex query with multiple nested field types""" + sql = """ + SELECT + customer.profile.name, + orders[0].total, + settings.preferences.theme + FROM transactions + WHERE customer.profile.age > 18 + AND orders[0].status = 'completed' + AND settings.notifications = true + """ + parser = SQLParser(sql) + + assert not parser.has_errors, f"Parser errors: {parser.errors}" + + execution_plan = parser.get_execution_plan() + assert execution_plan.collection == "transactions" + + expected_projection = {"customer.profile.name": 1, "orders.0.total": 1, "settings.preferences.theme": 1} + assert execution_plan.projection_stage == expected_projection + + # The filter should be a combination of conditions + expected_filter = { + "$and": [ + {"customer.profile.age": {"$gt": 18}}, + {"orders.0.status": "completed"}, + {"settings.notifications": True}, + ] + } + assert execution_plan.filter_stage == expected_filter + + def test_reserved_word_user_fails(self): + """Test that unquoted 'user' keyword fails""" + sql = "SELECT user.profile.name FROM users" + + with pytest.raises(SqlSyntaxError) as exc_info: + parser = SQLParser(sql) + parser.get_execution_plan() + + assert "no viable alternative" in str(exc_info.value) + + def test_reserved_word_value_fails(self): + """Test that unquoted 'value' keyword fails""" + sql = "SELECT data.value FROM items" + + with pytest.raises(SqlSyntaxError) as exc_info: + parser = SQLParser(sql) + parser.get_execution_plan() + + assert "no viable alternative" in str(exc_info.value) + + def test_numeric_dot_notation_fails(self): + """Test that numeric dot notation fails""" + sql = "SELECT c.0.name FROM collection" + + with pytest.raises(SqlSyntaxError) as exc_info: + parser = SQLParser(sql) + parser.get_execution_plan() + + assert "mismatched input" in str(exc_info.value) + + def test_nested_with_comparison_operators(self): + """Test nested fields with various comparison operators""" + # Test supported comparison operators with non-reserved field names + test_cases = [ + ("profile.age > 18", {"profile.age": {"$gt": 18}}), + ("settings.total < 100", {"settings.total": {"$lt": 100}}), # Changed from 'count' (reserved) + ("status.active = true", {"status.active": True}), + ("config.name != 'default'", {"config.name": {"$ne": "default"}}), + ] + + for where_clause, expected_filter in test_cases: + sql = f"SELECT * FROM collection WHERE {where_clause}" + parser = SQLParser(sql) + + assert not parser.has_errors, f"Parser errors for '{where_clause}': {parser.errors}" + + execution_plan = parser.get_execution_plan() + assert execution_plan.filter_stage == expected_filter + + def test_nested_with_logical_operators(self): + """Test nested fields with logical operators""" + sql = """ + SELECT * FROM users + WHERE profile.age > 18 + AND settings.active = true + OR profile.vip = true + """ + parser = SQLParser(sql) + + assert not parser.has_errors, f"Parser errors: {parser.errors}" + + execution_plan = parser.get_execution_plan() + # The exact structure depends on operator precedence handling + assert "profile.age" in str(execution_plan.filter_stage) + assert "settings.active" in str(execution_plan.filter_stage) + assert "profile.vip" in str(execution_plan.filter_stage) + + def test_nested_with_aliases(self): + """Test nested fields with column aliases""" + sql = "SELECT profile.name AS fullname, settings.theme AS ui_theme FROM users" + parser = SQLParser(sql) + + assert not parser.has_errors, f"Parser errors: {parser.errors}" + + execution_plan = parser.get_execution_plan() + assert execution_plan.collection == "users" + # Note: Current implementation uses original field names in projection + # Aliases are handled at the result processing level + assert execution_plan.projection_stage == {"profile.name": 1, "settings.theme": 1} diff --git a/tests/test_superset_connection.py b/tests/test_superset_connection.py new file mode 100644 index 0000000..fecd009 --- /dev/null +++ b/tests/test_superset_connection.py @@ -0,0 +1,226 @@ +# -*- coding: utf-8 -*- +""" +Tests for superset subquery connection mode. + +Tests the mongodb+superset:// connection pattern and verifies that: +1. Superset mode is correctly detected from connection strings +2. SubqueryExecution strategy is registered and used +3. Subqueries are supported in superset mode +4. Subqueries are rejected in core mode +""" + + +from pymongosql.executor import ExecutionContext, ExecutionPlanFactory +from pymongosql.helper import ConnectionHelper +from pymongosql.superset_mongodb.executor import SupersetExecution + + +class TestSupersetConnectionString: + """Test parsing of superset connection strings""" + + def test_parse_superset_mode(self): + """Test parsing mongodb+superset:// connection string""" + mode, normalized = ConnectionHelper.parse_connection_string("mongodb+superset://localhost:27017/testdb") + assert mode == "superset" + assert normalized == "mongodb://localhost:27017/testdb" + + def test_parse_core_mode(self): + """Test parsing standard mongodb:// connection string""" + mode, normalized = ConnectionHelper.parse_connection_string("mongodb://localhost:27017/testdb") + assert mode == "standard" + assert normalized == "mongodb://localhost:27017/testdb" + + def test_parse_with_credentials(self): + """Test parsing connection string with username and password""" + mode, normalized = ConnectionHelper.parse_connection_string( + "mongodb+superset://user:pass@localhost:27017/testdb" + ) + assert mode == "superset" + assert "user:pass@localhost" in normalized + + def test_parse_with_query_params(self): + """Test parsing connection string with query parameters""" + mode, normalized = ConnectionHelper.parse_connection_string( + "mongodb+superset://localhost:27017/testdb?retryWrites=true&w=majority" + ) + assert mode == "superset" + assert "retryWrites=true" in normalized + assert "w=majority" in normalized + + def test_parse_none_connection_string(self): + """Test parsing None connection string returns defaults""" + mode, normalized = ConnectionHelper.parse_connection_string(None) + assert mode == "standard" + assert normalized is None + + def test_parse_empty_connection_string(self): + """Test parsing empty connection string returns defaults""" + mode, normalized = ConnectionHelper.parse_connection_string("") + assert mode == "standard" + assert normalized is None + + +class TestSupersetExecutionStrategy: + """Test SubqueryExecution strategy registration""" + + def test_subquery_execution_registered(self): + """Test that SupersetExecution strategy is registered""" + strategies = ExecutionPlanFactory._strategies + strategy_names = [s.__class__.__name__ for s in strategies] + assert "SupersetExecution" in strategy_names + + def test_subquery_execution_supports_subqueries(self): + """Test that SupersetExecution supports subquery contexts""" + subquery_sql = "SELECT * FROM (SELECT id, name FROM users) AS u WHERE u.id > 10" + context = ExecutionContext(subquery_sql, "superset") + + superset_strategy = SupersetExecution() + assert superset_strategy.supports(context) is True + + def test_standard_execution_rejects_subqueries(self): + """Test that StandardExecution doesn't support subqueries""" + from pymongosql.executor import StandardExecution + + subquery_sql = "SELECT * FROM (SELECT id, name FROM users) AS u WHERE u.id > 10" + context = ExecutionContext(subquery_sql, "superset") + + standard_strategy = StandardExecution() + assert standard_strategy.supports(context) is False + + def test_get_strategy_selects_subquery_execution(self): + """Test that get_strategy returns SupersetExecution for subquery context""" + subquery_sql = "SELECT * FROM (SELECT id, name FROM users) AS u WHERE u.id > 10" + context = ExecutionContext(subquery_sql, "superset") + + strategy = ExecutionPlanFactory.get_strategy(context) + assert isinstance(strategy, SupersetExecution) + + def test_get_strategy_selects_standard_execution(self): + """Test that get_strategy returns StandardExecution for simple queries""" + from pymongosql.executor import StandardExecution + + simple_sql = "SELECT id, name FROM users WHERE id > 10" + context = ExecutionContext(simple_sql) + + strategy = ExecutionPlanFactory.get_strategy(context) + assert isinstance(strategy, StandardExecution) + + +class TestConnectionModeDetection: + """Test connection mode detection in Connection class""" + + def test_superset_mode_detection(self): + """Test that superset mode is correctly detected""" + from pymongosql.helper import ConnectionHelper + + is_superset, _ = ConnectionHelper.parse_connection_string("mongodb+superset://localhost:27017/testdb") + assert is_superset == "superset" + + def test_core_mode_detection(self): + """Test that core mode is correctly detected""" + from pymongosql.helper import ConnectionHelper + + is_core, _ = ConnectionHelper.parse_connection_string("mongodb://localhost:27017/testdb") + assert is_core == "standard" + + +class TestSubqueryExecutionIntegration: + """Integration tests for subquery execution with real MongoDB data""" + + def test_core_connection_with_subqueries(self, conn): + """Test that core connection with subquery execution""" + assert conn.mode == "standard" + + cursor = conn.cursor() + subquery_sql = "SELECT * FROM (SELECT _id, name FROM users) AS u WHERE u.age > 25" + + cursor.execute(subquery_sql) + rows = cursor.fetchall() + assert len(rows) == 0 + + def test_core_connection_with_standard_queries(self, conn): + """Test simple query on users collection""" + cursor = conn.cursor() + cursor.execute("SELECT _id, name, age FROM users WHERE age > 25") + + rows = cursor.fetchall() + assert len(rows) > 0 + + # Verify column names + description = cursor.description + col_names = [desc[0] for desc in description] if description else [] + assert "_id" in col_names or "id" in col_names + assert "name" in col_names + assert "age" in col_names + + def test_subquery_simple_wrapping(self, superset_conn): + """Test simple subquery wrapping on users""" + assert superset_conn.mode == "superset" + + cursor = superset_conn.cursor() + + # Simple subquery: wrap a MongoDB query result + subquery_sql = "SELECT * FROM (SELECT _id, name, age FROM users) AS u LIMIT 5" + + cursor.execute(subquery_sql) + rows = cursor.fetchall() + assert len(rows) == 5 + + def test_subquery_with_where_condition(self, superset_conn): + """Test subquery with WHERE on wrapper""" + cursor = superset_conn.cursor() + + # Subquery: select from users, then filter in wrapper + subquery_sql = "SELECT * FROM (SELECT _id, name, age FROM users) AS u WHERE age > 30" + + cursor.execute(subquery_sql) + rows = cursor.fetchall() + # Should have results where age > 30 + assert len(rows) == 11 + + def test_subquery_products_by_price_range(self, superset_conn): + """Test subquery filtering products by price range""" + cursor = superset_conn.cursor() + + # Subquery: get products, filter by price range in wrapper + subquery_sql = """ + SELECT * FROM (SELECT _id, name, price, category FROM products WHERE price > 100) + AS p WHERE price < 2000 LIMIT 10 + """ + + cursor.execute(subquery_sql) + rows = cursor.fetchall() + assert len(rows) == 10 + + def test_subquery_orders_aggregation(self, superset_conn): + """Test subquery on orders with multiple conditions""" + cursor = superset_conn.cursor() + + # Subquery: get orders, then filter for high-value completed orders + subquery_sql = """ + SELECT * FROM (SELECT _id, user_id, total_amount, status FROM orders) + AS o WHERE status = 'completed' LIMIT 18 + """ + + cursor.execute(subquery_sql) + rows = cursor.fetchall() + assert len(rows) == 18 + + def test_multiple_queries_in_session(self, superset_conn): + """Test running multiple queries in single superset session""" + cursor = superset_conn.cursor() + + # Query 1: Users + cursor.execute("SELECT _id, name, age FROM users LIMIT 3") + users = cursor.fetchall() + assert len(users) == 3 + + # Query 2: Orders + cursor.execute("SELECT _id, status FROM orders LIMIT 3") + orders = cursor.fetchall() + assert len(orders) == 3 + + # Query 3: Products + cursor.execute("SELECT _id, name, price FROM products LIMIT 3") + products = cursor.fetchall() + assert len(products) == 3