From 203b395fe0d09d8c30c37462c46192b96e6dac00 Mon Sep 17 00:00:00 2001 From: Sydney Runkle <54324534+sydney-runkle@users.noreply.github.com> Date: Thu, 16 Nov 2023 10:14:31 -0600 Subject: [PATCH] Fix validation of `Literal` from JSON keys when used as `dict` key (#1075) Co-authored-by: David Montague <35119617+dmontagu@users.noreply.github.com> --- src/validators/literal.rs | 16 +++++++++++++--- tests/test.rs | 33 ++++++++++++++++++++++++++++++++- 2 files changed, 45 insertions(+), 4 deletions(-) diff --git a/src/validators/literal.rs b/src/validators/literal.rs index c9a846695..686920cca 100644 --- a/src/validators/literal.rs +++ b/src/validators/literal.rs @@ -9,7 +9,7 @@ use pyo3::{intern, PyTraverseError, PyVisit}; use crate::build_tools::{py_schema_err, py_schema_error_type}; use crate::errors::{ErrorType, ValError, ValResult}; -use crate::input::Input; +use crate::input::{Input, ValidationMatch}; use crate::py_gc::PyGcTraverse; use crate::tools::SchemaDict; @@ -116,8 +116,18 @@ impl LiteralLookup { } } if let Some(expected_strings) = &self.expected_str { - // dbg!(expected_strings); - if let Ok(either_str) = input.exact_str() { + let validation_result = if input.is_python() { + input.exact_str() + } else { + // Strings coming from JSON are treated as "strict" but not "exact" for reasons + // of parsing types like UUID; see the implementation of `validate_str` for Json + // inputs for justification. We might change that eventually, but for now we need + // to work around this when loading from JSON + // V3 TODO: revisit making this "exact" for JSON inputs + input.validate_str(true, false).map(ValidationMatch::into_inner) + }; + + if let Ok(either_str) = validation_result { let cow = either_str.as_cow()?; if let Some(id) = expected_strings.get(cow.as_ref()) { return Ok(Some((input, &self.values[*id]))); diff --git a/tests/test.rs b/tests/test.rs index 348520435..4b0918d40 100644 --- a/tests/test.rs +++ b/tests/test.rs @@ -1,6 +1,6 @@ #[cfg(test)] mod tests { - use _pydantic_core::SchemaSerializer; + use _pydantic_core::{SchemaSerializer, SchemaValidator}; use pyo3::prelude::*; use pyo3::types::PyDict; @@ -86,4 +86,35 @@ a = A() assert_eq!(serialized, b"{\"b\":\"b\"}"); }); } + + #[test] + fn test_literal_schema() { + Python::with_gil(|py| { + let code = r#" +schema = { + "type": "dict", + "keys_schema": { + "type": "literal", + "expected": ["a", "b"], + }, + "values_schema": { + "type": "str", + }, + "strict": False, +} +json_input = '{"a": "something"}' + "#; + let locals = PyDict::new(py); + py.run(code, None, Some(locals)).unwrap(); + let schema: &PyDict = locals.get_item("schema").unwrap().unwrap().extract().unwrap(); + let json_input: &PyAny = locals.get_item("json_input").unwrap().unwrap().extract().unwrap(); + let binding = SchemaValidator::py_new(py, schema, None) + .unwrap() + .validate_json(py, json_input, None, None, None) + .unwrap(); + let validation_result: &PyAny = binding.extract(py).unwrap(); + let repr = format!("{}", validation_result.repr().unwrap()); + assert_eq!(repr, "{'a': 'something'}"); + }); + } }