In [None]:
import os
import sys    
import yaml
from lamin_dataloader.dataset import GeneIdTokenizer
from datamodule import MappedCollectionDataModule
from pathlib import Path
import pandas as pd

In [3]:
%load_ext autoreload
%autoreload 2

In [None]:
# data structure:
# - data_root
#   - gene_token_mapping.pkl
#   - h5ads
#     - file1.h5ad
#     - file2.h5ad

data_root = Path('./temp_data')
dataset_path = data_root / 'h5ads/'
os.makedirs(data_root, exist_ok=True)
os.makedirs(dataset_path, exist_ok=True)

In [None]:
import subprocess
import anndata as ad

for filename in ["4724c395-0c46-46d2-81f7-60fd271fb488.h5ad", '01209dce-3575-4bed-b1df-129f57fbc031.h5ad']:
    bash_command = f"curl -o {dataset_path}/{filename} 'https://datasets.cellxgene.cziscience.com/{filename}'"
    subprocess.run(bash_command, shell=True)
    
adata = ad.read_h5ad(data_root / '4724c395-0c46-46d2-81f7-60fd271fb488.h5ad')
gene_mapping = adata.var_names.to_series().to_pickle(data_root / 'gene_token_mapping.pkl')

In [64]:
gene_mapping_path = data_root / 'gene_token_mapping.pkl'
gene_mapping = pd.read_pickle(gene_mapping_path).to_dict()
# gene_mapping = {
#   '<cls>': 0,
#   '<pad>': 1,
#   'ENSG00000233576': 2,
#   'ENSG00000121410': 3,
#   'ENSG00000268895': 4,
#   'ENSG00000148584': 5,
#   'ENSG00000175899': 6,
#   ...
# }

split = {'train': ['4724c395-0c46-46d2-81f7-60fd271fb488.h5ad', 
                   '01209dce-3575-4bed-b1df-129f57fbc031.h5ad'
                   ],
          'val':  ['4724c395-0c46-46d2-81f7-60fd271fb488.h5ad', 
                   '01209dce-3575-4bed-b1df-129f57fbc031.h5ad'
                   ]
          }



dataset_args = {
  'train':{
    'max_tokens': 1000, 
    }, 
  'val':{
    'max_tokens': 1000, 
    }
  }
dataloader_args = {
  'train':
    {
    'shuffle': True,
    'num_workers': 1,
    'drop_last': True,
    'batch_size': 32,
    'num_samples': None,
    'pin_memory': True,
    'filter': {'cell_type': ['neuron', 'ependymal cell'], 'sex': ['male']},
    },
  'val':
    {
    'shuffle': True,
    'num_workers': 1,
    'drop_last': True,
    'batch_size': 32,
    'num_samples': None,
    'pin_memory': True,
    },
}

datamodule_args = {
    'dataset_path': dataset_path,
    'split': split,
    'columns': ['cell_type', 'sex'],
    'normalization': 'log1p',
    'gene_sampling_strategy': 'top', # 'random' or 'top'
    'dataset_kwargs': dataset_args,
    'dataloader_kwargs': dataloader_args,
    'tokenizer': GeneIdTokenizer(gene_mapping)
}

datamodule = MappedCollectionDataModule(**datamodule_args)

Caching cell_type...
Caching sex...
Dataset 0: 59357 / 59357 tokens
Dataset 1: 54765 / 54765 tokens
coverage macro: 1.0
covarage micro: 1.0
Dataset 0: 59357 / 59357 tokens
Dataset 1: 54765 / 54765 tokens
coverage macro: 1.0
covarage micro: 1.0


In [65]:
train_loader = datamodule.train_dataloader() 
for i, batch in enumerate(train_loader):
    break

Creating train dataloader by 690 batches of size 32 taking 22080 samples from 87226 total samples; num_replicas=1; sum of indices: 1970327325; num_workers=1


In [66]:
batch

{'tokens': tensor([[12598, 30405,  2578,  ...,  5882, 15036, 33474],
         [12598, 20065, 30405,  ..., 17867,  1023, 15333],
         [12598, 30405,  4399,  ..., 19859, 24154, 19100],
         ...,
         [12598, 30405,  2578,  ..., 16171,  6910, 28087],
         [12598, 30099,  2971,  ..., 44072, 20280, 16712],
         [12598, 30099, 21314,  ..., 15407, 41295,  7714]]),
 'values': tensor([[7.0800, 5.7236, 5.3132,  ..., 1.7918, 1.7918, 1.7918],
         [8.1637, 5.3936, 5.1705,  ..., 1.7918, 1.7918, 1.7918],
         [7.0892, 4.7362, 4.6052,  ..., 1.3863, 1.3863, 1.3863],
         ...,
         [7.0934, 5.8289, 5.4972,  ..., 1.7918, 1.7918, 1.7918],
         [6.3986, 4.2195, 4.0775,  ..., 1.0986, 1.0986, 1.0986],
         [5.5947, 4.3944, 4.3041,  ..., 0.6931, 0.6931, 0.6931]]),
 'dataset_id': tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0]),
 'cell_type': tensor([12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 

In [67]:
val_loader = datamodule.val_dataloader() 
for i, batch in enumerate(val_loader):
    break

Creating train dataloader by 2725 batches of size 32 taking 87200 samples from 87226 total samples; num_replicas=1; sum of indices: 1970327325; num_workers=1


In [68]:
batch

{'tokens': tensor([[22467,  5734, 22646,  ..., 11946, 29187, 52260],
         [12598,  8151, 16572,  ...,  8375, 29801, 20108],
         [12598,  7353,  5388,  ..., 23629, 23418, 10772],
         ...,
         [ 8057, 29893, 26825,  ..., 39853, 38195, 59538],
         [12598, 30405,  2578,  ...,  2710,  1006, 19626],
         [26825, 22646, 29893,  ..., 28781, 53240, 25462]]),
 'values': tensor([[3.7612, 3.7377, 3.7136,  ..., 0.0000, 0.0000, 0.0000],
         [5.1180, 3.8501, 3.8286,  ..., 0.6931, 0.6931, 0.6931],
         [4.9127, 3.9120, 3.8918,  ..., 0.6931, 0.6931, 0.6931],
         ...,
         [4.1744, 4.0254, 3.7612,  ..., 0.0000, 0.0000, 0.0000],
         [7.2189, 6.3421, 5.7104,  ..., 1.9459, 1.9459, 1.9459],
         [3.3673, 3.2189, 3.1355,  ..., 0.0000, 0.0000, 0.0000]]),
 'dataset_id': tensor([1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0,
         0, 1, 1, 1, 1, 1, 0, 1]),
 'cell_type': tensor([ 0,  4,  5,  2, 12,  2, 12, 12, 12, 12,  0,  9,  0, 