In [71]:
import torch as tn

def _LU(M):
    """
    Perform an LU decomposition and returns L, U and a permutation vector P. 

    Args:
        M (torch.tensor): [description]

    Returns:
        tuple[torch.tensor,torch.tensor,torch.tensor]: L, U, P
    """
    LU, P = tn.linalg.lu_factor(M)

    # LU decomposition Permutation, Lower, Upper
    P, L, U = tn.lu_unpack(LU, P)  # P transpose or not transpose?
    P = P@tn.reshape(tn.arange(P.shape[1],
                     dtype=P.dtype, device=P.device), [-1, 1])

    return L, U, tn.squeeze(P).to(tn.int64)

In [72]:
M = tn.rand(10, 5)

L, U, P = _LU(M)

In [73]:
P_new = P @ tn.reshape(tn.arange(P.shape[1],
                    dtype=P.dtype, device=P.device), [-1, 1])

IndexError: tuple index out of range

In [74]:
P

tensor([5, 3, 7, 1, 2, 0, 6, 4, 8, 9])

In [75]:
P_new

tensor([[9],
        [6],
        [1],
        [4],
        [7],
        [5],
        [2],
        [3],
        [8],
        [0]])

In [76]:
tn.allclose((P @ L @ U), M)

RuntimeError: expected m1 and m2 to have the same dtype, but got: long int != float

In [77]:
(P @ L @ U), M

RuntimeError: expected m1 and m2 to have the same dtype, but got: long int != float

In [78]:
def _max_matrix(M):

    # Maximum element along dimention k
    values, indices = M.flatten().topk(1)
    try:
        # Return the actual index
        indices = [tn.unravel_index(i, M.shape) for i in indices]
    except:
        indices = [np.unravel_index(i, M.shape) for i in indices]

    return values, indices

In [79]:
M = tn.rand(10, 5)

values, indices = M.flatten().topk(1)
indices = [tn.unravel_index(i, M.shape) for i in indices]

values, indices

(tensor([0.9047]), [(tensor(1), tensor(3))])

In [99]:
def _LU(M):
    """
    Perform an LU decomposition and returns L, U and a permutation vector P. 

    Args:
        M (torch.tensor): [description]

    Returns:
        tuple[torch.tensor,torch.tensor,torch.tensor]: L, U, P
    """
    LU, P = tn.linalg.lu_factor(M)
    P, L, U = tn.lu_unpack(LU, P)  # P transpose or not transpose?
    P = P@tn.reshape(tn.arange(P.shape[1],
                     dtype=P.dtype, device=P.device), [-1, 1])
    # P = tn.reshape(tn.arange(P.shape[1],dtype=P.dtype,device=P.device),[1,-1]) @ P

    return L, U, tn.squeeze(P).to(tn.int64)


def _max_matrix(M):

    values, indices = M.flatten().topk(1)
    try:
        indices = [tn.unravel_index(i, M.shape) for i in indices]
    except:
        indices = [np.unravel_index(i, M.shape) for i in indices]

    return values, indices


# Max volume submatrix
def _maxvol(M):
    """
    Maxvol

    Args:
        M (torch.tensor): input matrix.

    Returns:
        torch.tensor: indices of tha maxvol submatrix.
    """

    if M.shape[1] >= M.shape[0]:
        # more cols than row -> return all the row indices
        idx = tn.tensor(range(M.shape[0]), dtype=tn.int64)
        return idx
    else:
        L, U, P = _LU(M)
        idx = P[:M.shape[1]]

    Msub = M[idx, :]

    Mat = tn.linalg.solve(Msub.T, M.T).t()
    print(Mat.shape, Msub.shape, M.shape)
    
    for i in range(100):
        val_max, idx_max = _max_matrix(tn.abs(Mat))
        print(idx)
        idx_max = idx_max[0]
        if val_max <= 1+5e-2:
            print("sort")
            print(idx)
            idx = tn.sort(idx)[0]
            return idx
        Mat += tn.outer(Mat[:, idx_max[1]], Mat[idx[idx_max[1]]] -
                        Mat[idx_max[0], :])/Mat[idx_max[0], idx_max[1]]
        idx[idx_max[1]] = idx_max[0]
    return idx

In [100]:
M = tn.rand(10, 5)

_maxvol(M)

torch.Size([10, 5]) torch.Size([5, 5]) torch.Size([10, 5])
tensor([6, 0, 1, 2, 8])
tensor([6, 3, 1, 2, 8])
tensor([6, 3, 9, 2, 8])
tensor([7, 3, 9, 2, 8])
sort
tensor([7, 3, 9, 2, 8])


tensor([2, 3, 7, 8, 9])