Skip to content

Commit

Permalink
Minor vmap fix in SO2 (facebookresearch#362)
Browse files Browse the repository at this point in the history
* Changed SO2._rotate_from_cos_sin to use cosine rather than point for creating new batched tensor.

* Changed SO2._rotate_from_cos_sin to avoid in-place indexing.
  • Loading branch information
luisenp committed Nov 16, 2022
1 parent 64c4171 commit edf7f89
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions theseus/geometry/so2.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,6 @@ def _rotate_from_cos_sin(
cosine: torch.Tensor,
sine: torch.Tensor,
) -> Point2:
batch_size = max(point.shape[0], cosine.shape[0])
if isinstance(point, torch.Tensor):
if point.ndim != 2 or point.shape[1] != 2:
raise ValueError(
Expand All @@ -260,10 +259,11 @@ def _rotate_from_cos_sin(
else:
point_tensor = point.tensor
px, py = point_tensor[:, 0], point_tensor[:, 1]
new_point_tensor = point_tensor.new_empty(batch_size, 2)
new_point_tensor[:, 0] = cosine * px - sine * py
new_point_tensor[:, 1] = sine * px + cosine * py
return Point2(tensor=new_point_tensor)
return Point2(
tensor=torch.stack(
[cosine * px - sine * py, sine * px + cosine * py], dim=1
)
)

def rotate(
self,
Expand Down

0 comments on commit edf7f89

Please sign in to comment.