Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions sidemantic/adapters/omni.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def _parse_view(self, file_path: Path) -> Model | None:
dimensions = []
primary_key = "id" # default

for dim_name, dim_def in view.get("dimensions", {}).items():
for dim_name, dim_def in (view.get("dimensions") or {}).items():
if dim_def is None:
dim_def = {}

Expand All @@ -124,7 +124,7 @@ def _parse_view(self, file_path: Path) -> Model | None:

# Parse measures
metrics = []
for measure_name, measure_def in view.get("measures", {}).items():
for measure_name, measure_def in (view.get("measures") or {}).items():
if measure_def is None:
measure_def = {}

Expand Down
7 changes: 7 additions & 0 deletions sidemantic/adapters/rill.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,16 @@ def parse(self, source: str | Path) -> SemanticGraph:

Returns:
SemanticGraph containing the parsed models

Raises:
FileNotFoundError: If the source path does not exist
"""
source_path = Path(source)

# Check if path exists first - fail loudly on configuration errors
if not source_path.exists():
raise FileNotFoundError(f"Path does not exist: {source_path}")

graph = SemanticGraph()
if source_path.is_file():
model = self._parse_file(source_path)
Expand Down
21 changes: 20 additions & 1 deletion sidemantic/core/table_calculation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@
similar to LookML table calculations or Excel formulas.
"""

import ast
import re
from typing import Literal

from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, model_validator


class TableCalculation(BaseModel):
Expand Down Expand Up @@ -51,5 +53,22 @@ class TableCalculation(BaseModel):
None, description="Window size for moving average (e.g., 7 for 7-day moving average)"
)

# Percentile value (0-1)
percentile: float | None = Field(
None, description="Percentile value between 0 and 1 (e.g., 0.5 for median, 0.95 for p95)"
)

@model_validator(mode="after")
def validate_formula_expression(self) -> "TableCalculation":
"""Validate formula expression syntax at creation time."""
if self.type == "formula" and self.expression is not None:
# Replace field references with placeholder numbers for syntax validation
test_expr = re.sub(r"\$\{[^}]+\}", "1", self.expression)
try:
ast.parse(test_expr, mode="eval")
except SyntaxError as e:
raise ValueError(f"Invalid formula expression syntax: {self.expression!r}") from e
return self

def __hash__(self) -> int:
return hash(self.name)
21 changes: 19 additions & 2 deletions sidemantic/core/time_intelligence.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from typing import Literal

from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, field_validator

TimeComparisonType = Literal[
"yoy", # Year over year
Expand Down Expand Up @@ -37,14 +37,31 @@ class TimeComparison(BaseModel):
"percent_change", description="How to calculate the comparison"
)

@field_validator("offset")
@classmethod
def validate_offset_not_zero(cls, v: int | None) -> int | None:
"""Validate that offset is not zero.

Zero offset would mean comparing a period to itself, which doesn't
make practical sense for time comparisons. Users should explicitly
get an error rather than having their input silently changed.
"""
if v == 0:
raise ValueError(
"offset cannot be 0. Time comparisons require a non-zero offset "
"to compare against a different time period. Use offset >= 1 for "
"past comparisons or offset <= -1 for future comparisons."
)
return v

@property
def offset_interval(self) -> tuple[int, str]:
"""Get the offset interval for this comparison.

Returns:
(amount, unit) tuple for SQL INTERVAL
"""
if self.offset and self.offset_unit:
if self.offset is not None and self.offset_unit is not None:
return (self.offset, self.offset_unit)

# Default offsets for standard comparisons
Expand Down
36 changes: 36 additions & 0 deletions sidemantic/db/base.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,44 @@
"""Base database adapter interface."""

import re
from abc import ABC, abstractmethod
from typing import Any

# Pattern for valid SQL identifiers: starts with letter or underscore,
# followed by letters, digits, or underscores. Also allows dots for
# qualified names (schema.table).
_IDENTIFIER_PATTERN = re.compile(r"^[a-zA-Z_][a-zA-Z0-9_]*(\.[a-zA-Z_][a-zA-Z0-9_]*)*$")


def validate_identifier(value: str, name: str = "identifier") -> str:
"""Validate that a value is a safe SQL identifier.

Prevents SQL injection by ensuring identifiers only contain safe characters.
Allows: letters, digits, underscores, and dots (for qualified names).
Must start with a letter or underscore.

Args:
value: The identifier value to validate
name: Human-readable name for error messages (e.g., "table name", "schema")

Returns:
The validated identifier (unchanged if valid)

Raises:
ValueError: If the identifier contains invalid characters
"""
if not value:
raise ValueError(f"Invalid {name}: cannot be empty")

if not _IDENTIFIER_PATTERN.match(value):
raise ValueError(
f"Invalid {name}: '{value}'. "
f"Identifiers must start with a letter or underscore and contain only "
f"letters, digits, underscores, and dots."
)

return value


class BaseDatabaseAdapter(ABC):
"""Abstract base class for database adapters.
Expand Down
6 changes: 5 additions & 1 deletion sidemantic/db/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from typing import Any

from sidemantic.db.base import BaseDatabaseAdapter
from sidemantic.db.base import BaseDatabaseAdapter, validate_identifier


class BigQueryResult:
Expand Down Expand Up @@ -142,9 +142,13 @@ def get_tables(self) -> list[dict]:

def get_columns(self, table_name: str, schema: str | None = None) -> list[dict]:
"""Get column information for a table."""
# Validate identifiers for consistency and defense in depth
# (BigQuery API handles these, but validation catches bad input early)
validate_identifier(table_name, "table name")
schema = schema or self.dataset_id
if not schema:
raise ValueError("schema (dataset_id) required for get_columns")
validate_identifier(schema, "schema")

table_ref = self.client.dataset(schema).table(table_name)
table = self.client.get_table(table_ref)
Expand Down
11 changes: 10 additions & 1 deletion sidemantic/db/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Any
from urllib.parse import parse_qs, unquote, urlparse

from sidemantic.db.base import BaseDatabaseAdapter
from sidemantic.db.base import BaseDatabaseAdapter, validate_identifier


class DatabricksResult:
Expand Down Expand Up @@ -133,8 +133,12 @@ def fetch_record_batch(self, result: DatabricksResult) -> Any:
def get_tables(self) -> list[dict]:
"""List all tables in the catalog/schema."""
if self.schema:
# Validate schema to prevent SQL injection
validate_identifier(self.schema, "schema")
sql = f"SHOW TABLES IN {self.schema}"
elif self.catalog:
# Validate catalog to prevent SQL injection
validate_identifier(self.catalog, "catalog")
sql = f"SHOW TABLES IN {self.catalog}"
else:
sql = "SHOW TABLES"
Expand All @@ -145,7 +149,12 @@ def get_tables(self) -> list[dict]:

def get_columns(self, table_name: str, schema: str | None = None) -> list[dict]:
"""Get column information for a table."""
# Validate identifiers to prevent SQL injection
validate_identifier(table_name, "table name")
schema = schema or self.schema
if schema:
validate_identifier(schema, "schema")

table_ref = f"{schema}.{table_name}" if schema else table_name

sql = f"DESCRIBE {table_ref}"
Expand Down
9 changes: 7 additions & 2 deletions sidemantic/db/duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import duckdb

from sidemantic.db.base import BaseDatabaseAdapter
from sidemantic.db.base import BaseDatabaseAdapter, validate_identifier


class DuckDBAdapter(BaseDatabaseAdapter):
Expand Down Expand Up @@ -51,7 +51,12 @@ def get_tables(self) -> list[dict]:

def get_columns(self, table_name: str, schema: str | None = None) -> list[dict]:
"""Get columns for a table."""
schema_filter = f"AND table_schema = '{schema}'" if schema else ""
# Validate identifiers to prevent SQL injection
validate_identifier(table_name, "table name")
if schema:
validate_identifier(schema, "schema")

schema_filter = f"AND schema_name = '{schema}'" if schema else ""
result = self.conn.execute(
f"""
SELECT column_name, data_type
Expand Down
7 changes: 6 additions & 1 deletion sidemantic/db/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Any
from urllib.parse import parse_qs, urlparse

from sidemantic.db.base import BaseDatabaseAdapter
from sidemantic.db.base import BaseDatabaseAdapter, validate_identifier


class PostgresResult:
Expand Down Expand Up @@ -131,6 +131,11 @@ def get_tables(self) -> list[dict]:

def get_columns(self, table_name: str, schema: str | None = None) -> list[dict]:
"""Get columns for a table."""
# Validate identifiers to prevent SQL injection
validate_identifier(table_name, "table name")
if schema:
validate_identifier(schema, "schema")

schema_filter = f"AND table_schema = '{schema}'" if schema else ""
result = self.execute(
f"""
Expand Down
9 changes: 8 additions & 1 deletion sidemantic/db/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Any
from urllib.parse import parse_qs, unquote, urlparse

from sidemantic.db.base import BaseDatabaseAdapter
from sidemantic.db.base import BaseDatabaseAdapter, validate_identifier


class SnowflakeResult:
Expand Down Expand Up @@ -140,6 +140,8 @@ def fetch_record_batch(self, result: SnowflakeResult) -> Any:
def get_tables(self) -> list[dict]:
"""List all tables in the database/schema."""
if self.schema:
# Validate schema to prevent SQL injection
validate_identifier(self.schema, "schema")
sql = f"""
SELECT table_name, table_schema as schema
FROM information_schema.tables
Expand All @@ -165,7 +167,12 @@ def get_tables(self) -> list[dict]:

def get_columns(self, table_name: str, schema: str | None = None) -> list[dict]:
"""Get column information for a table."""
# Validate identifiers to prevent SQL injection
validate_identifier(table_name, "table name")
schema = schema or self.schema
if schema:
validate_identifier(schema, "schema")

schema_filter = f"AND table_schema = '{schema}'" if schema else ""

sql = f"""
Expand Down
9 changes: 8 additions & 1 deletion sidemantic/db/spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Any
from urllib.parse import parse_qs, unquote, urlparse

from sidemantic.db.base import BaseDatabaseAdapter
from sidemantic.db.base import BaseDatabaseAdapter, validate_identifier


class SparkResult:
Expand Down Expand Up @@ -126,14 +126,21 @@ def fetch_record_batch(self, result: SparkResult) -> Any:

def get_tables(self) -> list[dict]:
"""List all tables in the database."""
# Validate database to prevent SQL injection
validate_identifier(self.database, "database")
sql = f"SHOW TABLES IN {self.database}"
result = self.execute(sql)
rows = result.fetchall()
return [{"table_name": row[1], "schema": row[0]} for row in rows]

def get_columns(self, table_name: str, schema: str | None = None) -> list[dict]:
"""Get column information for a table."""
# Validate identifiers to prevent SQL injection
validate_identifier(table_name, "table name")
schema = schema or self.database
if schema:
validate_identifier(schema, "schema")

table_ref = f"{schema}.{table_name}" if schema else table_name

sql = f"DESCRIBE {table_ref}"
Expand Down
Loading