In [6]:
import torch
ala_lit_coords=torch.Tensor([[[[-0.525, 1.363, 0.000],[0.0000,0.000,0.000],[1.526,0.000,0.000]]]]).float()
ala_lit_coords.shape

torch.Size([1, 1, 3, 3])

In [7]:
import torch
from openfold.utils.rigid_utils import Rotation, Rigid
def get_bb_frames(coords):
    """
    Returns a local rotation frame defined by N, CA, C positions.
    Args:
        coords: coordinates, tensor of shape (batch_size x length x 3 x 3)
        where the third dimension is in order of N, CA, C
    Returns:
        Local relative rotation frames in shape (batch_size x length x 3 x 3)
        Local translation in shape (batch_size x length x 3)
    """
    v1 = coords[:, :, 2] - coords[:, :, 1]
    v2 = coords[:, :, 0] - coords[:, :, 1]
    e1 = normalize(v1, dim=-1)  # [B, L, 3]
    u2 = v2 - e1 * torch.sum(e1 * v2, dim=-1, keepdim=True)
    e2 = normalize(u2, dim=-1)  # [B, L, 3]
    e3 = torch.cross(e1, e2, dim=-1)
    R = torch.stack([e1, e2, e3], dim=-1)
    t = coords[:, :, 1]  # translation is just the CA atom coordinate
    return Rigid(Rotation(R), t)

def norm(tensor, dim, eps=1e-8, keepdim=False):
    """
    Returns L2 norm along a dimension.
    """
    return torch.sqrt(
        torch.sum(torch.square(tensor), dim=dim, keepdim=keepdim) + eps)


def normalize(tensor, dim=-1):
    """
    Normalizes a tensor along a dimension after removing nans.
    """
    return nan_to_num(
        torch.div(tensor, norm(tensor, dim=dim, keepdim=True))
    )


def nan_to_num(ts, val=0.0):
    """
    Replaces nans in tensor with a fixed value.    
    """
    val = torch.tensor(val, dtype=ts.dtype, device=ts.device)
    return torch.where(~torch.isfinite(ts), val, ts)

In [8]:
frame1=Rigid.from_3_points(
                            p_neg_x_axis=ala_lit_coords[..., 2, :],
                            origin=ala_lit_coords[...,1,:],
                            p_xy_plane=ala_lit_coords[...,0,:],
                            eps=1e-4
                            )
frame1._rots._rot_mats

tensor([[[[-9.9998e-01, -1.6442e-05,  0.0000e+00],
          [ 0.0000e+00,  9.9997e-01,  0.0000e+00],
          [ 0.0000e+00,  0.0000e+00, -9.9995e-01]]]])

In [9]:
frame2 = get_bb_frames(ala_lit_coords)
frame2._rots._rot_mats

tensor([[[[1., 0., 0.],
          [0., 1., 0.],
          [0., 0., 1.]]]])