Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

changed hvg with PR to work with numba #2612

Merged
merged 14 commits into from Aug 22, 2023
152 changes: 113 additions & 39 deletions scanpy/experimental/pp/_highly_variable_genes.py
Expand Up @@ -3,10 +3,11 @@
from typing import Optional, Literal

import numpy as np
import numba as nb
import pandas as pd
import scipy.sparse as sp_sparse
from anndata import AnnData

from math import sqrt

from scanpy import logging as logg
from scanpy._settings import settings, Verbosity
Expand All @@ -27,6 +28,86 @@
)


@nb.njit(parallel=True)
def calculate_res_sparse(
indptr,
index,
data,
sums_genes,
sums_cells,
residuals,
sum_total,
clip,
theta,
n_genes,
n_cells,
):
Intron7 marked this conversation as resolved.
Show resolved Hide resolved
for gene in nb.prange(n_genes):
start_idx = indptr[gene]
stop_idx = indptr[gene + 1]

sparse_idx = start_idx
var_sum = np.float64(0.0)
sum_clipped_res = np.float64(0.0)
for cell in range(n_cells):
mu = sums_genes[gene] * sums_cells[cell] / sum_total
value = np.float64(0.0)
if sparse_idx < stop_idx and index[sparse_idx] == cell:
value = data[sparse_idx]
sparse_idx += 1
mu_sum = value - mu
pre_res = mu_sum / sqrt(mu + mu * mu / theta)
clipped_res = min(max(pre_res, -clip), clip)
flying-sheep marked this conversation as resolved.
Show resolved Hide resolved
sum_clipped_res += clipped_res

mean_clipped_res = sum_clipped_res / n_cells
sparse_idx = start_idx
for cell in range(n_cells):
mu = sums_genes[gene] * sums_cells[cell] / sum_total
value = np.float64(0.0)
if sparse_idx < stop_idx and index[sparse_idx] == cell:
value = data[sparse_idx]
sparse_idx += 1
mu_sum = value - mu
pre_res = mu_sum / sqrt(mu + mu * mu / theta)
clipped_res = min(max(pre_res, -clip), clip)
diff = clipped_res - mean_clipped_res
var_sum += diff * diff

residuals[gene] = var_sum / n_cells


@nb.njit(parallel=True)
def calculate_res_dense(
matrix, sums_genes, sums_cells, residuals, sum_total, clip, theta, n_genes, n_cells
):
flying-sheep marked this conversation as resolved.
Show resolved Hide resolved
for gene in nb.prange(n_genes):
sum_clipped_res = np.float64(0.0)
for cell in range(n_cells):
mu = sums_genes[gene] * sums_cells[cell] / sum_total
value = matrix[cell, gene]

mu_sum = value - mu
pre_res = mu_sum / sqrt(mu + mu * mu / theta)
clipped_res = min(max(pre_res, -clip), clip)
flying-sheep marked this conversation as resolved.
Show resolved Hide resolved
sum_clipped_res += clipped_res

mean_clipped_res = sum_clipped_res / n_cells

var_sum = np.float64(0.0)
for cell in range(n_cells):
mu = sums_genes[gene] * sums_cells[cell] / sum_total
value = matrix[cell, gene]

mu_sum = value - mu
pre_res = mu_sum / sqrt(mu + mu * mu / theta)
clipped_res = min(max(pre_res, -clip), clip)
diff = clipped_res - mean_clipped_res
var_sum += diff * diff

residuals[gene] = var_sum / n_cells


def _highly_variable_pearson_residuals(
adata: AnnData,
theta: float = 100,
Expand All @@ -39,30 +120,6 @@ def _highly_variable_pearson_residuals(
subset: bool = False,
inplace: bool = True,
) -> Optional[pd.DataFrame]:
"""\
See `scanpy.experimental.pp.highly_variable_genes`.

Returns
-------
If `inplace=True`, `adata.var` is updated with the following fields. Otherwise,
returns the same fields as :class:`~pandas.DataFrame`.

highly_variable : bool
boolean indicator of highly-variable genes
means : float
means per gene
variances : float
variance per gene
residual_variances : float
Residual variance per gene. Averaged in the case of multiple batches.
highly_variable_rank : float
Rank of the gene according to residual variance, median rank in the case of multiple batches
highly_variable_nbatches : int
If `batch_key` given, denotes in how many batches genes are detected as HVG
highly_variable_intersection : bool
If `batch_key` given, denotes the genes that are highly variable in all batches
"""

view_to_actual(adata)
X = _get_obs_rep(adata, layer=layer)
computed_on = layer if layer else 'adata.X'
Expand Down Expand Up @@ -105,24 +162,41 @@ def _highly_variable_pearson_residuals(
if clip < 0:
raise ValueError("Pearson residuals require `clip>=0` or `clip=None`.")

residual_gene_var = np.zeros((X_batch.shape[1]), dtype=np.float64)
flying-sheep marked this conversation as resolved.
Show resolved Hide resolved
if sp_sparse.issparse(X_batch):
sums_genes = np.sum(X_batch, axis=0)
sums_cells = np.sum(X_batch, axis=1)
sums_genes = np.array(X_batch.sum(axis=0)).ravel()
sums_cells = np.array(X_batch.sum(axis=1)).ravel()
sum_total = np.sum(sums_genes).squeeze()
X_batch = X_batch.tocsc()
calculate_res_sparse(
X_batch.indptr,
X_batch.indices,
X_batch.data.astype(np.float64),
sums_genes,
sums_cells,
residual_gene_var,
np.float64(sum_total),
np.float64(clip),
np.float64(theta),
X_batch.shape[1],
X_batch.shape[0],
)
else:
sums_genes = np.sum(X_batch, axis=0, keepdims=True)
sums_cells = np.sum(X_batch, axis=1, keepdims=True)
sums_genes = np.sum(X_batch, axis=0).ravel()
sums_cells = np.sum(X_batch, axis=1).ravel()
sum_total = np.sum(sums_genes)

# Compute pearson residuals in chunks
residual_gene_var = np.empty((X_batch.shape[1]))
for start in np.arange(0, X_batch.shape[1], chunksize):
stop = start + chunksize
mu = np.array(sums_cells @ sums_genes[:, start:stop] / sum_total)
X_dense = X_batch[:, start:stop].toarray()
residuals = (X_dense - mu) / np.sqrt(mu + mu**2 / theta)
residuals = np.clip(residuals, a_min=-clip, a_max=clip)
residual_gene_var[start:stop] = np.var(residuals, axis=0)
X_batch = np.array(X_batch, dtype=np.float64, order='F')
calculate_res_dense(
X_batch,
sums_genes,
sums_cells,
residual_gene_var,
np.float64(sum_total),
np.float64(clip),
np.float64(theta),
X_batch.shape[1],
X_batch.shape[0],
)

# Add 0 values for genes that were filtered out
unmasked_residual_gene_var = np.zeros(len(nonzero_genes))
Expand Down