In [11]:
import torch

def chamfer_distance(S1, S2):
    """
    Compute the Chamfer distance between two sets of 3D points.
    
    Parameters:
    S1: torch.Tensor of shape (1, 1024, 3)
    S2: torch.Tensor of shape (1, 1024, 3)
    
    Returns:
    float: Chamfer distance
    """
    # Expand the dimensions of S1 and S2 for broadcasting
    S1_expand = S1.unsqueeze(2)  # (1, 1024, 1, 3)
    S2_expand = S2.unsqueeze(1)  # (1, 1, 1024, 3)
    # print("this is expanding S1", S1)
    # print("this is expanding S2", S2)

    # Compute pairwise distances between points in S1 and S2
    dists = torch.norm(S1_expand - S2_expand, dim=-1)  # (1, 1024, 1024)
    # print("this is subtract operation result", (S1_expand - S2_expand))
    # print("this is dist", dists)
    # For each point in S1, find the nearest point in S2
    min_dist_S1_to_S2, _ = torch.min(dists, dim=2)  # (1, 1024)
    
    # For each point in S2, find the nearest point in S1
    min_dist_S2_to_S1, _ = torch.min(dists, dim=1)  # (1, 1024)
    
    # Sum of minimum distances
    chamfer_dist = min_dist_S1_to_S2.mean() + min_dist_S2_to_S1.mean()
    
    return chamfer_dist.item()

# Example usage
S1 = torch.randn(1, 2, 3)  # Replace with your actual tensor
S2 = torch.randn(1, 2, 3)  # Replace with your actual tensor

S1 = [[0, 0, 0], [1,1,1]]
S2 = [[1,1,1], [2,2,2]]

S1_tensor = torch.tensor(S1, dtype=torch.float32).unsqueeze(0)  # Shape: (1, 2, 3)
S2_tensor = torch.tensor(S2, dtype=torch.float32).unsqueeze(0)  # Shape: (1, 2, 3)

print("S1 Tensor:", S1_tensor)
print("S2 Tensor:", S2_tensor)

dist = chamfer_distance(S1_tensor, S2_tensor)
print(f'Chamfer Distance: {dist}')


S1 Tensor: tensor([[[0., 0., 0.],
         [1., 1., 1.]]])
S2 Tensor: tensor([[[1., 1., 1.],
         [2., 2., 2.]]])
Chamfer Distance: 1.7320507764816284


In [14]:
import torch

def chamfer_distance(S1, S2):
    """
    Compute the Chamfer distance between two sets of 3D points.
    
    Parameters:
    S1: torch.Tensor of shape (1, N, 3) with requires_grad=True
    S2: torch.Tensor of shape (1, N, 3) with requires_grad=True
    
    Returns:
    torch.Tensor: Chamfer distance with requires_grad=True
    """
    # Expand the dimensions of S1 and S2 for broadcasting
    print(f'S1.requires_grad: {S1.requires_grad}')
    print(f'S2.requires_grad: {S2.requires_grad}')
    S1_expand = S1.unsqueeze(2)  # Shape: (1, N, 1, 3)
    S2_expand = S2.unsqueeze(1)  # Shape: (1, 1, N, 3)

    # Compute pairwise distances between points in S1 and S2
    dists = torch.norm(S1_expand - S2_expand, dim=-1)  # Shape: (1, N, N)
    
    # For each point in S1, find the nearest point in S2
    min_dist_S1_to_S2, _ = torch.min(dists, dim=2)  # Shape: (1, N)
    
    # For each point in S2, find the nearest point in S1
    min_dist_S2_to_S1, _ = torch.min(dists, dim=1)  # Shape: (1, N)
    
    # Sum of minimum distances
    chamfer_dist = min_dist_S1_to_S2.mean() + min_dist_S2_to_S1.mean()
    
    return chamfer_dist

# Example usage with backpropagation
# Generate some example data

S1 = torch.tensor([[0, 0, 0], [1, 1, 1]], dtype=torch.float32).unsqueeze(0)  # Shape: (1, 2, 3)
S2 = torch.tensor([[1, 1, 1], [2, 2, 2]], dtype=torch.float32).unsqueeze(0)  # Shape: (1, 2, 3)

# Ensure the tensors require gradients
S1.requires_grad_(True)
S2.requires_grad_(True)

# Compute Chamfer Distance
loss = chamfer_distance(S1, S2)

# Perform backpropagation
loss.backward()

# Check the gradients
print(S1.grad)
print(S2.grad)


S1.requires_grad: True
S2.requires_grad: True
S1.requires_grad: True
S2.requires_grad: True
tensor([[[-0.2887, -0.2887, -0.2887],
         [-0.2887, -0.2887, -0.2887]]])
tensor([[[0.2887, 0.2887, 0.2887],
         [0.2887, 0.2887, 0.2887]]])
