Skip to content

Commit

Permalink
Fix incorrect subclass check for secretstr (#6730)
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexVndnblcke committed Jul 18, 2023
1 parent c28c0c6 commit 66251f8
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 3 deletions.
6 changes: 3 additions & 3 deletions pydantic/types.py
Expand Up @@ -728,7 +728,7 @@ class _SecretFieldValidator:
inner_schema: CoreSchema = _dataclasses.field(init=False)

def validate(self, value: _SecretField[SecretType] | SecretType, _: core_schema.ValidationInfo) -> Any:
error_prefix: Literal['string', 'bytes'] = 'string' if self.field_type is SecretStr else 'bytes'
error_prefix: Literal['string', 'bytes'] = 'string' if issubclass(self.field_type, SecretStr) else 'bytes'
if self.min_length is not None and len(value) < self.min_length:
short_kind: core_schema.ErrorType = f'{error_prefix}_too_short' # type: ignore[assignment]
raise PydanticKnownError(short_kind, {'min_length': self.min_length})
Expand Down Expand Up @@ -771,8 +771,8 @@ def __get_pydantic_json_schema__(
def __get_pydantic_core_schema__(
self, source: type[Any], handler: _annotated_handlers.GetCoreSchemaHandler
) -> core_schema.CoreSchema:
self.inner_schema = handler(str if self.field_type is SecretStr else bytes)
error_kind = 'string_type' if self.field_type is SecretStr else 'bytes_type'
self.inner_schema = handler(str if issubclass(self.field_type, SecretStr) else bytes)
error_kind = 'string_type' if issubclass(self.field_type, SecretStr) else 'bytes_type'
return core_schema.general_after_validator_function(
self.validate,
core_schema.union_schema(
Expand Down
33 changes: 33 additions & 0 deletions tests/test_types.py
Expand Up @@ -3894,6 +3894,39 @@ class Foobar(BaseModel):
assert f.empty_password.get_secret_value() == ''


def test_secretstr_subclass():
class DecryptableStr(SecretStr):
"""
Simulate a SecretStr with decryption capabilities.
"""

def decrypt_value(self) -> str:
return f'MOCK DECRYPTED {self.get_secret_value()}'

class Foobar(BaseModel):
password: DecryptableStr
empty_password: SecretStr

# Initialize the model.
f = Foobar(password='1234', empty_password='')

# Assert correct types.
assert f.password.__class__.__name__ == 'DecryptableStr'
assert f.empty_password.__class__.__name__ == 'SecretStr'

# Assert str and repr are correct.
assert str(f.password) == '**********'
assert str(f.empty_password) == ''
assert repr(f.password) == "DecryptableStr('**********')"
assert repr(f.empty_password) == "SecretStr('')"
assert len(f.password) == 4
assert len(f.empty_password) == 0

# Assert retrieval of secret value is correct
assert f.password.get_secret_value() == '1234'
assert f.empty_password.get_secret_value() == ''


def test_secretstr_equality():
assert SecretStr('abc') == SecretStr('abc')
assert SecretStr('123') != SecretStr('321')
Expand Down

0 comments on commit 66251f8

Please sign in to comment.