In [None]:
def column_optimal_swap(m: Mat2) -> Dict[int,int]:
    """Given a matrix m, tries to find a permutation of the columns such that
    there are as many ones on the diagonal as possible. 
    This reduces the number of row operations needed to do Gaussian elimination.
    """
    r, c = m.rows(), m.cols()
    connections:  Dict[int,Set[int]] = {i: set() for i in range(r)}
    connectionsr: Dict[int,Set[int]] = {j: set() for j in range(c)}

    for i in range(r):
            for j in range(c):
                if m.data[i][j]: 
                    connections[i].add(j)
                    connectionsr[j].add(i)

    target = _find_targets(connections, connectionsr)
    if target is None: target = dict()
    #target = {v:k for k,v in target.items()}
    left = list(set(range(c)).difference(target.values()))
    right = list(set(range(c)).difference(target.keys()))
    for i in range(len(left)):
        target[right[i]] = left[i]
    return target

def _find_targets(
        conn: Dict[int,Set[int]],
        connr: Dict[int,Set[int]],
        target:Dict[int,int]={}
        ) -> Optional[Dict[int,int]]:
    """Helper function for :func:`column_optimal_swap`.
    Recursively makes a choice for a permutation that places additional ones on the diagonal.
    Backtracks when it gets stuck in an unfavorable configuration."""
    target = target.copy()
    r = len(conn)
    c = len(connr)

    claimedcols = set(target.keys())
    claimedrows = set(target.values())

    while True:
        min_index = -1
        min_options = set(range(1000))
        for i in range(r):
            if i in claimedrows: continue
            s = conn[i] - claimedcols # The free columns
            if len(s) == 1:
                j = s.pop()
                target[j] = i
                claimedcols.add(j)
                claimedrows.add(i)
                break
            if len(s) == 0: return None # contradiction
            found_col = False
            for j in s:
                t = connr[j] - claimedrows
                if len(t) == 1: # j can only be connected to i
                    target[j] = i
                    claimedcols.add(j)
                    claimedrows.add(i)
                    found_col = True
                    break
            if found_col: break
            if len(s) < len(min_options):
                min_index = i
                min_options = s
        else: # Didn't find any forced choices
            if not (conn.keys() - claimedrows): # we are done
                return target
            if min_index == -1: raise ValueError("This shouldn't happen ever")
            # Start depth-first search
            tgt = target.copy()
            #print("backtracking on", min_index)
            for i2 in min_options:
                #print("trying option", i2)
                tgt[i2] = min_index
                new_target = _find_targets(conn, connr, tgt)
                if new_target: return new_target
            #print("Unsuccessful")
            return target

In [None]:
import Mat2
m = Mat2([[1,0,0,1],
           [0,1,1,0],
           [1,0,0,0],
           [0,1,1,1]])