In [None]:
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# Test: swap labels in a large mask matrix

In [None]:
f = '/greendata/Images2022/Gaby/dredFISH/DPNMF-FR_R1_4A_UC_R2_5C_2022Nov27/fishdata_2022Dec09/DPNMF-FR_R1_4A_UC_R2_5C_2022Nov27_Section2_total_mask_stitched.pt'
mat = torch.load(f)
mat, mat.shape

In [None]:
mat = mat.numpy()
mat, mat.shape

In [None]:
# setting up and check some stats
i, j = np.nonzero(mat)
unq = np.unique(mat[i,j])
lbl = np.random.choice(unq, size=len(unq), replace=False) # randomly swap the labels

m, n = mat.shape
nnz = len(i)
ncl = len(unq)

print(f"""size: ({m}, {n})
sparsity: {nnz/(m*n)}
n cells: {ncl}
old labels: {unq}
new labels: {lbl}
""")

In [None]:
%%time
# the problem:
# create from the old mask matrix a new matrix that with the swapped labels (oldlbl->newlbl)

def swap_mask(mat, lookup_o2n):
    """create from the old mask matrix a new matrix with the swapped labels according to the lookup table (pd.Series)
    """
    i, j = np.nonzero(mat)
    unq, inv = np.unique(mat[i,j], return_inverse=True)
    # assert np.all(unq[inv] == mat[i,j]) # unq[inv] should recreates the original one
    
    newmat = mat.copy()
    newmat[i,j] = lookup_o2n.loc[unq].values[inv]
    return newmat

lookup_o2n = pd.Series(lbl, index=unq)
newmat = swap_mask(mat, lookup_o2n)

In [None]:
%%time
# see if the results make sense

fig, axs = plt.subplots(1,2,figsize=(2*5,1*5))
ax = axs[0]
ax.set_title('Original')
g = ax.imshow(mat[5000:6000,
               5000:6000])
fig.colorbar(g, ax=ax)

ax = axs[1]
ax.set_title('Swapped')
g = ax.imshow(newmat[5000:6000,
                  5000:6000])
fig.colorbar(g, ax=ax)