In [11]:
import torch
import my_code.mysionna as sn

def matrix_sqrt(tensor):
    r""" Computes the square root of a matrix.

    Given a batch of Hermitian positive semi-definite matrices
    :math:`\mathbf{A}`, returns matrices :math:`\mathbf{B}`,
    such that :math:`\mathbf{B}\mathbf{B}^H = \mathbf{A}`.

    The two inner dimensions are assumed to correspond to the matrix rows
    and columns, respectively.

    Args:
        tensor ([..., M, M]) : A tensor of rank greater than or equal
            to two.

    Returns:
        A tensor of the same shape and type as ``tensor`` containing
        the matrix square root of its last two dimensions.
    """
    if sn.config.xla_compat and not tensor.is_grad_enabled():
        s, u = torch.linalg.eigh(tensor)

        # Compute sqrt of eigenvalues
        s = torch.abs(s)
        s = torch.sqrt(s)
        s = s.type(dtype=u.dtype)

        # Matrix multiplication
        s = s.unsqueeze(-2)
        return torch.matmul(u * s, torch.conj(torch.transpose(u, -2, -1)))
    else:
        s, u = torch.linalg.eigh(tensor)

        # Compute sqrt of eigenvalues
        s = torch.abs(s)
        s = torch.sqrt(s)
        s = s.type(dtype=u.dtype)

        # Matrix multiplication
        s = s.unsqueeze(-2)
        return torch.matmul(u * s, torch.conj(torch.transpose(u, -2, -1)))
# Example usage:
tensor = torch.randn(3, 3, dtype=torch.float64)
tensor = tensor @ tensor.T  # Make it positive semi-definite

sqrt_tensor = matrix_sqrt(tensor)
print("Original tensor:\n", tensor)
print("Square root of the tensor:\n", sqrt_tensor)
print("Product of sqrt and its transpose:\n", sqrt_tensor @ sqrt_tensor.T)

Original tensor:
 tensor([[ 0.1571, -0.6021, -0.6281],
        [-0.6021,  4.9617,  2.9200],
        [-0.6281,  2.9200,  4.3224]], dtype=torch.float64)
Square root of the tensor:
 tensor([[ 0.2683, -0.1866, -0.2242],
        [-0.1866,  2.1024,  0.7119],
        [-0.2242,  0.7119,  1.9405]], dtype=torch.float64)
Product of sqrt and its transpose:
 tensor([[ 0.1571, -0.6021, -0.6281],
        [-0.6021,  4.9617,  2.9200],
        [-0.6281,  2.9200,  4.3224]], dtype=torch.float64)
