In [13]:
# What am I solving
# A shortest augmenting path algorithm for rectangular assignment problem
import numpy as np
from scipy.optimize import linear_sum_assignment
# each C[i,j] is the cost of matching vertex i of the 
# first partite set (a ‘worker’) and vertex j
# find a set of problems to jobs of minimal cost

cost = np.array([
    [4, 1, 3],
    [2, 0, 5], 
    [3, 2, 2]
])

row_ind, col_ind = linear_sum_assignment(cost)
print(row_ind, col_ind)
cost[row_ind, col_ind].sum()

[0 1 2] [1 0 2]


np.int64(5)

In [12]:
from numba import cuda
from numba import int32, float32, boolean
import numpy as np
import cupy as cp

@cuda.jit(device=True)
def augmenting_path(
    nc, cost, u, v, path, row4col,
    shortestPathCosts, i, 
    SR, SC,
    remaining, p_minVal
):
    minVal = 0.0
    num_remaining = nc
    for it in range(nc):
        remaining[it] = nc - it - 1

    SR[:] = False
    SC[:] = False
    shortestPathCosts[:] = np.inf

    sink = -1
    while sink == -1:
        index = -1
        lowest = np.inf  # FIXED: Should be positive infinity
        SR[i] = True

        for it in range(num_remaining):
            j = remaining[it]
            r = minVal + cost[i * nc + j] - u[i] - v[j]
            if r < shortestPathCosts[j]:
                path[j] = i
                shortestPathCosts[j] = r

            if shortestPathCosts[j] < lowest or \
                (shortestPathCosts[j] == lowest and row4col[j] == -1):
                lowest = shortestPathCosts[j]
                index = it

        minVal = lowest
        if minVal == np.inf:  # FIXED: Correct infinity comparison
            return -1

        j = remaining[index]
        if row4col[j] == -1:
            sink = j
        else:
            i = row4col[j]

        SC[j] = True
        num_remaining -= 1
        remaining[index] = remaining[num_remaining]

    p_minVal[0] = minVal  # FIXED: Move this outside the loop
    return sink

@cuda.jit(device=True)
def solve(nc, nr, cost, maximize, a, b):
    u = cuda.local.array((N_MEM,), float32)
    v = cuda.local.array((N_MEM,), float32)
    u[:] = 0
    v[:] = 0
    shortestPathCosts = cuda.local.array((N_MEM), float32)
    path = cuda.local.array((N_MEM,), int32)
    shortestPathCosts[:] = np.inf
    path[:] = -1
    col4row = cuda.local.array((N_MEM,), int32)
    row4col = cuda.local.array((N_MEM,), int32)
    col4row[:] = -1
    row4col[:] = -1
    SR = cuda.local.array((N_MEM,), boolean)
    SC = cuda.local.array((N_MEM,), boolean)
    remaining = cuda.local.array((N_MEM,), int32)

    for curRow in range(nr):
        minVal = cuda.local.array((1,), float32)
        sink = -1
        sink = augmenting_path(nc, cost, u, v, path, row4col,
                                    shortestPathCosts, curRow, SR, SC,
                                    remaining, minVal)
        if sink < 0:
            return RECT_LSAP_INFEASIBLE

        u[curRow] += minVal[0]
        for i in range(nr):
            if SR[i] and i != curRow:
                u[i] += minVal[0] - shortestPathCosts[col4row[i]]
        
        for j in range(nc):
            if SC[j]:
                v[j] -= minVal[0] - shortestPathCosts[j]
        j = int(sink)
        while True:
            i = path[j]
            row4col[j] = i
            tmp = col4row[i] # std::swap
            col4row[i] = j
            j = tmp
            if i == curRow:
                break
    
    for i in range(nr):
        a[i] = i
        b[i] = col4row[i]
    return 0

N_MEM = 20
RECT_LSAP_INFEASIBLE = -100

@cuda.jit#(debug=True, opt=False)
def rlsap_kernel(cost_batch, meta_batch, outp_batch, n_problems):
    tid = cuda.blockIdx.x * cuda.blockDim.x + cuda.threadIdx.x
    if tid < n_problems:
        maximize = True
        nc = meta_batch[tid, 0]
        nr = meta_batch[tid, 1]
        cost = cost_batch[tid] # nr x nc
        signal = solve(nc, nr, cost, maximize, outp_batch[tid, 0], outp_batch[tid, 1])
        assert signal != RECT_LSAP_INFEASIBLE

cost = np.array([
    [4, 1, 3],
    [2, 0, 5], 
    [3, 2, 2]
], dtype=np.int32)

n_problems = 8
cost_batch = cp.zeros((n_problems, 20), dtype=np.int32)
meta_batch = cp.zeros((n_problems, 2), dtype=np.int32)
meta_batch[0, 0] = 3  # as it's 3x3
meta_batch[0, 1] = 3  # as it's 3x3
cost_batch[0, :cost.size] = cp.asarray(cost.ravel())
outp_batch = cp.zeros((n_problems, 2, 20), dtype=np.int32)

threads_per_block = 1#32
blocks_per_grid = 1# (n_problems + threads_per_block - 1) // threads_per_block

rlsap_kernel[blocks_per_grid, threads_per_block](cost_batch, meta_batch, outp_batch, 
                                          np.uint32(n_problems))
print(outp_batch[0])



[[0 1 2 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [1 0 2 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]]


In [None]:
# for each atom type
# get nodes with that atom type [n1,n2...]
# make graph
# add nodes like (1,j) and (2,j) for those
# Add all edges (full bipartite), and add costs for that edge
# if nodes are missing, add missing with weight...
# solve full bipartite matching

## costs
# in i neigh, count atom types [c,n]
# in j neigh same [c]
# get all 