In [1]:
import pandas as pd
import numpy as np
from pathlib import Path
from abc_atlas_access.abc_atlas_cache.abc_project_cache import AbcProjectCache
from sklearn.model_selection import train_test_split
from sklearn.utils.class_weight import compute_class_weight
import anndata
import scipy.sparse
import scanpy as sc
import os

In [2]:
download_base = Path('/gpfs/scratch/blukacsy/abc_atlas')
abc_cache = AbcProjectCache.from_cache_dir(download_base)
print(f"Current manifest: {abc_cache.current_manifest}")

Current manifest: releases/20250531/manifest.json


In [3]:
abc_cache.list_directories

['ASAP-PMDBS-10X',
 'ASAP-PMDBS-taxonomy',
 'Allen-CCF-2020',
 'HMBA-10xMultiome-BG',
 'HMBA-10xMultiome-BG-Aligned',
 'HMBA-BG-taxonomy-CCN20250428',
 'MERFISH-C57BL6J-638850',
 'MERFISH-C57BL6J-638850-CCF',
 'MERFISH-C57BL6J-638850-imputed',
 'MERFISH-C57BL6J-638850-sections',
 'SEAAD',
 'SEAAD-taxonomy',
 'WHB-10Xv3',
 'WHB-taxonomy',
 'WMB-10X',
 'WMB-10XMulti',
 'WMB-10Xv2',
 'WMB-10Xv3',
 'WMB-neighborhoods',
 'WMB-taxonomy',
 'Zeng-Aging-Mouse-10Xv3',
 'Zeng-Aging-Mouse-WMB-taxonomy',
 'Zhuang-ABCA-1',
 'Zhuang-ABCA-1-CCF',
 'Zhuang-ABCA-2',
 'Zhuang-ABCA-2-CCF',
 'Zhuang-ABCA-3',
 'Zhuang-ABCA-3-CCF',
 'Zhuang-ABCA-4',
 'Zhuang-ABCA-4-CCF']

In [4]:
abc_cache.list_metadata_files('WMB-10X')

['cell_metadata',
 'cell_metadata_with_cluster_annotation',
 'example_genes_all_cells_expression',
 'gene',
 'region_of_interest_metadata']

In [5]:
cell = abc_cache.get_metadata_dataframe(
    directory='WMB-10X',
    file_name='cell_metadata',
    dtype={'cell_label': str}
)
cell.set_index('cell_label', inplace=True)
print("Number of cells = ", len(cell))

cell_metadata.csv: 100%|██████████| 1.01G/1.01G [00:45<00:00, 22.1MMB/s]   


Number of cells =  4042976


In [6]:
cell["dataset_label"].value_counts()

dataset_label
WMB-10Xv3       2341350
WMB-10Xv2       1699939
WMB-10XMulti       1687
Name: count, dtype: int64

In [7]:
cell["feature_matrix_label"].value_counts()

feature_matrix_label
WMB-10Xv3-MB             337101
WMB-10Xv3-STR            283782
WMB-10Xv2-Isocortex-2    249360
WMB-10Xv2-Isocortex-3    249356
WMB-10Xv2-Isocortex-4    248784
WMB-10Xv2-Isocortex-1    248776
WMB-10Xv3-Isocortex-1    227670
WMB-10Xv3-Isocortex-2    227537
WMB-10Xv2-HPF            207281
WMB-10Xv2-OLF            192182
WMB-10Xv3-MY             191746
WMB-10Xv3-CB             181723
WMB-10Xv3-HPF            181055
WMB-10Xv3-HY             162296
WMB-10Xv3-P              143157
WMB-10Xv2-TH             130555
WMB-10Xv3-TH             130454
WMB-10Xv3-PAL            108046
WMB-10Xv2-HY              99879
WMB-10Xv3-OLF             88560
WMB-10Xv3-CTXsp           78223
WMB-10Xv2-CTXsp           43985
WMB-10Xv2-MB              29781
WMB-10XMulti               1687
Name: count, dtype: int64

In [8]:
abc_cache.list_data_files('WMB-10XMulti')

['WMB-10XMulti/log2', 'WMB-10XMulti/raw']

In [9]:
abc_cache.list_data_files('WMB-10Xv2')

['WMB-10Xv2-CTXsp/log2',
 'WMB-10Xv2-CTXsp/raw',
 'WMB-10Xv2-HPF/log2',
 'WMB-10Xv2-HPF/raw',
 'WMB-10Xv2-HY/log2',
 'WMB-10Xv2-HY/raw',
 'WMB-10Xv2-Isocortex-1/log2',
 'WMB-10Xv2-Isocortex-1/raw',
 'WMB-10Xv2-Isocortex-2/log2',
 'WMB-10Xv2-Isocortex-2/raw',
 'WMB-10Xv2-Isocortex-3/log2',
 'WMB-10Xv2-Isocortex-3/raw',
 'WMB-10Xv2-Isocortex-4/log2',
 'WMB-10Xv2-Isocortex-4/raw',
 'WMB-10Xv2-MB/log2',
 'WMB-10Xv2-MB/raw',
 'WMB-10Xv2-OLF/log2',
 'WMB-10Xv2-OLF/raw',
 'WMB-10Xv2-TH/log2',
 'WMB-10Xv2-TH/raw']

In [10]:
abc_cache.list_data_files('WMB-10Xv3')

['WMB-10Xv3-CB/log2',
 'WMB-10Xv3-CB/raw',
 'WMB-10Xv3-CTXsp/log2',
 'WMB-10Xv3-CTXsp/raw',
 'WMB-10Xv3-HPF/log2',
 'WMB-10Xv3-HPF/raw',
 'WMB-10Xv3-HY/log2',
 'WMB-10Xv3-HY/raw',
 'WMB-10Xv3-Isocortex-1/log2',
 'WMB-10Xv3-Isocortex-1/raw',
 'WMB-10Xv3-Isocortex-2/log2',
 'WMB-10Xv3-Isocortex-2/raw',
 'WMB-10Xv3-MB/log2',
 'WMB-10Xv3-MB/raw',
 'WMB-10Xv3-MY/log2',
 'WMB-10Xv3-MY/raw',
 'WMB-10Xv3-OLF/log2',
 'WMB-10Xv3-OLF/raw',
 'WMB-10Xv3-P/log2',
 'WMB-10Xv3-P/raw',
 'WMB-10Xv3-PAL/log2',
 'WMB-10Xv3-PAL/raw',
 'WMB-10Xv3-STR/log2',
 'WMB-10Xv3-STR/raw',
 'WMB-10Xv3-TH/log2',
 'WMB-10Xv3-TH/raw']

In [11]:
# get all log gene expression matrices
def get_matrices(dir_name):
    return [matrix for matrix in abc_cache.list_data_files(dir_name) if (matrix.endswith('Isocortex-1/log2') or matrix.endswith('Isocortex-2/log2') or matrix.endswith('STR/log2'))]

# all_matrices = sorted(get_matrices('WMB-10XMulti') + get_matrices('WMB-10Xv2') + get_matrices('WMB-10Xv3'))
all_matrices = sorted(get_matrices('WMB-10Xv3'))

In [12]:
all_matrices

['WMB-10Xv3-Isocortex-1/log2',
 'WMB-10Xv3-Isocortex-2/log2',
 'WMB-10Xv3-STR/log2']

In [13]:
all_adata = []
for matrix in all_matrices:
    directory = 'WMB-10Xv3'
    # directory = "WMB-10XMulti"
    # if ('10Xv3' in matrix): directory = 'WMB-10Xv3'
    # if ('10Xv2' in matrix): directory = 'WMB-10Xv2'
    file = abc_cache.get_data_path(directory=directory, file_name=matrix)
    temp_adata = anndata.read_h5ad(file)
    all_adata.append(temp_adata)

WMB-10Xv3-Isocortex-1-log2.h5ad: 100%|██████████| 11.8G/11.8G [10:21<00:00, 18.9MMB/s]    
WMB-10Xv3-Isocortex-2-log2.h5ad: 100%|██████████| 8.36G/8.36G [06:20<00:00, 22.0MMB/s]    
WMB-10Xv3-STR-log2.h5ad: 100%|██████████| 11.9G/11.9G [09:13<00:00, 21.5MMB/s]    


In [14]:
len(all_adata)

3

In [15]:
adata = anndata.concat(all_adata, axis=0, join='outer')
print("done")

done


In [17]:
adata

AnnData object with n_obs × n_vars = 741624 × 32285
    obs: 'cell_barcode', 'library_label', 'anatomical_division_label'

In [18]:
adata.write("/gpfs/scratch/blukacsy/abc_atlas/data/adata.h5ad")

In [19]:
adata = anndata.read_h5ad("/gpfs/scratch/blukacsy/abc_atlas/data/adata.h5ad")

In [20]:
adata

AnnData object with n_obs × n_vars = 741624 × 32285
    obs: 'cell_barcode', 'library_label', 'anatomical_division_label'

In [21]:
cell = abc_cache.get_metadata_dataframe(
    directory='WMB-10X',
    file_name='cell_metadata',
    dtype={'cell_label': str}
)
cell.set_index('cell_label', inplace=True)
print(f"Total cells: {len(cell)}")

cluster_details = abc_cache.get_metadata_dataframe(
    directory='WMB-taxonomy',
    file_name='cluster_to_cluster_annotation_membership_pivoted',
    keep_default_na=False
)
cluster_details.set_index('cluster_alias', inplace=True)

Total cells: 4042976


cluster_to_cluster_annotation_membership_pivoted.csv: 100%|██████████| 531k/531k [00:00<00:00, 1.12MMB/s] 


In [22]:
cell

Unnamed: 0_level_0,cell_barcode,barcoded_cell_sample_label,library_label,feature_matrix_label,entity,brain_section_label,library_method,region_of_interest_acronym,donor_label,donor_genotype,donor_sex,dataset_label,x,y,cluster_alias,abc_sample_id
cell_label,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1
GCGAGAAGTTAAGGGC-410_B05,GCGAGAAGTTAAGGGC,410_B05,L8TX_201030_01_C12,WMB-10Xv3-HPF,cell,,10Xv3,RHP,Snap25-IRES2-Cre;Ai14-550850,Ai14(RCL-tdT)/wt,F,WMB-10Xv3,23.146826,-3.086639,1,484be5df-5d44-4bfe-9652-7b5bc739c211
AATGGCTCAGCTCCTT-411_B06,AATGGCTCAGCTCCTT,411_B06,L8TX_201029_01_E10,WMB-10Xv3-HPF,cell,,10Xv3,RHP,Snap25-IRES2-Cre;Ai14-550851,Ai14(RCL-tdT)/wt,F,WMB-10Xv3,23.138481,-3.022000,1,5638505d-e1e8-457f-9e5b-59e3e2302417
AACACACGTTGCTTGA-410_B05,AACACACGTTGCTTGA,410_B05,L8TX_201030_01_C12,WMB-10Xv3-HPF,cell,,10Xv3,RHP,Snap25-IRES2-Cre;Ai14-550850,Ai14(RCL-tdT)/wt,F,WMB-10Xv3,23.472557,-2.992709,1,a0544e29-194f-4d34-9af4-13e7377b648f
CACAGATAGAGGCGGA-410_A05,CACAGATAGAGGCGGA,410_A05,L8TX_201029_01_A10,WMB-10Xv3-HPF,cell,,10Xv3,RHP,Snap25-IRES2-Cre;Ai14-550850,Ai14(RCL-tdT)/wt,F,WMB-10Xv3,23.379622,-3.043442,1,c777ac0b-77e1-4d76-bf8e-2b3d9e08b253
AAAGTGAAGCATTTCG-410_B05,AAAGTGAAGCATTTCG,410_B05,L8TX_201030_01_C12,WMB-10Xv3-HPF,cell,,10Xv3,RHP,Snap25-IRES2-Cre;Ai14-550850,Ai14(RCL-tdT)/wt,F,WMB-10Xv3,23.909480,-2.601536,1,49860925-e82b-46df-a228-fd2f97e75d39
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
GTGTGAGCAAACGCGA-1350_C05,GTGTGAGCAAACGCGA,1350_C05,L8XR_220728_01_A05,WMB-10XMulti,cell,,10xRSeq_Mult,MB,C57BL6J-641405,wt/wt,M,WMB-10XMulti,-7.716915,0.223654,8861,ba1d0e38-bea7-4d4f-bfcd-49121938e743
TTAGCAATCCCTGTTA-1350_C05,TTAGCAATCCCTGTTA,1350_C05,L8XR_220728_01_A05,WMB-10XMulti,cell,,10xRSeq_Mult,MB,C57BL6J-641405,wt/wt,M,WMB-10XMulti,-3.115098,-3.024478,8215,342bd0bb-cbe5-479b-9c70-fef59a730255
TTTGGCTGTCGCGCAA-1350_C05,TTTGGCTGTCGCGCAA,1350_C05,L8XR_220728_01_A05,WMB-10XMulti,cell,,10xRSeq_Mult,MB,C57BL6J-641405,wt/wt,M,WMB-10XMulti,-7.950964,0.409335,8798,4634de09-d8e0-4e40-a49b-eba311de08b5
ATCCACCTCACAGACT-1320_B04,ATCCACCTCACAGACT,1320_B04,L8XR_220630_02_B10,WMB-10XMulti,cell,,10xRSeq_Mult,OLF,C57BL6J-625156,wt/wt,F,WMB-10XMulti,4.579441,12.135833,8798,5b3061de-1cb8-47b6-9368-52824e1031ce


In [23]:
cluster_details

Unnamed: 0_level_0,neurotransmitter,class,subclass,supertype,cluster
cluster_alias,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
1,Glut,01 IT-ET Glut,018 L2 IT PPP-APr Glut,0082 L2 IT PPP-APr Glut_3,0326 L2 IT PPP-APr Glut_3
2,Glut,01 IT-ET Glut,018 L2 IT PPP-APr Glut,0082 L2 IT PPP-APr Glut_3,0327 L2 IT PPP-APr Glut_3
3,Glut,01 IT-ET Glut,018 L2 IT PPP-APr Glut,0081 L2 IT PPP-APr Glut_2,0322 L2 IT PPP-APr Glut_2
4,Glut,01 IT-ET Glut,018 L2 IT PPP-APr Glut,0081 L2 IT PPP-APr Glut_2,0323 L2 IT PPP-APr Glut_2
5,Glut,01 IT-ET Glut,018 L2 IT PPP-APr Glut,0081 L2 IT PPP-APr Glut_2,0325 L2 IT PPP-APr Glut_2
...,...,...,...,...,...
34368,GABA-Glyc,27 MY GABA,288 MDRN Hoxb5 Ebf2 Gly-Gaba,1102 MDRN Hoxb5 Ebf2 Gly-Gaba_1,4955 MDRN Hoxb5 Ebf2 Gly-Gaba_1
34372,GABA-Glyc,27 MY GABA,285 MY Lhx1 Gly-Gaba,1091 MY Lhx1 Gly-Gaba_3,4901 MY Lhx1 Gly-Gaba_3
34374,GABA-Glyc,27 MY GABA,285 MY Lhx1 Gly-Gaba,1091 MY Lhx1 Gly-Gaba_3,4902 MY Lhx1 Gly-Gaba_3
34376,GABA-Glyc,27 MY GABA,285 MY Lhx1 Gly-Gaba,1091 MY Lhx1 Gly-Gaba_3,4903 MY Lhx1 Gly-Gaba_3


In [24]:
cell_extended = cell.join(cluster_details, on='cluster_alias')
print(f"Cells with annotations: {len(cell_extended)}")

Cells with annotations: 4042976


In [25]:
cell_extended

Unnamed: 0_level_0,cell_barcode,barcoded_cell_sample_label,library_label,feature_matrix_label,entity,brain_section_label,library_method,region_of_interest_acronym,donor_label,donor_genotype,...,dataset_label,x,y,cluster_alias,abc_sample_id,neurotransmitter,class,subclass,supertype,cluster
cell_label,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
GCGAGAAGTTAAGGGC-410_B05,GCGAGAAGTTAAGGGC,410_B05,L8TX_201030_01_C12,WMB-10Xv3-HPF,cell,,10Xv3,RHP,Snap25-IRES2-Cre;Ai14-550850,Ai14(RCL-tdT)/wt,...,WMB-10Xv3,23.146826,-3.086639,1,484be5df-5d44-4bfe-9652-7b5bc739c211,Glut,01 IT-ET Glut,018 L2 IT PPP-APr Glut,0082 L2 IT PPP-APr Glut_3,0326 L2 IT PPP-APr Glut_3
AATGGCTCAGCTCCTT-411_B06,AATGGCTCAGCTCCTT,411_B06,L8TX_201029_01_E10,WMB-10Xv3-HPF,cell,,10Xv3,RHP,Snap25-IRES2-Cre;Ai14-550851,Ai14(RCL-tdT)/wt,...,WMB-10Xv3,23.138481,-3.022000,1,5638505d-e1e8-457f-9e5b-59e3e2302417,Glut,01 IT-ET Glut,018 L2 IT PPP-APr Glut,0082 L2 IT PPP-APr Glut_3,0326 L2 IT PPP-APr Glut_3
AACACACGTTGCTTGA-410_B05,AACACACGTTGCTTGA,410_B05,L8TX_201030_01_C12,WMB-10Xv3-HPF,cell,,10Xv3,RHP,Snap25-IRES2-Cre;Ai14-550850,Ai14(RCL-tdT)/wt,...,WMB-10Xv3,23.472557,-2.992709,1,a0544e29-194f-4d34-9af4-13e7377b648f,Glut,01 IT-ET Glut,018 L2 IT PPP-APr Glut,0082 L2 IT PPP-APr Glut_3,0326 L2 IT PPP-APr Glut_3
CACAGATAGAGGCGGA-410_A05,CACAGATAGAGGCGGA,410_A05,L8TX_201029_01_A10,WMB-10Xv3-HPF,cell,,10Xv3,RHP,Snap25-IRES2-Cre;Ai14-550850,Ai14(RCL-tdT)/wt,...,WMB-10Xv3,23.379622,-3.043442,1,c777ac0b-77e1-4d76-bf8e-2b3d9e08b253,Glut,01 IT-ET Glut,018 L2 IT PPP-APr Glut,0082 L2 IT PPP-APr Glut_3,0326 L2 IT PPP-APr Glut_3
AAAGTGAAGCATTTCG-410_B05,AAAGTGAAGCATTTCG,410_B05,L8TX_201030_01_C12,WMB-10Xv3-HPF,cell,,10Xv3,RHP,Snap25-IRES2-Cre;Ai14-550850,Ai14(RCL-tdT)/wt,...,WMB-10Xv3,23.909480,-2.601536,1,49860925-e82b-46df-a228-fd2f97e75d39,Glut,01 IT-ET Glut,018 L2 IT PPP-APr Glut,0082 L2 IT PPP-APr Glut_3,0326 L2 IT PPP-APr Glut_3
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
GTGTGAGCAAACGCGA-1350_C05,GTGTGAGCAAACGCGA,1350_C05,L8XR_220728_01_A05,WMB-10XMulti,cell,,10xRSeq_Mult,MB,C57BL6J-641405,wt/wt,...,WMB-10XMulti,-7.716915,0.223654,8861,ba1d0e38-bea7-4d4f-bfcd-49121938e743,GABA-Glyc,26 P GABA,278 NLL Gata3 Gly-Gaba,1074 NLL Gata3 Gly-Gaba_1,4804 NLL Gata3 Gly-Gaba_1
TTAGCAATCCCTGTTA-1350_C05,TTAGCAATCCCTGTTA,1350_C05,L8XR_220728_01_A05,WMB-10XMulti,cell,,10xRSeq_Mult,MB,C57BL6J-641405,wt/wt,...,WMB-10XMulti,-3.115098,-3.024478,8215,342bd0bb-cbe5-479b-9c70-fef59a730255,Glut,19 MB Glut,157 RN Spp1 Glut,0682 RN Spp1 Glut_1,2761 RN Spp1 Glut_1
TTTGGCTGTCGCGCAA-1350_C05,TTTGGCTGTCGCGCAA,1350_C05,L8XR_220728_01_A05,WMB-10XMulti,cell,,10xRSeq_Mult,MB,C57BL6J-641405,wt/wt,...,WMB-10XMulti,-7.950964,0.409335,8798,4634de09-d8e0-4e40-a49b-eba311de08b5,GABA-Glyc,26 P GABA,278 NLL Gata3 Gly-Gaba,1076 NLL Gata3 Gly-Gaba_3,4806 NLL Gata3 Gly-Gaba_3
ATCCACCTCACAGACT-1320_B04,ATCCACCTCACAGACT,1320_B04,L8XR_220630_02_B10,WMB-10XMulti,cell,,10xRSeq_Mult,OLF,C57BL6J-625156,wt/wt,...,WMB-10XMulti,4.579441,12.135833,8798,5b3061de-1cb8-47b6-9368-52824e1031ce,GABA-Glyc,26 P GABA,278 NLL Gata3 Gly-Gaba,1076 NLL Gata3 Gly-Gaba_3,4806 NLL Gata3 Gly-Gaba_3


In [26]:
cell_extended["neurotransmitter"].value_counts()

neurotransmitter
Glut         2054137
             1089152
GABA          834601
GABA-Glyc      36490
Dopa            9396
Glut-GABA       8989
Chol            7582
Sero            1469
Nora             626
Hist             534
Name: count, dtype: int64

In [27]:
cell_extended["class"].value_counts()

class
01 IT-ET Glut        1095484
31 OPC-Oligo          545179
02 NP-CT-L6b Glut     310198
30 Astro-Epen         308681
29 CB Glut            141106
06 CTX-CGE GABA       139032
33 Vascular           137493
07 CTX-MGE GABA       122085
19 MB Glut            120552
18 TH Glut            115401
09 CNU-LGE GABA       115160
05 OB-IMN GABA        107502
34 Immune              92580
12 HY GABA             90376
04 DG-IMN Glut         84352
20 MB GABA             82032
11 CNU-HYa GABA        78482
14 HY Glut             66984
28 CB GABA             51226
13 CNU-HYa Glut        45685
27 MY GABA             33311
10 LSX GABA            30313
24 MY Glut             27543
23 P Glut              25303
26 P GABA              20206
08 CNU-MGE GABA        18849
16 HY MM Glut          13730
17 MH-LH Glut          10770
03 OB-CR Glut           4767
21 MB Dopa              4301
22 MB-HB Sero           2466
32 OEC                  1132
25 Pineal Glut           504
15 HY Gnrh1 Glut         191
Name: co

In [28]:
print(f"Cells missing neuro: {cell_extended['neurotransmitter'].isna().sum()}")

Cells missing neuro: 0


In [29]:
cell_extended.index

Index(['GCGAGAAGTTAAGGGC-410_B05', 'AATGGCTCAGCTCCTT-411_B06',
       'AACACACGTTGCTTGA-410_B05', 'CACAGATAGAGGCGGA-410_A05',
       'AAAGTGAAGCATTTCG-410_B05', 'GATCGTATCGAATCCA-411_B06',
       'AGATGAAAGGACCCAA-410_A05', 'TCTCACGGTCAGGAGT-411_A06',
       'GATTCTTGTTCGCGTG-410_B05', 'TTTCGATAGTAAAGCT-410_B05',
       ...
       'TGTCCTTCATCTAGCA-1320_D04', 'TTAACTGAGTCAGGCC-1320_D04',
       'TGGTTCTGTCTATCGT-1315_A01', 'ATCATGTCATCATGGC-1350_C05',
       'ATTCCTCCAGCCAGAA-1350_C05', 'GTGTGAGCAAACGCGA-1350_C05',
       'TTAGCAATCCCTGTTA-1350_C05', 'TTTGGCTGTCGCGCAA-1350_C05',
       'ATCCACCTCACAGACT-1320_B04', 'TCGTTAGCATTGTCCT-1320_B04'],
      dtype='object', name='cell_label', length=4042976)

In [None]:
# adata = adata[cell_extended.index, :]

In [60]:
# adata

AnnData object with n_obs × n_vars = 741624 × 32285
    obs: 'cell_barcode', 'library_label', 'anatomical_division_label', 'neurotransmitter'

In [30]:
adata.obs['neurotransmitter'] = cell_extended['neurotransmitter']

In [31]:
adata.obs['neurotransmitter']

cell_label
GCACTAAGTACAAGTA-399_B02    Glut
CAGAGCCGTCCTGGTG-399_A02    Glut
CGCCATTCACGACAGA-403_A06    Glut
GGGTTTATCCGATCGG-399_B02    Glut
GTGTGATCAGACAAAT-399_B02    Glut
                            ... 
TTTGTTGTCCGGTTCT-346_C06        
TTTGTTGTCGGAACTT-417_B03        
TTTGTTGTCGGTAAGG-454_B03        
TTTGTTGTCTACTGAG-577_B03        
TTTGTTGTCTATGTGG-415_C01        
Name: neurotransmitter, Length: 741624, dtype: object

In [32]:
s_assigned = adata.obs["neurotransmitter"]
s_expected = cell_extended["neurotransmitter"].reindex(adata.obs_names)

# 1) Quick truth: are they identical (treats NaNs in the same spots as equal)?
print("identical (incl NaNs):", s_assigned.equals(s_expected))

# 2) Count *real* mismatches (ignore rows where both are NaN)
real_mismatch = (s_assigned != s_expected) & ~(s_assigned.isna() & s_expected.isna())
print("real mismatches:", int(real_mismatch.sum()))

# 3) Diagnose why you have NaNs
not_in_global = ~adata.obs_names.isin(cell_extended.index)
print("subset cells not found in cell_extended index:", int(not_in_global.sum()))

present_idx = adata.obs_names[~not_in_global]
nan_in_global = cell_extended.loc[present_idx, "neurotransmitter"].isna().sum()
print("present but neurotransmitter is NaN in global:", int(nan_in_global))

identical (incl NaNs): True
real mismatches: 0
subset cells not found in cell_extended index: 2635
present but neurotransmitter is NaN in global: 0


In [33]:
adata.obs['class'] = cell_extended['class']
adata.obs['subclass'] = cell_extended['subclass']
adata.obs['supertype'] = cell_extended['supertype']
adata.obs['cluster'] = cell_extended['cluster']

In [34]:
adata

AnnData object with n_obs × n_vars = 741624 × 32285
    obs: 'cell_barcode', 'library_label', 'anatomical_division_label', 'neurotransmitter', 'class', 'subclass', 'supertype', 'cluster'

In [35]:
adata.obs["neurotransmitter"]

cell_label
GCACTAAGTACAAGTA-399_B02    Glut
CAGAGCCGTCCTGGTG-399_A02    Glut
CGCCATTCACGACAGA-403_A06    Glut
GGGTTTATCCGATCGG-399_B02    Glut
GTGTGATCAGACAAAT-399_B02    Glut
                            ... 
TTTGTTGTCCGGTTCT-346_C06        
TTTGTTGTCGGAACTT-417_B03        
TTTGTTGTCGGTAAGG-454_B03        
TTTGTTGTCTACTGAG-577_B03        
TTTGTTGTCTATGTGG-415_C01        
Name: neurotransmitter, Length: 741624, dtype: object

In [36]:
adata.obs["class"]

cell_label
GCACTAAGTACAAGTA-399_B02    01 IT-ET Glut
CAGAGCCGTCCTGGTG-399_A02    01 IT-ET Glut
CGCCATTCACGACAGA-403_A06    01 IT-ET Glut
GGGTTTATCCGATCGG-399_B02    01 IT-ET Glut
GTGTGATCAGACAAAT-399_B02    01 IT-ET Glut
                                ...      
TTTGTTGTCCGGTTCT-346_C06        34 Immune
TTTGTTGTCGGAACTT-417_B03      33 Vascular
TTTGTTGTCGGTAAGG-454_B03      33 Vascular
TTTGTTGTCTACTGAG-577_B03      33 Vascular
TTTGTTGTCTATGTGG-415_C01      33 Vascular
Name: class, Length: 741624, dtype: object

In [37]:
adata.obs["neurotransmitter"].value_counts()

neurotransmitter
Glut         289785
             249656
GABA         198680
Glut-GABA       453
Chol            382
Dopa             30
Hist              1
Nora              1
GABA-Glyc         1
Name: count, dtype: int64

In [38]:
adata.obs["class"].value_counts()

class
01 IT-ET Glut        202054
30 Astro-Epen         90255
31 OPC-Oligo          80103
09 CNU-LGE GABA       78589
02 NP-CT-L6b Glut     65425
11 CNU-HYa GABA       40201
33 Vascular           40049
34 Immune             37720
10 LSX GABA           29176
06 CTX-CGE GABA       24102
13 CNU-HYa Glut       20304
07 CTX-MGE GABA       19366
08 CNU-MGE GABA        5266
05 OB-IMN GABA         4824
03 OB-CR Glut           620
04 DG-IMN Glut          344
14 HY Glut              265
12 HY GABA              147
18 TH Glut               63
32 OEC                   46
16 HY MM Glut            23
25 Pineal Glut           19
19 MB Glut               10
20 MB GABA                5
29 CB Glut                5
15 HY Gnrh1 Glut          2
24 MY Glut                2
17 MH-LH Glut             2
27 MY GABA                1
21 MB Dopa                1
Name: count, dtype: int64

In [39]:
col = "neurotransmitter"

n_before = adata.n_obs
n_nan = int(adata.obs[col].isna().sum()) # 2635
print(n_before)
print(n_nan)

adata = adata[~adata.obs[col].isna()].copy()
s = adata.obs[col].astype("string")
adata.obs.loc[s.str.strip().eq(""), col] = "no-name"

print(f"Dropped {n_nan} unlabeled cells; {n_before} → {adata.n_obs}")
print(adata.obs[col].value_counts(dropna=False))

741624
2635
Dropped 2635 unlabeled cells; 741624 → 738989
neurotransmitter
Glut         289785
no-name      249656
GABA         198680
Glut-GABA       453
Chol            382
Dopa             30
Hist              1
Nora              1
GABA-Glyc         1
Name: count, dtype: int64


In [40]:
def print_stats(anndata, level):
    unique_celltype_sub = anndata.obs[level].unique()
    print(unique_celltype_sub)
    num_unique_celltype_sub = anndata.obs[level].nunique()
    print(f"Number of unique sub cell types: {num_unique_celltype_sub}")
    print()

In [41]:
print_stats(adata, "neurotransmitter")
print_stats(adata, "class")

['Glut' 'GABA' 'no-name' 'Dopa' 'Hist' 'Chol' 'Nora' 'GABA-Glyc'
 'Glut-GABA']
Number of unique sub cell types: 9

['01 IT-ET Glut' '11 CNU-HYa GABA' '05 OB-IMN GABA' '09 CNU-LGE GABA'
 '07 CTX-MGE GABA' '02 NP-CT-L6b Glut' '04 DG-IMN Glut' '12 HY GABA'
 '15 HY Gnrh1 Glut' '13 CNU-HYa Glut' '14 HY Glut' '03 OB-CR Glut'
 '16 HY MM Glut' '24 MY Glut' '27 MY GABA' '25 Pineal Glut'
 '30 Astro-Epen' '08 CNU-MGE GABA' '31 OPC-Oligo' '32 OEC' '33 Vascular'
 '34 Immune' '06 CTX-CGE GABA' '10 LSX GABA' '20 MB GABA' '19 MB Glut'
 '21 MB Dopa' '17 MH-LH Glut' '18 TH Glut' '29 CB Glut']
Number of unique sub cell types: 30



In [42]:
adata.obs["cluster"].value_counts()

cluster
5225 Astro-TE NN_3                         60004
0109 L2/3 IT CTX Glut_2                    40156
5312 Microglia NN_1                        35218
5285 MOL NN_4                              33538
5269 OPC NN_1                              29289
                                           ...  
2546 PH-SUM Foxa1 Glut_3                       1
2542 PH-SUM Foxa1 Glut_2                       1
2559 PH-SUM Foxa1 Glut_6                       1
2608 LH Pou4f1 Sox1 Glut_3                     1
3656 LGv-SPFp-SPFm Nkx2-2 Tcf7l2 Gaba_4        1
Name: count, Length: 1799, dtype: int64

In [43]:
level = "cluster"
min_count = 6

# drop rare classes (<6 cells b/c we are doing 80/20 split twice, so 6->5,1, 5->4,1)
s = adata.obs[level]
counts = s.value_counts()
rare_labels = counts[counts < min_count].index
keep_mask = ~s.isin(rare_labels)

print(f"Dropping {(~keep_mask).sum()} cells from {len(rare_labels)} rare labels: {list(rare_labels)}")
adata = adata[keep_mask].copy()

Dropping 880 cells from 479 rare labels: ['1283 MEA-BST Lhx6 Nfib Gaba_1', '0584 OB Dopa-Gaba_1', '0577 OB-out Frmd7 Gaba_2', '0574 OB-out Frmd7 Gaba_1', '0022 IT EP-CLA Glut_4', '0337 L2/3 IT PPP Glut_3', '1266 MEA-BST Lhx6 Sp9 Gaba_7', '0892 NDB-SI-MA-STRv Lhx8 Gaba_4', '1307 MEA-BST Lhx6 Nfib Gaba_5', '0520 OB Meis2 Thsd7b Gaba_1', '0010 IT EP-CLA Glut_1', '0328 L2 IT PPP-APr Glut_4', '2574 MM Foxb1 Glut_2', '1540 BST-MPN Six3 Nrgn Gaba_2', '0705 RHP-COA Ndnf Gaba_6', '2186 SI-MA-LPO-LHA Skor1 Glut_6', '0924 PAL-STR Gaba-Chol_1', '0198 MEA Slc17a7 Glut_1', '0553 OB-in Frmd7 Gaba_3', '2238 MPN-MPO-PVpo Hmx2 Glut_1', '1989 BST-po Iigp1 Glut_1', '1600 RT-ZI Gnb3 Gaba_5', '2202 SI-MA-LPO-LHA Skor1 Glut_8', '1526 PVR Six3 Sox3 Gaba_7', '0403 L6b EPd Glut_1', '5197 CB Granule Glut_1', '5222 Astro-TE NN_2', '0630 Vip Gaba_2', '5290 OEC NN_1', '1481 MPO-ADP Lhx8 Gaba_5', '0406 L6b EPd Glut_2', '0934 PAL-STR Gaba-Chol_3', '0866 NDB-SI-MA-STRv Lhx8 Gaba_1', '2072 MS-SF Bsx Glut_1', '2054 MS-S

In [44]:
adata.obs["cluster"].value_counts()

cluster
5225 Astro-TE NN_3                 60004
0109 L2/3 IT CTX Glut_2            40156
5312 Microglia NN_1                35218
5285 MOL NN_4                      33538
5269 OPC NN_1                      29289
                                   ...  
2036 COAa-PAA-MEA Barhl2 Glut_5        6
1494 BST Tac2 Gaba_1                   6
2037 COAa-PAA-MEA Barhl2 Glut_5        6
0775 Sst Gaba_3                        6
5320 ILC NN_2                          6
Name: count, Length: 1320, dtype: int64

In [45]:
adata.obs['feature_matrix_label'] = cell_extended['feature_matrix_label']

In [46]:
adata.obs['feature_matrix_label'].value_counts()

feature_matrix_label
WMB-10Xv3-STR            283132
WMB-10Xv3-Isocortex-2    227534
WMB-10Xv3-Isocortex-1    227443
Name: count, dtype: int64

In [47]:
sc.pp.highly_variable_genes(adata, n_top_genes=3500, batch_key='feature_matrix_label')

In [48]:
adata

AnnData object with n_obs × n_vars = 738109 × 32285
    obs: 'cell_barcode', 'library_label', 'anatomical_division_label', 'neurotransmitter', 'class', 'subclass', 'supertype', 'cluster', 'feature_matrix_label'
    var: 'highly_variable', 'means', 'dispersions', 'dispersions_norm', 'highly_variable_nbatches', 'highly_variable_intersection'
    uns: 'hvg'

In [49]:
mouse_whole_brain_anndata = adata

In [50]:
mouse_whole_brain_anndata.write("/gpfs/scratch/blukacsy/abc_atlas/data/mouse_whole_brain_anndata.h5ad")

In [51]:
mouse_whole_brain_anndata = anndata.read_h5ad("/gpfs/scratch/blukacsy/abc_atlas/data/mouse_whole_brain_anndata.h5ad")

In [52]:
mouse_whole_brain_anndata

AnnData object with n_obs × n_vars = 738109 × 32285
    obs: 'cell_barcode', 'library_label', 'anatomical_division_label', 'neurotransmitter', 'class', 'subclass', 'supertype', 'cluster', 'feature_matrix_label'
    var: 'highly_variable', 'means', 'dispersions', 'dispersions_norm', 'highly_variable_nbatches', 'highly_variable_intersection'
    uns: 'hvg'

In [53]:
from pandas.api.types import is_categorical_dtype

label = "0001 CLA-EPd-CTX Car3 Glut_1"
parent_col = "supertype"
child_col = "cluster"

parent_new = f"{label}_broad"
child_new  = f"{label}_fine"

def explicit_rename(ad, parent_col, child_col, old, new_parent, new_child):
    
    s_parent = ad.obs[parent_col].astype("string")
    s_child  = ad.obs[child_col].astype("string")

    n_parent = int((s_parent == old).sum())
    n_child  = int((s_child  == old).sum())

    if is_categorical_dtype(ad.obs[parent_col]):
        ad.obs[parent_col] = ad.obs[parent_col].cat.add_categories([new_parent])
    if is_categorical_dtype(ad.obs[child_col]):
        ad.obs[child_col] = ad.obs[child_col].cat.add_categories([new_child])

    ad.obs.loc[s_parent == old, parent_col] = new_parent
    ad.obs.loc[s_child  == old, child_col]  = new_child

    if is_categorical_dtype(ad.obs[parent_col]):
        ad.obs[parent_col] = ad.obs[parent_col].cat.remove_unused_categories()
    if is_categorical_dtype(ad.obs[child_col]):
        ad.obs[child_col] = ad.obs[child_col].cat.remove_unused_categories()

    print(f"Renamed {n_parent} in {parent_col} → {new_parent}")
    print(f"Renamed {n_child} in {child_col}  → {new_child}")

explicit_rename(mouse_whole_brain_anndata, parent_col, child_col, label, parent_new, child_new)

Renamed 1663 in supertype → 0001 CLA-EPd-CTX Car3 Glut_1_broad
Renamed 78 in cluster  → 0001 CLA-EPd-CTX Car3 Glut_1_fine


  if is_categorical_dtype(ad.obs[parent_col]):
  if is_categorical_dtype(ad.obs[child_col]):
  if is_categorical_dtype(ad.obs[parent_col]):
  if is_categorical_dtype(ad.obs[child_col]):


In [54]:
mouse_whole_brain_anndata.write("/gpfs/scratch/blukacsy/abc_atlas/data/mouse_whole_brain_anndata.h5ad")

In [55]:
def holdout_subset(name, anndata, split, seed):
    train_anndata, test_anndata = train_test_split(anndata.obs.index, test_size=split, random_state=seed, stratify=anndata.obs["cluster"].values)
    print(f'{name}_train_anndata shape', train_anndata.shape)
    print(f'{name}_test_anndata shape', test_anndata.shape)
    print()

    train_anndata = anndata[train_anndata].copy()
    test_anndata = anndata[test_anndata].copy()

    train_anndata.write(f"/gpfs/scratch/blukacsy/abc_atlas/data/{name}_train_anndata.h5ad")
    test_anndata.write(f"/gpfs/scratch/blukacsy/abc_atlas/data/{name}_test_anndata.h5ad")

In [56]:
seed = 8653
split = 0.2
holdout_subset("mouse_whole_brain", mouse_whole_brain_anndata, split, seed)

mouse_whole_brain_train_anndata shape (590487,)
mouse_whole_brain_test_anndata shape (147622,)



In [57]:
mouse_whole_brain_train_anndata = anndata.read_h5ad("/gpfs/scratch/blukacsy/abc_atlas/data/mouse_whole_brain_train_anndata.h5ad")
mouse_whole_brain_test_anndata = anndata.read_h5ad("/gpfs/scratch/blukacsy/abc_atlas/data/mouse_whole_brain_test_anndata.h5ad")

In [58]:
print_stats(mouse_whole_brain_train_anndata, "cluster")
print_stats(mouse_whole_brain_test_anndata, "cluster")

['0953 STR D1 Gaba_4', '1279 MEA-BST Lhx6 Nr2e1 Gaba_2', '5227 Astro-TE NN_3', '5284 MOL NN_4', '1291 MEA-BST Lhx6 Nfib Gaba_2', ..., '0869 NDB-SI-MA-STRv Lhx8 Gaba_1', '2191 SI-MA-LPO-LHA Skor1 Glut_6', '0401 CA2-FC-IG Glut_2', '0868 NDB-SI-MA-STRv Lhx8 Gaba_1', '1286 MEA-BST Lhx6 Nfib Gaba_1']
Length: 1320
Categories (1320, object): ['0002 CLA-EPd-CTX Car3 Glut_1', '0003 CLA-EPd-CTX Car3 Glut_1', '0004 CLA-EPd-CTX Car3 Glut_1', '0006 IT EP-CLA Glut_1', ..., '5320 ILC NN_2', '5321 NK cells NN_3', '5322 T cells NN_4', '0001 CLA-EPd-CTX Car3 Glut_1_fine']
Number of unique sub cell types: 1320

['5311 Endo NN_1', '0116 L2/3 IT CTX Glut_3', '0962 STR D1 Gaba_8', '0612 OB-STR-CTX Inh IMN_4', '1384 CEA-BST Ebf1 Pdyn Gaba_1', ..., '0694 RHP-COA Ndnf Gaba_3', '0564 OB-in Frmd7 Gaba_5', '0869 NDB-SI-MA-STRv Lhx8 Gaba_1', '1021 IA Mgp Gaba_3', '0525 OB Meis2 Thsd7b Gaba_3']
Length: 1320
Categories (1320, object): ['0002 CLA-EPd-CTX Car3 Glut_1', '0003 CLA-EPd-CTX Car3 Glut_1', '0004 CLA-EPd-CTX

In [59]:
mouse_whole_brain_train_anndata.obs["cluster"].value_counts()

cluster
5225 Astro-TE NN_3                    48003
0109 L2/3 IT CTX Glut_2               32125
5312 Microglia NN_1                   28174
5285 MOL NN_4                         26830
5269 OPC NN_1                         23431
                                      ...  
1522 PVR Six3 Sox3 Gaba_5                 5
1489 MPN-MPO-LPO Lhx6 Zfhx3 Gaba_2        5
0560 OB-in Frmd7 Gaba_5                   5
0568 OB-out Frmd7 Gaba_1                  5
5320 ILC NN_2                             5
Name: count, Length: 1320, dtype: int64

In [60]:
mouse_whole_brain_test_anndata.obs["cluster"].value_counts()

cluster
5225 Astro-TE NN_3              12001
0109 L2/3 IT CTX Glut_2          8031
5312 Microglia NN_1              7044
5285 MOL NN_4                    6708
5269 OPC NN_1                    5858
                                ...  
0147 L2/3 IT PIR-ENTl Glut_2        1
0152 L2/3 IT PIR-ENTl Glut_3        1
0158 L2/3 IT PIR-ENTl Glut_4        1
0654 Vip Gaba_9                     1
5320 ILC NN_2                       1
Name: count, Length: 1320, dtype: int64

In [61]:
hierarchy = ['class', 'subclass', 'supertype', 'cluster']

In [62]:
mouse_whole_brain_hierarchy_dict = {}

In [63]:
def add_path(root, path):
    global count
    node = root
    prev = None 
    for label in path:
        if (label == prev): continue
        node = node.setdefault(label, {})
        prev = label

def create_hierarchy_dict(anndata, hierarchy_dict):
    unique_paths = anndata.obs[hierarchy].drop_duplicates().values
    for path in unique_paths: add_path(hierarchy_dict, path)

In [64]:
create_hierarchy_dict(mouse_whole_brain_train_anndata, mouse_whole_brain_hierarchy_dict)

In [65]:
import json
def save_hierarchy_dict(name, hierarchy_dict):
    with open(f"/gpfs/scratch/blukacsy/abc_atlas/data/{name}_hierarchy_dict.json", "w") as file:
        json.dump(hierarchy_dict, file, indent=4)

In [66]:
save_hierarchy_dict("mouse_whole_brain", mouse_whole_brain_hierarchy_dict)

In [67]:
import re
def create_name(input): return re.sub(r"[^A-Za-z0-9]+", "_", input).strip("_").lower()

In [68]:
def get_leaves(tree):
    res = []
    for key, value in tree.items():
        if value: res.extend(get_leaves(value))
        else: res.append(key)
    return res

In [69]:
def preprocess_node(cell_name, dataset_name, node, anndata, split, seed):
    sub_dict = {key: get_leaves(value) if value else [key] for key, value in node.items()}
    int_mapping = {key: idx for idx, key in enumerate(sub_dict)}

    cell_name = create_name(cell_name)
    with open(f"/gpfs/scratch/blukacsy/abc_atlas/data/{dataset_name}_int_mapping_{cell_name}.json", "w") as file:
        json.dump(int_mapping, file, indent=4)

    reverse_mapping = {value: key for key, values in sub_dict.items() for value in values}
    
    finest_level = hierarchy[-1]
    anndata_subset = anndata[anndata.obs[finest_level].isin(reverse_mapping)].copy()
    anndata_subset.obs["cell_names"] = anndata_subset.obs[finest_level].map(reverse_mapping)
    anndata_subset.obs["cell_integers"] = anndata_subset.obs["cell_names"].map(int_mapping)
    anndata_subset_hvg = anndata_subset[:, anndata_subset.var['highly_variable']].copy()

    if scipy.sparse.issparse(anndata_subset_hvg.X):
        X = anndata_subset_hvg.X.toarray()
    else:
        X = anndata_subset_hvg.X
    
    y = anndata_subset_hvg.obs["cell_integers"].values

    train_features, val_features, train_labels, val_labels = train_test_split(X, y, test_size=split, random_state=seed, stratify=y)
    weights = compute_class_weight(class_weight='balanced', classes=np.unique(train_labels), y=train_labels)

    train_features = np.array(train_features)
    val_features = np.array(val_features)
    train_labels = np.array(train_labels)
    val_labels = np.array(val_labels)
    weights = np.array(weights)

    np.save(f'/gpfs/scratch/blukacsy/abc_atlas/arrays/{dataset_name}_train_features_hvg_{cell_name}.npy', train_features)
    np.save(f'/gpfs/scratch/blukacsy/abc_atlas/arrays/{dataset_name}_val_features_hvg_{cell_name}.npy', val_features)
    np.save(f'/gpfs/scratch/blukacsy/abc_atlas/arrays/{dataset_name}_train_labels_hvg_{cell_name}.npy', train_labels)
    np.save(f'/gpfs/scratch/blukacsy/abc_atlas/arrays/{dataset_name}_val_labels_hvg_{cell_name}.npy', val_labels)
    np.save(f'/gpfs/scratch/blukacsy/abc_atlas/arrays/{dataset_name}_weights_hvg_{cell_name}.npy', weights)

    anndata_subset_hvg.write(f"/gpfs/scratch/blukacsy/abc_atlas/data/{dataset_name}_train_anndata_hvg_{cell_name}.h5ad")

In [70]:
def preprocess_leaf(cell_name, dataset_name, leaf, anndata, split, seed):
    int_mapping = {key: idx for idx, key in enumerate(leaf)}

    cell_name = create_name(cell_name)
    with open(f"/gpfs/scratch/blukacsy/abc_atlas/data/{dataset_name}_int_mapping_{cell_name}.json", "w") as file:
        json.dump(int_mapping, file, indent=4)

    finest_level = hierarchy[-1]
    anndata_subset = anndata[anndata.obs[finest_level].isin(leaf)].copy()
    anndata_subset.obs["cell_integers"] = anndata_subset.obs[finest_level].map(int_mapping)
    anndata_subset_hvg = anndata_subset[:, anndata_subset.var['highly_variable']].copy()

    if scipy.sparse.issparse(anndata_subset_hvg.X):
        X = anndata_subset_hvg.X.toarray()
    else:
        X = anndata_subset_hvg.X
    
    y = anndata_subset_hvg.obs["cell_integers"].values

    train_features, val_features, train_labels, val_labels = train_test_split(X, y, test_size=split, random_state=seed, stratify=y)
    weights = compute_class_weight(class_weight='balanced', classes=np.unique(train_labels), y=train_labels)

    train_features = np.array(train_features)
    val_features = np.array(val_features)
    train_labels = np.array(train_labels)
    val_labels = np.array(val_labels)
    weights = np.array(weights)

    np.save(f'/gpfs/scratch/blukacsy/abc_atlas/arrays/{dataset_name}_train_features_hvg_{cell_name}.npy', train_features)
    np.save(f'/gpfs/scratch/blukacsy/abc_atlas/arrays/{dataset_name}_val_features_hvg_{cell_name}.npy', val_features)
    np.save(f'/gpfs/scratch/blukacsy/abc_atlas/arrays/{dataset_name}_train_labels_hvg_{cell_name}.npy', train_labels)
    np.save(f'/gpfs/scratch/blukacsy/abc_atlas/arrays/{dataset_name}_val_labels_hvg_{cell_name}.npy', val_labels)
    np.save(f'/gpfs/scratch/blukacsy/abc_atlas/arrays/{dataset_name}_weights_hvg_{cell_name}.npy', weights)

    anndata_subset_hvg.write(f"/gpfs/scratch/blukacsy/abc_atlas/data/{dataset_name}_train_anndata_hvg_{cell_name}.h5ad")

In [71]:
def hierarchical_classification(cell_name, dataset_name, dict, anndata, split, seed):
    children = list(dict)
    if (not children): return

    if all(not dict[child] for child in children):
        preprocess_leaf(cell_name, dataset_name, children, anndata, split, seed)
        return
    
    preprocess_node(cell_name, dataset_name, dict, anndata, split, seed)
    for child in children: hierarchical_classification(child, dataset_name, dict[child], anndata, split, seed)

In [72]:
seed = 6296
split = 0.2
hierarchical_classification("class", "mouse_whole_brain", mouse_whole_brain_hierarchy_dict, mouse_whole_brain_train_anndata, split, seed)
print("done")

done


In [73]:
def flat_classification(dataset_name, anndata, split, seed, level):
    finest_level = level # hierarchy[-1]
    cell_types = sorted(anndata.obs[finest_level].unique())

    int_mapping = {groups: i for i, groups in enumerate(cell_types)}
    with open(f"/gpfs/scratch/blukacsy/abc_atlas/data/{dataset_name}_{level}_int_mapping_flat.json", "w") as file:
        json.dump(int_mapping, file, indent=4)

    anndata.obs["cell_integers"] = anndata.obs[finest_level].map(int_mapping)

    anndata_hvg = anndata[:, anndata.var['highly_variable']].copy()

    if scipy.sparse.issparse(anndata_hvg.X):
        X = anndata_hvg.X.toarray()
    else:
        X = anndata_hvg.X
    
    y = anndata_hvg.obs["cell_integers"].values

    train_features, val_features, train_labels, val_labels = train_test_split(X, y, test_size=split, random_state=seed, stratify=y)
    weights = compute_class_weight(class_weight='balanced', classes=np.unique(train_labels), y=train_labels)

    train_features = np.array(train_features)
    val_features = np.array(val_features)
    train_labels = np.array(train_labels)
    val_labels = np.array(val_labels)
    weights = np.array(weights)

    np.save(f'/gpfs/scratch/blukacsy/abc_atlas/arrays/{dataset_name}_{level}_train_features_hvg_flat.npy', train_features)
    np.save(f'/gpfs/scratch/blukacsy/abc_atlas/arrays/{dataset_name}_{level}_val_features_hvg_flat.npy', val_features)
    np.save(f'/gpfs/scratch/blukacsy/abc_atlas/arrays/{dataset_name}_{level}_train_labels_hvg_flat.npy', train_labels)
    np.save(f'/gpfs/scratch/blukacsy/abc_atlas/arrays/{dataset_name}_{level}_val_labels_hvg_flat.npy', val_labels)
    np.save(f'/gpfs/scratch/blukacsy/abc_atlas/arrays/{dataset_name}_{level}_weights_hvg_flat.npy', weights)

    anndata_hvg.write(f"/gpfs/scratch/blukacsy/abc_atlas/data/{dataset_name}_{level}_train_anndata_hvg_flat.h5ad")

In [74]:
seed = 6296
split = 0.2
flat_classification("mouse_whole_brain", mouse_whole_brain_train_anndata, split, seed, "neurotransmitter")
flat_classification("mouse_whole_brain", mouse_whole_brain_train_anndata, split, seed, "class")
flat_classification("mouse_whole_brain", mouse_whole_brain_train_anndata, split, seed, "subclass")
flat_classification("mouse_whole_brain", mouse_whole_brain_train_anndata, split, seed, "supertype")
flat_classification("mouse_whole_brain", mouse_whole_brain_train_anndata, split, seed, "cluster")
print("boo")

boo


In [77]:
# look at data to confirm

cell_name = create_name("18_th_glut")
dataset_name = "mouse_whole_brain"

anndata = sc.read_h5ad(f"/gpfs/scratch/blukacsy/abc_atlas/data/{dataset_name}_train_anndata_hvg_{cell_name}.h5ad")

print(anndata)
print()
print(anndata.obs[hierarchy[-1]].value_counts())
print(anndata.obs[hierarchy[-1]].unique())
print()
print(anndata.obs["cell_names"].value_counts())
print(anndata.obs["cell_names"].unique())
print()
print(anndata.obs["cell_integers"].value_counts())
print(anndata.obs["cell_integers"].unique())
print()

train_features = np.load(f"/gpfs/scratch/blukacsy/abc_atlas/arrays/{dataset_name}_train_features_hvg_{cell_name}.npy")
val_features = np.load(f"/gpfs/scratch/blukacsy/abc_atlas/arrays/{dataset_name}_val_features_hvg_{cell_name}.npy")
train_labels = np.load(f"/gpfs/scratch/blukacsy/abc_atlas/arrays/{dataset_name}_train_labels_hvg_{cell_name}.npy")
val_labels = np.load(f"/gpfs/scratch/blukacsy/abc_atlas/arrays/{dataset_name}_val_labels_hvg_{cell_name}.npy")
weights = np.load(f"/gpfs/scratch/blukacsy/abc_atlas/arrays/{dataset_name}_weights_hvg_{cell_name}.npy")

print('train features shape:', train_features.shape)
print('val features shape:', val_features.shape)
print('train labels shape:', train_labels.shape)
print('val labels shape:', val_labels.shape)
print('weights shape:', weights.shape)
print()

class_weights = dict(enumerate(weights))
print(class_weights)

AnnData object with n_obs × n_vars = 36 × 3500
    obs: 'cell_barcode', 'library_label', 'anatomical_division_label', 'neurotransmitter', 'class', 'subclass', 'supertype', 'cluster', 'feature_matrix_label', 'cell_names', 'cell_integers'
    var: 'highly_variable', 'means', 'dispersions', 'dispersions_norm', 'highly_variable_nbatches', 'highly_variable_intersection'
    uns: 'hvg'

cluster
2651 TH Prkcd Grin2c Glut_3    22
2616 AV Col27a1 Glut_1         14
Name: count, dtype: int64
['2651 TH Prkcd Grin2c Glut_3', '2616 AV Col27a1 Glut_1']
Categories (2, object): ['2616 AV Col27a1 Glut_1', '2651 TH Prkcd Grin2c Glut_3']

cell_names
151 TH Prkcd Grin2c Glut    22
148 AV Col27a1 Glut         14
Name: count, dtype: int64
['151 TH Prkcd Grin2c Glut', '148 AV Col27a1 Glut']
Categories (2, object): ['148 AV Col27a1 Glut', '151 TH Prkcd Grin2c Glut']

cell_integers
0    22
1    14
Name: count, dtype: int64
[0, 1]
Categories (2, int64): [1, 0]

train features shape: (28, 3500)
val features shape

In [39]:
level = "class"
anndata = sc.read_h5ad(f"/gpfs/scratch/blukacsy/abc_atlas/data/{dataset_name}_{level}_train_anndata_hvg_flat.h5ad")

print(anndata)
print()
print(anndata.obs[level].value_counts())
print(anndata.obs[level].unique())
print()
print(anndata.obs["cell_integers"].value_counts())
print(anndata.obs["cell_integers"].unique())
print()

train_features = np.load(f"/gpfs/scratch/blukacsy/abc_atlas/arrays/{dataset_name}_{level}_train_features_hvg_flat.npy")
val_features = np.load(f"/gpfs/scratch/blukacsy/abc_atlas/arrays/{dataset_name}_{level}_val_features_hvg_flat.npy")
train_labels = np.load(f"/gpfs/scratch/blukacsy/abc_atlas/arrays/{dataset_name}_{level}_train_labels_hvg_flat.npy")
val_labels = np.load(f"/gpfs/scratch/blukacsy/abc_atlas/arrays/{dataset_name}_{level}_val_labels_hvg_flat.npy")
weights = np.load(f"/gpfs/scratch/blukacsy/abc_atlas/arrays/{dataset_name}_{level}_weights_hvg_flat.npy")

print('train features shape:', train_features.shape)
print('val features shape:', val_features.shape)
print('train labels shape:', train_labels.shape)
print('val labels shape:', val_labels.shape)
print('weights shape:', weights.shape)
print()

class_weights = dict(enumerate(weights))
print(class_weights)

AnnData object with n_obs × n_vars = 590487 × 3500
    obs: 'cell_barcode', 'library_label', 'anatomical_division_label', 'neurotransmitter', 'class', 'subclass', 'supertype', 'cluster', 'feature_matrix_label', 'cell_integers'
    var: 'highly_variable', 'means', 'dispersions', 'dispersions_norm', 'highly_variable_nbatches', 'highly_variable_intersection'
    uns: 'hvg'

class
01 IT-ET Glut        161547
30 Astro-Epen         72176
31 OPC-Oligo          64081
09 CNU-LGE GABA       62862
02 NP-CT-L6b Glut     52318
11 CNU-HYa GABA       32076
33 Vascular           32037
34 Immune             30176
10 LSX GABA           23341
06 CTX-CGE GABA       19259
13 CNU-HYa Glut       16174
07 CTX-MGE GABA       15468
08 CNU-MGE GABA        4178
05 OB-IMN GABA         3801
03 OB-CR Glut           487
04 DG-IMN Glut          263
14 HY Glut              142
18 TH Glut               36
32 OEC                   31
25 Pineal Glut           15
12 HY GABA               14
16 HY MM Glut             5
Name

In [86]:
# # Flatten all cluster leaves from the tree (across all classes)
# def all_leaves(tree):
#     out = []
#     for k, v in tree.items():
#         if v:
#             out.extend(all_leaves(v))
#         else:
#             out.append(k)
#     return out

# leaves = set(all_leaves(mouse_whole_brain_hierarchy_dict))  # expected cluster names

# df = mouse_whole_brain_train_anndata.obs[['class','subclass','supertype','cluster']].copy()

# # rows that FAIL membership in your leaf set (these are currently being dropped)
# bad_mask = ~df['cluster'].isin(leaves)
# print("Unmapped rows:", int(bad_mask.sum()))
# print(df.loc[bad_mask, 'class'].value_counts().head(10))
# print(df.loc[bad_mask].head(10))

Unmapped rows: 62
class
01 IT-ET Glut        62
02 NP-CT-L6b Glut     0
03 OB-CR Glut         0
04 DG-IMN Glut        0
05 OB-IMN GABA        0
06 CTX-CGE GABA       0
07 CTX-MGE GABA       0
08 CNU-MGE GABA       0
09 CNU-LGE GABA       0
10 LSX GABA           0
Name: count, dtype: int64
                                    class                   subclass  \
cell_label                                                             
ATGCATGTCTATCACT-390_D02    01 IT-ET Glut  001 CLA-EPd-CTX Car3 Glut   
TGGAACTTCGTTCGCT-075.2_A01  01 IT-ET Glut  001 CLA-EPd-CTX Car3 Glut   
GAAGAATCAGAGTTCT-117_C01    01 IT-ET Glut  001 CLA-EPd-CTX Car3 Glut   
CTGTGGGCATGGGCAA-117_D01    01 IT-ET Glut  001 CLA-EPd-CTX Car3 Glut   
TGTAGACTCCACAAGT-382_B03    01 IT-ET Glut  001 CLA-EPd-CTX Car3 Glut   
CTGCAGGAGAATGTTG-391_C03    01 IT-ET Glut  001 CLA-EPd-CTX Car3 Glut   
GAGATGGAGTACTCGT-381_B02    01 IT-ET Glut  001 CLA-EPd-CTX Car3 Glut   
CCCTCTCAGGAGAGTA-117_A01    01 IT-ET Glut  001 CLA-EPd-CTX Car

In [87]:
# def all_leaves(tree):
#     out = []
#     for k, v in tree.items():
#         out.extend(all_leaves(v) if v else [k])
#     return out

# leaves = set(all_leaves(mouse_whole_brain_hierarchy_dict))
# target = "0001 CLA-EPd-CTX Car3 Glut_1"
# print(target in leaves)  # -> False right now

False


In [75]:
import argparse
from pathlib import Path
import numpy as np

total = []

directory = Path("/gpfs/scratch/blukacsy/abc_atlas/arrays")
def print_trainable(dataset):
    path = f"{dataset}_train_labels_hvg*.npy"
    for array in directory.glob(path):
        y = np.load(array)
        if (len(np.unique(y)) < 2): continue
        name = array.stem.removeprefix(f"{dataset}_train_labels_hvg_")
        print(dataset + " " + name)
        total.append(name)

# ap = argparse.ArgumentParser()
# ap.add_argument("--dataset", required=True)
# args = ap.parse_args()
print_trainable("mouse_whole_brain")
print()
print("num", len(total))

mouse_whole_brain 0021_l5_it_ctx_glut_4
mouse_whole_brain 0137_dg_glut_2
mouse_whole_brain 0122_l5_np_ctx_glut_1
mouse_whole_brain 0369_cea_bst_six3_cyp26b1_gaba_3
mouse_whole_brain 0510_coaa_paa_mea_barhl2_glut_2
mouse_whole_brain 056_sst_chodl_gaba
mouse_whole_brain 0218_sst_gaba_5
mouse_whole_brain 0382_acb_bst_fs_d1_gaba_2
mouse_whole_brain 0346_mea_bst_sox6_gaba_10
mouse_whole_brain 0015_l6_it_ctx_glut_3
mouse_whole_brain 055_str_lhx8_gaba
mouse_whole_brain 330_vlmc_nn
mouse_whole_brain 0151_ob_in_frmd7_gaba_2
mouse_whole_brain 0277_str_d2_gaba_4
mouse_whole_brain 0177_vip_gaba_5
mouse_whole_brain 0252_ndb_si_ma_strv_lhx8_gaba_9
mouse_whole_brain 0302_lsx_otx2_gaba_2
mouse_whole_brain 0362_mea_bst_lhx6_nfib_gaba_6
mouse_whole_brain 0358_mea_bst_lhx6_nfib_gaba_2
mouse_whole_brain 0322_lsx_prdm12_slit2_gaba_6
mouse_whole_brain 0297_ndb_si_ant_prdm12_gaba_3
mouse_whole_brain 0235_str_prox1_lhx6_gaba_3
mouse_whole_brain 1188_vlmc_nn_2
mouse_whole_brain 0368_cea_bst_six3_cyp26b1_gaba_2

In [76]:
import argparse
from pathlib import Path
import numpy as np

directory = Path("/gpfs/scratch/blukacsy/abc_atlas/arrays")
def print_trainable(dataset, level):
    path = f"{dataset}_{level}_train_labels_hvg*.npy"
    for array in directory.glob(path):
        y = np.load(array)
        if (len(np.unique(y)) < 2): continue
        print(dataset + "_" + level + " flat")

# ap = argparse.ArgumentParser()
# ap.add_argument("--dataset", required=True)
# args = ap.parse_args()
print_trainable("mouse_whole_brain", "neurotransmitter")
print_trainable("mouse_whole_brain", "class")
print_trainable("mouse_whole_brain", "subclass")
print_trainable("mouse_whole_brain", "supertype")
print_trainable("mouse_whole_brain", "cluster")

mouse_whole_brain_neurotransmitter flat
mouse_whole_brain_class flat
mouse_whole_brain_subclass flat
mouse_whole_brain_supertype flat
mouse_whole_brain_cluster flat


In [77]:
mouse_whole_brain_train_anndata.obs["class"].unique()

['09 CNU-LGE GABA', '11 CNU-HYa GABA', '30 Astro-Epen', '31 OPC-Oligo', '01 IT-ET Glut', ..., '18 TH Glut', '16 HY MM Glut', '12 HY GABA', '25 Pineal Glut', '32 OEC']
Length: 22
Categories (22, object): ['01 IT-ET Glut', '02 NP-CT-L6b Glut', '03 OB-CR Glut', '04 DG-IMN Glut', ..., '31 OPC-Oligo', '32 OEC', '33 Vascular', '34 Immune']

In [78]:
for x in mouse_whole_brain_train_anndata.obs["class"].unique():
    print("mouse_whole_brain", create_name(x))

mouse_whole_brain 09_cnu_lge_gaba
mouse_whole_brain 11_cnu_hya_gaba
mouse_whole_brain 30_astro_epen
mouse_whole_brain 31_opc_oligo
mouse_whole_brain 01_it_et_glut
mouse_whole_brain 34_immune
mouse_whole_brain 10_lsx_gaba
mouse_whole_brain 33_vascular
mouse_whole_brain 13_cnu_hya_glut
mouse_whole_brain 02_np_ct_l6b_glut
mouse_whole_brain 06_ctx_cge_gaba
mouse_whole_brain 08_cnu_mge_gaba
mouse_whole_brain 07_ctx_mge_gaba
mouse_whole_brain 05_ob_imn_gaba
mouse_whole_brain 03_ob_cr_glut
mouse_whole_brain 04_dg_imn_glut
mouse_whole_brain 14_hy_glut
mouse_whole_brain 18_th_glut
mouse_whole_brain 16_hy_mm_glut
mouse_whole_brain 12_hy_gaba
mouse_whole_brain 25_pineal_glut
mouse_whole_brain 32_oec


In [3]:
import numpy as np
import scanpy as sc
import matplotlib.pyplot as plt
import scipy.sparse
from sklearn.model_selection import train_test_split
from sklearn.utils.class_weight import compute_class_weight
import keras as ks
import sklearn.metrics as metrics
import pandas as pd
import re
import json
import os

2025-10-01 09:27:16.580559: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-10-01 09:27:16.787441: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1759325236.830497 2010299 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1759325236.840428 2010299 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1759325236.922244 2010299 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking 

In [4]:
!pip install openpyxl==3.1.5



In [10]:
def flat_prediction(level):
    test_anndata = sc.read_h5ad(f"/gpfs/scratch/blukacsy/abc_atlas/data/mouse_whole_brain_test_anndata.h5ad")
    test_anndata_hvg = test_anndata[:, test_anndata.var['highly_variable'] ].copy()

    model_path = f"/gpfs/scratch/blukacsy/abc_atlas/models/mouse_whole_brain_{level}_hvg_flat_jax_v1.keras"
    model = ks.models.load_model(model_path, custom_objects={'LeakyReLU': ks.layers.LeakyReLU}, compile=False)

    X = test_anndata_hvg.X

    if scipy.sparse.issparse(X):
        X = X.toarray()

    logits = model.predict(X)
    max_indices = np.argmax(logits, axis=1)

    path = f"/gpfs/scratch/blukacsy/abc_atlas/data/mouse_whole_brain_{level}_int_mapping_flat.json"
    with open(path) as file:
        int_mapping = json.load(file)

    inverse_dict = {i: j for j, i in int_mapping.items()}
    predictions = [inverse_dict[i] for i in max_indices]

    test_anndata_hvg.obs['predicted_cell_type'] = predictions
    true = test_anndata_hvg.obs[level].values

    per_class_accuracy = (pd.Series(predictions == true).groupby(pd.Series(true)).mean().to_dict())
    
    overall = {
        'accuracy': metrics.accuracy_score(true, predictions),
        'balanced_accuracy': metrics.balanced_accuracy_score(true, predictions),
        'precision': metrics.precision_score(true, predictions, average='macro', zero_division=0),
        'recall': metrics.recall_score(true, predictions, average='macro', zero_division=0),
        'f1_score': metrics.f1_score(true, predictions, average='macro', zero_division=0),
        'average_precision': metrics.average_precision_score(true, logits, average='macro')
    }

    class_report = metrics.classification_report(true, predictions, digits=2, zero_division=0, output_dict=True)

    class_report_df = (pd.DataFrame(class_report).transpose().rename_axis("label").reset_index())

    df = pd.DataFrame({
        "cell_id": test_anndata_hvg.obs_names.to_numpy(),
        "true_cell": true,
        "predicted_cell": predictions,
    })

    xlsx_path = f"/gpfs/scratch/blukacsy/abc_atlas/results/mouse_whole_brain_{level}_flat_classification.xlsx"
    with pd.ExcelWriter(xlsx_path, engine="openpyxl") as writer:
        pd.Series(overall, name="value").to_frame().to_excel(writer, sheet_name="overall_metrics")
        pd.Series(per_class_accuracy, name="accuracy").to_frame().to_excel(writer, sheet_name="per_cell_accuracy")
        class_report_df.to_excel(writer, sheet_name="classification_report", index=False)
        df.to_excel(writer, sheet_name="predictions", index=False)

    print(f"Excel written to `{xlsx_path}`")

    test_anndata_hvg.write(f"/gpfs/scratch/blukacsy/abc_atlas/data/mouse_whole_brain_{level}_test_anndata_flat_predictions.h5ad")

In [11]:
flat_prediction("neurotransmitter")
print("neuro done")
flat_prediction("class")
print("class done")
flat_prediction("subclass")
print("sub done")
flat_prediction("supertype")
print("super done")
flat_prediction("cluster")
print("clust done")

2025-10-01 09:47:27.311713: E external/local_xla/xla/stream_executor/cuda/cuda_platform.cc:51] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: UNKNOWN ERROR (303)


[1m4614/4614[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m27s[0m 6ms/step


  per_class_accuracy = (pd.Series(predictions == true).groupby(pd.Series(true)).mean().to_dict())


Excel written to `/gpfs/scratch/blukacsy/abc_atlas/results/mouse_whole_brain_neurotransmitter_flat_classification.xlsx`
neuro done
[1m4614/4614[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m15s[0m 3ms/step


  per_class_accuracy = (pd.Series(predictions == true).groupby(pd.Series(true)).mean().to_dict())


Excel written to `/gpfs/scratch/blukacsy/abc_atlas/results/mouse_whole_brain_class_flat_classification.xlsx`
class done
[1m4614/4614[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m26s[0m 6ms/step


  per_class_accuracy = (pd.Series(predictions == true).groupby(pd.Series(true)).mean().to_dict())


Excel written to `/gpfs/scratch/blukacsy/abc_atlas/results/mouse_whole_brain_subclass_flat_classification.xlsx`
sub done
[1m4614/4614[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m29s[0m 6ms/step


  per_class_accuracy = (pd.Series(predictions == true).groupby(pd.Series(true)).mean().to_dict())


Excel written to `/gpfs/scratch/blukacsy/abc_atlas/results/mouse_whole_brain_supertype_flat_classification.xlsx`
super done
[1m4614/4614[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m25s[0m 5ms/step


  per_class_accuracy = (pd.Series(predictions == true).groupby(pd.Series(true)).mean().to_dict())


Excel written to `/gpfs/scratch/blukacsy/abc_atlas/results/mouse_whole_brain_cluster_flat_classification.xlsx`
clust done


In [41]:
def create_name(input): return re.sub(r"[^A-Za-z0-9]+", "_", input).strip("_").lower()

In [42]:
class Node:
    def __init__(self, name, model, index_to_child, children):
        self.name = name
        self.model = model
        self.index_to_child = index_to_child
        self.children = children

    def predict(self, gene):
        
        if (not self.children): return self.name

        if ((self.model == None) or (len(self.children) == 1)):
            child_node = next(iter(self.children.values()))
            return child_node.predict(gene)

        logits = self.model.predict(np.array([gene]), verbose=0)
        max_index = np.argmax(logits, axis=1)[0]
        child_name = self.index_to_child[max_index]

        return self.children[child_name].predict(gene)

In [43]:
def create_tree(hierarchy_dict):
    def recurse(node_name, subtree):

        name = create_name(node_name)
        model_path = f"/gpfs/scratch/blukacsy/abc_atlas/models/mouse_whole_brain_hvg_{name}_jax_v1.keras"
        model = None
        
        if (os.path.exists(model_path)):
            model = ks.models.load_model(model_path, custom_objects={'LeakyReLU': ks.layers.LeakyReLU}, compile=False)
            
        index_to_child = None
        if subtree:
            children = list(subtree)
            int_mapping = f"/gpfs/scratch/blukacsy/abc_atlas/data/mouse_whole_brain_int_mapping_{name}.json"

            if (os.path.exists(int_mapping)):
                with open(int_mapping) as file:
                    child_to_index = json.load(file)

                index_to_child = {int(value): key for key, value in child_to_index.items()}

            else: index_to_child = {i: child for i, child in enumerate(children)}
            child_nodes = {child: recurse(child, subtree[child]) for child in children}
        
        else: child_nodes = {}
        return Node(node_name, model, index_to_child, child_nodes)
    
    return recurse("class", hierarchy_dict)

In [49]:
def hierarchical_predictions(level, hierarchy_dict):
    test_anndata = sc.read_h5ad(f"/gpfs/scratch/blukacsy/abc_atlas/data/mouse_whole_brain_test_anndata.h5ad")
    test_anndata_hvg = test_anndata[:, test_anndata.var['highly_variable'] ].copy()

    root = create_tree(hierarchy_dict)

    X = test_anndata_hvg.X

    if scipy.sparse.issparse(X):
        X = X.toarray()

    predictions = []
    for gene in X:
        predictions.append(root.predict(gene))

    test_anndata_hvg.obs['predicted_cell_type'] = predictions
    true = test_anndata_hvg.obs[level].values

    per_class_accuracy = (pd.Series(predictions == true).groupby(pd.Series(true)).mean().to_dict())
    
    overall = {
        'accuracy': metrics.accuracy_score(true, predictions),
        'balanced_accuracy': metrics.balanced_accuracy_score(true, predictions),
        'precision': metrics.precision_score(true, predictions, average='macro', zero_division=0),
        'recall': metrics.recall_score(true, predictions, average='macro', zero_division=0),
        'f1_score': metrics.f1_score(true, predictions, average='macro', zero_division=0),
        'AUPRC': 'N/A'
    }

    class_report = metrics.classification_report(true, predictions, digits=2, zero_division=0, output_dict=True)

    class_report_df = (pd.DataFrame(class_report).transpose().rename_axis("label").reset_index())

    df = pd.DataFrame({
        "cell_id": test_anndata_hvg.obs_names.to_numpy(),
        "true_cell": true,
        "predicted_cell": predictions,
    })

    xlsx_path = f"/gpfs/scratch/blukacsy/abc_atlas/results/mouse_whole_brain_{level}_hierarchical_classification.xlsx"
    with pd.ExcelWriter(xlsx_path, engine="openpyxl") as writer:
        pd.Series(overall, name="value").to_frame().to_excel(writer, sheet_name="overall_metrics")
        pd.Series(per_class_accuracy, name="accuracy").to_frame().to_excel(writer, sheet_name="per_cell_accuracy")
        class_report_df.to_excel(writer, sheet_name="classification_report", index=False)
        df.to_excel(writer, sheet_name="predictions", index=False)

    print(f"Excel written to `{xlsx_path}`")

    test_anndata_hvg.write(f"/gpfs/scratch/blukacsy/abc_atlas/data/mouse_whole_brain_{level}_test_anndata_hierarchy_predictions.h5ad")

In [44]:
def load_hierarchy_dict(path):
    with open(path, "r") as file:
        return json.load(file)

In [45]:
def truncate_hierarchy(tree: dict, max_depth: int, _depth: int = 1) -> dict:
    if not isinstance(tree, dict) or not tree:
        return {}
    if _depth == max_depth:
        return {k: {} for k in tree}
    return {k: truncate_hierarchy(v, max_depth, _depth + 1) for k, v in tree.items()}

In [46]:
level_to_depth = {
    "class": 1,
    "subclass": 2,
    "supertype": 3,
    "cluster": 4,
}

In [47]:
mouse_whole_brain_hierarchy_dict = load_hierarchy_dict("/gpfs/scratch/blukacsy/abc_atlas/data/mouse_whole_brain_hierarchy_dict.json")

In [55]:
for level in ["class", "subclass", "supertype", "cluster"]:
    truncated = truncate_hierarchy(mouse_whole_brain_hierarchy_dict, level_to_depth[level])
    hierarchical_predictions(level, truncated)
    print(f"{level} done")

  per_class_accuracy = (pd.Series(predictions == true).groupby(pd.Series(true)).mean().to_dict())


Excel written to `/gpfs/scratch/blukacsy/abc_atlas/results/mouse_whole_brain_class_hierarchical_classification.xlsx`
class done


  per_class_accuracy = (pd.Series(predictions == true).groupby(pd.Series(true)).mean().to_dict())


Excel written to `/gpfs/scratch/blukacsy/abc_atlas/results/mouse_whole_brain_subclass_hierarchical_classification.xlsx`
subclass done


  per_class_accuracy = (pd.Series(predictions == true).groupby(pd.Series(true)).mean().to_dict())


Excel written to `/gpfs/scratch/blukacsy/abc_atlas/results/mouse_whole_brain_supertype_hierarchical_classification.xlsx`
supertype done


  per_class_accuracy = (pd.Series(predictions == true).groupby(pd.Series(true)).mean().to_dict())


Excel written to `/gpfs/scratch/blukacsy/abc_atlas/results/mouse_whole_brain_cluster_hierarchical_classification.xlsx`
cluster done


In [3]:
striatum_clustered = sc.read_h5ad("/gpfs/scratch/blukacsy/striatum_clustered.h5ad")

In [76]:
striatum_clustered

AnnData object with n_obs × n_vars = 42043 × 21579
    obs: 'doublet', 'doublet_score', 'doublet_bool', 'n_genes_by_counts', 'total_counts', 'type', 'sample', 'batch_type', 'batch', 'log1p_n_genes_by_counts', 'log1p_total_counts', 'pct_counts_in_top_50_genes', 'pct_counts_in_top_100_genes', 'pct_counts_in_top_200_genes', 'pct_counts_in_top_500_genes', 'total_counts_mt', 'pct_counts_mt', 'total_counts_ribo', 'pct_counts_ribo', 'total_counts_hb', 'pct_counts_hb', 'n_genes', 'n_counts', 'size_factors', 'leiden_0.2', 'leiden_0.3', 'leiden_0.4', 'leiden_0.6', 'leiden_0.8', 'leiden_1.0', 'leiden_1.4', 'leiden_1.7', 'leiden_2.0', 'leiden_2.5'
    var: 'gene_ids', 'feature_types', 'n_cells-0', 'n_cells_by_counts-0', 'mean_counts-0', 'pct_dropout_by_counts-0', 'total_counts-0', 'n_cells-1', 'n_cells_by_counts-1', 'mean_counts-1', 'pct_dropout_by_counts-1', 'total_counts-1', 'n_cells-2', 'n_cells_by_counts-2', 'mean_counts-2', 'pct_dropout_by_counts-2', 'total_counts-2', 'n_cells-3', 'n_cells_by

In [98]:
striatum_clustered.obs["batch_type"].value_counts()

batch_type
batch1    18544
batch4    13269
batch2    10230
Name: count, dtype: int64

In [27]:
test = striatum_clustered[:, striatum_clustered.var['highly_variable'] ].copy()

In [28]:
test

AnnData object with n_obs × n_vars = 42043 × 2950
    obs: 'doublet', 'doublet_score', 'doublet_bool', 'n_genes_by_counts', 'total_counts', 'type', 'sample', 'batch_type', 'batch', 'log1p_n_genes_by_counts', 'log1p_total_counts', 'pct_counts_in_top_50_genes', 'pct_counts_in_top_100_genes', 'pct_counts_in_top_200_genes', 'pct_counts_in_top_500_genes', 'total_counts_mt', 'pct_counts_mt', 'total_counts_ribo', 'pct_counts_ribo', 'total_counts_hb', 'pct_counts_hb', 'n_genes', 'n_counts', 'size_factors', 'leiden_0.2', 'leiden_0.3', 'leiden_0.4', 'leiden_0.6', 'leiden_0.8', 'leiden_1.0', 'leiden_1.4', 'leiden_1.7', 'leiden_2.0', 'leiden_2.5'
    var: 'gene_ids', 'feature_types', 'n_cells-0', 'n_cells_by_counts-0', 'mean_counts-0', 'pct_dropout_by_counts-0', 'total_counts-0', 'n_cells-1', 'n_cells_by_counts-1', 'mean_counts-1', 'pct_dropout_by_counts-1', 'total_counts-1', 'n_cells-2', 'n_cells_by_counts-2', 'mean_counts-2', 'pct_dropout_by_counts-2', 'total_counts-2', 'n_cells-3', 'n_cells_by_

In [138]:
res = sc.read_h5ad("/gpfs/scratch/blukacsy/abc_atlas/results/striatum_clustered_mwb_predictions.h5ad")

In [139]:
res

AnnData object with n_obs × n_vars = 42043 × 21579
    obs: 'doublet', 'doublet_score', 'doublet_bool', 'n_genes_by_counts', 'total_counts', 'type', 'sample', 'batch_type', 'batch', 'log1p_n_genes_by_counts', 'log1p_total_counts', 'pct_counts_in_top_50_genes', 'pct_counts_in_top_100_genes', 'pct_counts_in_top_200_genes', 'pct_counts_in_top_500_genes', 'total_counts_mt', 'pct_counts_mt', 'total_counts_ribo', 'pct_counts_ribo', 'total_counts_hb', 'pct_counts_hb', 'n_genes', 'n_counts', 'size_factors', 'leiden_0.2', 'leiden_0.3', 'leiden_0.4', 'leiden_0.6', 'leiden_0.8', 'leiden_1.0', 'leiden_1.4', 'leiden_1.7', 'leiden_2.0', 'leiden_2.5', 'pred_flat_neurotransmitter', 'pred_flat_class', 'pred_flat_subclass', 'pred_flat_supertype', 'pred_flat_cluster', 'pred_hier_class', 'pred_hier_subclass', 'pred_hier_supertype', 'pred_hier_cluster'
    var: 'gene_ids', 'feature_types', 'n_cells-0', 'n_cells_by_counts-0', 'mean_counts-0', 'pct_dropout_by_counts-0', 'total_counts-0', 'n_cells-1', 'n_cell

In [140]:
res.X

array([[ 1.1843488e+00, -2.6635351e-02, -1.2202565e-02, ...,
        -1.2682188e-04,  3.6168651e-05,  7.7891731e-01],
       [ 9.6031862e-01,  8.3774239e-01,  3.9538506e-01, ...,
        -1.2682188e-04,  3.6168651e-05,  1.0415727e+00],
       [ 9.0260506e-01, -2.6635351e-02, -1.2202565e-02, ...,
        -1.2682188e-04,  3.6168651e-05,  9.7604543e-01],
       ...,
       [ 7.0522428e-03, -9.9165067e-03, -4.5041265e-03, ...,
         1.6537460e-04, -1.6345401e-04,  1.7385267e+00],
       [ 7.0522428e-03, -9.9165067e-03, -4.5041265e-03, ...,
         1.6537460e-04, -1.6345401e-04,  1.3682370e+00],
       [ 7.0522428e-03, -9.9165067e-03, -4.5041265e-03, ...,
         1.6537460e-04, -1.6345401e-04,  1.6115695e-01]], dtype=float32)

In [148]:
res.X.min()

np.float32(-0.7174092)

In [149]:
res.X.max()

np.float32(7.474733)

In [141]:
boo = sc.read_h5ad("/gpfs/scratch/blukacsy/abc_atlas/data/mouse_whole_brain_train_anndata.h5ad")

In [142]:
boo

AnnData object with n_obs × n_vars = 590487 × 32285
    obs: 'cell_barcode', 'library_label', 'anatomical_division_label', 'neurotransmitter', 'class', 'subclass', 'supertype', 'cluster', 'feature_matrix_label'
    var: 'highly_variable', 'means', 'dispersions', 'dispersions_norm', 'highly_variable_nbatches', 'highly_variable_intersection'
    uns: 'hvg'

In [143]:
mwb_hvg = boo[:, boo.var['highly_variable'] ].copy()

In [144]:
mwb_hvg

AnnData object with n_obs × n_vars = 590487 × 3500
    obs: 'cell_barcode', 'library_label', 'anatomical_division_label', 'neurotransmitter', 'class', 'subclass', 'supertype', 'cluster', 'feature_matrix_label'
    var: 'highly_variable', 'means', 'dispersions', 'dispersions_norm', 'highly_variable_nbatches', 'highly_variable_intersection'
    uns: 'hvg'

In [146]:
mwb_hvg.X.toarray()

array([[0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.]], dtype=float32)

In [150]:
mwb_hvg.X.toarray().min()

np.float32(0.0)

In [151]:
mwb_hvg.X.toarray().max()

np.float32(18.62102)

In [90]:
mwb_hvg.obs["neurotransmitter"].nunique()

6

In [91]:
mwb_hvg.obs["class"].nunique()

22

In [92]:
mwb_hvg.obs["subclass"].nunique()

117

In [93]:
mwb_hvg.obs["supertype"].nunique()

412

In [94]:
mwb_hvg.obs["cluster"].nunique()

1320

In [95]:
cands = []
for col in striatum_clustered.obs.columns:
    s = striatum_clustered.obs[col]
    if pd.api.types.is_object_dtype(s) or pd.api.types.is_categorical_dtype(s):
        nunq = s.nunique(dropna=True)
        if 1 < nunq <= 200:
            cands.append((col, nunq))
sorted(cands, key=lambda x: x[1])

  if pd.api.types.is_object_dtype(s) or pd.api.types.is_categorical_dtype(s):


[('type', 2),
 ('batch_type', 3),
 ('sample', 6),
 ('batch', 6),
 ('leiden_0.2', 14),
 ('leiden_0.3', 15),
 ('leiden_0.4', 15),
 ('leiden_0.6', 19),
 ('leiden_0.8', 20),
 ('leiden_1.0', 24),
 ('leiden_1.4', 28),
 ('leiden_1.7', 30),
 ('leiden_2.0', 33),
 ('leiden_2.5', 39)]

In [11]:
boo.obs["class"].value_counts()

class
01 IT-ET Glut        161547
30 Astro-Epen         72176
31 OPC-Oligo          64081
09 CNU-LGE GABA       62862
02 NP-CT-L6b Glut     52318
11 CNU-HYa GABA       32076
33 Vascular           32037
34 Immune             30176
10 LSX GABA           23341
06 CTX-CGE GABA       19259
13 CNU-HYa Glut       16174
07 CTX-MGE GABA       15468
08 CNU-MGE GABA        4178
05 OB-IMN GABA         3801
03 OB-CR Glut           487
04 DG-IMN Glut          263
14 HY Glut              142
18 TH Glut               36
32 OEC                   31
25 Pineal Glut           15
12 HY GABA               14
16 HY MM Glut             5
Name: count, dtype: int64

In [74]:
import os, json, re, gc
import numpy as np
import pandas as pd
import scanpy as sc
import scipy.sparse as sp
import keras as ks

BASE = "/gpfs/scratch/blukacsy/abc_atlas"
DATA_DIR = f"{BASE}/data"
MODELS_DIR = f"{BASE}/models"
RESULTS_DIR = f"{BASE}/results"

SOURCE_BASE = "mouse_whole_brain"
TARGET_PATH = "/gpfs/scratch/blukacsy/striatum_clustered.h5ad"

LEVELS_FLAT = ["neurotransmitter", "class", "subclass", "supertype", "cluster"]
LEVELS_HIER = ["class", "subclass", "supertype", "cluster"]

def _flat_intmap_path(level: str) -> str:
    return f"{DATA_DIR}/{SOURCE_BASE}_{level}_int_mapping_flat.json"

def _flat_hvg_path(level: str) -> str:
    return f"{DATA_DIR}/{SOURCE_BASE}_{level}_train_anndata_hvg_flat.h5ad"

def _flat_model_path(level: str) -> str:
    return f"{MODELS_DIR}/{SOURCE_BASE}_{level}_hvg_flat_jax_v1.keras"

def _hierarchy_json_path() -> str:
    return f"{DATA_DIR}/{SOURCE_BASE}_hierarchy_dict.json"

def _load_hvgs_for_level(level: str) -> pd.Index:
    p = _flat_hvg_path(level)
    if not os.path.exists(p):
        raise FileNotFoundError(f"Missing HVG training file for level '{level}': {p}")
    ad = sc.read_h5ad(p)
    return ad.var_names.copy()

def _align_to_genes(ad_tgt: sc.AnnData, source_genes: pd.Index) -> np.ndarray:
    """
    Dense float32 matrix with columns in source_genes order.
    Tries multiple keys to match: var_names, var['gene_ids'], var['gene_ids'](no-version).
    Chooses the best-hit mapping and logs it.
    """
    import re
    import numpy as np
    import scipy.sparse as sp
    import pandas as pd

    def is_ensembl_like(names: pd.Index) -> bool:
        pat = re.compile(r"^ENSMUSG\d+(?:\.\d+)?$")
        n = len(names)
        if n == 0: return False
        k = sum(1 for g in names[:min(5000, n)] if pat.match(str(g)))
        return (k / min(5000, n)) >= 0.8

    def strip_version(idx: pd.Index) -> pd.Index:
        return idx.astype(str).str.replace(r"\.\d+$", "", regex=True)

    n, m = ad_tgt.n_obs, len(source_genes)
    Xfull = ad_tgt.X

    # candidates: (label, index, dict_map)
    candidates = []
    names = pd.Index(ad_tgt.var_names.astype(str))
    candidates.append(("var_names", names, {g:i for i,g in enumerate(names)}))

    if "gene_ids" in ad_tgt.var.columns:
        gids = pd.Index(ad_tgt.var["gene_ids"].astype(str))
        candidates.append(("gene_ids", gids, {g:i for i,g in enumerate(gids)}))
        gids_nv = strip_version(gids)
        candidates.append(("gene_ids_novers", gids_nv, {g:i for i,g in enumerate(gids_nv)}))

    # prefer IDs when the source looks like Ensembl
    src_is_ens = is_ensembl_like(source_genes)
    label_order = ["gene_ids", "gene_ids_novers", "var_names"] if src_is_ens else ["var_names", "gene_ids", "gene_ids_novers"]

    best = None
    for wanted in label_order:
        for label, idx_list, idx_map in candidates:
            if label != wanted: 
                continue
            hits = [(pos, idx_map[g]) for pos, g in enumerate(source_genes) if g in idx_map]
            if best is None or len(hits) > len(best[2]):
                best = (label, idx_map, hits)
        if best and best[0] == wanted and len(best[2]) > 0:
            break

    hit_label, idx_map, hits = best if best is not None else (None, {}, [])
    present = len(hits)

    X = np.zeros((n, m), dtype=np.float32)
    if present > 0:
        src_pos, tgt_idx = zip(*hits)
        Xp = Xfull[:, list(tgt_idx)]
        Xp = Xp.toarray() if sp.issparse(Xp) else np.asarray(Xp)
        X[:, list(src_pos)] = Xp.astype(np.float32, copy=False)

    print(f"[align] target cells={n} | source genes={m} | via={hit_label} | present={present} | padded={m-present}")
    return X

def _predict_flat_level(ad_tgt: sc.AnnData, level: str) -> pd.Series:
    hvgs = _load_hvgs_for_level(level)
    mpath = _flat_model_path(level)
    if not os.path.exists(mpath):
        raise FileNotFoundError(f"Flat model missing for level '{level}': {mpath}")
    model = ks.models.load_model(mpath, custom_objects={'LeakyReLU': ks.layers.LeakyReLU}, compile=False)

    jmap = _flat_intmap_path(level)
    if not os.path.exists(jmap):
        raise FileNotFoundError(f"Flat int mapping missing for level '{level}': {jmap}")
    with open(jmap) as f:
        name_to_int = json.load(f)
    int_to_name = {int(v): k for k, v in name_to_int.items()}

    X = _align_to_genes(ad_tgt, hvgs)
    logits = model.predict(X, verbose=0)
    max_idx = np.argmax(logits, axis=1)
    preds = pd.Series([int_to_name[int(i)] for i in max_idx], index=ad_tgt.obs_names, name=f"pred_flat_{level}")
    
    del model, X, logits; gc.collect()
    return preds

def _predict_hier_level(ad_tgt: sc.AnnData, level: str, hierarchy_dict: dict, level_to_depth: dict) -> pd.Series:
    
    depth = level_to_depth[level]
    truncated = truncate_hierarchy(hierarchy_dict, max_depth=depth)

    root = create_tree(truncated)

    hvgs = _load_hvgs_for_level(level)
    X = _align_to_genes(ad_tgt, hvgs)

    out = []
    for i in range(X.shape[0]):
        out.append(root.predict(X[i, :]))
        if (i + 1) % 10000 == 0:
            print(f"[hier {level}] predicted {i+1}/{X.shape[0]}")
    preds = pd.Series(out, index=ad_tgt.obs_names, name=f"pred_hier_{level}")
    del X, out; gc.collect()
    return preds

def run_striatum_predictions_and_save():
    print(f"[load] {TARGET_PATH}")
    ad = sc.read_h5ad(TARGET_PATH)

    for lvl in LEVELS_FLAT:
        print(f"\n=== FLAT prediction level={lvl} ===")
        ad.obs[f"pred_flat_{lvl}"] = _predict_flat_level(ad, lvl).astype("category")

    hpath = _hierarchy_json_path()
    if not os.path.exists(hpath):
        raise FileNotFoundError(f"Hierarchy dict missing: {hpath}")
    full_hierarchy = load_hierarchy_dict(hpath)

    for lvl in LEVELS_HIER:
        print(f"\n=== HIERARCHICAL prediction level={lvl} ===")
        ad.obs[f"pred_hier_{lvl}"] = _predict_hier_level(ad, lvl, full_hierarchy, level_to_depth).astype("category")

    ad.uns.setdefault("mwb_prediction_levels", {})
    ad.uns["mwb_prediction_levels"]["flat"] = LEVELS_FLAT
    ad.uns["mwb_prediction_levels"]["hier"] = LEVELS_HIER
    ad.uns["mwb_prediction_levels"]["source"] = SOURCE_BASE

    os.makedirs(RESULTS_DIR, exist_ok=True)
    out_h5ad = f"{RESULTS_DIR}/striatum_clustered_mwb_predictions.h5ad"
    ad.write(out_h5ad)
    print(f"\n[SAVED] {out_h5ad}  (added obs columns: "
          + ", ".join([f"pred_flat_{l}" for l in LEVELS_FLAT] + [f"pred_hier_{l}" for l in LEVELS_HIER]) + ")")

run_striatum_predictions_and_save()

[load] /gpfs/scratch/blukacsy/striatum_clustered.h5ad

=== FLAT prediction level=neurotransmitter ===
[align] target cells=42043 | source genes=3500 | via=gene_ids | present=2512 | padded=988

=== FLAT prediction level=class ===
[align] target cells=42043 | source genes=3500 | via=gene_ids | present=2512 | padded=988

=== FLAT prediction level=subclass ===
[align] target cells=42043 | source genes=3500 | via=gene_ids | present=2512 | padded=988

=== FLAT prediction level=supertype ===
[align] target cells=42043 | source genes=3500 | via=gene_ids | present=2512 | padded=988

=== FLAT prediction level=cluster ===
[align] target cells=42043 | source genes=3500 | via=gene_ids | present=2512 | padded=988

=== HIERARCHICAL prediction level=class ===
[align] target cells=42043 | source genes=3500 | via=gene_ids | present=2512 | padded=988
[hier class] predicted 10000/42043
[hier class] predicted 20000/42043
[hier class] predicted 30000/42043
[hier class] predicted 40000/42043

=== HIERARCHICA

In [75]:
print("foo")

foo


In [99]:
res = sc.read_h5ad("/gpfs/scratch/blukacsy/abc_atlas/results/striatum_clustered_mwb_predictions.h5ad")

In [100]:
res

AnnData object with n_obs × n_vars = 42043 × 21579
    obs: 'doublet', 'doublet_score', 'doublet_bool', 'n_genes_by_counts', 'total_counts', 'type', 'sample', 'batch_type', 'batch', 'log1p_n_genes_by_counts', 'log1p_total_counts', 'pct_counts_in_top_50_genes', 'pct_counts_in_top_100_genes', 'pct_counts_in_top_200_genes', 'pct_counts_in_top_500_genes', 'total_counts_mt', 'pct_counts_mt', 'total_counts_ribo', 'pct_counts_ribo', 'total_counts_hb', 'pct_counts_hb', 'n_genes', 'n_counts', 'size_factors', 'leiden_0.2', 'leiden_0.3', 'leiden_0.4', 'leiden_0.6', 'leiden_0.8', 'leiden_1.0', 'leiden_1.4', 'leiden_1.7', 'leiden_2.0', 'leiden_2.5', 'pred_flat_neurotransmitter', 'pred_flat_class', 'pred_flat_subclass', 'pred_flat_supertype', 'pred_flat_cluster', 'pred_hier_class', 'pred_hier_subclass', 'pred_hier_supertype', 'pred_hier_cluster'
    var: 'gene_ids', 'feature_types', 'n_cells-0', 'n_cells_by_counts-0', 'mean_counts-0', 'pct_dropout_by_counts-0', 'total_counts-0', 'n_cells-1', 'n_cell

In [137]:
res.X

array([[ 1.1843488e+00, -2.6635351e-02, -1.2202565e-02, ...,
        -1.2682188e-04,  3.6168651e-05,  7.7891731e-01],
       [ 9.6031862e-01,  8.3774239e-01,  3.9538506e-01, ...,
        -1.2682188e-04,  3.6168651e-05,  1.0415727e+00],
       [ 9.0260506e-01, -2.6635351e-02, -1.2202565e-02, ...,
        -1.2682188e-04,  3.6168651e-05,  9.7604543e-01],
       ...,
       [ 7.0522428e-03, -9.9165067e-03, -4.5041265e-03, ...,
         1.6537460e-04, -1.6345401e-04,  1.7385267e+00],
       [ 7.0522428e-03, -9.9165067e-03, -4.5041265e-03, ...,
         1.6537460e-04, -1.6345401e-04,  1.3682370e+00],
       [ 7.0522428e-03, -9.9165067e-03, -4.5041265e-03, ...,
         1.6537460e-04, -1.6345401e-04,  1.6115695e-01]], dtype=float32)

In [129]:
res.obs["pred_flat_cluster"].value_counts()

pred_flat_cluster
0096 L4/5 IT CTX Glut_5    41514
0040 L6 IT CTX Glut_2        526
5310 Endo NN_1                 3
Name: count, dtype: int64

In [136]:
res.obs["pred_hier_cluster"]

AAACCCAAGACCATAA_sc20_1    0601 OB-STR-CTX Inh IMN_2
AAACCCAAGTCTTCCC_sc20_1    0601 OB-STR-CTX Inh IMN_2
AAACCCACAGAGCCCT_sc20_1    0601 OB-STR-CTX Inh IMN_2
AAACCCACATCAGTCA_sc20_1    0601 OB-STR-CTX Inh IMN_2
AAACCCAGTCTTGGTA_sc20_1    0601 OB-STR-CTX Inh IMN_2
                                     ...            
TTTGTTGGTACTGGGA_sc41_2    0601 OB-STR-CTX Inh IMN_2
TTTGTTGGTATCGTTG_sc41_2    0601 OB-STR-CTX Inh IMN_2
TTTGTTGGTTGGCCGT_sc41_2    0601 OB-STR-CTX Inh IMN_2
TTTGTTGGTTTCGACA_sc41_2    0601 OB-STR-CTX Inh IMN_2
TTTGTTGTCCAGCCTT_sc41_2    0601 OB-STR-CTX Inh IMN_2
Name: pred_hier_cluster, Length: 42043, dtype: category
Categories (7, object): ['0601 OB-STR-CTX Inh IMN_2', '0605 OB-STR-CTX Inh IMN_2', '0607 OB-STR-CTX Inh IMN_2', '0616 OB-STR-CTX Inh IMN_5', '0618 OB-STR-CTX Inh IMN_5', '0621 OB-STR-CTX Inh IMN_6', '5310 Endo NN_1']

In [135]:
import os, re
import numpy as np
import pandas as pd
import scanpy as sc
import matplotlib.pyplot as plt

BASE = "/gpfs/scratch/blukacsy/abc_atlas"
DATA_DIR = f"{BASE}/data"
RESULTS_DIR = f"{BASE}/results"

PRED_FILE = f"{RESULTS_DIR}/striatum_clustered_mwb_predictions.h5ad"
OUTDIR = f"{RESULTS_DIR}/umaps"
os.makedirs(OUTDIR, exist_ok=True)

LEVELS_FLAT = ["neurotransmitter", "class", "subclass", "supertype", "cluster"]
LEVELS_HIER = ["class", "subclass", "supertype", "cluster"]

def _abbrev(name: str):
    s = re.sub(r"\(.*?\)", "", str(name))  # strip parentheticals

    # remove one or more leading 2–4 digit blocks plus common separators
    # e.g., "0123 - 045  Inhibitory Neuron" -> "Inhibitory Neuron"
    s = re.sub(r"^(?:\s*\d{2,4}[\s_\-:./]*)+", "", s)

    # split into tokens
    words = [w for w in re.split(r"[ /_\-]+", s.strip()) if w]

    # drop numeric tokens at the very front (2–4 digits), just in case
    while words and re.fullmatch(r"\d{2,4}", words[0]):
        words.pop(0)

    # if the first remaining token still begins with 2–4 digits (e.g., "1234Astro"),
    # strip those digits from that token
    if words:
        words[0] = re.sub(r"^\d{2,4}", "", words[0])
        if words[0] == "":
            words.pop(0)

    if not words:
        return "NA"

    if len(words) >= 3:
        return "".join(w[0].upper() for w in words[:3])

    if len(words) == 2:
        first, second = words
        return (first[:2] + second[:1]).upper() if len(first) > 1 else (first[:1] + second[:2]).upper()

    # single word
    return words[0][:3].upper()

def _palette(categories):
    import itertools
    from matplotlib import rcParams
    default_colors = rcParams['axes.prop_cycle'].by_key().get('color', [])
    cyc = itertools.cycle(default_colors if default_colors else ["#377eb8","#4daf4a","#984ea3","#ff7f00","#e41a1c"])
    return {c: next(cyc) for c in categories}

def _ensure_umap(adata):
    if "X_umap" in adata.obsm:
        return
    if "X_pca" in adata.obsm:
        print("using PCA")
        sc.pp.neighbors(adata, use_rep="X_pca")
    else:
        print("creating custom UMAP")
        sc.pp.pca(adata, n_comps=min(50, adata.n_vars))
        sc.pp.neighbors(adata, n_pcs=min(50, adata.n_vars))
    sc.tl.umap(adata, min_dist=0.5)

def plot_pred_umap(adata, column, title, savepath, annotate_centroids=True, legend_max=40, s=2, alpha=0.7):
    if column not in adata.obs:
        print(f"[WARN] missing {column}; skip")
        return
    _ensure_umap(adata)

    lab = adata.obs[column].astype("category")
    coords = pd.DataFrame(adata.obsm["X_umap"], columns=["UMAP1", "UMAP2"], index=adata.obs_names)
    coords["label"] = lab

    cats = list(lab.cat.categories)
    pal = _palette(cats)

    counts = coords["label"].value_counts()
    show_cats = list(counts.head(40 if legend_max is None else legend_max).index)

    fig, ax = plt.subplots(figsize=(9,8))
    for cat, group in coords.groupby("label", observed=True):
        show_lbl = (cat in show_cats)
        lbl = f"{str(cat)} (n={len(group)})" if show_lbl else None
        ax.scatter(group["UMAP1"], group["UMAP2"], s=s, alpha=alpha, lw=0,
                   color=pal.get(cat, "#999999"), label=lbl)
        if annotate_centroids and show_lbl:
            ax.text(group["UMAP1"].mean(), group["UMAP2"].mean(), _abbrev(cat),
                    fontsize=7, fontweight="bold", ha="center", va="center",
                    color="black",
                    bbox=dict(facecolor="white", alpha=0.7, edgecolor="none", pad=1))

    ax.set_title(title)
    ax.set_xlabel("UMAP1"); ax.set_ylabel("UMAP2")
    if show_cats:
        ax.legend(
            bbox_to_anchor=(1.02, 1),
            loc="upper left",
            fontsize=8,
            title="Predicted (n)",
            markerscale=4.0,
            scatterpoints=1
        )
    plt.tight_layout()
    plt.savefig(savepath, dpi=300, bbox_inches="tight")
    plt.close(fig)

def make_all_umaps():
    print(f"[load] {PRED_FILE}")
    ad = sc.read_h5ad(PRED_FILE)

    # 5 flat
    for lvl in LEVELS_FLAT:
        col = f"pred_flat_{lvl}"
        title = f"Striatum UMAP w/ MWB {lvl} flat preds"
        out = f"{OUTDIR}/striatum_umap_mwb_{lvl}_flat.png"
        plot_pred_umap(ad, col, title, out,
                       annotate_centroids=True,
                       legend_max=(25 if lvl == "cluster" else 40),
                       s=2, alpha=0.7)
        print(f"[saved] {out}")

    # 4 hierarchical
    for lvl in LEVELS_HIER:
        col = f"pred_hier_{lvl}"
        title = f"Striatum UMAP w/ MWB {lvl} hier preds"
        out = f"{OUTDIR}/striatum_umap_mwb_{lvl}_hier.png"
        plot_pred_umap(ad, col, title, out,
                       annotate_centroids=True,
                       legend_max=(25 if lvl == "cluster" else 40),
                       s=2, alpha=0.7)
        print(f"[saved] {out}")

    print("done")

# Uncomment to run immediately:
make_all_umaps()

[load] /gpfs/scratch/blukacsy/abc_atlas/results/striatum_clustered_mwb_predictions.h5ad
[saved] /gpfs/scratch/blukacsy/abc_atlas/results/umaps/striatum_umap_mwb_neurotransmitter_flat.png
[saved] /gpfs/scratch/blukacsy/abc_atlas/results/umaps/striatum_umap_mwb_class_flat.png
[saved] /gpfs/scratch/blukacsy/abc_atlas/results/umaps/striatum_umap_mwb_subclass_flat.png
[saved] /gpfs/scratch/blukacsy/abc_atlas/results/umaps/striatum_umap_mwb_supertype_flat.png
[saved] /gpfs/scratch/blukacsy/abc_atlas/results/umaps/striatum_umap_mwb_cluster_flat.png
[saved] /gpfs/scratch/blukacsy/abc_atlas/results/umaps/striatum_umap_mwb_class_hier.png
[saved] /gpfs/scratch/blukacsy/abc_atlas/results/umaps/striatum_umap_mwb_subclass_hier.png
[saved] /gpfs/scratch/blukacsy/abc_atlas/results/umaps/striatum_umap_mwb_supertype_hier.png
[saved] /gpfs/scratch/blukacsy/abc_atlas/results/umaps/striatum_umap_mwb_cluster_hier.png
done


In [65]:
import scanpy as sc
import pandas as pd
import numpy as np
import re

STRI = "/gpfs/scratch/blukacsy/striatum_clustered.h5ad"
BASE = "/gpfs/scratch/blukacsy/abc_atlas/data"

mwb_hvg_paths = {
    "neurotransmitter": f"{BASE}/mouse_whole_brain_neurotransmitter_train_anndata_hvg_flat.h5ad",
    "class":            f"{BASE}/mouse_whole_brain_class_train_anndata_hvg_flat.h5ad",
    "subclass":         f"{BASE}/mouse_whole_brain_subclass_train_anndata_hvg_flat.h5ad",
    "supertype":        f"{BASE}/mouse_whole_brain_supertype_train_anndata_hvg_flat.h5ad",
    "cluster":          f"{BASE}/mouse_whole_brain_cluster_train_anndata_hvg_flat.h5ad",
}

def is_ensembl_like(names: pd.Index) -> bool:
    # Heuristic: >80% look like ENSMUSG#########
    pat = re.compile(r"^ENSMUSG\d+(?:\.\d+)?$")
    n = len(names)
    if n == 0: return False
    k = sum(1 for g in names[:min(5000, n)] if pat.match(str(g)))
    return (k / min(5000, n)) >= 0.8

def strip_version(ids: pd.Index) -> pd.Index:
    return ids.astype(str).str.replace(r"\.\d+$", "", regex=True)

print("[load] striatum")
ad_s = sc.read_h5ad(STRI)

stri_names = pd.Index(ad_s.var_names.astype(str))
stri_ids   = pd.Index(ad_s.var["gene_ids"].astype(str)) if "gene_ids" in ad_s.var.columns else pd.Index([])
stri_ids_novers = strip_version(stri_ids) if len(stri_ids) else pd.Index([])

print(f"striatum vars: names={len(stri_names)}, gene_ids={len(stri_ids)}")

for lvl, path in mwb_hvg_paths.items():
    ad = sc.read_h5ad(path)
    hvgs = pd.Index(ad.var_names.astype(str))

    mwb_is_ens = is_ensembl_like(hvgs)
    print(f"\n[{lvl}] MWB HVGs = {len(hvgs)} | ensembl_like={mwb_is_ens}")

    # direct vs var_names
    direct = hvgs.intersection(stri_names)
    # vs gene_ids (as-is)
    ids_hit = hvgs.intersection(stri_ids) if len(stri_ids) else pd.Index([])
    # vs gene_ids (no version)
    ids_nv_hit = hvgs.intersection(stri_ids_novers) if len(stri_ids_novers) else pd.Index([])

    def pct(x): return f"{(100*len(x)/max(len(hvgs),1)):.2f}%"

    print(f"  overlap vs striatum.var_names:   {len(direct)} ({pct(direct)})")
    if len(ids_hit):
        print(f"  overlap vs striatum.gene_ids:    {len(ids_hit)} ({pct(ids_hit)})")
    if len(ids_nv_hit):
        print(f"  overlap vs gene_ids(no-version): {len(ids_nv_hit)} ({pct(ids_nv_hit)})")

    # show a few examples if any overlap
    for label, ix in [("names", direct), ("gene_ids", ids_hit), ("gene_ids_nover", ids_nv_hit)]:
        if len(ix) > 0:
            print(f"    e.g. ({label}): {list(ix[:8])}")
            break

[load] striatum
striatum vars: names=21579, gene_ids=21579

[neurotransmitter] MWB HVGs = 3500 | ensembl_like=True
  overlap vs striatum.var_names:   0 (0.00%)
  overlap vs striatum.gene_ids:    2512 (71.77%)
  overlap vs gene_ids(no-version): 2512 (71.77%)
    e.g. (gene_ids): ['ENSMUSG00000025900', 'ENSMUSG00000002459', 'ENSMUSG00000033740', 'ENSMUSG00000079671', 'ENSMUSG00000097893', 'ENSMUSG00000042501', 'ENSMUSG00000048960', 'ENSMUSG00000016918']

[class] MWB HVGs = 3500 | ensembl_like=True
  overlap vs striatum.var_names:   0 (0.00%)
  overlap vs striatum.gene_ids:    2512 (71.77%)
  overlap vs gene_ids(no-version): 2512 (71.77%)
    e.g. (gene_ids): ['ENSMUSG00000025900', 'ENSMUSG00000002459', 'ENSMUSG00000033740', 'ENSMUSG00000079671', 'ENSMUSG00000097893', 'ENSMUSG00000042501', 'ENSMUSG00000048960', 'ENSMUSG00000016918']

[subclass] MWB HVGs = 3500 | ensembl_like=True
  overlap vs striatum.var_names:   0 (0.00%)
  overlap vs striatum.gene_ids:    2512 (71.77%)
  overlap vs gen

In [67]:
boo.var_names

Index(['ENSMUSG00000051951', 'ENSMUSG00000089699', 'ENSMUSG00000102331',
       'ENSMUSG00000102343', 'ENSMUSG00000025900', 'ENSMUSG00000025902',
       'ENSMUSG00000104238', 'ENSMUSG00000104328', 'ENSMUSG00000033845',
       'ENSMUSG00000025903',
       ...
       'ENSMUSG00000096550', 'ENSMUSG00000094172', 'ENSMUSG00000094887',
       'ENSMUSG00000091585', 'ENSMUSG00000095763', 'ENSMUSG00000095523',
       'ENSMUSG00000095475', 'ENSMUSG00000094855', 'ENSMUSG00000095019',
       'ENSMUSG00000095041'],
      dtype='object', name='gene_identifier', length=32285)

In [68]:
striatum_clustered.var_names

Index(['Xkr4', 'Gm1992', 'Gm19938', 'Rp1', 'Mrpl15', 'Lypla1', 'Tcea1',
       'Rgs20', 'Gm16041', 'Atp6v1h',
       ...
       'AC168977.1', 'CAAA01118383.1', 'AC132444.5', 'Vamp7', 'Spry3', 'Tmlhe',
       '4933409K07Rik', 'Gm10931', 'CAAA01147332.1', 'AC149090.1'],
      dtype='object', length=21579)

In [70]:
boo

AnnData object with n_obs × n_vars = 590487 × 32285
    obs: 'cell_barcode', 'library_label', 'anatomical_division_label', 'neurotransmitter', 'class', 'subclass', 'supertype', 'cluster', 'feature_matrix_label'
    var: 'highly_variable', 'means', 'dispersions', 'dispersions_norm', 'highly_variable_nbatches', 'highly_variable_intersection'
    uns: 'hvg'

In [73]:
striatum_clustered.var["gene_ids"]

Xkr4              ENSMUSG00000051951
Gm1992            ENSMUSG00000089699
Gm19938           ENSMUSG00000102331
Rp1               ENSMUSG00000025900
Mrpl15            ENSMUSG00000033845
                         ...        
Tmlhe             ENSMUSG00000079834
4933409K07Rik     ENSMUSG00000095552
Gm10931           ENSMUSG00000094350
CAAA01147332.1    ENSMUSG00000095742
AC149090.1        ENSMUSG00000095041
Name: gene_ids, Length: 21579, dtype: object

In [71]:
def integer_mappings(group):
    unique_groups = sorted(cell_extended[group].unique())
    group_to_int = {groups: i for i, groups in enumerate(unique_groups)}

    cell_subset = cell_extended.copy()
    adata.obs[f'{group}_labels'] = cell_subset[group].map(group_to_int)
    print(f'# unique groups: {len(unique_groups)}')

    group_mapping = pd.DataFrame({
        f'{group}': list(group_to_int.keys()),
        'label': list(group_to_int.values())
    })

    return group_mapping

In [72]:
class_mapping = integer_mappings('class')
subclass_mapping = integer_mappings('subclass')
supertype_mapping = integer_mappings('supertype')
cluster_mapping = integer_mappings('cluster')

# unique groups: 34
# unique groups: 338
# unique groups: 1201
# unique groups: 5322


In [84]:
adata

AnnData object with n_obs × n_vars = 4042976 × 32285
    obs: 'cell_barcode', 'library_label', 'anatomical_division_label', 'class', 'subclass', 'supertype', 'cluster', 'class_labels', 'subclass_labels', 'supertype_labels', 'cluster_labels'

In [85]:
adata.write("/gpfs/scratch/blukacsy/abc_atlas/data/adata.h5ad")

In [8]:
adata = anndata.read_h5ad("/gpfs/scratch/blukacsy/abc_atlas/data/adata.h5ad")

In [9]:
adata

AnnData object with n_obs × n_vars = 4042976 × 32285
    obs: 'cell_barcode', 'library_label', 'anatomical_division_label', 'class', 'subclass', 'supertype', 'cluster', 'class_labels', 'subclass_labels', 'supertype_labels', 'cluster_labels'

In [26]:
adata.obs['feature_matrix_label'] = cell_extended['feature_matrix_label']

In [32]:
sc.pp.highly_variable_genes(adata, n_top_genes=3500, batch_key='feature_matrix_label')

In [33]:
adata

AnnData object with n_obs × n_vars = 4042976 × 32285
    obs: 'cell_barcode', 'library_label', 'anatomical_division_label', 'class', 'subclass', 'supertype', 'cluster', 'class_labels', 'subclass_labels', 'supertype_labels', 'cluster_labels', 'feature_matrix_label'
    var: 'highly_variable', 'means', 'dispersions', 'dispersions_norm', 'highly_variable_nbatches', 'highly_variable_intersection'
    uns: 'hvg'

In [42]:
adata_hvg = adata[:, adata.var['highly_variable'] ].copy()

In [57]:
adata_hvg

AnnData object with n_obs × n_vars = 4042976 × 3500
    obs: 'cell_barcode', 'library_label', 'anatomical_division_label', 'class', 'subclass', 'supertype', 'cluster', 'class_labels', 'subclass_labels', 'supertype_labels', 'cluster_labels', 'feature_matrix_label'
    var: 'highly_variable', 'means', 'dispersions', 'dispersions_norm', 'highly_variable_nbatches', 'highly_variable_intersection'
    uns: 'hvg'

In [87]:
def train_test_val_split(adata, dimension, group, seed1, seed2):
    
    if scipy.sparse.issparse(adata.X):
        X = adata.X.toarray()
    else:
        X = adata.X

    print(X.min())
    print(X.max())

    y = adata.obs[f'{group}_labels'].values

    indices = adata.obs.index.to_numpy()
    
    train_features, test_features, train_labels, test_labels, train_index, test_index = train_test_split(X, y, indices, test_size = 0.2, random_state=seed1, stratify=y)
    test_features, val_features, test_labels, val_labels, test_index, val_index = train_test_split(test_features, test_labels, test_index, test_size = 0.5, random_state=seed2, stratify=test_labels)
    weights = compute_class_weight(class_weight='balanced', classes=np.unique(train_labels), y=train_labels)

    train_features = np.array(train_features)
    test_features = np.array(test_features)
    val_features = np.array(val_features)
    train_labels = np.array(train_labels)
    test_labels = np.array(test_labels)
    val_labels = np.array(val_labels)
    weights = np.array(weights)
    train_index = np.array(train_index)
    test_index = np.array(test_index)
    val_index = np.array(val_index)

    print()
    print('Train features shape:', train_features.shape)
    print('Val features shape:', val_features.shape)
    print('Test features shape:', test_features.shape)
    print('Train labels shape:', train_labels.shape)
    print('Val labels shape:', val_labels.shape)
    print('Test labels shape:', test_labels.shape)
    print('Weights shape:', weights.shape)
    print()
    print('Train index shape:', train_index.shape)
    print('Test index shape:', test_index.shape)
    print('Val index shape:', val_index.shape)
    print()

    np.save(f'/gpfs/scratch/blukacsy/abc_atlas/arrays/train_features_{dimension}_{group}.npy', train_features)
    np.save(f'/gpfs/scratch/blukacsy/abc_atlas/arrays/test_features_{dimension}_{group}.npy', test_features)
    np.save(f'/gpfs/scratch/blukacsy/abc_atlas/arrays/val_features_{dimension}_{group}.npy', val_features)
    np.save(f'/gpfs/scratch/blukacsy/abc_atlas/arrays/train_labels_{dimension}_{group}.npy', train_labels)
    np.save(f'/gpfs/scratch/blukacsy/abc_atlas/arrays/test_labels_{dimension}_{group}.npy', test_labels)
    np.save(f'/gpfs/scratch/blukacsy/abc_atlas/arrays/val_labels_{dimension}_{group}.npy', val_labels)
    np.save(f'/gpfs/scratch/blukacsy/abc_atlas/arrays/weights_{dimension}_{group}.npy', weights)
    np.save(f'/gpfs/scratch/blukacsy/abc_atlas/arrays/train_index_{dimension}_{group}.npy', train_index)
    np.save(f'/gpfs/scratch/blukacsy/abc_atlas/arrays/test_index_{dimension}_{group}.npy', test_index)
    np.save(f'/gpfs/scratch/blukacsy/abc_atlas/arrays/val_index_{dimension}_{group}.npy', val_index)

In [83]:
class_mapping.to_csv('/gpfs/scratch/blukacsy/abc_atlas/data/class_mapping.csv')
subclass_mapping.to_csv('/gpfs/scratch/blukacsy/abc_atlas/data/subclass_mapping.csv')
supertype_mapping.to_csv('/gpfs/scratch/blukacsy/abc_atlas/data/supertype_mapping.csv')
cluster_mapping.to_csv('/gpfs/scratch/blukacsy/abc_atlas/data/cluster_mapping.csv')

In [84]:
adata.write("/gpfs/scratch/blukacsy/abc_atlas/data/adata.h5ad")
adata_hvg.write("/gpfs/scratch/blukacsy/abc_atlas/data/adata_hvg.h5ad")
print("done writing")

done writing


In [85]:
train_test_val_split(adata_hvg, "hvg", "class", 7105, 3870)

0.0
18.814487

Train features shape: (3234380, 3500)
Val features shape: (404298, 3500)
Test features shape: (404298, 3500)
Train labels shape: (3234380,)
Val labels shape: (404298,)
Test labels shape: (404298,)
Weights shape: (34,)

Train index shape: (3234380,)
Test index shape: (404298,)
Val index shape: (404298,)


In [88]:
train_test_val_split(adata_hvg, "hvg", "subclass", 7105, 3870)
train_test_val_split(adata_hvg, "hvg", "supertype", 7105, 3870)
train_test_val_split(adata_hvg, "hvg", "cluster", 7105, 3870)

0.0
18.814487

Train features shape: (3234380, 3500)
Val features shape: (404298, 3500)
Test features shape: (404298, 3500)
Train labels shape: (3234380,)
Val labels shape: (404298,)
Test labels shape: (404298,)
Weights shape: (338,)

Train index shape: (3234380,)
Test index shape: (404298,)
Val index shape: (404298,)

0.0
18.814487

Train features shape: (3234380, 3500)
Val features shape: (404298, 3500)
Test features shape: (404298, 3500)
Train labels shape: (3234380,)
Val labels shape: (404298,)
Test labels shape: (404298,)
Weights shape: (1201,)

Train index shape: (3234380,)
Test index shape: (404298,)
Val index shape: (404298,)

0.0
18.814487

Train features shape: (3234380, 3500)
Val features shape: (404298, 3500)
Test features shape: (404298, 3500)
Train labels shape: (3234380,)
Val labels shape: (404298,)
Test labels shape: (404298,)
Weights shape: (5322,)

Train index shape: (3234380,)
Test index shape: (404298,)
Val index shape: (404298,)



In [None]:
train_test_val_split(adata, "original", "class", 7105, 3870)

0.0
18.814487
