Skip to content

Commit

Permalink
[SymForce-External] Assert Matrix shape is correct when copying matrix
Browse files Browse the repository at this point in the history
Assert Matrix shape is correct when copying matrix

Previously it was possible to construct a `sf.V2` with `sf.V3(sf.V2())`.
Not only is this confusing, but it can also be used to trick the type
checker into thinking a matrix is of a type it is not (for example, mypy
assumes `sf.V3(sf.V2())` is a `sf.V3`).

An example where this change is useful is, for `a = sf.M23()` and
`b = sf.M34()`, `sf.M24(a * b)`. This is because mypy cannot tell that
`a * b` has type `sf.M24`. With this change, not only does wrapping the
expression in `sf.M24` communicate this fact to mypy, but it also
performs a runtime check that it does in fact have correct shape.

In `matrix.py`, had to change several methods to use `Matrix` instead of
`self.__class__` or `cls` to construct a new matrix object when the new
matrix was not guarenteed to have the same shape.

Topic: check_shape_copy_construct
Closes #290
GitOrigin-RevId: 276157ff645a3ad88c8b616a3f409d3dda399fb1
  • Loading branch information
bradley-solliday-skydio authored and aaron-skydio committed Jan 7, 2023
1 parent af2dff0 commit f8f7eb4
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 5 deletions.
15 changes: 10 additions & 5 deletions symforce/geo/matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,11 @@ def __new__(cls, *args: _T.Any, **kwargs: _T.Any) -> Matrix:
# 2) Construct with another Matrix - this is easy
elif len(args) == 1 and hasattr(args[0], "is_Matrix") and args[0].is_Matrix:
rows, cols = args[0].shape
if cls._is_fixed_size():
assert cls.SHAPE == (
rows,
cols,
), f"Inconsistent shape: expected shape {cls.SHAPE} but found shape {(rows, cols)}"
flat_list = list(args[0])

# 3) If there's one argument and it's an array, works for fixed or dynamic size.
Expand Down Expand Up @@ -480,10 +485,10 @@ def transpose(self) -> Matrix:
"""
Matrix Transpose
"""
return self.__class__(self.mat.transpose())
return Matrix(self.mat.transpose())

def reshape(self, rows: int, cols: int) -> Matrix:
return self.__class__(self.mat.reshape(rows, cols))
return Matrix(self.mat.reshape(rows, cols))

def dot(self, other: Matrix) -> _T.Scalar:
"""
Expand Down Expand Up @@ -590,7 +595,7 @@ def __getitem__(self, item: _T.Any) -> _T.Any:
"""
ret = self.mat.__getitem__(item)
if isinstance(ret, sf.sympy.Matrix):
ret = self.__class__(ret)
ret = Matrix(ret)
return ret

def __setitem__(
Expand Down Expand Up @@ -662,9 +667,9 @@ def __mul__(
if typing_util.scalar_like(right):
return self.applyfunc(lambda x: x * right)
elif isinstance(right, Matrix):
return self.__class__(self.mat * right.mat)
return Matrix(self.mat * right.mat)
else:
return self.__class__(self.mat * right)
return Matrix(self.mat * right)

@_T.overload
def __rmul__(
Expand Down
1 change: 1 addition & 0 deletions test/geo_matrix_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def test_construction(self) -> None:
# 2) Matrix(sf.sympy.Matrix([[1, 2], [3, 4]])) # Matrix22 with [1, 2, 3, 4] data
self.assertIsInstance(sf.M(sf.sympy.Matrix([[1, 2], [3, 4]])), sf.M22)
self.assertEqual(sf.M(sf.sympy.Matrix([[1, 2], [3, 4]])), sf.M([[1, 2], [3, 4]]))
self.assertRaises(AssertionError, lambda: sf.V3(sf.V2()))

# 3A) Matrix([[1, 2], [3, 4]]) # Matrix22 with [1, 2, 3, 4] data
self.assertIsInstance(sf.M([[1, 2], [3, 4]]), sf.M22)
Expand Down

0 comments on commit f8f7eb4

Please sign in to comment.