diff --git a/Cargo.lock b/Cargo.lock index d5646e293d..4f93e243f5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2207,18 +2207,18 @@ dependencies = [ [[package]] name = "serde" -version = "1.0.210" +version = "1.0.219" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c8e3592472072e6e22e0a54d5904d9febf8508f65fb8552499a1abc7d1078c3a" +checksum = "5f0e2c6ed6606019b4e29e69dbaba95b11854410e5347d525002456dbbb786b6" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.210" +version = "1.0.219" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "243902eda00fad750862fc144cea25caca5e20d615af0a81bee94ca738f1df1f" +checksum = "5b0276cf7f2c73365f7157c8123c21cd9a50fbbd844757af28ca1f5925fc2a00" dependencies = [ "proc-macro2", "quote", @@ -2737,7 +2737,9 @@ dependencies = [ "log", "maybe-async", "prost", + "rand 0.8.5", "regex", + "serde", "serde_json", "serial_test", "smol", diff --git a/rust/BUILD b/rust/BUILD index 4a393d7516..aa1d10c459 100644 --- a/rust/BUILD +++ b/rust/BUILD @@ -37,6 +37,7 @@ typedb_driver_deps = [ "@crates//:itertools", "@crates//:log", "@crates//:prost", + "@crates//:serde", "@crates//:tokio", "@crates//:tokio-stream", "@crates//:tonic", @@ -70,7 +71,10 @@ rust_library( rust_test( name = "typedb_driver_unit_tests", crate = ":typedb_driver", - deps = ["@crates//:serde_json"], + deps = [ + "@crates//:rand", + "@crates//:serde_json", + ], ) assemble_crate( diff --git a/rust/Cargo.toml b/rust/Cargo.toml index dd1990d4b3..4d4c66c059 100644 --- a/rust/Cargo.toml +++ b/rust/Cargo.toml @@ -16,6 +16,11 @@ [dev-dependencies] + [dev-dependencies.rand] + features = ["alloc", "default", "getrandom", "libc", "rand_chacha", "small_rng", "std", "std_rng"] + version = "0.8.5" + default-features = false + [dev-dependencies.smol] features = [] version = "1.3.0" @@ -69,6 +74,11 @@ version = "0.4.27" default-features = false + [dependencies.serde] + features = ["alloc", "default", "derive", "rc", "serde_derive", "std"] + version = "1.0.219" + default-features = false + [dependencies.tokio-stream] features = ["default", "net", "time"] version = "0.1.17" diff --git a/rust/src/answer/json.rs b/rust/src/answer/json.rs index ce28296b07..263f6014f3 100644 --- a/rust/src/answer/json.rs +++ b/rust/src/answer/json.rs @@ -21,9 +21,16 @@ use std::{ borrow::Cow, collections::HashMap, fmt::{self, Write}, + iter, }; -#[derive(Clone, Debug)] +use itertools::Itertools; +use serde::{ + ser::{SerializeMap, SerializeSeq}, + Deserialize, Serialize, +}; + +#[derive(Clone, Debug, PartialEq)] pub enum JSON { Object(HashMap, JSON>), Array(Vec), @@ -112,9 +119,154 @@ fn write_escaped_string(string: &str, f: &mut fmt::Formatter<'_>) -> fmt::Result write!(f, r#""{}""#, unsafe { String::from_utf8_unchecked(buf) }) } +impl Serialize for JSON { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + match self { + Self::Object(object) => { + let mut map = serializer.serialize_map(Some(object.len()))?; + for (key, value) in object { + map.serialize_entry(key, value)?; + } + map.end() + } + Self::Array(array) => { + let mut seq = serializer.serialize_seq(Some(array.len()))?; + for item in array { + seq.serialize_element(item)?; + } + seq.end() + } + Self::String(string) => serializer.serialize_str(string), + &Self::Number(number) => serializer.serialize_f64(number), + &Self::Boolean(boolean) => serializer.serialize_bool(boolean), + Self::Null => serializer.serialize_unit(), + } + } +} + +impl<'de> Deserialize<'de> for JSON { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + struct Visitor; + + impl<'de> serde::de::Visitor<'de> for Visitor { + type Value = JSON; + + fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result { + formatter.write_str("a valid JSON value") + } + + fn visit_bool(self, value: bool) -> Result + where + E: serde::de::Error, + { + Ok(JSON::Boolean(value)) + } + + fn visit_i64(self, value: i64) -> Result + where + E: serde::de::Error, + { + Ok(JSON::Number(value as f64)) + } + + fn visit_i128(self, value: i128) -> Result + where + E: serde::de::Error, + { + Ok(JSON::Number(value as f64)) + } + + fn visit_u64(self, value: u64) -> Result + where + E: serde::de::Error, + { + Ok(JSON::Number(value as f64)) + } + + fn visit_u128(self, value: u128) -> Result + where + E: serde::de::Error, + { + Ok(JSON::Number(value as f64)) + } + + fn visit_f64(self, value: f64) -> Result + where + E: serde::de::Error, + { + Ok(JSON::Number(value)) + } + + fn visit_str(self, value: &str) -> Result + where + E: serde::de::Error, + { + Ok(JSON::String(Cow::Owned(value.to_owned()))) + } + + fn visit_string(self, value: String) -> Result + where + E: serde::de::Error, + { + Ok(JSON::String(Cow::Owned(value))) + } + + fn visit_none(self) -> Result + where + E: serde::de::Error, + { + Ok(JSON::Null) + } + + fn visit_some(self, deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + JSON::deserialize(deserializer) + } + + fn visit_unit(self) -> Result + where + E: serde::de::Error, + { + Ok(JSON::Null) + } + + fn visit_seq(self, mut seq: A) -> Result + where + A: serde::de::SeqAccess<'de>, + { + Ok(JSON::Array(iter::from_fn(|| seq.next_element().transpose()).try_collect()?)) + } + + fn visit_map(self, mut map: A) -> Result + where + A: serde::de::MapAccess<'de>, + { + Ok(JSON::Object(iter::from_fn(|| map.next_entry().transpose()).try_collect()?)) + } + } + + deserializer.deserialize_any(Visitor) + } +} + #[cfg(test)] mod test { - use std::borrow::Cow; + use std::{borrow::Cow, collections::HashMap, iter}; + + use rand::{ + distributions::{DistString, Distribution, Standard, WeightedIndex}, + rngs::ThreadRng, + thread_rng, Rng, + }; + use serde_json::json; use super::JSON; @@ -126,4 +278,62 @@ mod test { let json_string = JSON::String(Cow::Owned(string)); assert_eq!(serde_json::to_string(&serde_json_value).unwrap(), json_string.to_string()); } + + fn sample_json() -> JSON { + JSON::Object(HashMap::from([ + ("array".into(), JSON::Array(vec![JSON::Boolean(true), JSON::String("string".into())])), + ("number".into(), JSON::Number(123.4)), + ])) + } + + #[test] + fn serialize() { + let ser = serde_json::to_value(sample_json()).unwrap(); + let value = json!( { "array": [true, "string"], "number": 123.4 }); + assert_eq!(ser, value); + } + + #[test] + fn deserialize() { + let deser: JSON = serde_json::from_str(r#"{ "array": [true, "string"], "number": 123.4 }"#).unwrap(); + let json = sample_json(); + assert_eq!(deser, json); + } + + fn random_string(rng: &mut impl Rng) -> String { + let len = rng.gen_range(0..64); + Standard.sample_string(rng, len) + } + + fn random_json(rng: &mut R) -> JSON { + let weights = [1, 1, 3, 3, 3, 3]; + let generators: &[fn(&mut R) -> JSON] = &[ + |rng| { + let len = rng.gen_range(0..12); + JSON::Object( + iter::from_fn(|| Some((Cow::Owned(random_string(rng)), random_json(rng)))).take(len).collect(), + ) + }, + |rng| { + let len = rng.gen_range(0..12); + JSON::Array(iter::from_fn(|| Some(random_json(rng))).take(len).collect()) + }, + |rng| JSON::String(Cow::Owned(random_string(rng))), + |rng| JSON::Number(rng.gen()), + |rng| JSON::Boolean(rng.gen()), + |_| JSON::Null, + ]; + let dist = WeightedIndex::new(&weights).unwrap(); + generators[dist.sample(rng)](rng) + } + + #[test] + fn serde_roundtrip() { + let mut rng = thread_rng(); + for _ in 0..1000 { + let json = random_json(&mut rng); + let deser = serde_json::from_value(serde_json::to_value(&json).unwrap()).unwrap(); + assert_eq!(json, deser); + } + } }