Skip to content

Commit

Permalink
Add IATA code type
Browse files Browse the repository at this point in the history
  • Loading branch information
rlouf committed May 2, 2024
1 parent ec664fe commit 1053e7d
Show file tree
Hide file tree
Showing 6 changed files with 59 additions and 7 deletions.
1 change: 1 addition & 0 deletions docs/reference/types.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ Outlines provides custom Pydantic types so you can focus on your use case rather
- Using `outlines.types.ZipCode` will generate valid US Zip(+4) codes.
- Using `outlines.types.PhoneNumber` will generate valid US phone numbers.
- Using `outlines.types.ISBN` will generate ISBNs. Note that there is no guarantee that the [check digit](https://en.wikipedia.org/wiki/ISBN#Check_digits) will be correct.
- Using `outlines.types.airports.IATA` will generate valid airport IATA codes.

You can use these types in Pydantic schemas for JSON-structured generation:

Expand Down
20 changes: 14 additions & 6 deletions outlines/fsm/types.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import datetime
from typing import Protocol, Tuple, Type, Union
from enum import EnumMeta
from typing import Any, Protocol, Tuple, Type

from typing_extensions import _AnnotatedAlias, get_args

Expand All @@ -12,9 +13,7 @@


class FormatFunction(Protocol):
def __call__(
self, sequence: str
) -> Union[int, float, bool, datetime.date, datetime.time, datetime.datetime]:
def __call__(self, sequence: str) -> Any:
...


Expand All @@ -24,8 +23,17 @@ def python_types_to_regex(python_type: Type) -> Tuple[str, FormatFunction]:
json_schema = get_args(python_type)[1].json_schema
type_class = get_args(python_type)[0]

regex_str = json_schema["pattern"]
format_fn = lambda x: type_class(x)
custom_regex_str = json_schema["pattern"]

def custom_format_fn(sequence: str) -> Any:
return type_class(sequence)

return custom_regex_str, custom_format_fn

if isinstance(python_type, EnumMeta):
values = python_type.__members__.keys()
regex_str = "(" + "|".join(values) + ")"
format_fn = lambda x: str(x)

return regex_str, format_fn

Expand Down
1 change: 1 addition & 0 deletions outlines/types/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .isbn import ISBN
from .phone_numbers import PhoneNumber
from .zip_codes import ZipCode
from . import airports
16 changes: 16 additions & 0 deletions outlines/types/airports.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
"""Generate valid airport codes."""
from enum import Enum

try:
from pyairports.airports import AIRPORT_LIST
except ImportError:
raise ImportError(
'The `airports` module requires "pyairports" to be installed. You can install it with "pip install pyairports"'
)


AIRPORT_IATA_LIST = list(
{(airport[3], airport[3]) for airport in AIRPORT_LIST if airport[3] != ""}
)

IATA = Enum("Airport", AIRPORT_IATA_LIST) # type:ignore
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ module = [
"vllm.*",
"uvicorn.*",
"fastapi.*",
"pyairports.*",
]
ignore_missing_imports = true

Expand Down
27 changes: 26 additions & 1 deletion tests/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
(types.ISBN, "0-596-52068-9", True),
],
)
def test_phone_number(custom_type, test_string, should_match):
def test_type_regex(custom_type, test_string, should_match):
class Model(BaseModel):
attr: custom_type

Expand All @@ -40,3 +40,28 @@ class Model(BaseModel):
assert isinstance(format_fn(1), str)
does_match = re.match(regex_str, test_string) is not None
assert does_match is should_match


@pytest.mark.parametrize(
"custom_type,test_string,should_match",
[
(types.airports.IATA, "CDG", True),
(types.airports.IATA, "XXX", False),
],
)
def test_type_enum(custom_type, test_string, should_match):

type_name = custom_type.__name__

class Model(BaseModel):
attr: custom_type

schema = Model.model_json_schema()
assert isinstance(schema["$defs"][type_name]["enum"], list)
does_match = test_string in schema["$defs"][type_name]["enum"]
assert does_match is should_match

regex_str, format_fn = python_types_to_regex(custom_type)
assert isinstance(format_fn(1), str)
does_match = re.match(regex_str, test_string) is not None
assert does_match is should_match

0 comments on commit 1053e7d

Please sign in to comment.