# Basics: Iterables and Iterators


## Creating Iterables From AnnData

The `Ann2DataBasic` class expects an anndata iterable (e.g., `Iterable[AnnData]`) when called upon. This in your case could be a list of anndata objects. However, if you have a single anndata object and want to split it into multiple anndata objects, you can use the implementations of `geome.iterables.ToIterable` classes. The signature of these classes is as follows:

```python
class ToIterable(ABC):
    @abstractmethod
    def __call__(self, adata: AnnData) -> Iterable[AnnData]:
        pass

class Ann2DatAbstract(ABC):
    """Abstract class that transforms an iterable of AnnData to Pytorch Geometric Data objects."""

    def __init__(
        self,
        fields: dict[str, list[str]],
        adata2iterable: Callable[[AnnData], Iterable[AnnData]] | None = None,
        ...
    ) -> None:
        pass

    def __call__(self, adata: AnnData | Iterable[AnnData]) -> Iterable[Data]:
        pass
```

You can give an instance of this class to the `anndata2iter` parameter of the `geome.ann2data.Ann2Data` constructor or you can just call the instance with an anndata object to get an iterable of anndata objects. The advantage of giving the instance to the `anndata2iter` parameter is that you can use the same instance to split multiple anndata objects and specify preprocessing strategies before the split happens.

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from geome import iterables, transforms, ann2data
import squidpy as sq
import numpy as np
from anndata import AnnData

## Load data
First, let's load the data and see what it looks like. In this example assume that we want to split by these categories specified in `adata.obs["Cluster"]`.


In [3]:
# Load squidpy dataset
adata = sq.datasets.mibitof()
adata.obs["Cluster"].cat.categories

Index(['Endothelial', 'Epithelial', 'Fibroblast', 'Imm_other', 'Myeloid_CD11c',
       'Myeloid_CD68', 'Tcell_CD4', 'Tcell_CD8'],
      dtype='object')

### Create ToIterable instance: ToCategoryIterator
We will create an instance of `ToCategoryIterator` class which will split the anndata object by the categories specified in `adata.obs["Cluster"]`. The signature of this class is as follows:

```python
class ToCategoryIterator(ToIterable):
    """Iterates over `adata` by category on the given axis (either obs(0) or var(1)).

    Preserves the categories in the resulting AnnData obs and var Series.
    """

    def __init__(self, category: str, axis: Literal[0, 1, "obs", "var"] = "obs", preserve_categories: bool = True):
        pass

    def __call__(self, adata: AnnData) -> Iterator[AnnData]:
        pass
```

#### Reminder: Iterator vs Iterable
- An `iterable` is an object that can be iterated over. It returns an iterator when `iter()` is called on it.
- An `iterator` is an object that produces the next value when `next()` is called on it.
- An `iterator` is an `iterable` but an `iterable` is not an `iterator`.

The class `ToCategoryIterator` returns an `iterator` of anndata objects since an `iterator` is an `iterable` this complies with the `Ann2Data` class's requirements. Being able to return an `iterator` is useful because it allows us to lazily load the data and not store all the data in memory at once.


In [4]:
to_iterable: iterables.ToIterable = iterables.ToCategoryIterator("Cluster", axis="obs", preserve_categories=True)

In [5]:
split_adatas = list(to_iterable(adata))  # split by cluster
assert len(split_adatas) == len(adata.obs["Cluster"].cat.categories) # ensure all clusters have their own adata
split_adatas[:3]  # show first 3

[AnnData object with n_obs × n_vars = 115 × 36
     obs: 'row_num', 'point', 'cell_id', 'X1', 'center_rowcoord', 'center_colcoord', 'cell_size', 'category', 'donor', 'Cluster', 'batch', 'library_id'
     var: 'mean-0', 'std-0', 'mean-1', 'std-1', 'mean-2', 'std-2'
     uns: 'Cluster_colors', 'batch_colors', 'neighbors', 'spatial', 'umap'
     obsm: 'X_scanorama', 'X_umap', 'spatial'
     obsp: 'connectivities', 'distances',
 AnnData object with n_obs × n_vars = 746 × 36
     obs: 'row_num', 'point', 'cell_id', 'X1', 'center_rowcoord', 'center_colcoord', 'cell_size', 'category', 'donor', 'Cluster', 'batch', 'library_id'
     var: 'mean-0', 'std-0', 'mean-1', 'std-1', 'mean-2', 'std-2'
     uns: 'Cluster_colors', 'batch_colors', 'neighbors', 'spatial', 'umap'
     obsm: 'X_scanorama', 'X_umap', 'spatial'
     obsp: 'connectivities', 'distances',
 AnnData object with n_obs × n_vars = 270 × 36
     obs: 'row_num', 'point', 'cell_id', 'X1', 'center_rowcoord', 'center_colcoord', 'cell_size',

#### Important Note about `preserve_categories` parameter

If preserve_categories is set to True, the categories of the original anndata object will be preserved in the split anndata objects. This means that if a category is not present in a split anndata object, it will still be present in the `obs` attribute of the split anndata object but with all values set to 0. This is useful when you want to keep track of the categories that were present in the original anndata object. However, if you want to remove the categories that are not present in a split anndata object, you can set preserve_categories to False. This will remove the categories that are not present in the split anndata object from the `obs` attribute of the split anndata object.

This is important when you want to use one-hot encoding for the categories.

In [6]:
assert all(len(ad.obs["Cluster"].cat.categories) == len(adata.obs["Cluster"].cat.categories) for ad in split_adatas)  # ensure all splits have the same category

The case for `preserve_categories=False` is shown in the example below.

In [7]:
unpreserved_example = list(iterables.ToCategoryIterator("Cluster", axis="obs", preserve_categories=False)(adata))
assert all(len(ad.obs["Cluster"].cat.categories) == 1 for ad in unpreserved_example)
# you can see that the categories are not preserved and each split has only one category
[len(ad.obs["Cluster"].cat.categories) for ad in unpreserved_example]

[1, 1, 1, 1, 1, 1, 1, 1]

### The role of AnnData2Iterable in Ann2Data

In the following example, we will show three different ways of creating data objects that result in the same data objects.
1. `Ann2DataBasic` and giving the `ToCategoryIterator` instance to the `anndata2iter` parameter.
2. `Ann2DataBasic` and splitting the anndata object using the `ToCategoryIterator` instance and then passing the resulting iterable to the `Ann2DataBase` call.
3. `Ann2DataByCategory` which is a subclass of `Ann2DataBasic`, takes the category as a parameter, and uses the `ToCategoryIterator` instance internally.

1. `Ann2DataBasic` and giving the `ToCategoryIterator` instance to the `anndata2iter` parameter.

In [8]:
result1 = ann2data.Ann2DataBasic(
    fields={"x": ["X"]},
    adata2iter=iterables.ToCategoryIterator("Cluster", axis="obs"),
).to_list(adata)

2.  `Ann2DataBasic` and splitting the anndata object using the `ToCategoryIterator` instance and then passing the resulting iterable to the `Ann2DataBase` call.

In [9]:
result2 = ann2data.Ann2DataBasic(
    fields={"x": ["X"]},
).to_list(iterables.ToCategoryIterator("Cluster", axis="obs")(adata))

3. `Ann2DataByCategory` which is a subclass of `Ann2DataBasic`, takes the category as a parameter, and uses the `ToCategoryIterator` instance internally.

In [10]:
result3 = ann2data.Ann2DataByCategory(
    fields={"x": ["X"]},
    category="Cluster",
).to_list(adata)

Below is the code that demonstrates the three ways of creating data objects that result in the same data objects.

In [11]:
assert all(np.allclose(r1.x, r2.x) for r1, r2 in zip(result1, result2)) and all(np.allclose(r1.x, r3.x) for r1, r3 in zip(result1, result3))