Skip to content

Commit

Permalink
Reject invalid SiLabs upgrade images (#500)
Browse files Browse the repository at this point in the history
* Superficially validate SiLabs upgrade images

* Validate an image CRC32 while parsing its structure

* Explicitly test single-byte corruption
  • Loading branch information
puddly committed Sep 28, 2020
1 parent a0e9b4f commit f9d2f7c
Show file tree
Hide file tree
Showing 4 changed files with 365 additions and 4 deletions.
45 changes: 44 additions & 1 deletion tests/test_ota.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import zigpy.ota
import zigpy.ota.image
import zigpy.ota.provider
import zigpy.ota.validators

from .async_mock import AsyncMock, MagicMock, patch, sentinel

Expand Down Expand Up @@ -42,8 +43,14 @@ def key():
def ota():
app = MagicMock(spec_set=zigpy.application.ControllerApplication)
tradfri = MagicMock(spec_set=zigpy.ota.provider.Trådfri)
validate_ota_image = MagicMock(
spec_set=zigpy.ota.validators.validate_ota_image,
return_value=zigpy.ota.validators.ValidationResult.VALID,
)

with patch("zigpy.ota.provider.Trådfri", tradfri):
return zigpy.ota.OTA(app)
with patch("zigpy.ota.validate_ota_image", validate_ota_image):
yield zigpy.ota.OTA(app)


async def test_ota_initialize(ota):
Expand Down Expand Up @@ -111,6 +118,42 @@ async def test_get_image_new(ota, image, key, image_with_version, monkeypatch):
assert ota.async_event.call_count == 1


async def test_get_image_invalid(ota, image, image_with_version):
corrupted = image_with_version(image.version)

zigpy.ota.validate_ota_image.return_value = (
zigpy.ota.validators.ValidationResult.INVALID
)
ota.async_event = AsyncMock(return_value=[None, corrupted])

assert len(ota._image_cache) == 0
res = await ota.get_ota_image(MANUFACTURER_ID, IMAGE_TYPE)
assert len(ota._image_cache) == 0

assert res is None


@pytest.mark.parametrize("v1", [0, 1])
@pytest.mark.parametrize("v2", [0, 1])
async def test_get_image_invalid_then_valid_versions(v1, v2, ota, image_with_version):
image = image_with_version(100 + v1)
image.header.header_string = b"\x12" * 32

corrupted = image_with_version(100 + v2)
corrupted.header.header_string = b"\x11" * 32

ota.async_event = AsyncMock(return_value=[corrupted, image])
zigpy.ota.validate_ota_image.side_effect = [
zigpy.ota.validators.ValidationResult.INVALID,
zigpy.ota.validators.ValidationResult.VALID,
]

res = await ota.get_ota_image(MANUFACTURER_ID, IMAGE_TYPE)

# The valid image is always picked, even if the versions match
assert res.header.header_string == image.header.header_string


def test_cached_image_expiration(image, monkeypatch):
cached = zigpy.ota.CachedImage.new(image)
assert cached.expired is False
Expand Down
180 changes: 180 additions & 0 deletions tests/test_ota_validators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
import zlib

import pytest

from zigpy.ota import validators
from zigpy.ota.image import ElementTagId, OTAImage, SubElement
from zigpy.ota.validators import ValidationResult


def create_ebl_image(tags):
# All images start with a 140-byte "0x0000" header
tags = [(b"\x00\x00", b"test" * 35)] + tags

assert all([len(tag) == 2 for tag, value in tags])
image = b"".join(tag + len(value).to_bytes(2, "big") + value for tag, value in tags)

# And end with a checksum
image += b"\xFC\x04\x00\x04" + zlib.crc32(image + b"\xFC\x04\x00\x04").to_bytes(
4, "little"
)

if len(image) % 64 != 0:
image += b"\xFF" * (64 - len(image) % 64)

return image


def create_gbl_image(tags):
# All images start with an 8-byte header
tags = [(b"\xEB\x17\xA6\x03", b"\x00\x00\x00\x03\x01\x01\x00\x00")] + tags

assert all([len(tag) == 4 for tag, value in tags])
image = b"".join(
tag + len(value).to_bytes(4, "little") + value for tag, value in tags
)

# And end with a checksum
image += (
b"\xFC\x04\x04\xFC"
+ b"\x04\x00\x00\x00"
+ zlib.crc32(image + b"\xFC\x04\x04\xFC" + b"\x04\x00\x00\x00").to_bytes(
4, "little"
)
)

return image


VALID_EBL_IMAGE = create_ebl_image([(b"ab", b"foo")])
VALID_GBL_IMAGE = create_gbl_image([(b"test", b"foo")])


def create_subelement(tag_id, value):
return SubElement.deserialize(
tag_id.serialize() + len(value).to_bytes(4, "little") + value
)[0]


def test_parse_silabs_ebl():
list(validators.parse_silabs_ebl(VALID_EBL_IMAGE))

image = create_ebl_image([(b"AA", b"test"), (b"BB", b"foo" * 20)])

header, tag1, tag2, checksum = validators.parse_silabs_ebl(image)
assert len(image) % 64 == 0
assert header[0] == b"\x00\x00" and len(header[1]) == 140
assert tag1 == (b"AA", b"test")
assert tag2 == (b"BB", b"foo" * 20)
assert checksum[0] == b"\xFC\x04" and len(checksum[1]) == 4

# Padding needs to be a multiple of 64 bytes
with pytest.raises(AssertionError):
list(validators.parse_silabs_ebl(image[:-1]))

with pytest.raises(AssertionError):
list(validators.parse_silabs_ebl(image + b"\xFF"))

# Corrupted images are detected
corrupted_image = image.replace(b"foo", b"goo", 1)
assert image != corrupted_image

with pytest.raises(AssertionError):
list(validators.parse_silabs_ebl(corrupted_image))


def test_parse_silabs_gbl():
list(validators.parse_silabs_gbl(VALID_GBL_IMAGE))

image = create_gbl_image([(b"AAAA", b"test"), (b"BBBB", b"foo" * 20)])

header, tag1, tag2, checksum = validators.parse_silabs_gbl(image)
assert header[0] == b"\xEB\x17\xA6\x03" and len(header[1]) == 8
assert tag1 == (b"AAAA", b"test")
assert tag2 == (b"BBBB", b"foo" * 20)
assert checksum[0] == b"\xFC\x04\x04\xFC" and len(checksum[1]) == 4

# No padding is allowed
with pytest.raises(AssertionError):
list(validators.parse_silabs_gbl(image + b"\xFF"))

# Corrupted images are detected
corrupted_image = image.replace(b"foo", b"goo", 1)
assert image != corrupted_image

with pytest.raises(AssertionError):
list(validators.parse_silabs_gbl(corrupted_image))


def test_validate_firmware():
assert validators.validate_firmware(VALID_EBL_IMAGE) == ValidationResult.VALID
assert (
validators.validate_firmware(VALID_EBL_IMAGE[:-1]) == ValidationResult.INVALID
)
assert (
validators.validate_firmware(VALID_EBL_IMAGE + b"\xFF")
== ValidationResult.INVALID
)

assert validators.validate_firmware(VALID_GBL_IMAGE) == ValidationResult.VALID
assert (
validators.validate_firmware(VALID_GBL_IMAGE[:-1]) == ValidationResult.INVALID
)

assert validators.validate_firmware(b"UNKNOWN") == ValidationResult.UNKNOWN


def test_validate_ota_image_simple_valid():
image = OTAImage()
image.subelements = [
create_subelement(ElementTagId.UPGRADE_IMAGE, VALID_EBL_IMAGE),
]

assert validators.validate_ota_image(image) == ValidationResult.VALID


def test_validate_ota_image_complex_valid():
image = OTAImage()
image.subelements = [
create_subelement(ElementTagId.ECDSA_SIGNATURE, b"asd"),
create_subelement(ElementTagId.UPGRADE_IMAGE, VALID_EBL_IMAGE),
create_subelement(ElementTagId.UPGRADE_IMAGE, VALID_GBL_IMAGE),
create_subelement(ElementTagId.ECDSA_SIGNING_CERTIFICATE, b"foo"),
]

assert validators.validate_ota_image(image) == ValidationResult.VALID


def test_validate_ota_image_invalid():
image = OTAImage()
image.subelements = [
create_subelement(ElementTagId.UPGRADE_IMAGE, VALID_EBL_IMAGE[:-1]),
]

assert validators.validate_ota_image(image) == ValidationResult.INVALID


def test_validate_ota_image_mixed_invalid():
image = OTAImage()
image.subelements = [
create_subelement(ElementTagId.UPGRADE_IMAGE, b"unknown"),
create_subelement(ElementTagId.UPGRADE_IMAGE, VALID_EBL_IMAGE[:-1]),
]

assert validators.validate_ota_image(image) == ValidationResult.INVALID


def test_validate_ota_image_mixed_valid():
image = OTAImage()
image.subelements = [
create_subelement(ElementTagId.UPGRADE_IMAGE, b"unknown1"),
create_subelement(ElementTagId.UPGRADE_IMAGE, VALID_EBL_IMAGE),
]

assert validators.validate_ota_image(image) == ValidationResult.UNKNOWN


def test_validate_ota_image_empty():
image = OTAImage()

assert validators.validate_ota_image(image) == ValidationResult.UNKNOWN
21 changes: 18 additions & 3 deletions zigpy/ota/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from zigpy.config import CONF_OTA, CONF_OTA_DIR, CONF_OTA_IKEA, CONF_OTA_LEDVANCE
from zigpy.ota.image import ImageKey, OTAImage
import zigpy.ota.provider
from zigpy.ota.validators import ValidationResult, validate_ota_image
from zigpy.typing import ControllerApplicationType
import zigpy.util

Expand Down Expand Up @@ -69,11 +70,25 @@ async def get_ota_image(self, manufacturer_id, image_type) -> Optional[OTAImage]
return self._image_cache[key]

images = await self.async_event("get_image", key)
images = [img for img in images if img]
if not images:
valid_images = []

for image in images:
if image is None:
continue

result = validate_ota_image(image)
LOGGER.debug("Validation result for OTA image %s: %s", image, result)

if result == ValidationResult.INVALID:
LOGGER.error("OTA image %s is invalid!", image)
continue

valid_images.append(image)

if not valid_images:
return None

cached = CachedImage.new(max(images, key=lambda img: img.version))
cached = CachedImage.new(max(valid_images, key=lambda img: img.version))
self._image_cache[key] = cached
return cached

Expand Down

0 comments on commit f9d2f7c

Please sign in to comment.