Skip to content

Commit

Permalink
Localize types
Browse files Browse the repository at this point in the history
  • Loading branch information
rlouf committed May 6, 2024
1 parent 353cebb commit 4f8433d
Show file tree
Hide file tree
Showing 6 changed files with 73 additions and 16 deletions.
28 changes: 22 additions & 6 deletions docs/reference/types.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,8 @@

Outlines provides custom Pydantic types so you can focus on your use case rather than on writing regular expressions:


| Category | Type | Import | Description |
|:--------:|:----:|:-------|:------------|
| Zip code | US | `outlines.types.ZipCode` | Generate US Zip(+4) codes |
| Phone number | US | `outlines.types.PhoneNumber` | Generate valid US phone numbers |
| ISBN | 10 & 13 | `outlines.types.ISBN` | There is no guarantee that the [check digit][wiki-isbn] will be correct |
| Airport | IATA | `outlines.types.airports.IATA` | Valid [airport IATA codes][wiki-airport-iata] |
| Country | alpha-2 code | `outlines.types.airports.Alpha2` | Valid [country alpha-2 codes][wiki-country-alpha-2] |
Expand All @@ -15,18 +12,37 @@ Outlines provides custom Pydantic types so you can focus on your use case rather
| | name | `outlines.types.countries.Name` | Valid country names |
| | flag | `outlines.types.countries.Flag` | Valid flag emojis |

Some types require localization. We currently only support US types, but please don't hesitate to create localized versions of the different types and open a Pull Request. Localized types are specified using `types.locale` in the following way:

```python
from outlines import types

types.locale("us").ZipCode
types.locale("us").PhoneNumber
```

Here are the localized types that are currently available:

| Category | Locale | Import | Description |
|:--------:|:----:|:-------|:------------|
| Zip code | US | `ZipCode` | Generate US Zip(+4) codes |
| Phone number | US | `PhoneNumber` | Generate valid US phone numbers |


You can use these types in Pydantic schemas for JSON-structured generation:

```python
from pydantic import BaseModel

from outlines import models, generate, types

# Specify the locale for types
locale = types.locale("us")

class Client(BaseModel):
name: str
phone_number: types.PhoneNumber
zip_code: types.ZipCode
phone_number: locale.PhoneNumber
zip_code: locale.ZipCode


model = models.transformers("mistralai/Mistral-7B-v0.1")
Expand All @@ -47,7 +63,7 @@ from outlines import models, generate, types


model = models.transformers("mistralai/Mistral-7B-v0.1")
generator = generate.format(model, types.PhoneNumber)
generator = generate.format(model, types.locale("us").PhoneNumber)
result = generator(
"Return a US Phone number: "
)
Expand Down
3 changes: 1 addition & 2 deletions outlines/types/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from . import airports, countries
from .isbn import ISBN
from .phone_numbers import PhoneNumber
from .zip_codes import ZipCode
from .locales import locale
21 changes: 21 additions & 0 deletions outlines/types/locales.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from dataclasses import dataclass

from outlines.types.phone_numbers import USPhoneNumber
from outlines.types.zip_codes import USZipCode


@dataclass
class US:
ZipCode = USZipCode
PhoneNumber = USPhoneNumber


def locale(locale_str: str):
locales = {"us": US}

if locale_str not in locales:
raise NotImplementedError(
f"The locale {locale_str} is not supported yet. Please don't hesitate to create custom types for you locale and open a Pull Request."
)

return locales[locale_str]
2 changes: 1 addition & 1 deletion outlines/types/phone_numbers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
US_PHONE_NUMBER = r"(\([0-9]{3}\) |[0-9]{3}-)[0-9]{3}-[0-9]{4}"


PhoneNumber = Annotated[
USPhoneNumber = Annotated[
str,
WithJsonSchema({"type": "string", "pattern": US_PHONE_NUMBER}),
]
2 changes: 1 addition & 1 deletion outlines/types/zip_codes.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,4 @@
US_ZIP_CODE = r"\d{5}(?:-\d{4})?"


ZipCode = Annotated[str, WithJsonSchema({"type": "string", "pattern": US_ZIP_CODE})]
USZipCode = Annotated[str, WithJsonSchema({"type": "string", "pattern": US_ZIP_CODE})]
33 changes: 27 additions & 6 deletions tests/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@
@pytest.mark.parametrize(
"custom_type,test_string,should_match",
[
(types.PhoneNumber, "12", False),
(types.PhoneNumber, "(123) 123-1234", True),
(types.PhoneNumber, "123-123-1234", True),
(types.ZipCode, "12", False),
(types.ZipCode, "12345", True),
(types.ZipCode, "12345-1234", True),
(types.phone_numbers.USPhoneNumber, "12", False),
(types.phone_numbers.USPhoneNumber, "(123) 123-1234", True),
(types.phone_numbers.USPhoneNumber, "123-123-1234", True),
(types.zip_codes.USZipCode, "12", False),
(types.zip_codes.USZipCode, "12345", True),
(types.zip_codes.USZipCode, "12345-1234", True),
(types.ISBN, "ISBN 0-1-2-3-4-5", False),
(types.ISBN, "ISBN 978-0-596-52068-7", True),
# (types.ISBN, "ISBN 978-0-596-52068-1", True), wrong check digit
Expand All @@ -42,6 +42,27 @@ class Model(BaseModel):
assert does_match is should_match


def test_locale_not_implemented():
with pytest.raises(NotImplementedError):
types.locale("fr")


@pytest.mark.parametrize(
"locale_str,base_types,locale_types",
[
(
"us",
["ZipCode", "PhoneNumber"],
[types.zip_codes.USZipCode, types.phone_numbers.USPhoneNumber],
)
],
)
def test_locale(locale_str, base_types, locale_types):
for base_type, locale_type in zip(base_types, locale_types):
type = getattr(types.locale(locale_str), base_type)
assert type == locale_type


@pytest.mark.parametrize(
"custom_type,test_string,should_match",
[
Expand Down

0 comments on commit 4f8433d

Please sign in to comment.