Skip to content

Commit

Permalink
Implement ability to set entity resolver for serde Deserializer
Browse files Browse the repository at this point in the history
Co-authored-by: Mingun <alexander_sergey@mail.ru>
  • Loading branch information
pigeonhands and Mingun committed Apr 8, 2023
1 parent b2ceae7 commit 9b5e0e9
Show file tree
Hide file tree
Showing 6 changed files with 301 additions and 39 deletions.
4 changes: 4 additions & 0 deletions Changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@

### New Features

- [#581]: Allow `Deserializer` to set `quick_xml::de::EntityResolver` for
resolving unknown entities that would otherwise cause the parser to return
an [`EscapeError::UnrecognizedSymbol`] error.

### Bug Fixes

### Misc Changes
Expand Down
47 changes: 30 additions & 17 deletions src/de/map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

use crate::{
de::key::QNameDeserializer,
de::resolver::EntityResolver,
de::simple_type::SimpleTypeDeserializer,
de::{str2bool, DeEvent, Deserializer, XmlRead, TEXT_KEY, VALUE_KEY},
encoding::Decoder,
Expand Down Expand Up @@ -165,13 +166,14 @@ enum ValueSource {
///
/// - `'a` lifetime represents a parent deserializer, which could own the data
/// buffer.
pub(crate) struct MapAccess<'de, 'a, R>
pub(crate) struct MapAccess<'de, 'a, R, E>
where
R: XmlRead<'de>,
E: EntityResolver,
{
/// Tag -- owner of attributes
start: BytesStart<'de>,
de: &'a mut Deserializer<'de, R>,
de: &'a mut Deserializer<'de, R, E>,
/// State of the iterator over attributes. Contains the next position in the
/// inner `start` slice, from which next attribute should be parsed.
iter: IterState,
Expand All @@ -190,13 +192,14 @@ where
has_value_field: bool,
}

impl<'de, 'a, R> MapAccess<'de, 'a, R>
impl<'de, 'a, R, E> MapAccess<'de, 'a, R, E>
where
R: XmlRead<'de>,
E: EntityResolver,
{
/// Create a new MapAccess
pub fn new(
de: &'a mut Deserializer<'de, R>,
de: &'a mut Deserializer<'de, R, E>,
start: BytesStart<'de>,
fields: &'static [&'static str],
) -> Result<Self, DeError> {
Expand All @@ -211,9 +214,10 @@ where
}
}

impl<'de, 'a, R> de::MapAccess<'de> for MapAccess<'de, 'a, R>
impl<'de, 'a, R, E> de::MapAccess<'de> for MapAccess<'de, 'a, R, E>
where
R: XmlRead<'de>,
E: EntityResolver,
{
type Error = DeError;

Expand Down Expand Up @@ -369,13 +373,14 @@ macro_rules! forward {
/// A deserializer for a value of map or struct. That deserializer slightly
/// differently processes events for a primitive types and sequences than
/// a [`Deserializer`].
struct MapValueDeserializer<'de, 'a, 'm, R>
struct MapValueDeserializer<'de, 'a, 'm, R, E>
where
R: XmlRead<'de>,
E: EntityResolver,
{
/// Access to the map that created this deserializer. Gives access to the
/// context, such as list of fields, that current map known about.
map: &'m mut MapAccess<'de, 'a, R>,
map: &'m mut MapAccess<'de, 'a, R, E>,
/// Determines, should [`Deserializer::read_string_impl()`] expand the second
/// level of tags or not.
///
Expand Down Expand Up @@ -453,9 +458,10 @@ where
allow_start: bool,
}

impl<'de, 'a, 'm, R> MapValueDeserializer<'de, 'a, 'm, R>
impl<'de, 'a, 'm, R, E> MapValueDeserializer<'de, 'a, 'm, R, E>
where
R: XmlRead<'de>,
E: EntityResolver,
{
/// Returns a next string as concatenated content of consequent [`Text`] and
/// [`CData`] events, used inside [`deserialize_primitives!()`].
Expand All @@ -468,9 +474,10 @@ where
}
}

impl<'de, 'a, 'm, R> de::Deserializer<'de> for MapValueDeserializer<'de, 'a, 'm, R>
impl<'de, 'a, 'm, R, E> de::Deserializer<'de> for MapValueDeserializer<'de, 'a, 'm, R, E>
where
R: XmlRead<'de>,
E: EntityResolver,
{
type Error = DeError;

Expand Down Expand Up @@ -629,13 +636,14 @@ impl<'de> TagFilter<'de> {
///
/// [`Text`]: crate::events::Event::Text
/// [`CData`]: crate::events::Event::CData
struct MapValueSeqAccess<'de, 'a, 'm, R>
struct MapValueSeqAccess<'de, 'a, 'm, R, E>
where
R: XmlRead<'de>,
E: EntityResolver,
{
/// Accessor to a map that creates this accessor and to a deserializer for
/// a sequence items.
map: &'m mut MapAccess<'de, 'a, R>,
map: &'m mut MapAccess<'de, 'a, R, E>,
/// Filter that determines whether a tag is a part of this sequence.
///
/// When feature `overlapped-lists` is not activated, iteration will stop
Expand All @@ -653,18 +661,20 @@ where
}

#[cfg(feature = "overlapped-lists")]
impl<'de, 'a, 'm, R> Drop for MapValueSeqAccess<'de, 'a, 'm, R>
impl<'de, 'a, 'm, R, E> Drop for MapValueSeqAccess<'de, 'a, 'm, R, E>
where
R: XmlRead<'de>,
E: EntityResolver,
{
fn drop(&mut self) {
self.map.de.start_replay(self.checkpoint);
}
}

impl<'de, 'a, 'm, R> SeqAccess<'de> for MapValueSeqAccess<'de, 'a, 'm, R>
impl<'de, 'a, 'm, R, E> SeqAccess<'de> for MapValueSeqAccess<'de, 'a, 'm, R, E>
where
R: XmlRead<'de>,
E: EntityResolver,
{
type Error = DeError;

Expand Down Expand Up @@ -705,18 +715,20 @@ where
////////////////////////////////////////////////////////////////////////////////////////////////////

/// A deserializer for a single item of a sequence.
struct SeqItemDeserializer<'de, 'a, 'm, R>
struct SeqItemDeserializer<'de, 'a, 'm, R, E>
where
R: XmlRead<'de>,
E: EntityResolver,
{
/// Access to the map that created this deserializer. Gives access to the
/// context, such as list of fields, that current map known about.
map: &'m mut MapAccess<'de, 'a, R>,
map: &'m mut MapAccess<'de, 'a, R, E>,
}

impl<'de, 'a, 'm, R> SeqItemDeserializer<'de, 'a, 'm, R>
impl<'de, 'a, 'm, R, E> SeqItemDeserializer<'de, 'a, 'm, R, E>
where
R: XmlRead<'de>,
E: EntityResolver,
{
/// Returns a next string as concatenated content of consequent [`Text`] and
/// [`CData`] events, used inside [`deserialize_primitives!()`].
Expand All @@ -729,9 +741,10 @@ where
}
}

impl<'de, 'a, 'm, R> de::Deserializer<'de> for SeqItemDeserializer<'de, 'a, 'm, R>
impl<'de, 'a, 'm, R, E> de::Deserializer<'de> for SeqItemDeserializer<'de, 'a, 'm, R, E>
where
R: XmlRead<'de>,
E: EntityResolver,
{
type Error = DeError;

Expand Down
90 changes: 78 additions & 12 deletions src/de/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1833,10 +1833,13 @@ macro_rules! deserialize_option {

mod key;
mod map;
mod resolver;
mod simple_type;
mod var;

pub use crate::errors::serialize::DeError;
pub use resolver::{EntityResolver, NoEntityResolver};

use crate::{
encoding::Decoder,
errors::Error,
Expand Down Expand Up @@ -1935,6 +1938,8 @@ pub enum PayloadEvent<'a> {
Text(BytesText<'a>),
/// Unescaped character data stored in `<![CDATA[...]]>`.
CData(BytesCData<'a>),
/// Document type definition data (DTD) stored in `<!DOCTYPE ...>`.
DocType(BytesText<'a>),
/// End of XML document.
Eof,
}
Expand All @@ -1948,6 +1953,7 @@ impl<'a> PayloadEvent<'a> {
PayloadEvent::End(e) => PayloadEvent::End(e.into_owned()),
PayloadEvent::Text(e) => PayloadEvent::Text(e.into_owned()),
PayloadEvent::CData(e) => PayloadEvent::CData(e.into_owned()),
PayloadEvent::DocType(e) => PayloadEvent::DocType(e.into_owned()),
PayloadEvent::Eof => PayloadEvent::Eof,
}
}
Expand All @@ -1956,23 +1962,40 @@ impl<'a> PayloadEvent<'a> {
/// An intermediate reader that consumes [`PayloadEvent`]s and produces final [`DeEvent`]s.
/// [`PayloadEvent::Text`] events, that followed by any event except
/// [`PayloadEvent::Text`] or [`PayloadEvent::CData`], are trimmed from the end.
struct XmlReader<'i, R: XmlRead<'i>> {
struct XmlReader<'i, R: XmlRead<'i>, E: EntityResolver = NoEntityResolver> {
/// A source of low-level XML events
reader: R,
/// Intermediate event, that could be returned by the next call to `next()`.
/// If that is the `Text` event then leading spaces already trimmed, but
/// trailing spaces is not. Before the event will be returned, trimming of
/// the spaces could be necessary
lookahead: Result<PayloadEvent<'i>, DeError>,

/// Used to resolve unknown entities that would otherwise cause the parser
/// to return an [`EscapeError::UnrecognizedSymbol`] error.
///
/// [`EscapeError::UnrecognizedSymbol`]: crate::escape::EscapeError::UnrecognizedSymbol
entity_resolver: E,
}

impl<'i, R: XmlRead<'i>> XmlReader<'i, R> {
fn new(mut reader: R) -> Self {
impl<'i, R: XmlRead<'i>, E: EntityResolver> XmlReader<'i, R, E> {
fn new(reader: R) -> Self
where
E: Default,
{
Self::with_resolver(reader, E::default())
}

fn with_resolver(mut reader: R, entity_resolver: E) -> Self {
// Lookahead by one event immediately, so we do not need to check in the
// loop if we need lookahead or not
let lookahead = reader.next();

Self { reader, lookahead }
Self {
reader,
lookahead,
entity_resolver,
}
}

/// Read next event and put it in lookahead, return the current lookahead
Expand Down Expand Up @@ -2028,7 +2051,7 @@ impl<'i, R: XmlRead<'i>> XmlReader<'i, R> {
if self.need_trim_end() {
e.inplace_trim_end();
}
Ok(e.unescape()?)
Ok(e.unescape_with(|entity| self.entity_resolver.resolve(entity))?)
}
PayloadEvent::CData(e) => Ok(e.decode()?),

Expand All @@ -2047,9 +2070,15 @@ impl<'i, R: XmlRead<'i>> XmlReader<'i, R> {
if self.need_trim_end() && e.inplace_trim_end() {
continue;
}
self.drain_text(e.unescape()?)
self.drain_text(e.unescape_with(|entity| self.entity_resolver.resolve(entity))?)
}
PayloadEvent::CData(e) => self.drain_text(e.decode()?),
PayloadEvent::DocType(e) => {
self.entity_resolver
.capture(e)
.map_err(|err| DeError::Custom(format!("cannot parse DTD: {}", err)))?;
continue;
}
PayloadEvent::Eof => Ok(DeEvent::Eof),
};
}
Expand Down Expand Up @@ -2166,12 +2195,12 @@ where
////////////////////////////////////////////////////////////////////////////////////////////////////

/// A structure that deserializes XML into Rust values.
pub struct Deserializer<'de, R>
pub struct Deserializer<'de, R, E: EntityResolver = NoEntityResolver>
where
R: XmlRead<'de>,
{
/// An XML reader that streams events into this deserializer
reader: XmlReader<'de, R>,
reader: XmlReader<'de, R, E>,

/// When deserializing sequences sometimes we have to skip unwanted events.
/// That events should be stored and then replayed. This is a replay buffer,
Expand Down Expand Up @@ -2226,7 +2255,13 @@ where
peek: None,
}
}
}

impl<'de, R, E> Deserializer<'de, R, E>
where
R: XmlRead<'de>,
E: EntityResolver,
{
/// Set the maximum number of events that could be skipped during deserialization
/// of sequences.
///
Expand Down Expand Up @@ -2556,20 +2591,49 @@ where
/// instead, because it will borrow instead of copy. If you have `&[u8]` which
/// is known to represent UTF-8, you can decode it first before using [`from_str`].
pub fn from_reader(reader: R) -> Self {
Self::with_resolver(reader, NoEntityResolver)
}
}

impl<'de, R, E> Deserializer<'de, IoReader<R>, E>
where
R: BufRead,
E: EntityResolver,
{
/// Create new deserializer that will copy data from the specified reader
/// into internal buffer. If you already have a string use [`Self::from_str`]
/// instead, because it will borrow instead of copy. If you have `&[u8]` which
/// is known to represent UTF-8, you can decode it first before using [`from_str`].
pub fn with_resolver(reader: R, entity_resolver: E) -> Self {
let mut reader = Reader::from_reader(reader);
reader.expand_empty_elements(true).check_end_names(true);

Self::new(IoReader {
let io_reader = IoReader {
reader,
start_trimmer: StartTrimmer::default(),
buf: Vec::new(),
})
};

Self {
reader: XmlReader::with_resolver(io_reader, entity_resolver),

#[cfg(feature = "overlapped-lists")]
read: VecDeque::new(),
#[cfg(feature = "overlapped-lists")]
write: VecDeque::new(),
#[cfg(feature = "overlapped-lists")]
limit: None,

#[cfg(not(feature = "overlapped-lists"))]
peek: None,
}
}
}

impl<'de, 'a, R> de::Deserializer<'de> for &'a mut Deserializer<'de, R>
impl<'de, 'a, R, E> de::Deserializer<'de> for &'a mut Deserializer<'de, R, E>
where
R: XmlRead<'de>,
E: EntityResolver,
{
type Error = DeError;

Expand Down Expand Up @@ -2705,9 +2769,10 @@ where
///
/// Technically, multiple top-level elements violates XML rule of only one top-level
/// element, but we consider this as several concatenated XML documents.
impl<'de, 'a, R> SeqAccess<'de> for &'a mut Deserializer<'de, R>
impl<'de, 'a, R, E> SeqAccess<'de> for &'a mut Deserializer<'de, R, E>
where
R: XmlRead<'de>,
E: EntityResolver,
{
type Error = DeError;

Expand Down Expand Up @@ -2743,6 +2808,7 @@ impl StartTrimmer {
#[inline(always)]
fn trim<'a>(&mut self, event: Event<'a>) -> Option<PayloadEvent<'a>> {
let (event, trim_next_event) = match event {
Event::DocType(e) => (PayloadEvent::DocType(e), false),
Event::Start(e) => (PayloadEvent::Start(e), true),
Event::End(e) => (PayloadEvent::End(e), true),
Event::Eof => (PayloadEvent::Eof, true),
Expand Down

0 comments on commit 9b5e0e9

Please sign in to comment.