diff --git a/Cargo.lock b/Cargo.lock index 7721e33fa..b57e0b682 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -224,7 +224,7 @@ dependencies = [ [[package]] name = "pydantic-core" -version = "0.7.0" +version = "0.7.1" dependencies = [ "ahash", "enum_dispatch", diff --git a/Cargo.toml b/Cargo.toml index b4a84a7b9..4bd145a04 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "pydantic-core" -version = "0.7.0" +version = "0.7.1" edition = "2021" license = "MIT" homepage = "https://github.com/pydantic/pydantic-core" diff --git a/pydantic_core/__init__.py b/pydantic_core/__init__.py index c90d87490..b45f96343 100644 --- a/pydantic_core/__init__.py +++ b/pydantic_core/__init__.py @@ -1,4 +1,5 @@ from ._pydantic_core import ( + MultiHostUrl, PydanticCustomError, PydanticKnownError, PydanticOmit, @@ -16,6 +17,7 @@ 'CoreSchema', 'SchemaValidator', 'Url', + 'MultiHostUrl', 'SchemaError', 'ValidationError', 'PydanticCustomError', diff --git a/pydantic_core/_pydantic_core.pyi b/pydantic_core/_pydantic_core.pyi index 2f0db34d4..b6fc71021 100644 --- a/pydantic_core/_pydantic_core.pyi +++ b/pydantic_core/_pydantic_core.pyi @@ -14,6 +14,7 @@ __all__ = ( 'build_profile', 'SchemaValidator', 'Url', + 'MultiHostUrl', 'SchemaError', 'ValidationError', 'PydanticCustomError', @@ -44,19 +45,35 @@ class Url: username: 'str | None' password: 'str | None' host: 'str | None' - host_type: Literal['domain', 'punycode_domain', 'ipv4', 'ipv6', None] port: 'int | None' path: 'str | None' query: 'str | None' fragment: 'str | None' - def __init__(self, raw_url: str) -> None: ... def unicode_host(self) -> 'str | None': ... def query_params(self) -> 'list[tuple[str, str]]': ... def unicode_string(self) -> str: ... def __str__(self) -> str: ... def __repr__(self) -> str: ... +class MultiHostHost(TypedDict): + username: 'str | None' + password: 'str | None' + host: str + port: 'int | None' + +class MultiHostUrl: + scheme: str + path: 'str | None' + query: 'str | None' + fragment: 'str | None' + + def hosts(self) -> 'list[MultiHostHost]': ... + def query_params(self) -> 'list[tuple[str, str]]': ... + def unicode_string(self) -> str: ... + def __str__(self) -> str: ... + def __repr__(self) -> str: ... + class SchemaError(Exception): pass diff --git a/pydantic_core/core_schema.py b/pydantic_core/core_schema.py index 7630b4e06..386d1bc16 100644 --- a/pydantic_core/core_schema.py +++ b/pydantic_core/core_schema.py @@ -1023,26 +1023,77 @@ def json_schema(schema: CoreSchema | None = None, *, ref: str | None = None, ext class UrlSchema(TypedDict, total=False): type: Required[Literal['url']] - host_required: bool # default False max_length: int allowed_schemes: List[str] + host_required: bool # default False + default_host: str + default_port: int + default_path: str + strict: bool ref: str extra: Any def url_schema( *, - host_required: bool | None = None, max_length: int | None = None, allowed_schemes: list[str] | None = None, + host_required: bool | None = None, + default_host: str | None = None, + default_port: int | None = None, + default_path: str | None = None, + strict: bool | None = None, ref: str | None = None, extra: Any = None, ) -> UrlSchema: return dict_not_none( type='url', + max_length=max_length, + allowed_schemes=allowed_schemes, host_required=host_required, + default_host=default_host, + default_port=default_port, + default_path=default_path, + strict=strict, + ref=ref, + extra=extra, + ) + + +class MultiHostUrlSchema(TypedDict, total=False): + type: Required[Literal['multi-host-url']] + max_length: int + allowed_schemes: List[str] + host_required: bool # default False + default_host: str + default_port: int + default_path: str + strict: bool + ref: str + extra: Any + + +def multi_host_url_schema( + *, + max_length: int | None = None, + allowed_schemes: list[str] | None = None, + host_required: bool | None = None, + default_host: str | None = None, + default_port: int | None = None, + default_path: str | None = None, + strict: bool | None = None, + ref: str | None = None, + extra: Any = None, +) -> MultiHostUrlSchema: + return dict_not_none( + type='multi-host-url', max_length=max_length, allowed_schemes=allowed_schemes, + host_required=host_required, + default_host=default_host, + default_port=default_port, + default_path=default_path, + strict=strict, ref=ref, extra=extra, ) @@ -1087,6 +1138,7 @@ def url_schema( CustomErrorSchema, JsonSchema, UrlSchema, + MultiHostUrlSchema, ] # used in _pydantic_core.pyi::PydanticKnownError @@ -1171,6 +1223,5 @@ def url_schema( 'url_parsing', 'url_syntax_violation', 'url_too_long', - 'url_schema', - 'url_host_required', + 'url_scheme', ] diff --git a/src/errors/types.rs b/src/errors/types.rs index 56ed4038d..7a300ef29 100644 --- a/src/errors/types.rs +++ b/src/errors/types.rs @@ -334,12 +334,10 @@ pub enum ErrorType { UrlTooLong { max_length: usize, }, - #[strum(message = "URL schema should be {expected_schemas}")] - UrlSchema { - expected_schemas: String, + #[strum(message = "URL scheme should be {expected_schemes}")] + UrlScheme { + expected_schemes: String, }, - #[strum(message = "URL host required")] - UrlHostRequired, } macro_rules! render { @@ -475,7 +473,7 @@ impl ErrorType { Self::UrlParsing { .. } => extract_context!(UrlParsing, ctx, error: String), Self::UrlSyntaxViolation { .. } => extract_context!(Cow::Owned, UrlSyntaxViolation, ctx, error: String), Self::UrlTooLong { .. } => extract_context!(UrlTooLong, ctx, max_length: usize), - Self::UrlSchema { .. } => extract_context!(UrlSchema, ctx, expected_schemas: String), + Self::UrlScheme { .. } => extract_context!(UrlScheme, ctx, expected_schemes: String), _ => { if ctx.is_some() { py_err!(PyTypeError; "'{}' errors do not require context", value) @@ -566,7 +564,7 @@ impl ErrorType { Self::UrlParsing { error } => render!(self, error), Self::UrlSyntaxViolation { error } => render!(self, error), Self::UrlTooLong { max_length } => to_string_render!(self, max_length), - Self::UrlSchema { expected_schemas } => render!(self, expected_schemas), + Self::UrlScheme { expected_schemes } => render!(self, expected_schemes), _ => Ok(self.message_template().to_string()), } } @@ -619,7 +617,7 @@ impl ErrorType { Self::UrlParsing { error } => py_dict!(py, error), Self::UrlSyntaxViolation { error } => py_dict!(py, error), Self::UrlTooLong { max_length } => py_dict!(py, max_length), - Self::UrlSchema { expected_schemas } => py_dict!(py, expected_schemas), + Self::UrlScheme { expected_schemes } => py_dict!(py, expected_schemes), _ => Ok(None), } } diff --git a/src/input/input_abstract.rs b/src/input/input_abstract.rs index 02264b4eb..d33a2611c 100644 --- a/src/input/input_abstract.rs +++ b/src/input/input_abstract.rs @@ -4,7 +4,7 @@ use pyo3::prelude::*; use pyo3::types::{PyString, PyType}; use crate::errors::{InputValue, LocItem, ValResult}; -use crate::PyUrl; +use crate::{PyMultiHostUrl, PyUrl}; use super::datetime::{EitherDate, EitherDateTime, EitherTime, EitherTimedelta}; use super::return_enums::{EitherBytes, EitherString}; @@ -59,6 +59,10 @@ pub trait Input<'a>: fmt::Debug + ToPyObject { None } + fn input_as_multi_host_url(&self) -> Option { + None + } + fn callable(&self) -> bool { false } diff --git a/src/input/input_python.rs b/src/input/input_python.rs index 246f6fdcb..fa1f2dfe2 100644 --- a/src/input/input_python.rs +++ b/src/input/input_python.rs @@ -13,7 +13,7 @@ use pyo3::types::{PyDictItems, PyDictKeys, PyDictValues}; use pyo3::{ffi, intern, AsPyPointer, PyTypeInfo}; use crate::errors::{py_err_string, ErrorType, InputValue, LocItem, ValError, ValLineError, ValResult}; -use crate::PyUrl; +use crate::{PyMultiHostUrl, PyUrl}; use super::datetime::{ bytes_as_date, bytes_as_datetime, bytes_as_time, bytes_as_timedelta, date_as_datetime, float_as_datetime, @@ -113,6 +113,10 @@ impl<'a> Input<'a> for PyAny { self.extract::().ok() } + fn input_as_multi_host_url(&self) -> Option { + self.extract::().ok() + } + fn callable(&self) -> bool { self.is_callable() } diff --git a/src/lib.rs b/src/lib.rs index ccd3d893d..f37095d10 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -18,7 +18,7 @@ mod url; mod validators; // required for benchmarks -pub use self::url::PyUrl; +pub use self::url::{PyMultiHostUrl, PyUrl}; pub use build_tools::SchemaError; pub use errors::{list_all_errors, PydanticCustomError, PydanticKnownError, PydanticOmit, ValidationError}; pub use validators::SchemaValidator; @@ -44,6 +44,7 @@ fn _pydantic_core(_py: Python, m: &PyModule) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; + m.add_class::()?; m.add_function(wrap_pyfunction!(list_all_errors, m)?)?; Ok(()) } diff --git a/src/url.rs b/src/url.rs index 69a81a537..970e5e82b 100644 --- a/src/url.rs +++ b/src/url.rs @@ -1,46 +1,23 @@ use idna::punycode::decode_to_string; use pyo3::prelude::*; +use pyo3::types::PyDict; +use url::Url; #[pyclass(name = "Url", module = "pydantic_core._pydantic_core")] #[derive(Clone)] #[cfg_attr(debug_assertions, derive(Debug))] pub struct PyUrl { - lib_url: url::Url, + lib_url: Url, } -static PUNYCODE_PREFIX: &str = "xn--"; - impl PyUrl { - pub fn new(lib_url: url::Url) -> Self { + pub fn new(lib_url: Url) -> Self { Self { lib_url } } - pub fn into_url(self) -> url::Url { + pub fn into_url(self) -> Url { self.lib_url } - - fn decode_punycode(&self, domain: &str) -> Option { - let mut result = String::with_capacity(domain.len()); - for chunk in domain.split('.') { - if let Some(stripped) = chunk.strip_prefix(PUNYCODE_PREFIX) { - result.push_str(&decode_to_string(stripped)?); - } else { - result.push_str(chunk); - } - result.push('.'); - } - result.pop(); - Some(result) - } - - fn is_punnycode_domain(&self, domain: &str) -> bool { - self.is_special() && domain.split('.').any(|part| part.starts_with(PUNYCODE_PREFIX)) - } - - // based on https://github.com/servo/rust-url/blob/1c1e406874b3d2aa6f36c5d2f3a5c2ea74af9efb/url/src/parser.rs#L161-L167 - fn is_special(&self) -> bool { - matches!(self.lib_url.scheme(), "http" | "https" | "ws" | "wss" | "ftp" | "file") - } } #[pymethods] @@ -71,25 +48,14 @@ impl PyUrl { // string representation of the host, with punycode decoded when appropriate pub fn unicode_host(&self) -> Option { match self.lib_url.host() { - Some(url::Host::Domain(domain)) if self.is_punnycode_domain(domain) => self.decode_punycode(domain), + Some(url::Host::Domain(domain)) if is_punnycode_domain(&self.lib_url, domain) => decode_punycode(domain), _ => self.lib_url.host_str().map(|h| h.to_string()), } } - #[getter] - pub fn host_type(&self) -> Option<&'static str> { - match self.lib_url.host() { - Some(url::Host::Domain(domain)) if self.is_punnycode_domain(domain) => Some("punycode_domain"), - Some(url::Host::Domain(_)) => Some("domain"), - Some(url::Host::Ipv4(_)) => Some("ipv4"), - Some(url::Host::Ipv6(_)) => Some("ipv6"), - None => None, - } - } - #[getter] pub fn port(&self) -> Option { - self.lib_url.port() + self.lib_url.port_or_known_default() } #[getter] @@ -121,22 +87,7 @@ impl PyUrl { // string representation of the URL, with punycode decoded when appropriate pub fn unicode_string(&self) -> String { - let s = self.lib_url.to_string(); - - match self.lib_url.host() { - Some(url::Host::Domain(domain)) if self.is_punnycode_domain(domain) => { - // we know here that we have a punycode domain, so we simply replace the first instance - // of the punycode domain with the decoded domain - // this is ugly, but since `slice()`, `host_start` and `host_end` are all private to `Url`, - // we have no better option, since the `schema` has to be `https`, `http` etc, (see `is_special` above), - // we can safely assume that the first match for the domain, is the domain - match self.decode_punycode(domain) { - Some(decoded) => s.replacen(domain, &decoded, 1), - None => s, - } - } - _ => s, - } + unicode_url(&self.lib_url) } pub fn __str__(&self) -> &str { @@ -147,3 +98,184 @@ impl PyUrl { format!("Url('{}')", self.lib_url) } } + +#[pyclass(name = "MultiHostUrl", module = "pydantic_core._pydantic_core")] +#[derive(Clone)] +#[cfg_attr(debug_assertions, derive(Debug))] +pub struct PyMultiHostUrl { + ref_url: PyUrl, + extra_urls: Option>, +} + +impl PyMultiHostUrl { + pub fn new(ref_url: Url, extra_urls: Option>) -> Self { + Self { + ref_url: PyUrl::new(ref_url), + extra_urls, + } + } + + pub fn mut_lib_url(&mut self) -> &mut Url { + &mut self.ref_url.lib_url + } +} + +#[pymethods] +impl PyMultiHostUrl { + #[getter] + pub fn scheme(&self) -> &str { + self.ref_url.scheme() + } + + pub fn hosts<'s, 'py>(&'s self, py: Python<'py>) -> PyResult> { + if let Some(extra_urls) = &self.extra_urls { + let mut hosts = Vec::with_capacity(extra_urls.len() + 1); + for url in extra_urls { + hosts.push(host_to_dict(py, url)?); + } + hosts.push(host_to_dict(py, &self.ref_url.lib_url)?); + Ok(hosts) + } else if self.ref_url.lib_url.has_host() { + Ok(vec![host_to_dict(py, &self.ref_url.lib_url)?]) + } else { + Ok(vec![]) + } + } + + #[getter] + pub fn path(&self) -> Option<&str> { + self.ref_url.path() + } + + #[getter] + pub fn query(&self) -> Option<&str> { + self.ref_url.query() + } + + pub fn query_params(&self, py: Python) -> PyObject { + self.ref_url.query_params(py) + } + + #[getter] + pub fn fragment(&self) -> Option<&str> { + self.ref_url.fragment() + } + + // string representation of the URL, with punycode decoded when appropriate + pub fn unicode_string(&self) -> String { + if let Some(extra_urls) = &self.extra_urls { + let schema = self.ref_url.lib_url.scheme(); + let host_offset = schema.len() + 3; + + let mut full_url = self.ref_url.unicode_string(); + full_url.insert(host_offset, ','); + + // special urls will have had a trailing slash added, non-special urls will not + // hence we need to remove the last char if the schema is special + #[allow(clippy::bool_to_int_with_if)] + let sub = if schema_is_special(schema) { 1 } else { 0 }; + + let hosts = extra_urls + .iter() + .map(|url| { + let str = unicode_url(url); + str[host_offset..str.len() - sub].to_string() + }) + .collect::>() + .join(","); + full_url.insert_str(host_offset, &hosts); + full_url + } else { + self.ref_url.unicode_string() + } + } + + pub fn __str__(&self) -> String { + if let Some(extra_urls) = &self.extra_urls { + let schema = self.ref_url.lib_url.scheme(); + let host_offset = schema.len() + 3; + + let mut full_url = self.ref_url.lib_url.to_string(); + full_url.insert(host_offset, ','); + + // special urls will have had a trailing slash added, non-special urls will not + // hence we need to remove the last char if the schema is special + #[allow(clippy::bool_to_int_with_if)] + let sub = if schema_is_special(schema) { 1 } else { 0 }; + + let hosts = extra_urls + .iter() + .map(|url| { + let str = url.as_str(); + &str[host_offset..str.len() - sub] + }) + .collect::>() + .join(","); + full_url.insert_str(host_offset, &hosts); + full_url + } else { + self.ref_url.__str__().to_string() + } + } + + pub fn __repr__(&self) -> String { + format!("Url('{}')", self.__str__()) + } +} + +fn host_to_dict<'a, 'b>(py: Python<'a>, lib_url: &'b Url) -> PyResult<&'a PyDict> { + let dict = PyDict::new(py); + dict.set_item( + "username", + match lib_url.username() { + "" => py.None(), + user => user.into_py(py), + }, + )?; + dict.set_item("password", lib_url.password())?; + dict.set_item("host", lib_url.host_str())?; + dict.set_item("port", lib_url.port_or_known_default())?; + + Ok(dict) +} + +fn unicode_url(lib_url: &Url) -> String { + let mut s = lib_url.to_string(); + + match lib_url.host() { + Some(url::Host::Domain(domain)) if is_punnycode_domain(lib_url, domain) => { + if let Some(decoded) = decode_punycode(domain) { + // replace the range containing the punycode domain with the decoded domain + let start = lib_url.scheme().len() + 3; + s.replace_range(start..start + domain.len(), &decoded) + } + s + } + _ => s, + } +} + +fn decode_punycode(domain: &str) -> Option { + let mut result = String::with_capacity(domain.len()); + for chunk in domain.split('.') { + if let Some(stripped) = chunk.strip_prefix(PUNYCODE_PREFIX) { + result.push_str(&decode_to_string(stripped)?); + } else { + result.push_str(chunk); + } + result.push('.'); + } + result.pop(); + Some(result) +} + +static PUNYCODE_PREFIX: &str = "xn--"; + +fn is_punnycode_domain(lib_url: &Url, domain: &str) -> bool { + schema_is_special(lib_url.scheme()) && domain.split('.').any(|part| part.starts_with(PUNYCODE_PREFIX)) +} + +// based on https://github.com/servo/rust-url/blob/1c1e406874b3d2aa6f36c5d2f3a5c2ea74af9efb/url/src/parser.rs#L161-L167 +pub fn schema_is_special(schema: &str) -> bool { + matches!(schema, "http" | "https" | "ws" | "wss" | "ftp" | "file") +} diff --git a/src/validators/mod.rs b/src/validators/mod.rs index 52e49687f..b62531a6b 100644 --- a/src/validators/mod.rs +++ b/src/validators/mod.rs @@ -386,8 +386,9 @@ pub fn build_validator<'a>( custom_error::CustomErrorValidator, // json data json::JsonValidator, - // url type + // url types url::UrlValidator, + url::MultiHostUrlValidator, ) } @@ -507,8 +508,9 @@ pub enum CombinedValidator { CustomError(custom_error::CustomErrorValidator), // json data Json(json::JsonValidator), - // url type + // url types Url(url::UrlValidator), + MultiHostUrl(url::MultiHostUrlValidator), } /// This trait must be implemented by all validators, it allows various validators to be accessed consistently, diff --git a/src/validators/url.rs b/src/validators/url.rs index f4c70fe0f..5b5408a76 100644 --- a/src/validators/url.rs +++ b/src/validators/url.rs @@ -1,28 +1,34 @@ use std::cell::RefCell; +use std::iter::Peekable; +use std::str::Chars; use pyo3::intern; use pyo3::prelude::*; use pyo3::types::{PyDict, PyList}; use ahash::AHashSet; -use url::{SyntaxViolation, Url}; +use url::{ParseError, SyntaxViolation, Url}; use crate::build_tools::{is_strict, py_err, SchemaDict}; use crate::errors::{ErrorType, ValError, ValResult}; use crate::input::Input; use crate::recursion_guard::RecursionGuard; -use crate::PyUrl; +use crate::url::{schema_is_special, PyMultiHostUrl, PyUrl}; use super::literal::expected_repr_name; use super::{BuildContext, BuildValidator, CombinedValidator, Extra, Validator}; +type AllowedSchemas = Option<(AHashSet, String)>; + #[derive(Debug, Clone)] pub struct UrlValidator { strict: bool, - host_required: bool, max_length: Option, - allowed_schemes: Option>, - expected_repr: Option, + allowed_schemes: AllowedSchemas, + host_required: bool, + default_host: Option, + default_port: Option, + default_path: Option, name: String, } @@ -34,32 +40,16 @@ impl BuildValidator for UrlValidator { config: Option<&PyDict>, _build_context: &mut BuildContext, ) -> PyResult { - let (allowed_schemes, expected_repr, name): (Option>, Option, String) = - match schema.get_as::<&PyList>(intern!(schema.py(), "allowed_schemes"))? { - Some(list) => { - if list.is_empty() { - return py_err!(r#""allowed_schemes" should have length > 0"#); - } - - let mut expected: AHashSet = AHashSet::new(); - let mut repr_args = Vec::new(); - for item in list.iter() { - let str = item.extract()?; - repr_args.push(format!("'{str}'")); - expected.insert(str); - } - let (repr, name) = expected_repr_name(repr_args, "literal"); - (Some(expected), Some(repr), name) - } - None => (None, None, Self::EXPECTED_TYPE.to_string()), - }; + let (allowed_schemes, name) = get_allowed_schemas(schema, Self::EXPECTED_TYPE)?; Ok(Self { strict: is_strict(schema, config)?, - host_required: schema.get_as(intern!(schema.py(), "host_required"))?.unwrap_or(false), max_length: schema.get_as(intern!(schema.py(), "max_length"))?, + host_required: schema.get_as(intern!(schema.py(), "host_required"))?.unwrap_or(false), + default_host: schema.get_as(intern!(schema.py(), "default_host"))?, + default_port: schema.get_as(intern!(schema.py(), "default_port"))?, + default_path: schema.get_as(intern!(schema.py(), "default_path"))?, allowed_schemes, - expected_repr, name, } .into()) @@ -75,18 +65,25 @@ impl Validator for UrlValidator { _slots: &'data [CombinedValidator], _recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { - let lib_url = self.get_url(input, extra.strict.unwrap_or(self.strict))?; + let mut lib_url = self.get_url(input, extra.strict.unwrap_or(self.strict))?; - if let Some(ref allowed_schemes) = self.allowed_schemes { + if let Some((ref allowed_schemes, ref expected_schemes_repr)) = self.allowed_schemes { if !allowed_schemes.contains(lib_url.scheme()) { - let expected_schemas = self.expected_repr.as_ref().unwrap().clone(); - return Err(ValError::new(ErrorType::UrlSchema { expected_schemas }, input)); + let expected_schemes = expected_schemes_repr.clone(); + return Err(ValError::new(ErrorType::UrlScheme { expected_schemes }, input)); } } - if self.host_required && !lib_url.has_host() { - return Err(ValError::new(ErrorType::UrlHostRequired, input)); + + match check_sub_defaults( + &mut lib_url, + self.host_required, + &self.default_host, + self.default_port, + &self.default_path, + ) { + Ok(()) => Ok(PyUrl::new(lib_url).into_py(py)), + Err(error_type) => return Err(ValError::new(error_type, input)), } - Ok(PyUrl::new(lib_url).into_py(py)) } fn get_name(&self) -> &str { @@ -100,63 +97,389 @@ impl UrlValidator { Ok(either_str) => { let cow = either_str.as_cow()?; let url_str = cow.as_ref(); - self.parse_str(url_str, input, strict) + + self.check_length(input, url_str)?; + + parse_url(url_str, input, strict) } Err(_) => { // we don't need to worry about whether the url was parsed in strict mode before, // even if it was, any syntax errors would have been fixed by the first validation - let lib_url = match input.input_as_url() { - Some(url) => url.into_url(), - None => return Err(ValError::new(ErrorType::UrlType, input)), - }; - if let Some(max_length) = self.max_length { - if lib_url.as_str().len() > max_length { - return Err(ValError::new(ErrorType::UrlTooLong { max_length }, input)); - } + if let Some(py_url) = input.input_as_url() { + let lib_url = py_url.into_url(); + self.check_length(input, lib_url.as_str())?; + Ok(lib_url) + } else if let Some(multi_host_url) = input.input_as_multi_host_url() { + let url_str = multi_host_url.__str__(); + self.check_length(input, &url_str)?; + + parse_url(&url_str, input, strict) + } else { + Err(ValError::new(ErrorType::UrlType, input)) } - Ok(lib_url) } } } - fn parse_str<'s, 'url, 'input>( - &'s self, - url_str: &'url str, - input: &'input impl Input<'input>, - strict: bool, - ) -> ValResult<'input, Url> { + fn check_length<'s, 'data>(&self, input: &'data impl Input<'data>, url_str: &str) -> ValResult<'data, ()> { if let Some(max_length) = self.max_length { if url_str.len() > max_length { return Err(ValError::new(ErrorType::UrlTooLong { max_length }, input)); } } + Ok(()) + } +} + +#[derive(Debug, Clone)] +pub struct MultiHostUrlValidator { + strict: bool, + max_length: Option, + allowed_schemes: AllowedSchemas, + host_required: bool, + default_host: Option, + default_port: Option, + default_path: Option, + name: String, +} + +impl BuildValidator for MultiHostUrlValidator { + const EXPECTED_TYPE: &'static str = "multi-host-url"; + + fn build( + schema: &PyDict, + config: Option<&PyDict>, + _build_context: &mut BuildContext, + ) -> PyResult { + let (allowed_schemes, name) = get_allowed_schemas(schema, Self::EXPECTED_TYPE)?; + + let default_host: Option = schema.get_as(intern!(schema.py(), "default_host"))?; + if let Some(ref default_host) = default_host { + if default_host.contains(',') { + return py_err!("default_host cannot contain a comma, see pydantic-core#326"); + } + } + Ok(Self { + strict: is_strict(schema, config)?, + max_length: schema.get_as(intern!(schema.py(), "max_length"))?, + allowed_schemes, + host_required: schema.get_as(intern!(schema.py(), "host_required"))?.unwrap_or(false), + default_host, + default_port: schema.get_as(intern!(schema.py(), "default_port"))?, + default_path: schema.get_as(intern!(schema.py(), "default_path"))?, + name, + } + .into()) + } +} + +impl Validator for MultiHostUrlValidator { + fn validate<'s, 'data>( + &'s self, + py: Python<'data>, + input: &'data impl Input<'data>, + extra: &Extra, + _slots: &'data [CombinedValidator], + _recursion_guard: &'s mut RecursionGuard, + ) -> ValResult<'data, PyObject> { + let mut multi_url = self.get_url(input, extra.strict.unwrap_or(self.strict))?; + + if let Some((ref allowed_schemes, ref expected_schemes_repr)) = self.allowed_schemes { + if !allowed_schemes.contains(multi_url.scheme()) { + let expected_schemes = expected_schemes_repr.clone(); + return Err(ValError::new(ErrorType::UrlScheme { expected_schemes }, input)); + } + } + match check_sub_defaults( + multi_url.mut_lib_url(), + self.host_required, + &self.default_host, + self.default_port, + &self.default_path, + ) { + Ok(()) => Ok(multi_url.into_py(py)), + Err(error_type) => return Err(ValError::new(error_type, input)), + } + } + + fn get_name(&self) -> &str { + &self.name + } +} + +impl MultiHostUrlValidator { + fn get_url<'s, 'data>(&'s self, input: &'data impl Input<'data>, strict: bool) -> ValResult<'data, PyMultiHostUrl> { + match input.validate_str(strict) { + Ok(either_str) => { + let cow = either_str.as_cow()?; + let url_str = cow.as_ref(); + + self.check_length(input, || url_str.len())?; + + parse_multihost_url(url_str, input, strict) + } + Err(_) => { + // we don't need to worry about whether the url was parsed in strict mode before, + // even if it was, any syntax errors would have been fixed by the first validation + if let Some(multi_url) = input.input_as_multi_host_url() { + self.check_length(input, || multi_url.__str__().len())?; + Ok(multi_url) + } else if let Some(py_url) = input.input_as_url() { + let lib_url = py_url.into_url(); + self.check_length(input, || lib_url.as_str().len())?; + Ok(PyMultiHostUrl::new(lib_url, None)) + } else { + Err(ValError::new(ErrorType::UrlType, input)) + } + } + } + } + + fn check_length<'s, 'data, F>(&self, input: &'data impl Input<'data>, func: F) -> ValResult<'data, ()> + where + F: FnOnce() -> usize, + { + if let Some(max_length) = self.max_length { + if func() > max_length { + return Err(ValError::new(ErrorType::UrlTooLong { max_length }, input)); + } + } + Ok(()) + } +} + +fn parse_multihost_url<'url, 'input>( + url_str: &'url str, + input: &'input impl Input<'input>, + strict: bool, +) -> ValResult<'input, PyMultiHostUrl> { + let mut chars = PositionedPeekable::new(url_str); + // consume whitespace, taken from `with_log` + // https://github.com/servo/rust-url/blob/v2.3.1/url/src/parser.rs#L213-L226 + loop { + let peek = chars.peek(); + if let Some(c) = peek { + match c { + '\t' | '\n' | '\r' => (), + c if c <= &' ' => (), + _ => break, + }; + chars.next(); + } else { + break; + } + } + + macro_rules! parsing_err { + ($parse_error:expr) => { + Err(ValError::new( + ErrorType::UrlParsing { + error: $parse_error.to_string(), + }, + input, + )) + }; + } + + // consume the url schema, taken from `parse_scheme` + // https://github.com/servo/rust-url/blob/v2.3.1/url/src/parser.rs#L387-L411 + let schema_start = chars.position; + while let Some(c) = chars.next() { + match c { + 'a'..='z' | 'A'..='Z' | '0'..='9' | '+' | '-' | '.' => (), + ':' => break, + _ => return parsing_err!(ParseError::RelativeUrlWithoutBase), + } + } + let schema = url_str[schema_start..chars.position - 1].to_ascii_lowercase(); + + // consume the double slash, or any number of slashes, including backslashes, taken from `parse_with_scheme` + // https://github.com/servo/rust-url/blob/v2.3.1/url/src/parser.rs#L413-L456 + loop { + let peek = chars.peek(); + match peek { + Some(&'/') | Some(&'\\') => { + chars.next(); + } + _ => break, + } + } + let prefix = &url_str[..chars.position]; + + // process host and port, splitting based on `,`, some logic taken from `parse_host` + // https://github.com/servo/rust-url/blob/v2.3.1/url/src/parser.rs#L971-L1026 + let mut hosts: Vec<&str> = Vec::with_capacity(3); + let mut start = chars.position; + while let Some(c) = chars.next() { + match c { + '\\' if schema_is_special(&schema) => break, + '/' | '?' | '#' => break, + ',' => { + // minus 1, not `chars.last_len` because we know that the last char was a `,` with length 1 + let end = chars.position - 1; + if start == end { + return parsing_err!(ParseError::EmptyHost); + } + hosts.push(&url_str[start..end]); + start = chars.position; + } + _ => (), + } + } + // with just one host, for consistent behaviour, we parse the URL the same as with multiple hosts + + let reconstructed_url = format!("{prefix}{}", &url_str[start..]); + let ref_url = parse_url(&reconstructed_url, input, strict)?; + + if hosts.is_empty() { + // if there's no one host (e.g. no `,`), we allow it to be empty to allow for default hosts + Ok(PyMultiHostUrl::new(ref_url, None)) + } else { + // with more than one host, none of them can be empty + if !ref_url.has_host() { + return parsing_err!(ParseError::EmptyHost); + } + let extra_urls: Vec = hosts + .iter() + .map(|host| { + let reconstructed_url = format!("{prefix}{host}"); + parse_url(&reconstructed_url, input, strict) + }) + .collect::>()?; - // if we're in strict mode, we collect consider a syntax violation as an error - if strict { - // we could build a vec of syntax violations and return them all, but that seems like overkill - // and unlike other parser style validators - let vios: RefCell> = RefCell::new(None); - let r = Url::options() - .syntax_violation_callback(Some(&|v| *vios.borrow_mut() = Some(v))) - .parse(url_str); - - match r { - Ok(url) => { - if let Some(vio) = vios.into_inner() { - Err(ValError::new( - ErrorType::UrlSyntaxViolation { - error: vio.description().into(), - }, - input, - )) - } else { - Ok(url) - } + if extra_urls.iter().any(|url| !url.has_host()) { + return parsing_err!(ParseError::EmptyHost); + } + + Ok(PyMultiHostUrl::new(ref_url, Some(extra_urls))) + } +} + +fn parse_url<'url, 'input>( + url_str: &'url str, + input: &'input impl Input<'input>, + strict: bool, +) -> ValResult<'input, Url> { + // if we're in strict mode, we collect consider a syntax violation as an error + if strict { + // we could build a vec of syntax violations and return them all, but that seems like overkill + // and unlike other parser style validators + let vios: RefCell> = RefCell::new(None); + let r = Url::options() + .syntax_violation_callback(Some(&|v| { + match v { + // telling users offer about credentials in URLs doesn't really make sense in this context + SyntaxViolation::EmbeddedCredentials => (), + _ => *vios.borrow_mut() = Some(v), + } + })) + .parse(url_str); + + match r { + Ok(url) => { + if let Some(vio) = vios.into_inner() { + Err(ValError::new( + ErrorType::UrlSyntaxViolation { + error: vio.description().into(), + }, + input, + )) + } else { + Ok(url) } - Err(e) => Err(ValError::new(ErrorType::UrlParsing { error: e.to_string() }, input)), } + Err(e) => Err(ValError::new(ErrorType::UrlParsing { error: e.to_string() }, input)), + } + } else { + Url::parse(url_str).map_err(move |e| ValError::new(ErrorType::UrlParsing { error: e.to_string() }, input)) + } +} + +/// check host_required and substitute `default_host`, `default_port` & `default_path` if they aren't set +fn check_sub_defaults( + lib_url: &mut Url, + host_required: bool, + default_host: &Option, + default_port: Option, + default_path: &Option, +) -> Result<(), ErrorType> { + let map_parse_err = |e: ParseError| ErrorType::UrlParsing { error: e.to_string() }; + if !lib_url.has_host() { + if let Some(ref default_host) = default_host { + lib_url.set_host(Some(default_host)).map_err(map_parse_err)?; + } else if host_required { + return Err(ErrorType::UrlParsing { + error: ParseError::EmptyHost.to_string(), + }); + } + } + if lib_url.port().is_none() { + if let Some(default_port) = default_port { + lib_url + .set_port(Some(default_port)) + .map_err(|_| map_parse_err(ParseError::EmptyHost))?; + } + } + if let Some(ref default_path) = default_path { + let path = lib_url.path(); + if path.is_empty() || path == "/" { + lib_url.set_path(default_path); + } + } + Ok(()) +} + +fn get_allowed_schemas(schema: &PyDict, name: &'static str) -> PyResult<(AllowedSchemas, String)> { + match schema.get_as::<&PyList>(intern!(schema.py(), "allowed_schemes"))? { + Some(list) => { + if list.is_empty() { + return py_err!(r#""allowed_schemes" should have length > 0"#); + } + + let mut expected: AHashSet = AHashSet::new(); + let mut repr_args = Vec::new(); + for item in list.iter() { + let str = item.extract()?; + repr_args.push(format!("'{str}'")); + expected.insert(str); + } + let (repr, name) = expected_repr_name(repr_args, name); + Ok((Some((expected, repr)), name)) + } + None => Ok((None, name.to_string())), + } +} + +#[cfg_attr(debug_assertions, derive(Debug))] +struct PositionedPeekable<'a> { + peekable: Peekable>, + position: usize, + last_len: usize, +} + +impl<'a> PositionedPeekable<'a> { + fn new(input: &'a str) -> Self { + Self { + peekable: input.chars().peekable(), + position: 0, + last_len: 0, + } + } + + fn next(&mut self) -> Option { + let c = self.peekable.next(); + if let Some(c) = c { + self.last_len = c.len_utf8(); + self.position += self.last_len; } else { - Url::parse(url_str).map_err(move |e| ValError::new(ErrorType::UrlParsing { error: e.to_string() }, input)) + // needs to be zero here so if you do `position - last_len` after the last char, you get the + // correct position + self.last_len = 0; } + c + } + + fn peek(&mut self) -> Option<&char> { + self.peekable.peek() } } diff --git a/tests/test_errors.py b/tests/test_errors.py index 80930401c..b8a038850 100644 --- a/tests/test_errors.py +++ b/tests/test_errors.py @@ -259,8 +259,7 @@ def f(input_value, **kwargs): ('url_parsing', 'Input should be a valid URL, Foobar', {'error': 'Foobar'}), ('url_syntax_violation', 'Input violated strict URL syntax rules, Foobar', {'error': 'Foobar'}), ('url_too_long', 'URL should have at most 42 characters', {'max_length': 42}), - ('url_schema', 'URL schema should be "foo", "bar" or "spam"', {'expected_schemas': '"foo", "bar" or "spam"'}), - ('url_host_required', 'URL host required', None), + ('url_scheme', 'URL scheme should be "foo", "bar" or "spam"', {'expected_schemes': '"foo", "bar" or "spam"'}), ] diff --git a/tests/validators/test_url.py b/tests/validators/test_url.py index 23ed47c6f..f921ea38b 100644 --- a/tests/validators/test_url.py +++ b/tests/validators/test_url.py @@ -1,7 +1,9 @@ +import re + import pytest from dirty_equals import HasRepr, IsInstance -from pydantic_core import SchemaValidator, Url, ValidationError, core_schema +from pydantic_core import MultiHostUrl, SchemaError, SchemaValidator, Url, ValidationError, core_schema from ..conftest import Err, PyAndJson @@ -22,8 +24,7 @@ def test_url_ok(py_and_json: PyAndJson): assert url.fragment == 'quux' assert url.username is None assert url.password is None - assert url.port is None - assert url.host_type == 'domain' + assert url.port == 443 @pytest.fixture(scope='module', name='url_validator') @@ -38,16 +39,18 @@ def url_validator_fixture(): 'http://example.com', { 'str()': 'http://example.com/', - 'host_type': 'domain', 'host': 'example.com', 'unicode_host()': 'example.com', 'unicode_string()': 'http://example.com/', }, ), - ('http://exa\nmple.com', {'str()': 'http://example.com/', 'host_type': 'domain', 'host': 'example.com'}), + ('http://exa\nmple.com', {'str()': 'http://example.com/', 'host': 'example.com'}), ('xxx', Err('relative URL without a base')), ('http://', Err('empty host')), ('https://xn---', Err('invalid international domain name')), + ('http://example.com:65535', 'http://example.com:65535/'), + ('http:\\\\example.com', 'http://example.com/'), + ('http:example.com', 'http://example.com/'), ('http://example.com:65536', Err('invalid port number')), ('http://1...1', Err('invalid IPv4 address')), ('https://[2001:0db8:85a3:0000:0000:8a2e:0370:7334[', Err('invalid IPv6 address')), @@ -56,48 +59,55 @@ def url_validator_fixture(): ('http://exam%ple.com', Err('invalid domain character')), ('http:// /', Err('invalid domain character')), ('/more', Err('relative URL without a base')), - ('http://example.com./foobar', {'str()': 'http://example.com./foobar', 'host_type': 'domain'}), + ('http://example.com./foobar', {'str()': 'http://example.com./foobar'}), # works since we're in lax mode - ( - b'http://example.com', - {'str()': 'http://example.com/', 'host_type': 'domain', 'unicode_host()': 'example.com'}, - ), - ('http:/foo', {'str()': 'http://foo/', 'host_type': 'domain'}), - ('http:///foo', {'str()': 'http://foo/', 'host_type': 'domain'}), - ('http://exam_ple.com', {'str()': 'http://exam_ple.com/', 'host_type': 'domain'}), - ('http://exam-ple.com', {'str()': 'http://exam-ple.com/', 'host_type': 'domain'}), - ('http://example-.com', {'str()': 'http://example-.com/', 'host_type': 'domain'}), - ('https://£££.com', {'str()': 'https://xn--9aaa.com/', 'host_type': 'punycode_domain'}), - ('https://foobar.£££.com', {'str()': 'https://foobar.xn--9aaa.com/', 'host_type': 'punycode_domain'}), - ('https://foo.£$.money.com', {'str()': 'https://foo.xn--$-9ba.money.com/', 'host_type': 'punycode_domain'}), - ('https://xn--9aaa.com/', {'str()': 'https://xn--9aaa.com/', 'host_type': 'punycode_domain'}), - ('https://münchen/', {'str()': 'https://xn--mnchen-3ya/', 'host_type': 'punycode_domain'}), - ('http://à.א̈.com', {'str()': 'http://xn--0ca.xn--ssa73l.com/', 'host_type': 'punycode_domain'}), + (b'http://example.com', {'str()': 'http://example.com/', 'unicode_host()': 'example.com'}), + ('http:/foo', {'str()': 'http://foo/'}), + ('http:///foo', {'str()': 'http://foo/'}), + ('http://exam_ple.com', {'str()': 'http://exam_ple.com/'}), + ('http://exam-ple.com', {'str()': 'http://exam-ple.com/'}), + ('http://example-.com', {'str()': 'http://example-.com/'}), + ('https://£££.com', {'str()': 'https://xn--9aaa.com/'}), + ('https://foobar.£££.com', {'str()': 'https://foobar.xn--9aaa.com/'}), + ('https://foo.£$.money.com', {'str()': 'https://foo.xn--$-9ba.money.com/'}), + ('https://xn--9aaa.com/', {'str()': 'https://xn--9aaa.com/'}), + ('https://münchen/', {'str()': 'https://xn--mnchen-3ya/'}), + ('http://à.א̈.com', {'str()': 'http://xn--0ca.xn--ssa73l.com/'}), ('ssh://xn--9aaa.com/', 'ssh://xn--9aaa.com/'), ('ssh://münchen.com/', 'ssh://m%C3%BCnchen.com/'), ('ssh://example/', 'ssh://example/'), ('ssh://£££/', 'ssh://%C2%A3%C2%A3%C2%A3/'), ('ssh://%C2%A3%C2%A3%C2%A3/', 'ssh://%C2%A3%C2%A3%C2%A3/'), - ('ftp://127.0.0.1', {'str()': 'ftp://127.0.0.1/', 'host_type': 'ipv4'}), - ( - 'wss://1.1.1.1', - {'str()': 'wss://1.1.1.1/', 'host_type': 'ipv4', 'host': '1.1.1.1', 'unicode_host()': '1.1.1.1'}, - ), + ('ftp://127.0.0.1', {'str()': 'ftp://127.0.0.1/', 'path': '/'}), + ('wss://1.1.1.1', {'str()': 'wss://1.1.1.1/', 'host': '1.1.1.1', 'unicode_host()': '1.1.1.1'}), + ('snap://[::1]', {'str()': 'snap://[::1]', 'host': '[::1]', 'unicode_host()': '[::1]'}), ( 'ftp://[2001:0db8:85a3:0000:0000:8a2e:0370:7334]', { 'str()': 'ftp://[2001:db8:85a3::8a2e:370:7334]/', - 'host_type': 'ipv6', 'host': '[2001:db8:85a3::8a2e:370:7334]', 'unicode_host()': '[2001:db8:85a3::8a2e:370:7334]', }, ), + ('foobar://127.0.0.1', {'str()': 'foobar://127.0.0.1', 'path': None}), + ( + 'mysql://[2001:0db8:85a3:0000:0000:8a2e:0370:7334]', + {'str()': 'mysql://[2001:db8:85a3::8a2e:370:7334]', 'path': None}, + ), + ( + 'mysql://[2001:0db8:85a3:0000:0000:8a2e:0370:7334]/thing', + {'str()': 'mysql://[2001:db8:85a3::8a2e:370:7334]/thing', 'path': '/thing'}, + ), ('https:/more', {'str()': 'https://more/', 'host': 'more'}), ('https:more', {'str()': 'https://more/', 'host': 'more'}), - ('file:///foobar', {'str()': 'file:///foobar', 'host_type': None, 'host': None, 'unicode_host()': None}), - ('file:///:80', {'str()': 'file:///:80', 'host_type': None}), + ('file:///foobar', {'str()': 'file:///foobar', 'host': None, 'unicode_host()': None}), + ('file:///:80', {'str()': 'file:///:80'}), ('file://:80', Err('invalid domain character')), ('foobar://:80', Err('empty host')), + # with bashslashes + ('file:\\\\foobar\\more', {'str()': 'file://foobar/more', 'host': 'foobar', 'path': '/more'}), + ('http:\\\\foobar\\more', {'str()': 'http://foobar/more', 'host': 'foobar', 'path': '/more'}), + ('mongo:\\\\foobar\\more', {'str()': 'mongo:\\\\foobar\\more', 'host': None, 'path': '\\\\foobar\\more'}), ('mongodb+srv://server.example.com/', 'mongodb+srv://server.example.com/'), ('http://example.com.', {'host': 'example.com.', 'unicode_host()': 'example.com.'}), ('http:/example.com', {'host': 'example.com', 'unicode_host()': 'example.com'}), @@ -110,18 +120,84 @@ def url_validator_fixture(): ('https://£££.com', {'host': 'xn--9aaa.com', 'unicode_host()': '£££.com'}), ('https://£££.com.', {'host': 'xn--9aaa.com.', 'unicode_host()': '£££.com.'}), ('https://xn--9aaa.com/', {'host': 'xn--9aaa.com', 'unicode_host()': '£££.com'}), - ('https://münchen/', {'host': 'xn--mnchen-3ya', 'unicode_host()': 'münchen'}), + ( + 'https://münchen/', + {'host': 'xn--mnchen-3ya', 'unicode_host()': 'münchen', 'unicode_string()': 'https://münchen/'}, + ), ('http://à.א̈.com', {'host': 'xn--0ca.xn--ssa73l.com', 'unicode_host()': 'à.א̈.com'}), ('ftp://xn--0ca.xn--ssa73l.com', {'host': 'xn--0ca.xn--ssa73l.com', 'unicode_host()': 'à.א̈.com'}), ('https://foobar.£££.com/', {'host': 'foobar.xn--9aaa.com', 'unicode_host()': 'foobar.£££.com'}), ('https://£££.com', {'unicode_string()': 'https://£££.com/'}), ('https://xn--9aaa.com/', {'unicode_string()': 'https://£££.com/'}), - ('https://münchen/', {'unicode_string()': 'https://münchen/'}), ('wss://1.1.1.1', {'unicode_string()': 'wss://1.1.1.1/'}), ('file:///foobar', {'unicode_string()': 'file:///foobar'}), + ( + 'postgresql+py-postgresql://user:pass@localhost:5432/app', + { + 'str()': 'postgresql+py-postgresql://user:pass@localhost:5432/app', + 'username': 'user', + 'password': 'pass', + }, + ), + ('https://https/', {'host': 'https', 'unicode_host()': 'https'}), + ('http://user:@example.org', {'str()': 'http://user@example.org/', 'username': 'user', 'password': None}), + ( + 'http://us@er:p[ass@example.org', + {'str()': 'http://us%40er:p%5Bass@example.org/', 'username': 'us%40er', 'password': 'p%5Bass'}, + ), + ( + 'http://us%40er:p%5Bass@example.org', + {'str()': 'http://us%40er:p%5Bass@example.org/', 'username': 'us%40er', 'password': 'p%5Bass'}, + ), + ( + 'http://us[]er:p,ass@example.org', + {'str()': 'http://us%5B%5Der:p,ass@example.org/', 'username': 'us%5B%5Der', 'password': 'p,ass'}, + ), + ('http://%2F:@example.org', {'str()': 'http://%2F@example.org/', 'username': '%2F', 'password': None}), + ('foo://user:@example.org', {'str()': 'foo://user@example.org', 'username': 'user', 'password': None}), + ( + 'foo://us@er:p[ass@example.org', + {'str()': 'foo://us%40er:p%5Bass@example.org', 'username': 'us%40er', 'password': 'p%5Bass'}, + ), + ( + 'foo://us%40er:p%5Bass@example.org', + {'str()': 'foo://us%40er:p%5Bass@example.org', 'username': 'us%40er', 'password': 'p%5Bass'}, + ), + ( + 'foo://us[]er:p,ass@example.org', + {'str()': 'foo://us%5B%5Der:p,ass@example.org', 'username': 'us%5B%5Der', 'password': 'p,ass'}, + ), + ('foo://%2F:@example.org', {'str()': 'foo://%2F@example.org', 'username': '%2F', 'password': None}), + ('HTTP://EXAMPLE.ORG', {'str()': 'http://example.org/'}), + ('HTTP://EXAMPLE.org', {'str()': 'http://example.org/'}), + ('POSTGRES://EXAMPLE.ORG', {'str()': 'postgres://EXAMPLE.ORG'}), + ('https://twitter.com/@handle', {'str()': 'https://twitter.com/@handle', 'path': '/@handle'}), + (' https://www.example.com \n', 'https://www.example.com/'), + # https://www.xudongz.com/blog/2017/idn-phishing/ accepted but converted + ('https://www.аррӏе.com/', 'https://www.xn--80ak6aa92e.com/'), + ('https://exampl£e.org', 'https://xn--example-gia.org/'), + ('https://example.珠宝', 'https://example.xn--pbt977c/'), + ('https://example.vermögensberatung', 'https://example.xn--vermgensberatung-pwb/'), + ('https://example.рф', 'https://example.xn--p1ai/'), + ('https://exampl£e.珠宝', 'https://xn--example-gia.xn--pbt977c/'), + ('ht💣tp://example.org', Err('relative URL without a base')), + ( + 'http://usßer:pasℝs@a💣b.com:123/c?d=e&d=f#g', + { + 'str()': 'http://us%C3%9Fer:pas%E2%84%9Ds@xn--ab-qt72a.com:123/c?d=e&d=f#g', + 'username': 'us%C3%9Fer', + 'password': 'pas%E2%84%9Ds', + 'host': 'xn--ab-qt72a.com', + 'port': 123, + 'path': '/c', + 'query': 'd=e&d=f', + 'query_params()': [('d', 'e'), ('d', 'f')], + 'fragment': 'g', + }, + ), ], ) -def test_url_error(url_validator, url, expected): +def test_url_cases(url_validator, url, expected): if isinstance(expected, Err): with pytest.raises(ValidationError) as exc_info: url_validator.validate_python(url) @@ -131,7 +207,7 @@ def test_url_error(url_validator, url, expected): assert error['ctx']['error'] == expected.message else: output_url = url_validator.validate_python(url) - assert isinstance(output_url, Url) + assert isinstance(output_url, (Url, MultiHostUrl)) if isinstance(expected, str): assert str(output_url) == expected else: @@ -147,6 +223,84 @@ def test_url_error(url_validator, url, expected): assert output_parts == expected +@pytest.mark.parametrize( + 'validator_kwargs,url,expected', + [ + ( + dict(default_port=1234, default_path='/baz'), + 'http://example.org', + {'str()': 'http://example.org:1234/baz', 'path': '/baz'}, + ), + (dict(default_port=1234, default_path='/baz'), 'http://example.org/', 'http://example.org:1234/baz'), + (dict(default_port=1234, default_path='/baz'), 'http://example.org/bang', 'http://example.org:1234/bang'), + (dict(default_port=1234, default_path='/baz'), 'http://example.org:1111', 'http://example.org:1111/baz'), + (dict(default_port=1234, default_path='/baz'), 'foobar://example.org', 'foobar://example.org:1234/baz'), + (dict(default_host='localhost'), 'redis:///foobar', 'redis://localhost/foobar'), + (dict(default_host='localhost'), 'redis://', 'redis://localhost'), + (dict(default_host='localhost', default_path='/baz'), 'redis://', 'redis://localhost/baz'), + (dict(default_host='localhost'), 'redis://xxx/foobar', 'redis://xxx/foobar'), + (dict(host_required=True), 'redis://', Err('empty host')), + ], +) +@pytest.mark.parametrize('validator_type', ['Url', 'MultiHostUrl']) +def test_url_defaults(validator_type, validator_kwargs, url, expected): + if validator_type == 'Url': + schema = core_schema.url_schema(**validator_kwargs) + else: + schema = core_schema.multi_host_url_schema(**validator_kwargs) + s = SchemaValidator(schema) + test_url_cases(s, url, expected) + + +@pytest.mark.parametrize( + 'validator_kwargs,url,expected', + [ + ( + dict(default_port=1234, default_path='/baz'), + 'http://example.org', + {'str()': 'http://example.org:1234/baz', 'host': 'example.org', 'port': 1234, 'path': '/baz'}, + ), + (dict(default_host='localhost'), 'redis://', {'str()': 'redis://localhost', 'host': 'localhost'}), + ], +) +def test_url_defaults_single_url(validator_kwargs, url, expected): + s = SchemaValidator(core_schema.url_schema(**validator_kwargs)) + test_url_cases(s, url, expected) + + +@pytest.mark.parametrize( + 'validator_kwargs,url,expected', + [ + ( + dict(default_port=1234, default_path='/baz'), + 'http://example.org', + { + 'str()': 'http://example.org:1234/baz', + 'hosts()': [{'host': 'example.org', 'password': None, 'port': 1234, 'username': None}], + 'path': '/baz', + }, + ), + ( + dict(default_host='localhost'), + 'redis://', + { + 'str()': 'redis://localhost', + 'hosts()': [{'host': 'localhost', 'password': None, 'port': None, 'username': None}], + }, + ), + ({}, 'redis://', {'str()': 'redis://', 'hosts()': []}), + ], +) +def test_url_defaults_multi_host_url(validator_kwargs, url, expected): + s = SchemaValidator(core_schema.multi_host_url_schema(**validator_kwargs)) + test_url_cases(s, url, expected) + + +def test_multi_host_default_host_no_comma(): + with pytest.raises(SchemaError, match='default_host cannot contain a comma, see pydantic-core#326'): + SchemaValidator(core_schema.multi_host_url_schema(default_host='foo,bar')) + + @pytest.fixture(scope='module', name='strict_url_validator') def strict_url_validator_fixture(): return SchemaValidator(core_schema.url_schema(), {'strict': True}) @@ -155,23 +309,29 @@ def strict_url_validator_fixture(): @pytest.mark.parametrize( 'url,expected', [ - ('http://example.com', {'str()': 'http://example.com/', 'host_type': 'domain', 'host': 'example.com'}), + ('http://example.com', {'str()': 'http://example.com/', 'host': 'example.com'}), ('http://exa\nmple.com', Err('tabs or newlines are ignored in URLs', 'url_syntax_violation')), ('xxx', Err('relative URL without a base', 'url_parsing')), ('http:/foo', Err('expected //', 'url_syntax_violation')), ('http:///foo', Err('expected //', 'url_syntax_violation')), ('http:////foo', Err('expected //', 'url_syntax_violation')), - ('http://exam_ple.com', {'str()': 'http://exam_ple.com/', 'host_type': 'domain'}), + ('http://exam_ple.com', {'str()': 'http://exam_ple.com/'}), ('https:/more', Err('expected //', 'url_syntax_violation')), ('https:more', Err('expected //', 'url_syntax_violation')), - ('file:///foobar', {'str()': 'file:///foobar', 'host_type': None, 'host': None, 'unicode_host()': None}), + ('file:///foobar', {'str()': 'file:///foobar', 'host': None, 'unicode_host()': None}), ('file://:80', Err('invalid domain character', 'url_parsing')), ('file:/xx', Err('expected // after file:', 'url_syntax_violation')), ('foobar://:80', Err('empty host', 'url_parsing')), ('mongodb+srv://server.example.com/', 'mongodb+srv://server.example.com/'), + ('http://user:@example.org', 'http://user@example.org/'), + ('http://us[er:@example.org', Err('non-URL code point', 'url_syntax_violation')), + ('http://us%5Ber:bar@example.org', 'http://us%5Ber:bar@example.org/'), + ('http://user:@example.org', 'http://user@example.org/'), + ('mongodb://us%5Ber:bar@example.org', 'mongodb://us%5Ber:bar@example.org'), + ('mongodb://us@er@example.org', Err('unencoded @ sign in username or password', 'url_syntax_violation')), ], ) -def test_url_error_strict(strict_url_validator, url, expected): +def test_url_error(strict_url_validator, url, expected): if isinstance(expected, Err): with pytest.raises(ValidationError) as exc_info: strict_url_validator.validate_python(url) @@ -197,14 +357,6 @@ def test_url_error_strict(strict_url_validator, url, expected): assert output_parts == expected -def test_host_required(): - v = SchemaValidator(core_schema.url_schema(host_required=True)) - url = v.validate_python('http://example.com') - assert url.host == 'example.com' - with pytest.raises(ValidationError, match=r'URL host required \[type=url_host_required,'): - v.validate_python('unix:/run/foo.socket') - - def test_no_host(url_validator): url = url_validator.validate_python('data:text/plain,Stuff') assert str(url) == 'data:text/plain,Stuff' @@ -246,11 +398,11 @@ def test_allowed_schemes_error(): # insert_assert(exc_info.value.errors()) assert exc_info.value.errors() == [ { - 'type': 'url_schema', + 'type': 'url_scheme', 'loc': (), - 'msg': "URL schema should be 'http' or 'https'", + 'msg': "URL scheme should be 'http' or 'https'", 'input': 'unix:/run/foo.socket', - 'ctx': {'expected_schemas': "'http' or 'https'"}, + 'ctx': {'expected_schemes': "'http' or 'https'"}, } ] @@ -262,11 +414,11 @@ def test_allowed_schemes_errors(): # insert_assert(exc_info.value.errors()) assert exc_info.value.errors() == [ { - 'type': 'url_schema', + 'type': 'url_scheme', 'loc': (), - 'msg': "URL schema should be 'a', 'b' or 'c'", + 'msg': "URL scheme should be 'a', 'b' or 'c'", 'input': 'unix:/run/foo.socket', - 'ctx': {'expected_schemas': "'a', 'b' or 'c'"}, + 'ctx': {'expected_schemes': "'a', 'b' or 'c'"}, } ] @@ -277,13 +429,31 @@ def test_url_query_repeat(url_validator): assert url.query_params() == [('a', '1'), ('a', '2')] -def test_url_to_url(url_validator): +def test_url_to_url(url_validator, multi_host_url_validator): url: Url = url_validator.validate_python('https://example.com') + assert isinstance(url, Url) assert str(url) == 'https://example.com/' + url2 = url_validator.validate_python(url) + assert isinstance(url2, Url) assert str(url2) == 'https://example.com/' assert url is not url2 + multi_url = multi_host_url_validator.validate_python('https://example.com') + assert isinstance(multi_url, MultiHostUrl) + + url3 = url_validator.validate_python(multi_url) + assert isinstance(url3, Url) + assert str(url3) == 'https://example.com/' + + multi_url2 = multi_host_url_validator.validate_python('foobar://x:y@foo,x:y@bar.com') + assert isinstance(multi_url2, MultiHostUrl) + + url4 = url_validator.validate_python(multi_url2) + assert isinstance(url4, Url) + assert str(url4) == 'foobar://x:y%40foo,x%3Ay@bar.com' + assert url4.host == 'bar.com' + def test_url_to_constraint(): v1 = SchemaValidator(core_schema.url_schema()) @@ -311,11 +481,11 @@ def test_url_to_constraint(): v3.validate_python(url) assert exc_info.value.errors() == [ { - 'type': 'url_schema', + 'type': 'url_scheme', 'loc': (), - 'msg': "URL schema should be 'https'", + 'msg': "URL scheme should be 'https'", 'input': IsInstance(Url) & HasRepr("Url('http://example.com/foobar/bar')"), - 'ctx': {'expected_schemas': "'https'"}, + 'ctx': {'expected_schemes': "'https'"}, } ] @@ -358,9 +528,560 @@ def test_username(url_validator, input_value, expected, username, password): assert url.password == password -def test_strict_not_strict(url_validator, strict_url_validator): +def test_strict_not_strict(url_validator, strict_url_validator, multi_host_url_validator): url = url_validator.validate_python('http:/example.com/foobar/bar') assert str(url) == 'http://example.com/foobar/bar' url2 = strict_url_validator.validate_python(url) assert str(url2) == 'http://example.com/foobar/bar' + + multi_url = multi_host_url_validator.validate_python('https://example.com') + assert isinstance(multi_url, MultiHostUrl) + + url3 = strict_url_validator.validate_python(multi_url) + assert isinstance(url3, Url) + assert str(url3) == 'https://example.com/' + + multi_url2 = multi_host_url_validator.validate_python('foobar://x:y@foo,x:y@bar.com') + assert isinstance(multi_url2, MultiHostUrl) + + with pytest.raises(ValidationError, match=r'unencoded @ sign in username or password \[type=url_syntax_violation'): + strict_url_validator.validate_python(multi_url2) + + +def test_multi_host_url_ok_single(py_and_json: PyAndJson): + v = py_and_json(core_schema.multi_host_url_schema()) + url: MultiHostUrl = v.validate_test('https://example.com/foo/bar?a=b') + assert isinstance(url, MultiHostUrl) + assert str(url) == 'https://example.com/foo/bar?a=b' + assert repr(url) == "Url('https://example.com/foo/bar?a=b')" + assert url.scheme == 'https' + assert url.path == '/foo/bar' + assert url.query == 'a=b' + assert url.query_params() == [('a', 'b')] + assert url.fragment is None + # insert_assert(url.hosts()) + assert url.hosts() == [{'username': None, 'password': None, 'host': 'example.com', 'port': 443}] + + url: MultiHostUrl = v.validate_test('postgres://foo:bar@example.com:1234') + assert isinstance(url, MultiHostUrl) + assert str(url) == 'postgres://foo:bar@example.com:1234' + assert url.scheme == 'postgres' + # insert_assert(url.hosts()) + assert url.hosts() == [{'username': 'foo', 'password': 'bar', 'host': 'example.com', 'port': 1234}] + + +def test_multi_host_url_ok_2(py_and_json: PyAndJson): + v = py_and_json(core_schema.multi_host_url_schema()) + url: MultiHostUrl = v.validate_test('https://foo.com,bar.com/path') + assert isinstance(url, MultiHostUrl) + assert str(url) == 'https://foo.com,bar.com/path' + assert url.scheme == 'https' + assert url.path == '/path' + # insert_assert(url.hosts()) + assert url.hosts() == [ + {'username': None, 'password': None, 'host': 'foo.com', 'port': 443}, + {'username': None, 'password': None, 'host': 'bar.com', 'port': 443}, + ] + + +@pytest.fixture(scope='module', name='multi_host_url_validator') +def multi_host_url_validator_fixture(): + return SchemaValidator(core_schema.multi_host_url_schema()) + + +@pytest.mark.parametrize( + 'url,expected', + [ + ( + 'http://example.com', + { + 'str()': 'http://example.com/', + 'hosts()': [{'host': 'example.com', 'password': None, 'port': 80, 'username': None}], + 'unicode_string()': 'http://example.com/', + }, + ), + ( + 'postgres://example.com', + { + 'str()': 'postgres://example.com', + 'scheme': 'postgres', + 'hosts()': [{'host': 'example.com', 'password': None, 'port': None, 'username': None}], + }, + ), + ( + 'mongodb://foo,bar,spam/xxx', + { + 'str()': 'mongodb://foo,bar,spam/xxx', + 'scheme': 'mongodb', + 'hosts()': [ + {'host': 'foo', 'password': None, 'port': None, 'username': None}, + {'host': 'bar', 'password': None, 'port': None, 'username': None}, + {'host': 'spam', 'password': None, 'port': None, 'username': None}, + ], + }, + ), + (' mongodb://foo,bar,spam/xxx ', 'mongodb://foo,bar,spam/xxx'), + (' \n\r\t mongodb://foo,bar,spam/xxx', 'mongodb://foo,bar,spam/xxx'), + ( + 'mongodb+srv://foo,bar,spam/xxx', + { + 'str()': 'mongodb+srv://foo,bar,spam/xxx', + 'scheme': 'mongodb+srv', + 'hosts()': [ + {'host': 'foo', 'password': None, 'port': None, 'username': None}, + {'host': 'bar', 'password': None, 'port': None, 'username': None}, + {'host': 'spam', 'password': None, 'port': None, 'username': None}, + ], + }, + ), + ( + 'https://foo:bar@example.com,fo%20o:bar@example.com', + { + 'str()': 'https://foo:bar@example.com,fo%20o:bar@example.com/', + 'scheme': 'https', + 'hosts()': [ + {'host': 'example.com', 'password': 'bar', 'port': 443, 'username': 'foo'}, + {'host': 'example.com', 'password': 'bar', 'port': 443, 'username': 'fo%20o'}, + ], + }, + ), + ( + 'postgres://foo:bar@example.com,fo%20o:bar@example.com', + { + 'str()': 'postgres://foo:bar@example.com,fo%20o:bar@example.com', + 'scheme': 'postgres', + 'hosts()': [ + {'host': 'example.com', 'password': 'bar', 'port': None, 'username': 'foo'}, + {'host': 'example.com', 'password': 'bar', 'port': None, 'username': 'fo%20o'}, + ], + }, + ), + ('postgres://', {'str()': 'postgres://', 'scheme': 'postgres', 'hosts()': []}), + ('postgres://,', Err('empty host')), + ('postgres://,,', Err('empty host')), + ('postgres://foo,\n,bar', Err('empty host')), + ('postgres://\n,bar', Err('empty host')), + ('postgres://foo,\n', Err('empty host')), + ('postgres://foo,', Err('empty host')), + ('postgres://,foo', Err('empty host')), + ('http://', Err('empty host')), + ('http://,', Err('empty host')), + ('http://,,', Err('empty host')), + ('http://foo,\n,bar', Err('empty host')), + ('http://\n,bar', Err('empty host')), + ('http://foo,\n', Err('empty host')), + ('http://foo,', Err('empty host')), + ('http://,foo', Err('empty host')), + ('http@foobar', Err('relative URL without a base')), + ( + 'mongodb://foo\n,b\nar,\nspam/xxx', + { + 'str()': 'mongodb://foo,bar,spam/xxx', + 'scheme': 'mongodb', + 'hosts()': [ + {'host': 'foo', 'password': None, 'port': None, 'username': None}, + {'host': 'bar', 'password': None, 'port': None, 'username': None}, + {'host': 'spam', 'password': None, 'port': None, 'username': None}, + ], + }, + ), + ( + 'postgres://user:pass@host1.db.net:4321,host2.db.net:6432/app', + { + 'str()': 'postgres://user:pass@host1.db.net:4321,host2.db.net:6432/app', + 'scheme': 'postgres', + 'hosts()': [ + {'host': 'host1.db.net', 'password': 'pass', 'port': 4321, 'username': 'user'}, + {'host': 'host2.db.net', 'password': None, 'port': 6432, 'username': None}, + ], + 'path': '/app', + }, + ), + ( + 'postgresql+py-postgresql://user:pass@localhost:5432/app', + { + 'str()': 'postgresql+py-postgresql://user:pass@localhost:5432/app', + 'hosts()': [{'host': 'localhost', 'password': 'pass', 'port': 5432, 'username': 'user'}], + }, + ), + ('http://foo#bar', 'http://foo/#bar'), + ('mongodb://foo#bar', 'mongodb://foo#bar'), + ('http://foo,bar#spam', 'http://foo,bar/#spam'), + ('mongodb://foo,bar#spam', 'mongodb://foo,bar#spam'), + ('http://foo,bar?x=y', 'http://foo,bar/?x=y'), + ('mongodb://foo,bar?x=y', 'mongodb://foo,bar?x=y'), + ('foo://foo,bar?x=y', 'foo://foo,bar?x=y'), + ( + ( + 'mongodb://mongodb1.example.com:27317,mongodb2.example.com:27017/' + 'mydatabase?replicaSet=mySet&authSource=authDB' + ), + { + 'str()': ( + 'mongodb://mongodb1.example.com:27317,mongodb2.example.com:27017/' + 'mydatabase?replicaSet=mySet&authSource=authDB' + ), + 'hosts()': [ + {'host': 'mongodb1.example.com', 'password': None, 'port': 27317, 'username': None}, + {'host': 'mongodb2.example.com', 'password': None, 'port': 27017, 'username': None}, + ], + 'query_params()': [('replicaSet', 'mySet'), ('authSource', 'authDB')], + }, + ), + # with bashslashes + ( + 'FILE:\\\\foo,bar\\more', + { + 'str()': 'file://foo,bar/more', + 'path': '/more', + 'hosts()': [ + {'host': 'foo', 'password': None, 'port': None, 'username': None}, + {'host': 'bar', 'password': None, 'port': None, 'username': None}, + ], + }, + ), + ( + 'http:\\\\foo,bar\\more', + { + 'str()': 'http://foo,bar/more', + 'path': '/more', + 'hosts()': [ + {'host': 'foo', 'password': None, 'port': 80, 'username': None}, + {'host': 'bar', 'password': None, 'port': 80, 'username': None}, + ], + }, + ), + ('mongo:\\\\foo,bar\\more', Err('empty host')), + ( + 'foobar://foo[]bar:x@y@whatever,foo[]bar:x@y@whichever', + { + 'str()': 'foobar://foo%5B%5Dbar:x%40y@whatever,foo%5B%5Dbar:x%40y@whichever', + 'hosts()': [ + {'host': 'whatever', 'password': 'x%40y', 'port': None, 'username': 'foo%5B%5Dbar'}, + {'host': 'whichever', 'password': 'x%40y', 'port': None, 'username': 'foo%5B%5Dbar'}, + ], + }, + ), + ( + 'foobar://foo%2Cbar:x@y@whatever,snap', + { + 'str()': 'foobar://foo%2Cbar:x%40y@whatever,snap', + 'hosts()': [ + {'host': 'whatever', 'password': 'x%40y', 'port': None, 'username': 'foo%2Cbar'}, + {'host': 'snap', 'password': None, 'port': None, 'username': None}, + ], + }, + ), + ( + 'mongodb://x:y@[::1],1.1.1.1:888/xxx', + { + 'str()': 'mongodb://x:y@[::1],1.1.1.1:888/xxx', + 'scheme': 'mongodb', + 'hosts()': [ + {'host': '[::1]', 'password': 'y', 'port': None, 'username': 'x'}, + {'host': '1.1.1.1', 'password': None, 'port': 888, 'username': None}, + ], + }, + ), + ( + 'http://foo.co.uk,bar.spam.things.com', + { + 'str()': 'http://foo.co.uk,bar.spam.things.com/', + 'hosts()': [ + {'host': 'foo.co.uk', 'password': None, 'port': 80, 'username': None}, + {'host': 'bar.spam.things.com', 'password': None, 'port': 80, 'username': None}, + ], + }, + ), + ('ht💣tp://example.com', Err('relative URL without a base')), + # punycode ß + ( + 'http://£££.com', + { + 'str()': 'http://xn--9aaa.com/', + 'hosts()': [{'host': 'xn--9aaa.com', 'password': None, 'port': 80, 'username': None}], + 'unicode_string()': 'http://£££.com/', + }, + ), + ( + 'http://£££.co.uk,münchen.com/foo?bar=baz#qux', + { + 'str()': 'http://xn--9aaa.co.uk,xn--mnchen-3ya.com/foo?bar=baz#qux', + 'hosts()': [ + {'host': 'xn--9aaa.co.uk', 'password': None, 'port': 80, 'username': None}, + {'host': 'xn--mnchen-3ya.com', 'password': None, 'port': 80, 'username': None}, + ], + 'unicode_string()': 'http://£££.co.uk,münchen.com/foo?bar=baz#qux', + }, + ), + ( + 'postgres://£££.co.uk,münchen.com/foo?bar=baz#qux', + { + 'str()': 'postgres://%C2%A3%C2%A3%C2%A3.co.uk,m%C3%BCnchen.com/foo?bar=baz#qux', + 'hosts()': [ + {'host': '%C2%A3%C2%A3%C2%A3.co.uk', 'password': None, 'port': None, 'username': None}, + {'host': 'm%C3%BCnchen.com', 'password': None, 'port': None, 'username': None}, + ], + 'unicode_string()': 'postgres://%C2%A3%C2%A3%C2%A3.co.uk,m%C3%BCnchen.com/foo?bar=baz#qux', + }, + ), + ], +) +def test_multi_url_cases(multi_host_url_validator, url, expected): + if isinstance(expected, Err): + with pytest.raises(ValidationError) as exc_info: + multi_host_url_validator.validate_python(url) + assert exc_info.value.error_count() == 1 + error = exc_info.value.errors()[0] + assert error['type'] == 'url_parsing' + assert error['ctx']['error'] == expected.message + else: + output_url = multi_host_url_validator.validate_python(url) + assert isinstance(output_url, MultiHostUrl) + if isinstance(expected, str): + assert str(output_url) == expected + else: + assert isinstance(expected, dict) + output_parts = {} + for key in expected: + if key == 'str()': + output_parts[key] = str(output_url) + elif key.endswith('()'): + output_parts[key] = getattr(output_url, key[:-2])() + else: + output_parts[key] = getattr(output_url, key) + # debug(output_parts) + assert output_parts == expected + + +@pytest.fixture(scope='module', name='strict_multi_host_url_validator') +def strict_multi_host_url_validator_fixture(): + return SchemaValidator(core_schema.multi_host_url_schema(strict=True)) + + +@pytest.mark.parametrize( + 'url,expected', + [ + ('http://example.com', 'http://example.com/'), + ( + ' mongodb://foo,bar,spam/xxx ', + Err('leading or trailing control or space character are ignored in URLs', 'url_syntax_violation'), + ), + ( + ' \n\r\t mongodb://foo,bar,spam/xxx', + Err('leading or trailing control or space character are ignored in URLs', 'url_syntax_violation'), + ), + # with bashslashes + ('file:\\\\foo,bar\\more', Err('backslash', 'url_syntax_violation')), + ('http:\\\\foo,bar\\more', Err('backslash', 'url_syntax_violation')), + ('mongo:\\\\foo,bar\\more', Err('non-URL code point', 'url_syntax_violation')), + ('foobar://foo[]bar:x@y@whatever,foo[]bar:x@y@whichever', Err('non-URL code point', 'url_syntax_violation')), + ( + 'foobar://foo%2Cbar:x@y@whatever,snap', + Err('unencoded @ sign in username or password', 'url_syntax_violation'), + ), + ('foobar://foo%2Cbar:x%40y@whatever,snap', 'foobar://foo%2Cbar:x%40y@whatever,snap'), + ], +) +def test_multi_url_cases_strict(strict_multi_host_url_validator, url, expected): + if isinstance(expected, Err): + with pytest.raises(ValidationError) as exc_info: + strict_multi_host_url_validator.validate_python(url) + assert exc_info.value.error_count() == 1 + error = exc_info.value.errors()[0] + assert error['type'] == expected.errors + assert error['ctx']['error'] == expected.message + else: + output_url = strict_multi_host_url_validator.validate_python(url) + assert isinstance(output_url, MultiHostUrl) + if isinstance(expected, str): + assert str(output_url) == expected + else: + assert isinstance(expected, dict) + output_parts = {} + for key in expected: + if key == 'str()': + output_parts[key] = str(output_url) + elif key.endswith('()'): + output_parts[key] = getattr(output_url, key[:-2])() + else: + output_parts[key] = getattr(output_url, key) + assert output_parts == expected + + +def test_url_to_multi_url(url_validator, multi_host_url_validator): + url: Url = url_validator.validate_python('https://example.com') + assert isinstance(url, Url) + assert str(url) == 'https://example.com/' + + url2 = multi_host_url_validator.validate_python(url) + assert isinstance(url2, MultiHostUrl) + assert str(url2) == 'https://example.com/' + assert url is not url2 + + url3 = multi_host_url_validator.validate_python(url2) + assert isinstance(url3, MultiHostUrl) + assert str(url3) == 'https://example.com/' + assert url2 is not url3 + + +def test_multi_wrong_type(multi_host_url_validator): + assert str(multi_host_url_validator.validate_python('http://example.com')) == 'http://example.com/' + with pytest.raises(ValidationError, match=r'URL input should be a string or URL \[type=url_type,'): + multi_host_url_validator.validate_python(42) + + +def test_multi_allowed_schemas(): + v = SchemaValidator(core_schema.multi_host_url_schema(allowed_schemes=['http', 'foo'])) + assert str(v.validate_python('http://example.com')) == 'http://example.com/' + assert str(v.validate_python('foo://example.com')) == 'foo://example.com' + with pytest.raises(ValidationError, match=r"URL scheme should be 'http' or 'foo' \[type=url_scheme,"): + v.validate_python('https://example.com') + + +def test_multi_max_length(url_validator): + v = SchemaValidator(core_schema.multi_host_url_schema(max_length=25)) + assert str(v.validate_python('http://example.com')) == 'http://example.com/' + with pytest.raises(ValidationError, match=r'URL should have at most 25 characters \[type=url_too_long,'): + v.validate_python('https://example.com/this-is-too-long') + + url = v.validate_python('http://example.com') + assert str(v.validate_python(url)) == 'http://example.com/' + + simple_url = url_validator.validate_python('http://example.com') + assert isinstance(simple_url, Url) + assert str(v.validate_python(simple_url)) == 'http://example.com/' + + long_simple_url = url_validator.validate_python('http://example.com/this-is-too-long') + with pytest.raises(ValidationError, match=r'URL should have at most 25 characters \[type=url_too_long,'): + v.validate_python(long_simple_url) + + +def test_zero_schemas(): + with pytest.raises(SchemaError, match='"allowed_schemes" should have length > 0'): + SchemaValidator(core_schema.multi_host_url_schema(allowed_schemes=[])) + + +@pytest.mark.parametrize( + 'url,expected', + [ + # urlparse doesn't follow RFC 3986 Section 3.2 + ( + 'http://google.com#@evil.com/', + dict( + scheme='http', + host='google.com', + # path='', CHANGED + path='/', + fragment='@evil.com/', + ), + ), + # CVE-2016-5699 + ( + 'http://127.0.0.1%0d%0aConnection%3a%20keep-alive', + # dict(scheme='http', host='127.0.0.1%0d%0aconnection%3a%20keep-alive'), CHANGED + Err('Input should be a valid URL, invalid domain character [type=url_parsing,'), + ), + # NodeJS unicode -> double dot + ('http://google.com/\uff2e\uff2e/abc', dict(scheme='http', host='google.com', path='/%EF%BC%AE%EF%BC%AE/abc')), + # Scheme without :// + ( + "javascript:a='@google.com:12345/';alert(0)", + dict(scheme='javascript', path="a='@google.com:12345/';alert(0)"), + ), + ( + '//google.com/a/b/c', + # dict(host='google.com', path='/a/b/c'), + Err('Input should be a valid URL, relative URL without a base [type=url_parsing,'), + ), + # International URLs + ( + 'http://ヒ:キ@ヒ.abc.ニ/ヒ?キ#ワ', + dict( + scheme='http', + host='xn--pdk.abc.xn--idk', + auth='%E3%83%92:%E3%82%AD', + path='/%E3%83%92', + query='%E3%82%AD', + fragment='%E3%83%AF', + ), + ), + # Injected headers (CVE-2016-5699, CVE-2019-9740, CVE-2019-9947) + ( + '10.251.0.83:7777?a=1 HTTP/1.1\r\nX-injected: header', + # dict( CHANGED + # host='10.251.0.83', + # port=7777, + # path='', + # query='a=1%20HTTP/1.1%0D%0AX-injected:%20header', + # ), + Err('Input should be a valid URL, relative URL without a base [type=url_parsing,'), + ), + # ADDED, similar to the above with scheme added + ( + 'http://10.251.0.83:7777?a=1 HTTP/1.1\r\nX-injected: header', + dict( + host='10.251.0.83', + port=7777, + path='/', + # query='a=1%20HTTP/1.1%0D%0AX-injected:%20header', CHANGED + query='a=1%20HTTP/1.1X-injected:%20header', + ), + ), + ( + 'http://127.0.0.1:6379?\r\nSET test failure12\r\n:8080/test/?test=a', + dict( + scheme='http', + host='127.0.0.1', + port=6379, + # path='', + path='/', + # query='%0D%0ASET%20test%20failure12%0D%0A:8080/test/?test=a', CHANGED + query='SET%20test%20failure12:8080/test/?test=a', + ), + ), + # See https://bugs.xdavidhu.me/google/2020/03/08/the-unexpected-google-wide-domain-check-bypass/ + ( + 'https://user:pass@xdavidhu.me\\test.corp.google.com:8080/path/to/something?param=value#hash', + dict( + scheme='https', + auth='user:pass', + host='xdavidhu.me', + # path='/%5Ctest.corp.google.com:8080/path/to/something', CHANGED + path='/test.corp.google.com:8080/path/to/something', + query='param=value', + fragment='hash', + ), + ), + # # Tons of '@' causing backtracking + ( + 'https://' + ('@' * 10000) + '[', + # False, CHANGED + Err('Input should be a valid URL, invalid IPv6 address [type=url_parsing,'), + ), + ( + 'https://user:' + ('@' * 10000) + 'example.com', + dict(scheme='https', auth='user:' + ('%40' * 9999), host='example.com'), + ), + ], +) +def test_url_vulnerabilities(url_validator, url, expected): + """ + Test cases from + https://github.com/urllib3/urllib3/blob/7ef7444fd0fc22a825be6624af85343cefa36fef/test/test_util.py#L422 + """ + if isinstance(expected, Err): + with pytest.raises(ValidationError, match=re.escape(expected.message)): + url_validator.validate_python(url) + else: + output_url = url_validator.validate_python(url) + assert isinstance(output_url, Url) + output_parts = {} + for key in expected: + # one tweak required to match urllib3 logic + if key == 'auth': + output_parts[key] = f'{output_url.username}:{output_url.password}' + else: + output_parts[key] = getattr(output_url, key) + assert output_parts == expected