# Data handling in scvi-tools

In this tutorial we will cover how data is handled in scvi-tools. 

Sections:

1. Data Registration via `setup_anndata()`.
2. Introduction to the `registry` comprised of `data_registry`, `state_registry`, and `summary_stats`.
3. Explanation of `AnnDataField` classes and how they populate the `registry` via the `AnnDataManager`.
3. Data loading with `AnnDataLoader()`.

In [1]:
import sys

#if branch is stable, will install via pypi, else will install from source
branch = "stable"
IN_COLAB = "google.colab" in sys.modules

if IN_COLAB and branch == "stable":
    !pip install --quiet scvi-tools[tutorials]
elif IN_COLAB and branch != "stable":
    !pip install --quiet --upgrade jsonschema
    !pip install --quiet git+https://github.com/yoseflab/scvi-tools@$branch#egg=scvi-tools[tutorials]

In [2]:
import scvi
from scvi import REGISTRY_KEYS
import numpy as np

Global seed set to 0
  doc = func(self, args[0].__doc__, *args[1:], **kwargs)


## 1. Data Registration

Scvi-tools knows what data to load into models via a data registration process handled by `setup_anndata()`. 

The setup process produces an `AnnDataManager` object which wraps the `AnnData` object and creates a corresponding `registry`. We will go over the `registry` in subsequent sections.

`setup_anndata()` is used to setup data fields specific to each model.

Here we will go over the parameters of one instance of a `setup_anndata()` method, `scvi.model.SCVI.setup_anndata()`:

- `adata` is the input `AnnData` object.
- `layer` is the key in `adata.layers` to use for the input data matrix. By default, this is None and the input data matrix will be pulled from `adata.X`. 
- `batch_key` is the key in `adata.obs` for batch information. If this is None, will assume that all the data is the same batch.
- `labels_key` is the key in `adata.obs` for label information. If this is None, will assume that all the data has the same label.
- `size_factor_key` is the key in `adata.obs` that optionally stores size factors for computing the likelihood. If this is None, the library size is used to compute the size factor.
- `categorical_covariate_keys` is a list of keys in `adata.obs` for categorical covariates.
- `continuous_covariate_key` is a list of keys in `adata.obs` for continuous covariates. 

Under the hood:

- For all categorical data (batch, labels, categorical covariates), scvi will automatically compute a mapping from values to integers. Eg. `['a','b','c','a']` will become `[0,1,2,0]`.
- For data fields registered with `scvi.model.SCVI.setup_anndata()`, scvi will copy the data to a seperate field in the anndata. 
    - `batch_key` is copied to `scvi.obs['_scvi_batch']` with its integer encoding
    - `labels_key` is copied to `scvi.obs['_scvi_labels']` with its integer encoding
    - keys in `categorical_covariate_keys` are concatenated and saved as a pandas DataFrame and stored in `adata.obsm['_scvi_extra_categorical_covs']` with its integer encoding.
    - keys in `continuous_covariate_keys` are concatenated and saved as a pandas DataFrame and stored in `adata.obsm['_scvi_extra_continuous_covs']`

These preprocessing steps are detailed in the `register_field()` function of the `AnnDataField` subclasses representing each data field. Later in this tutorial we will go over how these `AnnDataField` objects tie into the `AnnDataManager` and how they are defined.

In the following code, we first format an example AnnData Object to setup for scvi-tools, then call `scvi.model.SCVI.setup_anndata()` to register all the tensors we want to load to the model during training. 
For our example AnnData Object, we build off the `synthetic_iid()` dataset, copy X to a layer, and add continuous and categorical covariates to the AnnData.

In [3]:
adata = scvi.data.synthetic_iid()
adata.layers['raw_counts'] = adata.X.copy()
adata.obs['my_categorical_covariate'] = ['A'] * 200 + ['B'] * 200
adata.obs['my_continuous_covariate'] = np.random.randint(0,100,400)
print(adata)

AnnData object with n_obs × n_vars = 400 × 100
    obs: 'batch', 'labels', 'my_categorical_covariate', 'my_continuous_covariate'
    uns: 'protein_names'
    obsm: 'protein_expression'
    layers: 'raw_counts'


In [4]:
scvi.model.SCVI.setup_anndata(
    adata,
    batch_key="batch",
    labels_key="labels",
    layer="raw_counts",
    categorical_covariate_keys=["my_categorical_covariate"],
    continuous_covariate_keys=["my_continuous_covariate"],
)

Under the hood, this method creates an `AnnDataManager` instance and stores it in a model-specific manager store until a model is initialized with the same `AnnData` object. We will go over this in detail in section 3 of this tutorial.

## 2. Model Registry

In this section we enumerate the fields in the model registry in the case of `scvi.model.SCVI`. The registry takes the form of a nested dictionary and is stored as an instance variable of a model, `model.registry_`. This is a pointer to the registry on the `AnnDataManager` object created by `setup_anndata()`, `model.adata_manager.registry`.

The top level of the registry contains the following keys:

- `scvi_version` keeps track of the version of scvi-tools used to setup the AnnData Object.
- `model_name` and `setup_args` keep track of the model and arguments used to run `setup_anndata()`.
- `field_registries` is dictionary which maps registry keys (e.g. `batch`, `labels`) to additional field-specific information.

Within each field registry, there the following three keys:
- `data_registry` contains the location of data to load. This is what is used by the DataLoaders to iterate over the AnnData.
- `state_registry` contains any state (e.g. categorical mappings for batch) relevant to the field during `register_field()`.
- `summary_stats` contains summary statistics relevant to the field.

We can view a summary of a registry by running `view_anndata_setup()`.

In [10]:
model = scvi.model.SCVI(adata)
registry = model.adata_manager.registry
print(registry.keys()) # There is additionally a _scvi_uuid key which is used to uniquely identify AnnData objects for subsequent retrieval.

dict_keys(['scvi_version', 'model_name', 'setup_kwargs', 'field_registries', '_scvi_uuid'])


In [11]:
model.view_anndata_setup()

The above summary incorporates all three of the components making up each field registry as mentioned before.

### Data Registry
First, lets turn our attention to the `data_registry`.

This is used by the DataLoaders to load data during the data loop. Each key of the data_registry is the name of tensor and is used to retreive the data from the dataloader output.

- All the data registered via `scvi.model.SCVI.setup_anndata()` has its keys globally set via `scvi.REGISTRY_KEYS` (we will see later that these can be user-defined keys as well).

The value of each key in the data_registry is a dictionary with two keys: `attr_name` and `attr_key`. 

- `attr_name` is the attribute of `adata` to load data from eg. `obs`, `obsm`, `layers`.
- `attr_key` is the key of the attribute to access the data.


For example, based off the following data_registry, batch information is loaded from `adata.obs['_scvi_batch']` and will be accessible via `REGISTRY_KEYS.BATCH_KEY`.

While the data registry dictionary is stored within the `registry`, the `AnnDataManager` provides a helper method, `adata_manager.data_registry`, which coalesces the full data registry across each of the fields. This helper method additionally wraps the dictionary in a custom `attrdict` class which allows dictionary access via dot notation (e.g. `data_registry.batch.attr_name`).

In [15]:
data_registry = model.adata_manager.data_registry
data_registry

attrdict({'X': attrdict({'attr_name': 'layers', 'attr_key': 'raw_counts'}), 'batch': attrdict({'attr_name': 'obs', 'attr_key': '_scvi_batch'}), 'labels': attrdict({'attr_name': 'obs', 'attr_key': '_scvi_labels'}), 'extra_categorical_covs': attrdict({'attr_name': 'obsm', 'attr_key': '_scvi_extra_categorical_covs'}), 'extra_continuous_covs': attrdict({'attr_name': 'obsm', 'attr_key': '_scvi_extra_continuous_covs'})})

In [16]:
print(REGISTRY_KEYS.X_KEY)                 # key for X values
print(REGISTRY_KEYS.BATCH_KEY)             # key for batch info
print(REGISTRY_KEYS.LABELS_KEY)            # key for label data
print(REGISTRY_KEYS.PROTEIN_EXP_KEY)       # key for protein data
print(REGISTRY_KEYS.CAT_COVS_KEY)          # key for categorical covariate data
print(REGISTRY_KEYS.CONT_COVS_KEY)         # key for continuous covariate data

X
batch
labels
proteins
extra_categorical_covs
extra_continuous_covs


In [18]:
print(REGISTRY_KEYS.BATCH_KEY)
print(data_registry[REGISTRY_KEYS.BATCH_KEY])
print(data_registry.batch.attr_key)

batch
attrdict({'attr_name': 'obs', 'attr_key': '_scvi_batch'})
_scvi_batch


### State Registries

During the data registration process, we also keep track of additional information from the registration process, necessary for model initialization or downstream functionality. For example, for the batch field, scvi-tools keeps track of the location of the original data as well as the categorical to integer mappings.

The batch state registry holds the following two keys:
- `original_key` is the original key passed in by the user to load the data.
- `categorical_mapping` is the categorical to integer mapping of the data. The index of the category is its corresponding integer representation.

We can access a state registry via the function `AnnDataManager.get_state_registry()` which takes a registry key.

In [21]:
batch_state_registry = model.adata_manager.get_state_registry(REGISTRY_KEYS.BATCH_KEY)
print(batch_state_registry.keys())

print(f"Categorical mapping: {batch_state_registry.categorical_mapping}")
print(f"Original key: {batch_state_registry.original_key}")

dict_keys(['categorical_mapping', 'original_key'])
Categorical mapping: ['batch_0' 'batch_1']
Original key: batch


Here, we will give another example of a state registry for the field `extra_categorical_covariates`.

The extra categorical covariates state registry contains three keys:

- `mappings` whose value is a dictionary where the key is the original obs key and the value is the categorical mapping.
- `field_keys` these are the keys of the pandas DataFrame in `adata.obsm['_scvi_extra_categoricals']` created by `setup_anndata()`.
- `n_cats_per_key` contains the number of categories per key.

In [28]:
extra_cat_state_registry = model.adata_manager.get_state_registry(REGISTRY_KEYS.CAT_COVS_KEY)
extra_cat_state_registry.keys()

dict_keys(['mappings', 'field_keys', 'n_cats_per_key'])

In [29]:
print(f"Mappings: {extra_cat_state_registry.mappings}")
print(f"Keys: {extra_cat_state_registry.field_keys}")
print(f"N cats per key: {extra_cat_state_registry.n_cats_per_key}")

Mappings: {'my_categorical_covariate': array(['A', 'B'], dtype=object)}
Keys: ['my_categorical_covariate']
N cats per key: [2]


### Summary Stats

Lastly, we have the summary stats dictionary which is a dictionary meant to store summary statistics frequently used in models, to avoid redundancy and for summarization in `view_anndata_setup()`. Like the other two components, the `AnnDataManager` has a helper method in the form of the property `adata_manager.summary_stats`.

In [30]:
model.adata_manager.summary_stats

attrdict({'n_cells': 400, 'n_vars': 100, 'n_batch': 2, 'n_labels': 3, 'n_extra_categorical_covs': 1, 'n_extra_continuous_covs': 1})

## 3. AnnDataManager and AnnDataFields

Now that we have gone over the final output of `setup_anndata()`, we can go over how the underlying logic is organized.

While the `AnnDataManager` provides the main interface to the data registration components, the logic specific to each field is encapsulated in `AnnDataField` classes (any child class of `BaseAnnDataField`).

An `AnnDataField` class contains four main functions to be implemented:
1. `register_field` sets up the relevant field on the AnnData object and returns the state registry for this field.
2. `validate_field` is a function called before `register_field`. E.g. checks if the data field is present on the AnnData object.
3. `transfer_field` is a function similar to `register_field`, but additionally takes a source `state_registry` which can modify the behavior of registration. E.g. for categorical fields we may want to maintain the source categories and append any additional categories on the target AnnData object for downstream transfer learning.
4. `get_summary_stats` is a function that takes a `state_registry` and outputs the summary stat dictionary. Note, this means the summary statistics must be a function of what is stored in `state_registry`.

Together, the set of `AnnDataField`s produces the `registry` detailed in part 2.

The `AnnDataManager` takes a set of `AnnDataField`s and orchestrates calls to these functions and stores the resulting `registry`. As mentioned before, the `AnnDataManager` is constructed during `setup_anndata()` and retrieved during model initialization.

Here we have an abbreviated version of a `setup_anndata()` implementation for a model that only takes a `layer` kwarg and a `batch_key`:

```python
@classmethod
def setup_anndata(
    cls,
    adata: AnnData,
    layer: Optional[str] = None,
    batch_key: Optional[str] = None,
    **kwargs, # Used when loading a model with a new AnnData object.
):
    setup_method_args = cls._get_setup_method_args(**locals()) # Used for saving/loading purposes.
    anndata_fields = [
        LayerField(REGISTRY_KEYS.X_KEY, layer, is_count_data=True),
        CategoricalObsField(REGISTRY_KEYS.BATCH_KEY, batch_key),
    ]
    adata_manager = AnnDataManager(
        fields=anndata_fields, setup_method_args=setup_method_args
    )
    adata_manager.register_fields(adata, **kwargs)
    cls.register_manager(adata_manager) # Stores the AnnDataManager in a class-specific manager store.
```

The `setup_anndata()` function itself is quite simple since any complexity in preprocessing is contained within the `AnnDataField` functions. By factorizing the preprocessing steps into each subclass, model developers can easily extend and reuse logic across models and fields.

## 4. DataLoaders

`AnnDataLoader` is the base dataloader for scvi-tools. In this section we show how the data registered is loaded by `AnnDataLoader`.

Parameters of `AnnDataLoader`:

- `adata_manager`: `AnnDataManager` object to load data from.
- `shuffle`: if True will shuffle the data beforehand.
- `indices`: can provide a subset of indices to load from (Useful when doing train/test splits).
- `data_and_attributes`: a dictionary where the key corresponds to its key in the `data_registry` and the value is the numpy data type. By default, all data is passed to the model as `np.float32`.
- `data_loader_kwargs`: additional arguments from `torch.utils.data.DataLoader`.

First, we construct an `AnnDataLoader` and get the first batch. Then we will enumerate all the values in the batch. The variable **data_batch** contains the first batch of data. It is a dictionary whose values are the tensors registered in the previous section via `setup_anndata()`. 


In [31]:
from scvi.dataloaders._ann_dataloader import AnnDataLoader

# initialize an AnnDataLoader which will iterate over our anndata
adl = AnnDataLoader(model.adata_manager, shuffle=False, batch_size = 10)

# get the first batch of data
data_batch = next(tensors for tensors in adl)

For tensors setup with `setup_anndata()` the keys are from `scvi.REGISTRY_KEYS`. Notice that the keys in **data_batch** are the same as the keys in the `data_registry`. See previous section for more detailed explanation

In [32]:
print('data_batch_keys:')
print(data_batch.keys())

data_batch_keys:
dict_keys(['X', 'batch', 'labels', 'extra_categorical_covs', 'extra_continuous_covs'])


In [36]:
model.adata_manager.data_registry.keys()

dict_keys(['X', 'batch', 'labels', 'extra_categorical_covs', 'extra_continuous_covs'])

If we look at the labels for the first batch from the data loader, it corresponds to the labels of the first 10 cells of our AnnData. 

In [37]:
adata.obs['labels'][:10]

0    label_2
1    label_2
2    label_0
3    label_0
4    label_2
5    label_0
6    label_2
7    label_0
8    label_2
9    label_1
Name: labels, dtype: category
Categories (3, object): ['label_0', 'label_1', 'label_2']

In [39]:
# setup_anndata automatically encoded the categorical labels as integers
data_batch[REGISTRY_KEYS.LABELS_KEY] 

tensor([[2.],
        [2.],
        [0.],
        [0.],
        [2.],
        [0.],
        [2.],
        [0.],
        [2.],
        [1.]])

In [40]:
print(data_batch[REGISTRY_KEYS.X_KEY].shape) #shape is batch_size x n_genes
print(data_batch[REGISTRY_KEYS.BATCH_KEY].shape) #shape is batch_size x 1

torch.Size([10, 100])
torch.Size([10, 1])


By default, all the data loaded in scvi-tools is `np.float32`. If you wish to load as a different datatype, you can pass in a dictionary where the key corresponds to a key in the data registry and the value is the datatype.

In the following snippet, we load some continuous data as `np.float64` and integer data as `np.long32`.

In [41]:
adl = AnnDataLoader(model.adata_manager, shuffle=False, batch_size = 10)
data_batch = next(tensors for tensors in adl)

# by default data has the dtype np.float32
print(data_batch[REGISTRY_KEYS.X_KEY].dtype) 
print(data_batch[REGISTRY_KEYS.BATCH_KEY].dtype) 

torch.float32
torch.float32


In [42]:
data_batch.keys()

dict_keys(['X', 'batch', 'labels', 'extra_categorical_covs', 'extra_continuous_covs'])

To specify the datatype of each key, we can use the `data_and_attributes` parameter of AnnDataLoader. Here we make make `X` an `np.long` and our `cat_covs` an `np.float64`, but keep everything else as `np.float32`.

In [43]:
#the keys of data_and_attributes should correspond to keys in the data registry
data_registry_keys = model.adata_manager.data_registry.keys()
print("Data Registry keys:",data_registry_keys)

Data Registry keys: dict_keys(['X', 'batch', 'labels', 'extra_categorical_covs', 'extra_continuous_covs'])


In [44]:
data_and_attributes = {}
for key in data_registry_keys:
    if key == REGISTRY_KEYS.X_KEY:
        data_and_attributes[key] = np.long
    elif key == REGISTRY_KEYS.CONT_COVS_KEY:
        data_and_attributes[key] = np.float64
    else:
        data_and_attributes[key] = np.float32
print(data_and_attributes)

{'X': <class 'int'>, 'batch': <class 'numpy.float32'>, 'labels': <class 'numpy.float32'>, 'extra_categorical_covs': <class 'numpy.float32'>, 'extra_continuous_covs': <class 'numpy.float64'>}


Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  data_and_attributes[key] = np.long


In [45]:
adl = AnnDataLoader(model.adata_manager, shuffle=False, batch_size = 10, data_and_attributes=data_and_attributes)
data_batch = next(tensors for tensors in adl)

# by default data has the dtype np.float32
print(data_batch[REGISTRY_KEYS.X_KEY].dtype) 
print(data_batch[REGISTRY_KEYS.CONT_COVS_KEY].dtype) 

torch.int64
torch.float64


Finally, if the `data_and_attributes` parameter is used, it will only load the keys of the passed in dictionary. For example, if the only key in the dictionary passed in to `data_and_attributes` is X, the data loader will only load X.

In [47]:
data_and_attributes = {REGISTRY_KEYS.X_KEY: np.float}
adl = AnnDataLoader(
    model.adata_manager, shuffle=False, batch_size=10, data_and_attributes=data_and_attributes
)
data_batch = next(tensors for tensors in adl)

print(data_batch.keys())

dict_keys(['X'])


Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  data_and_attributes = {REGISTRY_KEYS.X_KEY: np.float}
