In [1]:
import pandas as pd
import numpy as np
from pathlib import Path
from abc_atlas_access.abc_atlas_cache.abc_project_cache import AbcProjectCache
from sklearn.model_selection import train_test_split
from sklearn.utils.class_weight import compute_class_weight
import anndata
import scipy.sparse
import scanpy as sc
import os

In [2]:
download_base = Path('/gpfs/scratch/blukacsy/abc_atlas')
abc_cache = AbcProjectCache.from_cache_dir(download_base)
print(f"Current manifest: {abc_cache.current_manifest}")

Current manifest: releases/20250531/manifest.json


In [3]:
abc_cache.list_directories

['ASAP-PMDBS-10X',
 'ASAP-PMDBS-taxonomy',
 'Allen-CCF-2020',
 'HMBA-10xMultiome-BG',
 'HMBA-10xMultiome-BG-Aligned',
 'HMBA-BG-taxonomy-CCN20250428',
 'MERFISH-C57BL6J-638850',
 'MERFISH-C57BL6J-638850-CCF',
 'MERFISH-C57BL6J-638850-imputed',
 'MERFISH-C57BL6J-638850-sections',
 'SEAAD',
 'SEAAD-taxonomy',
 'WHB-10Xv3',
 'WHB-taxonomy',
 'WMB-10X',
 'WMB-10XMulti',
 'WMB-10Xv2',
 'WMB-10Xv3',
 'WMB-neighborhoods',
 'WMB-taxonomy',
 'Zeng-Aging-Mouse-10Xv3',
 'Zeng-Aging-Mouse-WMB-taxonomy',
 'Zhuang-ABCA-1',
 'Zhuang-ABCA-1-CCF',
 'Zhuang-ABCA-2',
 'Zhuang-ABCA-2-CCF',
 'Zhuang-ABCA-3',
 'Zhuang-ABCA-3-CCF',
 'Zhuang-ABCA-4',
 'Zhuang-ABCA-4-CCF']

In [4]:
abc_cache.list_metadata_files('WMB-10X')

['cell_metadata',
 'cell_metadata_with_cluster_annotation',
 'example_genes_all_cells_expression',
 'gene',
 'region_of_interest_metadata']

In [5]:
abc_cache.list_data_files('WMB-10Xv2')

['WMB-10Xv2-CTXsp/log2',
 'WMB-10Xv2-CTXsp/raw',
 'WMB-10Xv2-HPF/log2',
 'WMB-10Xv2-HPF/raw',
 'WMB-10Xv2-HY/log2',
 'WMB-10Xv2-HY/raw',
 'WMB-10Xv2-Isocortex-1/log2',
 'WMB-10Xv2-Isocortex-1/raw',
 'WMB-10Xv2-Isocortex-2/log2',
 'WMB-10Xv2-Isocortex-2/raw',
 'WMB-10Xv2-Isocortex-3/log2',
 'WMB-10Xv2-Isocortex-3/raw',
 'WMB-10Xv2-Isocortex-4/log2',
 'WMB-10Xv2-Isocortex-4/raw',
 'WMB-10Xv2-MB/log2',
 'WMB-10Xv2-MB/raw',
 'WMB-10Xv2-OLF/log2',
 'WMB-10Xv2-OLF/raw',
 'WMB-10Xv2-TH/log2',
 'WMB-10Xv2-TH/raw']

In [6]:
# get all log gene expression matrices
def get_matrices(dir_name):
    return [matrix for matrix in abc_cache.list_data_files(dir_name) if matrix.endswith('/log2')]

all_matrices = sorted(get_matrices('WMB-10XMulti') + get_matrices('WMB-10Xv2') + get_matrices('WMB-10Xv3')) 

In [7]:
all_matrices

['WMB-10XMulti/log2',
 'WMB-10Xv2-CTXsp/log2',
 'WMB-10Xv2-HPF/log2',
 'WMB-10Xv2-HY/log2',
 'WMB-10Xv2-Isocortex-1/log2',
 'WMB-10Xv2-Isocortex-2/log2',
 'WMB-10Xv2-Isocortex-3/log2',
 'WMB-10Xv2-Isocortex-4/log2',
 'WMB-10Xv2-MB/log2',
 'WMB-10Xv2-OLF/log2',
 'WMB-10Xv2-TH/log2',
 'WMB-10Xv3-CB/log2',
 'WMB-10Xv3-CTXsp/log2',
 'WMB-10Xv3-HPF/log2',
 'WMB-10Xv3-HY/log2',
 'WMB-10Xv3-Isocortex-1/log2',
 'WMB-10Xv3-Isocortex-2/log2',
 'WMB-10Xv3-MB/log2',
 'WMB-10Xv3-MY/log2',
 'WMB-10Xv3-OLF/log2',
 'WMB-10Xv3-P/log2',
 'WMB-10Xv3-PAL/log2',
 'WMB-10Xv3-STR/log2',
 'WMB-10Xv3-TH/log2']

In [None]:
all_adata = []
for matrix in all_matrices:
    directory = "WMB-10XMulti"
    if ('10Xv3' in matrix): directory = 'WMB-10Xv3'
    if ('10Xv2' in matrix): directory = 'WMB-10Xv2'
    file = abc_cache.get_data_path(directory=directory, file_name=matrix)
    temp_adata = anndata.read_h5ad(file)
    all_adata.append(temp_adata)

In [25]:
len(all_adata)

24

In [26]:
adata = anndata.concat(all_adata, axis=0, join='outer')
print("done")

done


In [30]:
adata

AnnData object with n_obs × n_vars = 4059388 × 32285
    obs: 'cell_barcode', 'library_label', 'anatomical_division_label'

In [36]:
adata.write("/gpfs/scratch/blukacsy/abc_atlas/data/adata.h5ad")

In [6]:
adata = anndata.read_h5ad("/gpfs/scratch/blukacsy/abc_atlas/data/adata.h5ad")

In [7]:
adata

AnnData object with n_obs × n_vars = 4059388 × 32285
    obs: 'cell_barcode', 'library_label', 'anatomical_division_label'

In [12]:
cell = abc_cache.get_metadata_dataframe(
    directory='WMB-10X',
    file_name='cell_metadata',
    dtype={'cell_label': str}
)
cell.set_index('cell_label', inplace=True)
print(f"Total cells: {len(cell)}")

cluster_details = abc_cache.get_metadata_dataframe(
    directory='WMB-taxonomy',
    file_name='cluster_to_cluster_annotation_membership_pivoted',
    keep_default_na=False
)
cluster_details.set_index('cluster_alias', inplace=True)

Total cells: 4042976


In [13]:
cell

Unnamed: 0_level_0,cell_barcode,barcoded_cell_sample_label,library_label,feature_matrix_label,entity,brain_section_label,library_method,region_of_interest_acronym,donor_label,donor_genotype,donor_sex,dataset_label,x,y,cluster_alias,abc_sample_id
cell_label,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1
GCGAGAAGTTAAGGGC-410_B05,GCGAGAAGTTAAGGGC,410_B05,L8TX_201030_01_C12,WMB-10Xv3-HPF,cell,,10Xv3,RHP,Snap25-IRES2-Cre;Ai14-550850,Ai14(RCL-tdT)/wt,F,WMB-10Xv3,23.146826,-3.086639,1,484be5df-5d44-4bfe-9652-7b5bc739c211
AATGGCTCAGCTCCTT-411_B06,AATGGCTCAGCTCCTT,411_B06,L8TX_201029_01_E10,WMB-10Xv3-HPF,cell,,10Xv3,RHP,Snap25-IRES2-Cre;Ai14-550851,Ai14(RCL-tdT)/wt,F,WMB-10Xv3,23.138481,-3.022000,1,5638505d-e1e8-457f-9e5b-59e3e2302417
AACACACGTTGCTTGA-410_B05,AACACACGTTGCTTGA,410_B05,L8TX_201030_01_C12,WMB-10Xv3-HPF,cell,,10Xv3,RHP,Snap25-IRES2-Cre;Ai14-550850,Ai14(RCL-tdT)/wt,F,WMB-10Xv3,23.472557,-2.992709,1,a0544e29-194f-4d34-9af4-13e7377b648f
CACAGATAGAGGCGGA-410_A05,CACAGATAGAGGCGGA,410_A05,L8TX_201029_01_A10,WMB-10Xv3-HPF,cell,,10Xv3,RHP,Snap25-IRES2-Cre;Ai14-550850,Ai14(RCL-tdT)/wt,F,WMB-10Xv3,23.379622,-3.043442,1,c777ac0b-77e1-4d76-bf8e-2b3d9e08b253
AAAGTGAAGCATTTCG-410_B05,AAAGTGAAGCATTTCG,410_B05,L8TX_201030_01_C12,WMB-10Xv3-HPF,cell,,10Xv3,RHP,Snap25-IRES2-Cre;Ai14-550850,Ai14(RCL-tdT)/wt,F,WMB-10Xv3,23.909480,-2.601536,1,49860925-e82b-46df-a228-fd2f97e75d39
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
GTGTGAGCAAACGCGA-1350_C05,GTGTGAGCAAACGCGA,1350_C05,L8XR_220728_01_A05,WMB-10XMulti,cell,,10xRSeq_Mult,MB,C57BL6J-641405,wt/wt,M,WMB-10XMulti,-7.716915,0.223654,8861,ba1d0e38-bea7-4d4f-bfcd-49121938e743
TTAGCAATCCCTGTTA-1350_C05,TTAGCAATCCCTGTTA,1350_C05,L8XR_220728_01_A05,WMB-10XMulti,cell,,10xRSeq_Mult,MB,C57BL6J-641405,wt/wt,M,WMB-10XMulti,-3.115098,-3.024478,8215,342bd0bb-cbe5-479b-9c70-fef59a730255
TTTGGCTGTCGCGCAA-1350_C05,TTTGGCTGTCGCGCAA,1350_C05,L8XR_220728_01_A05,WMB-10XMulti,cell,,10xRSeq_Mult,MB,C57BL6J-641405,wt/wt,M,WMB-10XMulti,-7.950964,0.409335,8798,4634de09-d8e0-4e40-a49b-eba311de08b5
ATCCACCTCACAGACT-1320_B04,ATCCACCTCACAGACT,1320_B04,L8XR_220630_02_B10,WMB-10XMulti,cell,,10xRSeq_Mult,OLF,C57BL6J-625156,wt/wt,F,WMB-10XMulti,4.579441,12.135833,8798,5b3061de-1cb8-47b6-9368-52824e1031ce


In [14]:
cluster_details

Unnamed: 0_level_0,neurotransmitter,class,subclass,supertype,cluster
cluster_alias,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
1,Glut,01 IT-ET Glut,018 L2 IT PPP-APr Glut,0082 L2 IT PPP-APr Glut_3,0326 L2 IT PPP-APr Glut_3
2,Glut,01 IT-ET Glut,018 L2 IT PPP-APr Glut,0082 L2 IT PPP-APr Glut_3,0327 L2 IT PPP-APr Glut_3
3,Glut,01 IT-ET Glut,018 L2 IT PPP-APr Glut,0081 L2 IT PPP-APr Glut_2,0322 L2 IT PPP-APr Glut_2
4,Glut,01 IT-ET Glut,018 L2 IT PPP-APr Glut,0081 L2 IT PPP-APr Glut_2,0323 L2 IT PPP-APr Glut_2
5,Glut,01 IT-ET Glut,018 L2 IT PPP-APr Glut,0081 L2 IT PPP-APr Glut_2,0325 L2 IT PPP-APr Glut_2
...,...,...,...,...,...
34368,GABA-Glyc,27 MY GABA,288 MDRN Hoxb5 Ebf2 Gly-Gaba,1102 MDRN Hoxb5 Ebf2 Gly-Gaba_1,4955 MDRN Hoxb5 Ebf2 Gly-Gaba_1
34372,GABA-Glyc,27 MY GABA,285 MY Lhx1 Gly-Gaba,1091 MY Lhx1 Gly-Gaba_3,4901 MY Lhx1 Gly-Gaba_3
34374,GABA-Glyc,27 MY GABA,285 MY Lhx1 Gly-Gaba,1091 MY Lhx1 Gly-Gaba_3,4902 MY Lhx1 Gly-Gaba_3
34376,GABA-Glyc,27 MY GABA,285 MY Lhx1 Gly-Gaba,1091 MY Lhx1 Gly-Gaba_3,4903 MY Lhx1 Gly-Gaba_3


In [15]:
cell_extended = cell.join(cluster_details, on='cluster_alias')
print(f"Cells with annotations: {len(cell_extended)}")

Cells with annotations: 4042976


In [16]:
cell_extended

Unnamed: 0_level_0,cell_barcode,barcoded_cell_sample_label,library_label,feature_matrix_label,entity,brain_section_label,library_method,region_of_interest_acronym,donor_label,donor_genotype,...,dataset_label,x,y,cluster_alias,abc_sample_id,neurotransmitter,class,subclass,supertype,cluster
cell_label,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
GCGAGAAGTTAAGGGC-410_B05,GCGAGAAGTTAAGGGC,410_B05,L8TX_201030_01_C12,WMB-10Xv3-HPF,cell,,10Xv3,RHP,Snap25-IRES2-Cre;Ai14-550850,Ai14(RCL-tdT)/wt,...,WMB-10Xv3,23.146826,-3.086639,1,484be5df-5d44-4bfe-9652-7b5bc739c211,Glut,01 IT-ET Glut,018 L2 IT PPP-APr Glut,0082 L2 IT PPP-APr Glut_3,0326 L2 IT PPP-APr Glut_3
AATGGCTCAGCTCCTT-411_B06,AATGGCTCAGCTCCTT,411_B06,L8TX_201029_01_E10,WMB-10Xv3-HPF,cell,,10Xv3,RHP,Snap25-IRES2-Cre;Ai14-550851,Ai14(RCL-tdT)/wt,...,WMB-10Xv3,23.138481,-3.022000,1,5638505d-e1e8-457f-9e5b-59e3e2302417,Glut,01 IT-ET Glut,018 L2 IT PPP-APr Glut,0082 L2 IT PPP-APr Glut_3,0326 L2 IT PPP-APr Glut_3
AACACACGTTGCTTGA-410_B05,AACACACGTTGCTTGA,410_B05,L8TX_201030_01_C12,WMB-10Xv3-HPF,cell,,10Xv3,RHP,Snap25-IRES2-Cre;Ai14-550850,Ai14(RCL-tdT)/wt,...,WMB-10Xv3,23.472557,-2.992709,1,a0544e29-194f-4d34-9af4-13e7377b648f,Glut,01 IT-ET Glut,018 L2 IT PPP-APr Glut,0082 L2 IT PPP-APr Glut_3,0326 L2 IT PPP-APr Glut_3
CACAGATAGAGGCGGA-410_A05,CACAGATAGAGGCGGA,410_A05,L8TX_201029_01_A10,WMB-10Xv3-HPF,cell,,10Xv3,RHP,Snap25-IRES2-Cre;Ai14-550850,Ai14(RCL-tdT)/wt,...,WMB-10Xv3,23.379622,-3.043442,1,c777ac0b-77e1-4d76-bf8e-2b3d9e08b253,Glut,01 IT-ET Glut,018 L2 IT PPP-APr Glut,0082 L2 IT PPP-APr Glut_3,0326 L2 IT PPP-APr Glut_3
AAAGTGAAGCATTTCG-410_B05,AAAGTGAAGCATTTCG,410_B05,L8TX_201030_01_C12,WMB-10Xv3-HPF,cell,,10Xv3,RHP,Snap25-IRES2-Cre;Ai14-550850,Ai14(RCL-tdT)/wt,...,WMB-10Xv3,23.909480,-2.601536,1,49860925-e82b-46df-a228-fd2f97e75d39,Glut,01 IT-ET Glut,018 L2 IT PPP-APr Glut,0082 L2 IT PPP-APr Glut_3,0326 L2 IT PPP-APr Glut_3
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
GTGTGAGCAAACGCGA-1350_C05,GTGTGAGCAAACGCGA,1350_C05,L8XR_220728_01_A05,WMB-10XMulti,cell,,10xRSeq_Mult,MB,C57BL6J-641405,wt/wt,...,WMB-10XMulti,-7.716915,0.223654,8861,ba1d0e38-bea7-4d4f-bfcd-49121938e743,GABA-Glyc,26 P GABA,278 NLL Gata3 Gly-Gaba,1074 NLL Gata3 Gly-Gaba_1,4804 NLL Gata3 Gly-Gaba_1
TTAGCAATCCCTGTTA-1350_C05,TTAGCAATCCCTGTTA,1350_C05,L8XR_220728_01_A05,WMB-10XMulti,cell,,10xRSeq_Mult,MB,C57BL6J-641405,wt/wt,...,WMB-10XMulti,-3.115098,-3.024478,8215,342bd0bb-cbe5-479b-9c70-fef59a730255,Glut,19 MB Glut,157 RN Spp1 Glut,0682 RN Spp1 Glut_1,2761 RN Spp1 Glut_1
TTTGGCTGTCGCGCAA-1350_C05,TTTGGCTGTCGCGCAA,1350_C05,L8XR_220728_01_A05,WMB-10XMulti,cell,,10xRSeq_Mult,MB,C57BL6J-641405,wt/wt,...,WMB-10XMulti,-7.950964,0.409335,8798,4634de09-d8e0-4e40-a49b-eba311de08b5,GABA-Glyc,26 P GABA,278 NLL Gata3 Gly-Gaba,1076 NLL Gata3 Gly-Gaba_3,4806 NLL Gata3 Gly-Gaba_3
ATCCACCTCACAGACT-1320_B04,ATCCACCTCACAGACT,1320_B04,L8XR_220630_02_B10,WMB-10XMulti,cell,,10xRSeq_Mult,OLF,C57BL6J-625156,wt/wt,...,WMB-10XMulti,4.579441,12.135833,8798,5b3061de-1cb8-47b6-9368-52824e1031ce,GABA-Glyc,26 P GABA,278 NLL Gata3 Gly-Gaba,1076 NLL Gata3 Gly-Gaba_3,4806 NLL Gata3 Gly-Gaba_3


In [45]:
print(f"Cells missing subclass: {cell_extended['class'].isna().sum()}")

Cells missing subclass: 0


In [46]:
cell_extended.index

Index(['GCGAGAAGTTAAGGGC-410_B05', 'AATGGCTCAGCTCCTT-411_B06',
       'AACACACGTTGCTTGA-410_B05', 'CACAGATAGAGGCGGA-410_A05',
       'AAAGTGAAGCATTTCG-410_B05', 'GATCGTATCGAATCCA-411_B06',
       'AGATGAAAGGACCCAA-410_A05', 'TCTCACGGTCAGGAGT-411_A06',
       'GATTCTTGTTCGCGTG-410_B05', 'TTTCGATAGTAAAGCT-410_B05',
       ...
       'TGTCCTTCATCTAGCA-1320_D04', 'TTAACTGAGTCAGGCC-1320_D04',
       'TGGTTCTGTCTATCGT-1315_A01', 'ATCATGTCATCATGGC-1350_C05',
       'ATTCCTCCAGCCAGAA-1350_C05', 'GTGTGAGCAAACGCGA-1350_C05',
       'TTAGCAATCCCTGTTA-1350_C05', 'TTTGGCTGTCGCGCAA-1350_C05',
       'ATCCACCTCACAGACT-1320_B04', 'TCGTTAGCATTGTCCT-1320_B04'],
      dtype='object', name='cell_label', length=4042976)

In [39]:
adata = adata[cell_extended.index, :]

In [40]:
adata

View of AnnData object with n_obs × n_vars = 4042976 × 32285
    obs: 'cell_barcode', 'library_label', 'anatomical_division_label', 'class', 'subclass', 'supertype', 'cluster'

In [41]:
adata.obs['class'] = cell_extended['class']
adata.obs['subclass'] = cell_extended['subclass']
adata.obs['supertype'] = cell_extended['supertype']
adata.obs['cluster'] = cell_extended['cluster']

  adata.obs['class'] = cell_extended['class']


In [71]:
def integer_mappings(group):
    unique_groups = sorted(cell_extended[group].unique())
    group_to_int = {groups: i for i, groups in enumerate(unique_groups)}

    cell_subset = cell_extended.copy()
    adata.obs[f'{group}_labels'] = cell_subset[group].map(group_to_int)
    print(f'# unique groups: {len(unique_groups)}')

    group_mapping = pd.DataFrame({
        f'{group}': list(group_to_int.keys()),
        'label': list(group_to_int.values())
    })

    return group_mapping

In [72]:
class_mapping = integer_mappings('class')
subclass_mapping = integer_mappings('subclass')
supertype_mapping = integer_mappings('supertype')
cluster_mapping = integer_mappings('cluster')

# unique groups: 34
# unique groups: 338
# unique groups: 1201
# unique groups: 5322


In [84]:
adata

AnnData object with n_obs × n_vars = 4042976 × 32285
    obs: 'cell_barcode', 'library_label', 'anatomical_division_label', 'class', 'subclass', 'supertype', 'cluster', 'class_labels', 'subclass_labels', 'supertype_labels', 'cluster_labels'

In [85]:
adata.write("/gpfs/scratch/blukacsy/abc_atlas/data/adata.h5ad")

In [8]:
adata = anndata.read_h5ad("/gpfs/scratch/blukacsy/abc_atlas/data/adata.h5ad")

In [9]:
adata

AnnData object with n_obs × n_vars = 4042976 × 32285
    obs: 'cell_barcode', 'library_label', 'anatomical_division_label', 'class', 'subclass', 'supertype', 'cluster', 'class_labels', 'subclass_labels', 'supertype_labels', 'cluster_labels'

In [26]:
adata.obs['feature_matrix_label'] = cell_extended['feature_matrix_label']

In [None]:
# batching by 10x difference (10Xv2, 10Xv3, 10XMulti and their subtypes, but no idea if this is good practice)
sc.pp.highly_variable_genes(adata, n_top_genes=3500, batch_key='feature_matrix_label')

In [33]:
adata

AnnData object with n_obs × n_vars = 4042976 × 32285
    obs: 'cell_barcode', 'library_label', 'anatomical_division_label', 'class', 'subclass', 'supertype', 'cluster', 'class_labels', 'subclass_labels', 'supertype_labels', 'cluster_labels', 'feature_matrix_label'
    var: 'highly_variable', 'means', 'dispersions', 'dispersions_norm', 'highly_variable_nbatches', 'highly_variable_intersection'
    uns: 'hvg'

In [42]:
adata_hvg = adata[:, adata.var['highly_variable'] ].copy()

In [57]:
adata_hvg

AnnData object with n_obs × n_vars = 4042976 × 3500
    obs: 'cell_barcode', 'library_label', 'anatomical_division_label', 'class', 'subclass', 'supertype', 'cluster', 'class_labels', 'subclass_labels', 'supertype_labels', 'cluster_labels', 'feature_matrix_label'
    var: 'highly_variable', 'means', 'dispersions', 'dispersions_norm', 'highly_variable_nbatches', 'highly_variable_intersection'
    uns: 'hvg'

In [87]:
def train_test_val_split(adata, dimension, group, seed1, seed2):
    
    if scipy.sparse.issparse(adata.X):
        X = adata.X.toarray()
    else:
        X = adata.X

    print(X.min())
    print(X.max())

    y = adata.obs[f'{group}_labels'].values

    indices = adata.obs.index.to_numpy()
    
    train_features, test_features, train_labels, test_labels, train_index, test_index = train_test_split(X, y, indices, test_size = 0.2, random_state=seed1, stratify=y)
    test_features, val_features, test_labels, val_labels, test_index, val_index = train_test_split(test_features, test_labels, test_index, test_size = 0.5, random_state=seed2, stratify=test_labels)
    weights = compute_class_weight(class_weight='balanced', classes=np.unique(train_labels), y=train_labels)

    train_features = np.array(train_features)
    test_features = np.array(test_features)
    val_features = np.array(val_features)
    train_labels = np.array(train_labels)
    test_labels = np.array(test_labels)
    val_labels = np.array(val_labels)
    weights = np.array(weights)
    train_index = np.array(train_index)
    test_index = np.array(test_index)
    val_index = np.array(val_index)

    print()
    print('Train features shape:', train_features.shape)
    print('Val features shape:', val_features.shape)
    print('Test features shape:', test_features.shape)
    print('Train labels shape:', train_labels.shape)
    print('Val labels shape:', val_labels.shape)
    print('Test labels shape:', test_labels.shape)
    print('Weights shape:', weights.shape)
    print()
    print('Train index shape:', train_index.shape)
    print('Test index shape:', test_index.shape)
    print('Val index shape:', val_index.shape)
    print()

    np.save(f'/gpfs/scratch/blukacsy/abc_atlas/arrays/train_features_{dimension}_{group}.npy', train_features)
    np.save(f'/gpfs/scratch/blukacsy/abc_atlas/arrays/test_features_{dimension}_{group}.npy', test_features)
    np.save(f'/gpfs/scratch/blukacsy/abc_atlas/arrays/val_features_{dimension}_{group}.npy', val_features)
    np.save(f'/gpfs/scratch/blukacsy/abc_atlas/arrays/train_labels_{dimension}_{group}.npy', train_labels)
    np.save(f'/gpfs/scratch/blukacsy/abc_atlas/arrays/test_labels_{dimension}_{group}.npy', test_labels)
    np.save(f'/gpfs/scratch/blukacsy/abc_atlas/arrays/val_labels_{dimension}_{group}.npy', val_labels)
    np.save(f'/gpfs/scratch/blukacsy/abc_atlas/arrays/weights_{dimension}_{group}.npy', weights)
    np.save(f'/gpfs/scratch/blukacsy/abc_atlas/arrays/train_index_{dimension}_{group}.npy', train_index)
    np.save(f'/gpfs/scratch/blukacsy/abc_atlas/arrays/test_index_{dimension}_{group}.npy', test_index)
    np.save(f'/gpfs/scratch/blukacsy/abc_atlas/arrays/val_index_{dimension}_{group}.npy', val_index)

In [83]:
class_mapping.to_csv('/gpfs/scratch/blukacsy/abc_atlas/data/class_mapping.csv')
subclass_mapping.to_csv('/gpfs/scratch/blukacsy/abc_atlas/data/subclass_mapping.csv')
supertype_mapping.to_csv('/gpfs/scratch/blukacsy/abc_atlas/data/supertype_mapping.csv')
cluster_mapping.to_csv('/gpfs/scratch/blukacsy/abc_atlas/data/cluster_mapping.csv')

In [84]:
adata.write("/gpfs/scratch/blukacsy/abc_atlas/data/adata.h5ad")
adata_hvg.write("/gpfs/scratch/blukacsy/abc_atlas/data/adata_hvg.h5ad")
print("done writing")

done writing


In [85]:
train_test_val_split(adata_hvg, "hvg", "class", 7105, 3870)

0.0
18.814487

Train features shape: (3234380, 3500)
Val features shape: (404298, 3500)
Test features shape: (404298, 3500)
Train labels shape: (3234380,)
Val labels shape: (404298,)
Test labels shape: (404298,)
Weights shape: (34,)

Train index shape: (3234380,)
Test index shape: (404298,)
Val index shape: (404298,)


In [88]:
train_test_val_split(adata_hvg, "hvg", "subclass", 7105, 3870)
train_test_val_split(adata_hvg, "hvg", "supertype", 7105, 3870)
train_test_val_split(adata_hvg, "hvg", "cluster", 7105, 3870)

0.0
18.814487

Train features shape: (3234380, 3500)
Val features shape: (404298, 3500)
Test features shape: (404298, 3500)
Train labels shape: (3234380,)
Val labels shape: (404298,)
Test labels shape: (404298,)
Weights shape: (338,)

Train index shape: (3234380,)
Test index shape: (404298,)
Val index shape: (404298,)

0.0
18.814487

Train features shape: (3234380, 3500)
Val features shape: (404298, 3500)
Test features shape: (404298, 3500)
Train labels shape: (3234380,)
Val labels shape: (404298,)
Test labels shape: (404298,)
Weights shape: (1201,)

Train index shape: (3234380,)
Test index shape: (404298,)
Val index shape: (404298,)

0.0
18.814487

Train features shape: (3234380, 3500)
Val features shape: (404298, 3500)
Test features shape: (404298, 3500)
Train labels shape: (3234380,)
Val labels shape: (404298,)
Test labels shape: (404298,)
Weights shape: (5322,)

Train index shape: (3234380,)
Test index shape: (404298,)
Val index shape: (404298,)



In [None]:
# hvg arrays total ~55 GB cold storage