In [None]:
#| default_exp models

# models

**Set up**  
I love working with Pydantic models, but I'd like to add a few tweaks to make them even better to work with in Jupyter Notebooks

**The Goal**  
Update how Pydantic models are displayed in Jupyter Lab and Jupyter Notebook to make them more natural to work with in an IPython environment

**The Result**  
Subclassed Pydantic's BaseModel and add `_repr_html_` and `_repr_json_` methods


In [None]:
#| exporti 

from pydantic import BaseModel as PydanticBaseModel
from pydantic import ConfigDict
import logging
from json2html import json2html
from humble_chuck.delegation import delegates
from typing import *
from pydantic_settings import BaseSettings as PydanticBaseSettings
from pydantic_settings import (
    PydanticBaseSettingsSource,
    SettingsConfigDict,
    YamlConfigSettingsSource
)
from pydantic import create_model,Field
from pydantic.fields import FieldInfo
import yaml
from pathlib import Path

## Displaying Objects in IPython

If Jupyter encounters an error while calling their display methods, it will try the next available option for displaying an object. For example in Jupyter Lab the default representation is JSON, but if there is an error displaying the JSON it will fall back on displaying with HTML. If the error persists, it will fall back on default \_\
`__repr__` or `__str__` methods. 

So for this exercise, we'll create a function that tries to dump a pydanitc model, but if anything goes wrong it will just issue a warning and pass. That way if there is an issue with our custom display, we'll just get Pydantics default display mechanism. Note that we're using `delegates` here, so the kwargs you'll see in the docs are not included in the signature. 

In [None]:
#| exports

@delegates(PydanticBaseModel.model_dump) 
def model_dump_for_display(
    model:PydanticBaseModel, #The model to by displayed
    **kwargs
):
    """Calls PydanticBaseModel.model_dump(), 
    but if there is an issue it raises a warning and passes to allow default representation.  

    Delegates kwargs to PydanticBaseModel.model_dump
    """
    kwargs['mode']='json'
    try:
        return model.model_dump(**kwargs)
    except Exception as e:
        logging.warning(e)
        pass

In [None]:
#| exports 

class DisplayMixin:
    
    def _repr_json_(self):
        return model_dump_for_display(
            self,
            mode='json',
            **self.model_config.get('repr_kwargs', {})
        )

    def _repr_html_(self):
        return json2html.convert(
            model_dump_for_display(self, mode='json', **self.model_config.get('repr_kwargs', {}))
        )

::: {.callout-note}
Because we are subclassing Pydantic's BaseModel, the docs shown here are taken from the parent class. 
:::

In [None]:
#| exports

class BaseModel(PydanticBaseModel,DisplayMixin):
    pass


In [None]:
from pydantic import Field,AliasGenerator
from typing import *
from datetime import date

### Example

Note that you can customize how objects get displayed in the model_config. Here we'll choose to display the object with aliases instead of field names.

In [None]:
class Project(BaseModel):
    """Model for capturing details about a construction project"""
    
    model_config = ConfigDict(
        alias_generator=AliasGenerator(            
            serialization_alias=lambda field_name: field_name.title().replace('_',' '),
        ),
        repr_kwargs={'by_alias':True} #<-- I can control how the model gets displayed in jupyter by provided kwargs to model_dump 
    )
    
    project_name: str = Field(..., description="Name of the construction project")
    start_date: date = Field(..., description="Date when the project started")
    end_date: Optional[date] = Field(default=None, description="Date when the project ended")
    description: Optional[str] = Field(default=None, description="Short description of the project")
    is_active: bool = Field(..., description="Indicates if the project is currently active")
    budget: Optional[Dict[str, float]] = Field(default=None, description="Budget with different risk assessments")
    employees: List[Dict[str, str]] = Field(..., description="List of employees working on the project")
    technologies_used: List[str] = Field(..., description="List of technologies used in the project")
    


In the docs, you'll see this example represented as HTML. In Jupyter Lab it get's displayed as interactive, collapsible JSON.

In [None]:
# Creating an instance of the model
example_project = Project(
    project_name="Highway Bridge Construction",
    start_date=date(2024, 1, 15),
    end_date=None,
    description="A large-scale project focused on building a new highway bridge.",
    is_active=True,
    budget={"conservative": 5_000_000, "base_line": 6_500_000, "worst_case": 8_000_000},
    employees=[
        {"name": "Alice Johnson", "roll": "Project Manager"},
        {"name": "Bob Smith", "roll": "Engineer"},
        {"name": "Clara Davis", "roll": "Site Supervisor"}
    ],
    technologies_used=["AutoCAD", "Revit", "MS Project"]
)

example_project

name,roll
Alice Johnson,Project Manager
Bob Smith,Engineer
Clara Davis,Site Supervisor
Project Name,Highway Bridge Construction
Start Date,2024-01-15
End Date,
Description,A large-scale project focused on building a new highway bridge.
Is Active,True
Budget,conservative5000000.0base_line6500000.0worst_case8000000.0
Employees,namerollAlice JohnsonProject ManagerBob SmithEngineerClara DavisSite Supervisor

0,1
conservative,5000000.0
base_line,6500000.0
worst_case,8000000.0

name,roll
Alice Johnson,Project Manager
Bob Smith,Engineer
Clara Davis,Site Supervisor


## Settings

I like to have all my settings in the same place instead of maintaining various .env files all over the place. So I've customized Pydantic's BaseSettings to read values from a central yaml file.

In [None]:
#|export

def read_yaml_key(file_path: str, target_key: str) -> dict:
    """
    Reads values from a specific key in a YAML file and returns them as a dictionary.

    :param file_path: Path to the YAML file.
    :param target_key: The key whose values need to be extracted.
    :return: A dictionary containing the values for the specified key.
    """
    import yaml  # Ensure PyYAML is installed and imported

    if not file_path:
        return {}
        
    try:
        with open(file_path, 'r') as yaml_file:
            yaml_content = yaml.safe_load(yaml_file) or {}
    except FileNotFoundError:
        return {}
    
    return yaml_content.get(target_key, {})

In [None]:

file_path = "example_data/example_config.yaml"
target_key = "eg_db_"
result = read_yaml_key(file_path, target_key)
print(result)

In [None]:
#|exporti 

class YMLSettingsSource(PydanticBaseSettingsSource):
    """
    A simple settings source class that loads variables from a JSON file
    at the project's root.

    Here we happen to choose to use the `env_file_encoding` from Config
    when reading `config.json`
    """

    def get_field_value(
        self, field: FieldInfo, field_name: str
    ) -> Tuple[Any, str, bool]:
        env_prefix = self.config.get('env_prefix')
        file_content = read_yaml_key(
            self.config.get('yml_settings_path'),
            env_prefix
        )

        field_value = file_content.get(field_name)
        return field_value, field_name, False

    def prepare_field_value(
        self, field_name: str, field: FieldInfo, value: Any, value_is_complex: bool
    ) -> Any:
        return value

    def __call__(self) -> Dict[str, Any]:
        d: Dict[str, Any] = {}

        for field_name, field in self.settings_cls.model_fields.items():
            field_value, field_key, value_is_complex = self.get_field_value(
                field, field_name
            )
            field_value = self.prepare_field_value(
                field_name, field, field_value, value_is_complex
            )
            if field_value is not None:
                d[field_key] = field_value

        return d

In [None]:
#| export 

class BaseSettings(PydanticBaseSettings,DisplayMixin):
    model_config = SettingsConfigDict(
        yml_settings_path = Path.home() / ".humble-chuck-settings.yml"
    )

    
    @classmethod
    def settings_customise_sources(
        cls,
        settings_cls: Type[PydanticBaseSettings],
        init_settings: PydanticBaseSettingsSource,
        env_settings: PydanticBaseSettingsSource,
        dotenv_settings: PydanticBaseSettingsSource,
        file_secret_settings: PydanticBaseSettingsSource,
    ) -> Tuple[PydanticBaseSettingsSource, ...]:
        return (
            init_settings,
            env_settings,
            dotenv_settings,
            YMLSettingsSource(settings_cls),
            file_secret_settings,
        )

In [None]:
class ExampleSettings(BaseSettings):
    model_config = dict(
        yml_settings_path='example_data/example_config.yaml',
        env_prefix='eg_db_'
    )
    user:str
    password: str

In [None]:
ExampleSettings()

In [None]:
import os

In [None]:
os.environ['eg_db_user'] = 'arnold'
assert ExampleSettings().user == 'arnold'
os.environ.pop('eg_db_user')
assert ExampleSettings().user == 'harold'

## Generic Data Model

In [None]:
#|export

DataModelT = TypeVar('DataModelT')

In [None]:

class DataModel(BaseModel,Generic[DataModelT],DisplayMixin):
    """
    A Generic Data Model. The data attribute contains a list of objects of an arbitrary type. It is intended for use with a Pydantic model. 
    
    Supports rich __repr__ displays in HTML and Javascript for use in Jupyter Notebook and Lab, respectively. 
    """
    data: List[DataModelT] = []

    @delegates(BaseModel.model_dump)
    def to_dataframe(self,**kwargs):
        """turns `data` into a DataFrame. Delegates to `pydantic.BaseModel.model_dump` to control model serialization"""
        data = [x.model_dump(**kwargs) for x in self.data]
        return pd.DataFrame(data)

    @classmethod
    def display_html_schema(cls):
        return HTML(
            json2html.convert(
                cls.model_json_schema(
                    mode='serialization'
                )
            )
        )
    
    def _repr_html_(self):
        try:
            df_html = self.to_dataframe().head()._repr_html_()
            schema = self.model_json_schema()
            html_fields = [
                f"<header><b>{schema_field}</b>: {schema.get(schema_field)}\n</header>"
                for schema_field in ['title','description']
            ]
            for field in self.model_fields.keys():
                if field!='data':
                    html_fields.append(
                        f'<header><b>{field}</b>: {getattr(self,field)}</header>'
                    )
            return ''.join(
                x for x in html_fields + ['<header><b>DataFrame</b>: </header>',df_html]
            )
        except Exception as e:
            logging.warning(e)
            pass

### Example DataModel

In [None]:
import datetime as dt
import pandas as pd
from pydantic import Field,ConfigDict,BeforeValidator
from typing import Annotated,Optional

In [None]:

url = "https://data.cityofnewyork.us/api/views/c3uy-2p5r/rows.csv?accessType=DOWNLOAD"
air_quality_df = pd.read_csv(url)
air_quality_df.head()


Create a model to represent a row of data:

In [None]:
class AirQuality(BaseModel):
    """An air quality measurement from the City of New York"""
    model_config = ConfigDict(
        coerce_numbers_to_str=True,
        alias_generator = lambda x: x.replace('_',' ').title(),
        
    )
    
    unique_id: str = Field(alias='Unique ID')
    indicator_id: str = Field(alias='Indicator ID')
    name: str
    measure: str
    measure_info: str
    geo_type_name: str
    geo_join_id: str = Field(alias='Geo Join ID')
    geo_place_name: str    
    time_period: str
    start_date: Annotated[
        dt.date,
        BeforeValidator(lambda x: dt.datetime.strptime(x,"%m/%d/%Y"))
    ] = Field(alias='Start_Date')
    data_value: float
    

In [None]:
AirQuality.model_validate(
    air_quality_df.to_dict('records')[0]
)

In [None]:
class AirQualityData(DataModel[AirQuality]):
    """Air quality measurements from the City of New York."""
    source: str = "https://data.cityofnewyork.us/api/views/c3uy-2p5r/rows.csv?accessType=DOWNLOAD"
    accessed_at: dt.datetime = Field(default_factory=dt.datetime.now)

    def __init__(
        self,
        source:str = "https://data.cityofnewyork.us/api/views/c3uy-2p5r/rows.csv?accessType=DOWNLOAD",
    ):
        
        data = pd.read_csv(source)
        super().__init__(source=source,data = data.to_dict('records'))

In [None]:
AirQualityData()