From e12c606beb1e601b0afc57d02ff193b12d5a5700 Mon Sep 17 00:00:00 2001 From: Alexei Chetroi Date: Mon, 20 May 2019 21:51:38 -0400 Subject: [PATCH 1/2] Fix float type deserialization. --- tests/test_types.py | 14 ++++++++++---- zigpy/types/basic.py | 4 ++-- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/tests/test_types.py b/tests/test_types.py index f2260adaf..86be0738d 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -13,15 +13,21 @@ def test_int_too_short(): def test_single(): - v = t.Single(1.25) + value = 1.25 + extra = b'ab12!' + v = t.Single(value) ser = v.serialize() - assert t.Single.deserialize(ser) == (1.25, b'') + assert t.Single.deserialize(ser) == (value, b'') + assert t.Single.deserialize(ser + extra) == (value, extra) def test_double(): - v = t.Double(1.25) + value = 1.25 + extra = b'ab12!' + v = t.Double(value) ser = v.serialize() - assert t.Double.deserialize(ser) == (1.25, b'') + assert t.Double.deserialize(ser) == (value, b'') + assert t.Double.deserialize(ser + extra) == (value, extra) def test_lvbytes(): diff --git a/zigpy/types/basic.py b/zigpy/types/basic.py index e64358bdf..4e3486feb 100644 --- a/zigpy/types/basic.py +++ b/zigpy/types/basic.py @@ -130,7 +130,7 @@ def serialize(self): @classmethod def deserialize(cls, data): - return struct.unpack(' Date: Mon, 20 May 2019 22:16:22 -0400 Subject: [PATCH 2/2] Refactor float types. Raise if buffer size is incorrect. --- tests/test_types.py | 6 ++++++ zigpy/types/basic.py | 17 +++++++++-------- 2 files changed, 15 insertions(+), 8 deletions(-) diff --git a/tests/test_types.py b/tests/test_types.py index 86be0738d..faecb82fe 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -20,6 +20,9 @@ def test_single(): assert t.Single.deserialize(ser) == (value, b'') assert t.Single.deserialize(ser + extra) == (value, extra) + with pytest.raises(ValueError): + t.Double.deserialize(ser[1:]) + def test_double(): value = 1.25 @@ -29,6 +32,9 @@ def test_double(): assert t.Double.deserialize(ser) == (value, b'') assert t.Double.deserialize(ser + extra) == (value, extra) + with pytest.raises(ValueError): + t.Double.deserialize(ser[1:]) + def test_lvbytes(): d, r = t.LVBytes.deserialize(b'\x0412345') diff --git a/zigpy/types/basic.py b/zigpy/types/basic.py index 4e3486feb..187e2cdf8 100644 --- a/zigpy/types/basic.py +++ b/zigpy/types/basic.py @@ -125,21 +125,22 @@ class bitmap64(uint64_t): # noqa: N801 class Single(float): + _fmt = '