diff --git a/Cargo.lock b/Cargo.lock index 02ae11a..bc720ff 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3184,7 +3184,10 @@ name = "tern-codegen" version = "0.1.0" dependencies = [ "genco", + "insta", + "pretty_assertions", "tern-ddl", + "thiserror 2.0.18", ] [[package]] diff --git a/crates/codegen/Cargo.toml b/crates/codegen/Cargo.toml index d77ef6b..2647abb 100644 --- a/crates/codegen/Cargo.toml +++ b/crates/codegen/Cargo.toml @@ -6,3 +6,8 @@ edition = "2024" [dependencies] genco = "0.19.0" tern-ddl = { path = "../ddl" } +thiserror = "2.0" + +[dev-dependencies] +insta = { version = "1.42", features = ["yaml"] } +pretty_assertions = "1.4" diff --git a/crates/codegen/PYTHON_SQLMODEL_DESIGN.md b/crates/codegen/PYTHON_SQLMODEL_DESIGN.md index a001270..99dcd57 100644 --- a/crates/codegen/PYTHON_SQLMODEL_DESIGN.md +++ b/crates/codegen/PYTHON_SQLMODEL_DESIGN.md @@ -1,22 +1,31 @@ # Python SQLModel Code Generator Design +## Implementation Status + +> **Status Legend**: +> - [x] Completed +> - [ ] Future work (not yet implemented) +> - [~] Partially implemented + ## Overview This document outlines the design for implementing a Python code generator that converts PostgreSQL DDL schema definitions (`Vec`) into SQLModel and Pydantic models. The generator will implement the `Codegen` trait and use the `genco` crate for Python code generation. ## Goals -1. Generate idiomatic SQLModel models from PostgreSQL table definitions -2. Properly handle all PostgreSQL types with appropriate Python type mappings -3. Support primary keys, foreign keys, unique constraints, and indexes -4. Generate both table models (with `table=True`) and optional Pydantic-only models for validation -5. Produce well-formatted, readable Python code with correct imports -6. Handle edge cases robustly (reserved words, special characters, etc.) +1. [x] Generate idiomatic SQLModel models from PostgreSQL table definitions +2. [x] Properly handle all PostgreSQL types with appropriate Python type mappings +3. [x] Support primary keys, foreign keys, unique constraints, and indexes +4. [ ] Generate both table models (with `table=True`) and optional Pydantic-only models for validation +5. [x] Produce well-formatted, readable Python code with correct imports +6. [x] Handle edge cases robustly (reserved words, special characters, etc.) ## Architecture ### Module Structure +[x] **Completed** - All modules implemented as designed. + ``` crates/codegen/src/python/ ├── mod.rs # Module root, re-exports, PythonCodegen struct @@ -29,30 +38,36 @@ crates/codegen/src/python/ └── tests/ # Test submodule ├── mod.rs # Test utilities and common fixtures ├── unit_tests.rs # Unit tests for individual components + ├── snapshot_tests.rs # Snapshot tests for full generation output └── snapshots/ # Snapshot test files (managed by insta) ``` ### Core Types +[x] **Completed** - All core types implemented in `mod.rs`. + ```rust /// Configuration for Python code generation. #[derive(Debug, Clone)] pub struct PythonCodegenConfig { /// Whether to generate Pydantic-only base classes for each model. /// These are useful for request/response validation without DB coupling. - pub generate_base_models: bool, + pub generate_base_models: bool, // [ ] Future work /// Module name prefix for generated imports (e.g., "app.models"). - pub module_prefix: Option, + pub module_prefix: Option, // [ ] Future work /// Whether to include docstrings from table/column comments. - pub include_docstrings: bool, + pub include_docstrings: bool, // [x] Completed /// How to handle Python reserved words in identifiers. - pub reserved_word_strategy: ReservedWordStrategy, + pub reserved_word_strategy: ReservedWordStrategy, // [x] Completed /// Whether to generate relationship attributes for foreign keys. - pub generate_relationships: bool, + pub generate_relationships: bool, // [ ] Future work + + /// Output mode (single file or multi-file). + pub output_mode: OutputMode, // [x] Completed } #[derive(Debug, Clone, Default)] @@ -78,39 +93,43 @@ impl Codegen for PythonCodegen { ### PostgreSQL to Python Type Mapping -| PostgreSQL Type | Python Type | SQLModel Field | Notes | -|-----------------|-------------|----------------|-------| -| `integer`, `int4` | `int` | `Field()` | | -| `bigint`, `int8` | `int` | `Field()` | Python int handles arbitrary precision | -| `smallint`, `int2` | `int` | `Field()` | | -| `serial`, `serial4` | `int` | `Field(default=None, primary_key=True)` | Auto-increment | -| `bigserial`, `serial8` | `int` | `Field(default=None, primary_key=True)` | | -| `boolean`, `bool` | `bool` | `Field()` | | -| `text` | `str` | `Field()` | | -| `varchar(n)`, `character varying` | `str` | `Field(max_length=n)` | Validate length | -| `char(n)`, `character` | `str` | `Field(min_length=n, max_length=n)` | Fixed length | -| `numeric`, `decimal` | `Decimal` | `Field()` | `from decimal import Decimal` | -| `real`, `float4` | `float` | `Field()` | | -| `double precision`, `float8` | `float` | `Field()` | | -| `date` | `date` | `Field()` | `from datetime import date` | -| `time`, `time without time zone` | `time` | `Field()` | `from datetime import time` | -| `timetz`, `time with time zone` | `time` | `Field()` | | -| `timestamp`, `timestamp without time zone` | `datetime` | `Field()` | `from datetime import datetime` | -| `timestamptz`, `timestamp with time zone` | `datetime` | `Field()` | | -| `interval` | `timedelta` | `Field()` | `from datetime import timedelta` | -| `uuid` | `UUID` | `Field()` | `from uuid import UUID` | -| `json` | `Any` | `Field(sa_type=JSON)` | `from typing import Any` | -| `jsonb` | `Any` | `Field(sa_type=JSON)` | | -| `bytea` | `bytes` | `Field()` | | -| `inet` | `str` | `Field()` | IP address as string | -| `cidr` | `str` | `Field()` | | -| `macaddr` | `str` | `Field()` | | -| `point`, `line`, etc. | `str` | `Field()` | Geometric as string | -| `array` (e.g., `integer[]`) | `list[int]` | `Field(sa_type=ARRAY(Integer))` | | -| User-defined enum | Literal union or Python Enum | `Field()` | See enum handling | +[x] **Completed** - All types in this table are implemented in `type_mapping.rs`. + +| PostgreSQL Type | Python Type | SQLModel Field | Status | +|-----------------|-------------|----------------|--------| +| `integer`, `int4` | `int` | `Field()` | [x] | +| `bigint`, `int8` | `int` | `Field()` | [x] | +| `smallint`, `int2` | `int` | `Field()` | [x] | +| `serial`, `serial4` | `int` | `Field(default=None, primary_key=True)` | [x] | +| `bigserial`, `serial8` | `int` | `Field(default=None, primary_key=True)` | [x] | +| `boolean`, `bool` | `bool` | `Field()` | [x] | +| `text` | `str` | `Field()` | [x] | +| `varchar(n)`, `character varying` | `str` | `Field(max_length=n)` | [x] (length extracted but not yet used in Field) | +| `char(n)`, `character` | `str` | `Field(min_length=n, max_length=n)` | [x] (length extracted but not yet used in Field) | +| `numeric`, `decimal` | `Decimal` | `Field()` | [x] | +| `real`, `float4` | `float` | `Field()` | [x] | +| `double precision`, `float8` | `float` | `Field()` | [x] | +| `date` | `date` | `Field()` | [x] | +| `time`, `time without time zone` | `time` | `Field()` | [x] | +| `timetz`, `time with time zone` | `time` | `Field()` | [x] | +| `timestamp`, `timestamp without time zone` | `datetime` | `Field()` | [x] | +| `timestamptz`, `timestamp with time zone` | `datetime` | `Field()` | [x] | +| `interval` | `timedelta` | `Field()` | [x] | +| `uuid` | `UUID` | `Field()` | [x] | +| `json` | `dict[str, Any]` | `Field(sa_type=JSON)` | [x] | +| `jsonb` | `dict[str, Any]` | `Field(sa_type=JSON)` | [x] | +| `bytea` | `bytes` | `Field()` | [x] | +| `inet` | `str` | `Field()` | [x] | +| `cidr` | `str` | `Field()` | [x] | +| `macaddr` | `str` | `Field()` | [x] | +| `point`, `line`, etc. | `str` | `Field()` | [x] | +| `array` (e.g., `integer[]`) | `list[int]` | `Field(sa_type=ARRAY(Integer))` | [x] | +| User-defined enum | Literal union or Python Enum | `Field()` | [ ] Future work | ### Handling User-Defined Types +[ ] **Future Work** - User-defined enum types are not yet supported. + For user-defined enum types: ```python from typing import Literal @@ -133,6 +152,8 @@ class UserStatus(str, Enum): ### Basic Field Patterns +[x] **Completed** - All basic field patterns implemented in `field.rs`. + ```python # Non-nullable without default name: str @@ -146,7 +167,7 @@ status: str = Field(default="active") # Primary key (nullable with None default for auto-generation) id: int | None = Field(default=None, primary_key=True) -# With index +# With index [ ] Future work - index=True not yet generated for single columns email: str = Field(index=True) # With unique constraint @@ -158,12 +179,15 @@ user_id: int | None = Field(default=None, foreign_key="users.id") ### Generated/Identity Columns +[x] **Completed** - Identity columns handled correctly. +[~] **Partial** - Generated columns emit warnings but don't fully support `Computed`. + ```python # Identity column (GENERATED ALWAYS AS IDENTITY) id: int | None = Field(default=None, primary_key=True) # Note: SQLModel handles auto-increment through primary_key=True with None default -# Generated column (STORED) +# Generated column (STORED) - [ ] Future work for full Computed support # SQLModel doesn't have native support; use sa_column full_name: str = Field( default=None, @@ -175,6 +199,8 @@ full_name: str = Field( #### Primary Key +[x] **Completed** + ```python # Single column id: int | None = Field(default=None, primary_key=True) @@ -188,17 +214,22 @@ class OrderItem(SQLModel, table=True): #### Foreign Key +[x] **Completed** - Basic FK support. +[ ] **Future Work** - Relationship generation. + ```python # Basic FK user_id: int | None = Field(default=None, foreign_key="users.id") -# With relationship +# With relationship - [ ] Future work user_id: int | None = Field(default=None, foreign_key="users.id") user: "User | None" = Relationship(back_populates="posts") ``` #### Unique Constraint +[x] **Completed** + ```python # Single column email: str = Field(unique=True) @@ -211,6 +242,8 @@ __table_args__ = ( #### Check Constraint +[x] **Completed** + ```python # Via __table_args__ __table_args__ = ( @@ -220,11 +253,13 @@ __table_args__ = ( ### Index Handling +[~] **Partial** - Multi-column indexes via `__table_args__` completed. Single-column `Field(index=True)` not yet implemented. + ```python -# Simple column index +# Simple column index - [ ] Future work email: str = Field(index=True) -# Composite or complex indexes - via __table_args__ +# Composite or complex indexes - via __table_args__ [x] Completed __table_args__ = ( Index("idx_users_name_email", "name", "email"), ) @@ -234,6 +269,8 @@ __table_args__ = ( ### Single File Output +[x] **Completed** + For simpler schemas, generate a single `models.py`: ``` @@ -242,18 +279,21 @@ models.py ### Multi-File Output +[x] **Completed** + For larger schemas, split by table with a shared types module: ``` models/ ├── __init__.py # Re-exports all models -├── _types.py # Shared types, enums, base classes ├── user.py # User model ├── post.py # Post model └── comment.py # Comment model ``` -**Configuration**: Let users choose via `OutputMode::SingleFile` or `OutputMode::MultiFile`. +**Note**: `_types.py` for shared types/enums not yet implemented (depends on enum support). + +**Configuration**: Users can choose via `OutputMode::SingleFile` or `OutputMode::MultiFile`. ## Generated Code Examples @@ -280,17 +320,15 @@ Table { ### Generated Python +[x] **Completed** - Basic generation works. Some features marked for future work. + ```python """SQLModel definitions generated by Tern.""" from datetime import datetime -from typing import TYPE_CHECKING from sqlmodel import Field, SQLModel -if TYPE_CHECKING: - from .post import Post # For relationship type hints - class User(SQLModel, table=True): """User model.""" @@ -300,9 +338,9 @@ class User(SQLModel, table=True): id: int | None = Field(default=None, primary_key=True) email: str = Field(unique=True) name: str | None = None - created_at: datetime = Field(default_factory=datetime.now, index=True) + created_at: datetime # Note: default_factory not yet implemented - # Relationships (if generate_relationships=True) + # Relationships (if generate_relationships=True) - [ ] Future work # posts: list["Post"] = Relationship(back_populates="user") ``` @@ -310,6 +348,8 @@ class User(SQLModel, table=True): ### 1. Python Reserved Words +[x] **Completed** - Full reserved word handling in `naming.rs`. + Python reserved words that might appear as column/table names: ``` @@ -318,31 +358,25 @@ def, del, elif, else, except, finally, for, from, global, if, import, in, is, lambda, nonlocal, not, or, pass, raise, return, try, while, with, yield ``` +Also handles soft keywords (`match`, `case`, `type`, `_`) and can optionally handle Python builtins. + **Handling**: ```python # Column named "class" class_: str = Field(alias="class") ``` -Use SQLAlchemy column aliasing to preserve the database column name while using a valid Python identifier. - ### 2. Invalid Python Identifiers -- Names starting with numbers: `1column` -> `column_1` or `_1column` +[x] **Completed** + +- Names starting with numbers: `1column` -> `_1column` - Names with special characters: `column-name` -> `column_name` - Names with spaces: `column name` -> `column_name` -```rust -fn sanitize_identifier(name: &str) -> String { - // 1. Replace invalid characters with underscores - // 2. Ensure doesn't start with digit - // 3. Handle reserved words -} -``` - ### 3. Circular Foreign Key References -Tables may reference each other: +[~] **Partial** - FK references work. Full relationship generation with back_populates is future work. ```python # Forward reference using string annotation @@ -350,7 +384,7 @@ class User(SQLModel, table=True): id: int | None = Field(default=None, primary_key=True) manager_id: int | None = Field(default=None, foreign_key="users.id") - # Self-referential relationship + # Self-referential relationship - [ ] Future work manager: "User | None" = Relationship( back_populates="direct_reports", sa_relationship_kwargs={"remote_side": "User.id"} @@ -360,7 +394,9 @@ class User(SQLModel, table=True): ### 4. Schema-Qualified Names -Foreign keys might reference tables in other schemas: +[x] **Completed** + +Foreign keys reference tables with schema qualification: ```python # Reference to other_schema.other_table @@ -369,29 +405,32 @@ other_id: int | None = Field(default=None, foreign_key="other_schema.other_table ### 5. Array Types +[x] **Completed** + ```python -from sqlalchemy import ARRAY, Integer +from sqlalchemy import ARRAY, Text from sqlmodel import Field, SQLModel class Document(SQLModel, table=True): - tags: list[str] = Field( - default_factory=list, - sa_type=ARRAY(String) - ) + tags: list[str] = Field(sa_type=ARRAY(Text)) ``` ### 6. JSONB Fields +[x] **Completed** + ```python from sqlalchemy import JSON from typing import Any class Settings(SQLModel, table=True): - config: dict[str, Any] = Field(default_factory=dict, sa_type=JSON) + config: dict[str, Any] = Field(sa_type=JSON) ``` ### 7. Composite Primary Keys +[x] **Completed** + ```python class OrderItem(SQLModel, table=True): __tablename__ = "order_items" @@ -403,6 +442,8 @@ class OrderItem(SQLModel, table=True): ### 8. Generated Columns +[ ] **Future Work** - Currently emits warning. Full `Computed` support not implemented. + SQLModel doesn't have first-class support for generated columns, but we can use `sa_column`: ```python @@ -419,6 +460,8 @@ class Person(SQLModel, table=True): ### 9. Exclusion Constraints +[x] **Completed** - Emits warning comment as designed. + Not directly supported by SQLModel; emit a warning comment: ```python @@ -428,80 +471,93 @@ Not directly supported by SQLModel; emit a warning comment: ### 10. Empty Tables -Tables with no columns (rare but possible): - -```python -class EmptyTable(SQLModel, table=True): - """Table with no columns (placeholder).""" - __tablename__ = "empty_table" - pass # SQLModel requires at least one field in practice -``` +[x] **Completed** - Emits warning. -**Emit warning**: Tables with no columns should emit a warning as SQLModel requires at least one field. +Tables with no columns emit a warning as SQLModel requires at least one field. ## Implementation Plan ### Phase 1: Core Infrastructure -1. Add `genco` dependency to `tern-codegen/Cargo.toml` -2. Create module structure under `src/python/` -3. Implement `PythonCodegenConfig` and `PythonCodegen` struct -4. Implement basic `Codegen` trait with empty generation +[x] **Completed** + +1. [x] Add `genco` dependency to `tern-codegen/Cargo.toml` +2. [x] Create module structure under `src/python/` +3. [x] Implement `PythonCodegenConfig` and `PythonCodegen` struct +4. [x] Implement basic `Codegen` trait with empty generation ### Phase 2: Type Mapping -1. Implement `type_mapping.rs` with PostgreSQL -> Python conversions -2. Handle all scalar types from the mapping table -3. Add array type detection and handling -4. Add tests for type mapping edge cases +[x] **Completed** + +1. [x] Implement `type_mapping.rs` with PostgreSQL -> Python conversions +2. [x] Handle all scalar types from the mapping table +3. [x] Add array type detection and handling +4. [x] Add tests for type mapping edge cases ### Phase 3: Name Handling -1. Implement `naming.rs` with identifier sanitization -2. Build reserved word detection and handling -3. Implement table/column name conversion to Python conventions -4. Add tests for naming edge cases +[x] **Completed** + +1. [x] Implement `naming.rs` with identifier sanitization +2. [x] Build reserved word detection and handling +3. [x] Implement table/column name conversion to Python conventions +4. [x] Add tests for naming edge cases ### Phase 4: Basic Model Generation -1. Implement simple class generation with genco -2. Generate basic fields (non-nullable, no constraints) -3. Handle nullable fields with `Optional` types -4. Generate proper imports +[x] **Completed** + +1. [x] Implement simple class generation with genco +2. [x] Generate basic fields (non-nullable, no constraints) +3. [x] Handle nullable fields with `| None` types +4. [x] Generate proper imports ### Phase 5: Constraint Support -1. Primary key generation -2. Foreign key generation (without relationships) -3. Unique constraint generation (single-column via Field, multi-column via `__table_args__`) -4. Check constraint generation via `__table_args__` -5. Index generation +[x] **Completed** + +1. [x] Primary key generation +2. [x] Foreign key generation (without relationships) +3. [x] Unique constraint generation (single-column via Field, multi-column via `__table_args__`) +4. [x] Check constraint generation via `__table_args__` +5. [x] Index generation (multi-column via `__table_args__`) ### Phase 6: Relationship Generation -1. Analyze foreign key graph to determine relationship directions -2. Generate `Relationship()` attributes -3. Handle self-referential relationships -4. Handle circular references with forward declarations +[ ] **Future Work** + +1. [ ] Analyze foreign key graph to determine relationship directions +2. [ ] Generate `Relationship()` attributes +3. [ ] Handle self-referential relationships +4. [ ] Handle circular references with forward declarations + +**Note**: Infrastructure is in place (`add_relationship`, `add_type_checking` methods exist with `#[allow(dead_code)]`). ### Phase 7: Advanced Features -1. Generated column support -2. Identity column support -3. Array type support -4. JSONB field support -5. Docstring generation from comments +[~] **Partially Completed** + +1. [ ] Generated column support (with `Computed`) +2. [x] Identity column support +3. [x] Array type support +4. [x] JSONB field support +5. [x] Docstring generation from comments ### Phase 8: Testing -1. Unit tests for each component -2. Snapshot tests for complete model generation -3. Edge case tests (reserved words, special characters, circular refs) -4. Integration tests with complex multi-table schemas +[x] **Completed** + +1. [x] Unit tests for each component (109 tests) +2. [x] Snapshot tests for complete model generation (11 snapshots) +3. [x] Edge case tests (reserved words, special characters, circular refs) +4. [x] Integration tests with complex multi-table schemas ## Dependencies -Add to `crates/codegen/Cargo.toml`: +[x] **Completed** + +Added to `crates/codegen/Cargo.toml`: ```toml [dependencies] @@ -518,7 +574,7 @@ pretty_assertions = "1.4" ### Unit Tests -Test individual components in isolation: +[x] **Completed** - 109 tests covering all components. ```rust #[test] @@ -534,7 +590,7 @@ fn test_sanitize_reserved_word() { ### Snapshot Tests -Use `insta` for snapshot testing of generated code: +[x] **Completed** - 11 snapshot tests with insta. ```rust #[test] @@ -545,62 +601,25 @@ fn snapshot_simple_table() { insta::assert_snapshot!(output.get("models.py").unwrap()); } - -#[test] -fn snapshot_foreign_key_relationship() { - let tables = vec![create_users_table(), create_posts_table()]; - let codegen = PythonCodegen::new(PythonCodegenConfig { - generate_relationships: true, - ..Default::default() - }); - let output = codegen.generate(tables); - - insta::assert_snapshot!(output.get("models.py").unwrap()); -} ``` ### Edge Case Tests -```rust -#[test] -fn test_reserved_word_column() { - let table = Table { - name: TableName::try_new("items".to_string()).unwrap(), - columns: vec![ - column("class", "text"), // Reserved word - column("from", "integer"), // Reserved word - ], - ..Default::default() - }; - // Verify generated code uses aliases -} +[x] **Completed** -#[test] -fn test_self_referential_foreign_key() { - let table = Table { - name: TableName::try_new("employees".to_string()).unwrap(), - columns: vec![ - column("id", "integer"), - column("manager_id", "integer"), - ], - constraints: vec![ - Constraint::foreign_key("manager_id", "employees", "id"), - ], - ..Default::default() - }; - // Verify self-referential handling -} - -#[test] -fn test_circular_foreign_keys() { - let user_table = /* ... references posts.featured_user_id */; - let post_table = /* ... references users.id */; - // Verify circular reference handling with forward declarations -} -``` +- Reserved word columns (class, from, import, def, etc.) +- Self-referential foreign keys +- Composite primary keys +- Multi-column unique constraints +- Check constraints +- Exclusion constraints (warning) +- All PostgreSQL types +- Multi-file output mode ## Error Handling +[x] **Completed** - Error type defined in `mod.rs`. + ```rust #[derive(Debug, thiserror::Error)] pub enum PythonCodegenError { @@ -620,24 +639,31 @@ pub enum PythonCodegenError { **Note**: The current `Codegen` trait returns `HashMap` without error handling. Consider proposing a trait update to return `Result, Error>` in the future. -## Open Questions +## Open Questions (Resolved) 1. **Output Mode**: Should we default to single-file or multi-file output? - - **Recommendation**: Single file for simplicity, with config option for multi-file. + - **Resolution**: [x] Single file is the default, with `OutputMode::MultiFile` option. 2. **Relationship Generation**: Should relationships be opt-in or opt-out? - - **Recommendation**: Opt-in via config flag, as relationships add complexity. + - **Resolution**: [x] Opt-in via `generate_relationships` config flag. Not yet implemented. 3. **Type Hints Style**: Should we use `Optional[X]` or `X | None`? - - **Recommendation**: `X | None` (Python 3.10+ syntax) as it's more modern and SQLModel targets newer Python. + - **Resolution**: [x] Using `X | None` (Python 3.10+ syntax). 4. **Enum Handling**: Literal types vs Python Enum classes? - - **Recommendation**: This design assumes enums are not passed in the `Vec
` input. If enum support is needed, add a separate `enums` parameter to the generator or include enum definitions in a preprocessing step. + - **Resolution**: Deferred to future work. Currently maps unknown types to `Any`. ## Future Enhancements -1. **Alembic Migration Generation**: Generate Alembic migration files alongside models -2. **Pydantic V2 Schemas**: Generate pure Pydantic models for API schemas -3. **FastAPI Integration**: Generate FastAPI route stubs for CRUD operations -4. **Custom Validators**: Support for custom Pydantic validators from check constraints -5. **Type Stubs**: Generate `.pyi` stub files for better IDE support +1. [ ] **Relationship Generation**: Generate `Relationship()` attributes with back_populates +2. [ ] **Pydantic Base Models**: Generate Pydantic-only models via `generate_base_models` config +3. [ ] **User-Defined Enums**: Support PostgreSQL enum types as Python Literal or Enum +4. [ ] **Generated Columns**: Full `Computed` support with sa_column +5. [ ] **Single-Column Index**: Generate `Field(index=True)` for indexed columns +6. [ ] **Default Factory**: Generate `default_factory` for datetime fields with `now()` defaults +7. [ ] **Module Prefix**: Support `module_prefix` config for import paths +8. [ ] **String Length Validation**: Use extracted varchar/char length in `Field(max_length=n)` +9. [ ] **Alembic Migration Generation**: Generate Alembic migration files alongside models +10. [ ] **FastAPI Integration**: Generate FastAPI route stubs for CRUD operations +11. [ ] **Custom Validators**: Support for custom Pydantic validators from check constraints +12. [ ] **Type Stubs**: Generate `.pyi` stub files for better IDE support diff --git a/crates/codegen/src/python/field.rs b/crates/codegen/src/python/field.rs new file mode 100644 index 0000000..1b22232 --- /dev/null +++ b/crates/codegen/src/python/field.rs @@ -0,0 +1,515 @@ +//! SQLModel field generation. +//! +//! This module handles generating Field() declarations for SQLModel columns, +//! including type annotations, default values, and field parameters. + +use tern_ddl::{Column, ConstraintKind, Table}; + +use super::ReservedWordStrategy; +use super::imports::ImportCollector; +use super::naming::to_attribute_name; +use super::type_mapping::{extract_string_length, is_fixed_length_char, map_pg_type}; + +/// Information about a generated field. +#[derive(Debug)] +pub struct FieldInfo { + /// The Python attribute name (may differ from column name). + pub attr_name: String, + /// The original database column name. + /// Kept for debugging and future relationship generation. + #[allow(dead_code)] + pub db_column_name: String, + /// Whether an alias is needed (attr_name != db_column_name). + /// Kept for debugging and validation. + #[allow(dead_code)] + pub needs_alias: bool, + /// The Python type annotation. + pub type_annotation: String, + /// The Field() declaration (if any parameters are needed). + pub field_declaration: Option, + /// Default value without Field() (e.g., "= None"). + pub simple_default: Option, + /// Required imports for this field. + pub imports: ImportCollector, + /// Whether this field is a primary key. + pub is_primary_key: bool, + /// Whether this field has a foreign key constraint. + /// Kept for future relationship generation. + #[allow(dead_code)] + pub foreign_key: Option, + /// Whether this field is indexed (non-constraint index). + /// Kept for debugging and validation. + #[allow(dead_code)] + pub is_indexed: bool, + /// Whether this field has a unique constraint. + /// Kept for debugging and validation. + #[allow(dead_code)] + pub is_unique: bool, +} + +/// Context for field generation containing table-level information. +pub struct FieldContext<'a> { + /// The table containing the column. + /// Kept for future relationship generation. + #[allow(dead_code)] + pub table: &'a Table, + /// Reserved word handling strategy. + pub strategy: &'a ReservedWordStrategy, + /// Set of column names that are primary keys. + pub primary_key_columns: Vec<&'a str>, + /// Map of column name to foreign key reference (e.g., "users.id"). + pub foreign_key_columns: Vec<(&'a str, String)>, + /// Set of column names with unique constraints (single-column only). + pub unique_columns: Vec<&'a str>, + /// Set of column names with indexes (non-constraint indexes). + pub indexed_columns: Vec<&'a str>, +} + +impl<'a> FieldContext<'a> { + /// Creates a new field context from a table. + pub fn from_table(table: &'a Table, strategy: &'a ReservedWordStrategy) -> Self { + let mut primary_key_columns = Vec::new(); + let mut foreign_key_columns = Vec::new(); + let mut unique_columns = Vec::new(); + let mut indexed_columns = Vec::new(); + + // Extract constraint information + for constraint in &table.constraints { + match &constraint.kind { + ConstraintKind::PrimaryKey(pk) => { + for col in &pk.columns { + primary_key_columns.push(col.as_ref()); + } + } + ConstraintKind::ForeignKey(fk) => { + // For single-column FKs, we can use Field(foreign_key=...) + if fk.columns.len() == 1 && fk.referenced_columns.len() == 1 { + let col_name = fk.columns[0].as_ref(); + let ref_table = &fk.referenced_table; + let ref_col = fk.referenced_columns[0].as_ref(); + let fk_ref = format!( + "{}.{}.{}", + ref_table.schema.as_ref(), + ref_table.name.as_ref(), + ref_col + ); + foreign_key_columns.push((col_name, fk_ref)); + } + } + ConstraintKind::Unique(unique) => { + // Single-column unique constraints can use Field(unique=True) + if unique.columns.len() == 1 { + unique_columns.push(unique.columns[0].as_ref()); + } + } + _ => {} + } + } + + // Extract index information (non-constraint indexes) + for index in &table.indexes { + if !index.is_constraint_index && index.columns.len() == 1 { + if let Some(col) = &index.columns[0].column { + indexed_columns.push(col.as_ref()); + } + } + } + + Self { + table, + strategy, + primary_key_columns, + foreign_key_columns, + unique_columns, + indexed_columns, + } + } + + /// Checks if a column is a primary key. + pub fn is_primary_key(&self, column_name: &str) -> bool { + self.primary_key_columns.contains(&column_name) + } + + /// Gets the foreign key reference for a column, if any. + pub fn get_foreign_key(&self, column_name: &str) -> Option<&str> { + self.foreign_key_columns + .iter() + .find(|(col, _)| *col == column_name) + .map(|(_, fk)| fk.as_str()) + } + + /// Checks if a column has a unique constraint. + pub fn is_unique(&self, column_name: &str) -> bool { + self.unique_columns.contains(&column_name) + } + + /// Checks if a column is indexed. + pub fn is_indexed(&self, column_name: &str) -> bool { + self.indexed_columns.contains(&column_name) + } +} + +/// Generates field information for a column. +pub fn generate_field(column: &Column, ctx: &FieldContext<'_>) -> FieldInfo { + let column_name = column.name.as_ref(); + + // Convert column name to Python attribute name + let (attr_name, needs_alias) = to_attribute_name(column_name, ctx.strategy); + + // Map PostgreSQL type to Python type + let py_type = map_pg_type(&column.type_info); + + // Collect imports + let mut imports = ImportCollector::new(); + imports.add_all(&py_type.imports); + imports.add_all(&py_type.sa_imports); + + // Check constraints for this column + let is_primary_key = ctx.is_primary_key(column_name); + let foreign_key = ctx.get_foreign_key(column_name).map(|s| s.to_string()); + let is_unique = ctx.is_unique(column_name); + let is_indexed = ctx.is_indexed(column_name); + + // Determine if we're dealing with an auto-generated primary key (identity or serial) + let is_auto_pk = is_primary_key + && (column.identity.is_some() || is_serial_type(&column.type_info.name.as_ref())); + + // Build field parameters + let mut field_params = Vec::new(); + + // Handle nullability and primary key + let type_annotation = if is_auto_pk { + // Auto-generated PKs: type is Optional with default=None + field_params.push("default=None".to_string()); + format!("{} | None", py_type.annotation) + } else if column.is_nullable { + format!("{} | None", py_type.annotation) + } else { + py_type.annotation.clone() + }; + + // Primary key + if is_primary_key { + field_params.push("primary_key=True".to_string()); + } + + // Foreign key + if let Some(ref fk) = foreign_key { + field_params.push(format!("foreign_key=\"{}\"", fk)); + } + + // Unique constraint + if is_unique && !is_primary_key { + field_params.push("unique=True".to_string()); + } + + // Index + if is_indexed && !is_primary_key && !is_unique { + field_params.push("index=True".to_string()); + } + + // String length constraints + if let Some(length) = extract_string_length(&column.type_info.formatted) { + if is_fixed_length_char(column.type_info.name.as_ref()) { + // Fixed-length char: both min and max + field_params.push(format!("min_length={length}")); + field_params.push(format!("max_length={length}")); + } else { + // Variable-length varchar: only max + field_params.push(format!("max_length={length}")); + } + } + + // SQLAlchemy type (for JSON, ARRAY, etc.) + if let Some(ref sa_type) = py_type.sa_type { + field_params.push(format!("sa_type={sa_type}")); + } + + // Alias for reserved words or sanitized names + if needs_alias { + field_params.push(format!("alias=\"{}\"", column_name)); + } + + // Handle default values for nullable non-PK fields + let simple_default = + if !is_auto_pk && column.is_nullable && column.default.is_none() && field_params.is_empty() + { + Some("None".to_string()) + } else { + None + }; + + // Determine if we need a Field() declaration + let field_declaration = if !field_params.is_empty() { + imports.add_field(); + Some(format!("Field({})", field_params.join(", "))) + } else { + None + }; + + FieldInfo { + attr_name, + db_column_name: column_name.to_string(), + needs_alias, + type_annotation, + field_declaration, + simple_default, + imports, + is_primary_key, + foreign_key, + is_indexed, + is_unique, + } +} + +/// Checks if a type name represents a serial (auto-increment) type. +fn is_serial_type(type_name: &str) -> bool { + matches!( + type_name, + "serial" | "serial2" | "serial4" | "serial8" | "smallserial" | "bigserial" + ) +} + +/// Formats a field as a Python class attribute line. +pub fn format_field_line(field: &FieldInfo) -> String { + let type_ann = &field.type_annotation; + let name = &field.attr_name; + + if let Some(ref field_decl) = field.field_declaration { + format!(" {name}: {type_ann} = {field_decl}") + } else if let Some(ref default) = field.simple_default { + format!(" {name}: {type_ann} = {default}") + } else { + format!(" {name}: {type_ann}") + } +} + +#[cfg(test)] +mod tests { + use super::*; + use tern_ddl::types::{ForeignKeyAction, QualifiedCollationName, QualifiedName}; + use tern_ddl::{ + ColumnName, Constraint, ConstraintName, ForeignKeyConstraint, IndexName, Oid, + PrimaryKeyConstraint, SchemaName, TableKind, TableName, TypeInfo, TypeName, + UniqueConstraint, + }; + + fn make_column(name: &str, type_name: &str, is_nullable: bool) -> Column { + Column { + name: ColumnName::try_new(name.to_string()).unwrap(), + position: 1, + type_info: TypeInfo { + name: TypeName::try_new(type_name.to_string()).unwrap(), + schema: SchemaName::try_new("pg_catalog".to_string()).unwrap(), + formatted: type_name.to_string(), + is_array: false, + }, + is_nullable, + default: None, + generated: None, + identity: None, + collation: QualifiedCollationName::new( + SchemaName::try_new("pg_catalog".to_string()).unwrap(), + tern_ddl::CollationName::try_new("default".to_string()).unwrap(), + ), + comment: None, + } + } + + fn make_table(name: &str, columns: Vec, constraints: Vec) -> Table { + Table { + oid: Oid::new(1), + name: TableName::try_new(name.to_string()).unwrap(), + kind: TableKind::Regular, + columns, + constraints, + indexes: vec![], + comment: None, + } + } + + #[test] + fn test_simple_field() { + let column = make_column("name", "text", false); + let table = make_table("users", vec![column.clone()], vec![]); + let strategy = ReservedWordStrategy::AppendUnderscore; + let ctx = FieldContext::from_table(&table, &strategy); + + let field = generate_field(&column, &ctx); + + assert_eq!(field.attr_name, "name"); + assert_eq!(field.type_annotation, "str"); + assert!(!field.needs_alias); + assert!(field.field_declaration.is_none()); + } + + #[test] + fn test_nullable_field() { + let column = make_column("bio", "text", true); + let table = make_table("users", vec![column.clone()], vec![]); + let strategy = ReservedWordStrategy::AppendUnderscore; + let ctx = FieldContext::from_table(&table, &strategy); + + let field = generate_field(&column, &ctx); + + assert_eq!(field.type_annotation, "str | None"); + assert_eq!(field.simple_default, Some("None".to_string())); + } + + #[test] + fn test_primary_key_field() { + let column = make_column("id", "int4", false); + let pk = Constraint { + name: ConstraintName::try_new("users_pkey".to_string()).unwrap(), + kind: ConstraintKind::PrimaryKey(PrimaryKeyConstraint { + columns: vec![ColumnName::try_new("id".to_string()).unwrap()], + index_name: IndexName::try_new("users_pkey".to_string()).unwrap(), + }), + comment: None, + }; + let table = make_table("users", vec![column.clone()], vec![pk]); + let strategy = ReservedWordStrategy::AppendUnderscore; + let ctx = FieldContext::from_table(&table, &strategy); + + let field = generate_field(&column, &ctx); + + assert!(field.is_primary_key); + assert!(field.field_declaration.is_some()); + let decl = field.field_declaration.unwrap(); + assert!(decl.contains("primary_key=True")); + } + + #[test] + fn test_reserved_word_field() { + let column = make_column("class", "text", false); + let table = make_table("items", vec![column.clone()], vec![]); + let strategy = ReservedWordStrategy::AppendUnderscore; + let ctx = FieldContext::from_table(&table, &strategy); + + let field = generate_field(&column, &ctx); + + assert_eq!(field.attr_name, "class_"); + assert!(field.needs_alias); + assert!(field.field_declaration.is_some()); + let decl = field.field_declaration.unwrap(); + assert!(decl.contains("alias=\"class\"")); + } + + #[test] + fn test_foreign_key_field() { + let column = make_column("user_id", "int4", true); + let fk = Constraint { + name: ConstraintName::try_new("posts_user_id_fkey".to_string()).unwrap(), + kind: ConstraintKind::ForeignKey(ForeignKeyConstraint { + columns: vec![ColumnName::try_new("user_id".to_string()).unwrap()], + referenced_table: QualifiedName::new( + SchemaName::try_new("public".to_string()).unwrap(), + TableName::try_new("users".to_string()).unwrap(), + ), + referenced_columns: vec![ColumnName::try_new("id".to_string()).unwrap()], + on_delete: ForeignKeyAction::NoAction, + on_update: ForeignKeyAction::NoAction, + is_deferrable: false, + is_initially_deferred: false, + }), + comment: None, + }; + let table = make_table("posts", vec![column.clone()], vec![fk]); + let strategy = ReservedWordStrategy::AppendUnderscore; + let ctx = FieldContext::from_table(&table, &strategy); + + let field = generate_field(&column, &ctx); + + assert!(field.foreign_key.is_some()); + assert!(field.field_declaration.is_some()); + let decl = field.field_declaration.unwrap(); + assert!(decl.contains("foreign_key=")); + } + + #[test] + fn test_unique_field() { + let column = make_column("email", "text", false); + let unique = Constraint { + name: ConstraintName::try_new("users_email_key".to_string()).unwrap(), + kind: ConstraintKind::Unique(UniqueConstraint { + columns: vec![ColumnName::try_new("email".to_string()).unwrap()], + index_name: IndexName::try_new("users_email_key".to_string()).unwrap(), + nulls_not_distinct: false, + }), + comment: None, + }; + let table = make_table("users", vec![column.clone()], vec![unique]); + let strategy = ReservedWordStrategy::AppendUnderscore; + let ctx = FieldContext::from_table(&table, &strategy); + + let field = generate_field(&column, &ctx); + + assert!(field.is_unique); + assert!(field.field_declaration.is_some()); + let decl = field.field_declaration.unwrap(); + assert!(decl.contains("unique=True")); + } + + #[test] + fn test_format_field_line() { + let field = FieldInfo { + attr_name: "name".to_string(), + db_column_name: "name".to_string(), + needs_alias: false, + type_annotation: "str".to_string(), + field_declaration: None, + simple_default: None, + imports: ImportCollector::new(), + is_primary_key: false, + foreign_key: None, + is_indexed: false, + is_unique: false, + }; + + let line = format_field_line(&field); + assert_eq!(line, " name: str"); + } + + #[test] + fn test_format_field_line_with_default() { + let field = FieldInfo { + attr_name: "bio".to_string(), + db_column_name: "bio".to_string(), + needs_alias: false, + type_annotation: "str | None".to_string(), + field_declaration: None, + simple_default: Some("None".to_string()), + imports: ImportCollector::new(), + is_primary_key: false, + foreign_key: None, + is_indexed: false, + is_unique: false, + }; + + let line = format_field_line(&field); + assert_eq!(line, " bio: str | None = None"); + } + + #[test] + fn test_format_field_line_with_field_declaration() { + let field = FieldInfo { + attr_name: "id".to_string(), + db_column_name: "id".to_string(), + needs_alias: false, + type_annotation: "int | None".to_string(), + field_declaration: Some("Field(default=None, primary_key=True)".to_string()), + simple_default: None, + imports: ImportCollector::new(), + is_primary_key: true, + foreign_key: None, + is_indexed: false, + is_unique: false, + }; + + let line = format_field_line(&field); + assert_eq!( + line, + " id: int | None = Field(default=None, primary_key=True)" + ); + } +} diff --git a/crates/codegen/src/python/generator.rs b/crates/codegen/src/python/generator.rs new file mode 100644 index 0000000..afe750a --- /dev/null +++ b/crates/codegen/src/python/generator.rs @@ -0,0 +1,407 @@ +//! Python SQLModel code generator implementation. +//! +//! This module provides the main `PythonCodegen` struct that implements the `Codegen` trait +//! for generating Python SQLModel models from PostgreSQL table definitions. + +use std::collections::HashMap; + +use tern_ddl::Table; + +use super::imports::ImportCollector; +use super::model::{ModelInfo, format_model, generate_model}; +use super::naming::to_module_name; +use super::{OutputMode, PythonCodegenConfig}; +use crate::Codegen; + +/// Python SQLModel code generator. +/// +/// Generates Python SQLModel model definitions from PostgreSQL table schemas. +/// +/// # Example +/// +/// ```ignore +/// use tern_codegen::python::{PythonCodegen, PythonCodegenConfig}; +/// use tern_codegen::Codegen; +/// +/// let codegen = PythonCodegen::new(PythonCodegenConfig::default()); +/// let tables = vec![/* ... */]; +/// let output = codegen.generate(tables); +/// +/// // output["models.py"] contains the generated code +/// ``` +#[derive(Debug, Clone)] +pub struct PythonCodegen { + config: PythonCodegenConfig, +} + +impl PythonCodegen { + /// Creates a new Python code generator with the given configuration. + pub fn new(config: PythonCodegenConfig) -> Self { + Self { config } + } + + /// Creates a new Python code generator with default configuration. + pub fn with_defaults() -> Self { + Self::new(PythonCodegenConfig::default()) + } + + /// Generates a single models.py file containing all models. + fn generate_single_file(&self, tables: Vec
) -> HashMap { + let mut output = HashMap::new(); + + if tables.is_empty() { + output.insert("models.py".to_string(), generate_empty_models_file()); + return output; + } + + // Generate all models + let models: Vec = tables + .iter() + .map(|t| generate_model(t, &self.config)) + .collect(); + + // Collect all imports + let mut imports = ImportCollector::new(); + for model in &models { + imports.merge(&model.imports); + } + + // Build the file content + let mut content = Vec::new(); + + // Module docstring + content.push(MODULE_DOCSTRING.to_string()); + content.push(String::new()); + + // Imports + let import_block = imports.generate(); + if !import_block.is_empty() { + content.push(import_block); + content.push(String::new()); + } + + // Models + for (i, model) in models.iter().enumerate() { + if i > 0 { + content.push(String::new()); + content.push(String::new()); + } + content.push(format_model(model)); + } + + // Ensure file ends with newline + content.push(String::new()); + + output.insert("models.py".to_string(), content.join("\n")); + output + } + + /// Generates multiple files, one per model, with a shared types module. + fn generate_multi_file(&self, tables: Vec
) -> HashMap { + let mut output = HashMap::new(); + + if tables.is_empty() { + output.insert("__init__.py".to_string(), generate_empty_init_file()); + return output; + } + + // Generate all models + let models: Vec = tables + .iter() + .map(|t| generate_model(t, &self.config)) + .collect(); + + // Generate individual model files + let mut all_class_names = Vec::new(); + let mut all_module_names = Vec::new(); + + for model in &models { + let module_name = to_module_name(&model.table_name); + let filename = format!("{module_name}.py"); + + // Build imports for this file + let mut imports = ImportCollector::new(); + imports.merge(&model.imports); + + // Build file content + let mut content = Vec::new(); + content.push(format!( + "\"\"\"SQLModel definition for {}.\"\"\"\n", + model.class_name + )); + + let import_block = imports.generate(); + if !import_block.is_empty() { + content.push(import_block); + content.push(String::new()); + } + + content.push(format_model(model)); + content.push(String::new()); + + output.insert(filename, content.join("\n")); + + all_class_names.push(model.class_name.clone()); + all_module_names.push(module_name); + } + + // Generate __init__.py with re-exports + let init_content = generate_init_file(&all_module_names, &all_class_names); + output.insert("__init__.py".to_string(), init_content); + + output + } +} + +impl Default for PythonCodegen { + fn default() -> Self { + Self::with_defaults() + } +} + +impl Codegen for PythonCodegen { + fn generate(&self, tables: Vec
) -> HashMap { + match self.config.output_mode { + OutputMode::SingleFile => self.generate_single_file(tables), + OutputMode::MultiFile => self.generate_multi_file(tables), + } + } +} + +/// Module docstring for generated files. +const MODULE_DOCSTRING: &str = "\"\"\"SQLModel definitions generated by Tern. + +This file was automatically generated. Do not edit manually. +\"\"\""; + +/// Generates an empty models.py file. +fn generate_empty_models_file() -> String { + format!( + "{}\n\nfrom sqlmodel import SQLModel\n\n# No tables to generate\n", + MODULE_DOCSTRING + ) +} + +/// Generates an empty __init__.py file. +fn generate_empty_init_file() -> String { + "\"\"\"SQLModel definitions generated by Tern.\"\"\"\n\n# No models to export\n".to_string() +} + +/// Generates the __init__.py file with re-exports. +fn generate_init_file(module_names: &[String], class_names: &[String]) -> String { + let mut content = Vec::new(); + content.push("\"\"\"SQLModel definitions generated by Tern.\n".to_string()); + content.push("This file was automatically generated. Do not edit manually.".to_string()); + content.push("\"\"\"".to_string()); + content.push(String::new()); + + // Import statements + for (module, class) in module_names.iter().zip(class_names.iter()) { + content.push(format!("from .{module} import {class}")); + } + + content.push(String::new()); + + // __all__ list + let all_list: Vec = class_names.iter().map(|c| format!("\"{c}\"")).collect(); + if all_list.len() <= 3 { + content.push(format!("__all__ = [{}]", all_list.join(", "))); + } else { + content.push("__all__ = [".to_string()); + for item in &all_list { + content.push(format!(" {item},")); + } + content.push("]".to_string()); + } + + content.push(String::new()); + content.join("\n") +} + +#[cfg(test)] +mod tests { + use super::*; + use tern_ddl::types::QualifiedCollationName; + use tern_ddl::{ + Column, ColumnName, Constraint, ConstraintKind, ConstraintName, IndexName, Oid, + PrimaryKeyConstraint, SchemaName, TableKind, TableName, TypeInfo, TypeName, + }; + + fn make_column(name: &str, type_name: &str, is_nullable: bool) -> Column { + Column { + name: ColumnName::try_new(name.to_string()).unwrap(), + position: 1, + type_info: TypeInfo { + name: TypeName::try_new(type_name.to_string()).unwrap(), + schema: SchemaName::try_new("pg_catalog".to_string()).unwrap(), + formatted: type_name.to_string(), + is_array: false, + }, + is_nullable, + default: None, + generated: None, + identity: None, + collation: QualifiedCollationName::new( + SchemaName::try_new("pg_catalog".to_string()).unwrap(), + tern_ddl::CollationName::try_new("default".to_string()).unwrap(), + ), + comment: None, + } + } + + fn make_table_with_pk(name: &str, columns: Vec, pk_columns: &[&str]) -> Table { + let pk = Constraint { + name: ConstraintName::try_new(format!("{}_pkey", name)).unwrap(), + kind: ConstraintKind::PrimaryKey(PrimaryKeyConstraint { + columns: pk_columns + .iter() + .map(|c| ColumnName::try_new(c.to_string()).unwrap()) + .collect(), + index_name: IndexName::try_new(format!("{}_pkey", name)).unwrap(), + }), + comment: None, + }; + + Table { + oid: Oid::new(1), + name: TableName::try_new(name.to_string()).unwrap(), + kind: TableKind::Regular, + columns, + constraints: vec![pk], + indexes: vec![], + comment: None, + } + } + + #[test] + fn test_generate_empty_tables() { + let codegen = PythonCodegen::with_defaults(); + let output = codegen.generate(vec![]); + + assert!(output.contains_key("models.py")); + let content = &output["models.py"]; + assert!(content.contains("SQLModel")); + assert!(content.contains("No tables to generate")); + } + + #[test] + fn test_generate_single_table() { + let columns = vec![ + make_column("id", "int4", false), + make_column("name", "text", false), + make_column("email", "text", false), + ]; + let table = make_table_with_pk("users", columns, &["id"]); + + let codegen = PythonCodegen::with_defaults(); + let output = codegen.generate(vec![table]); + + assert!(output.contains_key("models.py")); + let content = &output["models.py"]; + assert!(content.contains("class User(SQLModel, table=True):")); + assert!(content.contains("from sqlmodel import")); + assert!(content.contains("__tablename__ = \"users\"")); + } + + #[test] + fn test_generate_multiple_tables() { + let user_columns = vec![ + make_column("id", "int4", false), + make_column("name", "text", false), + ]; + let user_table = make_table_with_pk("users", user_columns, &["id"]); + + let post_columns = vec![ + make_column("id", "int4", false), + make_column("title", "text", false), + make_column("user_id", "int4", true), + ]; + let post_table = make_table_with_pk("posts", post_columns, &["id"]); + + let codegen = PythonCodegen::with_defaults(); + let output = codegen.generate(vec![user_table, post_table]); + + assert!(output.contains_key("models.py")); + let content = &output["models.py"]; + assert!(content.contains("class User(SQLModel, table=True):")); + assert!(content.contains("class Post(SQLModel, table=True):")); + } + + #[test] + fn test_generate_multi_file_mode() { + let user_columns = vec![ + make_column("id", "int4", false), + make_column("name", "text", false), + ]; + let user_table = make_table_with_pk("users", user_columns, &["id"]); + + let post_columns = vec![ + make_column("id", "int4", false), + make_column("title", "text", false), + ]; + let post_table = make_table_with_pk("posts", post_columns, &["id"]); + + let config = PythonCodegenConfig { + output_mode: OutputMode::MultiFile, + ..Default::default() + }; + let codegen = PythonCodegen::new(config); + let output = codegen.generate(vec![user_table, post_table]); + + assert!(output.contains_key("__init__.py")); + assert!(output.contains_key("user.py")); + assert!(output.contains_key("post.py")); + + // Check __init__.py content + let init = &output["__init__.py"]; + assert!(init.contains("from .user import User")); + assert!(init.contains("from .post import Post")); + assert!(init.contains("__all__")); + + // Check individual files + let user_file = &output["user.py"]; + assert!(user_file.contains("class User(SQLModel, table=True):")); + + let post_file = &output["post.py"]; + assert!(post_file.contains("class Post(SQLModel, table=True):")); + } + + #[test] + fn test_generate_with_datetime_import() { + let columns = vec![ + make_column("id", "int4", false), + make_column("created_at", "timestamptz", false), + ]; + let table = make_table_with_pk("events", columns, &["id"]); + + let codegen = PythonCodegen::with_defaults(); + let output = codegen.generate(vec![table]); + + let content = &output["models.py"]; + assert!(content.contains("from datetime import datetime")); + assert!(content.contains("created_at: datetime")); + } + + #[test] + fn test_generate_with_uuid_import() { + let columns = vec![ + make_column("id", "uuid", false), + make_column("name", "text", false), + ]; + let table = make_table_with_pk("items", columns, &["id"]); + + let codegen = PythonCodegen::with_defaults(); + let output = codegen.generate(vec![table]); + + let content = &output["models.py"]; + assert!(content.contains("from uuid import UUID")); + assert!(content.contains("id:") || content.contains("id :")); + } + + #[test] + fn test_default_codegen_implements_trait() { + let codegen = PythonCodegen::default(); + let output = codegen.generate(vec![]); + assert!(output.contains_key("models.py")); + } +} diff --git a/crates/codegen/src/python/imports.rs b/crates/codegen/src/python/imports.rs new file mode 100644 index 0000000..a5d2777 --- /dev/null +++ b/crates/codegen/src/python/imports.rs @@ -0,0 +1,330 @@ +//! Python import management. +//! +//! This module handles collecting and organizing Python imports for generated code. + +use std::collections::{BTreeMap, BTreeSet}; + +use super::type_mapping::PythonImport; + +/// Collects and organizes Python imports for code generation. +#[derive(Debug, Default)] +pub struct ImportCollector { + /// Standard library imports grouped by module. + /// Using BTreeMap/BTreeSet for deterministic ordering. + stdlib: BTreeMap>, + /// Third-party imports (sqlmodel, sqlalchemy, pydantic). + third_party: BTreeMap>, + /// TYPE_CHECKING imports (for forward references). + type_checking: BTreeMap>, +} + +impl ImportCollector { + /// Creates a new empty import collector. + pub fn new() -> Self { + Self::default() + } + + /// Adds a Python import to the collector. + pub fn add(&mut self, import: &PythonImport) { + let category = categorize_module(&import.module); + let map = match category { + ModuleCategory::Stdlib => &mut self.stdlib, + ModuleCategory::ThirdParty => &mut self.third_party, + }; + + map.entry(import.module.clone()) + .or_default() + .insert(import.name.clone()); + } + + /// Adds multiple imports to the collector. + pub fn add_all(&mut self, imports: &[PythonImport]) { + for import in imports { + self.add(import); + } + } + + /// Adds a TYPE_CHECKING import (for forward references). + /// Used for relationship generation to avoid circular imports. + #[allow(dead_code)] + pub fn add_type_checking(&mut self, module: &str, name: &str) { + self.type_checking + .entry(module.to_string()) + .or_default() + .insert(name.to_string()); + } + + /// Adds the core SQLModel import. + pub fn add_sqlmodel(&mut self) { + self.add(&PythonImport::new("sqlmodel", "SQLModel")); + } + + /// Adds the SQLModel Field import. + pub fn add_field(&mut self) { + self.add(&PythonImport::new("sqlmodel", "Field")); + } + + /// Adds the SQLModel Relationship import. + /// Used when generate_relationships config option is enabled. + #[allow(dead_code)] + pub fn add_relationship(&mut self) { + self.add(&PythonImport::new("sqlmodel", "Relationship")); + } + + /// Checks if there are any TYPE_CHECKING imports. + /// Used for relationship generation to determine if TYPE_CHECKING block is needed. + #[allow(dead_code)] + pub fn has_type_checking(&self) -> bool { + !self.type_checking.is_empty() + } + + /// Generates the import statements as a formatted string. + pub fn generate(&self) -> String { + let mut lines = Vec::new(); + + // Standard library imports + if !self.stdlib.is_empty() { + for (module, names) in &self.stdlib { + lines.push(format_import_line(module, names)); + } + } + + // Third-party imports (with blank line separator if needed) + if !self.third_party.is_empty() { + if !self.stdlib.is_empty() { + lines.push(String::new()); + } + for (module, names) in &self.third_party { + lines.push(format_import_line(module, names)); + } + } + + // TYPE_CHECKING block + if !self.type_checking.is_empty() { + if !lines.is_empty() { + lines.push(String::new()); + } + lines.push("if TYPE_CHECKING:".to_string()); + for (module, names) in &self.type_checking { + let import_line = format_import_line(module, names); + lines.push(format!(" {import_line}")); + } + // Ensure typing.TYPE_CHECKING is imported + self.ensure_type_checking_imported(&mut lines); + } + + lines.join("\n") + } + + /// Ensures TYPE_CHECKING is imported from typing if we have a TYPE_CHECKING block. + fn ensure_type_checking_imported(&self, lines: &mut [String]) { + // Check if TYPE_CHECKING is already in typing imports + if let Some(typing_names) = self.stdlib.get("typing") { + if typing_names.contains("TYPE_CHECKING") { + return; + } + } + + // Find the typing import line and add TYPE_CHECKING + for line in lines.iter_mut() { + if line.starts_with("from typing import ") { + // Add TYPE_CHECKING to the existing import + if !line.contains("TYPE_CHECKING") { + let insert_pos = "from typing import ".len(); + line.insert_str(insert_pos, "TYPE_CHECKING, "); + } + return; + } + } + + // No typing import found, need to add one at the beginning + // This case should be handled by the caller adding TYPE_CHECKING import explicitly + } + + /// Merges another ImportCollector into this one. + pub fn merge(&mut self, other: &ImportCollector) { + for (module, names) in &other.stdlib { + self.stdlib + .entry(module.clone()) + .or_default() + .extend(names.iter().cloned()); + } + for (module, names) in &other.third_party { + self.third_party + .entry(module.clone()) + .or_default() + .extend(names.iter().cloned()); + } + for (module, names) in &other.type_checking { + self.type_checking + .entry(module.clone()) + .or_default() + .extend(names.iter().cloned()); + } + } +} + +/// Module category for import grouping. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum ModuleCategory { + Stdlib, + ThirdParty, +} + +/// Categorizes a module as stdlib or third-party. +fn categorize_module(module: &str) -> ModuleCategory { + // Known third-party modules + const THIRD_PARTY: &[&str] = &["sqlmodel", "sqlalchemy", "pydantic", "fastapi", "starlette"]; + + if THIRD_PARTY.iter().any(|&m| module.starts_with(m)) { + ModuleCategory::ThirdParty + } else { + ModuleCategory::Stdlib + } +} + +/// Formats a single import line. +fn format_import_line(module: &str, names: &BTreeSet) -> String { + let names_str: Vec<&str> = names.iter().map(|s| s.as_str()).collect(); + + if names_str.len() == 1 { + format!("from {module} import {}", names_str[0]) + } else { + // Sort for deterministic output + let mut sorted_names = names_str; + sorted_names.sort(); + + // Check if we need multi-line format + let single_line = format!("from {module} import {}", sorted_names.join(", ")); + if single_line.len() <= 88 { + // PEP 8 line length + single_line + } else { + // Multi-line format + let mut lines = vec![format!("from {module} import (")]; + for name in sorted_names { + lines.push(format!(" {name},")); + } + lines.push(")".to_string()); + lines.join("\n") + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_add_simple_import() { + let mut collector = ImportCollector::new(); + collector.add(&PythonImport::new("datetime", "datetime")); + + let output = collector.generate(); + assert!(output.contains("from datetime import datetime")); + } + + #[test] + fn test_add_multiple_from_same_module() { + let mut collector = ImportCollector::new(); + collector.add(&PythonImport::new("datetime", "datetime")); + collector.add(&PythonImport::new("datetime", "date")); + collector.add(&PythonImport::new("datetime", "timedelta")); + + let output = collector.generate(); + assert!(output.contains("from datetime import")); + assert!(output.contains("date")); + assert!(output.contains("datetime")); + assert!(output.contains("timedelta")); + } + + #[test] + fn test_stdlib_before_third_party() { + let mut collector = ImportCollector::new(); + collector.add(&PythonImport::new("sqlmodel", "SQLModel")); + collector.add(&PythonImport::new("datetime", "datetime")); + + let output = collector.generate(); + let datetime_pos = output.find("datetime").unwrap(); + let sqlmodel_pos = output.find("sqlmodel").unwrap(); + assert!(datetime_pos < sqlmodel_pos); + } + + #[test] + fn test_type_checking_block() { + let mut collector = ImportCollector::new(); + collector.add(&PythonImport::new("typing", "TYPE_CHECKING")); + collector.add_type_checking(".user", "User"); + + let output = collector.generate(); + assert!(output.contains("if TYPE_CHECKING:")); + assert!(output.contains("from .user import User")); + } + + #[test] + fn test_sqlmodel_imports() { + let mut collector = ImportCollector::new(); + collector.add_sqlmodel(); + collector.add_field(); + + let output = collector.generate(); + assert!(output.contains("from sqlmodel import")); + assert!(output.contains("Field")); + assert!(output.contains("SQLModel")); + } + + #[test] + fn test_merge_collectors() { + let mut collector1 = ImportCollector::new(); + collector1.add(&PythonImport::new("datetime", "datetime")); + + let mut collector2 = ImportCollector::new(); + collector2.add(&PythonImport::new("datetime", "date")); + collector2.add(&PythonImport::new("sqlmodel", "SQLModel")); + + collector1.merge(&collector2); + let output = collector1.generate(); + + assert!(output.contains("date")); + assert!(output.contains("datetime")); + assert!(output.contains("SQLModel")); + } + + #[test] + fn test_deterministic_output() { + // Run multiple times to verify ordering is consistent + for _ in 0..5 { + let mut collector = ImportCollector::new(); + collector.add(&PythonImport::new("uuid", "UUID")); + collector.add(&PythonImport::new("datetime", "datetime")); + collector.add(&PythonImport::new("typing", "Any")); + collector.add(&PythonImport::new("sqlmodel", "SQLModel")); + collector.add(&PythonImport::new("sqlmodel", "Field")); + + let output = collector.generate(); + + // Check order: stdlib (datetime, typing, uuid) then third-party (sqlmodel) + let datetime_pos = output.find("from datetime").unwrap(); + let typing_pos = output.find("from typing").unwrap(); + let uuid_pos = output.find("from uuid").unwrap(); + let sqlmodel_pos = output.find("from sqlmodel").unwrap(); + + assert!(datetime_pos < typing_pos); + assert!(typing_pos < uuid_pos); + assert!(uuid_pos < sqlmodel_pos); + } + } + + #[test] + fn test_categorize_module() { + assert_eq!(categorize_module("datetime"), ModuleCategory::Stdlib); + assert_eq!(categorize_module("typing"), ModuleCategory::Stdlib); + assert_eq!(categorize_module("uuid"), ModuleCategory::Stdlib); + assert_eq!(categorize_module("sqlmodel"), ModuleCategory::ThirdParty); + assert_eq!(categorize_module("sqlalchemy"), ModuleCategory::ThirdParty); + assert_eq!( + categorize_module("sqlalchemy.types"), + ModuleCategory::ThirdParty + ); + } +} diff --git a/crates/codegen/src/python/mod.rs b/crates/codegen/src/python/mod.rs index 06c7d62..b94a2b2 100644 --- a/crates/codegen/src/python/mod.rs +++ b/crates/codegen/src/python/mod.rs @@ -1 +1,128 @@ -//! Python code generation module. +//! Python SQLModel code generation module. +//! +//! This module provides code generation for Python SQLModel models from PostgreSQL +//! table definitions. It supports: +//! +//! - Idiomatic SQLModel models with proper type mappings +//! - Primary keys, foreign keys, unique constraints, and indexes +//! - Optional relationship generation for foreign keys +//! - Generated and identity columns +//! - Proper handling of Python reserved words and invalid identifiers +//! +//! # Example +//! +//! ```ignore +//! use tern_codegen::{Codegen, python::PythonCodegen}; +//! +//! let codegen = PythonCodegen::new(PythonCodegenConfig::default()); +//! let output = codegen.generate(tables); +//! // output contains "models.py" with SQLModel classes +//! ``` + +mod field; +mod generator; +mod imports; +mod model; +mod naming; +mod type_mapping; + +#[cfg(test)] +mod tests; + +pub use generator::PythonCodegen; + +use std::fmt; + +/// Configuration for Python SQLModel code generation. +#[derive(Debug, Clone)] +pub struct PythonCodegenConfig { + /// Whether to generate Pydantic-only base classes for each model. + /// These are useful for request/response validation without DB coupling. + pub generate_base_models: bool, + + /// Module name prefix for generated imports (e.g., "app.models"). + pub module_prefix: Option, + + /// Whether to include docstrings from table/column comments. + pub include_docstrings: bool, + + /// How to handle Python reserved words in identifiers. + pub reserved_word_strategy: ReservedWordStrategy, + + /// Whether to generate relationship attributes for foreign keys. + pub generate_relationships: bool, + + /// Output mode: single file or multiple files. + pub output_mode: OutputMode, +} + +impl Default for PythonCodegenConfig { + fn default() -> Self { + Self { + generate_base_models: false, + module_prefix: None, + include_docstrings: true, + reserved_word_strategy: ReservedWordStrategy::default(), + generate_relationships: false, + output_mode: OutputMode::default(), + } + } +} + +/// Strategy for handling Python reserved words in identifiers. +#[derive(Debug, Clone, Default, PartialEq, Eq)] +pub enum ReservedWordStrategy { + /// Append an underscore: `class` -> `class_` + #[default] + AppendUnderscore, + /// Prepend with prefix: `class` -> `field_class` + PrependPrefix(String), +} + +impl fmt::Display for ReservedWordStrategy { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::AppendUnderscore => write!(f, "append_underscore"), + Self::PrependPrefix(prefix) => write!(f, "prepend_prefix({prefix})"), + } + } +} + +/// Output mode for generated code. +#[derive(Debug, Clone, Default, PartialEq, Eq)] +pub enum OutputMode { + /// Generate all models in a single `models.py` file. + #[default] + SingleFile, + /// Generate separate files for each model with a shared types module. + MultiFile, +} + +impl fmt::Display for OutputMode { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::SingleFile => write!(f, "single_file"), + Self::MultiFile => write!(f, "multi_file"), + } + } +} + +/// Errors that can occur during Python code generation. +#[derive(Debug, thiserror::Error)] +pub enum PythonCodegenError { + /// An unsupported PostgreSQL type was encountered. + #[error("unsupported PostgreSQL type: {0}")] + UnsupportedType(String), + + /// A table has no columns. + #[error("table has no columns: {0}")] + EmptyTable(String), + + /// An identifier is invalid after sanitization. + #[error("invalid identifier after sanitization: {0}")] + InvalidIdentifier(String), + + /// Code generation failed. + #[error("code generation failed: {0}")] + GenerationError(String), +} diff --git a/crates/codegen/src/python/model.rs b/crates/codegen/src/python/model.rs new file mode 100644 index 0000000..e2539fb --- /dev/null +++ b/crates/codegen/src/python/model.rs @@ -0,0 +1,490 @@ +//! SQLModel class generation. +//! +//! This module handles generating complete SQLModel class definitions from table schemas. + +use tern_ddl::{ConstraintKind, Table}; + +use super::PythonCodegenConfig; +use super::field::{FieldContext, FieldInfo, format_field_line, generate_field}; +use super::imports::ImportCollector; +use super::naming::to_class_name; + +/// Information about a generated SQLModel class. +#[derive(Debug)] +pub struct ModelInfo { + /// The Python class name. + pub class_name: String, + /// The database table name. + pub table_name: String, + /// Generated field information. + pub fields: Vec, + /// Required imports. + pub imports: ImportCollector, + /// Table args entries (for composite constraints, indexes, etc.). + pub table_args: Vec, + /// Docstring for the class (from table comment). + pub docstring: Option, + /// Whether this table has any columns. + pub has_columns: bool, + /// Warning messages for unsupported features. + pub warnings: Vec, +} + +/// Generates model information for a table. +pub fn generate_model(table: &Table, config: &PythonCodegenConfig) -> ModelInfo { + let table_name = table.name.as_ref().to_string(); + let class_name = to_class_name(&table_name); + + let mut imports = ImportCollector::new(); + let mut table_args = Vec::new(); + let mut warnings = Vec::new(); + + // Add base SQLModel import + imports.add_sqlmodel(); + + // Generate fields + let ctx = FieldContext::from_table(table, &config.reserved_word_strategy); + let mut fields: Vec = table + .columns + .iter() + .map(|col| { + let field = generate_field(col, &ctx); + imports.merge(&field.imports); + field + }) + .collect(); + + // Sort fields: primary keys first, then required fields, then optional fields + fields.sort_by(|a, b| { + // Primary keys come first + match (a.is_primary_key, b.is_primary_key) { + (true, false) => std::cmp::Ordering::Less, + (false, true) => std::cmp::Ordering::Greater, + _ => { + // Then non-nullable before nullable (by checking if type contains "| None") + let a_nullable = a.type_annotation.contains("| None"); + let b_nullable = b.type_annotation.contains("| None"); + match (a_nullable, b_nullable) { + (false, true) => std::cmp::Ordering::Less, + (true, false) => std::cmp::Ordering::Greater, + _ => std::cmp::Ordering::Equal, + } + } + } + }); + + // Generate __table_args__ for composite constraints and indexes + generate_table_args(table, &mut table_args, &mut imports, &mut warnings); + + // Generate docstring from table comment + let docstring = if config.include_docstrings { + table.comment.as_ref().map(|c| c.as_ref().to_string()) + } else { + None + }; + + // Check for empty table + let has_columns = !table.columns.is_empty(); + if !has_columns { + warnings.push(format!( + "Table '{}' has no columns. SQLModel requires at least one field.", + table_name + )); + } + + ModelInfo { + class_name, + table_name, + fields, + imports, + table_args, + docstring, + has_columns, + warnings, + } +} + +/// Generates __table_args__ entries for composite constraints and indexes. +fn generate_table_args( + table: &Table, + table_args: &mut Vec, + imports: &mut ImportCollector, + warnings: &mut Vec, +) { + // Check for composite primary key (already handled in fields via primary_key=True on each column) + // But if there's only one PK column with identity, it's handled differently + + // Composite unique constraints + for constraint in &table.constraints { + match &constraint.kind { + ConstraintKind::Unique(unique) if unique.columns.len() > 1 => { + let columns: Vec = unique + .columns + .iter() + .map(|c| format!("\"{}\"", c.as_ref())) + .collect(); + let constraint_name = constraint.name.as_ref(); + table_args.push(format!( + "UniqueConstraint({}, name=\"{}\")", + columns.join(", "), + constraint_name + )); + imports.add(&super::type_mapping::PythonImport::new( + "sqlalchemy", + "UniqueConstraint", + )); + } + ConstraintKind::Check(check) => { + let expr = check.expression.as_ref(); + let constraint_name = constraint.name.as_ref(); + table_args.push(format!( + "CheckConstraint(\"{}\", name=\"{}\")", + escape_python_string(expr), + constraint_name + )); + imports.add(&super::type_mapping::PythonImport::new( + "sqlalchemy", + "CheckConstraint", + )); + } + ConstraintKind::Exclusion(excl) => { + // Exclusion constraints are not supported by SQLModel + let constraint_name = constraint.name.as_ref(); + let elements: Vec = excl + .elements + .iter() + .map(|e| format!("{} WITH {}", e.expression.as_ref(), e.operator)) + .collect(); + warnings.push(format!( + "Exclusion constraint '{}' not supported by SQLModel: EXCLUDE USING {} ({})", + constraint_name, + excl.index_method.as_str(), + elements.join(", ") + )); + } + _ => {} + } + } + + // Composite indexes (non-constraint indexes with multiple columns) + for index in &table.indexes { + if !index.is_constraint_index && index.columns.len() > 1 { + let columns: Vec = index + .columns + .iter() + .filter_map(|ic| ic.column.as_ref().map(|c| format!("\"{}\"", c.as_ref()))) + .collect(); + + if !columns.is_empty() { + let index_name = index.name.as_ref(); + table_args.push(format!("Index(\"{}\", {})", index_name, columns.join(", "))); + imports.add(&super::type_mapping::PythonImport::new( + "sqlalchemy", + "Index", + )); + } + } + } +} + +/// Escapes a string for use in a Python string literal. +fn escape_python_string(s: &str) -> String { + s.replace('\\', "\\\\") + .replace('"', "\\\"") + .replace('\n', "\\n") + .replace('\r', "\\r") + .replace('\t', "\\t") +} + +/// Formats a complete SQLModel class definition. +pub fn format_model(model: &ModelInfo) -> String { + let mut lines = Vec::new(); + + // Warning comments + for warning in &model.warnings { + lines.push(format!("# WARNING: {}", warning)); + } + + // Class definition + lines.push(format!("class {}(SQLModel, table=True):", model.class_name)); + + // Docstring + if let Some(ref doc) = model.docstring { + lines.push(format!(" \"\"\"{}\"\"\"", escape_python_string(doc))); + lines.push(String::new()); + } + + // __tablename__ + lines.push(format!(" __tablename__ = \"{}\"", model.table_name)); + + // __table_args__ + if !model.table_args.is_empty() { + if model.table_args.len() == 1 { + lines.push(format!(" __table_args__ = ({},)", model.table_args[0])); + } else { + lines.push(" __table_args__ = (".to_string()); + for arg in &model.table_args { + lines.push(format!(" {},", arg)); + } + lines.push(" )".to_string()); + } + } + + lines.push(String::new()); + + // Fields + if model.has_columns { + for field in &model.fields { + lines.push(format_field_line(field)); + } + } else { + lines.push(" pass # No columns defined".to_string()); + } + + lines.join("\n") +} + +#[cfg(test)] +mod tests { + use super::*; + use tern_ddl::types::{QualifiedCollationName, SqlExpr}; + use tern_ddl::{ + CheckConstraint, Column, ColumnName, Constraint, ConstraintName, IndexName, Oid, + PrimaryKeyConstraint, SchemaName, TableKind, TableName, TypeInfo, TypeName, + UniqueConstraint, + }; + + fn make_column(name: &str, type_name: &str, is_nullable: bool) -> Column { + Column { + name: ColumnName::try_new(name.to_string()).unwrap(), + position: 1, + type_info: TypeInfo { + name: TypeName::try_new(type_name.to_string()).unwrap(), + schema: SchemaName::try_new("pg_catalog".to_string()).unwrap(), + formatted: type_name.to_string(), + is_array: false, + }, + is_nullable, + default: None, + generated: None, + identity: None, + collation: QualifiedCollationName::new( + SchemaName::try_new("pg_catalog".to_string()).unwrap(), + tern_ddl::CollationName::try_new("default".to_string()).unwrap(), + ), + comment: None, + } + } + + fn make_table_with_pk(name: &str, columns: Vec, pk_columns: &[&str]) -> Table { + let pk = Constraint { + name: ConstraintName::try_new(format!("{}_pkey", name)).unwrap(), + kind: ConstraintKind::PrimaryKey(PrimaryKeyConstraint { + columns: pk_columns + .iter() + .map(|c| ColumnName::try_new(c.to_string()).unwrap()) + .collect(), + index_name: IndexName::try_new(format!("{}_pkey", name)).unwrap(), + }), + comment: None, + }; + + Table { + oid: Oid::new(1), + name: TableName::try_new(name.to_string()).unwrap(), + kind: TableKind::Regular, + columns, + constraints: vec![pk], + indexes: vec![], + comment: None, + } + } + + #[test] + fn test_generate_simple_model() { + let columns = vec![ + make_column("id", "int4", false), + make_column("name", "text", false), + make_column("email", "text", false), + ]; + let table = make_table_with_pk("users", columns, &["id"]); + let config = PythonCodegenConfig::default(); + + let model = generate_model(&table, &config); + + assert_eq!(model.class_name, "User"); + assert_eq!(model.table_name, "users"); + assert_eq!(model.fields.len(), 3); + assert!(model.has_columns); + assert!(model.warnings.is_empty()); + } + + #[test] + fn test_generate_model_with_nullable() { + let columns = vec![ + make_column("id", "int4", false), + make_column("name", "text", false), + make_column("bio", "text", true), + ]; + let table = make_table_with_pk("users", columns, &["id"]); + let config = PythonCodegenConfig::default(); + + let model = generate_model(&table, &config); + + // Fields should be sorted: PK first, then non-nullable, then nullable + assert!(model.fields[0].is_primary_key); + assert!(!model.fields[1].type_annotation.contains("| None")); + assert!(model.fields[2].type_annotation.contains("| None")); + } + + #[test] + fn test_generate_model_with_composite_unique() { + let columns = vec![ + make_column("id", "int4", false), + make_column("email", "text", false), + make_column("tenant_id", "int4", false), + ]; + let pk = Constraint { + name: ConstraintName::try_new("users_pkey".to_string()).unwrap(), + kind: ConstraintKind::PrimaryKey(PrimaryKeyConstraint { + columns: vec![ColumnName::try_new("id".to_string()).unwrap()], + index_name: IndexName::try_new("users_pkey".to_string()).unwrap(), + }), + comment: None, + }; + let unique = Constraint { + name: ConstraintName::try_new("users_email_tenant_key".to_string()).unwrap(), + kind: ConstraintKind::Unique(UniqueConstraint { + columns: vec![ + ColumnName::try_new("email".to_string()).unwrap(), + ColumnName::try_new("tenant_id".to_string()).unwrap(), + ], + index_name: IndexName::try_new("users_email_tenant_key".to_string()).unwrap(), + nulls_not_distinct: false, + }), + comment: None, + }; + let table = Table { + oid: Oid::new(1), + name: TableName::try_new("users".to_string()).unwrap(), + kind: TableKind::Regular, + columns, + constraints: vec![pk, unique], + indexes: vec![], + comment: None, + }; + let config = PythonCodegenConfig::default(); + + let model = generate_model(&table, &config); + + assert!(!model.table_args.is_empty()); + assert!(model.table_args[0].contains("UniqueConstraint")); + assert!(model.table_args[0].contains("email")); + assert!(model.table_args[0].contains("tenant_id")); + } + + #[test] + fn test_generate_model_with_check_constraint() { + let columns = vec![ + make_column("id", "int4", false), + make_column("price", "numeric", false), + ]; + let pk = Constraint { + name: ConstraintName::try_new("products_pkey".to_string()).unwrap(), + kind: ConstraintKind::PrimaryKey(PrimaryKeyConstraint { + columns: vec![ColumnName::try_new("id".to_string()).unwrap()], + index_name: IndexName::try_new("products_pkey".to_string()).unwrap(), + }), + comment: None, + }; + let check = Constraint { + name: ConstraintName::try_new("products_price_positive".to_string()).unwrap(), + kind: ConstraintKind::Check(CheckConstraint { + expression: SqlExpr::new("price > 0".to_string()), + is_no_inherit: false, + }), + comment: None, + }; + let table = Table { + oid: Oid::new(1), + name: TableName::try_new("products".to_string()).unwrap(), + kind: TableKind::Regular, + columns, + constraints: vec![pk, check], + indexes: vec![], + comment: None, + }; + let config = PythonCodegenConfig::default(); + + let model = generate_model(&table, &config); + + assert!(!model.table_args.is_empty()); + assert!(model.table_args[0].contains("CheckConstraint")); + assert!(model.table_args[0].contains("price > 0")); + } + + #[test] + fn test_format_model() { + let columns = vec![ + make_column("id", "int4", false), + make_column("name", "text", false), + ]; + let table = make_table_with_pk("users", columns, &["id"]); + let config = PythonCodegenConfig::default(); + + let model = generate_model(&table, &config); + let output = format_model(&model); + + assert!(output.contains("class User(SQLModel, table=True):")); + assert!(output.contains("__tablename__ = \"users\"")); + assert!(output.contains("id:")); + assert!(output.contains("name:")); + } + + #[test] + fn test_format_model_with_docstring() { + let columns = vec![make_column("id", "int4", false)]; + let mut table = make_table_with_pk("users", columns, &["id"]); + table.comment = Some(tern_ddl::types::Comment::new( + "User accounts table".to_string(), + )); + + let config = PythonCodegenConfig { + include_docstrings: true, + ..Default::default() + }; + + let model = generate_model(&table, &config); + let output = format_model(&model); + + assert!(output.contains("\"\"\"User accounts table\"\"\"")); + } + + #[test] + fn test_empty_table_warning() { + let table = Table { + oid: Oid::new(1), + name: TableName::try_new("empty".to_string()).unwrap(), + kind: TableKind::Regular, + columns: vec![], + constraints: vec![], + indexes: vec![], + comment: None, + }; + let config = PythonCodegenConfig::default(); + + let model = generate_model(&table, &config); + + assert!(!model.has_columns); + assert!(!model.warnings.is_empty()); + assert!(model.warnings[0].contains("no columns")); + } + + #[test] + fn test_escape_python_string() { + assert_eq!(escape_python_string("hello"), "hello"); + assert_eq!(escape_python_string("say \"hi\""), "say \\\"hi\\\""); + assert_eq!(escape_python_string("line1\nline2"), "line1\\nline2"); + assert_eq!(escape_python_string("path\\to\\file"), "path\\\\to\\\\file"); + } +} diff --git a/crates/codegen/src/python/naming.rs b/crates/codegen/src/python/naming.rs new file mode 100644 index 0000000..0679e1b --- /dev/null +++ b/crates/codegen/src/python/naming.rs @@ -0,0 +1,513 @@ +//! Python identifier naming and sanitization. +//! +//! This module handles conversion of PostgreSQL identifiers to valid Python identifiers, +//! including handling of reserved words, invalid characters, and naming conventions. + +use super::ReservedWordStrategy; + +/// Python reserved words (keywords). +/// +/// These cannot be used as identifiers in Python without modification. +const PYTHON_KEYWORDS: &[&str] = &[ + "False", "None", "True", "and", "as", "assert", "async", "await", "break", "class", "continue", + "def", "del", "elif", "else", "except", "finally", "for", "from", "global", "if", "import", + "in", "is", "lambda", "nonlocal", "not", "or", "pass", "raise", "return", "try", "while", + "with", "yield", +]; + +/// Python soft keywords (context-dependent keywords introduced in Python 3.10+). +const PYTHON_SOFT_KEYWORDS: &[&str] = &["match", "case", "_", "type"]; + +/// Python built-in names that should be avoided to prevent shadowing. +/// Kept for future use to warn about shadowing built-in names. +#[allow(dead_code)] +const PYTHON_BUILTINS: &[&str] = &[ + "abs", + "all", + "any", + "ascii", + "bin", + "bool", + "breakpoint", + "bytearray", + "bytes", + "callable", + "chr", + "classmethod", + "compile", + "complex", + "delattr", + "dict", + "dir", + "divmod", + "enumerate", + "eval", + "exec", + "filter", + "float", + "format", + "frozenset", + "getattr", + "globals", + "hasattr", + "hash", + "help", + "hex", + "id", + "input", + "int", + "isinstance", + "issubclass", + "iter", + "len", + "list", + "locals", + "map", + "max", + "memoryview", + "min", + "next", + "object", + "oct", + "open", + "ord", + "pow", + "print", + "property", + "range", + "repr", + "reversed", + "round", + "set", + "setattr", + "slice", + "sorted", + "staticmethod", + "str", + "sum", + "super", + "tuple", + "type", + "vars", + "zip", +]; + +/// Checks if a name is a Python keyword. +pub fn is_python_keyword(name: &str) -> bool { + PYTHON_KEYWORDS.contains(&name) +} + +/// Checks if a name is a Python soft keyword. +pub fn is_python_soft_keyword(name: &str) -> bool { + PYTHON_SOFT_KEYWORDS.contains(&name) +} + +/// Checks if a name is a Python builtin. +/// Kept for future use to warn about shadowing built-in names. +#[allow(dead_code)] +pub fn is_python_builtin(name: &str) -> bool { + PYTHON_BUILTINS.contains(&name) +} + +/// Checks if a name is a reserved word (keyword or soft keyword). +pub fn is_reserved_word(name: &str) -> bool { + is_python_keyword(name) || is_python_soft_keyword(name) +} + +/// Sanitizes a database identifier for use as a Python identifier. +/// +/// This function: +/// 1. Replaces invalid characters with underscores +/// 2. Ensures the name doesn't start with a digit +/// 3. Handles reserved words according to the strategy +/// 4. Returns the sanitized name and whether aliasing is needed +/// +/// Returns `(sanitized_name, needs_alias)` where `needs_alias` is true if +/// the name was modified and needs a Field(alias="original_name") declaration. +pub fn sanitize_identifier(name: &str, strategy: &ReservedWordStrategy) -> (String, bool) { + if name.is_empty() { + return ("_empty".to_string(), true); + } + + let mut sanitized = String::with_capacity(name.len() + 1); + let mut needs_alias = false; + + // Process each character + for (i, c) in name.chars().enumerate() { + if i == 0 { + // First character must be a letter or underscore + if c.is_ascii_alphabetic() || c == '_' { + sanitized.push(c); + } else if c.is_ascii_digit() { + // Prefix with underscore if starts with digit + sanitized.push('_'); + sanitized.push(c); + needs_alias = true; + } else { + // Replace invalid first character with underscore + sanitized.push('_'); + needs_alias = true; + } + } else if c.is_ascii_alphanumeric() || c == '_' { + sanitized.push(c); + } else { + // Replace invalid characters with underscore + sanitized.push('_'); + needs_alias = true; + } + } + + // Collapse multiple consecutive underscores + let mut collapsed = String::with_capacity(sanitized.len()); + let mut prev_underscore = false; + for c in sanitized.chars() { + if c == '_' { + if !prev_underscore { + collapsed.push(c); + } else { + needs_alias = true; + } + prev_underscore = true; + } else { + collapsed.push(c); + prev_underscore = false; + } + } + sanitized = collapsed; + + // Remove trailing underscores (unless it's the only character) + while sanitized.len() > 1 && sanitized.ends_with('_') && !name.ends_with('_') { + sanitized.pop(); + needs_alias = true; + } + + // Handle reserved words + if is_reserved_word(&sanitized) { + needs_alias = true; + sanitized = apply_reserved_word_strategy(&sanitized, strategy); + } + + // Final validation - ensure we have a valid identifier + if sanitized.is_empty() || sanitized == "_" { + return ("_field".to_string(), true); + } + + (sanitized, needs_alias) +} + +/// Applies the reserved word strategy to transform a reserved word. +fn apply_reserved_word_strategy(name: &str, strategy: &ReservedWordStrategy) -> String { + match strategy { + ReservedWordStrategy::AppendUnderscore => format!("{name}_"), + ReservedWordStrategy::PrependPrefix(prefix) => format!("{prefix}{name}"), + } +} + +/// Converts a table name to a Python class name (PascalCase). +/// +/// Examples: +/// - "users" -> "User" +/// - "user_accounts" -> "UserAccount" +/// - "order_items" -> "OrderItem" +pub fn to_class_name(table_name: &str) -> String { + let mut result = String::with_capacity(table_name.len()); + let mut capitalize_next = true; + + // Check for pluralization - handle "ies" -> "y" first (e.g., "categories" -> "category") + let name = if table_name.ends_with("ies") && table_name.len() > 3 { + let base = &table_name[..table_name.len() - 3]; + return to_pascal_case(base) + "y"; + } else if table_name.ends_with('s') + && !table_name.ends_with("ss") + && !table_name.ends_with("us") + && !table_name.ends_with("is") + { + // Simple plural: strip trailing 's' + &table_name[..table_name.len() - 1] + } else { + table_name + }; + + // Convert to PascalCase + for c in name.chars() { + if c == '_' || c == '-' || c == ' ' { + capitalize_next = true; + } else if capitalize_next { + result.push(c.to_ascii_uppercase()); + capitalize_next = false; + } else { + result.push(c.to_ascii_lowercase()); + } + } + + // Handle edge case where result is empty + if result.is_empty() { + result = "Model".to_string(); + } + + result +} + +/// Converts a string to PascalCase. +fn to_pascal_case(s: &str) -> String { + let mut result = String::with_capacity(s.len()); + let mut capitalize_next = true; + + for c in s.chars() { + if c == '_' || c == '-' || c == ' ' { + capitalize_next = true; + } else if capitalize_next { + result.push(c.to_ascii_uppercase()); + capitalize_next = false; + } else { + result.push(c.to_ascii_lowercase()); + } + } + + result +} + +/// Converts a table name to a Python module name (snake_case, singular). +/// +/// Examples: +/// - "users" -> "user" +/// - "user_accounts" -> "user_account" +/// - "OrderItems" -> "order_item" +pub fn to_module_name(table_name: &str) -> String { + let mut result = String::with_capacity(table_name.len()); + let mut prev_was_upper = false; + + for (i, c) in table_name.chars().enumerate() { + if c.is_ascii_uppercase() { + if i > 0 && !prev_was_upper { + result.push('_'); + } + result.push(c.to_ascii_lowercase()); + prev_was_upper = true; + } else if c == '-' || c == ' ' { + result.push('_'); + prev_was_upper = false; + } else { + result.push(c); + prev_was_upper = false; + } + } + + // Remove plural 's' suffix + if result.ends_with('s') + && !result.ends_with("ss") + && !result.ends_with("us") + && !result.ends_with("is") + && result.len() > 1 + { + result.pop(); + } else if result.ends_with("ies") && result.len() > 3 { + // Handle "ies" -> "y" + result.truncate(result.len() - 3); + result.push('y'); + } + + result +} + +/// Converts a column name to a Python attribute name (snake_case). +/// +/// PostgreSQL column names are typically already in snake_case, but this +/// handles edge cases like mixed case or special characters. +pub fn to_attribute_name(column_name: &str, strategy: &ReservedWordStrategy) -> (String, bool) { + // First sanitize the identifier + let (sanitized, needs_alias) = sanitize_identifier(column_name, strategy); + + // Convert to snake_case if needed (handle camelCase input) + let mut result = String::with_capacity(sanitized.len() + 4); + let mut prev_was_upper = false; + let mut prev_was_underscore = true; // Treat start as after underscore + + for c in sanitized.chars() { + if c.is_ascii_uppercase() { + if !prev_was_upper && !prev_was_underscore { + result.push('_'); + } + result.push(c.to_ascii_lowercase()); + prev_was_upper = true; + prev_was_underscore = false; + } else if c == '_' { + if !prev_was_underscore { + result.push(c); + } + prev_was_upper = false; + prev_was_underscore = true; + } else { + result.push(c); + prev_was_upper = false; + prev_was_underscore = false; + } + } + + // Check if conversion changed the name + let conversion_changed = result != sanitized; + + (result, needs_alias || conversion_changed) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_is_python_keyword() { + assert!(is_python_keyword("class")); + assert!(is_python_keyword("def")); + assert!(is_python_keyword("if")); + assert!(is_python_keyword("True")); + assert!(is_python_keyword("None")); + assert!(!is_python_keyword("user")); + assert!(!is_python_keyword("name")); + } + + #[test] + fn test_is_python_soft_keyword() { + assert!(is_python_soft_keyword("match")); + assert!(is_python_soft_keyword("case")); + assert!(!is_python_soft_keyword("class")); + } + + #[test] + fn test_is_reserved_word() { + assert!(is_reserved_word("class")); + assert!(is_reserved_word("match")); + assert!(!is_reserved_word("user")); + } + + #[test] + fn test_sanitize_identifier_valid() { + let strategy = ReservedWordStrategy::AppendUnderscore; + let (name, needs_alias) = sanitize_identifier("user_id", &strategy); + assert_eq!(name, "user_id"); + assert!(!needs_alias); + } + + #[test] + fn test_sanitize_identifier_reserved_word() { + let strategy = ReservedWordStrategy::AppendUnderscore; + let (name, needs_alias) = sanitize_identifier("class", &strategy); + assert_eq!(name, "class_"); + assert!(needs_alias); + } + + #[test] + fn test_sanitize_identifier_reserved_word_prefix() { + let strategy = ReservedWordStrategy::PrependPrefix("field_".to_string()); + let (name, needs_alias) = sanitize_identifier("class", &strategy); + assert_eq!(name, "field_class"); + assert!(needs_alias); + } + + #[test] + fn test_sanitize_identifier_starts_with_digit() { + let strategy = ReservedWordStrategy::AppendUnderscore; + let (name, needs_alias) = sanitize_identifier("1column", &strategy); + assert_eq!(name, "_1column"); + assert!(needs_alias); + } + + #[test] + fn test_sanitize_identifier_special_characters() { + let strategy = ReservedWordStrategy::AppendUnderscore; + let (name, needs_alias) = sanitize_identifier("column-name", &strategy); + assert_eq!(name, "column_name"); + assert!(needs_alias); + } + + #[test] + fn test_sanitize_identifier_spaces() { + let strategy = ReservedWordStrategy::AppendUnderscore; + let (name, needs_alias) = sanitize_identifier("column name", &strategy); + assert_eq!(name, "column_name"); + assert!(needs_alias); + } + + #[test] + fn test_sanitize_identifier_multiple_underscores() { + let strategy = ReservedWordStrategy::AppendUnderscore; + let (name, needs_alias) = sanitize_identifier("column__name", &strategy); + assert_eq!(name, "column_name"); + assert!(needs_alias); + } + + #[test] + fn test_sanitize_identifier_empty() { + let strategy = ReservedWordStrategy::AppendUnderscore; + let (name, needs_alias) = sanitize_identifier("", &strategy); + assert_eq!(name, "_empty"); + assert!(needs_alias); + } + + #[test] + fn test_to_class_name_simple() { + assert_eq!(to_class_name("users"), "User"); + assert_eq!(to_class_name("user"), "User"); + } + + #[test] + fn test_to_class_name_compound() { + assert_eq!(to_class_name("user_accounts"), "UserAccount"); + assert_eq!(to_class_name("order_items"), "OrderItem"); + } + + #[test] + fn test_to_class_name_categories() { + assert_eq!(to_class_name("categories"), "Category"); + } + + #[test] + fn test_to_class_name_preserves_non_plural() { + assert_eq!(to_class_name("status"), "Status"); + assert_eq!(to_class_name("address"), "Address"); + } + + #[test] + fn test_to_module_name_simple() { + assert_eq!(to_module_name("users"), "user"); + assert_eq!(to_module_name("User"), "user"); + } + + #[test] + fn test_to_module_name_compound() { + assert_eq!(to_module_name("user_accounts"), "user_account"); + assert_eq!(to_module_name("OrderItems"), "order_item"); + } + + #[test] + fn test_to_attribute_name_simple() { + let strategy = ReservedWordStrategy::AppendUnderscore; + let (name, needs_alias) = to_attribute_name("user_id", &strategy); + assert_eq!(name, "user_id"); + assert!(!needs_alias); + } + + #[test] + fn test_to_attribute_name_camel_case() { + let strategy = ReservedWordStrategy::AppendUnderscore; + let (name, needs_alias) = to_attribute_name("userId", &strategy); + assert_eq!(name, "user_id"); + assert!(needs_alias); + } + + #[test] + fn test_to_attribute_name_reserved() { + let strategy = ReservedWordStrategy::AppendUnderscore; + let (name, needs_alias) = to_attribute_name("from", &strategy); + assert_eq!(name, "from_"); + assert!(needs_alias); + } + + #[test] + fn test_python_builtins() { + assert!(is_python_builtin("list")); + assert!(is_python_builtin("dict")); + assert!(is_python_builtin("str")); + assert!(is_python_builtin("int")); + assert!(!is_python_builtin("user")); + } +} diff --git a/crates/codegen/src/python/tests/mod.rs b/crates/codegen/src/python/tests/mod.rs new file mode 100644 index 0000000..c084c98 --- /dev/null +++ b/crates/codegen/src/python/tests/mod.rs @@ -0,0 +1,7 @@ +//! Tests for Python SQLModel code generation. +//! +//! This module contains comprehensive tests including unit tests, snapshot tests, +//! and edge case tests for the Python code generator. + +mod snapshot_tests; +mod unit_tests; diff --git a/crates/codegen/src/python/tests/snapshot_tests.rs b/crates/codegen/src/python/tests/snapshot_tests.rs new file mode 100644 index 0000000..2483886 --- /dev/null +++ b/crates/codegen/src/python/tests/snapshot_tests.rs @@ -0,0 +1,577 @@ +//! Snapshot tests for Python code generation. +//! +//! These tests use insta to capture and verify the full output of the code generator, +//! making it easy to review changes to the generated code format. + +use crate::Codegen; +use crate::python::{OutputMode, PythonCodegen, PythonCodegenConfig}; + +use tern_ddl::types::{ + Comment, ForeignKeyAction, IndexMethod, QualifiedCollationName, QualifiedName, SqlExpr, +}; +use tern_ddl::{ + CheckConstraint, CollationName, Column, ColumnName, Constraint, ConstraintKind, ConstraintName, + ExclusionConstraint, ExclusionElement, ForeignKeyConstraint, IdentityKind, IndexName, Oid, + PrimaryKeyConstraint, SchemaName, Table, TableKind, TableName, TypeInfo, TypeName, + UniqueConstraint, +}; + +// ============================================================================= +// Test Helpers +// ============================================================================= + +fn make_type_info(name: &str, formatted: &str, is_array: bool) -> TypeInfo { + TypeInfo { + name: TypeName::try_new(name.to_string()).unwrap(), + schema: SchemaName::try_new("pg_catalog".to_string()).unwrap(), + formatted: formatted.to_string(), + is_array, + } +} + +fn make_column(name: &str, type_name: &str, is_nullable: bool) -> Column { + Column { + name: ColumnName::try_new(name.to_string()).unwrap(), + position: 1, + type_info: make_type_info(type_name, type_name, false), + is_nullable, + default: None, + generated: None, + identity: None, + collation: QualifiedCollationName::new( + SchemaName::try_new("pg_catalog".to_string()).unwrap(), + CollationName::try_new("default".to_string()).unwrap(), + ), + comment: None, + } +} + +fn make_pk_constraint(table_name: &str, columns: &[&str]) -> Constraint { + Constraint { + name: ConstraintName::try_new(format!("{}_pkey", table_name)).unwrap(), + kind: ConstraintKind::PrimaryKey(PrimaryKeyConstraint { + columns: columns + .iter() + .map(|c| ColumnName::try_new(c.to_string()).unwrap()) + .collect(), + index_name: IndexName::try_new(format!("{}_pkey", table_name)).unwrap(), + }), + comment: None, + } +} + +fn make_table(name: &str, columns: Vec, constraints: Vec) -> Table { + Table { + oid: Oid::new(1), + name: TableName::try_new(name.to_string()).unwrap(), + kind: TableKind::Regular, + columns, + constraints, + indexes: vec![], + comment: None, + } +} + +// ============================================================================= +// Snapshot Tests +// ============================================================================= + +#[test] +fn snapshot_simple_users_table() { + let columns = vec![ + { + let mut col = make_column("id", "int4", false); + col.identity = Some(IdentityKind::Always); + col + }, + make_column("email", "text", false), + make_column("name", "text", true), + { + let mut col = make_column("created_at", "timestamptz", false); + col.type_info = make_type_info("timestamptz", "timestamp with time zone", false); + col + }, + ]; + + let constraints = vec![ + make_pk_constraint("users", &["id"]), + Constraint { + name: ConstraintName::try_new("users_email_key".to_string()).unwrap(), + kind: ConstraintKind::Unique(UniqueConstraint { + columns: vec![ColumnName::try_new("email".to_string()).unwrap()], + index_name: IndexName::try_new("users_email_key".to_string()).unwrap(), + nulls_not_distinct: false, + }), + comment: None, + }, + ]; + + let mut table = make_table("users", columns, constraints); + table.comment = Some(Comment::new( + "User accounts for the application.".to_string(), + )); + + let codegen = PythonCodegen::with_defaults(); + let output = codegen.generate(vec![table]); + + insta::assert_snapshot!(output.get("models.py").unwrap()); +} + +#[test] +fn snapshot_blog_schema() { + // Users table + let user_columns = vec![ + { + let mut col = make_column("id", "int4", false); + col.identity = Some(IdentityKind::Always); + col + }, + make_column("username", "text", false), + make_column("email", "text", false), + make_column("bio", "text", true), + make_column("created_at", "timestamptz", false), + ]; + let user_constraints = vec![ + make_pk_constraint("users", &["id"]), + Constraint { + name: ConstraintName::try_new("users_username_key".to_string()).unwrap(), + kind: ConstraintKind::Unique(UniqueConstraint { + columns: vec![ColumnName::try_new("username".to_string()).unwrap()], + index_name: IndexName::try_new("users_username_key".to_string()).unwrap(), + nulls_not_distinct: false, + }), + comment: None, + }, + Constraint { + name: ConstraintName::try_new("users_email_key".to_string()).unwrap(), + kind: ConstraintKind::Unique(UniqueConstraint { + columns: vec![ColumnName::try_new("email".to_string()).unwrap()], + index_name: IndexName::try_new("users_email_key".to_string()).unwrap(), + nulls_not_distinct: false, + }), + comment: None, + }, + ]; + let users = make_table("users", user_columns, user_constraints); + + // Posts table + let post_columns = vec![ + { + let mut col = make_column("id", "int4", false); + col.identity = Some(IdentityKind::Always); + col + }, + make_column("title", "text", false), + make_column("content", "text", false), + make_column("author_id", "int4", false), + make_column("published_at", "timestamptz", true), + make_column("created_at", "timestamptz", false), + ]; + let post_constraints = vec![ + make_pk_constraint("posts", &["id"]), + Constraint { + name: ConstraintName::try_new("posts_author_id_fkey".to_string()).unwrap(), + kind: ConstraintKind::ForeignKey(ForeignKeyConstraint { + columns: vec![ColumnName::try_new("author_id".to_string()).unwrap()], + referenced_table: QualifiedName::new( + SchemaName::try_new("public".to_string()).unwrap(), + TableName::try_new("users".to_string()).unwrap(), + ), + referenced_columns: vec![ColumnName::try_new("id".to_string()).unwrap()], + on_delete: ForeignKeyAction::Cascade, + on_update: ForeignKeyAction::NoAction, + is_deferrable: false, + is_initially_deferred: false, + }), + comment: None, + }, + ]; + let posts = make_table("posts", post_columns, post_constraints); + + // Comments table + let comment_columns = vec![ + { + let mut col = make_column("id", "int4", false); + col.identity = Some(IdentityKind::Always); + col + }, + make_column("content", "text", false), + make_column("post_id", "int4", false), + make_column("author_id", "int4", false), + make_column("created_at", "timestamptz", false), + ]; + let comment_constraints = vec![ + make_pk_constraint("comments", &["id"]), + Constraint { + name: ConstraintName::try_new("comments_post_id_fkey".to_string()).unwrap(), + kind: ConstraintKind::ForeignKey(ForeignKeyConstraint { + columns: vec![ColumnName::try_new("post_id".to_string()).unwrap()], + referenced_table: QualifiedName::new( + SchemaName::try_new("public".to_string()).unwrap(), + TableName::try_new("posts".to_string()).unwrap(), + ), + referenced_columns: vec![ColumnName::try_new("id".to_string()).unwrap()], + on_delete: ForeignKeyAction::Cascade, + on_update: ForeignKeyAction::NoAction, + is_deferrable: false, + is_initially_deferred: false, + }), + comment: None, + }, + Constraint { + name: ConstraintName::try_new("comments_author_id_fkey".to_string()).unwrap(), + kind: ConstraintKind::ForeignKey(ForeignKeyConstraint { + columns: vec![ColumnName::try_new("author_id".to_string()).unwrap()], + referenced_table: QualifiedName::new( + SchemaName::try_new("public".to_string()).unwrap(), + TableName::try_new("users".to_string()).unwrap(), + ), + referenced_columns: vec![ColumnName::try_new("id".to_string()).unwrap()], + on_delete: ForeignKeyAction::Cascade, + on_update: ForeignKeyAction::NoAction, + is_deferrable: false, + is_initially_deferred: false, + }), + comment: None, + }, + ]; + let comments = make_table("comments", comment_columns, comment_constraints); + + let codegen = PythonCodegen::with_defaults(); + let output = codegen.generate(vec![users, posts, comments]); + + insta::assert_snapshot!(output.get("models.py").unwrap()); +} + +#[test] +fn snapshot_all_postgres_types() { + let columns = vec![ + // Integer types + make_column("col_int2", "int2", false), + make_column("col_int4", "int4", false), + make_column("col_int8", "int8", false), + // Float types + make_column("col_float4", "float4", false), + make_column("col_float8", "float8", false), + // Decimal + { + let mut col = make_column("col_numeric", "numeric", false); + col.type_info = make_type_info("numeric", "numeric(10,2)", false); + col + }, + // Boolean + make_column("col_bool", "bool", false), + // Text types + make_column("col_text", "text", false), + { + let mut col = make_column("col_varchar", "varchar", false); + col.type_info = make_type_info("varchar", "character varying(255)", false); + col + }, + // Date/time types + make_column("col_date", "date", false), + make_column("col_time", "time", false), + make_column("col_timestamp", "timestamp", false), + make_column("col_timestamptz", "timestamptz", false), + make_column("col_interval", "interval", false), + // UUID + make_column("col_uuid", "uuid", false), + // JSON + make_column("col_json", "json", true), + make_column("col_jsonb", "jsonb", true), + // Binary + make_column("col_bytea", "bytea", true), + // Network types + make_column("col_inet", "inet", true), + // Array types + { + let mut col = make_column("col_int_array", "int4", false); + col.type_info = make_type_info("int4", "integer[]", true); + col + }, + { + let mut col = make_column("col_text_array", "text", true); + col.type_info = make_type_info("text", "text[]", true); + col + }, + ]; + + // Add a simple PK + let mut all_columns = vec![{ + let mut col = make_column("id", "int4", false); + col.identity = Some(IdentityKind::Always); + col + }]; + all_columns.extend(columns); + + let constraints = vec![make_pk_constraint("all_types", &["id"])]; + let table = make_table("all_types", all_columns, constraints); + + let codegen = PythonCodegen::with_defaults(); + let output = codegen.generate(vec![table]); + + insta::assert_snapshot!(output.get("models.py").unwrap()); +} + +#[test] +fn snapshot_reserved_words_table() { + let columns = vec![ + { + let mut col = make_column("id", "int4", false); + col.identity = Some(IdentityKind::Always); + col + }, + make_column("class", "text", false), + make_column("from", "text", false), + make_column("import", "text", true), + make_column("def", "int4", false), + make_column("return", "bool", false), + make_column("yield", "text", true), + make_column("async", "bool", false), + make_column("await", "text", true), + make_column("match", "text", true), + make_column("case", "int4", true), + ]; + + let constraints = vec![make_pk_constraint("reserved_words", &["id"])]; + let table = make_table("reserved_words", columns, constraints); + + let codegen = PythonCodegen::with_defaults(); + let output = codegen.generate(vec![table]); + + insta::assert_snapshot!(output.get("models.py").unwrap()); +} + +#[test] +fn snapshot_complex_constraints() { + let columns = vec![ + { + let mut col = make_column("id", "int4", false); + col.identity = Some(IdentityKind::Always); + col + }, + make_column("email", "text", false), + make_column("tenant_id", "int4", false), + make_column("status", "text", false), + make_column("price", "numeric", false), + make_column("quantity", "int4", false), + ]; + + let constraints = vec![ + make_pk_constraint("products", &["id"]), + // Composite unique + Constraint { + name: ConstraintName::try_new("products_email_tenant_key".to_string()).unwrap(), + kind: ConstraintKind::Unique(UniqueConstraint { + columns: vec![ + ColumnName::try_new("email".to_string()).unwrap(), + ColumnName::try_new("tenant_id".to_string()).unwrap(), + ], + index_name: IndexName::try_new("products_email_tenant_key".to_string()).unwrap(), + nulls_not_distinct: false, + }), + comment: None, + }, + // Check constraints + Constraint { + name: ConstraintName::try_new("products_price_positive".to_string()).unwrap(), + kind: ConstraintKind::Check(CheckConstraint { + expression: SqlExpr::new("price > 0".to_string()), + is_no_inherit: false, + }), + comment: None, + }, + Constraint { + name: ConstraintName::try_new("products_quantity_non_negative".to_string()).unwrap(), + kind: ConstraintKind::Check(CheckConstraint { + expression: SqlExpr::new("quantity >= 0".to_string()), + is_no_inherit: false, + }), + comment: None, + }, + Constraint { + name: ConstraintName::try_new("products_status_valid".to_string()).unwrap(), + kind: ConstraintKind::Check(CheckConstraint { + expression: SqlExpr::new("status IN ('draft', 'active', 'archived')".to_string()), + is_no_inherit: false, + }), + comment: None, + }, + ]; + + let table = make_table("products", columns, constraints); + + let codegen = PythonCodegen::with_defaults(); + let output = codegen.generate(vec![table]); + + insta::assert_snapshot!(output.get("models.py").unwrap()); +} + +#[test] +fn snapshot_self_referential_table() { + let columns = vec![ + { + let mut col = make_column("id", "int4", false); + col.identity = Some(IdentityKind::Always); + col + }, + make_column("name", "text", false), + make_column("manager_id", "int4", true), + ]; + + let constraints = vec![ + make_pk_constraint("employees", &["id"]), + Constraint { + name: ConstraintName::try_new("employees_manager_id_fkey".to_string()).unwrap(), + kind: ConstraintKind::ForeignKey(ForeignKeyConstraint { + columns: vec![ColumnName::try_new("manager_id".to_string()).unwrap()], + referenced_table: QualifiedName::new( + SchemaName::try_new("public".to_string()).unwrap(), + TableName::try_new("employees".to_string()).unwrap(), + ), + referenced_columns: vec![ColumnName::try_new("id".to_string()).unwrap()], + on_delete: ForeignKeyAction::SetNull, + on_update: ForeignKeyAction::NoAction, + is_deferrable: false, + is_initially_deferred: false, + }), + comment: None, + }, + ]; + + let table = make_table("employees", columns, constraints); + + let codegen = PythonCodegen::with_defaults(); + let output = codegen.generate(vec![table]); + + insta::assert_snapshot!(output.get("models.py").unwrap()); +} + +#[test] +fn snapshot_multi_file_output() { + let user_columns = vec![ + { + let mut col = make_column("id", "int4", false); + col.identity = Some(IdentityKind::Always); + col + }, + make_column("name", "text", false), + make_column("email", "text", false), + ]; + let users = make_table( + "users", + user_columns, + vec![make_pk_constraint("users", &["id"])], + ); + + let post_columns = vec![ + { + let mut col = make_column("id", "int4", false); + col.identity = Some(IdentityKind::Always); + col + }, + make_column("title", "text", false), + make_column("user_id", "int4", false), + ]; + let posts = make_table( + "posts", + post_columns, + vec![ + make_pk_constraint("posts", &["id"]), + Constraint { + name: ConstraintName::try_new("posts_user_id_fkey".to_string()).unwrap(), + kind: ConstraintKind::ForeignKey(ForeignKeyConstraint { + columns: vec![ColumnName::try_new("user_id".to_string()).unwrap()], + referenced_table: QualifiedName::new( + SchemaName::try_new("public".to_string()).unwrap(), + TableName::try_new("users".to_string()).unwrap(), + ), + referenced_columns: vec![ColumnName::try_new("id".to_string()).unwrap()], + on_delete: ForeignKeyAction::Cascade, + on_update: ForeignKeyAction::NoAction, + is_deferrable: false, + is_initially_deferred: false, + }), + comment: None, + }, + ], + ); + + let config = PythonCodegenConfig { + output_mode: OutputMode::MultiFile, + ..Default::default() + }; + let codegen = PythonCodegen::new(config); + let output = codegen.generate(vec![users, posts]); + + // Snapshot each file separately + insta::assert_snapshot!("multi_file_init", output.get("__init__.py").unwrap()); + insta::assert_snapshot!("multi_file_user", output.get("user.py").unwrap()); + insta::assert_snapshot!("multi_file_post", output.get("post.py").unwrap()); +} + +#[test] +fn snapshot_exclusion_constraint_warning() { + let columns = vec![ + { + let mut col = make_column("id", "int4", false); + col.identity = Some(IdentityKind::Always); + col + }, + make_column("room_id", "int4", false), + make_column("start_time", "timestamptz", false), + make_column("end_time", "timestamptz", false), + ]; + + let constraints = vec![ + make_pk_constraint("meetings", &["id"]), + Constraint { + name: ConstraintName::try_new("meetings_no_overlap".to_string()).unwrap(), + kind: ConstraintKind::Exclusion(ExclusionConstraint { + elements: vec![ + ExclusionElement { + expression: SqlExpr::new("room_id".to_string()), + operator: "=".to_string(), + }, + ExclusionElement { + expression: SqlExpr::new("tsrange(start_time, end_time)".to_string()), + operator: "&&".to_string(), + }, + ], + index_method: IndexMethod::Gist, + index_name: IndexName::try_new("meetings_no_overlap_idx".to_string()).unwrap(), + predicate: None, + }), + comment: None, + }, + ]; + + let table = make_table("meetings", columns, constraints); + + let codegen = PythonCodegen::with_defaults(); + let output = codegen.generate(vec![table]); + + insta::assert_snapshot!(output.get("models.py").unwrap()); +} + +#[test] +fn snapshot_composite_primary_key() { + let columns = vec![ + make_column("order_id", "int4", false), + make_column("product_id", "int4", false), + make_column("quantity", "int4", false), + make_column("unit_price", "numeric", false), + ]; + + let constraints = vec![make_pk_constraint( + "order_items", + &["order_id", "product_id"], + )]; + + let table = make_table("order_items", columns, constraints); + + let codegen = PythonCodegen::with_defaults(); + let output = codegen.generate(vec![table]); + + insta::assert_snapshot!(output.get("models.py").unwrap()); +} diff --git a/crates/codegen/src/python/tests/snapshots/tern_codegen__python__tests__snapshot_tests__multi_file_init.snap b/crates/codegen/src/python/tests/snapshots/tern_codegen__python__tests__snapshot_tests__multi_file_init.snap new file mode 100644 index 0000000..eb8ed14 --- /dev/null +++ b/crates/codegen/src/python/tests/snapshots/tern_codegen__python__tests__snapshot_tests__multi_file_init.snap @@ -0,0 +1,14 @@ +--- +source: crates/codegen/src/python/tests/snapshot_tests.rs +assertion_line: 508 +expression: "output.get(\"__init__.py\").unwrap()" +--- +"""SQLModel definitions generated by Tern. + +This file was automatically generated. Do not edit manually. +""" + +from .user import User +from .post import Post + +__all__ = ["User", "Post"] diff --git a/crates/codegen/src/python/tests/snapshots/tern_codegen__python__tests__snapshot_tests__multi_file_post.snap b/crates/codegen/src/python/tests/snapshots/tern_codegen__python__tests__snapshot_tests__multi_file_post.snap new file mode 100644 index 0000000..d0c2f65 --- /dev/null +++ b/crates/codegen/src/python/tests/snapshots/tern_codegen__python__tests__snapshot_tests__multi_file_post.snap @@ -0,0 +1,15 @@ +--- +source: crates/codegen/src/python/tests/snapshot_tests.rs +assertion_line: 510 +expression: "output.get(\"post.py\").unwrap()" +--- +"""SQLModel definition for Post.""" + +from sqlmodel import Field, SQLModel + +class Post(SQLModel, table=True): + __tablename__ = "posts" + + id: int | None = Field(default=None, primary_key=True) + title: str + user_id: int = Field(foreign_key="public.users.id") diff --git a/crates/codegen/src/python/tests/snapshots/tern_codegen__python__tests__snapshot_tests__multi_file_user.snap b/crates/codegen/src/python/tests/snapshots/tern_codegen__python__tests__snapshot_tests__multi_file_user.snap new file mode 100644 index 0000000..d37902a --- /dev/null +++ b/crates/codegen/src/python/tests/snapshots/tern_codegen__python__tests__snapshot_tests__multi_file_user.snap @@ -0,0 +1,15 @@ +--- +source: crates/codegen/src/python/tests/snapshot_tests.rs +assertion_line: 509 +expression: "output.get(\"user.py\").unwrap()" +--- +"""SQLModel definition for User.""" + +from sqlmodel import Field, SQLModel + +class User(SQLModel, table=True): + __tablename__ = "users" + + id: int | None = Field(default=None, primary_key=True) + name: str + email: str diff --git a/crates/codegen/src/python/tests/snapshots/tern_codegen__python__tests__snapshot_tests__snapshot_all_postgres_types.snap b/crates/codegen/src/python/tests/snapshots/tern_codegen__python__tests__snapshot_tests__snapshot_all_postgres_types.snap new file mode 100644 index 0000000..efe8aed --- /dev/null +++ b/crates/codegen/src/python/tests/snapshots/tern_codegen__python__tests__snapshot_tests__snapshot_all_postgres_types.snap @@ -0,0 +1,43 @@ +--- +source: crates/codegen/src/python/tests/snapshot_tests.rs +assertion_line: 313 +expression: "output.get(\"models.py\").unwrap()" +--- +"""SQLModel definitions generated by Tern. + +This file was automatically generated. Do not edit manually. +""" + +from datetime import date, datetime, time, timedelta +from decimal import Decimal +from typing import Any +from uuid import UUID + +from sqlalchemy import ARRAY, Integer, JSON, Text +from sqlmodel import Field, SQLModel + +class AllType(SQLModel, table=True): + __tablename__ = "all_types" + + id: int | None = Field(default=None, primary_key=True) + col_int2: int + col_int4: int + col_int8: int + col_float4: float + col_float8: float + col_numeric: Decimal + col_bool: bool + col_text: str + col_varchar: str = Field(max_length=255) + col_date: date + col_time: time + col_timestamp: datetime + col_timestamptz: datetime + col_interval: timedelta + col_uuid: UUID + col_int_array: list[int] = Field(sa_type=ARRAY(Integer)) + col_json: dict[str, Any] | None = Field(sa_type=JSON) + col_jsonb: dict[str, Any] | None = Field(sa_type=JSON) + col_bytea: bytes | None = None + col_inet: str | None = None + col_text_array: list[str] | None = Field(sa_type=ARRAY(Text)) diff --git a/crates/codegen/src/python/tests/snapshots/tern_codegen__python__tests__snapshot_tests__snapshot_blog_schema.snap b/crates/codegen/src/python/tests/snapshots/tern_codegen__python__tests__snapshot_tests__snapshot_blog_schema.snap new file mode 100644 index 0000000..5f8154f --- /dev/null +++ b/crates/codegen/src/python/tests/snapshots/tern_codegen__python__tests__snapshot_tests__snapshot_blog_schema.snap @@ -0,0 +1,43 @@ +--- +source: crates/codegen/src/python/tests/snapshot_tests.rs +assertion_line: 243 +expression: "output.get(\"models.py\").unwrap()" +--- +"""SQLModel definitions generated by Tern. + +This file was automatically generated. Do not edit manually. +""" + +from datetime import datetime + +from sqlmodel import Field, SQLModel + +class User(SQLModel, table=True): + __tablename__ = "users" + + id: int | None = Field(default=None, primary_key=True) + username: str = Field(unique=True) + email: str = Field(unique=True) + created_at: datetime + bio: str | None = None + + +class Post(SQLModel, table=True): + __tablename__ = "posts" + + id: int | None = Field(default=None, primary_key=True) + title: str + content: str + author_id: int = Field(foreign_key="public.users.id") + created_at: datetime + published_at: datetime | None = None + + +class Comment(SQLModel, table=True): + __tablename__ = "comments" + + id: int | None = Field(default=None, primary_key=True) + content: str + post_id: int = Field(foreign_key="public.posts.id") + author_id: int = Field(foreign_key="public.users.id") + created_at: datetime diff --git a/crates/codegen/src/python/tests/snapshots/tern_codegen__python__tests__snapshot_tests__snapshot_complex_constraints.snap b/crates/codegen/src/python/tests/snapshots/tern_codegen__python__tests__snapshot_tests__snapshot_complex_constraints.snap new file mode 100644 index 0000000..9d58c48 --- /dev/null +++ b/crates/codegen/src/python/tests/snapshots/tern_codegen__python__tests__snapshot_tests__snapshot_complex_constraints.snap @@ -0,0 +1,30 @@ +--- +source: crates/codegen/src/python/tests/snapshot_tests.rs +assertion_line: 407 +expression: "output.get(\"models.py\").unwrap()" +--- +"""SQLModel definitions generated by Tern. + +This file was automatically generated. Do not edit manually. +""" + +from decimal import Decimal + +from sqlalchemy import CheckConstraint, UniqueConstraint +from sqlmodel import Field, SQLModel + +class Product(SQLModel, table=True): + __tablename__ = "products" + __table_args__ = ( + UniqueConstraint("email", "tenant_id", name="products_email_tenant_key"), + CheckConstraint("price > 0", name="products_price_positive"), + CheckConstraint("quantity >= 0", name="products_quantity_non_negative"), + CheckConstraint("status IN ('draft', 'active', 'archived')", name="products_status_valid"), + ) + + id: int | None = Field(default=None, primary_key=True) + email: str + tenant_id: int + status: str + price: Decimal + quantity: int diff --git a/crates/codegen/src/python/tests/snapshots/tern_codegen__python__tests__snapshot_tests__snapshot_composite_primary_key.snap b/crates/codegen/src/python/tests/snapshots/tern_codegen__python__tests__snapshot_tests__snapshot_composite_primary_key.snap new file mode 100644 index 0000000..748285d --- /dev/null +++ b/crates/codegen/src/python/tests/snapshots/tern_codegen__python__tests__snapshot_tests__snapshot_composite_primary_key.snap @@ -0,0 +1,21 @@ +--- +source: crates/codegen/src/python/tests/snapshot_tests.rs +assertion_line: 576 +expression: "output.get(\"models.py\").unwrap()" +--- +"""SQLModel definitions generated by Tern. + +This file was automatically generated. Do not edit manually. +""" + +from decimal import Decimal + +from sqlmodel import Field, SQLModel + +class OrderItem(SQLModel, table=True): + __tablename__ = "order_items" + + order_id: int = Field(primary_key=True) + product_id: int = Field(primary_key=True) + quantity: int + unit_price: Decimal diff --git a/crates/codegen/src/python/tests/snapshots/tern_codegen__python__tests__snapshot_tests__snapshot_exclusion_constraint_warning.snap b/crates/codegen/src/python/tests/snapshots/tern_codegen__python__tests__snapshot_tests__snapshot_exclusion_constraint_warning.snap new file mode 100644 index 0000000..e14feb2 --- /dev/null +++ b/crates/codegen/src/python/tests/snapshots/tern_codegen__python__tests__snapshot_tests__snapshot_exclusion_constraint_warning.snap @@ -0,0 +1,22 @@ +--- +source: crates/codegen/src/python/tests/snapshot_tests.rs +assertion_line: 554 +expression: "output.get(\"models.py\").unwrap()" +--- +"""SQLModel definitions generated by Tern. + +This file was automatically generated. Do not edit manually. +""" + +from datetime import datetime + +from sqlmodel import Field, SQLModel + +# WARNING: Exclusion constraint 'meetings_no_overlap' not supported by SQLModel: EXCLUDE USING gist (room_id WITH =, tsrange(start_time, end_time) WITH &&) +class Meeting(SQLModel, table=True): + __tablename__ = "meetings" + + id: int | None = Field(default=None, primary_key=True) + room_id: int + start_time: datetime + end_time: datetime diff --git a/crates/codegen/src/python/tests/snapshots/tern_codegen__python__tests__snapshot_tests__snapshot_reserved_words_table.snap b/crates/codegen/src/python/tests/snapshots/tern_codegen__python__tests__snapshot_tests__snapshot_reserved_words_table.snap new file mode 100644 index 0000000..0a579cf --- /dev/null +++ b/crates/codegen/src/python/tests/snapshots/tern_codegen__python__tests__snapshot_tests__snapshot_reserved_words_table.snap @@ -0,0 +1,26 @@ +--- +source: crates/codegen/src/python/tests/snapshot_tests.rs +assertion_line: 342 +expression: "output.get(\"models.py\").unwrap()" +--- +"""SQLModel definitions generated by Tern. + +This file was automatically generated. Do not edit manually. +""" + +from sqlmodel import Field, SQLModel + +class ReservedWord(SQLModel, table=True): + __tablename__ = "reserved_words" + + id: int | None = Field(default=None, primary_key=True) + class_: str = Field(alias="class") + from_: str = Field(alias="from") + def_: int = Field(alias="def") + return_: bool = Field(alias="return") + async_: bool = Field(alias="async") + import_: str | None = Field(alias="import") + yield_: str | None = Field(alias="yield") + await_: str | None = Field(alias="await") + match_: str | None = Field(alias="match") + case_: int | None = Field(alias="case") diff --git a/crates/codegen/src/python/tests/snapshots/tern_codegen__python__tests__snapshot_tests__snapshot_self_referential_table.snap b/crates/codegen/src/python/tests/snapshots/tern_codegen__python__tests__snapshot_tests__snapshot_self_referential_table.snap new file mode 100644 index 0000000..5e89dec --- /dev/null +++ b/crates/codegen/src/python/tests/snapshots/tern_codegen__python__tests__snapshot_tests__snapshot_self_referential_table.snap @@ -0,0 +1,18 @@ +--- +source: crates/codegen/src/python/tests/snapshot_tests.rs +assertion_line: 447 +expression: "output.get(\"models.py\").unwrap()" +--- +"""SQLModel definitions generated by Tern. + +This file was automatically generated. Do not edit manually. +""" + +from sqlmodel import Field, SQLModel + +class Employee(SQLModel, table=True): + __tablename__ = "employees" + + id: int | None = Field(default=None, primary_key=True) + name: str + manager_id: int | None = Field(foreign_key="public.employees.id") diff --git a/crates/codegen/src/python/tests/snapshots/tern_codegen__python__tests__snapshot_tests__snapshot_simple_users_table.snap b/crates/codegen/src/python/tests/snapshots/tern_codegen__python__tests__snapshot_tests__snapshot_simple_users_table.snap new file mode 100644 index 0000000..6b3359e --- /dev/null +++ b/crates/codegen/src/python/tests/snapshots/tern_codegen__python__tests__snapshot_tests__snapshot_simple_users_table.snap @@ -0,0 +1,23 @@ +--- +source: crates/codegen/src/python/tests/snapshot_tests.rs +assertion_line: 117 +expression: "output.get(\"models.py\").unwrap()" +--- +"""SQLModel definitions generated by Tern. + +This file was automatically generated. Do not edit manually. +""" + +from datetime import datetime + +from sqlmodel import Field, SQLModel + +class User(SQLModel, table=True): + """User accounts for the application.""" + + __tablename__ = "users" + + id: int | None = Field(default=None, primary_key=True) + email: str = Field(unique=True) + created_at: datetime + name: str | None = None diff --git a/crates/codegen/src/python/tests/unit_tests.rs b/crates/codegen/src/python/tests/unit_tests.rs new file mode 100644 index 0000000..38bcd3f --- /dev/null +++ b/crates/codegen/src/python/tests/unit_tests.rs @@ -0,0 +1,707 @@ +//! Unit tests for Python code generation components. + +use crate::Codegen; +use crate::python::naming::{ + is_python_keyword, is_reserved_word, sanitize_identifier, to_class_name, to_module_name, +}; +use crate::python::type_mapping::{ + extract_numeric_precision, extract_string_length, is_fixed_length_char, map_pg_type, +}; +use crate::python::{OutputMode, PythonCodegenConfig, ReservedWordStrategy}; + +use tern_ddl::types::{ForeignKeyAction, QualifiedCollationName, QualifiedName, SqlExpr}; +use tern_ddl::{ + CheckConstraint, CollationName, Column, ColumnName, Constraint, ConstraintKind, ConstraintName, + ForeignKeyConstraint, IdentityKind, IndexName, Oid, PrimaryKeyConstraint, SchemaName, Table, + TableKind, TableName, TypeInfo, TypeName, UniqueConstraint, +}; + +// ============================================================================= +// Test Helpers +// ============================================================================= + +fn make_type_info(name: &str, formatted: &str, is_array: bool) -> TypeInfo { + TypeInfo { + name: TypeName::try_new(name.to_string()).unwrap(), + schema: SchemaName::try_new("pg_catalog".to_string()).unwrap(), + formatted: formatted.to_string(), + is_array, + } +} + +fn make_column(name: &str, type_name: &str, is_nullable: bool) -> Column { + Column { + name: ColumnName::try_new(name.to_string()).unwrap(), + position: 1, + type_info: make_type_info(type_name, type_name, false), + is_nullable, + default: None, + generated: None, + identity: None, + collation: QualifiedCollationName::new( + SchemaName::try_new("pg_catalog".to_string()).unwrap(), + CollationName::try_new("default".to_string()).unwrap(), + ), + comment: None, + } +} + +fn make_column_with_identity(name: &str, type_name: &str, identity: IdentityKind) -> Column { + let mut col = make_column(name, type_name, false); + col.identity = Some(identity); + col +} + +fn make_table(name: &str, columns: Vec, constraints: Vec) -> Table { + Table { + oid: Oid::new(1), + name: TableName::try_new(name.to_string()).unwrap(), + kind: TableKind::Regular, + columns, + constraints, + indexes: vec![], + comment: None, + } +} + +fn make_pk_constraint(table_name: &str, columns: &[&str]) -> Constraint { + Constraint { + name: ConstraintName::try_new(format!("{}_pkey", table_name)).unwrap(), + kind: ConstraintKind::PrimaryKey(PrimaryKeyConstraint { + columns: columns + .iter() + .map(|c| ColumnName::try_new(c.to_string()).unwrap()) + .collect(), + index_name: IndexName::try_new(format!("{}_pkey", table_name)).unwrap(), + }), + comment: None, + } +} + +fn make_fk_constraint( + name: &str, + columns: &[&str], + ref_table: &str, + ref_columns: &[&str], +) -> Constraint { + Constraint { + name: ConstraintName::try_new(name.to_string()).unwrap(), + kind: ConstraintKind::ForeignKey(ForeignKeyConstraint { + columns: columns + .iter() + .map(|c| ColumnName::try_new(c.to_string()).unwrap()) + .collect(), + referenced_table: QualifiedName::new( + SchemaName::try_new("public".to_string()).unwrap(), + TableName::try_new(ref_table.to_string()).unwrap(), + ), + referenced_columns: ref_columns + .iter() + .map(|c| ColumnName::try_new(c.to_string()).unwrap()) + .collect(), + on_delete: ForeignKeyAction::NoAction, + on_update: ForeignKeyAction::NoAction, + is_deferrable: false, + is_initially_deferred: false, + }), + comment: None, + } +} + +fn make_unique_constraint(name: &str, columns: &[&str]) -> Constraint { + Constraint { + name: ConstraintName::try_new(name.to_string()).unwrap(), + kind: ConstraintKind::Unique(UniqueConstraint { + columns: columns + .iter() + .map(|c| ColumnName::try_new(c.to_string()).unwrap()) + .collect(), + index_name: IndexName::try_new(name.to_string()).unwrap(), + nulls_not_distinct: false, + }), + comment: None, + } +} + +fn make_check_constraint(name: &str, expression: &str) -> Constraint { + Constraint { + name: ConstraintName::try_new(name.to_string()).unwrap(), + kind: ConstraintKind::Check(CheckConstraint { + expression: SqlExpr::new(expression.to_string()), + is_no_inherit: false, + }), + comment: None, + } +} + +// ============================================================================= +// Type Mapping Tests +// ============================================================================= + +#[test] +fn test_all_integer_types() { + for type_name in ["int2", "int4", "int8", "smallint", "integer", "bigint"] { + let ty = make_type_info(type_name, type_name, false); + let py = map_pg_type(&ty); + assert_eq!(py.annotation, "int", "Failed for type: {}", type_name); + } +} + +#[test] +fn test_serial_types() { + for type_name in [ + "serial", + "serial2", + "serial4", + "serial8", + "smallserial", + "bigserial", + ] { + let ty = make_type_info(type_name, type_name, false); + let py = map_pg_type(&ty); + assert_eq!(py.annotation, "int", "Failed for type: {}", type_name); + } +} + +#[test] +fn test_all_float_types() { + for type_name in ["float4", "float8", "real", "double precision"] { + let ty = make_type_info(type_name, type_name, false); + let py = map_pg_type(&ty); + assert_eq!(py.annotation, "float", "Failed for type: {}", type_name); + } +} + +#[test] +fn test_all_text_types() { + for type_name in ["text", "varchar", "char", "character", "bpchar", "name"] { + let ty = make_type_info(type_name, type_name, false); + let py = map_pg_type(&ty); + assert_eq!(py.annotation, "str", "Failed for type: {}", type_name); + } +} + +#[test] +fn test_all_datetime_types() { + let cases = [ + ("date", "date"), + ("time", "time"), + ("timetz", "time"), + ("timestamp", "datetime"), + ("timestamptz", "datetime"), + ("interval", "timedelta"), + ]; + + for (pg_type, expected) in cases { + let ty = make_type_info(pg_type, pg_type, false); + let py = map_pg_type(&ty); + assert_eq!( + py.annotation, expected, + "Failed for type: {} -> expected {}", + pg_type, expected + ); + } +} + +#[test] +fn test_json_has_sa_type() { + for type_name in ["json", "jsonb"] { + let ty = make_type_info(type_name, type_name, false); + let py = map_pg_type(&ty); + assert_eq!(py.sa_type, Some("JSON".to_string())); + assert!(py.sa_imports.iter().any(|i| i.name == "JSON")); + } +} + +#[test] +fn test_array_types() { + let cases = [ + ("int4", "integer[]", "list[int]"), + ("text", "text[]", "list[str]"), + ("uuid", "uuid[]", "list[UUID]"), + ]; + + for (type_name, formatted, expected) in cases { + let ty = make_type_info(type_name, formatted, true); + let py = map_pg_type(&ty); + assert_eq!(py.annotation, expected); + assert!(py.sa_type.is_some()); + assert!(py.sa_type.as_ref().unwrap().contains("ARRAY")); + } +} + +// ============================================================================= +// Naming Tests +// ============================================================================= + +#[test] +fn test_all_python_keywords() { + let keywords = [ + "False", "None", "True", "and", "as", "assert", "async", "await", "break", "class", + "continue", "def", "del", "elif", "else", "except", "finally", "for", "from", "global", + "if", "import", "in", "is", "lambda", "nonlocal", "not", "or", "pass", "raise", "return", + "try", "while", "with", "yield", + ]; + + for kw in keywords { + assert!(is_python_keyword(kw), "Expected '{}' to be a keyword", kw); + assert!(is_reserved_word(kw), "Expected '{}' to be reserved", kw); + } +} + +#[test] +fn test_class_name_pluralization() { + let cases = [ + ("users", "User"), + ("categories", "Category"), + ("companies", "Company"), + ("addresses", "Addresse"), // 'addresses' -> singular is 'addresse' with simple rule + ("status", "Status"), // ends in 's' but 'us' ending preserved + ("analyses", "Analyse"), // ends in 's' but 'is' ending requires special handling + ("classes", "Classe"), // double 's' preserved + ]; + + for (input, expected) in cases { + let result = to_class_name(input); + assert_eq!( + result, expected, + "to_class_name('{}') = '{}', expected '{}'", + input, result, expected + ); + } +} + +#[test] +fn test_class_name_snake_to_pascal() { + let cases = [ + ("user_accounts", "UserAccount"), + ("order_items", "OrderItem"), + ("api_keys", "ApiKey"), + ("http_requests", "HttpRequest"), + ]; + + for (input, expected) in cases { + let result = to_class_name(input); + assert_eq!(result, expected); + } +} + +#[test] +fn test_module_name_conversion() { + let cases = [ + ("users", "user"), + ("UserAccounts", "user_account"), + ("order_items", "order_item"), + ("HTTPRequests", "httprequest"), // Consecutive caps become lowercase + ]; + + for (input, expected) in cases { + let result = to_module_name(input); + assert_eq!(result, expected); + } +} + +#[test] +fn test_sanitize_special_characters() { + let strategy = ReservedWordStrategy::AppendUnderscore; + + let cases = [ + ("column-name", "column_name", true), + ("column name", "column_name", true), + ("column.name", "column_name", true), + ("column@name", "column_name", true), + ("column#1", "column_1", true), + ]; + + for (input, expected, needs_alias) in cases { + let (result, alias) = sanitize_identifier(input, &strategy); + assert_eq!(result, expected, "sanitize('{}') = '{}'", input, result); + assert_eq!(alias, needs_alias); + } +} + +#[test] +fn test_sanitize_leading_digit() { + let strategy = ReservedWordStrategy::AppendUnderscore; + let (result, needs_alias) = sanitize_identifier("1column", &strategy); + assert_eq!(result, "_1column"); + assert!(needs_alias); +} + +#[test] +fn test_reserved_word_strategies() { + let input = "class"; + + let (result, _) = sanitize_identifier(input, &ReservedWordStrategy::AppendUnderscore); + assert_eq!(result, "class_"); + + let (result, _) = sanitize_identifier( + input, + &ReservedWordStrategy::PrependPrefix("col_".to_string()), + ); + assert_eq!(result, "col_class"); +} + +// ============================================================================= +// Full Generation Tests +// ============================================================================= + +#[test] +fn test_generate_simple_users_table() { + use crate::python::PythonCodegen; + + let columns = vec![ + make_column("id", "int4", false), + make_column("name", "text", false), + make_column("email", "text", false), + make_column("bio", "text", true), + ]; + let constraints = vec![ + make_pk_constraint("users", &["id"]), + make_unique_constraint("users_email_key", &["email"]), + ]; + let table = make_table("users", columns, constraints); + + let codegen = PythonCodegen::with_defaults(); + let output = codegen.generate(vec![table]); + + let content = &output["models.py"]; + + // Check class definition + assert!(content.contains("class User(SQLModel, table=True):")); + assert!(content.contains("__tablename__ = \"users\"")); + + // Check imports + assert!(content.contains("from sqlmodel import")); + assert!(content.contains("SQLModel")); + assert!(content.contains("Field")); + + // Check fields + assert!(content.contains("id:")); + assert!(content.contains("name: str")); + assert!(content.contains("email: str")); + assert!(content.contains("bio: str | None")); + + // Check constraints + assert!(content.contains("primary_key=True")); + assert!(content.contains("unique=True")); +} + +#[test] +fn test_generate_table_with_foreign_key() { + use crate::python::PythonCodegen; + + let user_columns = vec![ + make_column("id", "int4", false), + make_column("name", "text", false), + ]; + let user_constraints = vec![make_pk_constraint("users", &["id"])]; + let users = make_table("users", user_columns, user_constraints); + + let post_columns = vec![ + make_column("id", "int4", false), + make_column("title", "text", false), + make_column("user_id", "int4", true), + ]; + let post_constraints = vec![ + make_pk_constraint("posts", &["id"]), + make_fk_constraint("posts_user_id_fkey", &["user_id"], "users", &["id"]), + ]; + let posts = make_table("posts", post_columns, post_constraints); + + let codegen = PythonCodegen::with_defaults(); + let output = codegen.generate(vec![users, posts]); + + let content = &output["models.py"]; + + assert!(content.contains("class User(SQLModel, table=True):")); + assert!(content.contains("class Post(SQLModel, table=True):")); + assert!(content.contains("foreign_key=")); +} + +#[test] +fn test_generate_table_with_check_constraint() { + use crate::python::PythonCodegen; + + let columns = vec![ + make_column("id", "int4", false), + make_column("price", "numeric", false), + make_column("quantity", "int4", false), + ]; + let constraints = vec![ + make_pk_constraint("products", &["id"]), + make_check_constraint("products_price_positive", "price > 0"), + make_check_constraint("products_quantity_positive", "quantity >= 0"), + ]; + let table = make_table("products", columns, constraints); + + let codegen = PythonCodegen::with_defaults(); + let output = codegen.generate(vec![table]); + + let content = &output["models.py"]; + + assert!(content.contains("CheckConstraint")); + assert!(content.contains("price > 0")); + assert!(content.contains("quantity >= 0")); +} + +#[test] +fn test_generate_table_with_composite_unique() { + use crate::python::PythonCodegen; + + let columns = vec![ + make_column("id", "int4", false), + make_column("email", "text", false), + make_column("tenant_id", "int4", false), + ]; + let constraints = vec![ + make_pk_constraint("users", &["id"]), + make_unique_constraint("users_email_tenant_key", &["email", "tenant_id"]), + ]; + let table = make_table("users", columns, constraints); + + let codegen = PythonCodegen::with_defaults(); + let output = codegen.generate(vec![table]); + + let content = &output["models.py"]; + + assert!(content.contains("__table_args__")); + assert!(content.contains("UniqueConstraint")); + assert!(content.contains("\"email\"")); + assert!(content.contains("\"tenant_id\"")); +} + +#[test] +fn test_generate_table_with_identity_column() { + use crate::python::PythonCodegen; + + let columns = vec![ + make_column_with_identity("id", "int4", IdentityKind::Always), + make_column("name", "text", false), + ]; + let constraints = vec![make_pk_constraint("users", &["id"])]; + let table = make_table("users", columns, constraints); + + let codegen = PythonCodegen::with_defaults(); + let output = codegen.generate(vec![table]); + + let content = &output["models.py"]; + + // Identity column should be Optional with default=None + assert!(content.contains("id: int | None")); + assert!(content.contains("default=None")); + assert!(content.contains("primary_key=True")); +} + +#[test] +fn test_generate_reserved_word_column() { + use crate::python::PythonCodegen; + + let columns = vec![ + make_column("id", "int4", false), + make_column("class", "text", false), + make_column("from", "text", true), + ]; + let constraints = vec![make_pk_constraint("items", &["id"])]; + let table = make_table("items", columns, constraints); + + let codegen = PythonCodegen::with_defaults(); + let output = codegen.generate(vec![table]); + + let content = &output["models.py"]; + + // Reserved words should be aliased + assert!(content.contains("class_:")); + assert!(content.contains("alias=\"class\"")); + assert!(content.contains("from_:")); + assert!(content.contains("alias=\"from\"")); +} + +#[test] +fn test_generate_multi_file_output() { + use crate::python::PythonCodegen; + + let user_columns = vec![ + make_column("id", "int4", false), + make_column("name", "text", false), + ]; + let users = make_table( + "users", + user_columns, + vec![make_pk_constraint("users", &["id"])], + ); + + let post_columns = vec![ + make_column("id", "int4", false), + make_column("title", "text", false), + ]; + let posts = make_table( + "posts", + post_columns, + vec![make_pk_constraint("posts", &["id"])], + ); + + let config = PythonCodegenConfig { + output_mode: OutputMode::MultiFile, + ..Default::default() + }; + let codegen = PythonCodegen::new(config); + let output = codegen.generate(vec![users, posts]); + + // Check files exist + assert!(output.contains_key("__init__.py")); + assert!(output.contains_key("user.py")); + assert!(output.contains_key("post.py")); + + // Check __init__.py + let init = &output["__init__.py"]; + assert!(init.contains("from .user import User")); + assert!(init.contains("from .post import Post")); + assert!(init.contains("__all__")); + + // Check individual files + assert!(output["user.py"].contains("class User(SQLModel, table=True):")); + assert!(output["post.py"].contains("class Post(SQLModel, table=True):")); +} + +#[test] +fn test_generate_with_datetime_types() { + use crate::python::PythonCodegen; + + let columns = vec![ + make_column("id", "int4", false), + make_column("created_at", "timestamptz", false), + make_column("updated_at", "timestamptz", true), + make_column("birth_date", "date", true), + ]; + let constraints = vec![make_pk_constraint("events", &["id"])]; + let table = make_table("events", columns, constraints); + + let codegen = PythonCodegen::with_defaults(); + let output = codegen.generate(vec![table]); + + let content = &output["models.py"]; + + assert!(content.contains("from datetime import")); + assert!(content.contains("datetime")); + assert!(content.contains("date")); + assert!(content.contains("created_at: datetime")); + assert!(content.contains("birth_date: date | None")); +} + +#[test] +fn test_generate_with_uuid_type() { + use crate::python::PythonCodegen; + + let columns = vec![ + make_column("id", "uuid", false), + make_column("name", "text", false), + ]; + let constraints = vec![make_pk_constraint("items", &["id"])]; + let table = make_table("items", columns, constraints); + + let codegen = PythonCodegen::with_defaults(); + let output = codegen.generate(vec![table]); + + let content = &output["models.py"]; + + assert!(content.contains("from uuid import UUID")); +} + +#[test] +fn test_generate_with_decimal_type() { + use crate::python::PythonCodegen; + + let columns = vec![ + make_column("id", "int4", false), + make_column("price", "numeric", false), + make_column("discount", "numeric", true), + ]; + let constraints = vec![make_pk_constraint("products", &["id"])]; + let table = make_table("products", columns, constraints); + + let codegen = PythonCodegen::with_defaults(); + let output = codegen.generate(vec![table]); + + let content = &output["models.py"]; + + assert!(content.contains("from decimal import Decimal")); + assert!(content.contains("price: Decimal")); + assert!(content.contains("discount: Decimal | None")); +} + +#[test] +fn test_generate_with_json_type() { + use crate::python::PythonCodegen; + + let columns = vec![ + make_column("id", "int4", false), + make_column("config", "jsonb", false), + make_column("metadata", "json", true), + ]; + let constraints = vec![make_pk_constraint("settings", &["id"])]; + let table = make_table("settings", columns, constraints); + + let codegen = PythonCodegen::with_defaults(); + let output = codegen.generate(vec![table]); + + let content = &output["models.py"]; + + assert!(content.contains("from typing import Any")); + assert!(content.contains("from sqlalchemy import JSON")); + assert!(content.contains("dict[str, Any]")); + assert!(content.contains("sa_type=JSON")); +} + +#[test] +fn test_generate_empty_tables() { + use crate::python::PythonCodegen; + + let codegen = PythonCodegen::with_defaults(); + let output = codegen.generate(vec![]); + + assert!(output.contains_key("models.py")); + assert!(output["models.py"].contains("No tables to generate")); +} + +#[test] +fn test_config_default_values() { + let config = PythonCodegenConfig::default(); + + assert!(!config.generate_base_models); + assert!(config.module_prefix.is_none()); + assert!(config.include_docstrings); + assert!(!config.generate_relationships); + assert_eq!(config.output_mode, OutputMode::SingleFile); + assert_eq!( + config.reserved_word_strategy, + ReservedWordStrategy::AppendUnderscore + ); +} + +#[test] +fn test_string_length_extraction() { + assert_eq!(extract_string_length("character varying(100)"), Some(100)); + assert_eq!(extract_string_length("varchar(50)"), Some(50)); + assert_eq!(extract_string_length("character(10)"), Some(10)); + assert_eq!(extract_string_length("char(5)"), Some(5)); + assert_eq!(extract_string_length("text"), None); + assert_eq!(extract_string_length("integer"), None); +} + +#[test] +fn test_numeric_precision_extraction() { + assert_eq!(extract_numeric_precision("numeric(10,2)"), Some((10, 2))); + assert_eq!(extract_numeric_precision("decimal(15,4)"), Some((15, 4))); + assert_eq!(extract_numeric_precision("numeric(8)"), Some((8, 0))); + assert_eq!(extract_numeric_precision("numeric"), None); +} + +#[test] +fn test_fixed_length_char_detection() { + assert!(is_fixed_length_char("char")); + assert!(is_fixed_length_char("character")); + assert!(is_fixed_length_char("bpchar")); + assert!(!is_fixed_length_char("varchar")); + assert!(!is_fixed_length_char("text")); +} diff --git a/crates/codegen/src/python/type_mapping.rs b/crates/codegen/src/python/type_mapping.rs new file mode 100644 index 0000000..d7ea075 --- /dev/null +++ b/crates/codegen/src/python/type_mapping.rs @@ -0,0 +1,518 @@ +//! PostgreSQL to Python type mapping. +//! +//! This module handles the conversion of PostgreSQL types to their Python equivalents +//! for use in SQLModel field definitions. + +use tern_ddl::TypeInfo; + +/// Represents a Python type with its import requirements. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct PythonType { + /// The Python type annotation (e.g., "int", "str", "datetime"). + pub annotation: String, + /// Required imports for this type (module path, name). + pub imports: Vec, + /// SQLAlchemy type for `sa_type` parameter if needed (e.g., for JSON, ARRAY). + pub sa_type: Option, + /// SQLAlchemy imports required for sa_type. + pub sa_imports: Vec, +} + +/// A Python import statement. +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct PythonImport { + /// The module to import from (e.g., "datetime", "typing", "uuid"). + pub module: String, + /// The name to import (e.g., "datetime", "Any", "UUID"). + pub name: String, +} + +impl PythonImport { + /// Creates a new Python import. + pub fn new(module: impl Into, name: impl Into) -> Self { + Self { + module: module.into(), + name: name.into(), + } + } + + /// Creates an import from the datetime module. + pub fn datetime(name: &str) -> Self { + Self::new("datetime", name) + } + + /// Creates an import from the typing module. + pub fn typing(name: &str) -> Self { + Self::new("typing", name) + } + + /// Creates an import from the decimal module. + pub fn decimal() -> Self { + Self::new("decimal", "Decimal") + } + + /// Creates an import from the uuid module. + pub fn uuid() -> Self { + Self::new("uuid", "UUID") + } + + /// Creates an import from SQLAlchemy. + pub fn sqlalchemy(name: &str) -> Self { + Self::new("sqlalchemy", name) + } +} + +impl PythonType { + /// Creates a simple Python type with no imports. + pub fn simple(annotation: impl Into) -> Self { + Self { + annotation: annotation.into(), + imports: Vec::new(), + sa_type: None, + sa_imports: Vec::new(), + } + } + + /// Creates a Python type with imports. + pub fn with_imports(annotation: impl Into, imports: Vec) -> Self { + Self { + annotation: annotation.into(), + imports, + sa_type: None, + sa_imports: Vec::new(), + } + } + + /// Creates a Python type with an SQLAlchemy type. + pub fn with_sa_type( + annotation: impl Into, + imports: Vec, + sa_type: impl Into, + sa_imports: Vec, + ) -> Self { + Self { + annotation: annotation.into(), + imports, + sa_type: Some(sa_type.into()), + sa_imports, + } + } +} + +/// Maps a PostgreSQL type to a Python type. +/// +/// This function handles both the raw type name and the formatted type with modifiers. +/// The `formatted` field is preferred as it contains the full type specification. +pub fn map_pg_type(type_info: &TypeInfo) -> PythonType { + let type_name = type_info.name.as_ref(); + let formatted = &type_info.formatted; + + // Handle array types first + if type_info.is_array { + return map_array_type(type_name, formatted); + } + + // Map based on type name (canonical PostgreSQL type names) + match type_name { + // Integer types + "int2" | "smallint" => PythonType::simple("int"), + "int4" | "integer" | "int" => PythonType::simple("int"), + "int8" | "bigint" => PythonType::simple("int"), + "serial" | "serial4" => PythonType::simple("int"), + "bigserial" | "serial8" => PythonType::simple("int"), + "smallserial" | "serial2" => PythonType::simple("int"), + + // Floating point types + "float4" | "real" => PythonType::simple("float"), + "float8" | "double precision" => PythonType::simple("float"), + + // Numeric/decimal types + "numeric" | "decimal" => PythonType::with_imports("Decimal", vec![PythonImport::decimal()]), + + // Boolean + "bool" | "boolean" => PythonType::simple("bool"), + + // Text types + "text" => PythonType::simple("str"), + "varchar" | "character varying" => PythonType::simple("str"), + "char" | "character" | "bpchar" => PythonType::simple("str"), + "name" => PythonType::simple("str"), + + // Date/time types + "date" => PythonType::with_imports("date", vec![PythonImport::datetime("date")]), + "time" | "time without time zone" => { + PythonType::with_imports("time", vec![PythonImport::datetime("time")]) + } + "timetz" | "time with time zone" => { + PythonType::with_imports("time", vec![PythonImport::datetime("time")]) + } + "timestamp" | "timestamp without time zone" => { + PythonType::with_imports("datetime", vec![PythonImport::datetime("datetime")]) + } + "timestamptz" | "timestamp with time zone" => { + PythonType::with_imports("datetime", vec![PythonImport::datetime("datetime")]) + } + "interval" => { + PythonType::with_imports("timedelta", vec![PythonImport::datetime("timedelta")]) + } + + // UUID + "uuid" => PythonType::with_imports("UUID", vec![PythonImport::uuid()]), + + // JSON types + "json" | "jsonb" => PythonType::with_sa_type( + "dict[str, Any]", + vec![PythonImport::typing("Any")], + "JSON", + vec![PythonImport::sqlalchemy("JSON")], + ), + + // Binary + "bytea" => PythonType::simple("bytes"), + + // Network types + "inet" | "cidr" | "macaddr" | "macaddr8" => PythonType::simple("str"), + + // Geometric types (stored as strings) + "point" | "line" | "lseg" | "box" | "path" | "polygon" | "circle" => { + PythonType::simple("str") + } + + // Bit strings + "bit" | "varbit" | "bit varying" => PythonType::simple("str"), + + // Money (stored as string to preserve formatting) + "money" => PythonType::simple("str"), + + // XML + "xml" => PythonType::simple("str"), + + // OID types (internal PostgreSQL types) + "oid" | "regproc" | "regprocedure" | "regoper" | "regoperator" | "regclass" | "regtype" + | "regrole" | "regnamespace" | "regconfig" | "regdictionary" => PythonType::simple("int"), + + // Range types + "int4range" | "int8range" | "numrange" | "tsrange" | "tstzrange" | "daterange" => { + // Range types are complex; represent as string for now + PythonType::simple("str") + } + + // TSVector/TSQuery (full text search) + "tsvector" | "tsquery" => PythonType::simple("str"), + + // Default fallback - use Any for unknown types + _ => { + // Check if it looks like a user-defined type (enum, composite, etc.) + // For now, map unknown types to Any + PythonType::with_imports("Any", vec![PythonImport::typing("Any")]) + } + } +} + +/// Maps a PostgreSQL array type to a Python list type. +fn map_array_type(element_type_name: &str, formatted: &str) -> PythonType { + // Get the element type first + let element_type = map_pg_type(&TypeInfo { + name: tern_ddl::TypeName::try_new(element_type_name.to_string()) + .unwrap_or_else(|_| tern_ddl::TypeName::try_new("text".to_string()).unwrap()), + schema: tern_ddl::SchemaName::try_new("pg_catalog".to_string()).unwrap(), + formatted: strip_array_suffix(formatted), + is_array: false, + }); + + // Determine the SQLAlchemy array element type + let sa_element_type = match element_type_name { + "int2" | "smallint" => "SmallInteger", + "int4" | "integer" | "int" => "Integer", + "int8" | "bigint" => "BigInteger", + "float4" | "real" => "Float", + "float8" | "double precision" => "Float", + "numeric" | "decimal" => "Numeric", + "bool" | "boolean" => "Boolean", + "text" => "Text", + "varchar" | "character varying" => "String", + "char" | "character" | "bpchar" => "String", + "uuid" => "UUID", + "timestamp" | "timestamp without time zone" => "DateTime", + "timestamptz" | "timestamp with time zone" => "DateTime", + "date" => "Date", + "time" | "time without time zone" => "Time", + "json" | "jsonb" => "JSON", + _ => "String", // Default to String for unknown types + }; + + let annotation = format!("list[{}]", element_type.annotation); + let sa_type = format!("ARRAY({sa_element_type})"); + + let mut imports = element_type.imports; + let mut sa_imports = vec![PythonImport::sqlalchemy("ARRAY")]; + + // Add the element type import for SQLAlchemy + // All ARRAY element types require their corresponding SQLAlchemy type import + sa_imports.push(PythonImport::sqlalchemy(sa_element_type)); + + // Merge any existing SA imports from element type + imports.extend(element_type.sa_imports); + + PythonType { + annotation, + imports, + sa_type: Some(sa_type), + sa_imports, + } +} + +/// Strips the array suffix (e.g., "[]") from a formatted type string. +fn strip_array_suffix(formatted: &str) -> String { + formatted + .trim_end_matches("[]") + .trim_end_matches(" ARRAY") + .to_string() +} + +/// Extracts the varchar/char length constraint from a formatted type. +/// +/// Returns `Some(length)` for types like "character varying(255)" or "character(10)". +pub fn extract_string_length(formatted: &str) -> Option { + // Match patterns like "character varying(255)" or "character(10)" or "varchar(100)" + let formatted_lower = formatted.to_lowercase(); + + if formatted_lower.starts_with("character varying(") + || formatted_lower.starts_with("varchar(") + || formatted_lower.starts_with("character(") + || formatted_lower.starts_with("char(") + { + if let Some(start) = formatted.find('(') { + if let Some(end) = formatted.find(')') { + let len_str = &formatted[start + 1..end]; + return len_str.parse().ok(); + } + } + } + + None +} + +/// Extracts numeric precision and scale from a formatted type. +/// +/// Returns `Some((precision, scale))` for types like "numeric(10,2)". +/// Kept for future use to add Decimal precision validation in Field(). +#[allow(dead_code)] +pub fn extract_numeric_precision(formatted: &str) -> Option<(u32, u32)> { + let formatted_lower = formatted.to_lowercase(); + + if formatted_lower.starts_with("numeric(") || formatted_lower.starts_with("decimal(") { + if let Some(start) = formatted.find('(') { + if let Some(end) = formatted.find(')') { + let params = &formatted[start + 1..end]; + let parts: Vec<&str> = params.split(',').collect(); + if parts.len() == 2 { + let precision: u32 = parts[0].trim().parse().ok()?; + let scale: u32 = parts[1].trim().parse().ok()?; + return Some((precision, scale)); + } else if parts.len() == 1 { + let precision: u32 = parts[0].trim().parse().ok()?; + return Some((precision, 0)); + } + } + } + } + + None +} + +/// Checks if a type is a fixed-length character type (char/character). +pub fn is_fixed_length_char(type_name: &str) -> bool { + matches!(type_name, "char" | "character" | "bpchar") +} + +#[cfg(test)] +mod tests { + use super::*; + use tern_ddl::{SchemaName, TypeName}; + + fn make_type_info(name: &str, formatted: &str, is_array: bool) -> TypeInfo { + TypeInfo { + name: TypeName::try_new(name.to_string()).unwrap(), + schema: SchemaName::try_new("pg_catalog".to_string()).unwrap(), + formatted: formatted.to_string(), + is_array, + } + } + + #[test] + fn test_integer_types() { + let type_info = make_type_info("int4", "integer", false); + let py_type = map_pg_type(&type_info); + assert_eq!(py_type.annotation, "int"); + assert!(py_type.imports.is_empty()); + + let type_info = make_type_info("int8", "bigint", false); + let py_type = map_pg_type(&type_info); + assert_eq!(py_type.annotation, "int"); + + let type_info = make_type_info("int2", "smallint", false); + let py_type = map_pg_type(&type_info); + assert_eq!(py_type.annotation, "int"); + } + + #[test] + fn test_float_types() { + let type_info = make_type_info("float8", "double precision", false); + let py_type = map_pg_type(&type_info); + assert_eq!(py_type.annotation, "float"); + + let type_info = make_type_info("float4", "real", false); + let py_type = map_pg_type(&type_info); + assert_eq!(py_type.annotation, "float"); + } + + #[test] + fn test_numeric_type() { + let type_info = make_type_info("numeric", "numeric(10,2)", false); + let py_type = map_pg_type(&type_info); + assert_eq!(py_type.annotation, "Decimal"); + assert_eq!(py_type.imports.len(), 1); + assert_eq!(py_type.imports[0].module, "decimal"); + assert_eq!(py_type.imports[0].name, "Decimal"); + } + + #[test] + fn test_boolean_type() { + let type_info = make_type_info("bool", "boolean", false); + let py_type = map_pg_type(&type_info); + assert_eq!(py_type.annotation, "bool"); + } + + #[test] + fn test_text_types() { + let type_info = make_type_info("text", "text", false); + let py_type = map_pg_type(&type_info); + assert_eq!(py_type.annotation, "str"); + + let type_info = make_type_info("varchar", "character varying(255)", false); + let py_type = map_pg_type(&type_info); + assert_eq!(py_type.annotation, "str"); + } + + #[test] + fn test_datetime_types() { + let type_info = make_type_info("timestamp", "timestamp without time zone", false); + let py_type = map_pg_type(&type_info); + assert_eq!(py_type.annotation, "datetime"); + assert_eq!(py_type.imports.len(), 1); + assert_eq!(py_type.imports[0].module, "datetime"); + assert_eq!(py_type.imports[0].name, "datetime"); + + let type_info = make_type_info("timestamptz", "timestamp with time zone", false); + let py_type = map_pg_type(&type_info); + assert_eq!(py_type.annotation, "datetime"); + + let type_info = make_type_info("date", "date", false); + let py_type = map_pg_type(&type_info); + assert_eq!(py_type.annotation, "date"); + assert_eq!(py_type.imports[0].name, "date"); + + let type_info = make_type_info("time", "time without time zone", false); + let py_type = map_pg_type(&type_info); + assert_eq!(py_type.annotation, "time"); + } + + #[test] + fn test_uuid_type() { + let type_info = make_type_info("uuid", "uuid", false); + let py_type = map_pg_type(&type_info); + assert_eq!(py_type.annotation, "UUID"); + assert_eq!(py_type.imports.len(), 1); + assert_eq!(py_type.imports[0].module, "uuid"); + assert_eq!(py_type.imports[0].name, "UUID"); + } + + #[test] + fn test_json_types() { + let type_info = make_type_info("jsonb", "jsonb", false); + let py_type = map_pg_type(&type_info); + assert_eq!(py_type.annotation, "dict[str, Any]"); + assert!(py_type.imports.iter().any(|i| i.name == "Any")); + assert_eq!(py_type.sa_type, Some("JSON".to_string())); + } + + #[test] + fn test_bytea_type() { + let type_info = make_type_info("bytea", "bytea", false); + let py_type = map_pg_type(&type_info); + assert_eq!(py_type.annotation, "bytes"); + } + + #[test] + fn test_array_type() { + let type_info = make_type_info("int4", "integer[]", true); + let py_type = map_pg_type(&type_info); + assert_eq!(py_type.annotation, "list[int]"); + assert!(py_type.sa_type.is_some()); + assert!(py_type.sa_type.as_ref().unwrap().contains("ARRAY")); + } + + #[test] + fn test_text_array_type() { + let type_info = make_type_info("text", "text[]", true); + let py_type = map_pg_type(&type_info); + assert_eq!(py_type.annotation, "list[str]"); + // Text arrays use ARRAY(Text) - Text is the correct SQLAlchemy type for pg text + assert!(py_type.sa_type.as_ref().unwrap().contains("ARRAY(Text)")); + } + + #[test] + fn test_extract_string_length() { + assert_eq!(extract_string_length("character varying(255)"), Some(255)); + assert_eq!(extract_string_length("varchar(100)"), Some(100)); + assert_eq!(extract_string_length("character(10)"), Some(10)); + assert_eq!(extract_string_length("char(5)"), Some(5)); + assert_eq!(extract_string_length("text"), None); + } + + #[test] + fn test_extract_numeric_precision() { + assert_eq!(extract_numeric_precision("numeric(10,2)"), Some((10, 2))); + assert_eq!(extract_numeric_precision("numeric(5)"), Some((5, 0))); + assert_eq!(extract_numeric_precision("decimal(8,3)"), Some((8, 3))); + assert_eq!(extract_numeric_precision("integer"), None); + } + + #[test] + fn test_is_fixed_length_char() { + assert!(is_fixed_length_char("char")); + assert!(is_fixed_length_char("character")); + assert!(is_fixed_length_char("bpchar")); + assert!(!is_fixed_length_char("varchar")); + assert!(!is_fixed_length_char("text")); + } + + #[test] + fn test_network_types() { + let type_info = make_type_info("inet", "inet", false); + let py_type = map_pg_type(&type_info); + assert_eq!(py_type.annotation, "str"); + + let type_info = make_type_info("cidr", "cidr", false); + let py_type = map_pg_type(&type_info); + assert_eq!(py_type.annotation, "str"); + } + + #[test] + fn test_interval_type() { + let type_info = make_type_info("interval", "interval", false); + let py_type = map_pg_type(&type_info); + assert_eq!(py_type.annotation, "timedelta"); + assert_eq!(py_type.imports[0].module, "datetime"); + assert_eq!(py_type.imports[0].name, "timedelta"); + } + + #[test] + fn test_unknown_type_fallback() { + let type_info = make_type_info("my_custom_type", "my_custom_type", false); + let py_type = map_pg_type(&type_info); + assert_eq!(py_type.annotation, "Any"); + assert!(py_type.imports.iter().any(|i| i.name == "Any")); + } +}