# Creating a new kernel

This tutorial shows you how to create your own kernel class which computes a cell-cell transition matrix.

## Import packages & data

In [2]:
from typing import Any
from copy import copy
from anndata import AnnData

import cellrank as cr
import numpy as np
import scipy.sparse as sp

In [3]:
adata = cr.datasets.pancreas()
adata

AnnData object with n_obs × n_vars = 2531 × 27998
    obs: 'day', 'proliferation', 'G2M_score', 'S_score', 'phase', 'clusters_coarse', 'clusters', 'clusters_fine', 'louvain_Alpha', 'louvain_Beta', 'palantir_pseudotime'
    var: 'highly_variable_genes'
    uns: 'clusters_colors', 'clusters_fine_colors', 'day_colors', 'louvain_Alpha_colors', 'louvain_Beta_colors', 'neighbors', 'pca'
    obsm: 'X_pca', 'X_umap'
    layers: 'spliced', 'unspliced'
    obsp: 'connectivities', 'distances'

## Minimal kernel

In order to create your own kernel class, you just need to do these 3 things:

- subclass from `cellrank.tl.kernels.Kernel`
- `.compute_transition_matrix` - saves the row-stochastic transition matrix in attribute `._transition_matrix` and returns itself, for this we use `._compute_transition_matrix` helper method
- `.copy`: return a copy of the kernel.

The `._compute_transition_matrix` row-normalizes any matrix passed to it, if needed (all elements must be non-negative) and optionally computes the condition number (can be costly and only works on dense matrices).

Below you can see a minimal implementation of a kernel where the computed transition matrix is just a diagonal.

In [30]:
class MyKernel(cr.tl.kernels.Kernel):
    
    def compute_transition_matrix(self, some_parameter: float = 0.5) -> "MyKernel":
        transition_matrix = sp.diags((some_parameter,) * len(self.adata), dtype=np.float64)
        self._compute_transition_matrix(transition_matrix, density_normalize=True)
        return self
    
    def copy(self) -> "MyKernel":
        return copy(self)

In [32]:
k = MyKernel(adata).compute_transition_matrix()
k.transition_matrix.A

array([[1., 0., 0., ..., 0., 0., 0.],
       [0., 1., 0., ..., 0., 0., 0.],
       [0., 0., 1., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 1., 0., 0.],
       [0., 0., 0., ..., 0., 1., 0.],
       [0., 0., 0., ..., 0., 0., 1.]])

## Reading from AnnData

In [41]:
class MyKernel(cr.tl.kernels.Kernel):
    def __init__(self, adata: AnnData, obs_key: str = "palantir_pseudotime", **kwargs: Any):
        super().__init__(adata=adata, obs_key=obs_key, **kwargs)
    
    def _read_from_adata(self, obs_key: str, **kwargs: Any):
        super()._read_from_adata(**kwargs)
        
        print(f"Reading `adata.obs[{obs_key!r}]`")
        self.pseudotime = self.adata.obs[obs_key].values
    
    def compute_transition_matrix(self, some_parameter: float = 0.5) -> "MyKernel":
        print("Accessing `.pseudotime`: ", self.pseudotime)
        transition_matrix = sp.diags((some_parameter,) * len(self.adata), dtype=np.float64)
        
        self._compute_transition_matrix(transition_matrix)
        
        return self
    
    def copy(self) -> "MyKernel":
        return copy(self)

In [34]:
k = MyKernel(adata).compute_transition_matrix()
k

Reading `adata.obs['palantir_pseudotime']`
Accessing `.pseudotime`:  [0.81281052 0.81832897 0.48974318 ... 0.73317134 0.92208156 0.8219729 ]


<MyKe>

## Caching values

Kernels can be combines using elementwise operators `+` and `*`. However, this can lead to multiple evaluations of the same expression, was it not for caching.

To that effect, we provide a method `._reuse_cache(parameters: Dict[str, Any]) -> bool` that returns `True` if a cached version is available or `False` otherwise. It also updates the parameters, which are accessible through attribute `.params`.

In [35]:
class MyKernel(cr.tl.kernels.Kernel):
    def __init__(self, adata: AnnData, obs_key: str = "palantir_pseudotime", **kwargs: Any):
        super().__init__(adata=adata, obs_key=obs_key, **kwargs)
    
    def _read_from_adata(self, obs_key: str, **kwargs: Any):
        super()._read_from_adata(**kwargs)
        
        print(f"Reading `adata.obs[{obs_key!r}]`")
        self.pseudotime = self.adata.obs[obs_key].values
    
    def compute_transition_matrix(self, some_parameter: float = 0.5) -> "MyKernel":
        if self._reuse_cache({"some_parameter": some_parameter}):
            print("Using cached values for parameters:", self.params)
            return self
        
        transition_matrix = sp.diags((some_parameter,) * len(self.adata), dtype=np.float64)
        
        self._compute_transition_matrix(transition_matrix, density_normalize=True)
        
        return self
    
    def copy(self) -> "MyKernel":
        return copy(self)

In [37]:
k = MyKernel(adata).compute_transition_matrix(some_parameter=0.1)
k.compute_transition_matrix(some_parameter=0.1)
print(k)

Reading `adata.obs['palantir_pseudotime']`
Using cached values for parameters: {'some_parameter': 0.1}
<MyKe[some_parameter=0.1]>


## Inverting kernel

Kernels have also associated direction with them, and the direction can be inverted, with `~` operator.
Although this is a very niche functionality, we recommend overriding `__invert__` method.
This operation is in-place and it:

- changes the direction (i.e. attribute `.backward` will become `True` if it was `False` and vice-versa).
- invalidates the current transition matrix and the parameters.

How this is implemented really depends on what data has been loaded into the kernel object. In our case, we just need to change the `.pseudotime` attribute.

In [38]:
class MyKernel(cr.tl.kernels.Kernel):
    def __init__(self, adata: AnnData, obs_key: str = "palantir_pseudotime", **kwargs: Any):
        super().__init__(adata=adata, obs_key=obs_key, **kwargs)
    
    def _read_from_adata(self, obs_key: str, **kwargs: Any):
        super()._read_from_adata(**kwargs)
        
        print(f"Reading `adata.obs[{obs_key!r}]`")
        self.pseudotime = self.adata.obs[obs_key].values
    
    def compute_transition_matrix(self, some_parameter: float = 0.5) -> "MyKernel":
        if self._reuse_cache({"some_parameter": some_parameter}):
            print("Using cached values for parameters:", self.params)
            return self
        
        transition_matrix = sp.diags((some_parameter,) * len(self.adata), dtype=np.float64)
        
        self._compute_transition_matrix(transition_matrix, density_normalize=True)
        
        return self
    
    def __invert__(self) -> "MyKernel":
        super().__invert__()
        self.pseudotime = np.max(self.pseudotime) - self.pseudotime
        return self
    
    def copy(self) -> "MyKernel":
        return copy(self)

In [39]:
k = MyKernel(adata)
k.pseudotime, k.backward

Reading `adata.obs['palantir_pseudotime']`


(array([0.81281052, 0.81832897, 0.48974318, ..., 0.73317134, 0.92208156,
        0.8219729 ]),
 False)

In [40]:
k_inv = ~k
assert k_inv is k  # operation is in-place
k.pseudotime, k.backward

(array([0.18718948, 0.18167103, 0.51025682, ..., 0.26682866, 0.07791844,
        0.1780271 ]),
 True)

# Conclusion

TODO:
- link to kernel tricks in docs?
- encourage contributions + link to contributing.rst if not already coming from there