diff --git a/voluptuous/schema_builder.py b/voluptuous/schema_builder.py index dd1b653..c98f20d 100644 --- a/voluptuous/schema_builder.py +++ b/voluptuous/schema_builder.py @@ -776,9 +776,10 @@ def key_literal(key): result[key] = value # recompile and send old object + result_cls = type(self) result_required = (required if required is not None else self.required) result_extra = (extra if extra is not None else self.extra) - return Schema(result, required=result_required, extra=result_extra) + return result_cls(result, required=result_required, extra=result_extra) def _compile_scalar(schema): diff --git a/voluptuous/tests/tests.py b/voluptuous/tests/tests.py index f847bc9..779d0c4 100644 --- a/voluptuous/tests/tests.py +++ b/voluptuous/tests/tests.py @@ -372,6 +372,7 @@ def test_schema_extend(): assert extended.schema == {'a': int, 'b': str} assert extended.required == base.required assert extended.extra == base.extra + assert isinstance(extended, Schema) def test_schema_extend_overrides(): @@ -411,6 +412,20 @@ def test_subschema_extension(): assert_equal(extended.schema, {'a': {'b': str, 'c': float, 'e': int}, 'd': str}) +def test_schema_extend_handles_schema_subclass(): + """Verify that Schema.extend handles a subclass of Schema""" + class S(Schema): + pass + + base = S({Required('a'): int}) + extension = {Optional('b'): str} + extended = base.extend(extension) + + expected_schema = {Required('a'): int, Optional('b'): str} + assert extended.schema == expected_schema + assert isinstance(extended, S) + + def test_equality(): assert_equal(Schema('foo'), Schema('foo'))