diff --git a/pkg-py/src/querychat/_datasource.py b/pkg-py/src/querychat/_datasource.py index 51a8ff43c..e5bdcc936 100644 --- a/pkg-py/src/querychat/_datasource.py +++ b/pkg-py/src/querychat/_datasource.py @@ -56,7 +56,7 @@ def format_schema(table_name: str, columns: list[ColumnMeta]) -> str: for col in columns: lines.append(f"- {col.name} ({col.sql_type})") - if col.kind in ("numeric", "date"): + if col.kind in ("numeric", "date") and col.min_val is not None and col.max_val is not None: lines.append(f" Range: {col.min_val} to {col.max_val}") elif col.categories: cats = ", ".join(f"'{v}'" for v in col.categories) @@ -453,9 +453,9 @@ def __init__( if not inspector.has_table(table_name): raise ValueError(f"Table '{table_name}' not found in database") - # Store original column names for validation - columns_info = inspector.get_columns(table_name) - self._colnames = [col["name"] for col in columns_info] + # Store column info for schema generation + self._columns_info = inspector.get_columns(table_name) + self._colnames = [col["name"] for col in self._columns_info] def get_db_type(self) -> str: """ @@ -466,140 +466,138 @@ def get_db_type(self) -> str: """ return self._engine.dialect.name.upper().replace(" SQL", "") - def get_schema(self, *, categorical_threshold: int) -> str: # noqa: PLR0912 + def get_schema(self, *, categorical_threshold: int) -> str: """ Generate schema information from database table. - Returns: + Parameters + ---------- + categorical_threshold + Maximum number of unique values for a text column to be considered + categorical + + Returns + ------- + : String describing the schema """ - inspector = inspect(self._engine) - columns = inspector.get_columns(self.table_name) + columns = [ + self._make_column_meta(col["name"], col["type"]) + for col in self._columns_info + ] + self._add_column_stats(columns, categorical_threshold) + return format_schema(self.table_name, columns) - schema = [f"Table: {self.table_name}", "Columns:"] + @staticmethod + def _make_column_meta(name: str, sa_type: sqltypes.TypeEngine) -> ColumnMeta: + """Create ColumnMeta from SQLAlchemy type.""" + kind: Literal["numeric", "text", "date", "other"] - # Build a single query to get all column statistics - select_parts = [] - numeric_columns = [] - text_columns = [] + if isinstance(sa_type, (sqltypes.Integer, sqltypes.BigInteger, sqltypes.SmallInteger)): + kind = "numeric" + sql_type = "INTEGER" + elif isinstance(sa_type, sqltypes.Float): + kind = "numeric" + sql_type = "FLOAT" + elif isinstance(sa_type, sqltypes.Numeric): + kind = "numeric" + sql_type = "NUMERIC" + elif isinstance(sa_type, (sqltypes.String, sqltypes.Text, sqltypes.Enum)): + kind = "text" + sql_type = "TEXT" + elif isinstance(sa_type, sqltypes.Date): + kind = "date" + sql_type = "DATE" + elif isinstance(sa_type, sqltypes.DateTime): + kind = "date" + sql_type = "TIMESTAMP" + elif isinstance(sa_type, sqltypes.Time): + kind = "date" + sql_type = "TIME" + elif isinstance(sa_type, sqltypes.Boolean): + kind = "other" + sql_type = "BOOLEAN" + else: + kind = "other" + sql_type = sa_type.__class__.__name__.upper() + return ColumnMeta(name=name, sql_type=sql_type, kind=kind) + + def _add_column_stats( + self, + columns: list[ColumnMeta], + categorical_threshold: int, + ) -> None: + """Add min/max/categories to column metadata using SQL queries.""" + # Build aggregate expressions for stats query + select_parts = [] for col in columns: - col_name = col["name"] - - # Check if column is numeric - if isinstance( - col["type"], - ( - sqltypes.Integer, - sqltypes.Numeric, - sqltypes.Float, - sqltypes.Date, - sqltypes.Time, - sqltypes.DateTime, - sqltypes.BigInteger, - sqltypes.SmallInteger, - ), - ): - numeric_columns.append(col_name) - select_parts.extend( - [ - f"MIN({col_name}) as {col_name}__min", - f"MAX({col_name}) as {col_name}__max", - ], - ) + if col.kind in ("numeric", "date"): + select_parts.append(f"MIN({col.name}) as {col.name}__min") + select_parts.append(f"MAX({col.name}) as {col.name}__max") + elif col.kind == "text": + select_parts.append(f"COUNT(DISTINCT {col.name}) as {col.name}__nunique") - # Check if column is text/string - elif isinstance( - col["type"], - (sqltypes.String, sqltypes.Text, sqltypes.Enum), - ): - text_columns.append(col_name) - select_parts.append( - f"COUNT(DISTINCT {col_name}) as {col_name}__distinct_count", - ) + if not select_parts: + return - # Execute single query to get all statistics - column_stats = {} - if select_parts: - try: - stats_query = text( - f"SELECT {', '.join(select_parts)} FROM {self.table_name}", - ) - with self._get_connection() as conn: - result = conn.execute(stats_query).fetchone() - if result: - # Convert result to dict for easier access - column_stats = dict(zip(result._fields, result, strict=False)) - except Exception: # noqa: S110 - pass # Fall back to no statistics if query fails - - # Get categorical values for text columns that are below threshold - categorical_values = {} - text_cols_to_query = [] - for col_name in text_columns: - distinct_count_key = f"{col_name}__distinct_count" - if ( - distinct_count_key in column_stats - and column_stats[distinct_count_key] - and column_stats[distinct_count_key] <= categorical_threshold - ): - text_cols_to_query.append(col_name) - - # Get categorical values in a single query if needed - if text_cols_to_query: - try: - # Build UNION query for all categorical columns - union_parts = [ - f"SELECT '{col_name}' as column_name, {col_name} as value " - f"FROM {self.table_name} WHERE {col_name} IS NOT NULL " - f"GROUP BY {col_name}" - for col_name in text_cols_to_query - ] - - if union_parts: - categorical_query = text(" UNION ALL ".join(union_parts)) - with self._get_connection() as conn: - results = conn.execute(categorical_query).fetchall() - for row in results: - col_name, value = row - if col_name not in categorical_values: - categorical_values[col_name] = [] - categorical_values[col_name].append(str(value)) - except Exception: # noqa: S110 - pass # Skip categorical values if query fails - - # Build schema description using collected statistics + # Execute stats query + stats = {} + try: + stats_query = text(f"SELECT {', '.join(select_parts)} FROM {self.table_name}") + with self._get_connection() as conn: + result = conn.execute(stats_query).fetchone() + if result: + stats = dict(zip(result._fields, result, strict=False)) + except Exception: + return # Fall back to no statistics if query fails + + # Populate min/max for numeric/date columns for col in columns: - col_name = col["name"] - sql_type = self._get_sql_type_name(col["type"]) - column_info = [f"- {col_name} ({sql_type})"] - - # Add range info for numeric columns - if col_name in numeric_columns: - min_key = f"{col_name}__min" - max_key = f"{col_name}__max" - if ( - min_key in column_stats - and max_key in column_stats - and column_stats[min_key] is not None - and column_stats[max_key] is not None - ): - column_info.append( - f" Range: {column_stats[min_key]} to {column_stats[max_key]}", - ) + if col.kind in ("numeric", "date"): + col.min_val = stats.get(f"{col.name}__min") + col.max_val = stats.get(f"{col.name}__max") - # Add categorical values for text columns - elif col_name in categorical_values: - values = categorical_values[col_name] - # Remove duplicates and sort - unique_values = sorted(set(values)) - values_str = ", ".join([f"'{v}'" for v in unique_values]) - column_info.append(f" Categorical values: {values_str}") + # Find text columns that qualify as categorical + categorical_cols = [ + col for col in columns + if col.kind == "text" + and (nunique := stats.get(f"{col.name}__nunique")) + and nunique <= categorical_threshold + ] - schema.extend(column_info) + if not categorical_cols: + return - return "\n".join(schema) + # Fetch categorical values in a single UNION query + self._fetch_categorical_values(categorical_cols) + + def _fetch_categorical_values(self, columns: list[ColumnMeta]) -> None: + """Fetch unique values for categorical columns.""" + union_parts = [ + f"SELECT '{col.name}' as col_name, {col.name} as value " + f"FROM {self.table_name} WHERE {col.name} IS NOT NULL " + f"GROUP BY {col.name}" + for col in columns + ] + + try: + query = text(" UNION ALL ".join(union_parts)) + with self._get_connection() as conn: + results = conn.execute(query).fetchall() + + # Group values by column + values_by_col: dict[str, list[str]] = {} + for col_name, value in results: + values_by_col.setdefault(col_name, []).append(str(value)) + + # Assign to columns + for col in columns: + if col.name in values_by_col: + col.categories = sorted(set(values_by_col[col.name])) + except Exception: # noqa: S110 + pass # Skip categorical values if query fails def execute_query(self, query: str) -> nw.DataFrame: """ @@ -691,27 +689,6 @@ def get_data(self) -> nw.DataFrame: """ return self.execute_query(f"SELECT * FROM {self.table_name}") - def _get_sql_type_name(self, type_: sqltypes.TypeEngine) -> str: # noqa: PLR0911 - """Convert SQLAlchemy type to SQL type name.""" - if isinstance(type_, sqltypes.Integer): - return "INTEGER" - elif isinstance(type_, sqltypes.Float): - return "FLOAT" - elif isinstance(type_, sqltypes.Numeric): - return "NUMERIC" - elif isinstance(type_, sqltypes.Boolean): - return "BOOLEAN" - elif isinstance(type_, sqltypes.DateTime): - return "TIMESTAMP" - elif isinstance(type_, sqltypes.Date): - return "DATE" - elif isinstance(type_, sqltypes.Time): - return "TIME" - elif isinstance(type_, (sqltypes.String, sqltypes.Text)): - return "TEXT" - else: - return type_.__class__.__name__.upper() - def _get_connection(self) -> Connection: """Get a connection to use for queries.""" return self._engine.connect()