In [None]:
from ALLCools.clustering import *
from wmb import brain, cemba, aibs
import numpy as np
import pandas as pd
import anndata

import matplotlib.pyplot as plt
from ALLCools.plot import *

from ALLCools.integration.seurat_class import SeuratIntegration

In [None]:
categorical_key = ['L1_annot', 'DissectionRegion']

## Input LSI before integration

In [None]:
ref_adata = anndata.read_h5ad('./adata/mc_pca.h5ad')
query_adata = anndata.read_h5ad('./adata/merfish_pca.h5ad')

In [None]:
adata_list = [ref_adata, query_adata]

### Init empty adata_merge

In [None]:
from scipy.sparse import csr_matrix

cells = sum([a.shape[0] for a in adata_list])
features = adata_list[0].shape[1]

adata_merge = anndata.AnnData(X=csr_matrix(([], ([], [])),
                                           shape=(cells, features)),
                              obs=pd.concat([a.obs for a in adata_list]),
                              var=adata_list[0].var)

In [None]:
mc_annot = cemba.get_mc_annot()

for key in categorical_key:
    adata_merge.obs[key] =mc_annot[key].to_pandas()
    

In [None]:
for adata in adata_list:
    for key in categorical_key:
        adata.obs[key] = adata_merge.obs[key]

In [None]:
adata_list

In [None]:
adata_merge

## Integration and transform

In [None]:
integrator = SeuratIntegration()

In [None]:
anchor = integrator.find_anchor(adata_list,
                                k_local=None,
                                key_local='X_pca',
                                k_anchor=5,
                                key_anchor='X',
                                dim_red='cca',
                                max_cc_cells=100000,
                                k_score=30,
                                k_filter=None,
                                scale1=False,
                                scale2=False,
                                n_components=50,
                                n_features=200,
                                alignments=[[[0], [1]]])

In [None]:
corrected = integrator.integrate(key_correct='X_pca',
                                 row_normalize=True,
                                 n_components=30,
                                 k_weight=100,
                                 sd=1,
                                 alignments=[[[0], [1]]])

adata_merge.obsm['X_pca_integrate'] = np.concatenate(corrected)

## Label transfer

In [None]:
transfer_results = integrator.label_transfer(
    ref=[0],
    qry=[1],
    categorical_key=categorical_key,
    key_dist='X_pca'
)

In [None]:
for k, v in transfer_results.items():
    v.to_hdf(f'{k}_transfer.hdf', key='data')

In [None]:
integrator.save_transfer_results_to_adata(adata_merge, transfer_results)

## Save

In [None]:
#adata_merge.write_h5ad('./adata/final.h5ad')

In [None]:
adata_merge

In [None]:
integrator.save('integration')