diff --git a/CHANGELOG.md b/CHANGELOG.md index d16b65a..95083a7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,10 @@ The Runway Model SDK follows [semantic versioning](https://semver.org/). Be awar Until version 1.0.0, expect that minor version changes may introduce breaking changes. We will take care not to introduce new behavior, features, or breaking changes in patch releases. If you require stability and reproducible behavior you *may* pin to a version or version range of the model SDK like `runway-python>=0.2.0` or `runway-python>=0.2,<0.3`. +## v.0.3.2 + +- Make segmentation serialize as a 3-channel color map when used as an output field, instead of a 1-channel label map. + ## v.0.3.1 - Remove default values for `min`, `max`, and `step` parameters of `number` data type. diff --git a/runway/__version__.py b/runway/__version__.py index e1424ed..73e3bb4 100644 --- a/runway/__version__.py +++ b/runway/__version__.py @@ -1 +1 @@ -__version__ = '0.3.1' +__version__ = '0.3.2' diff --git a/runway/data_types.py b/runway/data_types.py index dda5446..65cb600 100644 --- a/runway/data_types.py +++ b/runway/data_types.py @@ -485,12 +485,12 @@ class segmentation(BaseType): different object class. When used as an input data type, `segmentation` accepts a 1-channel base64-encoded PNG image, - where each pixel takes the value of one of the ids defined in `pixel_to_id`, or a 3-channel + where each pixel takes the value of one of the ids defined in `label_to_id`, or a 3-channel base64-encoded PNG colormap image, where each pixel takes the value of one of the colors - defined in `pixel_to_color`. + defined in `label_to_color`. - When used as an output data type, it serializes as a 1-channel base64-encoded PNG image, - where each pixel takes the value of one of the ids defined in `pixel_to_id`. + When used as an output data type, it serializes as a 3-channel base64-encoded PNG image, + where each pixel takes the value of one of the colors defined in `label_to_color`. .. code-block:: python @@ -564,6 +564,14 @@ def colormap_to_segmentation(self, img): seg[(cmap==color).all(axis=2)] = label_id return Image.fromarray(seg, 'L') + def segmentation_to_colormap(self, img): + seg = np.array(img) + cmap = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8) + for label, id in self.label_to_id.items(): + label_color = self.label_to_color[label] + cmap[(seg==id)] = label_color + return Image.fromarray(cmap, 'RGB') + def deserialize(self, value): try: image = value[value.find(",")+1:] @@ -585,6 +593,8 @@ def serialize(self, value): im_pil = value else: raise InvalidArgumentError(self.name, 'value is not a PIL or numpy image') + if im_pil.mode == 'L': + im_pil = self.segmentation_to_colormap(im_pil) buffer = IO() im_pil.save(buffer, format='PNG') return 'data:image/png;base64,' + base64.b64encode(buffer.getvalue()).decode('utf8') diff --git a/tests/test_data_types.py b/tests/test_data_types.py index 3e12008..c4c9a8c 100644 --- a/tests/test_data_types.py +++ b/tests/test_data_types.py @@ -410,9 +410,16 @@ def test_segmentation_to_dict(): assert obj['labelToColor'] == {"background": [0, 0, 0], "person": [140, 59, 255]} assert obj['description'] == None -def test_segmentation_serialize_and_deserialize(): +def test_segmentation_serialize_and_deserialize_colormap(): directory = os.path.dirname(os.path.realpath(__file__)) - img = Image.open(os.path.join(directory, 'test_segmentation.png')) + img = Image.open(os.path.join(directory, 'test_segmentation_colormap.png')) + serialized_pil = segmentation(label_to_id={"background": 0, "person": 1}).serialize(img) + deserialized_pil = segmentation(label_to_id={"background": 0, "person": 1}).deserialize(serialized_pil) + assert issubclass(type(deserialized_pil), Image.Image) + +def test_segmentation_serialize_and_deserialize_labelmap(): + directory = os.path.dirname(os.path.realpath(__file__)) + img = Image.open(os.path.join(directory, 'test_segmentation_labelmap.png')) serialized_pil = segmentation(label_to_id={"background": 0, "person": 1}).serialize(img) deserialized_pil = segmentation(label_to_id={"background": 0, "person": 1}).deserialize(serialized_pil) assert issubclass(type(deserialized_pil), Image.Image) diff --git a/tests/test_segmentation.png b/tests/test_segmentation.png deleted file mode 100644 index 7917519..0000000 Binary files a/tests/test_segmentation.png and /dev/null differ diff --git a/tests/test_segmentation_colormap.png b/tests/test_segmentation_colormap.png new file mode 100644 index 0000000..8fbd082 Binary files /dev/null and b/tests/test_segmentation_colormap.png differ diff --git a/tests/test_segmentation_labelmap.png b/tests/test_segmentation_labelmap.png new file mode 100644 index 0000000..c22dd3c Binary files /dev/null and b/tests/test_segmentation_labelmap.png differ