/
_base_field.py
161 lines (131 loc) · 5.49 KB
/
_base_field.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
from abc import ABC, abstractmethod
from typing import Optional, Type, Union
import numpy as np
import pandas as pd
import rich
from scvi._types import AnnOrMuData
from scvi.data import _constants
from scvi.data._utils import get_anndata_attribute
class BaseAnnDataField(ABC):
"""
Abstract class for a single AnnData/MuData field.
A Field class defines how scvi-tools will map a data field used by a model
to an attribute in an AnnData/MuData object.
"""
@property
@abstractmethod
def registry_key(self) -> str:
"""The key that is referenced by models via a data loader."""
@property
@abstractmethod
def attr_name(self) -> str:
"""The name of the AnnData attribute where the data is stored."""
@property
@abstractmethod
def attr_key(self) -> Optional[str]:
"""The key of the data field within the relevant AnnData attribute."""
@property
def mod_key(self) -> Optional[str]:
"""The modality key of the data field within the MuData (if applicable)."""
return None
@property
@abstractmethod
def is_empty(self) -> bool:
"""
Returns True if the field is empty as a function of its kwargs.
A field can be empty if it is composed of a collection of variables, and for a given
instance of a model, the collection is empty. If empty, the field will be omitted from
the registry, but included in the summary stats dictionary.
"""
@abstractmethod
def validate_field(self, adata: AnnOrMuData) -> None:
"""Validates whether an AnnData/MuData object is compatible with this field definition."""
@abstractmethod
def register_field(self, adata: AnnOrMuData) -> dict:
"""
Sets up the AnnData/MuData object and creates a mapping for scvi-tools models to use.
Returns
-------
dict
A dictionary containing any additional state required for scvi-tools models not
stored directly on the AnnData/MuData object.
"""
self.validate_field(adata)
return dict()
@abstractmethod
def transfer_field(
self, state_registry: dict, adata_target: AnnOrMuData, **kwargs
) -> dict:
"""
Takes an existing scvi-tools setup dictionary and transfers the same setup to the target AnnData.
Used when one is running a pretrained model on a new AnnData object, which
requires the mapping from the original data to be applied to the new AnnData object.
Parameters
----------
state_registry
state_registry dictionary created after registering an AnnData using an :class:`~scvi.data.AnnDataManager` object.
adata_target
AnnData/MuData object that is being registered.
**kwargs
Keyword arguments to modify transfer behavior.
Returns
-------
dict
A dictionary containing any additional state required for scvi-tools models not
stored directly on the AnnData object.
"""
return dict()
@abstractmethod
def get_summary_stats(self, state_registry: dict) -> dict:
"""
Returns a dictionary comprising of summary statistics relevant to the field.
Parameters
----------
state_registry
Dictionary returned by :meth:`~scvi.data.fields.BaseAnnDataField.register_field`.
Summary stats should always be a function of information stored in this dictionary.
Returns
-------
summary_stats_dict
The dictionary is of the form {summary_stat_name: summary_stat_value}.
This mapping is then combined with the mappings of other fields to make up
the summary stats mapping.
"""
@abstractmethod
def view_state_registry(self, state_registry: dict) -> Optional[rich.table.Table]:
"""
Returns a :class:`rich.table.Table` summarizing a state registry produced by this field.
Parameters
----------
state_registry
Dictionary returned by :meth:`~scvi.data.fields.BaseAnnDataField.register_field`.
Printed summary should always be a function of information stored in this dictionary.
Returns
-------
state_registry_summary
Optional :class:`rich.table.Table` summarizing the ``state_registry``.
"""
def get_field_data(self, adata: AnnOrMuData) -> Union[np.ndarray, pd.DataFrame]:
"""Returns the requested data as determined by the field for a given AnnData/MuData object."""
if self.is_empty:
raise AssertionError(f"The {self.registry_key} field is empty.")
return get_anndata_attribute(
adata, self.attr_name, self.attr_key, mod_key=self.mod_key
)
def get_data_registry(self) -> dict:
"""
Returns a nested dictionary which describes the mapping to the data field.
The dictionary is of the form {"mod_key": mod_key, "attr_name": attr_name, "attr_key": attr_key}.
This mapping is then combined with the mappings of other fields to make up the data registry.
"""
if self.is_empty:
return dict()
data_registry = {
_constants._DR_ATTR_NAME: self.attr_name,
_constants._DR_ATTR_KEY: self.attr_key,
}
if self.mod_key is not None:
data_registry[_constants._DR_MOD_KEY] = self.mod_key
return data_registry
# Convenience type
AnnDataField = Type[BaseAnnDataField]