In [18]:
%load_ext nb_black

import anndata
import inspect
from torch_adata import AnnDataset
from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader
import os
import torch
from typing import Union, List
import numpy as np
from licorice_font import font_format

from scdiffeq._core.utils import function_kwargs

adata = anndata.read_h5ad("./adata.Weinreb2020.h5ad")

The nb_black extension is already loaded. To reload it, use:
  %reload_ext nb_black


<IPython.core.display.Javascript object>

In [19]:
class SplitSize:
    def __init__(self, n_cells: int, n_groups: int):

        self._n_cells = n_cells
        self._n_groups = n_groups

    def _sum_norm(self, vals: Union[List, np.ndarray]) -> np.ndarray:
        return np.array(vals) / vals.sum()

    def uniformly(self):
        div, mod = divmod(self._n_cells, self._n_groups)
        return [div] * (self._n_groups - 1) + [div + mod]

    def proportioned(self, percentages=[0.8, 0.2], remainder_idx=-1):

        percentages = self._sum_norm(np.array(percentages))
        split_lengths = [int(self._n_cells * pct_i) for pct_i in percentages]
        remainder = self._n_cells - sum(split_lengths)
        split_lengths[remainder_idx] += remainder

        return split_lengths


class CellDataManager:
    """Data Manager at the AnnData Level."""

    def __init__(
        self,
        adata,
        use_key="X_pca",
        groupby="Time point",
        obs_keys=None,
        train_key="train",
        val_key="val",
        test_key="test",
        predict_key="predict",
        n_groups=None,
        train_val_percentages=[0.8, 0.2],
        remainder_idx=-1,
        predict_all=True,
        attr_names={"obs": [], "aux": []},
        one_hot=False,
        aux_keys=None,
        silent=True,
    ):

        self.__config__(locals())

    def __parse__(self, kwargs, ignore=["self"], hide=[]):

        for key, val in kwargs.items():
            if not key in ignore:
                setattr(self, key, val)

    def __config__(self, kwargs, ignore=["self"], hide=[]):

        self.AnnDataset_kwargs = function_kwargs(
            func=AnnDataset, kwargs=kwargs, ignore=["adata"]
        )
        self.AnnDataset_kwargs.pop("adata")
        self.__parse__(kwargs, ignore, hide)
        self.df = self.adata.obs.copy()
        self.data_keys = self._get_keys(kwargs)

        if not self.n_groups:
            self.n_groups = len(self.train_val_percentages)

        self.split = SplitSize(self.n_cells, self.n_groups)

        # configure train-val split if it has train but not val
        if self.has_train_not_val:
            self.n_fit = self.train_adata.shape[0]
            self._allocate_validation()

    # -- supporting methods: -------------------------------------------------------------
    def _get_keys(self, kwargs):
        return {
            attr.strip("_key"): val
            for attr, val in kwargs.items()
            if attr.endswith("key")
        }

    def _subset_adata(self, key):
        access_key = self.data_keys[key]
        if not hasattr(self.df, access_key):
            if (key == "predict") and (self.predict_all):
                self.df[access_key] = True
            else:
                print("Key: Access Key pair: {}:{} not found".format(key, access_key))
        # else invoke split w/ requisite args
        return self.adata[self.df[access_key]].copy()

    def _set_new_idx(self, df, idx, key_added):

        tmp = np.zeros(len(df), dtype=bool)
        tmp[idx] = True
        df[key_added] = tmp.astype(bool)

    def _allocate_validation(self, remainder_idx=-1):
        """If validation key is not found, invoke this function. Takes the train subset
        adata and breaks it into non-overlapping train and validation adata subsets."""

        if not self.n_groups:
            self.n_groups = len(self.train_val_percentages)

        train_adata = self.train_adata

        n_cells = train_adata.shape[0]

        self.data_keys["train"] = "fit_train"
        self.data_keys["val"] = "fit_val"

        train_val_split = SplitSize(n_cells, self.n_groups)

        if not self.train_val_percentages:
            n_train, n_val = train_val_split.uniformly()
        else:
            n_train, n_val = train_val_split.proportioned(
                percentages=self.train_val_percentages, remainder_idx=remainder_idx
            )

        original_train_idx = train_adata.obs.index
        train_idx = np.random.choice(
            range(len(original_train_idx)), size=n_train, replace=False
        )
        train_cells = np.zeros(len(original_train_idx), dtype=bool)
        train_cells[train_idx] = True
        fit_train_idx = original_train_idx[train_cells]
        fit_val_idx = original_train_idx[~train_cells]

        self._set_new_idx(self.df, idx=fit_train_idx.astype(int), key_added="fit_train")
        self._set_new_idx(self.df, idx=fit_val_idx.astype(int), key_added="fit_val")

        self.adata.obs = self.df

    def to_dataset(self, key):
        adata = getattr(self, "{}_adata".format(key))
        return AnnDataset(adata=adata, **self.AnnDataset_kwargs)

    # -- Properties: ---------------------------------------------------------------------
    @property
    def cell_idx(self):
        return self.adata.obs.index

    @property
    def n_cells(self):
        return self.adata.shape[0]

    @property
    def n_features(self):
        return self.adata.shape[1]

    @property
    def uniform_split(self) -> list([int, ..., int]):
        return self.split.uniformly()

    @property
    def proportioned_split(self) -> list([int, ..., int]):
        return self.split.proportioned(percentages=self.train_val_percentages)

    @property
    def train_adata(self):
        return self._subset_adata("train")

    @property
    def val_adata(self):
        return self._subset_adata("val")

    @property
    def test_adata(self):
        return self._subset_adata("test")

    @property
    def predict_adata(self):
        return self._subset_adata("predict")

    @property
    def has_train_not_val(self):
        return (hasattr(self.df, self.data_keys["train"])) and (
            not hasattr(self.df, self.data_keys["val"])
        )

    @property
    def train_dataset(self):
        return self.to_dataset("train")

    @property
    def val_dataset(self):
        return self.to_dataset("val")

    @property
    def test_dataset(self):
        return self.to_dataset("test")

    @property
    def predict_dataset(self):
        return self.to_dataset("predict")

<IPython.core.display.Javascript object>

In [20]:
class SplitSize:
    def __init__(self, n_cells: int, n_groups: int):

        self._n_cells = n_cells
        self._n_groups = n_groups

    def _sum_norm(self, vals: Union[List, np.ndarray]) -> np.ndarray:
        return np.array(vals) / vals.sum()

    def uniformly(self):
        div, mod = divmod(self._n_cells, self._n_groups)
        return [div] * (self._n_groups - 1) + [div + mod]

    def proportioned(self, percentages=[0.8, 0.2], remainder_idx=-1):

        percentages = self._sum_norm(np.array(percentages))
        split_lengths = [int(self._n_cells * pct_i) for pct_i in percentages]
        remainder = self._n_cells - sum(split_lengths)
        split_lengths[remainder_idx] += remainder

        return split_lengths


class CellDataManager:
    """Data Manager at the AnnData Level."""

    def __init__(
        self,
        adata,
        use_key="X_pca",
        groupby="Time point",
        obs_keys=None,
        train_key="train",
        val_key="val",
        test_key="test",
        predict_key="predict",
        n_groups=None,
        train_val_percentages=[0.8, 0.2],
        remainder_idx=-1,
        predict_all=True,
        attr_names={"obs": [], "aux": []},
        one_hot=False,
        aux_keys=None,
        silent=True,
    ):

        self.__config__(locals())

    def __parse__(self, kwargs, ignore=["self"], hide=[]):

        for key, val in kwargs.items():
            if not key in ignore:
                setattr(self, key, val)

    def __config__(self, kwargs, ignore=["self"], hide=[]):

        self.AnnDataset_kwargs = function_kwargs(
            func=AnnDataset, kwargs=kwargs, ignore=["adata"]
        )
        self.AnnDataset_kwargs.pop("adata")
        self.__parse__(kwargs, ignore, hide)
        self.df = self.adata.obs.copy()
        self.data_keys = self._get_keys(kwargs)

        if not self.n_groups:
            self.n_groups = len(self.train_val_percentages)

        self.split = SplitSize(self.n_cells, self.n_groups)

        # configure train-val split if it has train but not val
        if self.has_train_not_val:
            self.n_fit = self.train_adata.shape[0]
            self._allocate_validation()

    # -- supporting methods: -------------------------------------------------------------
    def _get_keys(self, kwargs):
        return {
            attr.strip("_key"): val
            for attr, val in kwargs.items()
            if attr.endswith("key")
        }

    def _subset_adata(self, key):
        access_key = self.data_keys[key]
        if not hasattr(self.df, access_key):
            if (key == "predict") and (self.predict_all):
                self.df[access_key] = True
            else:
                print("Key: Access Key pair: {}:{} not found".format(key, access_key))
        # else invoke split w/ requisite args
        return self.adata[self.df[access_key]].copy()

    def _set_new_idx(self, df, idx, key_added):

        tmp = np.zeros(len(df), dtype=bool)
        tmp[idx] = True
        df[key_added] = tmp.astype(bool)

    def _allocate_validation(self, remainder_idx=-1):
        """If validation key is not found, invoke this function. Takes the train subset
        adata and breaks it into non-overlapping train and validation adata subsets."""

        if not self.n_groups:
            self.n_groups = len(self.train_val_percentages)

        train_adata = self.train_adata

        n_cells = train_adata.shape[0]

        self.data_keys["train"] = "fit_train"
        self.data_keys["val"] = "fit_val"

        train_val_split = SplitSize(n_cells, self.n_groups)

        if not self.train_val_percentages:
            n_train, n_val = train_val_split.uniformly()
        else:
            n_train, n_val = train_val_split.proportioned(
                percentages=self.train_val_percentages, remainder_idx=remainder_idx
            )

        original_train_idx = train_adata.obs.index
        train_idx = np.random.choice(
            range(len(original_train_idx)), size=n_train, replace=False
        )
        train_cells = np.zeros(len(original_train_idx), dtype=bool)
        train_cells[train_idx] = True
        fit_train_idx = original_train_idx[train_cells]
        fit_val_idx = original_train_idx[~train_cells]

        self._set_new_idx(self.df, idx=fit_train_idx.astype(int), key_added="fit_train")
        self._set_new_idx(self.df, idx=fit_val_idx.astype(int), key_added="fit_val")

        self.adata.obs = self.df

    def to_dataset(self, key):
        adata = getattr(self, "{}_adata".format(key))
        return AnnDataset(adata=adata, **self.AnnDataset_kwargs)

    # -- Properties: ---------------------------------------------------------------------
    @property
    def cell_idx(self):
        return self.adata.obs.index

    @property
    def n_cells(self):
        return self.adata.shape[0]

    @property
    def n_features(self):
        return self.adata.shape[1]

    @property
    def uniform_split(self) -> list([int, ..., int]):
        return self.split.uniformly()

    @property
    def proportioned_split(self) -> list([int, ..., int]):
        return self.split.proportioned(percentages=self.train_val_percentages)

    @property
    def train_adata(self):
        return self._subset_adata("train")

    @property
    def val_adata(self):
        return self._subset_adata("val")

    @property
    def test_adata(self):
        return self._subset_adata("test")

    @property
    def predict_adata(self):
        return self._subset_adata("predict")

    @property
    def has_train_not_val(self):
        return (hasattr(self.df, self.data_keys["train"])) and (
            not hasattr(self.df, self.data_keys["val"])
        )

    @property
    def train_dataset(self):
        return self.to_dataset("train")

    @property
    def val_dataset(self):
        return self.to_dataset("val")

    @property
    def test_dataset(self):
        return self.to_dataset("test")

    @property
    def predict_dataset(self):
        return self.to_dataset("predict")
    
class LightningAnnDataModule(LightningDataModule):
    def __init__(
        self,
        adata: [anndata.AnnData] = None,
        batch_size=2000,
        num_workers=os.cpu_count(),
        use_key="X_pca",
        groupby="Time point",
        obs_keys=None,
        train_key="train",
        val_key="val",
        test_key="test",
        predict_key="predict",
        n_groups=None,
        train_val_percentages=[0.8, 0.2],
        remainder_idx=-1,
        predict_all=True,
        attr_names={"obs": [], "aux": []},
        one_hot=False,
        aux_keys=None,
        silent=True,
    ):
        super(LightningAnnDataModule, self).__init__()
        self.save_hyperparameters(ignore=["adata"])
        self._adata = adata
        self.cell_data_manager_kwargs = function_kwargs(
            func=CellDataManager, kwargs=locals(), ignore=["adata"]
        )
        self.cell_data_manager_kwargs.pop(
            "adata"
        )  # TODO: seems like the above "ignore" arg isn't working...

    # -- Supporting methods --------------------------------------------------------------
    def _return_loader(self, dataset_key):
        return DataLoader(
            getattr(self, "{}_dataset".format(dataset_key)),
            num_workers=self.hparams["num_workers"],
            batch_size=self.hparams["batch_size"],
        )

    # -- Properties: ---------------------------------------------------------------------
    @property
    def adata(self):

        if isinstance(self._adata, anndata.AnnData):
            return self._adata
        elif isinstance(self.hparams["h5ad_path"], str):
            return anndata.read_h5ad(self.hparams["h5ad_path"])
        print("Pass adata or h5ad_path")

    @property
    def n_cells(self):
        return self.adata.shape[0]

    @property
    def n_features(self):
        return self.adata.shape[1]

    @property
    def n_dims(self):
        return self.adata.obsm[self.hparams["use_key"]].shape[1]

    @property
    def batch_size(self):
        if not self.hparams["batch_size"]:
            return int(self.n_cells / 10)
        return self.hparams["batch_size"]

    # -- Standard methods: ---------------------------------------------------------------
    def prepare_data(self):
        self.data = CellDataManager(self.adata, **self.cell_data_manager_kwargs)

    def setup(self, stage=None):

        if stage in ["fit", "train", "val"]:
            self.train_dataset = self.data.train_dataset
            self.val_dataset = self.data.val_dataset

        elif stage == "test":
            self.test_dataset = self.data.test_dataset
        elif stage in [None, "predict"]:
            self.predict_dataset = self.data.predict_dataset
        else:
            print(
                "CURRENT STAGE: {} - no suitable subset found during `LightningDataModule.setup()`".format(
                    stage
                )
            )

    # -- Required DataLoader methods: ----------------------------------------------------
    def train_dataloader(self):
        return self._return_loader("train")

    def val_dataloader(self):
        return self._return_loader("val")

    def test_dataloader(self):
        return self._return_loader("test")

    def predict_dataloader(self):
        return self._return_loader("predict")

    def __repr__(self):
        return "⚡ {} ⚡".format(font_format("LightningAnnDataModule", ["PURPLE"]))

<IPython.core.display.Javascript object>

In [21]:
class SplitSize:
    def __init__(self, n_cells: int, n_groups: int):

        self._n_cells = n_cells
        self._n_groups = n_groups

    def _sum_norm(self, vals: Union[List, np.ndarray]) -> np.ndarray:
        return np.array(vals) / vals.sum()

    def uniformly(self):
        div, mod = divmod(self._n_cells, self._n_groups)
        return [div] * (self._n_groups - 1) + [div + mod]

    def proportioned(self, percentages=[0.8, 0.2], remainder_idx=-1):

        percentages = self._sum_norm(np.array(percentages))
        split_lengths = [int(self._n_cells * pct_i) for pct_i in percentages]
        remainder = self._n_cells - sum(split_lengths)
        split_lengths[remainder_idx] += remainder

        return split_lengths


class CellDataManager:
    """Data Manager at the AnnData Level."""

    def __init__(
        self,
        adata,
        use_key="X_pca",
        groupby="Time point",
        obs_keys=None,
        train_key="train",
        val_key="val",
        test_key="test",
        predict_key="predict",
        n_groups=None,
        train_val_percentages=[0.8, 0.2],
        remainder_idx=-1,
        predict_all=True,
        attr_names={"obs": [], "aux": []},
        one_hot=False,
        aux_keys=None,
        silent=True,
    ):

        self.__config__(locals())

    def __parse__(self, kwargs, ignore=["self"], hide=[]):

        for key, val in kwargs.items():
            if not key in ignore:
                setattr(self, key, val)

    def __config__(self, kwargs, ignore=["self"], hide=[]):

        self.AnnDataset_kwargs = function_kwargs(
            func=AnnDataset, kwargs=kwargs, ignore=["adata"]
        )
        self.AnnDataset_kwargs.pop("adata")
        self.__parse__(kwargs, ignore, hide)
        self.df = self.adata.obs.copy()
        self.data_keys = self._get_keys(kwargs)

        if not self.n_groups:
            self.n_groups = len(self.train_val_percentages)

        self.split = SplitSize(self.n_cells, self.n_groups)

        # configure train-val split if it has train but not val
        if self.has_train_not_val:
            self.n_fit = self.train_adata.shape[0]
            self._allocate_validation()

    # -- supporting methods: -------------------------------------------------------------
    def _get_keys(self, kwargs):
        return {
            attr.strip("_key"): val
            for attr, val in kwargs.items()
            if attr.endswith("key")
        }

    def _subset_adata(self, key):
        access_key = self.data_keys[key]
        if not hasattr(self.df, access_key):
            if (key == "predict") and (self.predict_all):
                self.df[access_key] = True
            else:
                print("Key: Access Key pair: {}:{} not found".format(key, access_key))
        # else invoke split w/ requisite args
        return self.adata[self.df[access_key]].copy()

    def _set_new_idx(self, df, idx, key_added):

        tmp = np.zeros(len(df), dtype=bool)
        tmp[idx] = True
        df[key_added] = tmp.astype(bool)

    def _allocate_validation(self, remainder_idx=-1):
        """If validation key is not found, invoke this function. Takes the train subset
        adata and breaks it into non-overlapping train and validation adata subsets."""

        if not self.n_groups:
            self.n_groups = len(self.train_val_percentages)

        train_adata = self.train_adata

        n_cells = train_adata.shape[0]

        self.data_keys["train"] = "fit_train"
        self.data_keys["val"] = "fit_val"

        train_val_split = SplitSize(n_cells, self.n_groups)

        if not self.train_val_percentages:
            n_train, n_val = train_val_split.uniformly()
        else:
            n_train, n_val = train_val_split.proportioned(
                percentages=self.train_val_percentages, remainder_idx=remainder_idx
            )

        original_train_idx = train_adata.obs.index
        train_idx = np.random.choice(
            range(len(original_train_idx)), size=n_train, replace=False
        )
        train_cells = np.zeros(len(original_train_idx), dtype=bool)
        train_cells[train_idx] = True
        fit_train_idx = original_train_idx[train_cells]
        fit_val_idx = original_train_idx[~train_cells]

        self._set_new_idx(self.df, idx=fit_train_idx.astype(int), key_added="fit_train")
        self._set_new_idx(self.df, idx=fit_val_idx.astype(int), key_added="fit_val")

        self.adata.obs = self.df

    def to_dataset(self, key):
        adata = getattr(self, "{}_adata".format(key))
        return AnnDataset(adata=adata, **self.AnnDataset_kwargs)

    # -- Properties: ---------------------------------------------------------------------
    @property
    def cell_idx(self):
        return self.adata.obs.index

    @property
    def n_cells(self):
        return self.adata.shape[0]

    @property
    def n_features(self):
        return self.adata.shape[1]

    @property
    def uniform_split(self) -> list([int, ..., int]):
        return self.split.uniformly()

    @property
    def proportioned_split(self) -> list([int, ..., int]):
        return self.split.proportioned(percentages=self.train_val_percentages)

    @property
    def train_adata(self):
        return self._subset_adata("train")

    @property
    def val_adata(self):
        return self._subset_adata("val")

    @property
    def test_adata(self):
        return self._subset_adata("test")

    @property
    def predict_adata(self):
        return self._subset_adata("predict")

    @property
    def has_train_not_val(self):
        return (hasattr(self.df, self.data_keys["train"])) and (
            not hasattr(self.df, self.data_keys["val"])
        )

    @property
    def train_dataset(self):
        return self.to_dataset("train")

    @property
    def val_dataset(self):
        return self.to_dataset("val")

    @property
    def test_dataset(self):
        return self.to_dataset("test")

    @property
    def predict_dataset(self):
        return self.to_dataset("predict")
    
class LightningAnnDataModule(LightningDataModule):
    def __init__(
        self,
        adata: [anndata.AnnData] = None,
        batch_size=2000,
        num_workers=os.cpu_count(),
        use_key="X_pca",
        groupby="Time point",
        obs_keys=None,
        train_key="train",
        val_key="val",
        test_key="test",
        predict_key="predict",
        n_groups=None,
        train_val_percentages=[0.8, 0.2],
        remainder_idx=-1,
        predict_all=True,
        attr_names={"obs": [], "aux": []},
        one_hot=False,
        aux_keys=None,
        silent=True,
    ):
        super(LightningAnnDataModule, self).__init__()
        self.save_hyperparameters(ignore=["adata"])
        self._adata = adata
        self.cell_data_manager_kwargs = function_kwargs(
            func=CellDataManager, kwargs=locals(), ignore=["adata"]
        )
        self.cell_data_manager_kwargs.pop(
            "adata"
        )  # TODO: seems like the above "ignore" arg isn't working...

    # -- Supporting methods --------------------------------------------------------------
    def _return_loader(self, dataset_key):
        return DataLoader(
            getattr(self, "{}_dataset".format(dataset_key)),
            num_workers=self.hparams["num_workers"],
            batch_size=self.hparams["batch_size"],
        )

    # -- Properties: ---------------------------------------------------------------------
    @property
    def adata(self):

        if isinstance(self._adata, anndata.AnnData):
            return self._adata
        elif isinstance(self.hparams["h5ad_path"], str):
            return anndata.read_h5ad(self.hparams["h5ad_path"])
        print("Pass adata or h5ad_path")

    @property
    def n_cells(self):
        return self.adata.shape[0]

    @property
    def n_features(self):
        return self.adata.shape[1]

    @property
    def n_dims(self):
        return self.adata.obsm[self.hparams["use_key"]].shape[1]

    @property
    def batch_size(self):
        if not self.hparams["batch_size"]:
            return int(self.n_cells / 10)
        return self.hparams["batch_size"]

    # -- Standard methods: ---------------------------------------------------------------
    def prepare_data(self):
        self.data = CellDataManager(self.adata, **self.cell_data_manager_kwargs)

    def setup(self, stage=None):

        if stage in ["fit", "train", "val"]:
            self.train_dataset = self.data.train_dataset
            self.val_dataset = self.data.val_dataset

        elif stage == "test":
            self.test_dataset = self.data.test_dataset
        elif stage in [None, "predict"]:
            self.predict_dataset = self.data.predict_dataset
        else:
            print(
                "CURRENT STAGE: {} - no suitable subset found during `LightningDataModule.setup()`".format(
                    stage
                )
            )

    # -- Required DataLoader methods: ----------------------------------------------------
    def train_dataloader(self):
        return self._return_loader("train")

    def val_dataloader(self):
        return self._return_loader("val")

    def test_dataloader(self):
        return self._return_loader("test")

    def predict_dataloader(self):
        return self._return_loader("predict")

    def __repr__(self):
        return "⚡ {} ⚡".format(font_format("LightningAnnDataModule", ["PURPLE"]))
    
class LightningDataModuleConfiguration:
    def __init__(
        self,
        adata: [anndata.AnnData] = None,
        batch_size=2000,
        num_workers=os.cpu_count(),
        use_key="X_pca",
        groupby="Time point",
        obs_keys=None,
        train_key="train",
        val_key="val",
        test_key="test",
        predict_key="predict",
        n_groups=None,
        train_val_percentages=[0.8, 0.2],
        remainder_idx=-1,
        predict_all=True,
        attr_names={"obs": [], "aux": []},
        one_hot=False,
        aux_keys=None,
        silent=True,
    ):

        kwargs = function_kwargs(LightningAnnDataModule, locals())
        kwargs.pop("adata")

        self._LightningDataModule = LightningAnnDataModule(adata, **kwargs)

    @property
    def LightningDataModule(self):
        return self._LightningDataModule

<IPython.core.display.Javascript object>

In [25]:
lit_data = LightningAnnDataModule(adata)

<IPython.core.display.Javascript object>

In [26]:
lit_data.prepare_data()
lit_data.setup(stage="fit")
lit_data.train_dataloader(), lit_data.val_dataloader()

(<torch.utils.data.dataloader.DataLoader at 0x7ff784199ac0>,
 <torch.utils.data.dataloader.DataLoader at 0x7ff6f11a35e0>)

<IPython.core.display.Javascript object>

#### Now test to see if it worls within the config class

In [27]:
DATA_KWARGS = {}
lit_data = LightningDataModuleConfiguration(adata, **DATA_KWARGS).LightningDataModule
lit_data

⚡ [95mLightningAnnDataModule[0m ⚡

<IPython.core.display.Javascript object>