Skip to content
This repository has been archived by the owner on Aug 15, 2021. It is now read-only.

Allow integer map keys to accept string-encoded integers #207

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
157 changes: 156 additions & 1 deletion src/de.rs
Original file line number Diff line number Diff line change
Expand Up @@ -972,6 +972,160 @@ where
}
}

// A wrapper around `visitor` to convert strings to integers. An version of this
// struct is created for each integer type `$numtype` (e.g. u16). Any call to
// `visit_str()` for this visitor will attempt to parse it to `$numtype` and pass
// call `visit_$numtype()` on the wrapped visitor. Calls to `visit_$numtype()` are
// passed through directly. All other visits result in an error.
macro_rules! int_or_string_visitor {
($visitor_name:ident, $visit:ident, $numtype:ty) => {
struct $visitor_name<'de, V: de::Visitor<'de>> {
inner_visitor: V,
_phantom_data: std::marker::PhantomData<&'de ()>,
}

impl<'de, V: de::Visitor<'de>> de::Visitor<'_> for $visitor_name<'de, V> {
type Value = V::Value;

fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(formatter, "an integer or string-encoded integer key")
}

fn visit_str<E>(self, v: &str) -> std::result::Result<Self::Value, E>
where
E: de::Error,
{
let num = v.parse().map_err(|e| E::custom(e))?;
self.inner_visitor.$visit(num)
}

fn $visit<E>(self, v: $numtype) -> std::result::Result<Self::Value, E>
where
E: de::Error,
{
self.inner_visitor.$visit(v)
}
}
};
}

int_or_string_visitor!(IntOrStringVisitorI8, visit_i8, i8);
int_or_string_visitor!(IntOrStringVisitorI16, visit_i16, i16);
int_or_string_visitor!(IntOrStringVisitorI32, visit_i32, i32);
int_or_string_visitor!(IntOrStringVisitorI64, visit_i64, i64);
int_or_string_visitor!(IntOrStringVisitorU8, visit_u8, u8);
int_or_string_visitor!(IntOrStringVisitorU16, visit_u16, u16);
int_or_string_visitor!(IntOrStringVisitorU32, visit_u32, u32);
int_or_string_visitor!(IntOrStringVisitorU64, visit_u64, u64);

serde::serde_if_integer128! {
int_or_string_visitor!(IntOrStringVisitorI128, visit_i128, i128);
int_or_string_visitor!(IntOrStringVisitorU128, visit_u128, u128);
}

// A type that can deserialize strings of integers (e.g. "4") to integers.
// This is necessary for compatibility with serde_json.
struct MapKey<'a, R> {
de: &'a mut Deserializer<R>,
}

macro_rules! deserialize_integer_key {
($visitor_name:ident, $method:ident) => {
fn $method<V>(self, visitor: V) -> Result<V::Value>
where
V: de::Visitor<'de>,
{
// Deserialize the next value, which should be either an integer
// of the type corresponding to $visitor_name, (e.g. u16), or it
// should be a string that can be parsed to that type.
self.de.deserialize_any($visitor_name {
inner_visitor: visitor,
_phantom_data: Default::default(),
})
}
};
}

impl<'de, 'a, R> de::Deserializer<'de> for MapKey<'a, R>
where
R: Read<'de>,
{
type Error = Error;

#[inline]
fn deserialize_any<V>(self, visitor: V) -> Result<V::Value>
where
V: de::Visitor<'de>,
{
self.de.deserialize_any(visitor)
}

deserialize_integer_key!(IntOrStringVisitorI8, deserialize_i8);
deserialize_integer_key!(IntOrStringVisitorI16, deserialize_i16);
deserialize_integer_key!(IntOrStringVisitorI32, deserialize_i32);
deserialize_integer_key!(IntOrStringVisitorI64, deserialize_i64);
deserialize_integer_key!(IntOrStringVisitorU8, deserialize_u8);
deserialize_integer_key!(IntOrStringVisitorU16, deserialize_u16);
deserialize_integer_key!(IntOrStringVisitorU32, deserialize_u32);
deserialize_integer_key!(IntOrStringVisitorU64, deserialize_u64);

serde::serde_if_integer128! {
deserialize_integer_key!(IntOrStringVisitorI128, deserialize_i128);
deserialize_integer_key!(IntOrStringVisitorU128, deserialize_u128);
}

#[inline]
fn deserialize_option<V>(self, visitor: V) -> Result<V::Value>
where
V: de::Visitor<'de>,
{
// Map keys cannot be null.
visitor.visit_some(self)
}

#[inline]
fn deserialize_newtype_struct<V>(self, _name: &'static str, visitor: V) -> Result<V::Value>
where
V: de::Visitor<'de>,
{
visitor.visit_newtype_struct(self)
}

#[inline]
fn deserialize_enum<V>(
self,
name: &'static str,
variants: &'static [&'static str],
visitor: V,
) -> Result<V::Value>
where
V: de::Visitor<'de>,
{
self.de.deserialize_enum(name, variants, visitor)
}

#[inline]
fn deserialize_bytes<V>(self, visitor: V) -> Result<V::Value>
where
V: de::Visitor<'de>,
{
self.de.deserialize_bytes(visitor)
}

#[inline]
fn deserialize_byte_buf<V>(self, visitor: V) -> Result<V::Value>
where
V: de::Visitor<'de>,
{
self.de.deserialize_bytes(visitor)
}

serde::forward_to_deserialize_any! {
bool f32 f64 char str string unit unit_struct seq tuple tuple_struct map
struct identifier ignored_any
}
}

struct MapAccess<'a, R> {
de: &'a mut Deserializer<R>,
len: &'a mut usize,
Expand Down Expand Up @@ -1004,7 +1158,8 @@ where
_ => {}
};

let value = seed.deserialize(&mut *self.de)?;
let value = seed.deserialize(MapKey { de: &mut *self.de })?;

Ok(Some(value))
}

Expand Down
73 changes: 73 additions & 0 deletions tests/de.rs
Original file line number Diff line number Diff line change
Expand Up @@ -744,4 +744,77 @@ mod std_tests {
let err = serde_cbor::from_slice::<serde_cbor::Value>(&input).expect_err("recursion limit");
assert!(err.is_syntax());
}

#[test]
fn test_int_as_string_map_keys() {
use std::collections::HashMap;

// Given a map with keys that are strings, but happen to be strings of integers
// e.g. "4", try to deserialize it into a HashMap<i32, ...>. This should
// work to have compatibility with serde_json.
let mut input = HashMap::<String, i32>::new();
input.insert("1".to_string(), 1);
input.insert("12345".to_string(), 2);
input.insert("-2".to_string(), 3);
let buf = to_vec(&input).unwrap();

let deserialized = from_slice::<HashMap<i16, i32>>(&buf).unwrap();

assert_eq!(deserialized.len(), 3);
assert_eq!(deserialized.get(&1), Some(&1));
assert_eq!(deserialized.get(&12345), Some(&2));
assert_eq!(deserialized.get(&-2), Some(&3));
}

#[test]
fn test_int_as_string_map_keys_unsigned_err() {
use std::collections::HashMap;

let mut input = HashMap::<String, i32>::new();
input.insert("1".to_string(), 1);
input.insert("-2".to_string(), 3);
let buf = to_vec(&input).unwrap();

// Should fail because key is negative.
from_slice::<HashMap<u16, i32>>(&buf).expect_err("");
}

#[test]
fn test_int_as_string_map_keys_out_of_range() {
use std::collections::HashMap;

let mut input = HashMap::<String, i32>::new();
input.insert("1".to_string(), 1);
input.insert("12345".to_string(), 2);
let buf = to_vec(&input).unwrap();

// Should fail because key is out of range.
from_slice::<HashMap<u8, i32>>(&buf).expect_err("");
}

#[test]
fn test_int_as_string_map_keys_empty() {
use std::collections::HashMap;

let mut input = HashMap::<String, i32>::new();
input.insert("1".to_string(), 1);
input.insert("".to_string(), 2);
let buf = to_vec(&input).unwrap();

// Should fail because key is empty
from_slice::<HashMap<u8, i32>>(&buf).expect_err("");
}

#[test]
fn test_int_as_string_map_keys_invalid() {
use std::collections::HashMap;

let mut input = HashMap::<String, i32>::new();
input.insert("1".to_string(), 1);
input.insert("12.5".to_string(), 2);
let buf = to_vec(&input).unwrap();

// Should fail because key is not an integer
from_slice::<HashMap<u8, i32>>(&buf).expect_err("");
}
}