# Efficient co-occ scores

In [10]:
from campa.tl import Experiment, FeatureExtractor
import os
from campa.pl import plot_mean_intensity, get_intensity_change, plot_intensity_change, plot_mean_size
import anndata as ad
from campa.pl import plot_co_occurrence, plot_co_occurrence_grid
import numpy as np
import squidpy as sq
from campa.tl._cluster import annotate_clustering
from campa.utils import init_logging
init_logging()

In [4]:
exp = Experiment.from_dir('VAE_all/CondVAE_pert-CC')
# just use one dir here, for extracting all features, run the script
data_dir = exp.data_params['data_dirs'][0]
#extr = FeatureExtractor(exp, data_dir=data_dir, cluster_name='clustering_res0.9_sub-0.33_seed1', 
#                        cluster_dir='aggregated/sub-0.005_sub-0.33')
#extr = FeatureExtractor(exp, data_dir=data_dir, cluster_name='clustering_res0.5', cluster_col='annotation')
extr = FeatureExtractor.from_adata(os.path.join(exp.full_path, 'aggregated/full_data', data_dir, 'features_annotation.h5ad'))

interval = np.logspace(np.log2(2),np.log2(80),10, base=2).astype(np.float32)

INFO:Experiment:Setting up experiment VAE_all/CondVAE_pert-CC
INFO:Experiment:Initialised from existing experiment in VAE_all/CondVAE_pert-CC
INFO:Experiment:Setting up experiment VAE_all/CondVAE_pert-CC
INFO:Experiment:Initialised from existing experiment in VAE_all/CondVAE_pert-CC
INFO:Experiment:Cluster annotation: using cluster data in aggregated/sub-0.001


In [5]:
extr.params

{'data_dir': '184A1_unperturbed/I09',
 'cluster_name': 'clustering_res0.5',
 'cluster_dir': None,
 'cluster_col': 'annotation',
 'exp_name': 'VAE_all/CondVAE_pert-CC'}

## Manually test implementations

In [7]:
def _prepare_co_occ(interval):
    """
    return lists of coordinates to consider for each interval. Coordinates are relative to [0,0].
    """
    arr = np.zeros((int(interval[-1])*2+1, int(interval[-1])*2+1))
    # calc distances for interval range (assuming c as center px)
    c = int(interval[-1])+1
    c = np.array([c,c]).T
    xx, yy = np.meshgrid(np.arange(len(arr)), np.arange(len(arr)))
    coords = np.array([xx.flatten(), yy.flatten()]).T
    dists = np.sqrt(np.sum((coords - c)**2, axis=-1))
    dists = dists.reshape(int(interval[-1])*2+1, -1)
    # calc coords for each thresh interval
    coord_lists = []
    for thres_min, thres_max in zip(interval[:-1], interval[1:]):
        xy = np.where((dists <= thres_max) & (dists > thres_min))
        coord_lists.append(xy - c[:,np.newaxis])

    return coord_lists

coord_lists = _prepare_co_occ(interval)


In [11]:
def squidpy_co_occ(extr, interval):
    self = extr
    obj_id = extr.mpp_data.unique_obj_ids[0]

    cluster_names = {n: i for i,n in enumerate(self.clusters)}
    
    adata = self.mpp_data.subset(obj_ids=[obj_id], copy=True).get_adata(obs=[self.params['cluster_name']])
    # ensure that cluster annotation is present in adata
    if self.params['cluster_name'] != self.params['cluster_col']:
        adata.obs[self.params['cluster_col']] = annotate_clustering(adata.obs[self.params['cluster_name']], self.annotation, 
            self.params['cluster_name'], self.params['cluster_col'])
    adata.obs[self.params['cluster_col']] = adata.obs[self.params['cluster_col']].astype('category')
    self.log.info(f'co-occurrence for {obj_id}, with shape {adata.shape}')
    cur_co_occ, _ = sq.gr.co_occurrence(
        adata,
        cluster_key=self.params['cluster_col'],
        spatial_key='spatial',
        interval=interval,
        copy=True, show_progress_bar=False,
        n_splits=1
    )
    # ensure that co_occ has correct format incase of missing clusters
    co_occ = np.zeros((len(self.clusters),len(self.clusters),len(interval)-1))
    cur_clusters = np.vectorize(cluster_names.__getitem__)(np.array(adata.obs[self.params['cluster_col']].cat.categories))
    grid = np.meshgrid(cur_clusters, cur_clusters)
    co_occ[grid[0].flat, grid[1].flat] = cur_co_occ.reshape(-1, len(interval)-1)
    
    return co_occ


from numba import njit, jit
import numba.types as nt
ft = nt.float32
it = nt.int64
import time

@jit(ft[:,:](it[:],it[:],it), fastmath=True)
def _count_co_occ(clus1:np.ndarray, clus2:np.ndarray, num_clusters:int) -> np.ndarray:
    co_occur = np.zeros((num_clusters, num_clusters), dtype=np.float32)
    for i, j in zip(clus1, clus2):
        co_occur[i, j] += 1
    return co_occur

def _co_occ_opt(coords1: np.ndarray, # int64
               coords2_list: np.ndarray, # int64
               clusters1: np.ndarray, # int64
               img: np.ndarray, # int64
               num_clusters: int
              ) -> np.ndarray:

    out = np.zeros((num_clusters, num_clusters, len(coords2_list)), dtype=np.float32)
    for idx, coords2 in enumerate(coords2_list):
        #t1 = time.time()
        
        #co_occur = np.zeros((num_clusters, num_clusters), dtype=np.float32)
        probs_con = np.zeros((num_clusters, num_clusters), dtype=np.float32)

        # get img coords to consider for this interval (len(interval_coords), num obs)
        cur_coords = np.expand_dims(coords2, 2) + np.expand_dims(coords1, 1)

        # get cluster of center pixel + repeat for len(interval_coords)
        clus1 = np.tile(clusters1, [cur_coords.shape[1], 1])
        # reshape to (2, xxx)
        cur_coords = cur_coords.reshape((2,-1))
        clus1 = clus1.reshape([-1])

        # filter cur_coords that are outside image
        shape = np.expand_dims(np.array([img.shape[1],img.shape[0]]), 1)
        mask = np.all((cur_coords >= 0) & (cur_coords < shape), axis=0)
        cur_coords = cur_coords[:,mask]
        clus1 = clus1[mask]

        # get cluster of cur_coords
        clus2 = img[cur_coords[1], cur_coords[0]].flatten()

        # remove those pairs where clus2 is outside of this image (cluster id is not a valid id)
        mask = clus2 < num_clusters
        #assert (clus1 < num_clusters).all()
        clus1 = clus1[mask]
        clus2 = clus2[mask]

        co_occur = _count_co_occ(clus1, clus2, num_clusters)

        probs_matrix = co_occur / np.sum(co_occur)
        probs = np.sum(probs_matrix, axis=1)

        for c in np.unique(img):
            # do not consider background value in img
            if c >= num_clusters:
                continue
            probs_conditional = co_occur[c] / np.sum(co_occur[c])
            probs_con[c, :] = probs_conditional / probs

        out[:, :, idx] = probs_con
        #t2 = time.time()
        #print(idx, coords2.shape, t1-t2, (t1-t2)/coords2.shape[1])
    return out

def opt_co_occ(extr, interval):
    self = extr
    obj_id = extr.mpp_data.unique_obj_ids[0]
    
    cluster_names = {n: i for i,n in enumerate(self.clusters + [''])}
    coords2_list = _prepare_co_occ(interval)
    
    mpp_data = self.mpp_data.subset(obj_ids=[obj_id], copy=True)
    img, (pad_x, pad_y) = mpp_data.get_object_img(obj_id, data=self.params['cluster_name'], annotation_kwargs={'annotation': self.annotation, 'to_col': self.params['cluster_col']})
    # convert labels to numbers
    img = np.vectorize(cluster_names.__getitem__)(img)
    clusters1 = np.vectorize(cluster_names.__getitem__)(annotate_clustering(mpp_data.data(self.params['cluster_name']), self.annotation, self.params['cluster_name'], self.params['cluster_col']))
    # shift coords according to image padding, st coords correspond to img coords
    coords1 = (np.array([mpp_data.x, mpp_data.y]) - np.array([pad_x, pad_y])[:,np.newaxis]).astype(np.int64)
    self.log.info(f'co-occurrence for {obj_id}, with {len(mpp_data.x)} elements')
    co_occ = _co_occ_opt(coords1, coords2_list, clusters1, img, num_clusters=len(self.clusters))
    
    return co_occ
    
    

In [12]:
%%time
# squidpy co occ
sq_co_occ = squidpy_co_occ(extr, interval)

INFO:MPPData:Before subsetting: 557 objects
INFO:MPPData:Subsetting to 1 objects
INFO:MPPData:Created new: MPPData for NascentRNA (17008 mpps with shape (1, 1, 34) from 1 objects). Data keys: ['x', 'y', 'obj_ids', 'mpp', 'labels', 'clustering_res0.5', 'latent', 'conditions'].
INFO:FeatureExtractor:co-occurrence for 199632, with shape (17008, 34)


CPU times: user 29.8 s, sys: 3.07 s, total: 32.9 s
Wall time: 33 s


In [8]:
@jit(
    ft[:,:](it[:,:], it[:,:], it[:], it[:,:], it),
    fastmath=True
)
def co_occ_numba_loop(coords1: np.ndarray, # int64
               coords2: np.ndarray, # int64
               clusters1: np.ndarray, # int64
               img: np.ndarray, # int64
               num_clusters: int
              ) -> np.ndarray:
               
    co_occur = np.zeros((num_clusters, num_clusters), np.float32)
    probs_con = np.zeros((num_clusters, num_clusters), np.float32)

    for c1,clus1 in zip(coords1.T,clusters1):
        for c2_ in coords2.T:
            c2 = c2_ + c1
            # is invalid?
            if c2[0] < 0 or c2[1] < 0 or c2[0] >= img.shape[1] or c2[1] >= img.shape[0]:
                continue
            clus2 = img[c2[1], c2[0]]
            # is outside of cell?
            if clus2 >= num_clusters:
                continue
            co_occur[clus1, clus2] += 1

    probs_matrix = co_occur / np.sum(co_occur)
    probs = np.sum(probs_matrix, axis=1)

    for c in np.unique(img):
        # do not consider background value in img
        if c >= num_clusters:
            continue
        probs_conditional = co_occur[c] / np.sum(co_occur[c])
        probs_con[c, :] = probs_conditional / probs
    return probs_con



In [13]:
%%time
# OPT with numpy
op_co_occ = opt_co_occ(extr, interval)

INFO:MPPData:Before subsetting: 557 objects
INFO:MPPData:Subsetting to 1 objects
INFO:MPPData:Created new: MPPData for NascentRNA (17008 mpps with shape (1, 1, 34) from 1 objects). Data keys: ['x', 'y', 'obj_ids', 'mpp', 'labels', 'clustering_res0.5', 'latent', 'conditions'].
INFO:FeatureExtractor:co-occurrence for 199632, with 17008 elements


CPU times: user 30.5 s, sys: 6.68 s, total: 37.2 s
Wall time: 37.3 s


In [None]:
%%time
# OPT with numba

# TODO might do some cluster annotation here
#df = pd.DataFrame(mpp_data.data(extr.params['cluster_name']), index=zip(mpp_data.x, mpp_data.y))

obj_id = extr.mpp_data.unique_obj_ids[0]

num_clusters = len(extr.clusters)
cluster_names = {n: i for i,n in enumerate(list(extr.clusters)+[''])}

mpp_data = extr.mpp_data.subset(obj_ids=[obj_id], copy=True)
img, (pad_x, pad_y) = mpp_data.get_object_img(obj_id, data=extr.params['cluster_name'])
# convert labels to numbers
img = np.vectorize(cluster_names.__getitem__)(img)
cluster = np.vectorize(cluster_names.__getitem__)(mpp_data.data(extr.params['cluster_name']))

# shift coords according to image padding, st coords correspond to img coords
coords1 = (np.array([mpp_data.x, mpp_data.y]) - np.array([pad_x, pad_y])[:,np.newaxis]).astype(np.int64)


out = np.zeros((num_clusters, num_clusters, len(coord_lists)), dtype=np.float32)
for idx, coords2 in enumerate(coord_lists):
    out[:,:,idx] = co_occ_numba_loop(coords1, coords2, cluster, img[:,:,0], num_clusters)

In [21]:
sq_co_occ[:,:,0]

array([[6.31509495e+00, 0.00000000e+00, 6.78164721e-01, 9.90489542e-01,
        3.87033410e-02, 1.36381984e+00, 1.24141657e+00],
       [0.00000000e+00, 1.24514322e+01, 1.10505450e+00, 2.39536981e-03,
        2.51594037e-02, 5.21539032e-01, 1.70420744e-02],
       [6.78164780e-01, 1.10505450e+00, 6.64635038e+00, 1.32795736e-01,
        3.26998711e-01, 1.15143716e+00, 1.76081493e-01],
       [9.90489662e-01, 2.39536981e-03, 1.32795736e-01, 4.88185644e+00,
        7.99145028e-02, 4.83710438e-01, 5.57625473e-01],
       [3.87033410e-02, 2.51594037e-02, 3.26998681e-01, 7.99145028e-02,
        4.71955538e+00, 1.56706184e-01, 6.83024451e-02],
       [1.36381984e+00, 5.21539092e-01, 1.15143716e+00, 4.83710408e-01,
        1.56706184e-01, 1.57430923e+00, 4.76014942e-01],
       [1.24141645e+00, 1.70420744e-02, 1.76081479e-01, 5.57625353e-01,
        6.83024451e-02, 4.76014912e-01, 9.61888409e+00]])

In [22]:
op_co_occ[:,:,0]

array([[6.3150949e+00, 0.0000000e+00, 6.7816484e-01, 9.9048966e-01,
        3.8703345e-02, 1.3638198e+00, 1.2414167e+00],
       [0.0000000e+00, 1.2451433e+01, 1.1050546e+00, 2.3953700e-03,
        2.5159407e-02, 5.2153909e-01, 1.7042076e-02],
       [6.7816472e-01, 1.1050545e+00, 6.6463513e+00, 1.3279574e-01,
        3.2699874e-01, 1.1514372e+00, 1.7608149e-01],
       [9.9048954e-01, 2.3953698e-03, 1.3279575e-01, 4.8818569e+00,
        7.9914503e-02, 4.8371044e-01, 5.5762547e-01],
       [3.8703341e-02, 2.5159404e-02, 3.2699874e-01, 7.9914503e-02,
        4.7195563e+00, 1.5670618e-01, 6.8302453e-02],
       [1.3638198e+00, 5.2153909e-01, 1.1514373e+00, 4.8371044e-01,
        1.5670618e-01, 1.5743092e+00, 4.7601497e-01],
       [1.2414166e+00, 1.7042074e-02, 1.7608151e-01, 5.5762547e-01,
        6.8302453e-02, 4.7601494e-01, 9.6188850e+00]], dtype=float32)

In [23]:
np.isclose(sq_co_occ,op_co_occ, rtol=1e-03).all()

True

## test implementation in code