From 5e8885a17cd6bb8b5b6e37d9569548cbeacc4fc8 Mon Sep 17 00:00:00 2001 From: Alexey Timin Date: Fri, 28 Jul 2023 11:03:37 +0200 Subject: [PATCH] Redesign API (#3) * redesign library * clean code * fix pylint * add methods for scalars * fix examples * fix pylint * fix example * fix clang build * fix casting * cpp-17 * fix clang build * print types if they missmatch * use real type * force type * update CHANGELOG --- .gitignore | 1 + CHANGELOG.md | 1 + CMakeLists.txt | 1 + README.md | 18 +- conan/test_package/CMakeLists.txt | 2 +- conan/test_package/src/example.cc | 15 +- drift_bytes/bytes.h | 334 ++++++++++++++++++++++++++--- pyproject.toml | 1 + python/src/drift_bytes/__init__.py | 2 +- python/src/drift_bytes/bytes.py | 322 +++++++++++++-------------- python/src/main.cc | 271 +++++++++++++---------- python/tests/__init__.py | 0 python/tests/test_buffers.py | 17 ++ python/tests/test_bytes.py | 22 -- python/tests/test_bytes_scalars.py | 128 ----------- python/tests/test_bytes_vectors.py | 110 ---------- python/tests/test_variant.py | 40 ++++ tests/bytes_test.cc | 130 ++++------- 18 files changed, 744 insertions(+), 671 deletions(-) create mode 100644 python/tests/__init__.py create mode 100644 python/tests/test_buffers.py delete mode 100644 python/tests/test_bytes.py delete mode 100644 python/tests/test_bytes_scalars.py delete mode 100644 python/tests/test_bytes_vectors.py create mode 100644 python/tests/test_variant.py diff --git a/.gitignore b/.gitignore index 5cbe593..511c53a 100644 --- a/.gitignore +++ b/.gitignore @@ -3,3 +3,4 @@ dist python/CMakeFiles python/*.cmake python/src/drift_bytes.egg-info +*.so diff --git a/CHANGELOG.md b/CHANGELOG.md index fb41d16..cdbfb5d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed - Use `set_` prefix for setters, [PR-2](https://github.com/panda-official/DriftBytes/pull/2) +- Redesign API, [PR-3](https://github.com/panda-official/DriftBytes/pull/3) ## 0.1.0 - 2023-07-20 diff --git a/CMakeLists.txt b/CMakeLists.txt index aa3cfbf..27e169c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -49,6 +49,7 @@ set_target_properties( ${TARGET_NAME} PROPERTIES CXX_STANDARD 17 CXX_STANDARD_REQUIRED ON CXX_EXTENSIONS OFF ) +target_compile_features(${TARGET_NAME} INTERFACE cxx_std_17) target_include_directories( ${TARGET_NAME} diff --git a/README.md b/README.md index 2fdb542..9d849cc 100644 --- a/README.md +++ b/README.md @@ -32,16 +32,22 @@ which is suitable for non-floating point data. #include -using drift_bytes::Bytes; +using drift_bytes::InputBuffer; +using drift_bytes::OutputBuffer; +using drift_bytes::Variant; int main() { - uint8_t val{42}; - auto bytes = Bytes(); - bytes.set_scalar(val); - auto new_val = bytes.set_scalar(); + Variant some_value{42}; - std::cout << new_val << std::endl; + OutputBuffer buffer; + buffer.push_back(some_value); + + InputBuffer input(buffer.str()); + Variant new_val = input.pop(); + + std::cout << new_val << std::endl; } + ``` ## Bulding diff --git a/conan/test_package/CMakeLists.txt b/conan/test_package/CMakeLists.txt index 54aced7..cf81dd7 100644 --- a/conan/test_package/CMakeLists.txt +++ b/conan/test_package/CMakeLists.txt @@ -4,5 +4,5 @@ project(WaveletBufferTest CXX) find_package(drift_bytes CONFIG REQUIRED) add_executable(example src/example.cc) -target_compile_features(example PUBLIC cxx_std_20) +target_compile_features(example PUBLIC cxx_std_17) target_link_libraries(example drift_bytes::drift_bytes) diff --git a/conan/test_package/src/example.cc b/conan/test_package/src/example.cc index 3571f1e..e471e06 100644 --- a/conan/test_package/src/example.cc +++ b/conan/test_package/src/example.cc @@ -2,13 +2,18 @@ #include -using drift_bytes::Bytes; +using drift_bytes::InputBuffer; +using drift_bytes::OutputBuffer; +using drift_bytes::Variant; int main() { - uint8_t val{42}; - auto bytes = Bytes(); - bytes.set_scalar(val); - auto new_val = bytes.scalar(); + Variant some_value{42}; + + OutputBuffer buffer; + buffer.push_back(some_value); + + InputBuffer input(buffer.str()); + Variant new_val = input.pop(); std::cout << new_val << std::endl; } diff --git a/drift_bytes/bytes.h b/drift_bytes/bytes.h index 7aace29..90409d3 100644 --- a/drift_bytes/bytes.h +++ b/drift_bytes/bytes.h @@ -11,6 +11,8 @@ #include #include #include +#include +#include #include #include @@ -19,58 +21,322 @@ namespace drift_bytes { -/** - * Serializes and deserializes variables. - */ -class Bytes { +const uint8_t kVersion = 0; + +enum Type : uint8_t { + kBool = 0, + kInt8 = 1, + kUInt8 = 2, + kInt16 = 3, + kUInt16 = 4, + kInt32 = 5, + kUInt32 = 6, + kInt64 = 7, + kUInt64 = 8, + kFloat32 = 9, + kFloat64 = 10, + kString = 11, +}; + +static const std::vector kSupportedType = { + "bool", "int8", "uint8", "int16", "uint16", "int32", + "uint32", "int64", "uint64", "float32", "float64", "string"}; + +using Shape = std::vector; + +using VarArray = std::vector< + std::variant>; + +class Variant { public: - using Shape = std::vector; + Variant(Shape shape, VarArray data) + : type_(), shape_(std::move(shape)), data_(std::move(data)) { + if (data_.empty()) { + throw std::out_of_range("Data is empty"); + } - Bytes() = default; - explicit Bytes(std::string &&bytes) { buffer_ << bytes; } + if (!shape_.empty()) { + bool match = + (std::accumulate(shape_.begin(), shape_.end(), 1, + std::multiplies()) == data_.size()); + if (!match) { + throw std::out_of_range("Shape and data size do not match"); + } + } else { + throw std::out_of_range("Shape is empty"); + } - std::string str() const { return buffer_.str(); } + type_ = static_cast(data_[0].index()); + } template - T scalar() { - cereal::PortableBinaryInputArchive archive(buffer_); - T t; - archive(t); - return t; + explicit Variant(T value) : type_(), shape_(), data_() { + shape_ = {1}; + data_ = {value}; + type_ = static_cast(data_[0].index()); } - template - std::vector vec() { - cereal::PortableBinaryInputArchive archive(buffer_); - std::vector t; - archive(t); - return t; + template + operator T() const { + if (type_ != static_cast(std::variant().index())) { + throw std::runtime_error("Type mismatch"); + } + + if (shape_ != Shape{1}) { + throw std::runtime_error("Looks like it is a vector"); + } + + return std::get(data_[0]); } - template - std::vector> mat() { + [[nodiscard]] Type type() const { return type_; } + [[nodiscard]] const Shape &shape() const { return shape_; } + [[nodiscard]] const VarArray &data() const { return data_; } + + bool operator==(const Variant &rhs) const { + return type_ == rhs.type_ && shape_ == rhs.shape_ && data_ == rhs.data_; + } + + bool operator!=(const Variant &rhs) const { return !(rhs == *this); } + + friend std::ostream &operator<<(std::ostream &os, const Variant &variant) { + os << "Variant(type:" << kSupportedType[variant.type_] << ", shape:{"; + for (auto &dim : variant.shape_) { + os << dim << ","; + } + os << "}, data:{"; + + for (auto &value : variant.data_) { + switch (variant.type_) { + case kBool: { + os << std::get(value) << ", "; + break; + } + case kInt8: { + os << std::get(value) << ","; + break; + } + case kUInt8: { + os << std::get(value) << ","; + break; + } + case kInt16: { + os << std::get(value) << ","; + break; + } + case kUInt16: { + os << std::get(value) << ","; + break; + } + case kInt32: { + os << std::get(value) << ","; + break; + } + case kUInt32: { + os << std::get(value) << ","; + break; + } + case kInt64: { + os << std::get(value) << ","; + break; + } + case kUInt64: { + os << std::get(value) << ","; + break; + } + case kFloat32: { + os << std::get(value) << ","; + break; + } + case kFloat64: { + os << std::get(value) << ","; + break; + } + case kString: { + os << std::get(value) << ","; + break; + } + } + } + + os << "})"; + return os; + } + + private: + Type type_; + Shape shape_; + VarArray data_; +}; + +class InputBuffer { + public: + explicit InputBuffer(std::string &&bytes) { + buffer_ << bytes; cereal::PortableBinaryInputArchive archive(buffer_); - std::vector> t; - archive(t); - return t; + uint8_t version; + archive(version); + + if (version != kVersion) { + throw std::runtime_error("Version mismatch"); + } } - template - void set_scalar(const T &t) { - cereal::PortableBinaryOutputArchive archive(buffer_); - archive(t); + std::string str() const { return buffer_.str(); } + + Variant pop() { + cereal::PortableBinaryInputArchive archive(buffer_); + Type type; + Shape shape; + archive(type, shape); + + auto size = std::accumulate(shape.begin(), shape.end(), 1, + std::multiplies()); + VarArray data(size); + for (auto &value : data) { + switch (type) { + case kBool: { + bool bool_value; + archive(bool_value); + value = bool_value; + break; + } + case kInt8: { + int8_t int8_value; + archive(int8_value); + value = int8_value; + break; + } + case kUInt8: { + uint8_t uint8_value; + archive(uint8_value); + value = uint8_value; + break; + } + + case kInt16: { + int16_t int16_value; + archive(int16_value); + value = int16_value; + break; + } + case kUInt16: { + uint16_t uint16_value; + archive(uint16_value); + value = uint16_value; + break; + } + case kInt32: { + int32_t int32_value; + archive(int32_value); + value = int32_value; + break; + } + case kUInt32: { + uint32_t uint32_value; + archive(uint32_value); + value = uint32_value; + break; + } + case kInt64: { + int64_t int64_value; + archive(int64_value); + value = int64_value; + break; + } + case kUInt64: { + uint64_t uint64_value; + archive(uint64_value); + value = uint64_value; + break; + } + case kFloat32: { + float float_value; + archive(float_value); + value = float_value; + break; + } + case kFloat64: { + double double_value; + archive(double_value); + value = double_value; + break; + } + case kString: { + std::string string_value; + archive(string_value); + value = string_value; + break; + } + default: + throw std::runtime_error("Unknown type"); + } + } + + return {shape, data}; } - template - void set_vec(const std::vector &t) { + bool empty() const { return buffer_.rdbuf()->in_avail() == 0; } + + private: + std::stringstream buffer_; +}; + +class OutputBuffer { + public: + OutputBuffer() : buffer_() { cereal::PortableBinaryOutputArchive archive(buffer_); - archive(t); + archive(kVersion); } - template - void set_mat(const std::vector> &t) { + std::string str() const { return buffer_.str(); } + + void push_back(const Variant &variant) { cereal::PortableBinaryOutputArchive archive(buffer_); - archive(t); + archive(variant.type(), variant.shape()); + for (const auto &value : variant.data()) { + switch (variant.type()) { + case kBool: + archive(std::get(value)); + break; + case kInt8: + archive(std::get(value)); + break; + case kUInt8: + archive(std::get(value)); + break; + case kInt16: + archive(std::get(value)); + break; + case kUInt16: + archive(std::get(value)); + break; + case kInt32: + archive(std::get(value)); + break; + case kUInt32: + archive(std::get(value)); + break; + case kInt64: + archive(std::get(value)); + break; + case kUInt64: + archive(std::get(value)); + break; + case kFloat32: + archive(std::get(value)); + break; + case kFloat64: + archive(std::get(value)); + break; + case kString: + archive(std::get(value)); + break; + default: + throw std::runtime_error("Unknown type"); + } + } } private: diff --git a/pyproject.toml b/pyproject.toml index 4f26292..a8d0e93 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,3 +23,4 @@ environment-pass = ["VERSION_SUFFIX"] [tool.pylint.MASTER] good-names = "b" +extension-pkg-allow-list = ["pybind11", "drift_bytes._drift_bytes"] diff --git a/python/src/drift_bytes/__init__.py b/python/src/drift_bytes/__init__.py index 17b9088..93abe5c 100644 --- a/python/src/drift_bytes/__init__.py +++ b/python/src/drift_bytes/__init__.py @@ -1,2 +1,2 @@ """Bytes""" -from .bytes import Bytes +from .bytes import Variant, InputBuffer, OutputBuffer diff --git a/python/src/drift_bytes/bytes.py b/python/src/drift_bytes/bytes.py index 8803722..c4c3f66 100644 --- a/python/src/drift_bytes/bytes.py +++ b/python/src/drift_bytes/bytes.py @@ -1,165 +1,175 @@ # pylint: disable=missing-docstring, too-many-public-methods, useless-super-delegation """Bindings for the C++ implementation of the Bytes class.""" -from typing import List +from typing import List, Union, Optional import drift_bytes._drift_bytes as impl # pylint: disable=import-error, no-name-in-module -class Bytes(impl.Bytes): - """Bytes class""" - +class Variant: + TYPES = impl.supported_types() # pylint: disable=c-extension-no-member + SUPPORTED_TYPES = Union[ + bool, + int, + float, + str, + List[bool], + List[int], + List[float], + List[str], + ] + + def __init__( + self, + value: SUPPORTED_TYPES, + kind: Optional[str] = None, + ): + """Create Variant object from value + + Args: + kind (str): Type of value can be: bool, uint8, int8, + uint16, int16, uint32, int32, uint64, int64, float32, float64, string + value (Union[bool, int, float, str, List[bool], List[int], List[float], + List[str]]): Value to be stored in Variant object + """ + + if isinstance(value, impl.Variant): + # for internal use only to pop from InputBuffer + self._variant = value + self._shape = value.shape() + return + + type_error = TypeError( + f"Unsupported type: {kind}. Must be one of: {self.TYPES}" + ) + + type = self._find_type(kind, type_error, value) + + if not isinstance(value, list): + value = [value] + + self._shape = [len(value)] + self._make_variant(type, self._shape, value) + + def _find_type(self, kind, type_error, value): # pylint: disable=too-many-branches + if kind is None: + if isinstance(value, bool): + kind = "bool" + elif isinstance(value, int): + kind = "int64" + elif isinstance(value, float): + kind = "float64" + elif isinstance(value, str): + kind = "string" + elif isinstance(value, list): + if len(value) == 0: + raise ValueError("Empty list cannot be converted to Variant") + if isinstance(value[0], bool): + kind = "bool" + elif isinstance(value[0], int): + kind = "int64" + elif isinstance(value[0], float): + kind = "float64" + elif isinstance(value[0], str): + kind = "string" + else: + raise type_error + elif isinstance(value, Variant): + kind = value.type + else: + raise type_error + if kind not in self.TYPES: + raise type_error + return kind + + def _make_variant(self, kind, shape, value): + if kind == "bool": + self._variant = impl.Variant.from_bools(shape, value) + elif kind == "uint8": + self._variant = impl.Variant.from_int8s(shape, value) + elif kind == "int8": + self._variant = impl.Variant.from_int8s(shape, value) + elif kind == "uint16": + self._variant = impl.Variant.from_uint16s(shape, value) + elif kind == "int16": + self._variant = impl.Variant.from_int16s(shape, value) + elif kind == "uint32": + self._variant = impl.Variant.from_uint32s(shape, value) + elif kind == "int32": + self._variant = impl.Variant.from_int32s(shape, value) + elif kind == "uint64": + self._variant = impl.Variant.from_uint64s(shape, value) + elif kind == "int64": + self._variant = impl.Variant.from_int64s(shape, value) + elif kind == "float32": + self._variant = impl.Variant.from_float32s(shape, value) + elif kind == "float64": + self._variant = impl.Variant.from_float64s(shape, value) + elif kind == "string": + self._variant = impl.Variant.from_strings(shape, value) + + @property + def type(self) -> str: + """Get type""" + return self._variant.type() + + @property + def shape(self) -> List[int]: + """Get shape""" + return self._variant.shape() + + @property + def value(self) -> SUPPORTED_TYPES: # pylint: disable=too-many-branches + """Get value""" + if self.type == "bool": + ary = self._variant.to_bools() + elif self.type == "uint8": + ary = self._variant.to_uint8s() + elif self.type == "int8": + ary = self._variant.to_int8s() + elif self.type == "uint16": + ary = self._variant.to_uint16s() + elif self.type == "int16": + ary = self._variant.to_int16s() + elif self.type == "uint32": + ary = self._variant.to_uint32s() + elif self.type == "int32": + ary = self._variant.to_int32s() + elif self.type == "uint64": + ary = self._variant.to_uint64s() + elif self.type == "int64": + ary = self._variant.to_int64s() + elif self.type == "float32": + ary = self._variant.to_float32s() + elif self.type == "float64": + ary = self._variant.to_float64s() + elif self.type == "string": + ary = self._variant.to_strings() + else: + raise TypeError(f"Unsupported type: {self.type}") + + if self.shape == [1]: + return ary[0] + + return ary + + +class InputBuffer: + def __init__(self, buffer: bytes): + self._buffer = impl.InputBuffer.from_bytes(buffer) + + def pop(self) -> Variant: + return Variant(self._buffer.pop()) + + def empty(self) -> bool: + return self._buffer.empty() + + +class OutputBuffer: def __init__(self): - super().__init__() - - @classmethod - def from_bytes(cls, data: bytes) -> "Bytes": - """Create Bytes object from bytes""" - return impl.Bytes.from_bytes(data) - - def to_bytes(self) -> bytes: - """Serialize Bytes object to bytes""" - return super().to_bytes() - - def get_bool(self) -> bool: - return super().get_bool() - - def set_bool(self, value: bool) -> None: - super().set_bool(value) - - def get_int8(self) -> int: - return super().get_int8() - - def set_int8(self, value: int) -> None: - super().set_int8(value) - - def get_int16(self) -> int: - return super().get_int16() - - def set_int16(self, value: int) -> None: - super().set_int16(value) - - def get_int32(self) -> int: - return super().get_int32() - - def set_int32(self, value: int) -> None: - super().set_int32(value) - - def get_int64(self) -> int: - return super().get_int64() - - def set_int64(self, value: int) -> None: - super().set_int64(value) - - def get_uint8(self) -> int: - return super().get_uint8() - - def set_uint8(self, value: int) -> None: - super().set_uint8(value) - - def get_uint16(self) -> int: - return super().get_uint16() - - def set_uint16(self, value: int) -> None: - super().set_uint16(value) - - def get_uint32(self) -> int: - return super().get_uint32() - - def set_uint32(self, value: int) -> None: - super().set_uint32(value) - - def get_uint64(self) -> int: - return super().get_uint64() - - def set_uint64(self, value: int) -> None: - super().set_uint64(value) - - def get_float32(self) -> float: - return super().get_float32() - - def set_float32(self, value: float) -> None: - super().set_float32(value) - - def get_float64(self) -> float: - return super().get_float64() - - def set_float64(self, value: float) -> None: - super().set_float64(value) - - def set_string(self, value: str) -> None: - super().set_string(value) - - def get_string(self) -> str: - return super().get_string() - - def get_bool_vec(self) -> List[bool]: - return super().get_bool_vec() - - def set_bool_vec(self, value: List[bool]) -> None: - super().set_bool_vec(value) - - def get_int8_vec(self) -> List[int]: - return super().get_int8_vec() - - def set_int8_vec(self, value: List[int]) -> None: - super().set_int8_vec(value) - - def get_int16_vec(self) -> List[int]: - return super().get_int16_vec() - - def set_int16_vec(self, value: List[int]) -> None: - super().set_int16_vec(value) - - def get_int32_vec(self) -> List[int]: - return super().get_int32_vec() - - def set_int32_vec(self, value: List[int]) -> None: - super().set_int32_vec(value) - - def get_int64_vec(self) -> List[int]: - return super().get_int64_vec() - - def set_int64_vec(self, value: List[int]) -> None: - super().set_int64_vec(value) - - def get_uint8_vec(self) -> List[int]: - return super().get_uint8_vec() - - def set_uint8_vec(self, value: List[int]) -> None: - super().set_uint8_vec(value) - - def get_uint16_vec(self) -> List[int]: - return super().get_uint16_vec() - - def set_uint16_vec(self, value: List[int]) -> None: - super().set_uint16_vec(value) - - def get_uint32_vec(self) -> List[int]: - return super().get_uint32_vec() - - def set_uint32_vec(self, value: List[int]) -> None: - super().set_uint32_vec(value) - - def get_uint64_vec(self) -> List[int]: - return super().get_uint64_vec() - - def set_uint64_vec(self, value: List[int]) -> None: - super().set_uint64_vec(value) - - def get_float32_vec(self) -> List[float]: - return super().get_float32_vec() - - def set_float32_vec(self, value: List[float]) -> None: - super().set_float32_vec(value) - - def get_float64_vec(self) -> List[float]: - return super().get_float64_vec() - - def set_float64_vec(self, value: List[float]) -> None: - super().set_float64_vec(value) + self._buffer = impl.OutputBuffer() - def get_string_vec(self) -> List[str]: - return super().get_string_vec() + def push(self, value: Variant): + self._buffer.push(value._variant) # pylint: disable=protected-access - def set_string_vec(self, value: List[str]) -> None: - super().set_string_vec(value) + def bytes(self): + return self._buffer.bytes() diff --git a/python/src/main.cc b/python/src/main.cc index 408e0a3..f38ea5c 100644 --- a/python/src/main.cc +++ b/python/src/main.cc @@ -3,138 +3,171 @@ #include #include +#include + namespace py = pybind11; -using drift_bytes::Bytes; +using drift_bytes::InputBuffer; +using drift_bytes::kSupportedType; +using drift_bytes::OutputBuffer; +using drift_bytes::Shape; +using drift_bytes::Type; +using drift_bytes::VarArray; +using drift_bytes::Variant; + +template +std::vector make_array(const Variant &variant) { + std::vector ary(std::accumulate(variant.shape().begin(), + variant.shape().end(), 1, + std::multiplies())); + for (int i = 0; i < ary.size(); ++i) { + try { + ary[i] = std::get(variant.data()[i]); + } catch (const std::bad_variant_access &) { + throw std::runtime_error( + "Type mismatch: " + kSupportedType[variant.type()] + + " != " + kSupportedType[std::variant().index()]); + } + } + return ary; +} PYBIND11_MODULE(_drift_bytes, m) { - auto klass = py::class_(m, "Bytes"); - klass.def(py::init<>()) + m.def("supported_types", + []() -> std::vector { return kSupportedType; }); + + auto variant = py::class_(m, "Variant"); + variant + .def_static("from_bools", + [](Shape shape, std::vector array) -> Variant { + VarArray var_array(array.size()); + for (int i = 0; i < array.size(); ++i) { + var_array[i] = static_cast( + array[i]); // We do it because apple-clang converts + // bool to int + } + + return {std::move(shape), std::move(var_array)}; + }) + .def_static( + "from_int8s", + [](Shape shape, std::vector array) -> Variant { + return {std::move(shape), VarArray(array.begin(), array.end())}; + }) + .def_static( + "from_uint8s", + [](Shape shape, std::vector array) -> Variant { + return {std::move(shape), VarArray(array.begin(), array.end())}; + }) + .def_static( + "from_int16s", + [](Shape shape, std::vector array) -> Variant { + return {std::move(shape), VarArray(array.begin(), array.end())}; + }) + .def_static( + "from_uint16s", + [](Shape shape, std::vector array) -> Variant { + return {std::move(shape), VarArray(array.begin(), array.end())}; + }) + .def_static( + "from_int32s", + [](Shape shape, std::vector array) -> Variant { + return {std::move(shape), VarArray(array.begin(), array.end())}; + }) .def_static( - "from_bytes", - [](py::bytes bytes) { return drift_bytes::Bytes(std::move(bytes)); }) - .def("to_bytes", - [](drift_bytes::Bytes &bytes) { return py::bytes(bytes.str()); }) - // Scalar types - .def("get_bool", - [](drift_bytes::Bytes &bytes) { return bytes.scalar(); }) - .def("set_bool", - [](drift_bytes::Bytes &bytes, bool val) { bytes.set_scalar(val); }) - .def("get_int8", - [](drift_bytes::Bytes &bytes) { return bytes.scalar(); }) - .def("set_int8", - [](drift_bytes::Bytes &bytes, int8_t val) { bytes.set_scalar(val); }) - .def("get_int16", - [](drift_bytes::Bytes &bytes) { return bytes.scalar(); }) - .def("set_int16", [](drift_bytes::Bytes &bytes, - int16_t val) { bytes.set_scalar(val); }) - .def("get_int32", - [](drift_bytes::Bytes &bytes) { return bytes.scalar(); }) - .def("set_int32", [](drift_bytes::Bytes &bytes, - int32_t val) { bytes.set_scalar(val); }) - .def("get_int64", - [](drift_bytes::Bytes &bytes) { return bytes.scalar(); }) - .def("set_int64", [](drift_bytes::Bytes &bytes, - int64_t val) { bytes.set_scalar(val); }) - .def("get_uint8", - [](drift_bytes::Bytes &bytes) { return bytes.scalar(); }) - .def("set_uint8", [](drift_bytes::Bytes &bytes, - uint8_t val) { bytes.set_scalar(val); }) - .def("get_uint16", - [](drift_bytes::Bytes &bytes) { return bytes.scalar(); }) - .def("set_uint16", [](drift_bytes::Bytes &bytes, - uint16_t val) { bytes.set_scalar(val); }) - .def("get_uint32", - [](drift_bytes::Bytes &bytes) { return bytes.scalar(); }) - .def("set_uint32", [](drift_bytes::Bytes &bytes, - uint32_t val) { bytes.set_scalar(val); }) - .def("get_uint64", - [](drift_bytes::Bytes &bytes) { return bytes.scalar(); }) - .def("set_uint64", [](drift_bytes::Bytes &bytes, - uint64_t val) { bytes.set_scalar(val); }) - .def("get_float32", - [](drift_bytes::Bytes &bytes) { return bytes.scalar(); }) - .def("set_float32", - [](drift_bytes::Bytes &bytes, float val) { bytes.set_scalar(val); }) - .def("get_float64", - [](drift_bytes::Bytes &bytes) { return bytes.scalar(); }) - .def("set_float64", - [](drift_bytes::Bytes &bytes, double val) { bytes.set_scalar(val); }) - .def( - "get_string", - [](drift_bytes::Bytes &bytes) { return bytes.scalar(); }) - .def("set_string", [](drift_bytes::Bytes &bytes, - const std::string &val) { bytes.set_scalar(val); }) - // Vector types - .def("get_bool_vec", - [](drift_bytes::Bytes &bytes) { return bytes.vec(); }) - .def("set_bool_vec", - [](drift_bytes::Bytes &bytes, const std::vector &val) { - bytes.set_vec(val); + "from_uint32s", + [](Shape shape, std::vector array) -> Variant { + return {std::move(shape), VarArray(array.begin(), array.end())}; + }) + .def_static( + "from_int64s", + [](Shape shape, std::vector array) -> Variant { + return {std::move(shape), VarArray(array.begin(), array.end())}; + }) + .def_static( + "from_uint64s", + [](Shape shape, std::vector array) -> Variant { + return {std::move(shape), VarArray(array.begin(), array.end())}; + }) + .def_static( + "from_float32s", + [](Shape shape, std::vector array) -> Variant { + return {std::move(shape), VarArray(array.begin(), array.end())}; + }) + .def_static( + "from_float64s", + [](Shape shape, std::vector array) -> Variant { + return {std::move(shape), VarArray(array.begin(), array.end())}; + }) + .def_static( + "from_strings", + [](Shape shape, std::vector array) -> Variant { + return {std::move(shape), VarArray(array.begin(), array.end())}; + }) + .def("to_bools", + [](Variant &variant) -> std::vector { + return make_array(variant); }) - .def("get_int8_vec", - [](drift_bytes::Bytes &bytes) { return bytes.vec(); }) - .def("set_int8_vec", - [](drift_bytes::Bytes &bytes, const std::vector &val) { - bytes.set_vec(val); + .def("to_int8s", + [](Variant &variant) -> std::vector { + return make_array(variant); }) - .def("get_int16_vec", - [](drift_bytes::Bytes &bytes) { return bytes.vec(); }) - .def("set_int16_vec", - [](drift_bytes::Bytes &bytes, const std::vector &val) { - bytes.set_vec(val); + .def("to_uint8s", + [](Variant &variant) -> std::vector { + return make_array(variant); }) - .def("get_int32_vec", - [](drift_bytes::Bytes &bytes) { return bytes.vec(); }) - .def("set_int32_vec", - [](drift_bytes::Bytes &bytes, const std::vector &val) { - bytes.set_vec(val); + .def("to_int16s", + [](Variant &variant) -> std::vector { + return make_array(variant); }) - .def("get_int64_vec", - [](drift_bytes::Bytes &bytes) { return bytes.vec(); }) - .def("set_int64_vec", - [](drift_bytes::Bytes &bytes, const std::vector &val) { - bytes.set_vec(val); + .def("to_uint16s", + [](Variant &variant) -> std::vector { + return make_array(variant); }) - .def("get_uint8_vec", - [](drift_bytes::Bytes &bytes) { return bytes.vec(); }) - .def("set_uint8_vec", - [](drift_bytes::Bytes &bytes, const std::vector &val) { - bytes.set_vec(val); + .def("to_int32s", + [](Variant &variant) -> std::vector { + return make_array(variant); }) - .def("get_uint16_vec", - [](drift_bytes::Bytes &bytes) { return bytes.vec(); }) - .def("set_uint16_vec", - [](drift_bytes::Bytes &bytes, const std::vector &val) { - bytes.set_vec(val); + .def("to_uint32s", + [](Variant &variant) -> std::vector { + return make_array(variant); }) - .def("get_uint32_vec", - [](drift_bytes::Bytes &bytes) { return bytes.vec(); }) - .def("set_uint32_vec", - [](drift_bytes::Bytes &bytes, const std::vector &val) { - bytes.set_vec(val); + .def("to_int64s", + [](Variant &variant) -> std::vector { + return make_array(variant); }) - .def("get_uint64_vec", - [](drift_bytes::Bytes &bytes) { return bytes.vec(); }) - .def("set_uint64_vec", - [](drift_bytes::Bytes &bytes, const std::vector &val) { - bytes.set_vec(val); + .def("to_uint64s", + [](Variant &variant) -> std::vector { + return make_array(variant); }) - .def("get_float32_vec", - [](drift_bytes::Bytes &bytes) { return bytes.vec(); }) - .def("set_float32_vec", - [](drift_bytes::Bytes &bytes, const std::vector &val) { - bytes.set_vec(val); + .def("to_float32s", + [](Variant &variant) -> std::vector { + return make_array(variant); }) - .def("get_float64_vec", - [](drift_bytes::Bytes &bytes) { return bytes.vec(); }) - .def("set_float64_vec", - [](drift_bytes::Bytes &bytes, const std::vector &val) { - bytes.set_vec(val); + .def("to_float64s", + [](Variant &variant) -> std::vector { + return make_array(variant); }) - .def("get_string_vec", - [](drift_bytes::Bytes &bytes) { return bytes.vec(); }) - .def("set_string_vec", - [](drift_bytes::Bytes &bytes, const std::vector &val) { - bytes.set_vec(val); - }); + .def("to_strings", + [](Variant &variant) -> std::vector { + return make_array(variant); + }) + .def("type", + [](Variant &variant) -> std::string { + return kSupportedType.at(variant.type()); + }) + .def("shape", [](Variant &variant) -> Shape { return variant.shape(); }); + + auto input_buffer = py::class_(m, "InputBuffer"); + input_buffer + .def_static("from_bytes", + [](py::bytes bytes) { return InputBuffer(std::move(bytes)); }) + .def("pop", [](InputBuffer &buffer) -> Variant { return buffer.pop(); }) + .def("empty", [](InputBuffer &buffer) -> bool { return buffer.empty(); }); + + auto output_buffer = py::class_(m, "OutputBuffer"); + output_buffer.def(py::init()) + .def("push", [](OutputBuffer &buffer, + const Variant &variant) { buffer.push_back(variant); }) + .def("bytes", + [](OutputBuffer &buffer) -> py::bytes { return buffer.str(); }); } diff --git a/python/tests/__init__.py b/python/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/python/tests/test_buffers.py b/python/tests/test_buffers.py new file mode 100644 index 0000000..dc28df6 --- /dev/null +++ b/python/tests/test_buffers.py @@ -0,0 +1,17 @@ +"""Test the InputBuffer and OutputBuffer classes.""" +from drift_bytes import Variant, InputBuffer, OutputBuffer + + +def test_input_output(): + """Should push and pop""" + out_buf = OutputBuffer() + out_buf.push(Variant([1, 2, 3, 4, 5, 6])) + + in_buf = InputBuffer(out_buf.bytes()) + + var = in_buf.pop() + assert var.type == "int64" + assert var.shape == [6] + assert var.value == [1, 2, 3, 4, 5, 6] + + assert in_buf.empty() diff --git a/python/tests/test_bytes.py b/python/tests/test_bytes.py deleted file mode 100644 index c223e1b..0000000 --- a/python/tests/test_bytes.py +++ /dev/null @@ -1,22 +0,0 @@ -"""Test for Bytes""" -import pytest - -from drift_bytes import Bytes - - -def test__bad_encoding(): - """Test bad encoding""" - b = Bytes() - b.set_bool(True) - - with pytest.raises(RuntimeError): - b.get_float32() - - -def test_serialization(): - """Test serialization""" - b = Bytes() - b.set_int8(42) - - b = Bytes.from_bytes(b.to_bytes()) - assert b.get_int8() == 42 diff --git a/python/tests/test_bytes_scalars.py b/python/tests/test_bytes_scalars.py deleted file mode 100644 index 0400b07..0000000 --- a/python/tests/test_bytes_scalars.py +++ /dev/null @@ -1,128 +0,0 @@ -"""Test for scalar types""" -from drift_bytes import Bytes - - -def test__bool(): - """Should be able to set and get bools""" - b = Bytes() - b.set_bool(True) - assert b.get_bool() - - b.set_bool(False) - assert not b.get_bool() - - -def test__int8(): - """Should be able to set and get int8s""" - b = Bytes() - b.set_int8(0) - assert b.get_int8() == 0 - - b.set_int8(127) - assert b.get_int8() == 127 - - b.set_int8(-128) - assert b.get_int8() == -128 - - -def test__int16(): - """Should be able to set and get int16s""" - b = Bytes() - b.set_int16(0) - assert b.get_int16() == 0 - - b.set_int16(32767) - assert b.get_int16() == 32767 - - b.set_int16(-32768) - assert b.get_int16() == -32768 - - -def test__int32(): - """Should be able to set and get int32s""" - b = Bytes() - b.set_int32(0) - assert b.get_int32() == 0 - - b.set_int32(2147483647) - assert b.get_int32() == 2147483647 - - b.set_int32(-2147483648) - assert b.get_int32() == -2147483648 - - -def test__int64(): - """Should be able to set and get int64s""" - b = Bytes() - b.set_int64(0) - assert b.get_int64() == 0 - - b.set_int64(9223372036854775807) - assert b.get_int64() == 9223372036854775807 - - b.set_int64(-9223372036854775808) - assert b.get_int64() == -9223372036854775808 - - -def test__uint8(): - """Should be able to set and get uint8s""" - b = Bytes() - b.set_uint8(0) - assert b.get_uint8() == 0 - - b.set_uint8(255) - assert b.get_uint8() == 255 - - -def test__uint16(): - """Should be able to set and get uint16s""" - b = Bytes() - b.set_uint16(0) - assert b.get_uint16() == 0 - - b.set_uint16(65535) - assert b.get_uint16() == 65535 - - -def test__uint32(): - """Should be able to set and get uint32s""" - b = Bytes() - b.set_uint32(0) - assert b.get_uint32() == 0 - - b.set_uint32(4294967295) - assert b.get_uint32() == 4294967295 - - -def test__uint64(): - """Should be able to set and get uint64s""" - b = Bytes() - b.set_uint64(0) - assert b.get_uint64() == 0 - - b.set_uint64(18446744073709551615) - assert b.get_uint64() == 18446744073709551615 - - -def test__float32(): - """Should be able to set and get float32s""" - b = Bytes() - b.set_float32(1.125) - - assert b.get_float32() == 1.125 - - -def test__float64(): - """Should be able to set and get float64s""" - b = Bytes() - b.set_float64(1.123456) - - assert b.get_float64() == 1.123456 - - -def test__string(): - """Should be able to set and get strings""" - b = Bytes() - b.set_string("Hello World, ÄÖÜäöüß") - - assert b.get_string() == "Hello World, ÄÖÜäöüß" diff --git a/python/tests/test_bytes_vectors.py b/python/tests/test_bytes_vectors.py deleted file mode 100644 index bf6b316..0000000 --- a/python/tests/test_bytes_vectors.py +++ /dev/null @@ -1,110 +0,0 @@ -"""Test vectors in Bytes class""" -from drift_bytes import Bytes - - -def test__bool_vec(): - """Test bool vectors""" - b = Bytes() - b.set_bool_vec([True, False, True]) - assert b.get_bool_vec() == [True, False, True] - b.set_bool_vec([False, True, False]) - assert b.get_bool_vec() == [False, True, False] - - -def test__int8_vec(): - """Test int8 vectors""" - b = Bytes() - b.set_int8_vec([0, 127, -128]) - assert b.get_int8_vec() == [0, 127, -128] - b.set_int8_vec([127, -128, 0]) - assert b.get_int8_vec() == [127, -128, 0] - - -def test__int16_vec(): - """Test int16 vectors""" - b = Bytes() - b.set_int16_vec([0, 32767, -32768]) - assert b.get_int16_vec() == [0, 32767, -32768] - b.set_int16_vec([32767, -32768, 0]) - assert b.get_int16_vec() == [32767, -32768, 0] - - -def test__int32_vec(): - """Test int32 vectors""" - b = Bytes() - b.set_int32_vec([0, 2147483647, -2147483648]) - assert b.get_int32_vec() == [0, 2147483647, -2147483648] - b.set_int32_vec([2147483647, -2147483648, 0]) - assert b.get_int32_vec() == [2147483647, -2147483648, 0] - - -def test__int64_vec(): - """Test int64 vectors""" - b = Bytes() - b.set_int64_vec([0, 9223372036854775807, -9223372036854775808]) - assert b.get_int64_vec() == [0, 9223372036854775807, -9223372036854775808] - b.set_int64_vec([9223372036854775807, -9223372036854775808, 0]) - assert b.get_int64_vec() == [9223372036854775807, -9223372036854775808, 0] - - -def test__uint8_vec(): - """Test uint8 vectors""" - b = Bytes() - b.set_uint8_vec([0, 255]) - assert b.get_uint8_vec() == [0, 255] - b.set_uint8_vec([255, 0]) - assert b.get_uint8_vec() == [255, 0] - - -def test__uint16_vec(): - """Test uint16 vectors""" - b = Bytes() - b.set_uint16_vec([0, 65535]) - assert b.get_uint16_vec() == [0, 65535] - b.set_uint16_vec([65535, 0]) - assert b.get_uint16_vec() == [65535, 0] - - -def test__uint32_vec(): - """Test uint32 vectors""" - b = Bytes() - b.set_uint32_vec([0, 4294967295]) - assert b.get_uint32_vec() == [0, 4294967295] - b.set_uint32_vec([4294967295, 0]) - assert b.get_uint32_vec() == [4294967295, 0] - - -def test__uint64_vec(): - """Test uint64 vectors""" - b = Bytes() - b.set_uint64_vec([0, 18446744073709551615]) - assert b.get_uint64_vec() == [0, 18446744073709551615] - b.set_uint64_vec([18446744073709551615, 0]) - assert b.get_uint64_vec() == [18446744073709551615, 0] - - -def test__float32_vec(): - """Test float32 vectors""" - b = Bytes() - b.set_float32_vec([0.0, 1.0, -1.0]) - assert b.get_float32_vec() == [0.0, 1.0, -1.0] - b.set_float32_vec([1.0, -1.0, 0.0]) - assert b.get_float32_vec() == [1.0, -1.0, 0.0] - - -def test__float64_vec(): - """Test float64 vectors""" - b = Bytes() - b.set_float64_vec([0.0, 1.0, -1.0]) - assert b.get_float64_vec() == [0.0, 1.0, -1.0] - b.set_float64_vec([1.0, -1.0, 0.0]) - assert b.get_float64_vec() == [1.0, -1.0, 0.0] - - -def test__string_vec(): - """Test string vectors""" - b = Bytes() - b.set_string_vec(["", "a", "abc"]) - assert b.get_string_vec() == ["", "a", "abc"] - b.set_string_vec(["abc", "a", ""]) - assert b.get_string_vec() == ["abc", "a", ""] diff --git a/python/tests/test_variant.py b/python/tests/test_variant.py new file mode 100644 index 0000000..56060a4 --- /dev/null +++ b/python/tests/test_variant.py @@ -0,0 +1,40 @@ +"""Variant tests""" +import pytest +from drift_bytes import Variant + + +@pytest.mark.parametrize( + "kind, value", + [ + ("bool", True), + ("int32", 1), + ("float32", 1.0), + ("string", "1"), + ("uint64", [1, 2, 3, 4]), + ], +) +def test_init(kind, value): + """Test Variant init""" + var = Variant(value, kind) + + assert var.type == kind + assert var.shape == [len(value)] if isinstance(value, list) else [1] + assert var.value == value + + +@pytest.mark.parametrize( + "kind, value", + [ + ("bool", True), + ("int64", 1), + ("float64", 1.0), + ("string", "1"), + ("int64", [1, 2, 3, 4]), + ], +) +def test_init_suggested_type(kind, value): + """Should suggest type""" + var = Variant(value) + assert var.type == kind + assert var.shape == [len(value)] if isinstance(value, list) else [1] + assert var.value == value diff --git a/tests/bytes_test.cc b/tests/bytes_test.cc index 95eed1a..ed6f975 100644 --- a/tests/bytes_test.cc +++ b/tests/bytes_test.cc @@ -8,103 +8,55 @@ #include "catch2/generators/catch_generators.hpp" -using drift_bytes::Bytes; - -TEST_CASE("Scalars") { - auto val = GENERATE( - true, std::numeric_limits::max(), - std::numeric_limits::max(), - std::numeric_limits::max(), - std::numeric_limits::max(), std::numeric_limits::max(), - std::numeric_limits::max(), std::numeric_limits::max(), - std::numeric_limits::max(), std::numeric_limits::max(), - std::numeric_limits::max()); - - CAPTURE(val); - - auto bytes = drift_bytes::Bytes(); - bytes.set_scalar(val); - - decltype(val) new_val; - new_val = bytes.scalar(); - - REQUIRE(new_val == val); -} - -TEST_CASE("Strings") { - std::string val = - GENERATE("Hello", "World", "Hello World", "Hello World!", "äöü"); - - CAPTURE(val); - - auto bytes = drift_bytes::Bytes(); - bytes.set_scalar(val); - - auto new_val = bytes.scalar(); - - REQUIRE(new_val == val); +using drift_bytes::InputBuffer; +using drift_bytes::OutputBuffer; +using drift_bytes::Shape; +using drift_bytes::Type; +using drift_bytes::Variant; + +TEST_CASE("Full test") { + Variant var1({1, 3}, {1, 2, 3}); + Variant var2 = + GENERATE(Variant({2}, {true, false}), Variant({2}, {1.0, 2.0}), + Variant({2}, {"Hello", "World"}), Variant({3}, {1l, 2l, 3l}), + Variant({3}, {1ul, 2ul, 3ul}), Variant({3}, {1.0f, 2.0f, 3.0f})); + + OutputBuffer out; + out.push_back(var1); + out.push_back(var2); + + InputBuffer in(out.str()); + + REQUIRE(in.pop() == var1); + REQUIRE_FALSE(in.empty()); + + REQUIRE(in.pop() == var2); + REQUIRE(in.empty()); } -TEST_CASE("Vectors") { - std::vector val = GENERATE(std::vector{1, 2, 3, 4, 5}, - std::vector{1, 2, 3, 4, 5, 6, 7, 8, 9}); - - CAPTURE(val); - - auto bytes = drift_bytes::Bytes(); - bytes.set_vec(val); +TEST_CASE("Variant: Test scalars") { + auto value = GENERATE(true, uint8_t{9}, int8_t{-9}, uint16_t{9}, int16_t{-9}, + uint32_t{9}, int32_t{-9}, uint64_t{9}, int64_t{-9}, + float{9.0}, double{9.0}); - auto new_val = bytes.vec(); + Variant var{value}; - REQUIRE(new_val == val); + REQUIRE(var.shape() == Shape{1}); + REQUIRE(value == decltype(value)(var)); } -TEST_CASE("Matrices") { - std::vector> val = GENERATE( - std::vector>{{1, 2}, {3, 4}, {5, 6}, {7, 8}, {9, 10}}, - std::vector>{{1, 2, 3, 4, 5}, {6, 7, 8, 9, 10}}); - - CAPTURE(val); - - auto bytes = drift_bytes::Bytes(); - bytes.set_vec(val); +TEST_CASE("Variant: Test strings") { + std::string value = "Hello World, ÄÖÜß should be UTF-8"; + Variant var{value}; - auto new_val = bytes.mat(); - - REQUIRE(new_val == val); -} - -TEST_CASE("Mixed data") { - int a; - std::vector fvec = {1.0, 2.0, 3.0}; - std::string s = "Hello World!"; - std::vector> mat = {{1, 2, 3}, {4, 5, 6}}; - - auto bytes = drift_bytes::Bytes(); - bytes.set_scalar(a); - bytes.set_vec(fvec); - bytes.set_scalar(s); - bytes.set_mat(mat); - - auto new_a = bytes.scalar(); - auto new_fvec = bytes.vec(); - auto new_s = bytes.scalar(); - auto new_mat = bytes.mat(); - - REQUIRE(new_a == a); - REQUIRE(new_fvec == fvec); - REQUIRE(new_s == s); - REQUIRE(new_mat == mat); + REQUIRE(var.shape() == Shape{1}); + REQUIRE(value == static_cast(var)); } -TEST_CASE("Serialization") { - std::vector> mat = {{1, 2, 3}, {4, 5, 6}}; - - auto bytes = drift_bytes::Bytes(); - bytes.set_mat(mat); - - bytes = drift_bytes::Bytes(bytes.str()); +TEST_CASE("Variant: test stream") { + Variant var({1, 3}, {1, 2, 3}); + std::stringstream ss; + ss << var; - auto new_mat = bytes.mat(); - REQUIRE(new_mat == mat); + REQUIRE(ss.str() == "Variant(type:int32, shape:{1,3,}, data:{1,2,3,})"); }