Skip to content

Commit

Permalink
[Fix] Fix the potential NaN bug in calc_square_dist() (open-mmlab#2356)
Browse files Browse the repository at this point in the history
  • Loading branch information
ZCMax authored and akozlov-outrider committed May 8, 2023
1 parent b121576 commit a7ca3f4
Showing 1 changed file with 4 additions and 9 deletions.
13 changes: 4 additions & 9 deletions mmcv/ops/points_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,11 @@ def calc_square_dist(point_feat_a: Tensor,
torch.Tensor: (B, N, M) Square distance between each point pair.
"""
num_channel = point_feat_a.shape[-1]
# [bs, n, 1]
a_square = torch.sum(point_feat_a.unsqueeze(dim=2).pow(2), dim=-1)
# [bs, 1, m]
b_square = torch.sum(point_feat_b.unsqueeze(dim=1).pow(2), dim=-1)

corr_matrix = torch.matmul(point_feat_a, point_feat_b.transpose(1, 2))

dist = a_square + b_square - 2 * corr_matrix
dist = torch.cdist(point_feat_a, point_feat_b)
if norm:
dist = torch.sqrt(dist) / num_channel
dist = dist / num_channel
else:
dist = torch.square(dist)
return dist


Expand Down

0 comments on commit a7ca3f4

Please sign in to comment.