Skip to content

Commit

Permalink
Prefer custom encoder over defaults if specified.
Browse files Browse the repository at this point in the history
  • Loading branch information
viveksunder committed Dec 24, 2020
1 parent 5614b94 commit fffd897
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 9 deletions.
18 changes: 9 additions & 9 deletions fastapi/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,17 @@ def jsonable_encoder(
exclude_unset: bool = False,
exclude_defaults: bool = False,
exclude_none: bool = False,
custom_encoder: Dict[Any, Callable[[Any], Any]] = {},
custom_encoder: Optional[Dict[Any, Callable[[Any], Any]]] = None,
sqlalchemy_safe: bool = True,
) -> Any:
custom_encoder = custom_encoder or {}
if custom_encoder:
if type(obj) in custom_encoder:
return custom_encoder[type(obj)](obj)
else:
for encoder_type, encoder_instance in custom_encoder.items():
if isinstance(obj, encoder_type):
return encoder_instance(obj)
if include is not None and not isinstance(include, set):
include = set(include)
if exclude is not None and not isinstance(exclude, set):
Expand Down Expand Up @@ -115,14 +123,6 @@ def jsonable_encoder(
)
return encoded_list

if custom_encoder:
if type(obj) in custom_encoder:
return custom_encoder[type(obj)](obj)
else:
for encoder_type, encoder in custom_encoder.items():
if isinstance(obj, encoder_type):
return encoder(obj)

if type(obj) in ENCODERS_BY_TYPE:
return ENCODERS_BY_TYPE[type(obj)](obj)
for encoder, classes_tuple in encoders_by_class_tuples.items():
Expand Down
15 changes: 15 additions & 0 deletions tests/test_jsonable_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,21 @@ class MyModel(BaseModel):
assert encoded_instance["dt_field"] == instance.dt_field.isoformat()


def test_custom_enum_encoders():
def custom_enum_encoder(v):
return v.value.lower()

class MyEnum(Enum):
ENUM_VAL_1 = "ENUM_VAL_1"

instance = MyEnum.ENUM_VAL_1

encoded_instance = jsonable_encoder(
instance, custom_encoder={MyEnum: custom_enum_encoder}
)
assert encoded_instance == custom_enum_encoder(instance)


def test_encode_model_with_path(model_with_path):
if isinstance(model_with_path.path, PureWindowsPath):
expected = "\\foo\\bar"
Expand Down

0 comments on commit fffd897

Please sign in to comment.