In [1]:
from pydantic import BaseModel, Field
from typing_extensions import TypedDict

In [2]:
class Citation(TypedDict):
    value: float
    page: int
    lines: list[int]
    bboxes: list[dict]


class EndingBalance(BaseModel):
    ending_balance: float = Field(description="The ending balance of the account")

In [3]:
EndingBalance.model_json_schema()

{'properties': {'ending_balance': {'description': 'The ending balance of the account',
   'title': 'Ending Balance',
   'type': 'number'}},
 'required': ['ending_balance'],
 'title': 'EndingBalance',
 'type': 'object'}

In [6]:
from pprint import pprint

from pydantic import BaseModel, create_model


class User(BaseModel):
    id: int
    name: str


pprint(User.model_json_schema())
ModifiedUser = create_model(
    "ModifiedUser",
    __base__=User,
    id=(str, ...),  # change type
)
pprint(ModifiedUser.model_json_schema())

{'properties': {'id': {'title': 'Id', 'type': 'integer'},
                'name': {'title': 'Name', 'type': 'string'}},
 'required': ['id', 'name'],
 'title': 'User',
 'type': 'object'}
{'properties': {'id': {'title': 'Id', 'type': 'string'},
                'name': {'title': 'Name', 'type': 'string'}},
 'required': ['id', 'name'],
 'title': 'ModifiedUser',
 'type': 'object'}


In [3]:
from pprint import pprint
from typing import Any, get_args, get_origin

from pydantic import BaseModel, create_model
from typing_extensions import TypedDict


# Original Citation structure
class Citation(TypedDict):
    page: int
    lines: list[int]


# Extended Citation with bounding boxes
class CitationWithBBox(TypedDict):
    page: int
    lines: list[int]
    bounding_boxes: list[dict[str, float]]  # list of {x, y, width, height}


# Processor determines citation type
class Processor:
    citation_type = list[Citation]
    citation_type_with_bbox = list[CitationWithBBox]


def is_citation_type(field_type: Any, citation_type: Any) -> bool:
    """
    Check if a field type matches the provided citation type.

    Args:
        field_type: The field type to check
        citation_type: The citation type to match against (e.g., processor.citation_type)

    Returns:
        True if the field type matches the citation type
    """
    # Direct comparison
    if field_type == citation_type:
        return True

    # Check if both are generic list types with same args
    field_origin = get_origin(field_type)
    citation_origin = get_origin(citation_type)

    if field_origin is list and citation_origin is list:
        field_args = get_args(field_type)
        citation_args = get_args(citation_type)

        if field_args and citation_args:
            # Compare the inner types
            if field_args[0] == citation_args[0]:
                return True
            # Also check by name for TypedDict comparisons
            if hasattr(field_args[0], "__name__") and hasattr(
                citation_args[0], "__name__"
            ):
                if field_args[0].__name__ == citation_args[0].__name__:
                    return True

    return False


def find_citation_fields(
    model: type[BaseModel], citation_type: Any, prefix: str = ""
) -> list[str]:
    """
    Recursively find all fields with citation type in a Pydantic model.

    Args:
        model: The Pydantic model to inspect
        citation_type: The citation type to search for (e.g., processor.citation_type)
        prefix: Current field path prefix for nested models

    Returns:
        List of field paths (e.g., ['my_data_citation', 'nested.citation_field'])
    """
    citation_fields = []

    for field_name, field_info in model.model_fields.items():
        field_type = field_info.annotation
        current_path = f"{prefix}.{field_name}" if prefix else field_name

        # Check if this field is a citation type
        if is_citation_type(field_type, citation_type):
            citation_fields.append(current_path)

        # Check if this field is a nested BaseModel
        elif isinstance(field_type, type) and issubclass(field_type, BaseModel):
            nested_fields = find_citation_fields(
                field_type, citation_type, current_path
            )
            citation_fields.extend(nested_fields)

        # Check if it's a list of BaseModel
        elif get_origin(field_type) is list:
            args = get_args(field_type)
            if args and isinstance(args[0], type) and issubclass(args[0], BaseModel):
                nested_fields = find_citation_fields(
                    args[0], citation_type, current_path
                )
                citation_fields.extend(nested_fields)

    return citation_fields


def modify_model_citations(
    model: type[BaseModel], original_citation_type: Any, new_citation_type: Any
) -> type[BaseModel]:
    """
    Recursively modify a Pydantic model to replace citation type with new citation type.

    Args:
        model: The original Pydantic model
        original_citation_type: The citation type to replace (e.g., processor.citation_type)
        new_citation_type: The new citation type (e.g., processor.citation_type_with_bbox)

    Returns:
        A new model class with updated citation types
    """
    new_fields = {}

    for field_name, field_info in model.model_fields.items():
        field_type = field_info.annotation
        default = field_info.default if field_info.default is not None else ...

        # If it's a citation field, update the type
        if is_citation_type(field_type, original_citation_type):
            new_fields[field_name] = (new_citation_type, default)

        # If it's a nested BaseModel, recursively modify it
        elif isinstance(field_type, type) and issubclass(field_type, BaseModel):
            modified_nested = modify_model_citations(
                field_type, original_citation_type, new_citation_type
            )
            new_fields[field_name] = (modified_nested, default)

        # If it's a list of BaseModel, modify the inner model
        elif get_origin(field_type) is list:
            args = get_args(field_type)
            if args and isinstance(args[0], type) and issubclass(args[0], BaseModel):
                modified_nested = modify_model_citations(
                    args[0], original_citation_type, new_citation_type
                )
                new_fields[field_name] = (list[modified_nested], default)
            else:
                new_fields[field_name] = (field_type, default)

        # Otherwise keep the original field
        else:
            new_fields[field_name] = (field_type, default)

    # Create new model with modified fields
    return create_model(f"{model.__name__}WithBBox", __base__=BaseModel, **new_fields)

In [4]:
# Example usage
processor = Processor()


class NestedData(BaseModel):
    nested_field: str
    nested_citation: processor.citation_type


class MyData(BaseModel):
    my_data: str
    my_data_citation: processor.citation_type
    nested: NestedData
    items: list[str]


# Find all citation fields using the processor's citation type
print("Citation fields found:")
citation_fields = find_citation_fields(MyData, processor.citation_type)
pprint(citation_fields)
print()

# Modify the model using processor's types
print("Original model schema:")
pprint(MyData.model_json_schema())
print()

ModifiedMyData = modify_model_citations(
    MyData,
    original_citation_type=processor.citation_type,
    new_citation_type=processor.citation_type_with_bbox,
)

print("Modified model schema:")
pprint(ModifiedMyData.model_json_schema())
print()

# Test with actual data
original_data = MyData(
    my_data="test",
    my_data_citation=[{"page": 1, "lines": [1, 2, 3]}],
    nested=NestedData(
        nested_field="nested", nested_citation=[{"page": 2, "lines": [4, 5]}]
    ),
    items=["a", "b"],
)

# Create modified data with bounding boxes
modified_data = ModifiedMyData(
    my_data="test",
    my_data_citation=[
        {
            "page": 1,
            "lines": [1, 2, 3],
            "bounding_boxes": [
                {"x": 10.0, "y": 20.0, "width": 100.0, "height": 15.0},
                {"x": 10.0, "y": 35.0, "width": 100.0, "height": 15.0},
                {"x": 10.0, "y": 50.0, "width": 100.0, "height": 15.0},
            ],
        }
    ],
    nested={
        "nested_field": "nested",
        "nested_citation": [
            {
                "page": 2,
                "lines": [4, 5],
                "bounding_boxes": [
                    {"x": 15.0, "y": 65.0, "width": 120.0, "height": 15.0},
                    {"x": 15.0, "y": 80.0, "width": 120.0, "height": 15.0},
                ],
            }
        ],
    },
    items=["a", "b"],
)

print("Modified data instance:")
pprint(modified_data.model_dump())


# Example with different processor configuration
print("\n" + "=" * 60)
print("Example with custom processor configuration:")
print("=" * 60 + "\n")


class CustomCitation(TypedDict):
    document_id: str
    paragraph: int


class CustomCitationWithBBox(TypedDict):
    document_id: str
    paragraph: int
    bounding_boxes: list[dict[str, float]]


class CustomProcessor:
    citation_type = list[CustomCitation]
    citation_type_with_bbox = list[CustomCitationWithBBox]


custom_processor = CustomProcessor()


class CustomData(BaseModel):
    content: str
    source: custom_processor.citation_type


print("Custom citation fields found:")
custom_citation_fields = find_citation_fields(
    CustomData, custom_processor.citation_type
)
pprint(custom_citation_fields)

ModifiedCustomData = modify_model_citations(
    CustomData,
    original_citation_type=custom_processor.citation_type,
    new_citation_type=custom_processor.citation_type_with_bbox,
)

print("\nModified custom model schema:")
pprint(ModifiedCustomData.model_json_schema())

Citation fields found:
['my_data_citation', 'nested.nested_citation']

Original model schema:
{'$defs': {'Citation': {'properties': {'lines': {'items': {'type': 'integer'},
                                                 'title': 'Lines',
                                                 'type': 'array'},
                                       'page': {'title': 'Page',
                                                'type': 'integer'}},
                        'required': ['page', 'lines'],
                        'title': 'Citation',
                        'type': 'object'},
           'NestedData': {'properties': {'nested_citation': {'items': {'$ref': '#/$defs/Citation'},
                                                             'title': 'Nested '
                                                                      'Citation',
                                                             'type': 'array'},
                                         'nested_field': {'title': 'Nested 