Skip to content

Commit

Permalink
Implement nearest neighbor approach for tcr-dist.
Browse files Browse the repository at this point in the history
* Also implement Levenshtein distance (Fix #32).
  • Loading branch information
grst committed Mar 20, 2020
1 parent d50f41b commit d73d8fe
Show file tree
Hide file tree
Showing 4 changed files with 128 additions and 4 deletions.
3 changes: 2 additions & 1 deletion docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,5 @@ nbsphinx
sphinxcontrib-bibtex
jupyter_client
ipykernel
ipywidgets
ipywidgets
python-levenshtein
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ requires = [
'pandas<1',
'numpy',
'parasail',
'scikit-learn'
'scikit-learn',
'python-levenshtein'
]

[tool.flit.metadata.requires-extra]
Expand Down
113 changes: 111 additions & 2 deletions sctcrpy/_tools/_tcr_dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,102 @@
import textwrap
from io import StringIO
import umap
from scipy.sparse import coo_matrix
from Levenshtein import distance as levenshtein_dist
import scipy.spatial


def _get_sparse_matrix_from_indices_distances_umap(
knn_indices, knn_dists, n_obs, n_neighbors
):
"""This is from scanpy.neighbors [Wolf18]_."""
rows = np.zeros((n_obs * n_neighbors), dtype=np.int64)
cols = np.zeros((n_obs * n_neighbors), dtype=np.int64)
vals = np.zeros((n_obs * n_neighbors), dtype=np.float64)

for i in range(knn_indices.shape[0]):
for j in range(n_neighbors):
if knn_indices[i, j] == -1:
continue # We didn't get the full knn for i
if knn_indices[i, j] == i:
val = 0.0
else:
val = knn_dists[i, j]

rows[i * n_neighbors + j] = i
cols[i * n_neighbors + j] = knn_indices[i, j]
vals[i * n_neighbors + j] = val

result = coo_matrix((vals, (rows, cols)), shape=(n_obs, n_obs))
result.eliminate_zeros()
return result.tocsr()


def _compute_connectivities_umap(
knn_indices,
knn_dists,
n_obs,
n_neighbors,
*,
set_op_mix_ratio=1.0,
local_connectivity=1.0,
):
"""\
This is from scanpy.neighbors [Wolf18]_ which again has taken it
from umap.fuzzy_simplicial_set [McInnes18]_.
Given a set of data X, a neighborhood size, and a measure of distance
compute the fuzzy simplicial set (here represented as a fuzzy graph in
the form of a sparse matrix) associated to the data. This is done by
locally approximating geodesic distance at each point, creating a fuzzy
simplicial set for each such point, and then combining all the local
fuzzy simplicial sets into a global one via a fuzzy union.
"""
from umap.umap_ import fuzzy_simplicial_set

X = coo_matrix(([], ([], [])), shape=(n_obs, 1))
connectivities = fuzzy_simplicial_set(
X,
n_neighbors,
None,
None,
knn_indices=knn_indices,
knn_dists=knn_dists,
set_op_mix_ratio=set_op_mix_ratio,
local_connectivity=local_connectivity,
)

if isinstance(connectivities, tuple):
# In umap-learn 0.4, this returns (result, sigmas, rhos)
connectivities = connectivities[0]

distances = _get_sparse_matrix_from_indices_distances_umap(
knn_indices, knn_dists, n_obs, n_neighbors
)

return distances, connectivities.tocsr()


def _dist_to_connectivities(
dist_mat: np.array, n_neighbors: int, *, random_state: int = 0
):
"""Convert a distance matrix into a sparse, nearest-neighbor distance
matrix and a sparse adjacencey matrix using umap.nearest_neighbors
and a fuzzy-simlicital-set embedding"""
knn_indices, knn_dists, forest = umap.umap_.nearest_neighbors(
dist_mat,
n_neighbors=n_neighbors,
metric="precomputed",
metric_kwds=dict(),
angular=False,
random_state=random_state,
)

dist, connectivities = _compute_connectivities_umap(
knn_indices, knn_dists, n_obs=dist_mat.shape[0], n_neighbors=n_neighbors
)

return dist, connectivities


class _DistanceCalculator(abc.ABC):
Expand Down Expand Up @@ -43,6 +139,16 @@ def calc_dist_mat(self, seqs: np.ndarray) -> np.ndarray:
return 1 - np.identity(len(seqs))


class _LevenshteinDistanceCalculator(_DistanceCalculator):
"""Calculates the Levenshtein (i.e. edit-distance) between sequences. """

def calc_dist_mat(self, seqs: np.ndarray) -> np.ndarray:
dist = scipy.spatial.distance.pdist(
seqs.reshape(-1, 1), metric=lambda x, y: levenshtein_dist(x[0], y[0])
)
return scipy.spatial.distance.squareform(dist)


class _KideraDistanceCalculator(_DistanceCalculator):
KIDERA_FACTORS = textwrap.dedent(
"""
Expand Down Expand Up @@ -302,11 +408,12 @@ def _dist_for_chain(
def tcr_neighbors(
adata: AnnData,
*,
metric: Literal["alignment", "kidera", "identity"] = "alignment",
metric: Literal["alignment", "kidera", "identity", "levenshtein"] = "alignment",
n_neighbors: int = 15,
n_jobs: [int, None] = None,
inplace: bool = True,
reduction_same_chain=np.fmin,
reduction_other_chain=np.fmin
reduction_other_chain=np.fmin,
) -> Union[None, dict]:
"""Compute the TCRdist on CDR3 sequences.
The equivalent of scanpy.pp.neighbors for TCR sequences.
Expand All @@ -327,6 +434,8 @@ def tcr_neighbors(
dist_calc = _KideraDistanceCalculator(n_jobs=n_jobs)
elif metric == "identity":
dist_calc = _IdentityDistanceCalculator()
elif metric == "levenshtein":
dist_calc = _LevenshteinDistanceCalculator()
else:
raise ValueError("Invalid distance metric.")

Expand Down
13 changes: 13 additions & 0 deletions tests/test_tcr_dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
_KideraDistanceCalculator,
_DistanceCalculator,
_IdentityDistanceCalculator,
_LevenshteinDistanceCalculator,
_dist_for_chain,
tcr_neighbors,
)
Expand All @@ -28,6 +29,11 @@ def identity():
return _IdentityDistanceCalculator()


@pytest.fixture
def levenshtein():
return _LevenshteinDistanceCalculator()


@pytest.fixture
def adata_cdr3():
obs = pd.DataFrame(
Expand Down Expand Up @@ -111,6 +117,13 @@ def test_kidera_dist(kidera):
)


def test_levensthein_dist(levenshtein):
npt.assert_almost_equal(
levenshtein.calc_dist_mat(np.array(["A", "AA", "AAA", "AAR"])),
np.array([[0, 1, 2, 2], [1, 0, 1, 1], [2, 1, 0, 1], [2, 1, 1, 0]]),
)


def test_align_row(aligner):
seqs = np.array(["AWAW", "VWVW", "HHHH"])
row0 = aligner._align_row(seqs, 0)
Expand Down

0 comments on commit d73d8fe

Please sign in to comment.