In [65]:
import torch
from matplotlib import pyplot as plt
@torch.jit.script_if_tracing
def MyLinearSumAssignment(TruthTensor, maximize=True,lookahead=2):
    '''
    If Maximize is False, I'm trying to minimize the costs. 
    This means that the mask must instead make all the weights far above all the others - 'inf' kind of thing. 
    '''
    #assert truthtensor is 2d and nonzero
    mask=torch.zeros_like(TruthTensor)
    results=torch.zeros_like(TruthTensor,dtype=torch.bool)

    finder=torch.argmax if maximize else torch.argmin
    TruthTensor=TruthTensor-(torch.min(torch.min(TruthTensor)))
    replaceval=-1 if maximize else (torch.max(torch.max(TruthTensor)))
    #add a small amount of noise to the tensor to break ties
    TruthTensor=TruthTensor+torch.randn_like(TruthTensor)*1e-6
    dimsizes=torch.tensor(TruthTensor.shape)
    #select index of the smallest value
    bigdim=torch.argmax(dimsizes).item()   # 0 
    small_dim=1-bigdim          # 1
    
    for i in range(TruthTensor.shape[small_dim]-1): # number of rows
        
        arr=torch.where(mask==1,replaceval,TruthTensor)
        deltas=torch.diff(torch.topk(arr,lookahead,dim=bigdim,largest=maximize).values,n=lookahead-1,dim=bigdim)
        col_index=torch.argmax(torch.abs(deltas),dim=small_dim) #this is the column to grab,  Note this measures step so its not important to do argmin...
        row_index=finder(torch.select(arr,small_dim,col_index))
        torch.select(mask,small_dim,col_index).fill_(1)
        torch.select(mask,bigdim,row_index).fill_(1)

        torch.select(torch.select(results,small_dim,col_index),0,row_index).fill_(True)
        # plt.subplot(1,3,1)
        # plt.imshow(arr.detach().cpu().numpy())
        # plt.subplot(1,3,2)
        # plt.imshow(mask.detach().cpu().numpy())
        # plt.subplot(1,3,3)
        # plt.imshow(results.detach().cpu().numpy())
        # plt.show()


    return torch.logical_or(results,torch.logical_not(mask))

In [67]:
#do trials 
from tqdm import tqdm
import torch
for i in tqdm(range(10000)):

    rand_input=torch.rand(10,10,requires_grad=True,device='cuda' if torch.cuda.is_available() else 'cpu')
    mask=torch.rand(10,10,requires_grad=True,device='cuda' if torch.cuda.is_available() else 'cpu')>0.5

    rand_input=torch.where(mask,rand_input,torch.zeros_like(rand_input))
    torch.select(rand_input,0,0).fill_(0)
    torch.select(rand_input,1,1).fill_(0)
    results=MyLinearSumAssignment(rand_input, maximize=True,lookahead=2)
    assert torch.all(torch.sum(results,dim=0)==1)
    assert torch.all(torch.sum(results,dim=1)==1)


 56%|█████▌    | 5572/10000 [00:15<00:12, 357.76it/s]