From dbe6cd61502426cb57df56e46c58591c76361e9a Mon Sep 17 00:00:00 2001 From: Daniel Moreno Manzano Date: Tue, 18 Nov 2025 10:23:25 +0100 Subject: [PATCH] feat: Support all stroke versions --- py_path_signature/data_models/stroke.py | 35 +++++++++++++++++-- py_path_signature/path_signature_extractor.py | 6 ++-- 2 files changed, 36 insertions(+), 5 deletions(-) diff --git a/py_path_signature/data_models/stroke.py b/py_path_signature/data_models/stroke.py index 0656fac..6991253 100644 --- a/py_path_signature/data_models/stroke.py +++ b/py_path_signature/data_models/stroke.py @@ -1,10 +1,21 @@ +from abc import abstractmethod +from typing import List, Tuple, Annotated, Iterable import math from py_path_signature.data_models.basic import BasicModel from py_path_signature.data_models.error_messages import ERROR_MESSAGES -from pydantic import root_validator, validator +from pydantic import root_validator, validator, Field from pydantic.types import conlist +class Stroke(BasicModel): + + @abstractmethod + def x_generator(self) -> Iterable[float]: + pass + + @abstractmethod + def y_generator(self) -> Iterable[float]: + pass class StrokeFormatError(Exception): """Custom error that is raised when an input stroke doesn't have the right format.""" @@ -14,7 +25,7 @@ def __init__(self, message: str) -> None: super().__init__(message) -class Stroke(BasicModel): +class StrokeDict(Stroke): x: conlist(float, min_items=1) y: conlist(float, min_items=1) @@ -35,3 +46,23 @@ def check_value_is_nan(cls, value): if math.isnan(value): raise StrokeFormatError(message=ERROR_MESSAGES["NAN_DETECTED"]) return value + + def x_generator(self) -> Iterable[float]: + for i in self.x: + yield i + + def y_generator(self) -> Iterable[float]: + for i in self.y: + yield i + +class StrokeTuple(Stroke): + + coordinates: Annotated[List[Tuple[float, float]], Field(min_length=2)] + + def x_generator(self) -> Iterable[float]: + for (x,y) in self.coordinates: + yield x + + def y_generator(self) -> Iterable[float]: + for (x,y) in self.coordinates: + yield y diff --git a/py_path_signature/path_signature_extractor.py b/py_path_signature/path_signature_extractor.py index eb4d275..12c940d 100644 --- a/py_path_signature/path_signature_extractor.py +++ b/py_path_signature/path_signature_extractor.py @@ -109,8 +109,8 @@ def calculate_bounding_box(strokes: List[Stroke]) -> Tuple[int, int, int, int]: if len(strokes) == 0: raise Exception("Empty list of strokes.") - x_values = [x for stroke in strokes for x in stroke.x] - y_values = [y for stroke in strokes for y in stroke.y] + x_values = [x for stroke in strokes for x in stroke.x_generator()] + y_values = [y for stroke in strokes for y in stroke.y_generator()] x_min = min(x_values) x_max = max(x_values) @@ -140,7 +140,7 @@ def from_coordinates_to_pixels(self, strokes: List[Stroke]) -> List[List[Tuple[i pixels = [] for stroke in strokes: out_stroke = [] - for (x, y) in zip(stroke.x, stroke.y): + for (x, y) in zip(stroke.x_generator(), stroke.y_generator()): x = (self.rendering_size[1] - 1) * (x - x_min) / w y = (self.rendering_size[0] - 1) * (y - y_min) / h