From 51064f3dd4082a48399150ec9c78ac219116db51 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Louf?= Date: Thu, 2 May 2024 11:48:35 +0200 Subject: [PATCH] Add country types --- docs/reference/types.md | 23 +++++++++++++++++++---- outlines/fsm/types.py | 8 +++++--- outlines/types/__init__.py | 2 +- outlines/types/countries.py | 24 ++++++++++++++++++++++++ pyproject.toml | 3 +++ tests/test_types.py | 11 ++++++++++- 6 files changed, 62 insertions(+), 9 deletions(-) create mode 100644 outlines/types/countries.py diff --git a/docs/reference/types.md b/docs/reference/types.md index fe5c51499..645249263 100644 --- a/docs/reference/types.md +++ b/docs/reference/types.md @@ -2,10 +2,18 @@ Outlines provides custom Pydantic types so you can focus on your use case rather than on writing regular expressions: -- Using `outlines.types.ZipCode` will generate valid US Zip(+4) codes. -- Using `outlines.types.PhoneNumber` will generate valid US phone numbers. -- Using `outlines.types.ISBN` will generate ISBNs. Note that there is no guarantee that the [check digit](https://en.wikipedia.org/wiki/ISBN#Check_digits) will be correct. -- Using `outlines.types.airports.IATA` will generate valid airport IATA codes. + +| Category | Type | Import | Description | +|:--------:|:----:|:-------|:------------| +| Zip code | US | `outlines.types.ZipCode` | Generate US Zip(+4) codes | +| Phone number | US | `outlines.types.PhoneNumber` | Generate valid US phone numbers | +| ISBN | 10 & 13 | `outlines.types.ISBN` | There is no guarantee that the [check digit][wiki-isbn] will be correct | +| Airport | IATA | `outlines.types.airports.IATA` | Valid [airport IATA codes][wiki-airport-iata] | +| Country | alpha-2 code | `outlines.types.airports.Alpha2` | Valid [country alpha-2 codes][wiki-country-alpha-2] | +| | alpha-3 code | `outlines.types.countries.Alpha3` | Valid [country alpha-3 codes][wiki-country-alpha-3] | +| | numeric code | `outlines.types.countries.Numeric` | Valid [country numeric codes][wiki-country-numeric] | +| | name | `outlines.types.countries.Name` | Valid country names | +| | flag | `outlines.types.countries.Flag` | Valid flag emojis | You can use these types in Pydantic schemas for JSON-structured generation: @@ -49,3 +57,10 @@ print(result) We plan on adding many more custom types. If you have found yourself writing regular expressions to generate fields of a given type, or if you could benefit from more specific types don't hesite to [submit a PR](https://github.com/outlines-dev/outlines/pulls) or [open an issue](https://github.com/outlines-dev/outlines/issues/new/choose). + + +[wiki-isbn]: https://en.wikipedia.org/wiki/ISBN#Check_digits +[wiki-airport-iata]: https://en.wikipedia.org/wiki/IATA_airport_code +[wiki-country-alpha-2]: https://en.wikipedia.org/wiki/ISO_3166-1_alpha-2 +[wiki-country-alpha-3]: https://en.wikipedia.org/wiki/ISO_3166-1_alpha-3 +[wiki-country-numeric]: https://en.wikipedia.org/wiki/ISO_3166-1_numeric diff --git a/outlines/fsm/types.py b/outlines/fsm/types.py index cddcd163f..5695dee07 100644 --- a/outlines/fsm/types.py +++ b/outlines/fsm/types.py @@ -32,10 +32,12 @@ def custom_format_fn(sequence: str) -> Any: if isinstance(python_type, EnumMeta): values = python_type.__members__.keys() - regex_str = "(" + "|".join(values) + ")" - format_fn = lambda x: str(x) + enum_regex_str: str = "(" + "|".join(values) + ")" - return regex_str, format_fn + def enum_format_fn(sequence: str) -> str: + return str(sequence) + + return enum_regex_str, enum_format_fn if python_type == float: diff --git a/outlines/types/__init__.py b/outlines/types/__init__.py index 7af7f296f..266d3a68e 100644 --- a/outlines/types/__init__.py +++ b/outlines/types/__init__.py @@ -1,4 +1,4 @@ +from . import airports, countries from .isbn import ISBN from .phone_numbers import PhoneNumber from .zip_codes import ZipCode -from . import airports diff --git a/outlines/types/countries.py b/outlines/types/countries.py new file mode 100644 index 000000000..888443dc6 --- /dev/null +++ b/outlines/types/countries.py @@ -0,0 +1,24 @@ +"""Generate valid country codes and names.""" +from enum import Enum + +try: + import pycountry +except ImportError: + raise ImportError( + 'The `countries` module requires "pycountry" to be installed. You can install it with "pip install pycountry"' + ) + +ALPHA_2_CODE = [(country.alpha_2, country.alpha_2) for country in pycountry.countries] +Alpha2 = Enum("Alpha_2", ALPHA_2_CODE) # type:ignore + +ALPHA_3_CODE = [(country.alpha_3, country.alpha_3) for country in pycountry.countries] +Alpha3 = Enum("Alpha_2", ALPHA_3_CODE) # type:ignore + +NUMERIC_CODE = [(country.numeric, country.numeric) for country in pycountry.countries] +Numeric = Enum("Numeric_code", NUMERIC_CODE) # type:ignore + +NAME = [(country.name, country.name) for country in pycountry.countries] +Name = Enum("Name", NAME) # type:ignore + +FLAG = [(country.flag, country.flag) for country in pycountry.countries] +Flag = Enum("Flag", FLAG) # type:ignore diff --git a/pyproject.toml b/pyproject.toml index c4d29142d..0b6ae6352 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -57,6 +57,8 @@ test = [ "vllm", "torch", "transformers", + "pycountry", + "pyairports", ] serve = [ "vllm>=0.3.0", @@ -126,6 +128,7 @@ module = [ "vllm.*", "uvicorn.*", "fastapi.*", + "pycountry.*", "pyairports.*", ] ignore_missing_imports = true diff --git a/tests/test_types.py b/tests/test_types.py index 8143d3d6d..2391ccc18 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -47,10 +47,19 @@ class Model(BaseModel): [ (types.airports.IATA, "CDG", True), (types.airports.IATA, "XXX", False), + (types.countries.Alpha2, "FR", True), + (types.countries.Alpha2, "XX", False), + (types.countries.Alpha3, "UKR", True), + (types.countries.Alpha3, "XXX", False), + (types.countries.Numeric, "004", True), + (types.countries.Numeric, "900", False), + (types.countries.Name, "Ukraine", True), + (types.countries.Name, "Wonderland", False), + (types.countries.Flag, "πŸ‡ΏπŸ‡Ό", True), + (types.countries.Flag, "πŸ€—", False), ], ) def test_type_enum(custom_type, test_string, should_match): - type_name = custom_type.__name__ class Model(BaseModel):