In [210]:
from typing import Union

import anndata as ad
import mudata as md
import muon as mu
import numpy as np
import pandas as pd
from anndata._core.anndata import AnnData
from anndata.compat import Index

from cellink._core.dummy_data import sim_adata, sim_gdata

In [334]:
class BaseDonorData:
    def _sync_donors(self, C, donor_id="donor_id"):
        overlap_donors = C.obs[donor_id].isin(self.obs_names)
        if not overlap_donors.any():
            raise ValueError(
                (
                    "No overlapping donors between DonorData and C,"
                    " make sure C.obs['%s'] exists and corresponds to DonorData.obs_names"
                ),
                donor_id,
            )
        C = C[overlap_donors]
        donor_order = pd.Series(range(len(self.obs_names)), index=self.obs_names)
        order_cells_by_donor = C.obs[donor_id].map(donor_order).astype(int).sort_values().index
        C = C[order_cells_by_donor]
        return C

    def _gen_repr_C(self) -> str:
        n_cells_per_donor = self.C.obs.groupby(self.donor_id, observed=True).size()
        min_n_cells = n_cells_per_donor.min()
        max_n_cells = n_cells_per_donor.max()
        _C_repr = str(self.C).split("\n")
        C_repr = [_C_repr[0]]
        C_repr += [
            f"    n_donors x n_cells x n_vars = {n_cells_per_donor.shape[0]} x [{min_n_cells}-{max_n_cells}] x {self.C.shape[1]}"
        ]
        C_repr += [f"    donor_id: {self.donor_id}"] + _C_repr[1:]
        C_repr = "\n       ".join(C_repr)
        return C_repr

    def _gen_repr(self, *args, **kwargs) -> str:
        descr = super()._gen_repr(*args, **kwargs)
        descr = descr.replace(self._base_cls.__name__, self.__class__.__name__)
        descr += "\n    C: " + self._gen_repr_C()
        return descr

    def _init_donor_data(self, C, donor_id="donor_id"):
        self.C = self._sync_donors(C, donor_id)
        self.donor_id = donor_id


class DonorAnnData(BaseDonorData, ad.AnnData):
    _base_cls = ad.AnnData

    def __init__(self, X, obs=None, var=None, uns=None, *, C, donor_id="donor_id", **kwargs):
        super().__init__(X, obs, var, uns, **kwargs)
        self._init_donor_data(C, donor_id)

    def __getitem__(self, index: Index) -> AnnData:
        """Returns a sliced view of the object."""
        oidx, vidx = self._normalize_indices(index)
        C = self.C[self.C.obs[self.donor_id].isin(self.obs_names[oidx])]
        return self.__class__(self, oidx=oidx, vidx=vidx, asview=True, C=C)


md.set_options(pull_on_update=False)


class DonorMuData(BaseDonorData, mu.MuData):
    _base_cls = mu.MuData

    def __init__(self, mdata, *, C, donor_id="donor_id", **kwargs):
        super()._init_as_actual(mdata, **kwargs)
        self._init_donor_data(C, donor_id)

    def __getitem__(self, index) -> Union["MuData", AnnData]:
        if isinstance(index, str):
            return self.mod[index]
        else:
            from anndata._core.index import _normalize_indices

            obsidx, varidx = _normalize_indices(index, self.obs.index, self.var.index)
            C = self.C[self.C.obs[self.donor_id].isin(self.obs_names[obsidx])]
            mdata = mu.MuData(self, as_view=True, index=index)
            return self.__class__(mdata, C=C, donor_id=self.donor_id)

In [335]:
adata = sim_adata()
gdata = sim_gdata(adata=adata)
shuffled_index = np.random.permutation(gdata.obs_names)
gdata = gdata[shuffled_index].copy()
dd = DonorAnnData(gdata, C=adata)
assert (dd.obs_names == dd.C.obs["donor_id"].unique()).all()
dd

DonorAnnData object with n_obs × n_vars = 10 × 98
    var: 'chrom', 'pos', 'a0', 'a1', 'maf'
    C: View of AnnData object with n_obs × n_vars = 563 × 20
           n_donors x n_cells x n_vars = 10 x [14-94] x 20
           donor_id: donor_id
           obs: 'celltype', 'cov1', 'cov2', 'cov3', 'donor_id'
           var: 'chrom', 'start', 'end', 'strand'

In [336]:
mgdata = mu.MuData({"geno": gdata})
mdd = DonorMuData(mgdata, C=adata)
assert (mdd.obs_names == mdd.C.obs["donor_id"].unique()).all()
print(mdd)

DonorMuData object with n_obs × n_vars = 10 × 98
  1 modality
    geno:	10 x 98
      var:	'chrom', 'pos', 'a0', 'a1', 'maf'
    C: View of AnnData object with n_obs × n_vars = 563 × 20
           n_donors x n_cells x n_vars = 10 x [14-94] x 20
           donor_id: donor_id
           obs: 'celltype', 'cov1', 'cov2', 'cov3', 'donor_id'
           var: 'chrom', 'start', 'end', 'strand'


In [338]:
madata = mu.MuData({"rna": adata})
madata

In [343]:
mgdata.obs_names

Index(['D0', 'D3', 'D4', 'D6', 'D1', 'D9', 'D5', 'D8', 'D7', 'D2'], dtype='object', name='donor_id')

In [351]:
mgdata = mu.MuData({"geno": gdata})
madata = mu.MuData({"rna": adata})
madata.obs["donor_id"] = madata["rna"].obs["donor_id"]
mdd = DonorMuData(mgdata, C=madata, donor_id="donor_id")
assert (mdd.obs_names == mdd.C.obs["donor_id"].unique()).all()
print(mdd)

DonorMuData object with n_obs × n_vars = 10 × 98
  1 modality
    geno:	10 x 98
      var:	'chrom', 'pos', 'a0', 'a1', 'maf'
    C: View of MuData object with n_obs × n_vars = 563 × 20
           n_donors x n_cells x n_vars = 10 x [14-94] x 20
           donor_id: donor_id
         obs:	'donor_id'
         1 modality
           rna:	563 x 20
             obs:	'celltype', 'cov1', 'cov2', 'cov3', 'donor_id'
             var:	'chrom', 'start', 'end', 'strand'


In [353]:
mdd[["D0", "D7"]]