-
Notifications
You must be signed in to change notification settings - Fork 580
/
data.py
128 lines (86 loc) · 2.88 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
"""
Functions returning copies of datasets as cheaply as possible,
i.e. without having to hit the disk or (in case of ``_pbmc3k_normalized``) recomputing normalization.
"""
from __future__ import annotations
import warnings
try:
from functools import cache
except ImportError: # Python < 3.9
from functools import lru_cache
def cache(func):
return lru_cache(maxsize=None)(func)
from typing import TYPE_CHECKING
import dask.array as da
from dask import delayed
from scipy import sparse
import scanpy as sc
if TYPE_CHECKING:
from anndata import AnnData
from anndata._core.sparse_dataset import SparseDataset
# Functions returning the same objects (easy to misuse)
_pbmc3k = cache(sc.datasets.pbmc3k)
_pbmc3k_processed = cache(sc.datasets.pbmc3k_processed)
_pbmc68k_reduced = cache(sc.datasets.pbmc68k_reduced)
_krumsiek11 = cache(sc.datasets.krumsiek11)
_paul15 = cache(sc.datasets.paul15)
# Functions returning copies
def pbmc3k() -> AnnData:
return _pbmc3k().copy()
def pbmc3k_processed() -> AnnData:
return _pbmc3k_processed().copy()
def pbmc68k_reduced() -> AnnData:
return _pbmc68k_reduced().copy()
def krumsiek11() -> AnnData:
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore", "Observation names are not unique", module="anndata"
)
return _krumsiek11().copy()
def paul15() -> AnnData:
return _paul15().copy()
# Derived datasets
@cache
def _pbmc3k_normalized() -> AnnData:
pbmc = pbmc3k()
pbmc.X = pbmc.X.astype("float64") # For better accuracy
sc.pp.filter_genes(pbmc, min_counts=1)
sc.pp.log1p(pbmc)
sc.pp.normalize_total(pbmc)
sc.pp.highly_variable_genes(pbmc)
return pbmc
def pbmc3k_normalized() -> AnnData:
return _pbmc3k_normalized().copy()
class CSRCallable:
"""Dummy class to bypass dask checks"""
def __new__(cls, shape, dtype):
return csr_callable(shape, dtype)
def csr_callable(shape: tuple[int, int], dtype) -> sparse.csr_matrix:
if len(shape) == 0:
shape = (0, 0)
if len(shape) == 1:
shape = (shape[0], 0)
elif len(shape) == 2:
pass
else:
raise ValueError(shape)
return sparse.csr_matrix(shape, dtype=dtype)
def make_dask_chunk(x: SparseDataset, start: int, end: int) -> da.Array:
def take_slice(x, idx):
return x[idx]
return da.from_delayed(
delayed(take_slice)(x, slice(start, end)),
dtype=x.dtype,
shape=(end - start, x.shape[1]),
meta=CSRCallable,
)
def sparse_dataset_as_dask(x: SparseDataset, stride: int):
n_chunks, rem = divmod(x.shape[0], stride)
chunks = []
cur_pos = 0
for i in range(n_chunks):
chunks.append(make_dask_chunk(x, cur_pos, cur_pos + stride))
cur_pos += stride
if rem:
chunks.append(make_dask_chunk(x, cur_pos, x.shape[0]))
return da.concatenate(chunks, axis=0)