In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
import scanpy as sc
from os.path import join, exists
from os import listdir
import anndata
import scipy
import numpy as np
import sys

from utils import *

# convert counts into float32
# Convenience method for computing the size of objects
def print_size_in_MB(x):
    return '{:.3} MB'.format(x.__sizeof__()/1e6)

def print_size_in_MB_sparse_matrix(a):
    # a = scipy.sparse.csr_matrix(np.random.randint(10, size=(40, 3)))
    # x = a.data.nbytes + a.indptr.nbytes + a.indices.nbytes
    size = a.data.size/(1024**2)
    return '{:.3} MB'.format(size)

import warnings
warnings.filterwarnings("ignore")

In [None]:
overwrite = False
for n_sample_per_batch in [500, 1000, None]:
    # examine types, columns and others incorporated in the object
    
    code_n_cells = (('_' + str(n_sample_per_batch) if n_sample_per_batch is not None else ''))

    print(code_n_cells)

    print('# of cells (input argument)', n_sample_per_batch)
    
    code_output = (('_' + str(n_sample_per_batch) if n_sample_per_batch is not None else '_all'))
    output_path = '../../data/integration_march_2021/input/input%s_cells.h5ad' % code_output
    print(output_path)            
    
    if exists(output_path) and not overwrite:
        continue

    dataset_names = ["Wong", "Scheetz", "Chen_c", "Hafler", "Roska", "Chen_a", "Sanes", "Hackney", "Chen_b"]

    # to avoid memory leaks do it in two rounds
    p1 = output_path.replace('.h5ad', '_part1.h5ad')

    if not exists(p1):
        print('loading', dataset_names[:4])
        ad1 = get_datasets(dataset_names[:4], code_n_cells=code_n_cells)

        print('ad1')
        print ('laoding datasets 1 done...')
        print(ad1.obs.dataset.value_counts())
        # save part1
         # save part1
        ad1 = ad1[ad1.obs.dataset.isin(set(dataset_names[:4])),:]
        ad1.write(p1, compression='lzf')
        del ad1
        print(p1)

    p2 = output_path.replace('.h5ad', '_part2.h5ad')
    if not exists(p2):
        print('loading', dataset_names[4:])
        ad2 = get_datasets(dataset_names[4:], code_n_cells=code_n_cells)
        print('ad2')
        print(ad2)
        print(ad2.obs.index)
        print ('laoding datasets 2 done...')
        print(ad2.obs.dataset.value_counts())

        # save part1
        ad2 = ad2[ad2.obs.dataset.isin(set(dataset_names[4:])),:]
        ad2.write(p2, compression='lzf')
        del ad2
        print(p2)           

    ad1 = sc.read_h5ad(p1)
    ad2 = sc.read_h5ad(p2)

    # print(ad1.obs.dataset.value_counts())
    # print(ad2.obs.dataset.value_counts())

    print('concatenating...')
    ad_final = anndata.concat([ad1, ad2])
    print('done...')

    print('ad final')
    print(ad1.shape, ad2.shape)
    print(ad_final.shape)
    # print(ad_final.obs.index)

    # define a unified code for all categories
    ad_final.obs['batch.merged'] = ad_final.obs['dataset'].astype(str) + ':' + ad_final.obs['batch'].astype(str)
    ad_final.obs['batch.merged'] = ad_final.obs['batch.merged'].astype('category').cat.codes
    # input_scib.obs['batch.merged'].value_counts()
    ad_final.obs['batch.merged'] = ad_final.obs['batch.merged'].astype('category').astype(str)
    # print(ad_final.obs['batch.merged'].value_counts())

    # we only care about genes detected in at least X cells or more (X=50)
    # min_cells = 50
    # sc.pp.filter_genes(ad, min_cells=min_cells)


    ad_final = ad_final[ad_final.obs['batch.merged'].map(ad_final.obs['batch.merged'].value_counts().to_dict()) > 100,:]
    ad_final.obs['batch.merged'].value_counts()

    print('saving to output...')
    ad_final.write(output_path, compression='lzf')
    print('done...')