In [1]:
import numpy as np
import pandas as pd
import scanpy as sc
import torch as nn
import warnings
import anndata as ann

from collections import defaultdict

import scvi
from scvi.dataloaders._ann_dataloader import AnnDataLoader

from pynndescent import NNDescent
from sklearn.utils.extmath import randomized_svd

Global seed set to 0


## Input data

In [2]:
adata=sc.AnnData(np.ones((4,5))*[1,2,3,4,5],
                obsm={'species_ratio':np.array([ # Species ratios
                    [1,0],
                    [0,1],
                    [0.5,0.5],
                    [0.1,0.9]]),
                    'eval_o':np.array([0,0,0,1]).reshape(-1,1), # Eval only orthologues},
                    'cov_species':np.ones((4,2,1))*[[1],[2]]}, # Species-metadata map
                var=pd.DataFrame({'species':['K']*3+['L']*2},
                                index=['a','b','c','d','e']), # Species of gene
                uns={'orthologues':pd.DataFrame({'K':['c'],'L':['d']}), # Orthgologue map
                    'species_order':['K','L']}) # Order of species
            

In [246]:
adata

AnnData object with n_obs × n_vars = 4 × 5
    var: 'species'
    uns: 'orthologues', 'species_order'
    obsm: 'species_ratio', 'eval_o', 'cov_species'

In [247]:
print('X')
display(adata.to_df())
print('var')
display(adata.var)
for k,v in adata.uns.items():
    print('uns',k)
    display(v)
for k,v in adata.obsm.items():
    print('obsm',k)
    display(v)

X


Unnamed: 0,a,b,c,d,e
0,1.0,2.0,3.0,4.0,5.0
1,1.0,2.0,3.0,4.0,5.0
2,1.0,2.0,3.0,4.0,5.0
3,1.0,2.0,3.0,4.0,5.0


var


Unnamed: 0,species
a,K
b,K
c,K
d,L
e,L


uns orthologues


Unnamed: 0,K,L
0,c,d


uns species_order


['K', 'L']

obsm species_ratio


array([[1. , 0. ],
       [0. , 1. ],
       [0.5, 0.5],
       [0.1, 0.9]])

obsm eval_o


array([[0],
       [0],
       [0],
       [1]])

obsm cov_species


array([[[1.],
        [2.]],

       [[1.],
        [2.]],

       [[1.],
        [2.]],

       [[1.],
        [2.]]])

In [3]:
scvi.data.setup_anndata(adata)
scvi.data.register_tensor_from_anndata(
    adata=adata,
    adata_attr_name="obsm",
    adata_key_name='species_ratio',
    registry_key="species_ratio",
    is_categorical=False,
)
scvi.data.register_tensor_from_anndata(
    adata=adata,
    adata_attr_name="obsm",
    adata_key_name='eval_o',
    registry_key="eval_o",
    is_categorical=False,
)
scvi.data.register_tensor_from_anndata(
    adata=adata,
    adata_attr_name="obsm",
    adata_key_name='cov_species',
    registry_key="cov_species",
    is_categorical=False,
)

[34mINFO    [0m No batch_key inputted, assuming all cells are same batch                            
[34mINFO    [0m No label_key inputted, assuming all cells have same label                           
[34mINFO    [0m Using data from adata.X                                                             
[34mINFO    [0m Successfully registered anndata object containing [1;36m4[0m cells, [1;36m5[0m vars, [1;36m1[0m batches, [1;36m1[0m     
         labels, and [1;36m0[0m proteins. Also registered [1;36m0[0m extra categorical covariates and [1;36m0[0m extra  
         continuous covariates.                                                              
[34mINFO    [0m Please do not further modify adata until model is trained.                          




In [249]:
adata

AnnData object with n_obs × n_vars = 4 × 5
    obs: '_scvi_batch', '_scvi_labels'
    var: 'species'
    uns: 'orthologues', 'species_order', '_scvi'
    obsm: 'species_ratio', 'eval_o', 'cov_species'

In [250]:
scvi.data.view_anndata_setup(adata)



In [251]:
adata.uns['_scvi']

{'scvi_version': '0.0.0',
 'categorical_mappings': {'_scvi_batch': {'original_key': '_scvi_batch',
   'mapping': array([0])},
  '_scvi_labels': {'original_key': '_scvi_labels', 'mapping': array([0])}},
 'data_registry': {'X': {'attr_name': 'X', 'attr_key': 'None'},
  'batch_indices': {'attr_name': 'obs', 'attr_key': '_scvi_batch'},
  'labels': {'attr_name': 'obs', 'attr_key': '_scvi_labels'},
  'species_ratio': {'attr_name': 'obsm', 'attr_key': 'species_ratio'},
  'eval_o': {'attr_name': 'obsm', 'attr_key': 'eval_o'},
  'cov_species': {'attr_name': 'obsm', 'attr_key': 'cov_species'}},
 'summary_stats': {'n_batch': 1,
  'n_cells': 4,
  'n_vars': 5,
  'n_labels': 1,
  'n_proteins': 0,
  'n_continuous_covs': 0}}

In [6]:
adata.uns['_scvi']['data_registry']

{'X': {'attr_name': 'X', 'attr_key': 'None'},
 'batch_indices': {'attr_name': 'obs', 'attr_key': '_scvi_batch'},
 'labels': {'attr_name': 'obs', 'attr_key': '_scvi_labels'},
 'species_ratio': {'attr_name': 'obsm', 'attr_key': 'species_ratio'},
 'eval_o': {'attr_name': 'obsm', 'attr_key': 'eval_o'},
 'cov_species': {'attr_name': 'obsm', 'attr_key': 'cov_species'}}

In [252]:
scvi._CONSTANTS

_CONSTANTS_NT(X_KEY='X', BATCH_KEY='batch_indices', LABELS_KEY='labels', PROTEIN_EXP_KEY='protein_expression', CAT_COVS_KEY='cat_covs', CONT_COVS_KEY='cont_covs')

In [253]:
adl = AnnDataLoader(adata, shuffle=False, batch_size = 10)

In [254]:
data_batch = next(tensors for tensors in adl)
data_batch

{'X': tensor([[1., 2., 3., 4., 5.],
         [1., 2., 3., 4., 5.],
         [1., 2., 3., 4., 5.],
         [1., 2., 3., 4., 5.]]),
 'batch_indices': tensor([[0.],
         [0.],
         [0.],
         [0.]]),
 'labels': tensor([[0.],
         [0.],
         [0.],
         [0.]]),
 'species_ratio': tensor([[1.0000, 0.0000],
         [0.0000, 1.0000],
         [0.5000, 0.5000],
         [0.1000, 0.9000]]),
 'eval_o': tensor([[0.],
         [0.],
         [0.],
         [1.]]),
 'cov_species': tensor([[[1.],
          [2.]],
 
         [[1.],
          [2.]],
 
         [[1.],
          [2.]],
 
         [[1.],
          [2.]]])}

## Maps

In [255]:
adata

AnnData object with n_obs × n_vars = 4 × 5
    obs: '_scvi_batch', '_scvi_labels'
    var: 'species'
    uns: 'orthologues', 'species_order', '_scvi'
    obsm: 'species_ratio', 'eval_o', 'cov_species'

### Gene maps

In [527]:
# Species gene maps

# Map genes to orthologues (not species specific)

# TODO add check that gene is only 1x in orthologues of species
orthologues=adata.uns['orthologues'].values.ravel() 
n_genes=adata.var.shape[0]
n_genes_mapped=n_genes-orthologues.shape[0]+adata.uns['orthologues'].shape[0]
gene_map=np.zeros((n_genes,n_genes_mapped))

# Map index to integers for latter determining orthologue position
orthologues_df=adata.uns['orthologues'].copy()
orthologues_df.index=range(adata.uns['orthologues'].shape[0])

# Ensure orthologue and species ordering of genes
# Orthologue order
# Assumes that var names are unique across all species and each genes is in orthologues of species only 1x,
# e.g. each gene name is present in orthologues_df only 1x # TODO check
gene_order = {gene: idx
                    for idx, data in orthologues_df.iterrows()
                    for gene in data.values}

# Number of genes per gene group
# For orthologues specify number, for species-specific specify list with numbers per species
gene_numbers = {'orthologues': orthologues_df.shape[0],
                'species_specific': []}
        
# Species-specific order, starting at idx positions after orthologues
idx = orthologues_df.shape[0]
for species in adata.uns['species_order']:
    n_species_specific = 0
    genes_species = adata.var.query('species==@species').index
    for gene in genes_species:
        if gene not in orthologues:
            gene_order[gene] = idx
            idx += 1                    
            n_species_specific += 1
    gene_numbers['species_specific'].append(n_species_specific)
print(gene_order)
print(gene_numbers)

# Map between original genes order and sorted genes with merged orthologues across species
for gene_idx, gene in enumerate(adata.var.index):
    gene_map[gene_idx, gene_order[gene]] = 1

# Make species-specific gene-orthologue maps
# Modify general gene-orthologue map to contain only species specific genes
species_maps=[]
for species in adata.uns['species_order']:
    species_maps.append(gene_map*(adata.var['species']==species).values.reshape(-1,1))
species_maps=nn.tensor(np.array(species_maps))
display(species_maps)

{'c': 0, 'd': 0, 'a': 1, 'b': 2, 'e': 3}
{'orthologues': 1, 'species_specific': [2, 1]}


tensor([[[0., 1., 0., 0.],
         [0., 0., 1., 0.],
         [1., 0., 0., 0.],
         [0., 0., 0., 0.],
         [0., 0., 0., 0.]],

        [[0., 0., 0., 0.],
         [0., 0., 0., 0.],
         [0., 0., 0., 0.],
         [1., 0., 0., 0.],
         [0., 0., 0., 1.]]], dtype=torch.float64)

In [192]:
# Map of whether a gene is orthologue or not
orthologue_map=nn.tensor(np.array([g in orthologues for g in adata.var.index]))
display(orthologue_map)

tensor([False, False,  True,  True, False])

### Mixup maps

In [163]:
# Mixup orthologues, dim = samples*species*genes
orthologue_mixup=data_batch['species_ratio'].unsqueeze(2)*\
    (orthologue_map.expand(data_batch['species_ratio'].shape[0],1,-1)).float() 
# Use species specific genes if species is present in mixup, dim = samples*species*genes
non_orthologue_mixup=(data_batch['species_ratio'].unsqueeze(2)>0).float()*\
    (~(orthologue_map.expand(data_batch['species_ratio'].shape[0],1,-1))).float()
# Make combined orthologue and species-specific genes maps, dim = samples*species*genes
gene_mixup_map=(orthologue_mixup+non_orthologue_mixup).unsqueeze(3)
# Make gene maps species specific (keep only genes from the species) and summarized
# in terms of orthologues, dim = samples*species*genes*genes_mapped
mixup_map=gene_mixup_map*species_maps
print(mixup_map.shape)
print(mixup_map)

torch.Size([4, 2, 5, 4])
tensor([[[[0.0000, 1.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 1.0000, 0.0000],
          [1.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000]],

         [[0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000]]],


        [[[0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000]],

         [[0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000],
          [1.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 1.0000]]],


        [[[0.0000, 1.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 1.0000, 0.00

In [303]:
# Expression mixup across species
expr_mixup=(data_batch['X'].unsqueeze(1).expand(-1,mixup_map.shape[1],-1).unsqueeze(3)*\
            mixup_map).sum(1).sum(1)
print(expr_mixup.shape)
print(expr_mixup)

torch.Size([4, 4])
tensor([[3.0000, 1.0000, 2.0000, 0.0000],
        [4.0000, 0.0000, 0.0000, 5.0000],
        [3.5000, 1.0000, 2.0000, 5.0000],
        [3.9000, 1.0000, 2.0000, 5.0000]], dtype=torch.float64)


In [288]:
# Metadata/cov mixup
cov_mixup=nn.matmul(data_batch['species_ratio'].unsqueeze(1),data_batch['cov_species']
                   ).squeeze(1)
print(cov_mixup.shape)
print(cov_mixup)

torch.Size([4, 1])
tensor([[1.0000],
        [2.0000],
        [1.5000],
        [1.9000]])


In [404]:
# Genes to evaluate/compute loss on
# Genes to eval in each sample based on orthologues/not
eval_genes=nn.matmul(nn.logical_not(data_batch['eval_o']).float(),
          (~orthologue_map).float().unsqueeze(0) # Eval species specific genes or not
         )+orthologue_map # Orthologues - always evaluated
# Genes to eval in each sample based on mixup species and orthologue/not
eval_map=nn.matmul(eval_genes.unsqueeze(1),mixup_map.sum(axis=1).float()).squeeze(1)
print(eval_map.shape)
print(eval_map)

torch.Size([4, 4])
tensor([[1., 1., 1., 0.],
        [1., 0., 0., 1.],
        [1., 1., 1., 1.],
        [1., 0., 0., 0.]])


In [472]:
# Per-species expression with non-species genes masked, dim = samples * species * genes_mapped
expr_species=(data_batch['X'].unsqueeze(1).expand(-1,species_maps.shape[0],-1).unsqueeze(3)*\
          species_maps.unsqueeze(0).expand(data_batch['X'].shape[0],-1,-1,-1)).sum(axis=2)
print(expr_species.shape)
print(expr_species)

torch.Size([4, 2, 4])
tensor([[[3., 1., 2., 0.],
         [4., 0., 0., 5.]],

        [[3., 1., 2., 0.],
         [4., 0., 0., 5.]],

        [[3., 1., 2., 0.],
         [4., 0., 0., 5.]],

        [[3., 1., 2., 0.],
         [4., 0., 0., 5.]]], dtype=torch.float64)


### Metrics

In [535]:
def gaussian_nll_mask(m,x,v,mask):
    """
    Compute Gausian negative log likelihood loss with sample-specific masked features.
    :param m: Predicted mean of target
    :param x: True target
    :param v: predicted v of target
    :param mask: Sample-specific feature mask of same shape as x specifiying by 1/0 if
        sample-feature should be used for computing loss or not, respectively.
    :return: loss
    """
    l=nn.nn.GaussianNLLLoss(reduction='none')(m,x,v)
    l=l*mask # Set some sample-specific features to 0
    print(l)
    l=l.sum(dim=1)/mask.sum(dim=1) # Normalise accounting for masking
    return l

# Example output
a=nn.tensor(abs(np.random.normal(size=eval_map.shape))).float()
print(gaussian_nll_mask(a,a,a,eval_map))

tensor([[-0.9761, -0.5708, -0.1844, -0.0000],
        [-0.1604,  0.0000,  0.0000, -0.2468],
        [ 0.0865, -0.3398, -0.2585, -1.4728],
        [ 0.3304, -0.0000, -0.0000, -0.0000]])
tensor([-0.5771, -0.2036, -0.4962,  0.3304])


## Prepare mixup and adata

In [3]:
# Input adata - as expected from the user
adata_in=sc.AnnData(np.random.normal(size=(4,5))*[1,2,3,4,5],
                obs={'species':[ 'K']*2+['L','M'], # Species
                    'cov_c':['a','b']*2, # Covariates - categorical example
                    'cov_n':[0,1]*2}, # Covariates - continous (numerical) example
                var=pd.DataFrame({'species':['K']*2+['L']*2+['M']},
                                index=['a','b','c','d','e']), # Species of gene
                uns={'orthologues':pd.DataFrame(
                    {'K':['b'],'L':['c'],'M':['e']}), # Orthgologue map
                    }) 

In [4]:
adata_in

AnnData object with n_obs × n_vars = 4 × 5
    obs: 'species', 'cov_c', 'cov_n'
    var: 'species'
    uns: 'orthologues'

In [5]:
species_key='species'
cov_cat_keys=['cov_c']
cov_cont_keys=['cov_n']
seed=0
alpha=0.1

In [134]:
# Prepare data
adata_pp=sc.AnnData(adata_in.X,var=adata_in.var)
# Species encoding
species=adata_in.obs[species_key].values
species_order=pd.Categorical(species).categories.values
adata_pp.obsm['species_ratio']=pd.get_dummies(species)[species_order].values
# Covariate encoding
# One-hot encoding of categorical covariates
cov_cat_data=[]
for cov_cat_key in cov_cat_keys:
    cat_order=pd.Categorical(adata_in.obs[cov_cat_key]).categories.values
    cov_cat_data.append(pd.get_dummies(adata_in.obs[cov_cat_key])[cat_order].values)
    adata_pp.uns[cov_cat_key+'_order']=cat_order
# Prepare single cov array for all covariates and in per-species format
cov_data=np.concatenate(cov_cat_data+[adata_in.obs[cov_cont_keys].values],axis=1)
adata_pp.obsm['cov_species']=np.broadcast_to(np.expand_dims(cov_data,axis=1),
                (cov_data.shape[0],len(species_order),cov_data.shape[1]))
# Whether to eval only orthologues - always false here
adata_pp.obsm['eval_o']=np.array([0]*adata_pp.shape[0]).reshape(-1,1)

In [159]:
def create_mixup(indices,adata,obs_prefix,seed=0):
    """
    Mixup.
    Vars from outer scope: species, species_order, alpha
    :param indices: Iter (cell pairs form mixup) of iters (cells in pairs) - 2 cells 
    for mixup should be specified for each pair. Indices are obs positions from adata.
    :param adata: Adata used for making the mixup. Should have 
    col species in var and cov_species in obsm.
    :param seed: Set seed for mixup ratio generation. If None does not set the seed.
    """
    xs=[]
    covs=[]
    species_ratios=[]
    obs_names=[]
    if seed is not None:
        np.random.seed(seed)
    for i,j in indices: 
        mixup_ratio_i=np.random.beta(alpha, alpha)
        mixup_ratio_j=1-mixup_ratio_i
        species_i=species[i]
        species_j=species[j]
        # Get expression, expression of unused species genes will be set to 0
        x_i=adata[i,:].X.copy().ravel()
        x_i[adata.var['species']!=species_i]=0
        x_j=adata[j,:].X.copy().ravel()
        x_j[adata.var['species']!=species_j]=0
        xs.append(x_i+x_j)
        cov_i=adata.obsm['cov_species'][i,0]
        cov_j=adata.obsm['cov_species'][j,0]
        # For species that are not being validated just set cov to mixup ratio, 
        # but this is not very relevant for the model as it is not being validated
        cov_ij=cov_i*mixup_ratio_i+cov_j*mixup_ratio_j
        covs.append(np.array([cov_i,cov_j]+[cov_ij]*(len(species_order)-2)))
        species_ratio=np.zeros(len(species_order))
        species_ratio[species_order==species_i]=mixup_ratio_i
        species_ratio[species_order==species_j]=mixup_ratio_j
        species_ratios.append(species_ratio)
        obs_names.append('_'.join([obs_prefix,str(i),str(j)]))
    adata_mixup=sc.AnnData(
        X=pd.DataFrame(np.array(xs),index=obs_names,columns=adata.var_names),
        obsm={'cov_species':np.array(covs),'species_ratio':np.array(species_ratios)}
    )
    
    return adata_mixup

In [160]:
def count_species_pairs(adata,description):
    """
    Counts N species pairs per species combination
    :param adata: Adata with species_ratio in obsm
    :param description: For printing out, adata name/description
    """
    species_pairs=dict()
    for row_idx in range(adata.obsm['species_ratio'].shape[0]):
        species_idxs=np.argwhere(adata.obsm['species_ratio'][row_idx]>0)
        pair_name=' and '.join([species_order[idx][0] for idx in species_idxs])
        if pair_name not in species_pairs:
            species_pairs[pair_name]=0
        species_pairs[pair_name]=species_pairs[pair_name]+1
    print('N pairs per species combination for',description,':',species_pairs)

In [161]:
# Random cross-species mixup
random_mixup_ratio=1
desired_n=int(random_mixup_ratio*adata_pp.shape[0])
random_mixup_idx=set()
idxs=list(range(adata_pp.shape[0]))
if seed is not None:
    np.random.seed(seed)
# Try to generate N random corss-species pairs 
# TODO may be problematic as could run indefinitely, for now a quick fix is added
tries=0
while len(random_mixup_idx)<desired_n and\
    tries<adata_in.shape[0]*10: # Quick fix to stop if can not find combionations
    # Randomly sample cells and make sure that species differ
    # TODO could make quicker by selecting random cell pairs from pairs 
    # of species directly
    i=np.random.choice(idxs)
    j=np.random.choice(idxs)
    species_i=species[i]
    species_j=species[j]
    # TODO could add check not to use same cell 2x, but maybe less important as 
    # mixup ratio will differ
    if species_i!=species_j: 
        # Could be used for checking that same cell par was not used before
        random_mixup_idx.add(frozenset((i,j)))
    tries+=1
if len(random_mixup_idx)<desired_n:
    warnings.warn('Found less than desired  number of random mixup samples.')
    print('Found %i/%i random mixup samples'%(
                  len(random_mixup_idx),desired_n))
# Make adata from selected mixup cells
adata_random_species_mixup=create_mixup(indices=random_mixup_idx,
                                        adata=adata_pp,obs_prefix='mixup_species_random',
                                        seed=seed)
count_species_pairs(adata_random_species_mixup,description='random corss-species mixup')
adata_random_species_mixup.obsm['eval_o']=np.array([1]*adata_random_species_mixup.shape[0]
                                                   ).reshape(-1,1)

# Check what came out
print('X')
display(adata_random_species_mixup.to_df())
for k,v in adata_random_species_mixup.obsm.items():
    print('obsm',k)
    display(v)

N pairs per species combination for random corss-species mixup : {'K and L': 2, 'K and M': 2}
X


Unnamed: 0,a,b,c,d,e
mixup_species_random_0_2,1.764052,0.800314,2.283113,0.4867,0.0
mixup_species_random_1_3,-0.977278,1.900177,0.0,0.0,-4.270479
mixup_species_random_1_2,-0.977278,1.900177,2.283113,0.4867,0.0
mixup_species_random_0_3,1.764052,0.800314,0.0,0.0,-4.270479


obsm cov_species


array([[[1.00000000e+00, 0.00000000e+00, 0.00000000e+00],
        [1.00000000e+00, 0.00000000e+00, 0.00000000e+00],
        [1.00000000e+00, 0.00000000e+00, 0.00000000e+00]],

       [[0.00000000e+00, 1.00000000e+00, 1.00000000e+00],
        [0.00000000e+00, 1.00000000e+00, 1.00000000e+00],
        [0.00000000e+00, 1.00000000e+00, 1.00000000e+00]],

       [[0.00000000e+00, 1.00000000e+00, 1.00000000e+00],
        [1.00000000e+00, 0.00000000e+00, 0.00000000e+00],
        [9.85473764e-01, 1.45262365e-02, 1.45262365e-02]],

       [[1.00000000e+00, 0.00000000e+00, 0.00000000e+00],
        [0.00000000e+00, 1.00000000e+00, 1.00000000e+00],
        [8.08633422e-04, 9.99191367e-01, 9.99191367e-01]]])

obsm species_ratio


array([[6.61193791e-02, 9.33880621e-01, 0.00000000e+00],
       [7.32928707e-01, 0.00000000e+00, 2.67071293e-01],
       [1.45262365e-02, 9.85473764e-01, 0.00000000e+00],
       [8.08633422e-04, 0.00000000e+00, 9.99191367e-01]])

obsm eval_o


array([[1],
       [1],
       [1],
       [1]])

In [162]:
# Map for transofming input epxression to orthologues only
# matmul(X,otm) = expression of orthologues only, combined by orthologuesacross species
# dim otm = n_genes * n_orthologues (summarized across species)
orthologues_transform_map=np.zeros((adata_pp.shape[1],adata_in.uns['orthologues'].shape[0]))
for idx,(name,data) in enumerate(adata_in.uns['orthologues'].iterrows()):
    for gene in data:
        orthologues_transform_map[np.argwhere(adata_in.var_names==gene),idx]=1
print(orthologues_transform_map)

[[0.]
 [1.]
 [1.]
 [0.]
 [1.]]


In [163]:
# Similar cells cross-species mixup
# For all pairs of species find shared neighbours and then pick randomly 
# cell pairs from all possible cell pairs

n_pca=15 # PCS for embedding
k=30 # N neighbours

similar_mixup_ratio=1
desired_n=int(similar_mixup_ratio*adata_pp.shape[0])
similar_mixup_idx=set()
# Adata mapped to orthologues summarised across species
# Change for testing TODO!!!! remove - need larger object than original one to build neighbours
if False:
    adata_orthologues=sc.AnnData(np.matmul(adata_pp.X,orthologues_transform_map),
                             obs=pd.DataFrame({'species':species},index=adata_pp.obs_names))
else:
    adata_orthologues=sc.AnnData(np.random.normal(size=(99,5000)),
                             obs=pd.DataFrame({'species':['K','L','M']*33}))
# Compute embedding
sc.pp.highly_variable_genes(adata_orthologues,
                                 # Can only compute as many HVGs as there are genes
                                 n_top_genes=min([2000,adata_orthologues.shape[1]]),
                                 flavor='cell_ranger', subset=True, inplace=True, 
                                 batch_key='species')
sc.pp.scale(adata_orthologues)
sc.pp.pca(adata_orthologues, n_comps=n_pca, zero_center=True)
# Neighbours for all species pairs
pairs_all=[]
for i in range(len(species_order)-1):
    for j in range(i+1,len(species_order)):
        # Prepare species data
        s_i=species_order[i]
        s_j=species_order[j]
        e_i=adata_orthologues[adata_orthologues.obs['species']==s_i,:].obsm['X_pca']
        e_j=adata_orthologues[adata_orthologues.obs['species']==s_j,:].obsm['X_pca']
        # On which position was originally each cell
        idxs_i=np.argwhere(adata_orthologues.obs['species'].values==s_i).ravel()
        idxs_j=np.argwhere(adata_orthologues.obs['species'].values==s_j).ravel()
        # KNN
        index_i = NNDescent(e_i, metric='correlation', n_jobs=-1)
        neighbours_j, distances_j = index_i.query(e_j, k=k)
        index_j = NNDescent(e_j, metric='correlation', n_jobs=-1)
        neighbours_i, distances_i = index_j.query(e_i, k=k)
        neighbours=np.zeros((e_i.shape[0],e_j.shape[0]))
        # Parse KNN - count if presnet for both directions
        pairs={}
        for cell_j in range(neighbours_j.shape[0]):
            for idx_i in range(k):
                cell_i=neighbours_j[cell_j][idx_i]
                pair=str(cell_i)+'_'+str(cell_j)
                if pair not in pairs:
                    pairs[pair]=0
                pairs[pair]=pairs[pair]+1
        for cell_i in range(neighbours_i.shape[0]):
            for idx_j in range(k):
                cell_j=neighbours_i[cell_i][idx_j]
                pair=str(cell_i)+'_'+str(cell_j)
                if pair not in pairs:
                    pairs[pair]=0
                pairs[pair]=pairs[pair]+1
        # Get shared neighbors based on counts of directions
        for pair,n in pairs.items():
            if n==2:
                idx_i=int(pair.split('_')[0])
                idx_j=int(pair.split('_')[1])
                # Map neighbors to original indices
                pairs_all.append((idxs_i[idx_i],idxs_j[idx_j])) 
# Testing, TODO remove
if True:
    pairs_all=[(1,2),(1,3),(0,2),(0,3),(2,3)]

# Subset to desired N of pairs
if seed is not None:
    np.random.seed(seed)
pairs_all=np.array(pairs_all)
pairs_idx=np.random.choice(range(pairs_all.shape[0]), size=desired_n, replace=False) 
pairs_all=pairs_all[pairs_idx]

# Create adata
adata_similar_species_mixup=create_mixup(indices=pairs_all,
                                        adata=adata_pp,obs_prefix='mixup_species_similar',
                                         seed=seed)
count_species_pairs(adata_similar_species_mixup,description='similar corss-species mixup')
adata_similar_species_mixup.obsm['eval_o']=np.array([0]*adata_similar_species_mixup.shape[0]
                                                   ).reshape(-1,1)

# Check what came out
print('X')
display(adata_similar_species_mixup.to_df())
for k,v in adata_similar_species_mixup.obsm.items():
    print('obsm',k)
    display(v)

  c.reorder_categories(natsorted(c.categories), inplace=True)
... storing 'species' as categorical
A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  self._setitem_single_block(indexer, value, name)


N pairs per species combination for similar corss-species mixup : {'K and L': 2, 'K and M': 2}
X


Unnamed: 0,a,b,c,d,e
mixup_species_similar_0_2,1.764052,0.800314,2.283113,0.4867,0.0
mixup_species_similar_1_2,-0.977278,1.900177,2.283113,0.4867,0.0
mixup_species_similar_1_3,-0.977278,1.900177,0.0,0.0,-4.270479
mixup_species_similar_0_3,1.764052,0.800314,0.0,0.0,-4.270479


obsm cov_species


array([[[1.00000000e+00, 0.00000000e+00, 0.00000000e+00],
        [1.00000000e+00, 0.00000000e+00, 0.00000000e+00],
        [1.00000000e+00, 0.00000000e+00, 0.00000000e+00]],

       [[0.00000000e+00, 1.00000000e+00, 1.00000000e+00],
        [1.00000000e+00, 0.00000000e+00, 0.00000000e+00],
        [2.67071293e-01, 7.32928707e-01, 7.32928707e-01]],

       [[0.00000000e+00, 1.00000000e+00, 1.00000000e+00],
        [0.00000000e+00, 1.00000000e+00, 1.00000000e+00],
        [0.00000000e+00, 1.00000000e+00, 1.00000000e+00]],

       [[1.00000000e+00, 0.00000000e+00, 0.00000000e+00],
        [0.00000000e+00, 1.00000000e+00, 1.00000000e+00],
        [8.08633422e-04, 9.99191367e-01, 9.99191367e-01]]])

obsm species_ratio


array([[6.61193791e-02, 9.33880621e-01, 0.00000000e+00],
       [7.32928707e-01, 2.67071293e-01, 0.00000000e+00],
       [1.45262365e-02, 0.00000000e+00, 9.85473764e-01],
       [8.08633422e-04, 0.00000000e+00, 9.99191367e-01]])

obsm eval_o


array([[0],
       [0],
       [0],
       [0]])

In [166]:
# Combine adatas
adata_combined=ann.concat([adata_pp,adata_random_species_mixup,adata_similar_species_mixup])
# Add extra info from processed adata
adata_combined.var=adata_pp.var
adata_combined.uns=adata_pp.uns
print(adata_combined)

AnnData object with n_obs × n_vars = 12 × 5
    var: 'species'
    uns: 'cov_c_order'
    obsm: 'cov_species', 'species_ratio', 'eval_o'
