diff --git a/Cargo.lock b/Cargo.lock
index 02ae11a..bc720ff 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -3184,7 +3184,10 @@ name = "tern-codegen"
version = "0.1.0"
dependencies = [
"genco",
+ "insta",
+ "pretty_assertions",
"tern-ddl",
+ "thiserror 2.0.18",
]
[[package]]
diff --git a/crates/codegen/Cargo.toml b/crates/codegen/Cargo.toml
index d77ef6b..2647abb 100644
--- a/crates/codegen/Cargo.toml
+++ b/crates/codegen/Cargo.toml
@@ -6,3 +6,8 @@ edition = "2024"
[dependencies]
genco = "0.19.0"
tern-ddl = { path = "../ddl" }
+thiserror = "2.0"
+
+[dev-dependencies]
+insta = { version = "1.42", features = ["yaml"] }
+pretty_assertions = "1.4"
diff --git a/crates/codegen/PYTHON_SQLMODEL_DESIGN.md b/crates/codegen/PYTHON_SQLMODEL_DESIGN.md
index a001270..99dcd57 100644
--- a/crates/codegen/PYTHON_SQLMODEL_DESIGN.md
+++ b/crates/codegen/PYTHON_SQLMODEL_DESIGN.md
@@ -1,22 +1,31 @@
# Python SQLModel Code Generator Design
+## Implementation Status
+
+> **Status Legend**:
+> - [x] Completed
+> - [ ] Future work (not yet implemented)
+> - [~] Partially implemented
+
## Overview
This document outlines the design for implementing a Python code generator that converts PostgreSQL DDL schema definitions (`Vec
`) into SQLModel and Pydantic models. The generator will implement the `Codegen` trait and use the `genco` crate for Python code generation.
## Goals
-1. Generate idiomatic SQLModel models from PostgreSQL table definitions
-2. Properly handle all PostgreSQL types with appropriate Python type mappings
-3. Support primary keys, foreign keys, unique constraints, and indexes
-4. Generate both table models (with `table=True`) and optional Pydantic-only models for validation
-5. Produce well-formatted, readable Python code with correct imports
-6. Handle edge cases robustly (reserved words, special characters, etc.)
+1. [x] Generate idiomatic SQLModel models from PostgreSQL table definitions
+2. [x] Properly handle all PostgreSQL types with appropriate Python type mappings
+3. [x] Support primary keys, foreign keys, unique constraints, and indexes
+4. [ ] Generate both table models (with `table=True`) and optional Pydantic-only models for validation
+5. [x] Produce well-formatted, readable Python code with correct imports
+6. [x] Handle edge cases robustly (reserved words, special characters, etc.)
## Architecture
### Module Structure
+[x] **Completed** - All modules implemented as designed.
+
```
crates/codegen/src/python/
├── mod.rs # Module root, re-exports, PythonCodegen struct
@@ -29,30 +38,36 @@ crates/codegen/src/python/
└── tests/ # Test submodule
├── mod.rs # Test utilities and common fixtures
├── unit_tests.rs # Unit tests for individual components
+ ├── snapshot_tests.rs # Snapshot tests for full generation output
└── snapshots/ # Snapshot test files (managed by insta)
```
### Core Types
+[x] **Completed** - All core types implemented in `mod.rs`.
+
```rust
/// Configuration for Python code generation.
#[derive(Debug, Clone)]
pub struct PythonCodegenConfig {
/// Whether to generate Pydantic-only base classes for each model.
/// These are useful for request/response validation without DB coupling.
- pub generate_base_models: bool,
+ pub generate_base_models: bool, // [ ] Future work
/// Module name prefix for generated imports (e.g., "app.models").
- pub module_prefix: Option,
+ pub module_prefix: Option, // [ ] Future work
/// Whether to include docstrings from table/column comments.
- pub include_docstrings: bool,
+ pub include_docstrings: bool, // [x] Completed
/// How to handle Python reserved words in identifiers.
- pub reserved_word_strategy: ReservedWordStrategy,
+ pub reserved_word_strategy: ReservedWordStrategy, // [x] Completed
/// Whether to generate relationship attributes for foreign keys.
- pub generate_relationships: bool,
+ pub generate_relationships: bool, // [ ] Future work
+
+ /// Output mode (single file or multi-file).
+ pub output_mode: OutputMode, // [x] Completed
}
#[derive(Debug, Clone, Default)]
@@ -78,39 +93,43 @@ impl Codegen for PythonCodegen {
### PostgreSQL to Python Type Mapping
-| PostgreSQL Type | Python Type | SQLModel Field | Notes |
-|-----------------|-------------|----------------|-------|
-| `integer`, `int4` | `int` | `Field()` | |
-| `bigint`, `int8` | `int` | `Field()` | Python int handles arbitrary precision |
-| `smallint`, `int2` | `int` | `Field()` | |
-| `serial`, `serial4` | `int` | `Field(default=None, primary_key=True)` | Auto-increment |
-| `bigserial`, `serial8` | `int` | `Field(default=None, primary_key=True)` | |
-| `boolean`, `bool` | `bool` | `Field()` | |
-| `text` | `str` | `Field()` | |
-| `varchar(n)`, `character varying` | `str` | `Field(max_length=n)` | Validate length |
-| `char(n)`, `character` | `str` | `Field(min_length=n, max_length=n)` | Fixed length |
-| `numeric`, `decimal` | `Decimal` | `Field()` | `from decimal import Decimal` |
-| `real`, `float4` | `float` | `Field()` | |
-| `double precision`, `float8` | `float` | `Field()` | |
-| `date` | `date` | `Field()` | `from datetime import date` |
-| `time`, `time without time zone` | `time` | `Field()` | `from datetime import time` |
-| `timetz`, `time with time zone` | `time` | `Field()` | |
-| `timestamp`, `timestamp without time zone` | `datetime` | `Field()` | `from datetime import datetime` |
-| `timestamptz`, `timestamp with time zone` | `datetime` | `Field()` | |
-| `interval` | `timedelta` | `Field()` | `from datetime import timedelta` |
-| `uuid` | `UUID` | `Field()` | `from uuid import UUID` |
-| `json` | `Any` | `Field(sa_type=JSON)` | `from typing import Any` |
-| `jsonb` | `Any` | `Field(sa_type=JSON)` | |
-| `bytea` | `bytes` | `Field()` | |
-| `inet` | `str` | `Field()` | IP address as string |
-| `cidr` | `str` | `Field()` | |
-| `macaddr` | `str` | `Field()` | |
-| `point`, `line`, etc. | `str` | `Field()` | Geometric as string |
-| `array` (e.g., `integer[]`) | `list[int]` | `Field(sa_type=ARRAY(Integer))` | |
-| User-defined enum | Literal union or Python Enum | `Field()` | See enum handling |
+[x] **Completed** - All types in this table are implemented in `type_mapping.rs`.
+
+| PostgreSQL Type | Python Type | SQLModel Field | Status |
+|-----------------|-------------|----------------|--------|
+| `integer`, `int4` | `int` | `Field()` | [x] |
+| `bigint`, `int8` | `int` | `Field()` | [x] |
+| `smallint`, `int2` | `int` | `Field()` | [x] |
+| `serial`, `serial4` | `int` | `Field(default=None, primary_key=True)` | [x] |
+| `bigserial`, `serial8` | `int` | `Field(default=None, primary_key=True)` | [x] |
+| `boolean`, `bool` | `bool` | `Field()` | [x] |
+| `text` | `str` | `Field()` | [x] |
+| `varchar(n)`, `character varying` | `str` | `Field(max_length=n)` | [x] (length extracted but not yet used in Field) |
+| `char(n)`, `character` | `str` | `Field(min_length=n, max_length=n)` | [x] (length extracted but not yet used in Field) |
+| `numeric`, `decimal` | `Decimal` | `Field()` | [x] |
+| `real`, `float4` | `float` | `Field()` | [x] |
+| `double precision`, `float8` | `float` | `Field()` | [x] |
+| `date` | `date` | `Field()` | [x] |
+| `time`, `time without time zone` | `time` | `Field()` | [x] |
+| `timetz`, `time with time zone` | `time` | `Field()` | [x] |
+| `timestamp`, `timestamp without time zone` | `datetime` | `Field()` | [x] |
+| `timestamptz`, `timestamp with time zone` | `datetime` | `Field()` | [x] |
+| `interval` | `timedelta` | `Field()` | [x] |
+| `uuid` | `UUID` | `Field()` | [x] |
+| `json` | `dict[str, Any]` | `Field(sa_type=JSON)` | [x] |
+| `jsonb` | `dict[str, Any]` | `Field(sa_type=JSON)` | [x] |
+| `bytea` | `bytes` | `Field()` | [x] |
+| `inet` | `str` | `Field()` | [x] |
+| `cidr` | `str` | `Field()` | [x] |
+| `macaddr` | `str` | `Field()` | [x] |
+| `point`, `line`, etc. | `str` | `Field()` | [x] |
+| `array` (e.g., `integer[]`) | `list[int]` | `Field(sa_type=ARRAY(Integer))` | [x] |
+| User-defined enum | Literal union or Python Enum | `Field()` | [ ] Future work |
### Handling User-Defined Types
+[ ] **Future Work** - User-defined enum types are not yet supported.
+
For user-defined enum types:
```python
from typing import Literal
@@ -133,6 +152,8 @@ class UserStatus(str, Enum):
### Basic Field Patterns
+[x] **Completed** - All basic field patterns implemented in `field.rs`.
+
```python
# Non-nullable without default
name: str
@@ -146,7 +167,7 @@ status: str = Field(default="active")
# Primary key (nullable with None default for auto-generation)
id: int | None = Field(default=None, primary_key=True)
-# With index
+# With index [ ] Future work - index=True not yet generated for single columns
email: str = Field(index=True)
# With unique constraint
@@ -158,12 +179,15 @@ user_id: int | None = Field(default=None, foreign_key="users.id")
### Generated/Identity Columns
+[x] **Completed** - Identity columns handled correctly.
+[~] **Partial** - Generated columns emit warnings but don't fully support `Computed`.
+
```python
# Identity column (GENERATED ALWAYS AS IDENTITY)
id: int | None = Field(default=None, primary_key=True)
# Note: SQLModel handles auto-increment through primary_key=True with None default
-# Generated column (STORED)
+# Generated column (STORED) - [ ] Future work for full Computed support
# SQLModel doesn't have native support; use sa_column
full_name: str = Field(
default=None,
@@ -175,6 +199,8 @@ full_name: str = Field(
#### Primary Key
+[x] **Completed**
+
```python
# Single column
id: int | None = Field(default=None, primary_key=True)
@@ -188,17 +214,22 @@ class OrderItem(SQLModel, table=True):
#### Foreign Key
+[x] **Completed** - Basic FK support.
+[ ] **Future Work** - Relationship generation.
+
```python
# Basic FK
user_id: int | None = Field(default=None, foreign_key="users.id")
-# With relationship
+# With relationship - [ ] Future work
user_id: int | None = Field(default=None, foreign_key="users.id")
user: "User | None" = Relationship(back_populates="posts")
```
#### Unique Constraint
+[x] **Completed**
+
```python
# Single column
email: str = Field(unique=True)
@@ -211,6 +242,8 @@ __table_args__ = (
#### Check Constraint
+[x] **Completed**
+
```python
# Via __table_args__
__table_args__ = (
@@ -220,11 +253,13 @@ __table_args__ = (
### Index Handling
+[~] **Partial** - Multi-column indexes via `__table_args__` completed. Single-column `Field(index=True)` not yet implemented.
+
```python
-# Simple column index
+# Simple column index - [ ] Future work
email: str = Field(index=True)
-# Composite or complex indexes - via __table_args__
+# Composite or complex indexes - via __table_args__ [x] Completed
__table_args__ = (
Index("idx_users_name_email", "name", "email"),
)
@@ -234,6 +269,8 @@ __table_args__ = (
### Single File Output
+[x] **Completed**
+
For simpler schemas, generate a single `models.py`:
```
@@ -242,18 +279,21 @@ models.py
### Multi-File Output
+[x] **Completed**
+
For larger schemas, split by table with a shared types module:
```
models/
├── __init__.py # Re-exports all models
-├── _types.py # Shared types, enums, base classes
├── user.py # User model
├── post.py # Post model
└── comment.py # Comment model
```
-**Configuration**: Let users choose via `OutputMode::SingleFile` or `OutputMode::MultiFile`.
+**Note**: `_types.py` for shared types/enums not yet implemented (depends on enum support).
+
+**Configuration**: Users can choose via `OutputMode::SingleFile` or `OutputMode::MultiFile`.
## Generated Code Examples
@@ -280,17 +320,15 @@ Table {
### Generated Python
+[x] **Completed** - Basic generation works. Some features marked for future work.
+
```python
"""SQLModel definitions generated by Tern."""
from datetime import datetime
-from typing import TYPE_CHECKING
from sqlmodel import Field, SQLModel
-if TYPE_CHECKING:
- from .post import Post # For relationship type hints
-
class User(SQLModel, table=True):
"""User model."""
@@ -300,9 +338,9 @@ class User(SQLModel, table=True):
id: int | None = Field(default=None, primary_key=True)
email: str = Field(unique=True)
name: str | None = None
- created_at: datetime = Field(default_factory=datetime.now, index=True)
+ created_at: datetime # Note: default_factory not yet implemented
- # Relationships (if generate_relationships=True)
+ # Relationships (if generate_relationships=True) - [ ] Future work
# posts: list["Post"] = Relationship(back_populates="user")
```
@@ -310,6 +348,8 @@ class User(SQLModel, table=True):
### 1. Python Reserved Words
+[x] **Completed** - Full reserved word handling in `naming.rs`.
+
Python reserved words that might appear as column/table names:
```
@@ -318,31 +358,25 @@ def, del, elif, else, except, finally, for, from, global, if, import, in,
is, lambda, nonlocal, not, or, pass, raise, return, try, while, with, yield
```
+Also handles soft keywords (`match`, `case`, `type`, `_`) and can optionally handle Python builtins.
+
**Handling**:
```python
# Column named "class"
class_: str = Field(alias="class")
```
-Use SQLAlchemy column aliasing to preserve the database column name while using a valid Python identifier.
-
### 2. Invalid Python Identifiers
-- Names starting with numbers: `1column` -> `column_1` or `_1column`
+[x] **Completed**
+
+- Names starting with numbers: `1column` -> `_1column`
- Names with special characters: `column-name` -> `column_name`
- Names with spaces: `column name` -> `column_name`
-```rust
-fn sanitize_identifier(name: &str) -> String {
- // 1. Replace invalid characters with underscores
- // 2. Ensure doesn't start with digit
- // 3. Handle reserved words
-}
-```
-
### 3. Circular Foreign Key References
-Tables may reference each other:
+[~] **Partial** - FK references work. Full relationship generation with back_populates is future work.
```python
# Forward reference using string annotation
@@ -350,7 +384,7 @@ class User(SQLModel, table=True):
id: int | None = Field(default=None, primary_key=True)
manager_id: int | None = Field(default=None, foreign_key="users.id")
- # Self-referential relationship
+ # Self-referential relationship - [ ] Future work
manager: "User | None" = Relationship(
back_populates="direct_reports",
sa_relationship_kwargs={"remote_side": "User.id"}
@@ -360,7 +394,9 @@ class User(SQLModel, table=True):
### 4. Schema-Qualified Names
-Foreign keys might reference tables in other schemas:
+[x] **Completed**
+
+Foreign keys reference tables with schema qualification:
```python
# Reference to other_schema.other_table
@@ -369,29 +405,32 @@ other_id: int | None = Field(default=None, foreign_key="other_schema.other_table
### 5. Array Types
+[x] **Completed**
+
```python
-from sqlalchemy import ARRAY, Integer
+from sqlalchemy import ARRAY, Text
from sqlmodel import Field, SQLModel
class Document(SQLModel, table=True):
- tags: list[str] = Field(
- default_factory=list,
- sa_type=ARRAY(String)
- )
+ tags: list[str] = Field(sa_type=ARRAY(Text))
```
### 6. JSONB Fields
+[x] **Completed**
+
```python
from sqlalchemy import JSON
from typing import Any
class Settings(SQLModel, table=True):
- config: dict[str, Any] = Field(default_factory=dict, sa_type=JSON)
+ config: dict[str, Any] = Field(sa_type=JSON)
```
### 7. Composite Primary Keys
+[x] **Completed**
+
```python
class OrderItem(SQLModel, table=True):
__tablename__ = "order_items"
@@ -403,6 +442,8 @@ class OrderItem(SQLModel, table=True):
### 8. Generated Columns
+[ ] **Future Work** - Currently emits warning. Full `Computed` support not implemented.
+
SQLModel doesn't have first-class support for generated columns, but we can use `sa_column`:
```python
@@ -419,6 +460,8 @@ class Person(SQLModel, table=True):
### 9. Exclusion Constraints
+[x] **Completed** - Emits warning comment as designed.
+
Not directly supported by SQLModel; emit a warning comment:
```python
@@ -428,80 +471,93 @@ Not directly supported by SQLModel; emit a warning comment:
### 10. Empty Tables
-Tables with no columns (rare but possible):
-
-```python
-class EmptyTable(SQLModel, table=True):
- """Table with no columns (placeholder)."""
- __tablename__ = "empty_table"
- pass # SQLModel requires at least one field in practice
-```
+[x] **Completed** - Emits warning.
-**Emit warning**: Tables with no columns should emit a warning as SQLModel requires at least one field.
+Tables with no columns emit a warning as SQLModel requires at least one field.
## Implementation Plan
### Phase 1: Core Infrastructure
-1. Add `genco` dependency to `tern-codegen/Cargo.toml`
-2. Create module structure under `src/python/`
-3. Implement `PythonCodegenConfig` and `PythonCodegen` struct
-4. Implement basic `Codegen` trait with empty generation
+[x] **Completed**
+
+1. [x] Add `genco` dependency to `tern-codegen/Cargo.toml`
+2. [x] Create module structure under `src/python/`
+3. [x] Implement `PythonCodegenConfig` and `PythonCodegen` struct
+4. [x] Implement basic `Codegen` trait with empty generation
### Phase 2: Type Mapping
-1. Implement `type_mapping.rs` with PostgreSQL -> Python conversions
-2. Handle all scalar types from the mapping table
-3. Add array type detection and handling
-4. Add tests for type mapping edge cases
+[x] **Completed**
+
+1. [x] Implement `type_mapping.rs` with PostgreSQL -> Python conversions
+2. [x] Handle all scalar types from the mapping table
+3. [x] Add array type detection and handling
+4. [x] Add tests for type mapping edge cases
### Phase 3: Name Handling
-1. Implement `naming.rs` with identifier sanitization
-2. Build reserved word detection and handling
-3. Implement table/column name conversion to Python conventions
-4. Add tests for naming edge cases
+[x] **Completed**
+
+1. [x] Implement `naming.rs` with identifier sanitization
+2. [x] Build reserved word detection and handling
+3. [x] Implement table/column name conversion to Python conventions
+4. [x] Add tests for naming edge cases
### Phase 4: Basic Model Generation
-1. Implement simple class generation with genco
-2. Generate basic fields (non-nullable, no constraints)
-3. Handle nullable fields with `Optional` types
-4. Generate proper imports
+[x] **Completed**
+
+1. [x] Implement simple class generation with genco
+2. [x] Generate basic fields (non-nullable, no constraints)
+3. [x] Handle nullable fields with `| None` types
+4. [x] Generate proper imports
### Phase 5: Constraint Support
-1. Primary key generation
-2. Foreign key generation (without relationships)
-3. Unique constraint generation (single-column via Field, multi-column via `__table_args__`)
-4. Check constraint generation via `__table_args__`
-5. Index generation
+[x] **Completed**
+
+1. [x] Primary key generation
+2. [x] Foreign key generation (without relationships)
+3. [x] Unique constraint generation (single-column via Field, multi-column via `__table_args__`)
+4. [x] Check constraint generation via `__table_args__`
+5. [x] Index generation (multi-column via `__table_args__`)
### Phase 6: Relationship Generation
-1. Analyze foreign key graph to determine relationship directions
-2. Generate `Relationship()` attributes
-3. Handle self-referential relationships
-4. Handle circular references with forward declarations
+[ ] **Future Work**
+
+1. [ ] Analyze foreign key graph to determine relationship directions
+2. [ ] Generate `Relationship()` attributes
+3. [ ] Handle self-referential relationships
+4. [ ] Handle circular references with forward declarations
+
+**Note**: Infrastructure is in place (`add_relationship`, `add_type_checking` methods exist with `#[allow(dead_code)]`).
### Phase 7: Advanced Features
-1. Generated column support
-2. Identity column support
-3. Array type support
-4. JSONB field support
-5. Docstring generation from comments
+[~] **Partially Completed**
+
+1. [ ] Generated column support (with `Computed`)
+2. [x] Identity column support
+3. [x] Array type support
+4. [x] JSONB field support
+5. [x] Docstring generation from comments
### Phase 8: Testing
-1. Unit tests for each component
-2. Snapshot tests for complete model generation
-3. Edge case tests (reserved words, special characters, circular refs)
-4. Integration tests with complex multi-table schemas
+[x] **Completed**
+
+1. [x] Unit tests for each component (109 tests)
+2. [x] Snapshot tests for complete model generation (11 snapshots)
+3. [x] Edge case tests (reserved words, special characters, circular refs)
+4. [x] Integration tests with complex multi-table schemas
## Dependencies
-Add to `crates/codegen/Cargo.toml`:
+[x] **Completed**
+
+Added to `crates/codegen/Cargo.toml`:
```toml
[dependencies]
@@ -518,7 +574,7 @@ pretty_assertions = "1.4"
### Unit Tests
-Test individual components in isolation:
+[x] **Completed** - 109 tests covering all components.
```rust
#[test]
@@ -534,7 +590,7 @@ fn test_sanitize_reserved_word() {
### Snapshot Tests
-Use `insta` for snapshot testing of generated code:
+[x] **Completed** - 11 snapshot tests with insta.
```rust
#[test]
@@ -545,62 +601,25 @@ fn snapshot_simple_table() {
insta::assert_snapshot!(output.get("models.py").unwrap());
}
-
-#[test]
-fn snapshot_foreign_key_relationship() {
- let tables = vec![create_users_table(), create_posts_table()];
- let codegen = PythonCodegen::new(PythonCodegenConfig {
- generate_relationships: true,
- ..Default::default()
- });
- let output = codegen.generate(tables);
-
- insta::assert_snapshot!(output.get("models.py").unwrap());
-}
```
### Edge Case Tests
-```rust
-#[test]
-fn test_reserved_word_column() {
- let table = Table {
- name: TableName::try_new("items".to_string()).unwrap(),
- columns: vec![
- column("class", "text"), // Reserved word
- column("from", "integer"), // Reserved word
- ],
- ..Default::default()
- };
- // Verify generated code uses aliases
-}
+[x] **Completed**
-#[test]
-fn test_self_referential_foreign_key() {
- let table = Table {
- name: TableName::try_new("employees".to_string()).unwrap(),
- columns: vec![
- column("id", "integer"),
- column("manager_id", "integer"),
- ],
- constraints: vec![
- Constraint::foreign_key("manager_id", "employees", "id"),
- ],
- ..Default::default()
- };
- // Verify self-referential handling
-}
-
-#[test]
-fn test_circular_foreign_keys() {
- let user_table = /* ... references posts.featured_user_id */;
- let post_table = /* ... references users.id */;
- // Verify circular reference handling with forward declarations
-}
-```
+- Reserved word columns (class, from, import, def, etc.)
+- Self-referential foreign keys
+- Composite primary keys
+- Multi-column unique constraints
+- Check constraints
+- Exclusion constraints (warning)
+- All PostgreSQL types
+- Multi-file output mode
## Error Handling
+[x] **Completed** - Error type defined in `mod.rs`.
+
```rust
#[derive(Debug, thiserror::Error)]
pub enum PythonCodegenError {
@@ -620,24 +639,31 @@ pub enum PythonCodegenError {
**Note**: The current `Codegen` trait returns `HashMap` without error handling. Consider proposing a trait update to return `Result, Error>` in the future.
-## Open Questions
+## Open Questions (Resolved)
1. **Output Mode**: Should we default to single-file or multi-file output?
- - **Recommendation**: Single file for simplicity, with config option for multi-file.
+ - **Resolution**: [x] Single file is the default, with `OutputMode::MultiFile` option.
2. **Relationship Generation**: Should relationships be opt-in or opt-out?
- - **Recommendation**: Opt-in via config flag, as relationships add complexity.
+ - **Resolution**: [x] Opt-in via `generate_relationships` config flag. Not yet implemented.
3. **Type Hints Style**: Should we use `Optional[X]` or `X | None`?
- - **Recommendation**: `X | None` (Python 3.10+ syntax) as it's more modern and SQLModel targets newer Python.
+ - **Resolution**: [x] Using `X | None` (Python 3.10+ syntax).
4. **Enum Handling**: Literal types vs Python Enum classes?
- - **Recommendation**: This design assumes enums are not passed in the `Vec` input. If enum support is needed, add a separate `enums` parameter to the generator or include enum definitions in a preprocessing step.
+ - **Resolution**: Deferred to future work. Currently maps unknown types to `Any`.
## Future Enhancements
-1. **Alembic Migration Generation**: Generate Alembic migration files alongside models
-2. **Pydantic V2 Schemas**: Generate pure Pydantic models for API schemas
-3. **FastAPI Integration**: Generate FastAPI route stubs for CRUD operations
-4. **Custom Validators**: Support for custom Pydantic validators from check constraints
-5. **Type Stubs**: Generate `.pyi` stub files for better IDE support
+1. [ ] **Relationship Generation**: Generate `Relationship()` attributes with back_populates
+2. [ ] **Pydantic Base Models**: Generate Pydantic-only models via `generate_base_models` config
+3. [ ] **User-Defined Enums**: Support PostgreSQL enum types as Python Literal or Enum
+4. [ ] **Generated Columns**: Full `Computed` support with sa_column
+5. [ ] **Single-Column Index**: Generate `Field(index=True)` for indexed columns
+6. [ ] **Default Factory**: Generate `default_factory` for datetime fields with `now()` defaults
+7. [ ] **Module Prefix**: Support `module_prefix` config for import paths
+8. [ ] **String Length Validation**: Use extracted varchar/char length in `Field(max_length=n)`
+9. [ ] **Alembic Migration Generation**: Generate Alembic migration files alongside models
+10. [ ] **FastAPI Integration**: Generate FastAPI route stubs for CRUD operations
+11. [ ] **Custom Validators**: Support for custom Pydantic validators from check constraints
+12. [ ] **Type Stubs**: Generate `.pyi` stub files for better IDE support
diff --git a/crates/codegen/src/python/field.rs b/crates/codegen/src/python/field.rs
new file mode 100644
index 0000000..1b22232
--- /dev/null
+++ b/crates/codegen/src/python/field.rs
@@ -0,0 +1,515 @@
+//! SQLModel field generation.
+//!
+//! This module handles generating Field() declarations for SQLModel columns,
+//! including type annotations, default values, and field parameters.
+
+use tern_ddl::{Column, ConstraintKind, Table};
+
+use super::ReservedWordStrategy;
+use super::imports::ImportCollector;
+use super::naming::to_attribute_name;
+use super::type_mapping::{extract_string_length, is_fixed_length_char, map_pg_type};
+
+/// Information about a generated field.
+#[derive(Debug)]
+pub struct FieldInfo {
+ /// The Python attribute name (may differ from column name).
+ pub attr_name: String,
+ /// The original database column name.
+ /// Kept for debugging and future relationship generation.
+ #[allow(dead_code)]
+ pub db_column_name: String,
+ /// Whether an alias is needed (attr_name != db_column_name).
+ /// Kept for debugging and validation.
+ #[allow(dead_code)]
+ pub needs_alias: bool,
+ /// The Python type annotation.
+ pub type_annotation: String,
+ /// The Field() declaration (if any parameters are needed).
+ pub field_declaration: Option,
+ /// Default value without Field() (e.g., "= None").
+ pub simple_default: Option,
+ /// Required imports for this field.
+ pub imports: ImportCollector,
+ /// Whether this field is a primary key.
+ pub is_primary_key: bool,
+ /// Whether this field has a foreign key constraint.
+ /// Kept for future relationship generation.
+ #[allow(dead_code)]
+ pub foreign_key: Option,
+ /// Whether this field is indexed (non-constraint index).
+ /// Kept for debugging and validation.
+ #[allow(dead_code)]
+ pub is_indexed: bool,
+ /// Whether this field has a unique constraint.
+ /// Kept for debugging and validation.
+ #[allow(dead_code)]
+ pub is_unique: bool,
+}
+
+/// Context for field generation containing table-level information.
+pub struct FieldContext<'a> {
+ /// The table containing the column.
+ /// Kept for future relationship generation.
+ #[allow(dead_code)]
+ pub table: &'a Table,
+ /// Reserved word handling strategy.
+ pub strategy: &'a ReservedWordStrategy,
+ /// Set of column names that are primary keys.
+ pub primary_key_columns: Vec<&'a str>,
+ /// Map of column name to foreign key reference (e.g., "users.id").
+ pub foreign_key_columns: Vec<(&'a str, String)>,
+ /// Set of column names with unique constraints (single-column only).
+ pub unique_columns: Vec<&'a str>,
+ /// Set of column names with indexes (non-constraint indexes).
+ pub indexed_columns: Vec<&'a str>,
+}
+
+impl<'a> FieldContext<'a> {
+ /// Creates a new field context from a table.
+ pub fn from_table(table: &'a Table, strategy: &'a ReservedWordStrategy) -> Self {
+ let mut primary_key_columns = Vec::new();
+ let mut foreign_key_columns = Vec::new();
+ let mut unique_columns = Vec::new();
+ let mut indexed_columns = Vec::new();
+
+ // Extract constraint information
+ for constraint in &table.constraints {
+ match &constraint.kind {
+ ConstraintKind::PrimaryKey(pk) => {
+ for col in &pk.columns {
+ primary_key_columns.push(col.as_ref());
+ }
+ }
+ ConstraintKind::ForeignKey(fk) => {
+ // For single-column FKs, we can use Field(foreign_key=...)
+ if fk.columns.len() == 1 && fk.referenced_columns.len() == 1 {
+ let col_name = fk.columns[0].as_ref();
+ let ref_table = &fk.referenced_table;
+ let ref_col = fk.referenced_columns[0].as_ref();
+ let fk_ref = format!(
+ "{}.{}.{}",
+ ref_table.schema.as_ref(),
+ ref_table.name.as_ref(),
+ ref_col
+ );
+ foreign_key_columns.push((col_name, fk_ref));
+ }
+ }
+ ConstraintKind::Unique(unique) => {
+ // Single-column unique constraints can use Field(unique=True)
+ if unique.columns.len() == 1 {
+ unique_columns.push(unique.columns[0].as_ref());
+ }
+ }
+ _ => {}
+ }
+ }
+
+ // Extract index information (non-constraint indexes)
+ for index in &table.indexes {
+ if !index.is_constraint_index && index.columns.len() == 1 {
+ if let Some(col) = &index.columns[0].column {
+ indexed_columns.push(col.as_ref());
+ }
+ }
+ }
+
+ Self {
+ table,
+ strategy,
+ primary_key_columns,
+ foreign_key_columns,
+ unique_columns,
+ indexed_columns,
+ }
+ }
+
+ /// Checks if a column is a primary key.
+ pub fn is_primary_key(&self, column_name: &str) -> bool {
+ self.primary_key_columns.contains(&column_name)
+ }
+
+ /// Gets the foreign key reference for a column, if any.
+ pub fn get_foreign_key(&self, column_name: &str) -> Option<&str> {
+ self.foreign_key_columns
+ .iter()
+ .find(|(col, _)| *col == column_name)
+ .map(|(_, fk)| fk.as_str())
+ }
+
+ /// Checks if a column has a unique constraint.
+ pub fn is_unique(&self, column_name: &str) -> bool {
+ self.unique_columns.contains(&column_name)
+ }
+
+ /// Checks if a column is indexed.
+ pub fn is_indexed(&self, column_name: &str) -> bool {
+ self.indexed_columns.contains(&column_name)
+ }
+}
+
+/// Generates field information for a column.
+pub fn generate_field(column: &Column, ctx: &FieldContext<'_>) -> FieldInfo {
+ let column_name = column.name.as_ref();
+
+ // Convert column name to Python attribute name
+ let (attr_name, needs_alias) = to_attribute_name(column_name, ctx.strategy);
+
+ // Map PostgreSQL type to Python type
+ let py_type = map_pg_type(&column.type_info);
+
+ // Collect imports
+ let mut imports = ImportCollector::new();
+ imports.add_all(&py_type.imports);
+ imports.add_all(&py_type.sa_imports);
+
+ // Check constraints for this column
+ let is_primary_key = ctx.is_primary_key(column_name);
+ let foreign_key = ctx.get_foreign_key(column_name).map(|s| s.to_string());
+ let is_unique = ctx.is_unique(column_name);
+ let is_indexed = ctx.is_indexed(column_name);
+
+ // Determine if we're dealing with an auto-generated primary key (identity or serial)
+ let is_auto_pk = is_primary_key
+ && (column.identity.is_some() || is_serial_type(&column.type_info.name.as_ref()));
+
+ // Build field parameters
+ let mut field_params = Vec::new();
+
+ // Handle nullability and primary key
+ let type_annotation = if is_auto_pk {
+ // Auto-generated PKs: type is Optional with default=None
+ field_params.push("default=None".to_string());
+ format!("{} | None", py_type.annotation)
+ } else if column.is_nullable {
+ format!("{} | None", py_type.annotation)
+ } else {
+ py_type.annotation.clone()
+ };
+
+ // Primary key
+ if is_primary_key {
+ field_params.push("primary_key=True".to_string());
+ }
+
+ // Foreign key
+ if let Some(ref fk) = foreign_key {
+ field_params.push(format!("foreign_key=\"{}\"", fk));
+ }
+
+ // Unique constraint
+ if is_unique && !is_primary_key {
+ field_params.push("unique=True".to_string());
+ }
+
+ // Index
+ if is_indexed && !is_primary_key && !is_unique {
+ field_params.push("index=True".to_string());
+ }
+
+ // String length constraints
+ if let Some(length) = extract_string_length(&column.type_info.formatted) {
+ if is_fixed_length_char(column.type_info.name.as_ref()) {
+ // Fixed-length char: both min and max
+ field_params.push(format!("min_length={length}"));
+ field_params.push(format!("max_length={length}"));
+ } else {
+ // Variable-length varchar: only max
+ field_params.push(format!("max_length={length}"));
+ }
+ }
+
+ // SQLAlchemy type (for JSON, ARRAY, etc.)
+ if let Some(ref sa_type) = py_type.sa_type {
+ field_params.push(format!("sa_type={sa_type}"));
+ }
+
+ // Alias for reserved words or sanitized names
+ if needs_alias {
+ field_params.push(format!("alias=\"{}\"", column_name));
+ }
+
+ // Handle default values for nullable non-PK fields
+ let simple_default =
+ if !is_auto_pk && column.is_nullable && column.default.is_none() && field_params.is_empty()
+ {
+ Some("None".to_string())
+ } else {
+ None
+ };
+
+ // Determine if we need a Field() declaration
+ let field_declaration = if !field_params.is_empty() {
+ imports.add_field();
+ Some(format!("Field({})", field_params.join(", ")))
+ } else {
+ None
+ };
+
+ FieldInfo {
+ attr_name,
+ db_column_name: column_name.to_string(),
+ needs_alias,
+ type_annotation,
+ field_declaration,
+ simple_default,
+ imports,
+ is_primary_key,
+ foreign_key,
+ is_indexed,
+ is_unique,
+ }
+}
+
+/// Checks if a type name represents a serial (auto-increment) type.
+fn is_serial_type(type_name: &str) -> bool {
+ matches!(
+ type_name,
+ "serial" | "serial2" | "serial4" | "serial8" | "smallserial" | "bigserial"
+ )
+}
+
+/// Formats a field as a Python class attribute line.
+pub fn format_field_line(field: &FieldInfo) -> String {
+ let type_ann = &field.type_annotation;
+ let name = &field.attr_name;
+
+ if let Some(ref field_decl) = field.field_declaration {
+ format!(" {name}: {type_ann} = {field_decl}")
+ } else if let Some(ref default) = field.simple_default {
+ format!(" {name}: {type_ann} = {default}")
+ } else {
+ format!(" {name}: {type_ann}")
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use tern_ddl::types::{ForeignKeyAction, QualifiedCollationName, QualifiedName};
+ use tern_ddl::{
+ ColumnName, Constraint, ConstraintName, ForeignKeyConstraint, IndexName, Oid,
+ PrimaryKeyConstraint, SchemaName, TableKind, TableName, TypeInfo, TypeName,
+ UniqueConstraint,
+ };
+
+ fn make_column(name: &str, type_name: &str, is_nullable: bool) -> Column {
+ Column {
+ name: ColumnName::try_new(name.to_string()).unwrap(),
+ position: 1,
+ type_info: TypeInfo {
+ name: TypeName::try_new(type_name.to_string()).unwrap(),
+ schema: SchemaName::try_new("pg_catalog".to_string()).unwrap(),
+ formatted: type_name.to_string(),
+ is_array: false,
+ },
+ is_nullable,
+ default: None,
+ generated: None,
+ identity: None,
+ collation: QualifiedCollationName::new(
+ SchemaName::try_new("pg_catalog".to_string()).unwrap(),
+ tern_ddl::CollationName::try_new("default".to_string()).unwrap(),
+ ),
+ comment: None,
+ }
+ }
+
+ fn make_table(name: &str, columns: Vec, constraints: Vec) -> Table {
+ Table {
+ oid: Oid::new(1),
+ name: TableName::try_new(name.to_string()).unwrap(),
+ kind: TableKind::Regular,
+ columns,
+ constraints,
+ indexes: vec![],
+ comment: None,
+ }
+ }
+
+ #[test]
+ fn test_simple_field() {
+ let column = make_column("name", "text", false);
+ let table = make_table("users", vec![column.clone()], vec![]);
+ let strategy = ReservedWordStrategy::AppendUnderscore;
+ let ctx = FieldContext::from_table(&table, &strategy);
+
+ let field = generate_field(&column, &ctx);
+
+ assert_eq!(field.attr_name, "name");
+ assert_eq!(field.type_annotation, "str");
+ assert!(!field.needs_alias);
+ assert!(field.field_declaration.is_none());
+ }
+
+ #[test]
+ fn test_nullable_field() {
+ let column = make_column("bio", "text", true);
+ let table = make_table("users", vec![column.clone()], vec![]);
+ let strategy = ReservedWordStrategy::AppendUnderscore;
+ let ctx = FieldContext::from_table(&table, &strategy);
+
+ let field = generate_field(&column, &ctx);
+
+ assert_eq!(field.type_annotation, "str | None");
+ assert_eq!(field.simple_default, Some("None".to_string()));
+ }
+
+ #[test]
+ fn test_primary_key_field() {
+ let column = make_column("id", "int4", false);
+ let pk = Constraint {
+ name: ConstraintName::try_new("users_pkey".to_string()).unwrap(),
+ kind: ConstraintKind::PrimaryKey(PrimaryKeyConstraint {
+ columns: vec![ColumnName::try_new("id".to_string()).unwrap()],
+ index_name: IndexName::try_new("users_pkey".to_string()).unwrap(),
+ }),
+ comment: None,
+ };
+ let table = make_table("users", vec![column.clone()], vec![pk]);
+ let strategy = ReservedWordStrategy::AppendUnderscore;
+ let ctx = FieldContext::from_table(&table, &strategy);
+
+ let field = generate_field(&column, &ctx);
+
+ assert!(field.is_primary_key);
+ assert!(field.field_declaration.is_some());
+ let decl = field.field_declaration.unwrap();
+ assert!(decl.contains("primary_key=True"));
+ }
+
+ #[test]
+ fn test_reserved_word_field() {
+ let column = make_column("class", "text", false);
+ let table = make_table("items", vec![column.clone()], vec![]);
+ let strategy = ReservedWordStrategy::AppendUnderscore;
+ let ctx = FieldContext::from_table(&table, &strategy);
+
+ let field = generate_field(&column, &ctx);
+
+ assert_eq!(field.attr_name, "class_");
+ assert!(field.needs_alias);
+ assert!(field.field_declaration.is_some());
+ let decl = field.field_declaration.unwrap();
+ assert!(decl.contains("alias=\"class\""));
+ }
+
+ #[test]
+ fn test_foreign_key_field() {
+ let column = make_column("user_id", "int4", true);
+ let fk = Constraint {
+ name: ConstraintName::try_new("posts_user_id_fkey".to_string()).unwrap(),
+ kind: ConstraintKind::ForeignKey(ForeignKeyConstraint {
+ columns: vec![ColumnName::try_new("user_id".to_string()).unwrap()],
+ referenced_table: QualifiedName::new(
+ SchemaName::try_new("public".to_string()).unwrap(),
+ TableName::try_new("users".to_string()).unwrap(),
+ ),
+ referenced_columns: vec![ColumnName::try_new("id".to_string()).unwrap()],
+ on_delete: ForeignKeyAction::NoAction,
+ on_update: ForeignKeyAction::NoAction,
+ is_deferrable: false,
+ is_initially_deferred: false,
+ }),
+ comment: None,
+ };
+ let table = make_table("posts", vec![column.clone()], vec![fk]);
+ let strategy = ReservedWordStrategy::AppendUnderscore;
+ let ctx = FieldContext::from_table(&table, &strategy);
+
+ let field = generate_field(&column, &ctx);
+
+ assert!(field.foreign_key.is_some());
+ assert!(field.field_declaration.is_some());
+ let decl = field.field_declaration.unwrap();
+ assert!(decl.contains("foreign_key="));
+ }
+
+ #[test]
+ fn test_unique_field() {
+ let column = make_column("email", "text", false);
+ let unique = Constraint {
+ name: ConstraintName::try_new("users_email_key".to_string()).unwrap(),
+ kind: ConstraintKind::Unique(UniqueConstraint {
+ columns: vec![ColumnName::try_new("email".to_string()).unwrap()],
+ index_name: IndexName::try_new("users_email_key".to_string()).unwrap(),
+ nulls_not_distinct: false,
+ }),
+ comment: None,
+ };
+ let table = make_table("users", vec![column.clone()], vec![unique]);
+ let strategy = ReservedWordStrategy::AppendUnderscore;
+ let ctx = FieldContext::from_table(&table, &strategy);
+
+ let field = generate_field(&column, &ctx);
+
+ assert!(field.is_unique);
+ assert!(field.field_declaration.is_some());
+ let decl = field.field_declaration.unwrap();
+ assert!(decl.contains("unique=True"));
+ }
+
+ #[test]
+ fn test_format_field_line() {
+ let field = FieldInfo {
+ attr_name: "name".to_string(),
+ db_column_name: "name".to_string(),
+ needs_alias: false,
+ type_annotation: "str".to_string(),
+ field_declaration: None,
+ simple_default: None,
+ imports: ImportCollector::new(),
+ is_primary_key: false,
+ foreign_key: None,
+ is_indexed: false,
+ is_unique: false,
+ };
+
+ let line = format_field_line(&field);
+ assert_eq!(line, " name: str");
+ }
+
+ #[test]
+ fn test_format_field_line_with_default() {
+ let field = FieldInfo {
+ attr_name: "bio".to_string(),
+ db_column_name: "bio".to_string(),
+ needs_alias: false,
+ type_annotation: "str | None".to_string(),
+ field_declaration: None,
+ simple_default: Some("None".to_string()),
+ imports: ImportCollector::new(),
+ is_primary_key: false,
+ foreign_key: None,
+ is_indexed: false,
+ is_unique: false,
+ };
+
+ let line = format_field_line(&field);
+ assert_eq!(line, " bio: str | None = None");
+ }
+
+ #[test]
+ fn test_format_field_line_with_field_declaration() {
+ let field = FieldInfo {
+ attr_name: "id".to_string(),
+ db_column_name: "id".to_string(),
+ needs_alias: false,
+ type_annotation: "int | None".to_string(),
+ field_declaration: Some("Field(default=None, primary_key=True)".to_string()),
+ simple_default: None,
+ imports: ImportCollector::new(),
+ is_primary_key: true,
+ foreign_key: None,
+ is_indexed: false,
+ is_unique: false,
+ };
+
+ let line = format_field_line(&field);
+ assert_eq!(
+ line,
+ " id: int | None = Field(default=None, primary_key=True)"
+ );
+ }
+}
diff --git a/crates/codegen/src/python/generator.rs b/crates/codegen/src/python/generator.rs
new file mode 100644
index 0000000..afe750a
--- /dev/null
+++ b/crates/codegen/src/python/generator.rs
@@ -0,0 +1,407 @@
+//! Python SQLModel code generator implementation.
+//!
+//! This module provides the main `PythonCodegen` struct that implements the `Codegen` trait
+//! for generating Python SQLModel models from PostgreSQL table definitions.
+
+use std::collections::HashMap;
+
+use tern_ddl::Table;
+
+use super::imports::ImportCollector;
+use super::model::{ModelInfo, format_model, generate_model};
+use super::naming::to_module_name;
+use super::{OutputMode, PythonCodegenConfig};
+use crate::Codegen;
+
+/// Python SQLModel code generator.
+///
+/// Generates Python SQLModel model definitions from PostgreSQL table schemas.
+///
+/// # Example
+///
+/// ```ignore
+/// use tern_codegen::python::{PythonCodegen, PythonCodegenConfig};
+/// use tern_codegen::Codegen;
+///
+/// let codegen = PythonCodegen::new(PythonCodegenConfig::default());
+/// let tables = vec![/* ... */];
+/// let output = codegen.generate(tables);
+///
+/// // output["models.py"] contains the generated code
+/// ```
+#[derive(Debug, Clone)]
+pub struct PythonCodegen {
+ config: PythonCodegenConfig,
+}
+
+impl PythonCodegen {
+ /// Creates a new Python code generator with the given configuration.
+ pub fn new(config: PythonCodegenConfig) -> Self {
+ Self { config }
+ }
+
+ /// Creates a new Python code generator with default configuration.
+ pub fn with_defaults() -> Self {
+ Self::new(PythonCodegenConfig::default())
+ }
+
+ /// Generates a single models.py file containing all models.
+ fn generate_single_file(&self, tables: Vec) -> HashMap {
+ let mut output = HashMap::new();
+
+ if tables.is_empty() {
+ output.insert("models.py".to_string(), generate_empty_models_file());
+ return output;
+ }
+
+ // Generate all models
+ let models: Vec = tables
+ .iter()
+ .map(|t| generate_model(t, &self.config))
+ .collect();
+
+ // Collect all imports
+ let mut imports = ImportCollector::new();
+ for model in &models {
+ imports.merge(&model.imports);
+ }
+
+ // Build the file content
+ let mut content = Vec::new();
+
+ // Module docstring
+ content.push(MODULE_DOCSTRING.to_string());
+ content.push(String::new());
+
+ // Imports
+ let import_block = imports.generate();
+ if !import_block.is_empty() {
+ content.push(import_block);
+ content.push(String::new());
+ }
+
+ // Models
+ for (i, model) in models.iter().enumerate() {
+ if i > 0 {
+ content.push(String::new());
+ content.push(String::new());
+ }
+ content.push(format_model(model));
+ }
+
+ // Ensure file ends with newline
+ content.push(String::new());
+
+ output.insert("models.py".to_string(), content.join("\n"));
+ output
+ }
+
+ /// Generates multiple files, one per model, with a shared types module.
+ fn generate_multi_file(&self, tables: Vec) -> HashMap {
+ let mut output = HashMap::new();
+
+ if tables.is_empty() {
+ output.insert("__init__.py".to_string(), generate_empty_init_file());
+ return output;
+ }
+
+ // Generate all models
+ let models: Vec = tables
+ .iter()
+ .map(|t| generate_model(t, &self.config))
+ .collect();
+
+ // Generate individual model files
+ let mut all_class_names = Vec::new();
+ let mut all_module_names = Vec::new();
+
+ for model in &models {
+ let module_name = to_module_name(&model.table_name);
+ let filename = format!("{module_name}.py");
+
+ // Build imports for this file
+ let mut imports = ImportCollector::new();
+ imports.merge(&model.imports);
+
+ // Build file content
+ let mut content = Vec::new();
+ content.push(format!(
+ "\"\"\"SQLModel definition for {}.\"\"\"\n",
+ model.class_name
+ ));
+
+ let import_block = imports.generate();
+ if !import_block.is_empty() {
+ content.push(import_block);
+ content.push(String::new());
+ }
+
+ content.push(format_model(model));
+ content.push(String::new());
+
+ output.insert(filename, content.join("\n"));
+
+ all_class_names.push(model.class_name.clone());
+ all_module_names.push(module_name);
+ }
+
+ // Generate __init__.py with re-exports
+ let init_content = generate_init_file(&all_module_names, &all_class_names);
+ output.insert("__init__.py".to_string(), init_content);
+
+ output
+ }
+}
+
+impl Default for PythonCodegen {
+ fn default() -> Self {
+ Self::with_defaults()
+ }
+}
+
+impl Codegen for PythonCodegen {
+ fn generate(&self, tables: Vec) -> HashMap {
+ match self.config.output_mode {
+ OutputMode::SingleFile => self.generate_single_file(tables),
+ OutputMode::MultiFile => self.generate_multi_file(tables),
+ }
+ }
+}
+
+/// Module docstring for generated files.
+const MODULE_DOCSTRING: &str = "\"\"\"SQLModel definitions generated by Tern.
+
+This file was automatically generated. Do not edit manually.
+\"\"\"";
+
+/// Generates an empty models.py file.
+fn generate_empty_models_file() -> String {
+ format!(
+ "{}\n\nfrom sqlmodel import SQLModel\n\n# No tables to generate\n",
+ MODULE_DOCSTRING
+ )
+}
+
+/// Generates an empty __init__.py file.
+fn generate_empty_init_file() -> String {
+ "\"\"\"SQLModel definitions generated by Tern.\"\"\"\n\n# No models to export\n".to_string()
+}
+
+/// Generates the __init__.py file with re-exports.
+fn generate_init_file(module_names: &[String], class_names: &[String]) -> String {
+ let mut content = Vec::new();
+ content.push("\"\"\"SQLModel definitions generated by Tern.\n".to_string());
+ content.push("This file was automatically generated. Do not edit manually.".to_string());
+ content.push("\"\"\"".to_string());
+ content.push(String::new());
+
+ // Import statements
+ for (module, class) in module_names.iter().zip(class_names.iter()) {
+ content.push(format!("from .{module} import {class}"));
+ }
+
+ content.push(String::new());
+
+ // __all__ list
+ let all_list: Vec = class_names.iter().map(|c| format!("\"{c}\"")).collect();
+ if all_list.len() <= 3 {
+ content.push(format!("__all__ = [{}]", all_list.join(", ")));
+ } else {
+ content.push("__all__ = [".to_string());
+ for item in &all_list {
+ content.push(format!(" {item},"));
+ }
+ content.push("]".to_string());
+ }
+
+ content.push(String::new());
+ content.join("\n")
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use tern_ddl::types::QualifiedCollationName;
+ use tern_ddl::{
+ Column, ColumnName, Constraint, ConstraintKind, ConstraintName, IndexName, Oid,
+ PrimaryKeyConstraint, SchemaName, TableKind, TableName, TypeInfo, TypeName,
+ };
+
+ fn make_column(name: &str, type_name: &str, is_nullable: bool) -> Column {
+ Column {
+ name: ColumnName::try_new(name.to_string()).unwrap(),
+ position: 1,
+ type_info: TypeInfo {
+ name: TypeName::try_new(type_name.to_string()).unwrap(),
+ schema: SchemaName::try_new("pg_catalog".to_string()).unwrap(),
+ formatted: type_name.to_string(),
+ is_array: false,
+ },
+ is_nullable,
+ default: None,
+ generated: None,
+ identity: None,
+ collation: QualifiedCollationName::new(
+ SchemaName::try_new("pg_catalog".to_string()).unwrap(),
+ tern_ddl::CollationName::try_new("default".to_string()).unwrap(),
+ ),
+ comment: None,
+ }
+ }
+
+ fn make_table_with_pk(name: &str, columns: Vec, pk_columns: &[&str]) -> Table {
+ let pk = Constraint {
+ name: ConstraintName::try_new(format!("{}_pkey", name)).unwrap(),
+ kind: ConstraintKind::PrimaryKey(PrimaryKeyConstraint {
+ columns: pk_columns
+ .iter()
+ .map(|c| ColumnName::try_new(c.to_string()).unwrap())
+ .collect(),
+ index_name: IndexName::try_new(format!("{}_pkey", name)).unwrap(),
+ }),
+ comment: None,
+ };
+
+ Table {
+ oid: Oid::new(1),
+ name: TableName::try_new(name.to_string()).unwrap(),
+ kind: TableKind::Regular,
+ columns,
+ constraints: vec![pk],
+ indexes: vec![],
+ comment: None,
+ }
+ }
+
+ #[test]
+ fn test_generate_empty_tables() {
+ let codegen = PythonCodegen::with_defaults();
+ let output = codegen.generate(vec![]);
+
+ assert!(output.contains_key("models.py"));
+ let content = &output["models.py"];
+ assert!(content.contains("SQLModel"));
+ assert!(content.contains("No tables to generate"));
+ }
+
+ #[test]
+ fn test_generate_single_table() {
+ let columns = vec![
+ make_column("id", "int4", false),
+ make_column("name", "text", false),
+ make_column("email", "text", false),
+ ];
+ let table = make_table_with_pk("users", columns, &["id"]);
+
+ let codegen = PythonCodegen::with_defaults();
+ let output = codegen.generate(vec![table]);
+
+ assert!(output.contains_key("models.py"));
+ let content = &output["models.py"];
+ assert!(content.contains("class User(SQLModel, table=True):"));
+ assert!(content.contains("from sqlmodel import"));
+ assert!(content.contains("__tablename__ = \"users\""));
+ }
+
+ #[test]
+ fn test_generate_multiple_tables() {
+ let user_columns = vec![
+ make_column("id", "int4", false),
+ make_column("name", "text", false),
+ ];
+ let user_table = make_table_with_pk("users", user_columns, &["id"]);
+
+ let post_columns = vec![
+ make_column("id", "int4", false),
+ make_column("title", "text", false),
+ make_column("user_id", "int4", true),
+ ];
+ let post_table = make_table_with_pk("posts", post_columns, &["id"]);
+
+ let codegen = PythonCodegen::with_defaults();
+ let output = codegen.generate(vec![user_table, post_table]);
+
+ assert!(output.contains_key("models.py"));
+ let content = &output["models.py"];
+ assert!(content.contains("class User(SQLModel, table=True):"));
+ assert!(content.contains("class Post(SQLModel, table=True):"));
+ }
+
+ #[test]
+ fn test_generate_multi_file_mode() {
+ let user_columns = vec![
+ make_column("id", "int4", false),
+ make_column("name", "text", false),
+ ];
+ let user_table = make_table_with_pk("users", user_columns, &["id"]);
+
+ let post_columns = vec![
+ make_column("id", "int4", false),
+ make_column("title", "text", false),
+ ];
+ let post_table = make_table_with_pk("posts", post_columns, &["id"]);
+
+ let config = PythonCodegenConfig {
+ output_mode: OutputMode::MultiFile,
+ ..Default::default()
+ };
+ let codegen = PythonCodegen::new(config);
+ let output = codegen.generate(vec![user_table, post_table]);
+
+ assert!(output.contains_key("__init__.py"));
+ assert!(output.contains_key("user.py"));
+ assert!(output.contains_key("post.py"));
+
+ // Check __init__.py content
+ let init = &output["__init__.py"];
+ assert!(init.contains("from .user import User"));
+ assert!(init.contains("from .post import Post"));
+ assert!(init.contains("__all__"));
+
+ // Check individual files
+ let user_file = &output["user.py"];
+ assert!(user_file.contains("class User(SQLModel, table=True):"));
+
+ let post_file = &output["post.py"];
+ assert!(post_file.contains("class Post(SQLModel, table=True):"));
+ }
+
+ #[test]
+ fn test_generate_with_datetime_import() {
+ let columns = vec![
+ make_column("id", "int4", false),
+ make_column("created_at", "timestamptz", false),
+ ];
+ let table = make_table_with_pk("events", columns, &["id"]);
+
+ let codegen = PythonCodegen::with_defaults();
+ let output = codegen.generate(vec![table]);
+
+ let content = &output["models.py"];
+ assert!(content.contains("from datetime import datetime"));
+ assert!(content.contains("created_at: datetime"));
+ }
+
+ #[test]
+ fn test_generate_with_uuid_import() {
+ let columns = vec![
+ make_column("id", "uuid", false),
+ make_column("name", "text", false),
+ ];
+ let table = make_table_with_pk("items", columns, &["id"]);
+
+ let codegen = PythonCodegen::with_defaults();
+ let output = codegen.generate(vec![table]);
+
+ let content = &output["models.py"];
+ assert!(content.contains("from uuid import UUID"));
+ assert!(content.contains("id:") || content.contains("id :"));
+ }
+
+ #[test]
+ fn test_default_codegen_implements_trait() {
+ let codegen = PythonCodegen::default();
+ let output = codegen.generate(vec![]);
+ assert!(output.contains_key("models.py"));
+ }
+}
diff --git a/crates/codegen/src/python/imports.rs b/crates/codegen/src/python/imports.rs
new file mode 100644
index 0000000..a5d2777
--- /dev/null
+++ b/crates/codegen/src/python/imports.rs
@@ -0,0 +1,330 @@
+//! Python import management.
+//!
+//! This module handles collecting and organizing Python imports for generated code.
+
+use std::collections::{BTreeMap, BTreeSet};
+
+use super::type_mapping::PythonImport;
+
+/// Collects and organizes Python imports for code generation.
+#[derive(Debug, Default)]
+pub struct ImportCollector {
+ /// Standard library imports grouped by module.
+ /// Using BTreeMap/BTreeSet for deterministic ordering.
+ stdlib: BTreeMap>,
+ /// Third-party imports (sqlmodel, sqlalchemy, pydantic).
+ third_party: BTreeMap>,
+ /// TYPE_CHECKING imports (for forward references).
+ type_checking: BTreeMap>,
+}
+
+impl ImportCollector {
+ /// Creates a new empty import collector.
+ pub fn new() -> Self {
+ Self::default()
+ }
+
+ /// Adds a Python import to the collector.
+ pub fn add(&mut self, import: &PythonImport) {
+ let category = categorize_module(&import.module);
+ let map = match category {
+ ModuleCategory::Stdlib => &mut self.stdlib,
+ ModuleCategory::ThirdParty => &mut self.third_party,
+ };
+
+ map.entry(import.module.clone())
+ .or_default()
+ .insert(import.name.clone());
+ }
+
+ /// Adds multiple imports to the collector.
+ pub fn add_all(&mut self, imports: &[PythonImport]) {
+ for import in imports {
+ self.add(import);
+ }
+ }
+
+ /// Adds a TYPE_CHECKING import (for forward references).
+ /// Used for relationship generation to avoid circular imports.
+ #[allow(dead_code)]
+ pub fn add_type_checking(&mut self, module: &str, name: &str) {
+ self.type_checking
+ .entry(module.to_string())
+ .or_default()
+ .insert(name.to_string());
+ }
+
+ /// Adds the core SQLModel import.
+ pub fn add_sqlmodel(&mut self) {
+ self.add(&PythonImport::new("sqlmodel", "SQLModel"));
+ }
+
+ /// Adds the SQLModel Field import.
+ pub fn add_field(&mut self) {
+ self.add(&PythonImport::new("sqlmodel", "Field"));
+ }
+
+ /// Adds the SQLModel Relationship import.
+ /// Used when generate_relationships config option is enabled.
+ #[allow(dead_code)]
+ pub fn add_relationship(&mut self) {
+ self.add(&PythonImport::new("sqlmodel", "Relationship"));
+ }
+
+ /// Checks if there are any TYPE_CHECKING imports.
+ /// Used for relationship generation to determine if TYPE_CHECKING block is needed.
+ #[allow(dead_code)]
+ pub fn has_type_checking(&self) -> bool {
+ !self.type_checking.is_empty()
+ }
+
+ /// Generates the import statements as a formatted string.
+ pub fn generate(&self) -> String {
+ let mut lines = Vec::new();
+
+ // Standard library imports
+ if !self.stdlib.is_empty() {
+ for (module, names) in &self.stdlib {
+ lines.push(format_import_line(module, names));
+ }
+ }
+
+ // Third-party imports (with blank line separator if needed)
+ if !self.third_party.is_empty() {
+ if !self.stdlib.is_empty() {
+ lines.push(String::new());
+ }
+ for (module, names) in &self.third_party {
+ lines.push(format_import_line(module, names));
+ }
+ }
+
+ // TYPE_CHECKING block
+ if !self.type_checking.is_empty() {
+ if !lines.is_empty() {
+ lines.push(String::new());
+ }
+ lines.push("if TYPE_CHECKING:".to_string());
+ for (module, names) in &self.type_checking {
+ let import_line = format_import_line(module, names);
+ lines.push(format!(" {import_line}"));
+ }
+ // Ensure typing.TYPE_CHECKING is imported
+ self.ensure_type_checking_imported(&mut lines);
+ }
+
+ lines.join("\n")
+ }
+
+ /// Ensures TYPE_CHECKING is imported from typing if we have a TYPE_CHECKING block.
+ fn ensure_type_checking_imported(&self, lines: &mut [String]) {
+ // Check if TYPE_CHECKING is already in typing imports
+ if let Some(typing_names) = self.stdlib.get("typing") {
+ if typing_names.contains("TYPE_CHECKING") {
+ return;
+ }
+ }
+
+ // Find the typing import line and add TYPE_CHECKING
+ for line in lines.iter_mut() {
+ if line.starts_with("from typing import ") {
+ // Add TYPE_CHECKING to the existing import
+ if !line.contains("TYPE_CHECKING") {
+ let insert_pos = "from typing import ".len();
+ line.insert_str(insert_pos, "TYPE_CHECKING, ");
+ }
+ return;
+ }
+ }
+
+ // No typing import found, need to add one at the beginning
+ // This case should be handled by the caller adding TYPE_CHECKING import explicitly
+ }
+
+ /// Merges another ImportCollector into this one.
+ pub fn merge(&mut self, other: &ImportCollector) {
+ for (module, names) in &other.stdlib {
+ self.stdlib
+ .entry(module.clone())
+ .or_default()
+ .extend(names.iter().cloned());
+ }
+ for (module, names) in &other.third_party {
+ self.third_party
+ .entry(module.clone())
+ .or_default()
+ .extend(names.iter().cloned());
+ }
+ for (module, names) in &other.type_checking {
+ self.type_checking
+ .entry(module.clone())
+ .or_default()
+ .extend(names.iter().cloned());
+ }
+ }
+}
+
+/// Module category for import grouping.
+#[derive(Debug, Clone, Copy, PartialEq, Eq)]
+enum ModuleCategory {
+ Stdlib,
+ ThirdParty,
+}
+
+/// Categorizes a module as stdlib or third-party.
+fn categorize_module(module: &str) -> ModuleCategory {
+ // Known third-party modules
+ const THIRD_PARTY: &[&str] = &["sqlmodel", "sqlalchemy", "pydantic", "fastapi", "starlette"];
+
+ if THIRD_PARTY.iter().any(|&m| module.starts_with(m)) {
+ ModuleCategory::ThirdParty
+ } else {
+ ModuleCategory::Stdlib
+ }
+}
+
+/// Formats a single import line.
+fn format_import_line(module: &str, names: &BTreeSet) -> String {
+ let names_str: Vec<&str> = names.iter().map(|s| s.as_str()).collect();
+
+ if names_str.len() == 1 {
+ format!("from {module} import {}", names_str[0])
+ } else {
+ // Sort for deterministic output
+ let mut sorted_names = names_str;
+ sorted_names.sort();
+
+ // Check if we need multi-line format
+ let single_line = format!("from {module} import {}", sorted_names.join(", "));
+ if single_line.len() <= 88 {
+ // PEP 8 line length
+ single_line
+ } else {
+ // Multi-line format
+ let mut lines = vec![format!("from {module} import (")];
+ for name in sorted_names {
+ lines.push(format!(" {name},"));
+ }
+ lines.push(")".to_string());
+ lines.join("\n")
+ }
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ #[test]
+ fn test_add_simple_import() {
+ let mut collector = ImportCollector::new();
+ collector.add(&PythonImport::new("datetime", "datetime"));
+
+ let output = collector.generate();
+ assert!(output.contains("from datetime import datetime"));
+ }
+
+ #[test]
+ fn test_add_multiple_from_same_module() {
+ let mut collector = ImportCollector::new();
+ collector.add(&PythonImport::new("datetime", "datetime"));
+ collector.add(&PythonImport::new("datetime", "date"));
+ collector.add(&PythonImport::new("datetime", "timedelta"));
+
+ let output = collector.generate();
+ assert!(output.contains("from datetime import"));
+ assert!(output.contains("date"));
+ assert!(output.contains("datetime"));
+ assert!(output.contains("timedelta"));
+ }
+
+ #[test]
+ fn test_stdlib_before_third_party() {
+ let mut collector = ImportCollector::new();
+ collector.add(&PythonImport::new("sqlmodel", "SQLModel"));
+ collector.add(&PythonImport::new("datetime", "datetime"));
+
+ let output = collector.generate();
+ let datetime_pos = output.find("datetime").unwrap();
+ let sqlmodel_pos = output.find("sqlmodel").unwrap();
+ assert!(datetime_pos < sqlmodel_pos);
+ }
+
+ #[test]
+ fn test_type_checking_block() {
+ let mut collector = ImportCollector::new();
+ collector.add(&PythonImport::new("typing", "TYPE_CHECKING"));
+ collector.add_type_checking(".user", "User");
+
+ let output = collector.generate();
+ assert!(output.contains("if TYPE_CHECKING:"));
+ assert!(output.contains("from .user import User"));
+ }
+
+ #[test]
+ fn test_sqlmodel_imports() {
+ let mut collector = ImportCollector::new();
+ collector.add_sqlmodel();
+ collector.add_field();
+
+ let output = collector.generate();
+ assert!(output.contains("from sqlmodel import"));
+ assert!(output.contains("Field"));
+ assert!(output.contains("SQLModel"));
+ }
+
+ #[test]
+ fn test_merge_collectors() {
+ let mut collector1 = ImportCollector::new();
+ collector1.add(&PythonImport::new("datetime", "datetime"));
+
+ let mut collector2 = ImportCollector::new();
+ collector2.add(&PythonImport::new("datetime", "date"));
+ collector2.add(&PythonImport::new("sqlmodel", "SQLModel"));
+
+ collector1.merge(&collector2);
+ let output = collector1.generate();
+
+ assert!(output.contains("date"));
+ assert!(output.contains("datetime"));
+ assert!(output.contains("SQLModel"));
+ }
+
+ #[test]
+ fn test_deterministic_output() {
+ // Run multiple times to verify ordering is consistent
+ for _ in 0..5 {
+ let mut collector = ImportCollector::new();
+ collector.add(&PythonImport::new("uuid", "UUID"));
+ collector.add(&PythonImport::new("datetime", "datetime"));
+ collector.add(&PythonImport::new("typing", "Any"));
+ collector.add(&PythonImport::new("sqlmodel", "SQLModel"));
+ collector.add(&PythonImport::new("sqlmodel", "Field"));
+
+ let output = collector.generate();
+
+ // Check order: stdlib (datetime, typing, uuid) then third-party (sqlmodel)
+ let datetime_pos = output.find("from datetime").unwrap();
+ let typing_pos = output.find("from typing").unwrap();
+ let uuid_pos = output.find("from uuid").unwrap();
+ let sqlmodel_pos = output.find("from sqlmodel").unwrap();
+
+ assert!(datetime_pos < typing_pos);
+ assert!(typing_pos < uuid_pos);
+ assert!(uuid_pos < sqlmodel_pos);
+ }
+ }
+
+ #[test]
+ fn test_categorize_module() {
+ assert_eq!(categorize_module("datetime"), ModuleCategory::Stdlib);
+ assert_eq!(categorize_module("typing"), ModuleCategory::Stdlib);
+ assert_eq!(categorize_module("uuid"), ModuleCategory::Stdlib);
+ assert_eq!(categorize_module("sqlmodel"), ModuleCategory::ThirdParty);
+ assert_eq!(categorize_module("sqlalchemy"), ModuleCategory::ThirdParty);
+ assert_eq!(
+ categorize_module("sqlalchemy.types"),
+ ModuleCategory::ThirdParty
+ );
+ }
+}
diff --git a/crates/codegen/src/python/mod.rs b/crates/codegen/src/python/mod.rs
index 06c7d62..b94a2b2 100644
--- a/crates/codegen/src/python/mod.rs
+++ b/crates/codegen/src/python/mod.rs
@@ -1 +1,128 @@
-//! Python code generation module.
+//! Python SQLModel code generation module.
+//!
+//! This module provides code generation for Python SQLModel models from PostgreSQL
+//! table definitions. It supports:
+//!
+//! - Idiomatic SQLModel models with proper type mappings
+//! - Primary keys, foreign keys, unique constraints, and indexes
+//! - Optional relationship generation for foreign keys
+//! - Generated and identity columns
+//! - Proper handling of Python reserved words and invalid identifiers
+//!
+//! # Example
+//!
+//! ```ignore
+//! use tern_codegen::{Codegen, python::PythonCodegen};
+//!
+//! let codegen = PythonCodegen::new(PythonCodegenConfig::default());
+//! let output = codegen.generate(tables);
+//! // output contains "models.py" with SQLModel classes
+//! ```
+
+mod field;
+mod generator;
+mod imports;
+mod model;
+mod naming;
+mod type_mapping;
+
+#[cfg(test)]
+mod tests;
+
+pub use generator::PythonCodegen;
+
+use std::fmt;
+
+/// Configuration for Python SQLModel code generation.
+#[derive(Debug, Clone)]
+pub struct PythonCodegenConfig {
+ /// Whether to generate Pydantic-only base classes for each model.
+ /// These are useful for request/response validation without DB coupling.
+ pub generate_base_models: bool,
+
+ /// Module name prefix for generated imports (e.g., "app.models").
+ pub module_prefix: Option,
+
+ /// Whether to include docstrings from table/column comments.
+ pub include_docstrings: bool,
+
+ /// How to handle Python reserved words in identifiers.
+ pub reserved_word_strategy: ReservedWordStrategy,
+
+ /// Whether to generate relationship attributes for foreign keys.
+ pub generate_relationships: bool,
+
+ /// Output mode: single file or multiple files.
+ pub output_mode: OutputMode,
+}
+
+impl Default for PythonCodegenConfig {
+ fn default() -> Self {
+ Self {
+ generate_base_models: false,
+ module_prefix: None,
+ include_docstrings: true,
+ reserved_word_strategy: ReservedWordStrategy::default(),
+ generate_relationships: false,
+ output_mode: OutputMode::default(),
+ }
+ }
+}
+
+/// Strategy for handling Python reserved words in identifiers.
+#[derive(Debug, Clone, Default, PartialEq, Eq)]
+pub enum ReservedWordStrategy {
+ /// Append an underscore: `class` -> `class_`
+ #[default]
+ AppendUnderscore,
+ /// Prepend with prefix: `class` -> `field_class`
+ PrependPrefix(String),
+}
+
+impl fmt::Display for ReservedWordStrategy {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ match self {
+ Self::AppendUnderscore => write!(f, "append_underscore"),
+ Self::PrependPrefix(prefix) => write!(f, "prepend_prefix({prefix})"),
+ }
+ }
+}
+
+/// Output mode for generated code.
+#[derive(Debug, Clone, Default, PartialEq, Eq)]
+pub enum OutputMode {
+ /// Generate all models in a single `models.py` file.
+ #[default]
+ SingleFile,
+ /// Generate separate files for each model with a shared types module.
+ MultiFile,
+}
+
+impl fmt::Display for OutputMode {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ match self {
+ Self::SingleFile => write!(f, "single_file"),
+ Self::MultiFile => write!(f, "multi_file"),
+ }
+ }
+}
+
+/// Errors that can occur during Python code generation.
+#[derive(Debug, thiserror::Error)]
+pub enum PythonCodegenError {
+ /// An unsupported PostgreSQL type was encountered.
+ #[error("unsupported PostgreSQL type: {0}")]
+ UnsupportedType(String),
+
+ /// A table has no columns.
+ #[error("table has no columns: {0}")]
+ EmptyTable(String),
+
+ /// An identifier is invalid after sanitization.
+ #[error("invalid identifier after sanitization: {0}")]
+ InvalidIdentifier(String),
+
+ /// Code generation failed.
+ #[error("code generation failed: {0}")]
+ GenerationError(String),
+}
diff --git a/crates/codegen/src/python/model.rs b/crates/codegen/src/python/model.rs
new file mode 100644
index 0000000..e2539fb
--- /dev/null
+++ b/crates/codegen/src/python/model.rs
@@ -0,0 +1,490 @@
+//! SQLModel class generation.
+//!
+//! This module handles generating complete SQLModel class definitions from table schemas.
+
+use tern_ddl::{ConstraintKind, Table};
+
+use super::PythonCodegenConfig;
+use super::field::{FieldContext, FieldInfo, format_field_line, generate_field};
+use super::imports::ImportCollector;
+use super::naming::to_class_name;
+
+/// Information about a generated SQLModel class.
+#[derive(Debug)]
+pub struct ModelInfo {
+ /// The Python class name.
+ pub class_name: String,
+ /// The database table name.
+ pub table_name: String,
+ /// Generated field information.
+ pub fields: Vec,
+ /// Required imports.
+ pub imports: ImportCollector,
+ /// Table args entries (for composite constraints, indexes, etc.).
+ pub table_args: Vec,
+ /// Docstring for the class (from table comment).
+ pub docstring: Option,
+ /// Whether this table has any columns.
+ pub has_columns: bool,
+ /// Warning messages for unsupported features.
+ pub warnings: Vec,
+}
+
+/// Generates model information for a table.
+pub fn generate_model(table: &Table, config: &PythonCodegenConfig) -> ModelInfo {
+ let table_name = table.name.as_ref().to_string();
+ let class_name = to_class_name(&table_name);
+
+ let mut imports = ImportCollector::new();
+ let mut table_args = Vec::new();
+ let mut warnings = Vec::new();
+
+ // Add base SQLModel import
+ imports.add_sqlmodel();
+
+ // Generate fields
+ let ctx = FieldContext::from_table(table, &config.reserved_word_strategy);
+ let mut fields: Vec = table
+ .columns
+ .iter()
+ .map(|col| {
+ let field = generate_field(col, &ctx);
+ imports.merge(&field.imports);
+ field
+ })
+ .collect();
+
+ // Sort fields: primary keys first, then required fields, then optional fields
+ fields.sort_by(|a, b| {
+ // Primary keys come first
+ match (a.is_primary_key, b.is_primary_key) {
+ (true, false) => std::cmp::Ordering::Less,
+ (false, true) => std::cmp::Ordering::Greater,
+ _ => {
+ // Then non-nullable before nullable (by checking if type contains "| None")
+ let a_nullable = a.type_annotation.contains("| None");
+ let b_nullable = b.type_annotation.contains("| None");
+ match (a_nullable, b_nullable) {
+ (false, true) => std::cmp::Ordering::Less,
+ (true, false) => std::cmp::Ordering::Greater,
+ _ => std::cmp::Ordering::Equal,
+ }
+ }
+ }
+ });
+
+ // Generate __table_args__ for composite constraints and indexes
+ generate_table_args(table, &mut table_args, &mut imports, &mut warnings);
+
+ // Generate docstring from table comment
+ let docstring = if config.include_docstrings {
+ table.comment.as_ref().map(|c| c.as_ref().to_string())
+ } else {
+ None
+ };
+
+ // Check for empty table
+ let has_columns = !table.columns.is_empty();
+ if !has_columns {
+ warnings.push(format!(
+ "Table '{}' has no columns. SQLModel requires at least one field.",
+ table_name
+ ));
+ }
+
+ ModelInfo {
+ class_name,
+ table_name,
+ fields,
+ imports,
+ table_args,
+ docstring,
+ has_columns,
+ warnings,
+ }
+}
+
+/// Generates __table_args__ entries for composite constraints and indexes.
+fn generate_table_args(
+ table: &Table,
+ table_args: &mut Vec,
+ imports: &mut ImportCollector,
+ warnings: &mut Vec,
+) {
+ // Check for composite primary key (already handled in fields via primary_key=True on each column)
+ // But if there's only one PK column with identity, it's handled differently
+
+ // Composite unique constraints
+ for constraint in &table.constraints {
+ match &constraint.kind {
+ ConstraintKind::Unique(unique) if unique.columns.len() > 1 => {
+ let columns: Vec = unique
+ .columns
+ .iter()
+ .map(|c| format!("\"{}\"", c.as_ref()))
+ .collect();
+ let constraint_name = constraint.name.as_ref();
+ table_args.push(format!(
+ "UniqueConstraint({}, name=\"{}\")",
+ columns.join(", "),
+ constraint_name
+ ));
+ imports.add(&super::type_mapping::PythonImport::new(
+ "sqlalchemy",
+ "UniqueConstraint",
+ ));
+ }
+ ConstraintKind::Check(check) => {
+ let expr = check.expression.as_ref();
+ let constraint_name = constraint.name.as_ref();
+ table_args.push(format!(
+ "CheckConstraint(\"{}\", name=\"{}\")",
+ escape_python_string(expr),
+ constraint_name
+ ));
+ imports.add(&super::type_mapping::PythonImport::new(
+ "sqlalchemy",
+ "CheckConstraint",
+ ));
+ }
+ ConstraintKind::Exclusion(excl) => {
+ // Exclusion constraints are not supported by SQLModel
+ let constraint_name = constraint.name.as_ref();
+ let elements: Vec = excl
+ .elements
+ .iter()
+ .map(|e| format!("{} WITH {}", e.expression.as_ref(), e.operator))
+ .collect();
+ warnings.push(format!(
+ "Exclusion constraint '{}' not supported by SQLModel: EXCLUDE USING {} ({})",
+ constraint_name,
+ excl.index_method.as_str(),
+ elements.join(", ")
+ ));
+ }
+ _ => {}
+ }
+ }
+
+ // Composite indexes (non-constraint indexes with multiple columns)
+ for index in &table.indexes {
+ if !index.is_constraint_index && index.columns.len() > 1 {
+ let columns: Vec = index
+ .columns
+ .iter()
+ .filter_map(|ic| ic.column.as_ref().map(|c| format!("\"{}\"", c.as_ref())))
+ .collect();
+
+ if !columns.is_empty() {
+ let index_name = index.name.as_ref();
+ table_args.push(format!("Index(\"{}\", {})", index_name, columns.join(", ")));
+ imports.add(&super::type_mapping::PythonImport::new(
+ "sqlalchemy",
+ "Index",
+ ));
+ }
+ }
+ }
+}
+
+/// Escapes a string for use in a Python string literal.
+fn escape_python_string(s: &str) -> String {
+ s.replace('\\', "\\\\")
+ .replace('"', "\\\"")
+ .replace('\n', "\\n")
+ .replace('\r', "\\r")
+ .replace('\t', "\\t")
+}
+
+/// Formats a complete SQLModel class definition.
+pub fn format_model(model: &ModelInfo) -> String {
+ let mut lines = Vec::new();
+
+ // Warning comments
+ for warning in &model.warnings {
+ lines.push(format!("# WARNING: {}", warning));
+ }
+
+ // Class definition
+ lines.push(format!("class {}(SQLModel, table=True):", model.class_name));
+
+ // Docstring
+ if let Some(ref doc) = model.docstring {
+ lines.push(format!(" \"\"\"{}\"\"\"", escape_python_string(doc)));
+ lines.push(String::new());
+ }
+
+ // __tablename__
+ lines.push(format!(" __tablename__ = \"{}\"", model.table_name));
+
+ // __table_args__
+ if !model.table_args.is_empty() {
+ if model.table_args.len() == 1 {
+ lines.push(format!(" __table_args__ = ({},)", model.table_args[0]));
+ } else {
+ lines.push(" __table_args__ = (".to_string());
+ for arg in &model.table_args {
+ lines.push(format!(" {},", arg));
+ }
+ lines.push(" )".to_string());
+ }
+ }
+
+ lines.push(String::new());
+
+ // Fields
+ if model.has_columns {
+ for field in &model.fields {
+ lines.push(format_field_line(field));
+ }
+ } else {
+ lines.push(" pass # No columns defined".to_string());
+ }
+
+ lines.join("\n")
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use tern_ddl::types::{QualifiedCollationName, SqlExpr};
+ use tern_ddl::{
+ CheckConstraint, Column, ColumnName, Constraint, ConstraintName, IndexName, Oid,
+ PrimaryKeyConstraint, SchemaName, TableKind, TableName, TypeInfo, TypeName,
+ UniqueConstraint,
+ };
+
+ fn make_column(name: &str, type_name: &str, is_nullable: bool) -> Column {
+ Column {
+ name: ColumnName::try_new(name.to_string()).unwrap(),
+ position: 1,
+ type_info: TypeInfo {
+ name: TypeName::try_new(type_name.to_string()).unwrap(),
+ schema: SchemaName::try_new("pg_catalog".to_string()).unwrap(),
+ formatted: type_name.to_string(),
+ is_array: false,
+ },
+ is_nullable,
+ default: None,
+ generated: None,
+ identity: None,
+ collation: QualifiedCollationName::new(
+ SchemaName::try_new("pg_catalog".to_string()).unwrap(),
+ tern_ddl::CollationName::try_new("default".to_string()).unwrap(),
+ ),
+ comment: None,
+ }
+ }
+
+ fn make_table_with_pk(name: &str, columns: Vec, pk_columns: &[&str]) -> Table {
+ let pk = Constraint {
+ name: ConstraintName::try_new(format!("{}_pkey", name)).unwrap(),
+ kind: ConstraintKind::PrimaryKey(PrimaryKeyConstraint {
+ columns: pk_columns
+ .iter()
+ .map(|c| ColumnName::try_new(c.to_string()).unwrap())
+ .collect(),
+ index_name: IndexName::try_new(format!("{}_pkey", name)).unwrap(),
+ }),
+ comment: None,
+ };
+
+ Table {
+ oid: Oid::new(1),
+ name: TableName::try_new(name.to_string()).unwrap(),
+ kind: TableKind::Regular,
+ columns,
+ constraints: vec![pk],
+ indexes: vec![],
+ comment: None,
+ }
+ }
+
+ #[test]
+ fn test_generate_simple_model() {
+ let columns = vec![
+ make_column("id", "int4", false),
+ make_column("name", "text", false),
+ make_column("email", "text", false),
+ ];
+ let table = make_table_with_pk("users", columns, &["id"]);
+ let config = PythonCodegenConfig::default();
+
+ let model = generate_model(&table, &config);
+
+ assert_eq!(model.class_name, "User");
+ assert_eq!(model.table_name, "users");
+ assert_eq!(model.fields.len(), 3);
+ assert!(model.has_columns);
+ assert!(model.warnings.is_empty());
+ }
+
+ #[test]
+ fn test_generate_model_with_nullable() {
+ let columns = vec![
+ make_column("id", "int4", false),
+ make_column("name", "text", false),
+ make_column("bio", "text", true),
+ ];
+ let table = make_table_with_pk("users", columns, &["id"]);
+ let config = PythonCodegenConfig::default();
+
+ let model = generate_model(&table, &config);
+
+ // Fields should be sorted: PK first, then non-nullable, then nullable
+ assert!(model.fields[0].is_primary_key);
+ assert!(!model.fields[1].type_annotation.contains("| None"));
+ assert!(model.fields[2].type_annotation.contains("| None"));
+ }
+
+ #[test]
+ fn test_generate_model_with_composite_unique() {
+ let columns = vec![
+ make_column("id", "int4", false),
+ make_column("email", "text", false),
+ make_column("tenant_id", "int4", false),
+ ];
+ let pk = Constraint {
+ name: ConstraintName::try_new("users_pkey".to_string()).unwrap(),
+ kind: ConstraintKind::PrimaryKey(PrimaryKeyConstraint {
+ columns: vec![ColumnName::try_new("id".to_string()).unwrap()],
+ index_name: IndexName::try_new("users_pkey".to_string()).unwrap(),
+ }),
+ comment: None,
+ };
+ let unique = Constraint {
+ name: ConstraintName::try_new("users_email_tenant_key".to_string()).unwrap(),
+ kind: ConstraintKind::Unique(UniqueConstraint {
+ columns: vec![
+ ColumnName::try_new("email".to_string()).unwrap(),
+ ColumnName::try_new("tenant_id".to_string()).unwrap(),
+ ],
+ index_name: IndexName::try_new("users_email_tenant_key".to_string()).unwrap(),
+ nulls_not_distinct: false,
+ }),
+ comment: None,
+ };
+ let table = Table {
+ oid: Oid::new(1),
+ name: TableName::try_new("users".to_string()).unwrap(),
+ kind: TableKind::Regular,
+ columns,
+ constraints: vec![pk, unique],
+ indexes: vec![],
+ comment: None,
+ };
+ let config = PythonCodegenConfig::default();
+
+ let model = generate_model(&table, &config);
+
+ assert!(!model.table_args.is_empty());
+ assert!(model.table_args[0].contains("UniqueConstraint"));
+ assert!(model.table_args[0].contains("email"));
+ assert!(model.table_args[0].contains("tenant_id"));
+ }
+
+ #[test]
+ fn test_generate_model_with_check_constraint() {
+ let columns = vec![
+ make_column("id", "int4", false),
+ make_column("price", "numeric", false),
+ ];
+ let pk = Constraint {
+ name: ConstraintName::try_new("products_pkey".to_string()).unwrap(),
+ kind: ConstraintKind::PrimaryKey(PrimaryKeyConstraint {
+ columns: vec![ColumnName::try_new("id".to_string()).unwrap()],
+ index_name: IndexName::try_new("products_pkey".to_string()).unwrap(),
+ }),
+ comment: None,
+ };
+ let check = Constraint {
+ name: ConstraintName::try_new("products_price_positive".to_string()).unwrap(),
+ kind: ConstraintKind::Check(CheckConstraint {
+ expression: SqlExpr::new("price > 0".to_string()),
+ is_no_inherit: false,
+ }),
+ comment: None,
+ };
+ let table = Table {
+ oid: Oid::new(1),
+ name: TableName::try_new("products".to_string()).unwrap(),
+ kind: TableKind::Regular,
+ columns,
+ constraints: vec![pk, check],
+ indexes: vec![],
+ comment: None,
+ };
+ let config = PythonCodegenConfig::default();
+
+ let model = generate_model(&table, &config);
+
+ assert!(!model.table_args.is_empty());
+ assert!(model.table_args[0].contains("CheckConstraint"));
+ assert!(model.table_args[0].contains("price > 0"));
+ }
+
+ #[test]
+ fn test_format_model() {
+ let columns = vec![
+ make_column("id", "int4", false),
+ make_column("name", "text", false),
+ ];
+ let table = make_table_with_pk("users", columns, &["id"]);
+ let config = PythonCodegenConfig::default();
+
+ let model = generate_model(&table, &config);
+ let output = format_model(&model);
+
+ assert!(output.contains("class User(SQLModel, table=True):"));
+ assert!(output.contains("__tablename__ = \"users\""));
+ assert!(output.contains("id:"));
+ assert!(output.contains("name:"));
+ }
+
+ #[test]
+ fn test_format_model_with_docstring() {
+ let columns = vec![make_column("id", "int4", false)];
+ let mut table = make_table_with_pk("users", columns, &["id"]);
+ table.comment = Some(tern_ddl::types::Comment::new(
+ "User accounts table".to_string(),
+ ));
+
+ let config = PythonCodegenConfig {
+ include_docstrings: true,
+ ..Default::default()
+ };
+
+ let model = generate_model(&table, &config);
+ let output = format_model(&model);
+
+ assert!(output.contains("\"\"\"User accounts table\"\"\""));
+ }
+
+ #[test]
+ fn test_empty_table_warning() {
+ let table = Table {
+ oid: Oid::new(1),
+ name: TableName::try_new("empty".to_string()).unwrap(),
+ kind: TableKind::Regular,
+ columns: vec![],
+ constraints: vec![],
+ indexes: vec![],
+ comment: None,
+ };
+ let config = PythonCodegenConfig::default();
+
+ let model = generate_model(&table, &config);
+
+ assert!(!model.has_columns);
+ assert!(!model.warnings.is_empty());
+ assert!(model.warnings[0].contains("no columns"));
+ }
+
+ #[test]
+ fn test_escape_python_string() {
+ assert_eq!(escape_python_string("hello"), "hello");
+ assert_eq!(escape_python_string("say \"hi\""), "say \\\"hi\\\"");
+ assert_eq!(escape_python_string("line1\nline2"), "line1\\nline2");
+ assert_eq!(escape_python_string("path\\to\\file"), "path\\\\to\\\\file");
+ }
+}
diff --git a/crates/codegen/src/python/naming.rs b/crates/codegen/src/python/naming.rs
new file mode 100644
index 0000000..0679e1b
--- /dev/null
+++ b/crates/codegen/src/python/naming.rs
@@ -0,0 +1,513 @@
+//! Python identifier naming and sanitization.
+//!
+//! This module handles conversion of PostgreSQL identifiers to valid Python identifiers,
+//! including handling of reserved words, invalid characters, and naming conventions.
+
+use super::ReservedWordStrategy;
+
+/// Python reserved words (keywords).
+///
+/// These cannot be used as identifiers in Python without modification.
+const PYTHON_KEYWORDS: &[&str] = &[
+ "False", "None", "True", "and", "as", "assert", "async", "await", "break", "class", "continue",
+ "def", "del", "elif", "else", "except", "finally", "for", "from", "global", "if", "import",
+ "in", "is", "lambda", "nonlocal", "not", "or", "pass", "raise", "return", "try", "while",
+ "with", "yield",
+];
+
+/// Python soft keywords (context-dependent keywords introduced in Python 3.10+).
+const PYTHON_SOFT_KEYWORDS: &[&str] = &["match", "case", "_", "type"];
+
+/// Python built-in names that should be avoided to prevent shadowing.
+/// Kept for future use to warn about shadowing built-in names.
+#[allow(dead_code)]
+const PYTHON_BUILTINS: &[&str] = &[
+ "abs",
+ "all",
+ "any",
+ "ascii",
+ "bin",
+ "bool",
+ "breakpoint",
+ "bytearray",
+ "bytes",
+ "callable",
+ "chr",
+ "classmethod",
+ "compile",
+ "complex",
+ "delattr",
+ "dict",
+ "dir",
+ "divmod",
+ "enumerate",
+ "eval",
+ "exec",
+ "filter",
+ "float",
+ "format",
+ "frozenset",
+ "getattr",
+ "globals",
+ "hasattr",
+ "hash",
+ "help",
+ "hex",
+ "id",
+ "input",
+ "int",
+ "isinstance",
+ "issubclass",
+ "iter",
+ "len",
+ "list",
+ "locals",
+ "map",
+ "max",
+ "memoryview",
+ "min",
+ "next",
+ "object",
+ "oct",
+ "open",
+ "ord",
+ "pow",
+ "print",
+ "property",
+ "range",
+ "repr",
+ "reversed",
+ "round",
+ "set",
+ "setattr",
+ "slice",
+ "sorted",
+ "staticmethod",
+ "str",
+ "sum",
+ "super",
+ "tuple",
+ "type",
+ "vars",
+ "zip",
+];
+
+/// Checks if a name is a Python keyword.
+pub fn is_python_keyword(name: &str) -> bool {
+ PYTHON_KEYWORDS.contains(&name)
+}
+
+/// Checks if a name is a Python soft keyword.
+pub fn is_python_soft_keyword(name: &str) -> bool {
+ PYTHON_SOFT_KEYWORDS.contains(&name)
+}
+
+/// Checks if a name is a Python builtin.
+/// Kept for future use to warn about shadowing built-in names.
+#[allow(dead_code)]
+pub fn is_python_builtin(name: &str) -> bool {
+ PYTHON_BUILTINS.contains(&name)
+}
+
+/// Checks if a name is a reserved word (keyword or soft keyword).
+pub fn is_reserved_word(name: &str) -> bool {
+ is_python_keyword(name) || is_python_soft_keyword(name)
+}
+
+/// Sanitizes a database identifier for use as a Python identifier.
+///
+/// This function:
+/// 1. Replaces invalid characters with underscores
+/// 2. Ensures the name doesn't start with a digit
+/// 3. Handles reserved words according to the strategy
+/// 4. Returns the sanitized name and whether aliasing is needed
+///
+/// Returns `(sanitized_name, needs_alias)` where `needs_alias` is true if
+/// the name was modified and needs a Field(alias="original_name") declaration.
+pub fn sanitize_identifier(name: &str, strategy: &ReservedWordStrategy) -> (String, bool) {
+ if name.is_empty() {
+ return ("_empty".to_string(), true);
+ }
+
+ let mut sanitized = String::with_capacity(name.len() + 1);
+ let mut needs_alias = false;
+
+ // Process each character
+ for (i, c) in name.chars().enumerate() {
+ if i == 0 {
+ // First character must be a letter or underscore
+ if c.is_ascii_alphabetic() || c == '_' {
+ sanitized.push(c);
+ } else if c.is_ascii_digit() {
+ // Prefix with underscore if starts with digit
+ sanitized.push('_');
+ sanitized.push(c);
+ needs_alias = true;
+ } else {
+ // Replace invalid first character with underscore
+ sanitized.push('_');
+ needs_alias = true;
+ }
+ } else if c.is_ascii_alphanumeric() || c == '_' {
+ sanitized.push(c);
+ } else {
+ // Replace invalid characters with underscore
+ sanitized.push('_');
+ needs_alias = true;
+ }
+ }
+
+ // Collapse multiple consecutive underscores
+ let mut collapsed = String::with_capacity(sanitized.len());
+ let mut prev_underscore = false;
+ for c in sanitized.chars() {
+ if c == '_' {
+ if !prev_underscore {
+ collapsed.push(c);
+ } else {
+ needs_alias = true;
+ }
+ prev_underscore = true;
+ } else {
+ collapsed.push(c);
+ prev_underscore = false;
+ }
+ }
+ sanitized = collapsed;
+
+ // Remove trailing underscores (unless it's the only character)
+ while sanitized.len() > 1 && sanitized.ends_with('_') && !name.ends_with('_') {
+ sanitized.pop();
+ needs_alias = true;
+ }
+
+ // Handle reserved words
+ if is_reserved_word(&sanitized) {
+ needs_alias = true;
+ sanitized = apply_reserved_word_strategy(&sanitized, strategy);
+ }
+
+ // Final validation - ensure we have a valid identifier
+ if sanitized.is_empty() || sanitized == "_" {
+ return ("_field".to_string(), true);
+ }
+
+ (sanitized, needs_alias)
+}
+
+/// Applies the reserved word strategy to transform a reserved word.
+fn apply_reserved_word_strategy(name: &str, strategy: &ReservedWordStrategy) -> String {
+ match strategy {
+ ReservedWordStrategy::AppendUnderscore => format!("{name}_"),
+ ReservedWordStrategy::PrependPrefix(prefix) => format!("{prefix}{name}"),
+ }
+}
+
+/// Converts a table name to a Python class name (PascalCase).
+///
+/// Examples:
+/// - "users" -> "User"
+/// - "user_accounts" -> "UserAccount"
+/// - "order_items" -> "OrderItem"
+pub fn to_class_name(table_name: &str) -> String {
+ let mut result = String::with_capacity(table_name.len());
+ let mut capitalize_next = true;
+
+ // Check for pluralization - handle "ies" -> "y" first (e.g., "categories" -> "category")
+ let name = if table_name.ends_with("ies") && table_name.len() > 3 {
+ let base = &table_name[..table_name.len() - 3];
+ return to_pascal_case(base) + "y";
+ } else if table_name.ends_with('s')
+ && !table_name.ends_with("ss")
+ && !table_name.ends_with("us")
+ && !table_name.ends_with("is")
+ {
+ // Simple plural: strip trailing 's'
+ &table_name[..table_name.len() - 1]
+ } else {
+ table_name
+ };
+
+ // Convert to PascalCase
+ for c in name.chars() {
+ if c == '_' || c == '-' || c == ' ' {
+ capitalize_next = true;
+ } else if capitalize_next {
+ result.push(c.to_ascii_uppercase());
+ capitalize_next = false;
+ } else {
+ result.push(c.to_ascii_lowercase());
+ }
+ }
+
+ // Handle edge case where result is empty
+ if result.is_empty() {
+ result = "Model".to_string();
+ }
+
+ result
+}
+
+/// Converts a string to PascalCase.
+fn to_pascal_case(s: &str) -> String {
+ let mut result = String::with_capacity(s.len());
+ let mut capitalize_next = true;
+
+ for c in s.chars() {
+ if c == '_' || c == '-' || c == ' ' {
+ capitalize_next = true;
+ } else if capitalize_next {
+ result.push(c.to_ascii_uppercase());
+ capitalize_next = false;
+ } else {
+ result.push(c.to_ascii_lowercase());
+ }
+ }
+
+ result
+}
+
+/// Converts a table name to a Python module name (snake_case, singular).
+///
+/// Examples:
+/// - "users" -> "user"
+/// - "user_accounts" -> "user_account"
+/// - "OrderItems" -> "order_item"
+pub fn to_module_name(table_name: &str) -> String {
+ let mut result = String::with_capacity(table_name.len());
+ let mut prev_was_upper = false;
+
+ for (i, c) in table_name.chars().enumerate() {
+ if c.is_ascii_uppercase() {
+ if i > 0 && !prev_was_upper {
+ result.push('_');
+ }
+ result.push(c.to_ascii_lowercase());
+ prev_was_upper = true;
+ } else if c == '-' || c == ' ' {
+ result.push('_');
+ prev_was_upper = false;
+ } else {
+ result.push(c);
+ prev_was_upper = false;
+ }
+ }
+
+ // Remove plural 's' suffix
+ if result.ends_with('s')
+ && !result.ends_with("ss")
+ && !result.ends_with("us")
+ && !result.ends_with("is")
+ && result.len() > 1
+ {
+ result.pop();
+ } else if result.ends_with("ies") && result.len() > 3 {
+ // Handle "ies" -> "y"
+ result.truncate(result.len() - 3);
+ result.push('y');
+ }
+
+ result
+}
+
+/// Converts a column name to a Python attribute name (snake_case).
+///
+/// PostgreSQL column names are typically already in snake_case, but this
+/// handles edge cases like mixed case or special characters.
+pub fn to_attribute_name(column_name: &str, strategy: &ReservedWordStrategy) -> (String, bool) {
+ // First sanitize the identifier
+ let (sanitized, needs_alias) = sanitize_identifier(column_name, strategy);
+
+ // Convert to snake_case if needed (handle camelCase input)
+ let mut result = String::with_capacity(sanitized.len() + 4);
+ let mut prev_was_upper = false;
+ let mut prev_was_underscore = true; // Treat start as after underscore
+
+ for c in sanitized.chars() {
+ if c.is_ascii_uppercase() {
+ if !prev_was_upper && !prev_was_underscore {
+ result.push('_');
+ }
+ result.push(c.to_ascii_lowercase());
+ prev_was_upper = true;
+ prev_was_underscore = false;
+ } else if c == '_' {
+ if !prev_was_underscore {
+ result.push(c);
+ }
+ prev_was_upper = false;
+ prev_was_underscore = true;
+ } else {
+ result.push(c);
+ prev_was_upper = false;
+ prev_was_underscore = false;
+ }
+ }
+
+ // Check if conversion changed the name
+ let conversion_changed = result != sanitized;
+
+ (result, needs_alias || conversion_changed)
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ #[test]
+ fn test_is_python_keyword() {
+ assert!(is_python_keyword("class"));
+ assert!(is_python_keyword("def"));
+ assert!(is_python_keyword("if"));
+ assert!(is_python_keyword("True"));
+ assert!(is_python_keyword("None"));
+ assert!(!is_python_keyword("user"));
+ assert!(!is_python_keyword("name"));
+ }
+
+ #[test]
+ fn test_is_python_soft_keyword() {
+ assert!(is_python_soft_keyword("match"));
+ assert!(is_python_soft_keyword("case"));
+ assert!(!is_python_soft_keyword("class"));
+ }
+
+ #[test]
+ fn test_is_reserved_word() {
+ assert!(is_reserved_word("class"));
+ assert!(is_reserved_word("match"));
+ assert!(!is_reserved_word("user"));
+ }
+
+ #[test]
+ fn test_sanitize_identifier_valid() {
+ let strategy = ReservedWordStrategy::AppendUnderscore;
+ let (name, needs_alias) = sanitize_identifier("user_id", &strategy);
+ assert_eq!(name, "user_id");
+ assert!(!needs_alias);
+ }
+
+ #[test]
+ fn test_sanitize_identifier_reserved_word() {
+ let strategy = ReservedWordStrategy::AppendUnderscore;
+ let (name, needs_alias) = sanitize_identifier("class", &strategy);
+ assert_eq!(name, "class_");
+ assert!(needs_alias);
+ }
+
+ #[test]
+ fn test_sanitize_identifier_reserved_word_prefix() {
+ let strategy = ReservedWordStrategy::PrependPrefix("field_".to_string());
+ let (name, needs_alias) = sanitize_identifier("class", &strategy);
+ assert_eq!(name, "field_class");
+ assert!(needs_alias);
+ }
+
+ #[test]
+ fn test_sanitize_identifier_starts_with_digit() {
+ let strategy = ReservedWordStrategy::AppendUnderscore;
+ let (name, needs_alias) = sanitize_identifier("1column", &strategy);
+ assert_eq!(name, "_1column");
+ assert!(needs_alias);
+ }
+
+ #[test]
+ fn test_sanitize_identifier_special_characters() {
+ let strategy = ReservedWordStrategy::AppendUnderscore;
+ let (name, needs_alias) = sanitize_identifier("column-name", &strategy);
+ assert_eq!(name, "column_name");
+ assert!(needs_alias);
+ }
+
+ #[test]
+ fn test_sanitize_identifier_spaces() {
+ let strategy = ReservedWordStrategy::AppendUnderscore;
+ let (name, needs_alias) = sanitize_identifier("column name", &strategy);
+ assert_eq!(name, "column_name");
+ assert!(needs_alias);
+ }
+
+ #[test]
+ fn test_sanitize_identifier_multiple_underscores() {
+ let strategy = ReservedWordStrategy::AppendUnderscore;
+ let (name, needs_alias) = sanitize_identifier("column__name", &strategy);
+ assert_eq!(name, "column_name");
+ assert!(needs_alias);
+ }
+
+ #[test]
+ fn test_sanitize_identifier_empty() {
+ let strategy = ReservedWordStrategy::AppendUnderscore;
+ let (name, needs_alias) = sanitize_identifier("", &strategy);
+ assert_eq!(name, "_empty");
+ assert!(needs_alias);
+ }
+
+ #[test]
+ fn test_to_class_name_simple() {
+ assert_eq!(to_class_name("users"), "User");
+ assert_eq!(to_class_name("user"), "User");
+ }
+
+ #[test]
+ fn test_to_class_name_compound() {
+ assert_eq!(to_class_name("user_accounts"), "UserAccount");
+ assert_eq!(to_class_name("order_items"), "OrderItem");
+ }
+
+ #[test]
+ fn test_to_class_name_categories() {
+ assert_eq!(to_class_name("categories"), "Category");
+ }
+
+ #[test]
+ fn test_to_class_name_preserves_non_plural() {
+ assert_eq!(to_class_name("status"), "Status");
+ assert_eq!(to_class_name("address"), "Address");
+ }
+
+ #[test]
+ fn test_to_module_name_simple() {
+ assert_eq!(to_module_name("users"), "user");
+ assert_eq!(to_module_name("User"), "user");
+ }
+
+ #[test]
+ fn test_to_module_name_compound() {
+ assert_eq!(to_module_name("user_accounts"), "user_account");
+ assert_eq!(to_module_name("OrderItems"), "order_item");
+ }
+
+ #[test]
+ fn test_to_attribute_name_simple() {
+ let strategy = ReservedWordStrategy::AppendUnderscore;
+ let (name, needs_alias) = to_attribute_name("user_id", &strategy);
+ assert_eq!(name, "user_id");
+ assert!(!needs_alias);
+ }
+
+ #[test]
+ fn test_to_attribute_name_camel_case() {
+ let strategy = ReservedWordStrategy::AppendUnderscore;
+ let (name, needs_alias) = to_attribute_name("userId", &strategy);
+ assert_eq!(name, "user_id");
+ assert!(needs_alias);
+ }
+
+ #[test]
+ fn test_to_attribute_name_reserved() {
+ let strategy = ReservedWordStrategy::AppendUnderscore;
+ let (name, needs_alias) = to_attribute_name("from", &strategy);
+ assert_eq!(name, "from_");
+ assert!(needs_alias);
+ }
+
+ #[test]
+ fn test_python_builtins() {
+ assert!(is_python_builtin("list"));
+ assert!(is_python_builtin("dict"));
+ assert!(is_python_builtin("str"));
+ assert!(is_python_builtin("int"));
+ assert!(!is_python_builtin("user"));
+ }
+}
diff --git a/crates/codegen/src/python/tests/mod.rs b/crates/codegen/src/python/tests/mod.rs
new file mode 100644
index 0000000..c084c98
--- /dev/null
+++ b/crates/codegen/src/python/tests/mod.rs
@@ -0,0 +1,7 @@
+//! Tests for Python SQLModel code generation.
+//!
+//! This module contains comprehensive tests including unit tests, snapshot tests,
+//! and edge case tests for the Python code generator.
+
+mod snapshot_tests;
+mod unit_tests;
diff --git a/crates/codegen/src/python/tests/snapshot_tests.rs b/crates/codegen/src/python/tests/snapshot_tests.rs
new file mode 100644
index 0000000..2483886
--- /dev/null
+++ b/crates/codegen/src/python/tests/snapshot_tests.rs
@@ -0,0 +1,577 @@
+//! Snapshot tests for Python code generation.
+//!
+//! These tests use insta to capture and verify the full output of the code generator,
+//! making it easy to review changes to the generated code format.
+
+use crate::Codegen;
+use crate::python::{OutputMode, PythonCodegen, PythonCodegenConfig};
+
+use tern_ddl::types::{
+ Comment, ForeignKeyAction, IndexMethod, QualifiedCollationName, QualifiedName, SqlExpr,
+};
+use tern_ddl::{
+ CheckConstraint, CollationName, Column, ColumnName, Constraint, ConstraintKind, ConstraintName,
+ ExclusionConstraint, ExclusionElement, ForeignKeyConstraint, IdentityKind, IndexName, Oid,
+ PrimaryKeyConstraint, SchemaName, Table, TableKind, TableName, TypeInfo, TypeName,
+ UniqueConstraint,
+};
+
+// =============================================================================
+// Test Helpers
+// =============================================================================
+
+fn make_type_info(name: &str, formatted: &str, is_array: bool) -> TypeInfo {
+ TypeInfo {
+ name: TypeName::try_new(name.to_string()).unwrap(),
+ schema: SchemaName::try_new("pg_catalog".to_string()).unwrap(),
+ formatted: formatted.to_string(),
+ is_array,
+ }
+}
+
+fn make_column(name: &str, type_name: &str, is_nullable: bool) -> Column {
+ Column {
+ name: ColumnName::try_new(name.to_string()).unwrap(),
+ position: 1,
+ type_info: make_type_info(type_name, type_name, false),
+ is_nullable,
+ default: None,
+ generated: None,
+ identity: None,
+ collation: QualifiedCollationName::new(
+ SchemaName::try_new("pg_catalog".to_string()).unwrap(),
+ CollationName::try_new("default".to_string()).unwrap(),
+ ),
+ comment: None,
+ }
+}
+
+fn make_pk_constraint(table_name: &str, columns: &[&str]) -> Constraint {
+ Constraint {
+ name: ConstraintName::try_new(format!("{}_pkey", table_name)).unwrap(),
+ kind: ConstraintKind::PrimaryKey(PrimaryKeyConstraint {
+ columns: columns
+ .iter()
+ .map(|c| ColumnName::try_new(c.to_string()).unwrap())
+ .collect(),
+ index_name: IndexName::try_new(format!("{}_pkey", table_name)).unwrap(),
+ }),
+ comment: None,
+ }
+}
+
+fn make_table(name: &str, columns: Vec, constraints: Vec) -> Table {
+ Table {
+ oid: Oid::new(1),
+ name: TableName::try_new(name.to_string()).unwrap(),
+ kind: TableKind::Regular,
+ columns,
+ constraints,
+ indexes: vec![],
+ comment: None,
+ }
+}
+
+// =============================================================================
+// Snapshot Tests
+// =============================================================================
+
+#[test]
+fn snapshot_simple_users_table() {
+ let columns = vec![
+ {
+ let mut col = make_column("id", "int4", false);
+ col.identity = Some(IdentityKind::Always);
+ col
+ },
+ make_column("email", "text", false),
+ make_column("name", "text", true),
+ {
+ let mut col = make_column("created_at", "timestamptz", false);
+ col.type_info = make_type_info("timestamptz", "timestamp with time zone", false);
+ col
+ },
+ ];
+
+ let constraints = vec![
+ make_pk_constraint("users", &["id"]),
+ Constraint {
+ name: ConstraintName::try_new("users_email_key".to_string()).unwrap(),
+ kind: ConstraintKind::Unique(UniqueConstraint {
+ columns: vec![ColumnName::try_new("email".to_string()).unwrap()],
+ index_name: IndexName::try_new("users_email_key".to_string()).unwrap(),
+ nulls_not_distinct: false,
+ }),
+ comment: None,
+ },
+ ];
+
+ let mut table = make_table("users", columns, constraints);
+ table.comment = Some(Comment::new(
+ "User accounts for the application.".to_string(),
+ ));
+
+ let codegen = PythonCodegen::with_defaults();
+ let output = codegen.generate(vec![table]);
+
+ insta::assert_snapshot!(output.get("models.py").unwrap());
+}
+
+#[test]
+fn snapshot_blog_schema() {
+ // Users table
+ let user_columns = vec![
+ {
+ let mut col = make_column("id", "int4", false);
+ col.identity = Some(IdentityKind::Always);
+ col
+ },
+ make_column("username", "text", false),
+ make_column("email", "text", false),
+ make_column("bio", "text", true),
+ make_column("created_at", "timestamptz", false),
+ ];
+ let user_constraints = vec![
+ make_pk_constraint("users", &["id"]),
+ Constraint {
+ name: ConstraintName::try_new("users_username_key".to_string()).unwrap(),
+ kind: ConstraintKind::Unique(UniqueConstraint {
+ columns: vec![ColumnName::try_new("username".to_string()).unwrap()],
+ index_name: IndexName::try_new("users_username_key".to_string()).unwrap(),
+ nulls_not_distinct: false,
+ }),
+ comment: None,
+ },
+ Constraint {
+ name: ConstraintName::try_new("users_email_key".to_string()).unwrap(),
+ kind: ConstraintKind::Unique(UniqueConstraint {
+ columns: vec![ColumnName::try_new("email".to_string()).unwrap()],
+ index_name: IndexName::try_new("users_email_key".to_string()).unwrap(),
+ nulls_not_distinct: false,
+ }),
+ comment: None,
+ },
+ ];
+ let users = make_table("users", user_columns, user_constraints);
+
+ // Posts table
+ let post_columns = vec![
+ {
+ let mut col = make_column("id", "int4", false);
+ col.identity = Some(IdentityKind::Always);
+ col
+ },
+ make_column("title", "text", false),
+ make_column("content", "text", false),
+ make_column("author_id", "int4", false),
+ make_column("published_at", "timestamptz", true),
+ make_column("created_at", "timestamptz", false),
+ ];
+ let post_constraints = vec![
+ make_pk_constraint("posts", &["id"]),
+ Constraint {
+ name: ConstraintName::try_new("posts_author_id_fkey".to_string()).unwrap(),
+ kind: ConstraintKind::ForeignKey(ForeignKeyConstraint {
+ columns: vec![ColumnName::try_new("author_id".to_string()).unwrap()],
+ referenced_table: QualifiedName::new(
+ SchemaName::try_new("public".to_string()).unwrap(),
+ TableName::try_new("users".to_string()).unwrap(),
+ ),
+ referenced_columns: vec![ColumnName::try_new("id".to_string()).unwrap()],
+ on_delete: ForeignKeyAction::Cascade,
+ on_update: ForeignKeyAction::NoAction,
+ is_deferrable: false,
+ is_initially_deferred: false,
+ }),
+ comment: None,
+ },
+ ];
+ let posts = make_table("posts", post_columns, post_constraints);
+
+ // Comments table
+ let comment_columns = vec![
+ {
+ let mut col = make_column("id", "int4", false);
+ col.identity = Some(IdentityKind::Always);
+ col
+ },
+ make_column("content", "text", false),
+ make_column("post_id", "int4", false),
+ make_column("author_id", "int4", false),
+ make_column("created_at", "timestamptz", false),
+ ];
+ let comment_constraints = vec![
+ make_pk_constraint("comments", &["id"]),
+ Constraint {
+ name: ConstraintName::try_new("comments_post_id_fkey".to_string()).unwrap(),
+ kind: ConstraintKind::ForeignKey(ForeignKeyConstraint {
+ columns: vec![ColumnName::try_new("post_id".to_string()).unwrap()],
+ referenced_table: QualifiedName::new(
+ SchemaName::try_new("public".to_string()).unwrap(),
+ TableName::try_new("posts".to_string()).unwrap(),
+ ),
+ referenced_columns: vec![ColumnName::try_new("id".to_string()).unwrap()],
+ on_delete: ForeignKeyAction::Cascade,
+ on_update: ForeignKeyAction::NoAction,
+ is_deferrable: false,
+ is_initially_deferred: false,
+ }),
+ comment: None,
+ },
+ Constraint {
+ name: ConstraintName::try_new("comments_author_id_fkey".to_string()).unwrap(),
+ kind: ConstraintKind::ForeignKey(ForeignKeyConstraint {
+ columns: vec![ColumnName::try_new("author_id".to_string()).unwrap()],
+ referenced_table: QualifiedName::new(
+ SchemaName::try_new("public".to_string()).unwrap(),
+ TableName::try_new("users".to_string()).unwrap(),
+ ),
+ referenced_columns: vec![ColumnName::try_new("id".to_string()).unwrap()],
+ on_delete: ForeignKeyAction::Cascade,
+ on_update: ForeignKeyAction::NoAction,
+ is_deferrable: false,
+ is_initially_deferred: false,
+ }),
+ comment: None,
+ },
+ ];
+ let comments = make_table("comments", comment_columns, comment_constraints);
+
+ let codegen = PythonCodegen::with_defaults();
+ let output = codegen.generate(vec![users, posts, comments]);
+
+ insta::assert_snapshot!(output.get("models.py").unwrap());
+}
+
+#[test]
+fn snapshot_all_postgres_types() {
+ let columns = vec![
+ // Integer types
+ make_column("col_int2", "int2", false),
+ make_column("col_int4", "int4", false),
+ make_column("col_int8", "int8", false),
+ // Float types
+ make_column("col_float4", "float4", false),
+ make_column("col_float8", "float8", false),
+ // Decimal
+ {
+ let mut col = make_column("col_numeric", "numeric", false);
+ col.type_info = make_type_info("numeric", "numeric(10,2)", false);
+ col
+ },
+ // Boolean
+ make_column("col_bool", "bool", false),
+ // Text types
+ make_column("col_text", "text", false),
+ {
+ let mut col = make_column("col_varchar", "varchar", false);
+ col.type_info = make_type_info("varchar", "character varying(255)", false);
+ col
+ },
+ // Date/time types
+ make_column("col_date", "date", false),
+ make_column("col_time", "time", false),
+ make_column("col_timestamp", "timestamp", false),
+ make_column("col_timestamptz", "timestamptz", false),
+ make_column("col_interval", "interval", false),
+ // UUID
+ make_column("col_uuid", "uuid", false),
+ // JSON
+ make_column("col_json", "json", true),
+ make_column("col_jsonb", "jsonb", true),
+ // Binary
+ make_column("col_bytea", "bytea", true),
+ // Network types
+ make_column("col_inet", "inet", true),
+ // Array types
+ {
+ let mut col = make_column("col_int_array", "int4", false);
+ col.type_info = make_type_info("int4", "integer[]", true);
+ col
+ },
+ {
+ let mut col = make_column("col_text_array", "text", true);
+ col.type_info = make_type_info("text", "text[]", true);
+ col
+ },
+ ];
+
+ // Add a simple PK
+ let mut all_columns = vec![{
+ let mut col = make_column("id", "int4", false);
+ col.identity = Some(IdentityKind::Always);
+ col
+ }];
+ all_columns.extend(columns);
+
+ let constraints = vec![make_pk_constraint("all_types", &["id"])];
+ let table = make_table("all_types", all_columns, constraints);
+
+ let codegen = PythonCodegen::with_defaults();
+ let output = codegen.generate(vec![table]);
+
+ insta::assert_snapshot!(output.get("models.py").unwrap());
+}
+
+#[test]
+fn snapshot_reserved_words_table() {
+ let columns = vec![
+ {
+ let mut col = make_column("id", "int4", false);
+ col.identity = Some(IdentityKind::Always);
+ col
+ },
+ make_column("class", "text", false),
+ make_column("from", "text", false),
+ make_column("import", "text", true),
+ make_column("def", "int4", false),
+ make_column("return", "bool", false),
+ make_column("yield", "text", true),
+ make_column("async", "bool", false),
+ make_column("await", "text", true),
+ make_column("match", "text", true),
+ make_column("case", "int4", true),
+ ];
+
+ let constraints = vec![make_pk_constraint("reserved_words", &["id"])];
+ let table = make_table("reserved_words", columns, constraints);
+
+ let codegen = PythonCodegen::with_defaults();
+ let output = codegen.generate(vec![table]);
+
+ insta::assert_snapshot!(output.get("models.py").unwrap());
+}
+
+#[test]
+fn snapshot_complex_constraints() {
+ let columns = vec![
+ {
+ let mut col = make_column("id", "int4", false);
+ col.identity = Some(IdentityKind::Always);
+ col
+ },
+ make_column("email", "text", false),
+ make_column("tenant_id", "int4", false),
+ make_column("status", "text", false),
+ make_column("price", "numeric", false),
+ make_column("quantity", "int4", false),
+ ];
+
+ let constraints = vec![
+ make_pk_constraint("products", &["id"]),
+ // Composite unique
+ Constraint {
+ name: ConstraintName::try_new("products_email_tenant_key".to_string()).unwrap(),
+ kind: ConstraintKind::Unique(UniqueConstraint {
+ columns: vec![
+ ColumnName::try_new("email".to_string()).unwrap(),
+ ColumnName::try_new("tenant_id".to_string()).unwrap(),
+ ],
+ index_name: IndexName::try_new("products_email_tenant_key".to_string()).unwrap(),
+ nulls_not_distinct: false,
+ }),
+ comment: None,
+ },
+ // Check constraints
+ Constraint {
+ name: ConstraintName::try_new("products_price_positive".to_string()).unwrap(),
+ kind: ConstraintKind::Check(CheckConstraint {
+ expression: SqlExpr::new("price > 0".to_string()),
+ is_no_inherit: false,
+ }),
+ comment: None,
+ },
+ Constraint {
+ name: ConstraintName::try_new("products_quantity_non_negative".to_string()).unwrap(),
+ kind: ConstraintKind::Check(CheckConstraint {
+ expression: SqlExpr::new("quantity >= 0".to_string()),
+ is_no_inherit: false,
+ }),
+ comment: None,
+ },
+ Constraint {
+ name: ConstraintName::try_new("products_status_valid".to_string()).unwrap(),
+ kind: ConstraintKind::Check(CheckConstraint {
+ expression: SqlExpr::new("status IN ('draft', 'active', 'archived')".to_string()),
+ is_no_inherit: false,
+ }),
+ comment: None,
+ },
+ ];
+
+ let table = make_table("products", columns, constraints);
+
+ let codegen = PythonCodegen::with_defaults();
+ let output = codegen.generate(vec![table]);
+
+ insta::assert_snapshot!(output.get("models.py").unwrap());
+}
+
+#[test]
+fn snapshot_self_referential_table() {
+ let columns = vec![
+ {
+ let mut col = make_column("id", "int4", false);
+ col.identity = Some(IdentityKind::Always);
+ col
+ },
+ make_column("name", "text", false),
+ make_column("manager_id", "int4", true),
+ ];
+
+ let constraints = vec![
+ make_pk_constraint("employees", &["id"]),
+ Constraint {
+ name: ConstraintName::try_new("employees_manager_id_fkey".to_string()).unwrap(),
+ kind: ConstraintKind::ForeignKey(ForeignKeyConstraint {
+ columns: vec![ColumnName::try_new("manager_id".to_string()).unwrap()],
+ referenced_table: QualifiedName::new(
+ SchemaName::try_new("public".to_string()).unwrap(),
+ TableName::try_new("employees".to_string()).unwrap(),
+ ),
+ referenced_columns: vec![ColumnName::try_new("id".to_string()).unwrap()],
+ on_delete: ForeignKeyAction::SetNull,
+ on_update: ForeignKeyAction::NoAction,
+ is_deferrable: false,
+ is_initially_deferred: false,
+ }),
+ comment: None,
+ },
+ ];
+
+ let table = make_table("employees", columns, constraints);
+
+ let codegen = PythonCodegen::with_defaults();
+ let output = codegen.generate(vec![table]);
+
+ insta::assert_snapshot!(output.get("models.py").unwrap());
+}
+
+#[test]
+fn snapshot_multi_file_output() {
+ let user_columns = vec![
+ {
+ let mut col = make_column("id", "int4", false);
+ col.identity = Some(IdentityKind::Always);
+ col
+ },
+ make_column("name", "text", false),
+ make_column("email", "text", false),
+ ];
+ let users = make_table(
+ "users",
+ user_columns,
+ vec![make_pk_constraint("users", &["id"])],
+ );
+
+ let post_columns = vec![
+ {
+ let mut col = make_column("id", "int4", false);
+ col.identity = Some(IdentityKind::Always);
+ col
+ },
+ make_column("title", "text", false),
+ make_column("user_id", "int4", false),
+ ];
+ let posts = make_table(
+ "posts",
+ post_columns,
+ vec![
+ make_pk_constraint("posts", &["id"]),
+ Constraint {
+ name: ConstraintName::try_new("posts_user_id_fkey".to_string()).unwrap(),
+ kind: ConstraintKind::ForeignKey(ForeignKeyConstraint {
+ columns: vec![ColumnName::try_new("user_id".to_string()).unwrap()],
+ referenced_table: QualifiedName::new(
+ SchemaName::try_new("public".to_string()).unwrap(),
+ TableName::try_new("users".to_string()).unwrap(),
+ ),
+ referenced_columns: vec![ColumnName::try_new("id".to_string()).unwrap()],
+ on_delete: ForeignKeyAction::Cascade,
+ on_update: ForeignKeyAction::NoAction,
+ is_deferrable: false,
+ is_initially_deferred: false,
+ }),
+ comment: None,
+ },
+ ],
+ );
+
+ let config = PythonCodegenConfig {
+ output_mode: OutputMode::MultiFile,
+ ..Default::default()
+ };
+ let codegen = PythonCodegen::new(config);
+ let output = codegen.generate(vec![users, posts]);
+
+ // Snapshot each file separately
+ insta::assert_snapshot!("multi_file_init", output.get("__init__.py").unwrap());
+ insta::assert_snapshot!("multi_file_user", output.get("user.py").unwrap());
+ insta::assert_snapshot!("multi_file_post", output.get("post.py").unwrap());
+}
+
+#[test]
+fn snapshot_exclusion_constraint_warning() {
+ let columns = vec![
+ {
+ let mut col = make_column("id", "int4", false);
+ col.identity = Some(IdentityKind::Always);
+ col
+ },
+ make_column("room_id", "int4", false),
+ make_column("start_time", "timestamptz", false),
+ make_column("end_time", "timestamptz", false),
+ ];
+
+ let constraints = vec![
+ make_pk_constraint("meetings", &["id"]),
+ Constraint {
+ name: ConstraintName::try_new("meetings_no_overlap".to_string()).unwrap(),
+ kind: ConstraintKind::Exclusion(ExclusionConstraint {
+ elements: vec![
+ ExclusionElement {
+ expression: SqlExpr::new("room_id".to_string()),
+ operator: "=".to_string(),
+ },
+ ExclusionElement {
+ expression: SqlExpr::new("tsrange(start_time, end_time)".to_string()),
+ operator: "&&".to_string(),
+ },
+ ],
+ index_method: IndexMethod::Gist,
+ index_name: IndexName::try_new("meetings_no_overlap_idx".to_string()).unwrap(),
+ predicate: None,
+ }),
+ comment: None,
+ },
+ ];
+
+ let table = make_table("meetings", columns, constraints);
+
+ let codegen = PythonCodegen::with_defaults();
+ let output = codegen.generate(vec![table]);
+
+ insta::assert_snapshot!(output.get("models.py").unwrap());
+}
+
+#[test]
+fn snapshot_composite_primary_key() {
+ let columns = vec![
+ make_column("order_id", "int4", false),
+ make_column("product_id", "int4", false),
+ make_column("quantity", "int4", false),
+ make_column("unit_price", "numeric", false),
+ ];
+
+ let constraints = vec![make_pk_constraint(
+ "order_items",
+ &["order_id", "product_id"],
+ )];
+
+ let table = make_table("order_items", columns, constraints);
+
+ let codegen = PythonCodegen::with_defaults();
+ let output = codegen.generate(vec![table]);
+
+ insta::assert_snapshot!(output.get("models.py").unwrap());
+}
diff --git a/crates/codegen/src/python/tests/snapshots/tern_codegen__python__tests__snapshot_tests__multi_file_init.snap b/crates/codegen/src/python/tests/snapshots/tern_codegen__python__tests__snapshot_tests__multi_file_init.snap
new file mode 100644
index 0000000..eb8ed14
--- /dev/null
+++ b/crates/codegen/src/python/tests/snapshots/tern_codegen__python__tests__snapshot_tests__multi_file_init.snap
@@ -0,0 +1,14 @@
+---
+source: crates/codegen/src/python/tests/snapshot_tests.rs
+assertion_line: 508
+expression: "output.get(\"__init__.py\").unwrap()"
+---
+"""SQLModel definitions generated by Tern.
+
+This file was automatically generated. Do not edit manually.
+"""
+
+from .user import User
+from .post import Post
+
+__all__ = ["User", "Post"]
diff --git a/crates/codegen/src/python/tests/snapshots/tern_codegen__python__tests__snapshot_tests__multi_file_post.snap b/crates/codegen/src/python/tests/snapshots/tern_codegen__python__tests__snapshot_tests__multi_file_post.snap
new file mode 100644
index 0000000..d0c2f65
--- /dev/null
+++ b/crates/codegen/src/python/tests/snapshots/tern_codegen__python__tests__snapshot_tests__multi_file_post.snap
@@ -0,0 +1,15 @@
+---
+source: crates/codegen/src/python/tests/snapshot_tests.rs
+assertion_line: 510
+expression: "output.get(\"post.py\").unwrap()"
+---
+"""SQLModel definition for Post."""
+
+from sqlmodel import Field, SQLModel
+
+class Post(SQLModel, table=True):
+ __tablename__ = "posts"
+
+ id: int | None = Field(default=None, primary_key=True)
+ title: str
+ user_id: int = Field(foreign_key="public.users.id")
diff --git a/crates/codegen/src/python/tests/snapshots/tern_codegen__python__tests__snapshot_tests__multi_file_user.snap b/crates/codegen/src/python/tests/snapshots/tern_codegen__python__tests__snapshot_tests__multi_file_user.snap
new file mode 100644
index 0000000..d37902a
--- /dev/null
+++ b/crates/codegen/src/python/tests/snapshots/tern_codegen__python__tests__snapshot_tests__multi_file_user.snap
@@ -0,0 +1,15 @@
+---
+source: crates/codegen/src/python/tests/snapshot_tests.rs
+assertion_line: 509
+expression: "output.get(\"user.py\").unwrap()"
+---
+"""SQLModel definition for User."""
+
+from sqlmodel import Field, SQLModel
+
+class User(SQLModel, table=True):
+ __tablename__ = "users"
+
+ id: int | None = Field(default=None, primary_key=True)
+ name: str
+ email: str
diff --git a/crates/codegen/src/python/tests/snapshots/tern_codegen__python__tests__snapshot_tests__snapshot_all_postgres_types.snap b/crates/codegen/src/python/tests/snapshots/tern_codegen__python__tests__snapshot_tests__snapshot_all_postgres_types.snap
new file mode 100644
index 0000000..efe8aed
--- /dev/null
+++ b/crates/codegen/src/python/tests/snapshots/tern_codegen__python__tests__snapshot_tests__snapshot_all_postgres_types.snap
@@ -0,0 +1,43 @@
+---
+source: crates/codegen/src/python/tests/snapshot_tests.rs
+assertion_line: 313
+expression: "output.get(\"models.py\").unwrap()"
+---
+"""SQLModel definitions generated by Tern.
+
+This file was automatically generated. Do not edit manually.
+"""
+
+from datetime import date, datetime, time, timedelta
+from decimal import Decimal
+from typing import Any
+from uuid import UUID
+
+from sqlalchemy import ARRAY, Integer, JSON, Text
+from sqlmodel import Field, SQLModel
+
+class AllType(SQLModel, table=True):
+ __tablename__ = "all_types"
+
+ id: int | None = Field(default=None, primary_key=True)
+ col_int2: int
+ col_int4: int
+ col_int8: int
+ col_float4: float
+ col_float8: float
+ col_numeric: Decimal
+ col_bool: bool
+ col_text: str
+ col_varchar: str = Field(max_length=255)
+ col_date: date
+ col_time: time
+ col_timestamp: datetime
+ col_timestamptz: datetime
+ col_interval: timedelta
+ col_uuid: UUID
+ col_int_array: list[int] = Field(sa_type=ARRAY(Integer))
+ col_json: dict[str, Any] | None = Field(sa_type=JSON)
+ col_jsonb: dict[str, Any] | None = Field(sa_type=JSON)
+ col_bytea: bytes | None = None
+ col_inet: str | None = None
+ col_text_array: list[str] | None = Field(sa_type=ARRAY(Text))
diff --git a/crates/codegen/src/python/tests/snapshots/tern_codegen__python__tests__snapshot_tests__snapshot_blog_schema.snap b/crates/codegen/src/python/tests/snapshots/tern_codegen__python__tests__snapshot_tests__snapshot_blog_schema.snap
new file mode 100644
index 0000000..5f8154f
--- /dev/null
+++ b/crates/codegen/src/python/tests/snapshots/tern_codegen__python__tests__snapshot_tests__snapshot_blog_schema.snap
@@ -0,0 +1,43 @@
+---
+source: crates/codegen/src/python/tests/snapshot_tests.rs
+assertion_line: 243
+expression: "output.get(\"models.py\").unwrap()"
+---
+"""SQLModel definitions generated by Tern.
+
+This file was automatically generated. Do not edit manually.
+"""
+
+from datetime import datetime
+
+from sqlmodel import Field, SQLModel
+
+class User(SQLModel, table=True):
+ __tablename__ = "users"
+
+ id: int | None = Field(default=None, primary_key=True)
+ username: str = Field(unique=True)
+ email: str = Field(unique=True)
+ created_at: datetime
+ bio: str | None = None
+
+
+class Post(SQLModel, table=True):
+ __tablename__ = "posts"
+
+ id: int | None = Field(default=None, primary_key=True)
+ title: str
+ content: str
+ author_id: int = Field(foreign_key="public.users.id")
+ created_at: datetime
+ published_at: datetime | None = None
+
+
+class Comment(SQLModel, table=True):
+ __tablename__ = "comments"
+
+ id: int | None = Field(default=None, primary_key=True)
+ content: str
+ post_id: int = Field(foreign_key="public.posts.id")
+ author_id: int = Field(foreign_key="public.users.id")
+ created_at: datetime
diff --git a/crates/codegen/src/python/tests/snapshots/tern_codegen__python__tests__snapshot_tests__snapshot_complex_constraints.snap b/crates/codegen/src/python/tests/snapshots/tern_codegen__python__tests__snapshot_tests__snapshot_complex_constraints.snap
new file mode 100644
index 0000000..9d58c48
--- /dev/null
+++ b/crates/codegen/src/python/tests/snapshots/tern_codegen__python__tests__snapshot_tests__snapshot_complex_constraints.snap
@@ -0,0 +1,30 @@
+---
+source: crates/codegen/src/python/tests/snapshot_tests.rs
+assertion_line: 407
+expression: "output.get(\"models.py\").unwrap()"
+---
+"""SQLModel definitions generated by Tern.
+
+This file was automatically generated. Do not edit manually.
+"""
+
+from decimal import Decimal
+
+from sqlalchemy import CheckConstraint, UniqueConstraint
+from sqlmodel import Field, SQLModel
+
+class Product(SQLModel, table=True):
+ __tablename__ = "products"
+ __table_args__ = (
+ UniqueConstraint("email", "tenant_id", name="products_email_tenant_key"),
+ CheckConstraint("price > 0", name="products_price_positive"),
+ CheckConstraint("quantity >= 0", name="products_quantity_non_negative"),
+ CheckConstraint("status IN ('draft', 'active', 'archived')", name="products_status_valid"),
+ )
+
+ id: int | None = Field(default=None, primary_key=True)
+ email: str
+ tenant_id: int
+ status: str
+ price: Decimal
+ quantity: int
diff --git a/crates/codegen/src/python/tests/snapshots/tern_codegen__python__tests__snapshot_tests__snapshot_composite_primary_key.snap b/crates/codegen/src/python/tests/snapshots/tern_codegen__python__tests__snapshot_tests__snapshot_composite_primary_key.snap
new file mode 100644
index 0000000..748285d
--- /dev/null
+++ b/crates/codegen/src/python/tests/snapshots/tern_codegen__python__tests__snapshot_tests__snapshot_composite_primary_key.snap
@@ -0,0 +1,21 @@
+---
+source: crates/codegen/src/python/tests/snapshot_tests.rs
+assertion_line: 576
+expression: "output.get(\"models.py\").unwrap()"
+---
+"""SQLModel definitions generated by Tern.
+
+This file was automatically generated. Do not edit manually.
+"""
+
+from decimal import Decimal
+
+from sqlmodel import Field, SQLModel
+
+class OrderItem(SQLModel, table=True):
+ __tablename__ = "order_items"
+
+ order_id: int = Field(primary_key=True)
+ product_id: int = Field(primary_key=True)
+ quantity: int
+ unit_price: Decimal
diff --git a/crates/codegen/src/python/tests/snapshots/tern_codegen__python__tests__snapshot_tests__snapshot_exclusion_constraint_warning.snap b/crates/codegen/src/python/tests/snapshots/tern_codegen__python__tests__snapshot_tests__snapshot_exclusion_constraint_warning.snap
new file mode 100644
index 0000000..e14feb2
--- /dev/null
+++ b/crates/codegen/src/python/tests/snapshots/tern_codegen__python__tests__snapshot_tests__snapshot_exclusion_constraint_warning.snap
@@ -0,0 +1,22 @@
+---
+source: crates/codegen/src/python/tests/snapshot_tests.rs
+assertion_line: 554
+expression: "output.get(\"models.py\").unwrap()"
+---
+"""SQLModel definitions generated by Tern.
+
+This file was automatically generated. Do not edit manually.
+"""
+
+from datetime import datetime
+
+from sqlmodel import Field, SQLModel
+
+# WARNING: Exclusion constraint 'meetings_no_overlap' not supported by SQLModel: EXCLUDE USING gist (room_id WITH =, tsrange(start_time, end_time) WITH &&)
+class Meeting(SQLModel, table=True):
+ __tablename__ = "meetings"
+
+ id: int | None = Field(default=None, primary_key=True)
+ room_id: int
+ start_time: datetime
+ end_time: datetime
diff --git a/crates/codegen/src/python/tests/snapshots/tern_codegen__python__tests__snapshot_tests__snapshot_reserved_words_table.snap b/crates/codegen/src/python/tests/snapshots/tern_codegen__python__tests__snapshot_tests__snapshot_reserved_words_table.snap
new file mode 100644
index 0000000..0a579cf
--- /dev/null
+++ b/crates/codegen/src/python/tests/snapshots/tern_codegen__python__tests__snapshot_tests__snapshot_reserved_words_table.snap
@@ -0,0 +1,26 @@
+---
+source: crates/codegen/src/python/tests/snapshot_tests.rs
+assertion_line: 342
+expression: "output.get(\"models.py\").unwrap()"
+---
+"""SQLModel definitions generated by Tern.
+
+This file was automatically generated. Do not edit manually.
+"""
+
+from sqlmodel import Field, SQLModel
+
+class ReservedWord(SQLModel, table=True):
+ __tablename__ = "reserved_words"
+
+ id: int | None = Field(default=None, primary_key=True)
+ class_: str = Field(alias="class")
+ from_: str = Field(alias="from")
+ def_: int = Field(alias="def")
+ return_: bool = Field(alias="return")
+ async_: bool = Field(alias="async")
+ import_: str | None = Field(alias="import")
+ yield_: str | None = Field(alias="yield")
+ await_: str | None = Field(alias="await")
+ match_: str | None = Field(alias="match")
+ case_: int | None = Field(alias="case")
diff --git a/crates/codegen/src/python/tests/snapshots/tern_codegen__python__tests__snapshot_tests__snapshot_self_referential_table.snap b/crates/codegen/src/python/tests/snapshots/tern_codegen__python__tests__snapshot_tests__snapshot_self_referential_table.snap
new file mode 100644
index 0000000..5e89dec
--- /dev/null
+++ b/crates/codegen/src/python/tests/snapshots/tern_codegen__python__tests__snapshot_tests__snapshot_self_referential_table.snap
@@ -0,0 +1,18 @@
+---
+source: crates/codegen/src/python/tests/snapshot_tests.rs
+assertion_line: 447
+expression: "output.get(\"models.py\").unwrap()"
+---
+"""SQLModel definitions generated by Tern.
+
+This file was automatically generated. Do not edit manually.
+"""
+
+from sqlmodel import Field, SQLModel
+
+class Employee(SQLModel, table=True):
+ __tablename__ = "employees"
+
+ id: int | None = Field(default=None, primary_key=True)
+ name: str
+ manager_id: int | None = Field(foreign_key="public.employees.id")
diff --git a/crates/codegen/src/python/tests/snapshots/tern_codegen__python__tests__snapshot_tests__snapshot_simple_users_table.snap b/crates/codegen/src/python/tests/snapshots/tern_codegen__python__tests__snapshot_tests__snapshot_simple_users_table.snap
new file mode 100644
index 0000000..6b3359e
--- /dev/null
+++ b/crates/codegen/src/python/tests/snapshots/tern_codegen__python__tests__snapshot_tests__snapshot_simple_users_table.snap
@@ -0,0 +1,23 @@
+---
+source: crates/codegen/src/python/tests/snapshot_tests.rs
+assertion_line: 117
+expression: "output.get(\"models.py\").unwrap()"
+---
+"""SQLModel definitions generated by Tern.
+
+This file was automatically generated. Do not edit manually.
+"""
+
+from datetime import datetime
+
+from sqlmodel import Field, SQLModel
+
+class User(SQLModel, table=True):
+ """User accounts for the application."""
+
+ __tablename__ = "users"
+
+ id: int | None = Field(default=None, primary_key=True)
+ email: str = Field(unique=True)
+ created_at: datetime
+ name: str | None = None
diff --git a/crates/codegen/src/python/tests/unit_tests.rs b/crates/codegen/src/python/tests/unit_tests.rs
new file mode 100644
index 0000000..38bcd3f
--- /dev/null
+++ b/crates/codegen/src/python/tests/unit_tests.rs
@@ -0,0 +1,707 @@
+//! Unit tests for Python code generation components.
+
+use crate::Codegen;
+use crate::python::naming::{
+ is_python_keyword, is_reserved_word, sanitize_identifier, to_class_name, to_module_name,
+};
+use crate::python::type_mapping::{
+ extract_numeric_precision, extract_string_length, is_fixed_length_char, map_pg_type,
+};
+use crate::python::{OutputMode, PythonCodegenConfig, ReservedWordStrategy};
+
+use tern_ddl::types::{ForeignKeyAction, QualifiedCollationName, QualifiedName, SqlExpr};
+use tern_ddl::{
+ CheckConstraint, CollationName, Column, ColumnName, Constraint, ConstraintKind, ConstraintName,
+ ForeignKeyConstraint, IdentityKind, IndexName, Oid, PrimaryKeyConstraint, SchemaName, Table,
+ TableKind, TableName, TypeInfo, TypeName, UniqueConstraint,
+};
+
+// =============================================================================
+// Test Helpers
+// =============================================================================
+
+fn make_type_info(name: &str, formatted: &str, is_array: bool) -> TypeInfo {
+ TypeInfo {
+ name: TypeName::try_new(name.to_string()).unwrap(),
+ schema: SchemaName::try_new("pg_catalog".to_string()).unwrap(),
+ formatted: formatted.to_string(),
+ is_array,
+ }
+}
+
+fn make_column(name: &str, type_name: &str, is_nullable: bool) -> Column {
+ Column {
+ name: ColumnName::try_new(name.to_string()).unwrap(),
+ position: 1,
+ type_info: make_type_info(type_name, type_name, false),
+ is_nullable,
+ default: None,
+ generated: None,
+ identity: None,
+ collation: QualifiedCollationName::new(
+ SchemaName::try_new("pg_catalog".to_string()).unwrap(),
+ CollationName::try_new("default".to_string()).unwrap(),
+ ),
+ comment: None,
+ }
+}
+
+fn make_column_with_identity(name: &str, type_name: &str, identity: IdentityKind) -> Column {
+ let mut col = make_column(name, type_name, false);
+ col.identity = Some(identity);
+ col
+}
+
+fn make_table(name: &str, columns: Vec, constraints: Vec) -> Table {
+ Table {
+ oid: Oid::new(1),
+ name: TableName::try_new(name.to_string()).unwrap(),
+ kind: TableKind::Regular,
+ columns,
+ constraints,
+ indexes: vec![],
+ comment: None,
+ }
+}
+
+fn make_pk_constraint(table_name: &str, columns: &[&str]) -> Constraint {
+ Constraint {
+ name: ConstraintName::try_new(format!("{}_pkey", table_name)).unwrap(),
+ kind: ConstraintKind::PrimaryKey(PrimaryKeyConstraint {
+ columns: columns
+ .iter()
+ .map(|c| ColumnName::try_new(c.to_string()).unwrap())
+ .collect(),
+ index_name: IndexName::try_new(format!("{}_pkey", table_name)).unwrap(),
+ }),
+ comment: None,
+ }
+}
+
+fn make_fk_constraint(
+ name: &str,
+ columns: &[&str],
+ ref_table: &str,
+ ref_columns: &[&str],
+) -> Constraint {
+ Constraint {
+ name: ConstraintName::try_new(name.to_string()).unwrap(),
+ kind: ConstraintKind::ForeignKey(ForeignKeyConstraint {
+ columns: columns
+ .iter()
+ .map(|c| ColumnName::try_new(c.to_string()).unwrap())
+ .collect(),
+ referenced_table: QualifiedName::new(
+ SchemaName::try_new("public".to_string()).unwrap(),
+ TableName::try_new(ref_table.to_string()).unwrap(),
+ ),
+ referenced_columns: ref_columns
+ .iter()
+ .map(|c| ColumnName::try_new(c.to_string()).unwrap())
+ .collect(),
+ on_delete: ForeignKeyAction::NoAction,
+ on_update: ForeignKeyAction::NoAction,
+ is_deferrable: false,
+ is_initially_deferred: false,
+ }),
+ comment: None,
+ }
+}
+
+fn make_unique_constraint(name: &str, columns: &[&str]) -> Constraint {
+ Constraint {
+ name: ConstraintName::try_new(name.to_string()).unwrap(),
+ kind: ConstraintKind::Unique(UniqueConstraint {
+ columns: columns
+ .iter()
+ .map(|c| ColumnName::try_new(c.to_string()).unwrap())
+ .collect(),
+ index_name: IndexName::try_new(name.to_string()).unwrap(),
+ nulls_not_distinct: false,
+ }),
+ comment: None,
+ }
+}
+
+fn make_check_constraint(name: &str, expression: &str) -> Constraint {
+ Constraint {
+ name: ConstraintName::try_new(name.to_string()).unwrap(),
+ kind: ConstraintKind::Check(CheckConstraint {
+ expression: SqlExpr::new(expression.to_string()),
+ is_no_inherit: false,
+ }),
+ comment: None,
+ }
+}
+
+// =============================================================================
+// Type Mapping Tests
+// =============================================================================
+
+#[test]
+fn test_all_integer_types() {
+ for type_name in ["int2", "int4", "int8", "smallint", "integer", "bigint"] {
+ let ty = make_type_info(type_name, type_name, false);
+ let py = map_pg_type(&ty);
+ assert_eq!(py.annotation, "int", "Failed for type: {}", type_name);
+ }
+}
+
+#[test]
+fn test_serial_types() {
+ for type_name in [
+ "serial",
+ "serial2",
+ "serial4",
+ "serial8",
+ "smallserial",
+ "bigserial",
+ ] {
+ let ty = make_type_info(type_name, type_name, false);
+ let py = map_pg_type(&ty);
+ assert_eq!(py.annotation, "int", "Failed for type: {}", type_name);
+ }
+}
+
+#[test]
+fn test_all_float_types() {
+ for type_name in ["float4", "float8", "real", "double precision"] {
+ let ty = make_type_info(type_name, type_name, false);
+ let py = map_pg_type(&ty);
+ assert_eq!(py.annotation, "float", "Failed for type: {}", type_name);
+ }
+}
+
+#[test]
+fn test_all_text_types() {
+ for type_name in ["text", "varchar", "char", "character", "bpchar", "name"] {
+ let ty = make_type_info(type_name, type_name, false);
+ let py = map_pg_type(&ty);
+ assert_eq!(py.annotation, "str", "Failed for type: {}", type_name);
+ }
+}
+
+#[test]
+fn test_all_datetime_types() {
+ let cases = [
+ ("date", "date"),
+ ("time", "time"),
+ ("timetz", "time"),
+ ("timestamp", "datetime"),
+ ("timestamptz", "datetime"),
+ ("interval", "timedelta"),
+ ];
+
+ for (pg_type, expected) in cases {
+ let ty = make_type_info(pg_type, pg_type, false);
+ let py = map_pg_type(&ty);
+ assert_eq!(
+ py.annotation, expected,
+ "Failed for type: {} -> expected {}",
+ pg_type, expected
+ );
+ }
+}
+
+#[test]
+fn test_json_has_sa_type() {
+ for type_name in ["json", "jsonb"] {
+ let ty = make_type_info(type_name, type_name, false);
+ let py = map_pg_type(&ty);
+ assert_eq!(py.sa_type, Some("JSON".to_string()));
+ assert!(py.sa_imports.iter().any(|i| i.name == "JSON"));
+ }
+}
+
+#[test]
+fn test_array_types() {
+ let cases = [
+ ("int4", "integer[]", "list[int]"),
+ ("text", "text[]", "list[str]"),
+ ("uuid", "uuid[]", "list[UUID]"),
+ ];
+
+ for (type_name, formatted, expected) in cases {
+ let ty = make_type_info(type_name, formatted, true);
+ let py = map_pg_type(&ty);
+ assert_eq!(py.annotation, expected);
+ assert!(py.sa_type.is_some());
+ assert!(py.sa_type.as_ref().unwrap().contains("ARRAY"));
+ }
+}
+
+// =============================================================================
+// Naming Tests
+// =============================================================================
+
+#[test]
+fn test_all_python_keywords() {
+ let keywords = [
+ "False", "None", "True", "and", "as", "assert", "async", "await", "break", "class",
+ "continue", "def", "del", "elif", "else", "except", "finally", "for", "from", "global",
+ "if", "import", "in", "is", "lambda", "nonlocal", "not", "or", "pass", "raise", "return",
+ "try", "while", "with", "yield",
+ ];
+
+ for kw in keywords {
+ assert!(is_python_keyword(kw), "Expected '{}' to be a keyword", kw);
+ assert!(is_reserved_word(kw), "Expected '{}' to be reserved", kw);
+ }
+}
+
+#[test]
+fn test_class_name_pluralization() {
+ let cases = [
+ ("users", "User"),
+ ("categories", "Category"),
+ ("companies", "Company"),
+ ("addresses", "Addresse"), // 'addresses' -> singular is 'addresse' with simple rule
+ ("status", "Status"), // ends in 's' but 'us' ending preserved
+ ("analyses", "Analyse"), // ends in 's' but 'is' ending requires special handling
+ ("classes", "Classe"), // double 's' preserved
+ ];
+
+ for (input, expected) in cases {
+ let result = to_class_name(input);
+ assert_eq!(
+ result, expected,
+ "to_class_name('{}') = '{}', expected '{}'",
+ input, result, expected
+ );
+ }
+}
+
+#[test]
+fn test_class_name_snake_to_pascal() {
+ let cases = [
+ ("user_accounts", "UserAccount"),
+ ("order_items", "OrderItem"),
+ ("api_keys", "ApiKey"),
+ ("http_requests", "HttpRequest"),
+ ];
+
+ for (input, expected) in cases {
+ let result = to_class_name(input);
+ assert_eq!(result, expected);
+ }
+}
+
+#[test]
+fn test_module_name_conversion() {
+ let cases = [
+ ("users", "user"),
+ ("UserAccounts", "user_account"),
+ ("order_items", "order_item"),
+ ("HTTPRequests", "httprequest"), // Consecutive caps become lowercase
+ ];
+
+ for (input, expected) in cases {
+ let result = to_module_name(input);
+ assert_eq!(result, expected);
+ }
+}
+
+#[test]
+fn test_sanitize_special_characters() {
+ let strategy = ReservedWordStrategy::AppendUnderscore;
+
+ let cases = [
+ ("column-name", "column_name", true),
+ ("column name", "column_name", true),
+ ("column.name", "column_name", true),
+ ("column@name", "column_name", true),
+ ("column#1", "column_1", true),
+ ];
+
+ for (input, expected, needs_alias) in cases {
+ let (result, alias) = sanitize_identifier(input, &strategy);
+ assert_eq!(result, expected, "sanitize('{}') = '{}'", input, result);
+ assert_eq!(alias, needs_alias);
+ }
+}
+
+#[test]
+fn test_sanitize_leading_digit() {
+ let strategy = ReservedWordStrategy::AppendUnderscore;
+ let (result, needs_alias) = sanitize_identifier("1column", &strategy);
+ assert_eq!(result, "_1column");
+ assert!(needs_alias);
+}
+
+#[test]
+fn test_reserved_word_strategies() {
+ let input = "class";
+
+ let (result, _) = sanitize_identifier(input, &ReservedWordStrategy::AppendUnderscore);
+ assert_eq!(result, "class_");
+
+ let (result, _) = sanitize_identifier(
+ input,
+ &ReservedWordStrategy::PrependPrefix("col_".to_string()),
+ );
+ assert_eq!(result, "col_class");
+}
+
+// =============================================================================
+// Full Generation Tests
+// =============================================================================
+
+#[test]
+fn test_generate_simple_users_table() {
+ use crate::python::PythonCodegen;
+
+ let columns = vec![
+ make_column("id", "int4", false),
+ make_column("name", "text", false),
+ make_column("email", "text", false),
+ make_column("bio", "text", true),
+ ];
+ let constraints = vec![
+ make_pk_constraint("users", &["id"]),
+ make_unique_constraint("users_email_key", &["email"]),
+ ];
+ let table = make_table("users", columns, constraints);
+
+ let codegen = PythonCodegen::with_defaults();
+ let output = codegen.generate(vec![table]);
+
+ let content = &output["models.py"];
+
+ // Check class definition
+ assert!(content.contains("class User(SQLModel, table=True):"));
+ assert!(content.contains("__tablename__ = \"users\""));
+
+ // Check imports
+ assert!(content.contains("from sqlmodel import"));
+ assert!(content.contains("SQLModel"));
+ assert!(content.contains("Field"));
+
+ // Check fields
+ assert!(content.contains("id:"));
+ assert!(content.contains("name: str"));
+ assert!(content.contains("email: str"));
+ assert!(content.contains("bio: str | None"));
+
+ // Check constraints
+ assert!(content.contains("primary_key=True"));
+ assert!(content.contains("unique=True"));
+}
+
+#[test]
+fn test_generate_table_with_foreign_key() {
+ use crate::python::PythonCodegen;
+
+ let user_columns = vec![
+ make_column("id", "int4", false),
+ make_column("name", "text", false),
+ ];
+ let user_constraints = vec![make_pk_constraint("users", &["id"])];
+ let users = make_table("users", user_columns, user_constraints);
+
+ let post_columns = vec![
+ make_column("id", "int4", false),
+ make_column("title", "text", false),
+ make_column("user_id", "int4", true),
+ ];
+ let post_constraints = vec![
+ make_pk_constraint("posts", &["id"]),
+ make_fk_constraint("posts_user_id_fkey", &["user_id"], "users", &["id"]),
+ ];
+ let posts = make_table("posts", post_columns, post_constraints);
+
+ let codegen = PythonCodegen::with_defaults();
+ let output = codegen.generate(vec![users, posts]);
+
+ let content = &output["models.py"];
+
+ assert!(content.contains("class User(SQLModel, table=True):"));
+ assert!(content.contains("class Post(SQLModel, table=True):"));
+ assert!(content.contains("foreign_key="));
+}
+
+#[test]
+fn test_generate_table_with_check_constraint() {
+ use crate::python::PythonCodegen;
+
+ let columns = vec![
+ make_column("id", "int4", false),
+ make_column("price", "numeric", false),
+ make_column("quantity", "int4", false),
+ ];
+ let constraints = vec![
+ make_pk_constraint("products", &["id"]),
+ make_check_constraint("products_price_positive", "price > 0"),
+ make_check_constraint("products_quantity_positive", "quantity >= 0"),
+ ];
+ let table = make_table("products", columns, constraints);
+
+ let codegen = PythonCodegen::with_defaults();
+ let output = codegen.generate(vec![table]);
+
+ let content = &output["models.py"];
+
+ assert!(content.contains("CheckConstraint"));
+ assert!(content.contains("price > 0"));
+ assert!(content.contains("quantity >= 0"));
+}
+
+#[test]
+fn test_generate_table_with_composite_unique() {
+ use crate::python::PythonCodegen;
+
+ let columns = vec![
+ make_column("id", "int4", false),
+ make_column("email", "text", false),
+ make_column("tenant_id", "int4", false),
+ ];
+ let constraints = vec![
+ make_pk_constraint("users", &["id"]),
+ make_unique_constraint("users_email_tenant_key", &["email", "tenant_id"]),
+ ];
+ let table = make_table("users", columns, constraints);
+
+ let codegen = PythonCodegen::with_defaults();
+ let output = codegen.generate(vec![table]);
+
+ let content = &output["models.py"];
+
+ assert!(content.contains("__table_args__"));
+ assert!(content.contains("UniqueConstraint"));
+ assert!(content.contains("\"email\""));
+ assert!(content.contains("\"tenant_id\""));
+}
+
+#[test]
+fn test_generate_table_with_identity_column() {
+ use crate::python::PythonCodegen;
+
+ let columns = vec![
+ make_column_with_identity("id", "int4", IdentityKind::Always),
+ make_column("name", "text", false),
+ ];
+ let constraints = vec![make_pk_constraint("users", &["id"])];
+ let table = make_table("users", columns, constraints);
+
+ let codegen = PythonCodegen::with_defaults();
+ let output = codegen.generate(vec![table]);
+
+ let content = &output["models.py"];
+
+ // Identity column should be Optional with default=None
+ assert!(content.contains("id: int | None"));
+ assert!(content.contains("default=None"));
+ assert!(content.contains("primary_key=True"));
+}
+
+#[test]
+fn test_generate_reserved_word_column() {
+ use crate::python::PythonCodegen;
+
+ let columns = vec![
+ make_column("id", "int4", false),
+ make_column("class", "text", false),
+ make_column("from", "text", true),
+ ];
+ let constraints = vec![make_pk_constraint("items", &["id"])];
+ let table = make_table("items", columns, constraints);
+
+ let codegen = PythonCodegen::with_defaults();
+ let output = codegen.generate(vec![table]);
+
+ let content = &output["models.py"];
+
+ // Reserved words should be aliased
+ assert!(content.contains("class_:"));
+ assert!(content.contains("alias=\"class\""));
+ assert!(content.contains("from_:"));
+ assert!(content.contains("alias=\"from\""));
+}
+
+#[test]
+fn test_generate_multi_file_output() {
+ use crate::python::PythonCodegen;
+
+ let user_columns = vec![
+ make_column("id", "int4", false),
+ make_column("name", "text", false),
+ ];
+ let users = make_table(
+ "users",
+ user_columns,
+ vec![make_pk_constraint("users", &["id"])],
+ );
+
+ let post_columns = vec![
+ make_column("id", "int4", false),
+ make_column("title", "text", false),
+ ];
+ let posts = make_table(
+ "posts",
+ post_columns,
+ vec![make_pk_constraint("posts", &["id"])],
+ );
+
+ let config = PythonCodegenConfig {
+ output_mode: OutputMode::MultiFile,
+ ..Default::default()
+ };
+ let codegen = PythonCodegen::new(config);
+ let output = codegen.generate(vec![users, posts]);
+
+ // Check files exist
+ assert!(output.contains_key("__init__.py"));
+ assert!(output.contains_key("user.py"));
+ assert!(output.contains_key("post.py"));
+
+ // Check __init__.py
+ let init = &output["__init__.py"];
+ assert!(init.contains("from .user import User"));
+ assert!(init.contains("from .post import Post"));
+ assert!(init.contains("__all__"));
+
+ // Check individual files
+ assert!(output["user.py"].contains("class User(SQLModel, table=True):"));
+ assert!(output["post.py"].contains("class Post(SQLModel, table=True):"));
+}
+
+#[test]
+fn test_generate_with_datetime_types() {
+ use crate::python::PythonCodegen;
+
+ let columns = vec![
+ make_column("id", "int4", false),
+ make_column("created_at", "timestamptz", false),
+ make_column("updated_at", "timestamptz", true),
+ make_column("birth_date", "date", true),
+ ];
+ let constraints = vec![make_pk_constraint("events", &["id"])];
+ let table = make_table("events", columns, constraints);
+
+ let codegen = PythonCodegen::with_defaults();
+ let output = codegen.generate(vec![table]);
+
+ let content = &output["models.py"];
+
+ assert!(content.contains("from datetime import"));
+ assert!(content.contains("datetime"));
+ assert!(content.contains("date"));
+ assert!(content.contains("created_at: datetime"));
+ assert!(content.contains("birth_date: date | None"));
+}
+
+#[test]
+fn test_generate_with_uuid_type() {
+ use crate::python::PythonCodegen;
+
+ let columns = vec![
+ make_column("id", "uuid", false),
+ make_column("name", "text", false),
+ ];
+ let constraints = vec![make_pk_constraint("items", &["id"])];
+ let table = make_table("items", columns, constraints);
+
+ let codegen = PythonCodegen::with_defaults();
+ let output = codegen.generate(vec![table]);
+
+ let content = &output["models.py"];
+
+ assert!(content.contains("from uuid import UUID"));
+}
+
+#[test]
+fn test_generate_with_decimal_type() {
+ use crate::python::PythonCodegen;
+
+ let columns = vec![
+ make_column("id", "int4", false),
+ make_column("price", "numeric", false),
+ make_column("discount", "numeric", true),
+ ];
+ let constraints = vec![make_pk_constraint("products", &["id"])];
+ let table = make_table("products", columns, constraints);
+
+ let codegen = PythonCodegen::with_defaults();
+ let output = codegen.generate(vec![table]);
+
+ let content = &output["models.py"];
+
+ assert!(content.contains("from decimal import Decimal"));
+ assert!(content.contains("price: Decimal"));
+ assert!(content.contains("discount: Decimal | None"));
+}
+
+#[test]
+fn test_generate_with_json_type() {
+ use crate::python::PythonCodegen;
+
+ let columns = vec![
+ make_column("id", "int4", false),
+ make_column("config", "jsonb", false),
+ make_column("metadata", "json", true),
+ ];
+ let constraints = vec![make_pk_constraint("settings", &["id"])];
+ let table = make_table("settings", columns, constraints);
+
+ let codegen = PythonCodegen::with_defaults();
+ let output = codegen.generate(vec![table]);
+
+ let content = &output["models.py"];
+
+ assert!(content.contains("from typing import Any"));
+ assert!(content.contains("from sqlalchemy import JSON"));
+ assert!(content.contains("dict[str, Any]"));
+ assert!(content.contains("sa_type=JSON"));
+}
+
+#[test]
+fn test_generate_empty_tables() {
+ use crate::python::PythonCodegen;
+
+ let codegen = PythonCodegen::with_defaults();
+ let output = codegen.generate(vec![]);
+
+ assert!(output.contains_key("models.py"));
+ assert!(output["models.py"].contains("No tables to generate"));
+}
+
+#[test]
+fn test_config_default_values() {
+ let config = PythonCodegenConfig::default();
+
+ assert!(!config.generate_base_models);
+ assert!(config.module_prefix.is_none());
+ assert!(config.include_docstrings);
+ assert!(!config.generate_relationships);
+ assert_eq!(config.output_mode, OutputMode::SingleFile);
+ assert_eq!(
+ config.reserved_word_strategy,
+ ReservedWordStrategy::AppendUnderscore
+ );
+}
+
+#[test]
+fn test_string_length_extraction() {
+ assert_eq!(extract_string_length("character varying(100)"), Some(100));
+ assert_eq!(extract_string_length("varchar(50)"), Some(50));
+ assert_eq!(extract_string_length("character(10)"), Some(10));
+ assert_eq!(extract_string_length("char(5)"), Some(5));
+ assert_eq!(extract_string_length("text"), None);
+ assert_eq!(extract_string_length("integer"), None);
+}
+
+#[test]
+fn test_numeric_precision_extraction() {
+ assert_eq!(extract_numeric_precision("numeric(10,2)"), Some((10, 2)));
+ assert_eq!(extract_numeric_precision("decimal(15,4)"), Some((15, 4)));
+ assert_eq!(extract_numeric_precision("numeric(8)"), Some((8, 0)));
+ assert_eq!(extract_numeric_precision("numeric"), None);
+}
+
+#[test]
+fn test_fixed_length_char_detection() {
+ assert!(is_fixed_length_char("char"));
+ assert!(is_fixed_length_char("character"));
+ assert!(is_fixed_length_char("bpchar"));
+ assert!(!is_fixed_length_char("varchar"));
+ assert!(!is_fixed_length_char("text"));
+}
diff --git a/crates/codegen/src/python/type_mapping.rs b/crates/codegen/src/python/type_mapping.rs
new file mode 100644
index 0000000..d7ea075
--- /dev/null
+++ b/crates/codegen/src/python/type_mapping.rs
@@ -0,0 +1,518 @@
+//! PostgreSQL to Python type mapping.
+//!
+//! This module handles the conversion of PostgreSQL types to their Python equivalents
+//! for use in SQLModel field definitions.
+
+use tern_ddl::TypeInfo;
+
+/// Represents a Python type with its import requirements.
+#[derive(Debug, Clone, PartialEq, Eq)]
+pub struct PythonType {
+ /// The Python type annotation (e.g., "int", "str", "datetime").
+ pub annotation: String,
+ /// Required imports for this type (module path, name).
+ pub imports: Vec,
+ /// SQLAlchemy type for `sa_type` parameter if needed (e.g., for JSON, ARRAY).
+ pub sa_type: Option,
+ /// SQLAlchemy imports required for sa_type.
+ pub sa_imports: Vec,
+}
+
+/// A Python import statement.
+#[derive(Debug, Clone, PartialEq, Eq, Hash)]
+pub struct PythonImport {
+ /// The module to import from (e.g., "datetime", "typing", "uuid").
+ pub module: String,
+ /// The name to import (e.g., "datetime", "Any", "UUID").
+ pub name: String,
+}
+
+impl PythonImport {
+ /// Creates a new Python import.
+ pub fn new(module: impl Into, name: impl Into) -> Self {
+ Self {
+ module: module.into(),
+ name: name.into(),
+ }
+ }
+
+ /// Creates an import from the datetime module.
+ pub fn datetime(name: &str) -> Self {
+ Self::new("datetime", name)
+ }
+
+ /// Creates an import from the typing module.
+ pub fn typing(name: &str) -> Self {
+ Self::new("typing", name)
+ }
+
+ /// Creates an import from the decimal module.
+ pub fn decimal() -> Self {
+ Self::new("decimal", "Decimal")
+ }
+
+ /// Creates an import from the uuid module.
+ pub fn uuid() -> Self {
+ Self::new("uuid", "UUID")
+ }
+
+ /// Creates an import from SQLAlchemy.
+ pub fn sqlalchemy(name: &str) -> Self {
+ Self::new("sqlalchemy", name)
+ }
+}
+
+impl PythonType {
+ /// Creates a simple Python type with no imports.
+ pub fn simple(annotation: impl Into) -> Self {
+ Self {
+ annotation: annotation.into(),
+ imports: Vec::new(),
+ sa_type: None,
+ sa_imports: Vec::new(),
+ }
+ }
+
+ /// Creates a Python type with imports.
+ pub fn with_imports(annotation: impl Into, imports: Vec) -> Self {
+ Self {
+ annotation: annotation.into(),
+ imports,
+ sa_type: None,
+ sa_imports: Vec::new(),
+ }
+ }
+
+ /// Creates a Python type with an SQLAlchemy type.
+ pub fn with_sa_type(
+ annotation: impl Into,
+ imports: Vec,
+ sa_type: impl Into,
+ sa_imports: Vec,
+ ) -> Self {
+ Self {
+ annotation: annotation.into(),
+ imports,
+ sa_type: Some(sa_type.into()),
+ sa_imports,
+ }
+ }
+}
+
+/// Maps a PostgreSQL type to a Python type.
+///
+/// This function handles both the raw type name and the formatted type with modifiers.
+/// The `formatted` field is preferred as it contains the full type specification.
+pub fn map_pg_type(type_info: &TypeInfo) -> PythonType {
+ let type_name = type_info.name.as_ref();
+ let formatted = &type_info.formatted;
+
+ // Handle array types first
+ if type_info.is_array {
+ return map_array_type(type_name, formatted);
+ }
+
+ // Map based on type name (canonical PostgreSQL type names)
+ match type_name {
+ // Integer types
+ "int2" | "smallint" => PythonType::simple("int"),
+ "int4" | "integer" | "int" => PythonType::simple("int"),
+ "int8" | "bigint" => PythonType::simple("int"),
+ "serial" | "serial4" => PythonType::simple("int"),
+ "bigserial" | "serial8" => PythonType::simple("int"),
+ "smallserial" | "serial2" => PythonType::simple("int"),
+
+ // Floating point types
+ "float4" | "real" => PythonType::simple("float"),
+ "float8" | "double precision" => PythonType::simple("float"),
+
+ // Numeric/decimal types
+ "numeric" | "decimal" => PythonType::with_imports("Decimal", vec![PythonImport::decimal()]),
+
+ // Boolean
+ "bool" | "boolean" => PythonType::simple("bool"),
+
+ // Text types
+ "text" => PythonType::simple("str"),
+ "varchar" | "character varying" => PythonType::simple("str"),
+ "char" | "character" | "bpchar" => PythonType::simple("str"),
+ "name" => PythonType::simple("str"),
+
+ // Date/time types
+ "date" => PythonType::with_imports("date", vec![PythonImport::datetime("date")]),
+ "time" | "time without time zone" => {
+ PythonType::with_imports("time", vec![PythonImport::datetime("time")])
+ }
+ "timetz" | "time with time zone" => {
+ PythonType::with_imports("time", vec![PythonImport::datetime("time")])
+ }
+ "timestamp" | "timestamp without time zone" => {
+ PythonType::with_imports("datetime", vec![PythonImport::datetime("datetime")])
+ }
+ "timestamptz" | "timestamp with time zone" => {
+ PythonType::with_imports("datetime", vec![PythonImport::datetime("datetime")])
+ }
+ "interval" => {
+ PythonType::with_imports("timedelta", vec![PythonImport::datetime("timedelta")])
+ }
+
+ // UUID
+ "uuid" => PythonType::with_imports("UUID", vec![PythonImport::uuid()]),
+
+ // JSON types
+ "json" | "jsonb" => PythonType::with_sa_type(
+ "dict[str, Any]",
+ vec![PythonImport::typing("Any")],
+ "JSON",
+ vec![PythonImport::sqlalchemy("JSON")],
+ ),
+
+ // Binary
+ "bytea" => PythonType::simple("bytes"),
+
+ // Network types
+ "inet" | "cidr" | "macaddr" | "macaddr8" => PythonType::simple("str"),
+
+ // Geometric types (stored as strings)
+ "point" | "line" | "lseg" | "box" | "path" | "polygon" | "circle" => {
+ PythonType::simple("str")
+ }
+
+ // Bit strings
+ "bit" | "varbit" | "bit varying" => PythonType::simple("str"),
+
+ // Money (stored as string to preserve formatting)
+ "money" => PythonType::simple("str"),
+
+ // XML
+ "xml" => PythonType::simple("str"),
+
+ // OID types (internal PostgreSQL types)
+ "oid" | "regproc" | "regprocedure" | "regoper" | "regoperator" | "regclass" | "regtype"
+ | "regrole" | "regnamespace" | "regconfig" | "regdictionary" => PythonType::simple("int"),
+
+ // Range types
+ "int4range" | "int8range" | "numrange" | "tsrange" | "tstzrange" | "daterange" => {
+ // Range types are complex; represent as string for now
+ PythonType::simple("str")
+ }
+
+ // TSVector/TSQuery (full text search)
+ "tsvector" | "tsquery" => PythonType::simple("str"),
+
+ // Default fallback - use Any for unknown types
+ _ => {
+ // Check if it looks like a user-defined type (enum, composite, etc.)
+ // For now, map unknown types to Any
+ PythonType::with_imports("Any", vec![PythonImport::typing("Any")])
+ }
+ }
+}
+
+/// Maps a PostgreSQL array type to a Python list type.
+fn map_array_type(element_type_name: &str, formatted: &str) -> PythonType {
+ // Get the element type first
+ let element_type = map_pg_type(&TypeInfo {
+ name: tern_ddl::TypeName::try_new(element_type_name.to_string())
+ .unwrap_or_else(|_| tern_ddl::TypeName::try_new("text".to_string()).unwrap()),
+ schema: tern_ddl::SchemaName::try_new("pg_catalog".to_string()).unwrap(),
+ formatted: strip_array_suffix(formatted),
+ is_array: false,
+ });
+
+ // Determine the SQLAlchemy array element type
+ let sa_element_type = match element_type_name {
+ "int2" | "smallint" => "SmallInteger",
+ "int4" | "integer" | "int" => "Integer",
+ "int8" | "bigint" => "BigInteger",
+ "float4" | "real" => "Float",
+ "float8" | "double precision" => "Float",
+ "numeric" | "decimal" => "Numeric",
+ "bool" | "boolean" => "Boolean",
+ "text" => "Text",
+ "varchar" | "character varying" => "String",
+ "char" | "character" | "bpchar" => "String",
+ "uuid" => "UUID",
+ "timestamp" | "timestamp without time zone" => "DateTime",
+ "timestamptz" | "timestamp with time zone" => "DateTime",
+ "date" => "Date",
+ "time" | "time without time zone" => "Time",
+ "json" | "jsonb" => "JSON",
+ _ => "String", // Default to String for unknown types
+ };
+
+ let annotation = format!("list[{}]", element_type.annotation);
+ let sa_type = format!("ARRAY({sa_element_type})");
+
+ let mut imports = element_type.imports;
+ let mut sa_imports = vec![PythonImport::sqlalchemy("ARRAY")];
+
+ // Add the element type import for SQLAlchemy
+ // All ARRAY element types require their corresponding SQLAlchemy type import
+ sa_imports.push(PythonImport::sqlalchemy(sa_element_type));
+
+ // Merge any existing SA imports from element type
+ imports.extend(element_type.sa_imports);
+
+ PythonType {
+ annotation,
+ imports,
+ sa_type: Some(sa_type),
+ sa_imports,
+ }
+}
+
+/// Strips the array suffix (e.g., "[]") from a formatted type string.
+fn strip_array_suffix(formatted: &str) -> String {
+ formatted
+ .trim_end_matches("[]")
+ .trim_end_matches(" ARRAY")
+ .to_string()
+}
+
+/// Extracts the varchar/char length constraint from a formatted type.
+///
+/// Returns `Some(length)` for types like "character varying(255)" or "character(10)".
+pub fn extract_string_length(formatted: &str) -> Option {
+ // Match patterns like "character varying(255)" or "character(10)" or "varchar(100)"
+ let formatted_lower = formatted.to_lowercase();
+
+ if formatted_lower.starts_with("character varying(")
+ || formatted_lower.starts_with("varchar(")
+ || formatted_lower.starts_with("character(")
+ || formatted_lower.starts_with("char(")
+ {
+ if let Some(start) = formatted.find('(') {
+ if let Some(end) = formatted.find(')') {
+ let len_str = &formatted[start + 1..end];
+ return len_str.parse().ok();
+ }
+ }
+ }
+
+ None
+}
+
+/// Extracts numeric precision and scale from a formatted type.
+///
+/// Returns `Some((precision, scale))` for types like "numeric(10,2)".
+/// Kept for future use to add Decimal precision validation in Field().
+#[allow(dead_code)]
+pub fn extract_numeric_precision(formatted: &str) -> Option<(u32, u32)> {
+ let formatted_lower = formatted.to_lowercase();
+
+ if formatted_lower.starts_with("numeric(") || formatted_lower.starts_with("decimal(") {
+ if let Some(start) = formatted.find('(') {
+ if let Some(end) = formatted.find(')') {
+ let params = &formatted[start + 1..end];
+ let parts: Vec<&str> = params.split(',').collect();
+ if parts.len() == 2 {
+ let precision: u32 = parts[0].trim().parse().ok()?;
+ let scale: u32 = parts[1].trim().parse().ok()?;
+ return Some((precision, scale));
+ } else if parts.len() == 1 {
+ let precision: u32 = parts[0].trim().parse().ok()?;
+ return Some((precision, 0));
+ }
+ }
+ }
+ }
+
+ None
+}
+
+/// Checks if a type is a fixed-length character type (char/character).
+pub fn is_fixed_length_char(type_name: &str) -> bool {
+ matches!(type_name, "char" | "character" | "bpchar")
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use tern_ddl::{SchemaName, TypeName};
+
+ fn make_type_info(name: &str, formatted: &str, is_array: bool) -> TypeInfo {
+ TypeInfo {
+ name: TypeName::try_new(name.to_string()).unwrap(),
+ schema: SchemaName::try_new("pg_catalog".to_string()).unwrap(),
+ formatted: formatted.to_string(),
+ is_array,
+ }
+ }
+
+ #[test]
+ fn test_integer_types() {
+ let type_info = make_type_info("int4", "integer", false);
+ let py_type = map_pg_type(&type_info);
+ assert_eq!(py_type.annotation, "int");
+ assert!(py_type.imports.is_empty());
+
+ let type_info = make_type_info("int8", "bigint", false);
+ let py_type = map_pg_type(&type_info);
+ assert_eq!(py_type.annotation, "int");
+
+ let type_info = make_type_info("int2", "smallint", false);
+ let py_type = map_pg_type(&type_info);
+ assert_eq!(py_type.annotation, "int");
+ }
+
+ #[test]
+ fn test_float_types() {
+ let type_info = make_type_info("float8", "double precision", false);
+ let py_type = map_pg_type(&type_info);
+ assert_eq!(py_type.annotation, "float");
+
+ let type_info = make_type_info("float4", "real", false);
+ let py_type = map_pg_type(&type_info);
+ assert_eq!(py_type.annotation, "float");
+ }
+
+ #[test]
+ fn test_numeric_type() {
+ let type_info = make_type_info("numeric", "numeric(10,2)", false);
+ let py_type = map_pg_type(&type_info);
+ assert_eq!(py_type.annotation, "Decimal");
+ assert_eq!(py_type.imports.len(), 1);
+ assert_eq!(py_type.imports[0].module, "decimal");
+ assert_eq!(py_type.imports[0].name, "Decimal");
+ }
+
+ #[test]
+ fn test_boolean_type() {
+ let type_info = make_type_info("bool", "boolean", false);
+ let py_type = map_pg_type(&type_info);
+ assert_eq!(py_type.annotation, "bool");
+ }
+
+ #[test]
+ fn test_text_types() {
+ let type_info = make_type_info("text", "text", false);
+ let py_type = map_pg_type(&type_info);
+ assert_eq!(py_type.annotation, "str");
+
+ let type_info = make_type_info("varchar", "character varying(255)", false);
+ let py_type = map_pg_type(&type_info);
+ assert_eq!(py_type.annotation, "str");
+ }
+
+ #[test]
+ fn test_datetime_types() {
+ let type_info = make_type_info("timestamp", "timestamp without time zone", false);
+ let py_type = map_pg_type(&type_info);
+ assert_eq!(py_type.annotation, "datetime");
+ assert_eq!(py_type.imports.len(), 1);
+ assert_eq!(py_type.imports[0].module, "datetime");
+ assert_eq!(py_type.imports[0].name, "datetime");
+
+ let type_info = make_type_info("timestamptz", "timestamp with time zone", false);
+ let py_type = map_pg_type(&type_info);
+ assert_eq!(py_type.annotation, "datetime");
+
+ let type_info = make_type_info("date", "date", false);
+ let py_type = map_pg_type(&type_info);
+ assert_eq!(py_type.annotation, "date");
+ assert_eq!(py_type.imports[0].name, "date");
+
+ let type_info = make_type_info("time", "time without time zone", false);
+ let py_type = map_pg_type(&type_info);
+ assert_eq!(py_type.annotation, "time");
+ }
+
+ #[test]
+ fn test_uuid_type() {
+ let type_info = make_type_info("uuid", "uuid", false);
+ let py_type = map_pg_type(&type_info);
+ assert_eq!(py_type.annotation, "UUID");
+ assert_eq!(py_type.imports.len(), 1);
+ assert_eq!(py_type.imports[0].module, "uuid");
+ assert_eq!(py_type.imports[0].name, "UUID");
+ }
+
+ #[test]
+ fn test_json_types() {
+ let type_info = make_type_info("jsonb", "jsonb", false);
+ let py_type = map_pg_type(&type_info);
+ assert_eq!(py_type.annotation, "dict[str, Any]");
+ assert!(py_type.imports.iter().any(|i| i.name == "Any"));
+ assert_eq!(py_type.sa_type, Some("JSON".to_string()));
+ }
+
+ #[test]
+ fn test_bytea_type() {
+ let type_info = make_type_info("bytea", "bytea", false);
+ let py_type = map_pg_type(&type_info);
+ assert_eq!(py_type.annotation, "bytes");
+ }
+
+ #[test]
+ fn test_array_type() {
+ let type_info = make_type_info("int4", "integer[]", true);
+ let py_type = map_pg_type(&type_info);
+ assert_eq!(py_type.annotation, "list[int]");
+ assert!(py_type.sa_type.is_some());
+ assert!(py_type.sa_type.as_ref().unwrap().contains("ARRAY"));
+ }
+
+ #[test]
+ fn test_text_array_type() {
+ let type_info = make_type_info("text", "text[]", true);
+ let py_type = map_pg_type(&type_info);
+ assert_eq!(py_type.annotation, "list[str]");
+ // Text arrays use ARRAY(Text) - Text is the correct SQLAlchemy type for pg text
+ assert!(py_type.sa_type.as_ref().unwrap().contains("ARRAY(Text)"));
+ }
+
+ #[test]
+ fn test_extract_string_length() {
+ assert_eq!(extract_string_length("character varying(255)"), Some(255));
+ assert_eq!(extract_string_length("varchar(100)"), Some(100));
+ assert_eq!(extract_string_length("character(10)"), Some(10));
+ assert_eq!(extract_string_length("char(5)"), Some(5));
+ assert_eq!(extract_string_length("text"), None);
+ }
+
+ #[test]
+ fn test_extract_numeric_precision() {
+ assert_eq!(extract_numeric_precision("numeric(10,2)"), Some((10, 2)));
+ assert_eq!(extract_numeric_precision("numeric(5)"), Some((5, 0)));
+ assert_eq!(extract_numeric_precision("decimal(8,3)"), Some((8, 3)));
+ assert_eq!(extract_numeric_precision("integer"), None);
+ }
+
+ #[test]
+ fn test_is_fixed_length_char() {
+ assert!(is_fixed_length_char("char"));
+ assert!(is_fixed_length_char("character"));
+ assert!(is_fixed_length_char("bpchar"));
+ assert!(!is_fixed_length_char("varchar"));
+ assert!(!is_fixed_length_char("text"));
+ }
+
+ #[test]
+ fn test_network_types() {
+ let type_info = make_type_info("inet", "inet", false);
+ let py_type = map_pg_type(&type_info);
+ assert_eq!(py_type.annotation, "str");
+
+ let type_info = make_type_info("cidr", "cidr", false);
+ let py_type = map_pg_type(&type_info);
+ assert_eq!(py_type.annotation, "str");
+ }
+
+ #[test]
+ fn test_interval_type() {
+ let type_info = make_type_info("interval", "interval", false);
+ let py_type = map_pg_type(&type_info);
+ assert_eq!(py_type.annotation, "timedelta");
+ assert_eq!(py_type.imports[0].module, "datetime");
+ assert_eq!(py_type.imports[0].name, "timedelta");
+ }
+
+ #[test]
+ fn test_unknown_type_fallback() {
+ let type_info = make_type_info("my_custom_type", "my_custom_type", false);
+ let py_type = map_pg_type(&type_info);
+ assert_eq!(py_type.annotation, "Any");
+ assert!(py_type.imports.iter().any(|i| i.name == "Any"));
+ }
+}