Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
263 changes: 120 additions & 143 deletions pkg-py/src/querychat/_datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def format_schema(table_name: str, columns: list[ColumnMeta]) -> str:
for col in columns:
lines.append(f"- {col.name} ({col.sql_type})")

if col.kind in ("numeric", "date"):
if col.kind in ("numeric", "date") and col.min_val is not None and col.max_val is not None:
lines.append(f" Range: {col.min_val} to {col.max_val}")
elif col.categories:
cats = ", ".join(f"'{v}'" for v in col.categories)
Expand Down Expand Up @@ -453,9 +453,9 @@ def __init__(
if not inspector.has_table(table_name):
raise ValueError(f"Table '{table_name}' not found in database")

# Store original column names for validation
columns_info = inspector.get_columns(table_name)
self._colnames = [col["name"] for col in columns_info]
# Store column info for schema generation
self._columns_info = inspector.get_columns(table_name)
self._colnames = [col["name"] for col in self._columns_info]
Comment thread
cpsievert marked this conversation as resolved.

def get_db_type(self) -> str:
"""
Expand All @@ -466,140 +466,138 @@ def get_db_type(self) -> str:
"""
return self._engine.dialect.name.upper().replace(" SQL", "")

def get_schema(self, *, categorical_threshold: int) -> str: # noqa: PLR0912
def get_schema(self, *, categorical_threshold: int) -> str:
"""
Generate schema information from database table.

Returns:
Parameters
----------
categorical_threshold
Maximum number of unique values for a text column to be considered
categorical

Returns
-------
:
String describing the schema

"""
inspector = inspect(self._engine)
columns = inspector.get_columns(self.table_name)
columns = [
self._make_column_meta(col["name"], col["type"])
for col in self._columns_info
]
self._add_column_stats(columns, categorical_threshold)
return format_schema(self.table_name, columns)

schema = [f"Table: {self.table_name}", "Columns:"]
@staticmethod
def _make_column_meta(name: str, sa_type: sqltypes.TypeEngine) -> ColumnMeta:
"""Create ColumnMeta from SQLAlchemy type."""
kind: Literal["numeric", "text", "date", "other"]

# Build a single query to get all column statistics
select_parts = []
numeric_columns = []
text_columns = []
if isinstance(sa_type, (sqltypes.Integer, sqltypes.BigInteger, sqltypes.SmallInteger)):
kind = "numeric"
sql_type = "INTEGER"
elif isinstance(sa_type, sqltypes.Float):
kind = "numeric"
sql_type = "FLOAT"
elif isinstance(sa_type, sqltypes.Numeric):
kind = "numeric"
sql_type = "NUMERIC"
elif isinstance(sa_type, (sqltypes.String, sqltypes.Text, sqltypes.Enum)):
kind = "text"
sql_type = "TEXT"
elif isinstance(sa_type, sqltypes.Date):
kind = "date"
sql_type = "DATE"
elif isinstance(sa_type, sqltypes.DateTime):
kind = "date"
sql_type = "TIMESTAMP"
elif isinstance(sa_type, sqltypes.Time):
kind = "date"
sql_type = "TIME"
elif isinstance(sa_type, sqltypes.Boolean):
kind = "other"
sql_type = "BOOLEAN"
else:
kind = "other"
sql_type = sa_type.__class__.__name__.upper()

return ColumnMeta(name=name, sql_type=sql_type, kind=kind)

def _add_column_stats(
self,
columns: list[ColumnMeta],
categorical_threshold: int,
) -> None:
"""Add min/max/categories to column metadata using SQL queries."""
# Build aggregate expressions for stats query
select_parts = []
for col in columns:
col_name = col["name"]

# Check if column is numeric
if isinstance(
col["type"],
(
sqltypes.Integer,
sqltypes.Numeric,
sqltypes.Float,
sqltypes.Date,
sqltypes.Time,
sqltypes.DateTime,
sqltypes.BigInteger,
sqltypes.SmallInteger,
),
):
numeric_columns.append(col_name)
select_parts.extend(
[
f"MIN({col_name}) as {col_name}__min",
f"MAX({col_name}) as {col_name}__max",
],
)
if col.kind in ("numeric", "date"):
select_parts.append(f"MIN({col.name}) as {col.name}__min")
select_parts.append(f"MAX({col.name}) as {col.name}__max")
elif col.kind == "text":
select_parts.append(f"COUNT(DISTINCT {col.name}) as {col.name}__nunique")

# Check if column is text/string
elif isinstance(
col["type"],
(sqltypes.String, sqltypes.Text, sqltypes.Enum),
):
text_columns.append(col_name)
select_parts.append(
f"COUNT(DISTINCT {col_name}) as {col_name}__distinct_count",
)
if not select_parts:
return

# Execute single query to get all statistics
column_stats = {}
if select_parts:
try:
stats_query = text(
f"SELECT {', '.join(select_parts)} FROM {self.table_name}",
)
with self._get_connection() as conn:
result = conn.execute(stats_query).fetchone()
if result:
# Convert result to dict for easier access
column_stats = dict(zip(result._fields, result, strict=False))
except Exception: # noqa: S110
pass # Fall back to no statistics if query fails

# Get categorical values for text columns that are below threshold
categorical_values = {}
text_cols_to_query = []
for col_name in text_columns:
distinct_count_key = f"{col_name}__distinct_count"
if (
distinct_count_key in column_stats
and column_stats[distinct_count_key]
and column_stats[distinct_count_key] <= categorical_threshold
):
text_cols_to_query.append(col_name)

# Get categorical values in a single query if needed
if text_cols_to_query:
try:
# Build UNION query for all categorical columns
union_parts = [
f"SELECT '{col_name}' as column_name, {col_name} as value "
f"FROM {self.table_name} WHERE {col_name} IS NOT NULL "
f"GROUP BY {col_name}"
for col_name in text_cols_to_query
]

if union_parts:
categorical_query = text(" UNION ALL ".join(union_parts))
with self._get_connection() as conn:
results = conn.execute(categorical_query).fetchall()
for row in results:
col_name, value = row
if col_name not in categorical_values:
categorical_values[col_name] = []
categorical_values[col_name].append(str(value))
except Exception: # noqa: S110
pass # Skip categorical values if query fails

# Build schema description using collected statistics
# Execute stats query
stats = {}
try:
stats_query = text(f"SELECT {', '.join(select_parts)} FROM {self.table_name}")
with self._get_connection() as conn:
result = conn.execute(stats_query).fetchone()
if result:
stats = dict(zip(result._fields, result, strict=False))
except Exception:
return # Fall back to no statistics if query fails

# Populate min/max for numeric/date columns
for col in columns:
col_name = col["name"]
sql_type = self._get_sql_type_name(col["type"])
column_info = [f"- {col_name} ({sql_type})"]

# Add range info for numeric columns
if col_name in numeric_columns:
min_key = f"{col_name}__min"
max_key = f"{col_name}__max"
if (
min_key in column_stats
and max_key in column_stats
and column_stats[min_key] is not None
and column_stats[max_key] is not None
):
column_info.append(
f" Range: {column_stats[min_key]} to {column_stats[max_key]}",
)
if col.kind in ("numeric", "date"):
col.min_val = stats.get(f"{col.name}__min")
col.max_val = stats.get(f"{col.name}__max")

# Add categorical values for text columns
elif col_name in categorical_values:
values = categorical_values[col_name]
# Remove duplicates and sort
unique_values = sorted(set(values))
values_str = ", ".join([f"'{v}'" for v in unique_values])
column_info.append(f" Categorical values: {values_str}")
# Find text columns that qualify as categorical
categorical_cols = [
col for col in columns
if col.kind == "text"
and (nunique := stats.get(f"{col.name}__nunique"))
Comment thread
cpsievert marked this conversation as resolved.
and nunique <= categorical_threshold
]

schema.extend(column_info)
if not categorical_cols:
return

return "\n".join(schema)
# Fetch categorical values in a single UNION query
self._fetch_categorical_values(categorical_cols)

def _fetch_categorical_values(self, columns: list[ColumnMeta]) -> None:
"""Fetch unique values for categorical columns."""
union_parts = [
f"SELECT '{col.name}' as col_name, {col.name} as value "
f"FROM {self.table_name} WHERE {col.name} IS NOT NULL "
f"GROUP BY {col.name}"
for col in columns
]

try:
query = text(" UNION ALL ".join(union_parts))
with self._get_connection() as conn:
results = conn.execute(query).fetchall()

# Group values by column
values_by_col: dict[str, list[str]] = {}
for col_name, value in results:
values_by_col.setdefault(col_name, []).append(str(value))

# Assign to columns
for col in columns:
if col.name in values_by_col:
col.categories = sorted(set(values_by_col[col.name]))
except Exception: # noqa: S110
pass # Skip categorical values if query fails

def execute_query(self, query: str) -> nw.DataFrame:
"""
Expand Down Expand Up @@ -691,27 +689,6 @@ def get_data(self) -> nw.DataFrame:
"""
return self.execute_query(f"SELECT * FROM {self.table_name}")

def _get_sql_type_name(self, type_: sqltypes.TypeEngine) -> str: # noqa: PLR0911
"""Convert SQLAlchemy type to SQL type name."""
if isinstance(type_, sqltypes.Integer):
return "INTEGER"
elif isinstance(type_, sqltypes.Float):
return "FLOAT"
elif isinstance(type_, sqltypes.Numeric):
return "NUMERIC"
elif isinstance(type_, sqltypes.Boolean):
return "BOOLEAN"
elif isinstance(type_, sqltypes.DateTime):
return "TIMESTAMP"
elif isinstance(type_, sqltypes.Date):
return "DATE"
elif isinstance(type_, sqltypes.Time):
return "TIME"
elif isinstance(type_, (sqltypes.String, sqltypes.Text)):
return "TEXT"
else:
return type_.__class__.__name__.upper()

def _get_connection(self) -> Connection:
"""Get a connection to use for queries."""
return self._engine.connect()
Expand Down