In [7]:
import torch
from einops import repeat
import numpy as np

def remove_diagonal(x):
    rem_dims = x.shape[:-2]
    n = x.shape[-1]
    x = x.flatten(start_dim=-2)[...,1:].view(*rem_dims,n-1, n+1)[...,:-1].reshape(*rem_dims,n, n-1)
    return x

def pairwise_vectors(x,y=None,box=None):
    if y is None:
        y = x
    print(x.shape)
    print(y.shape)
    pair_vec = (x.unsqueeze(-2) - y.unsqueeze(-3))
    if box is not None:
        #unsqueeze and expand the dimensions of box to match the dimensions of pair_vec
        while len(box.shape) < len(pair_vec.shape):
            box = box.unsqueeze(0)
        box = box.expand_as(pair_vec)
        pair_vec = pair_vec - box*torch.round(pair_vec/box)
    return pair_vec

def point_point_distance(x,y=None,box=None,remove_diag=False):
    pair_vec = pairwise_vectors(x,y,box)
    distances = torch.linalg.norm(pair_vec, axis=-1)
    if y is None and remove_diag:
        distances =  remove_diagonal(distances)
    return distances

def dot_along_last_dim(x,y):
            return torch.einsum("...i,...i->...",x,y)

def point_segment_distance(x, a, b, box=None):
    # compute the distances between all points in x and all segments defined by b-a
    # x: (n_points, dim)
    # a,b : (n_segments, dim)
    ab = b - a
    ab_norm = torch.linalg.norm(ab, axis=-1)
    #ap: (n_points, n_segments, dim)
    ap = pairwise_vectors(x,a,box)
    ab = ab/ab_norm.unsqueeze(-1)
    ab = ab.unsqueeze(-3).expand(ap.shape)
    ap_ab = dot_along_last_dim(ap,ab)
    ap_ab = torch.clamp(ap_ab, torch.zeros_like(ap_ab), ab_norm.unsqueeze(-2))
    dist = torch.linalg.norm(ap - ap_ab[...,None]*ab, axis=-1)
    return dist

def pbc(x, box=None):
    if box is None:
        return x
    while len(box.shape) < len(x.shape):
        box = box.unsqueeze(0)
    return x - box*torch.round(x/box)

def orientation(p, q, r):
    val = (q[...,1] - p[...,1]) * (r[...,0] - q[...,0]) - (q[...,0] - p[...,0]) * (r[...,1] - q[...,1])
    return torch.where(val == 0, 0, torch.where(val > 0, 1, 2))

def all_pairs(x,y=None):
    if y is None:
        y = x
    n = x.shape[-2]
    m = y.shape[-2]
    x = repeat(x,"... dim -> ... m dim",m=m)
    y = repeat(y,"... m dim -> ... n m dim",n=n,m=m)
    return x,y

def segment_segment_distance(p1,q1,p2,q2,box=None, remove_diag=False):
    # compute all the point-segment distances between the two sets of segments
    # p1,q1: (n_segments, dim)
    # p2,q2: (n_segments_2, dim)
    #return (n_segments, n_segments_2)
    #stack all pairs of segments
    assert (p1.shape == q1.shape) & (p2.shape == q2.shape)
    p1_2 = point_segment_distance(p1,p2,q2,box)
    q1_2 = point_segment_distance(q1,p2,q2,box)
    p2_1 = point_segment_distance(p2,p1,q1,box)
    q2_1 = point_segment_distance(q2,p1,q1,box)
    #check if segments intersect
    p1,p2 = all_pairs(p1,p2)
    q1,q2 = all_pairs(q1,q2)
    #print(p1[0,0],p2[0,0],q1[0,0],q2[0,0])
    o1 = orientation(p1, q1, p2)
    o2 = orientation(p1, q1, q2)
    o3 = orientation(p2, q2, p1)
    o4 = orientation(p2, q2, q1)
    if_intersect = (o1!=o2)*(o3!=o4)
    dist = torch.min(torch.min(p1_2,q1_2),torch.min(p2_1,q2_1).transpose(-1,-2))
    
    #return (n_segments, n_segments_2)
    dist[if_intersect] = 0
    if remove_diag:
       dist = remove_diagonal(dist)
    return dist

In [11]:
def get_endpoints(x,theta,len):
    segments = len*torch.stack([torch.cos(theta), torch.sin(theta)], axis=-1)
    return torch.stack([x + 0.5*segments,x - 0.5*segments], axis=0)

In [14]:
x = torch.randn(2,2)
x.requires_grad = True
theta = torch.randn(2)*torch.tensor(np.pi)
theta.requires_grad = True
len = 3
endpts = get_endpoints(x,theta,len)
print(endpts.shape)
distances = segment_segment_distance(endpts[0],endpts[1],endpts[0],endpts[1],box=None,remove_diag=True)
spring_energy = 0.5*distances**2
spring_energy.sum().backward()
print(x.grad,theta.grad)

torch.Size([2, 2, 2])
torch.Size([2, 2])
torch.Size([2, 2])
torch.Size([2, 2])
torch.Size([2, 2])
torch.Size([2, 2])
torch.Size([2, 2])
torch.Size([2, 2])
torch.Size([2, 2])
tensor([[-0.0904, -2.4215],
        [ 0.0904,  2.4215]]) tensor([ 2.0380, -2.9101])
