In [1]:
%load_ext autoreload

In [16]:
import torch
import gc
import os

In [2]:
%autoreload 2
from self_supervision.estimator.cellnet import EstimatorAutoEncoder
from self_supervision.paths import DATA_DIR

In [3]:
STORE_DIR = os.path.join(DATA_DIR, 'merlin_cxg_2023_05_15_sf-log1p')

In [4]:
# init estim class
estim = EstimatorAutoEncoder(STORE_DIR, hvg=False)

In [13]:
# init datamodule
estim.init_datamodule(batch_size=4096) 

In [8]:
# init random clf model
estim.init_model(
                model_type='mlp_clf',
                model_kwargs={
                    'learning_rate': 1e-3,
                    'weight_decay': 0.1,
                    'lr_scheduler': torch.optim.lr_scheduler.StepLR,
                    'dropout': 0.1,
                    'lr_scheduler_kwargs': {
                        'step_size': 2,
                        'gamma': 0.9,
                        'verbose': True
                    },
                    'units': [512, 512, 256, 256, 64],
                    'supervised_subset': None,
                },
            )

In [9]:
for batch in estim.datamodule.train_dataloader():
    print('batch: ', batch)
    break

batch:  ({'X': tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:0'), 'cell_type': tensor([  3, 136, 122,  ...,  14,  19, 129], device='cuda:0'), 'dataset_id': tensor([168, 169, 168,  ..., 143, 178, 191], device='cuda:0')}, None)


In [17]:
def count_samples(dataloader):
    """
    Count the number of samples in a given dataloader.
    """
    sample_count = 0
    for batch in dataloader:
        sample_count += len(batch[0]['X'])
        gc.collect()  # Invoke garbage collector
    return sample_count

def count_cell_types(dataloader):
    """
    Count the number of unique cell types in a given dataloader.
    """
    unique_cell_types = set()
    for batch in dataloader:
        unique_cell_types.update(batch[0]['cell_type'].tolist())
        gc.collect()  # Invoke garbage collector
    return len(unique_cell_types)

def filter_by_dataset_id(dataloader, dataset_id):
    """
    Filter the batches by a given dataset_id.
    """
    filtered_batches = []
    for batch in dataloader:
        mask = batch[0]['dataset_id'] == dataset_id
        if any(mask):
            filtered_batches.append(({'X': batch[0]['X'][mask], 
                                      'cell_type': batch[0]['cell_type'][mask], 
                                      'dataset_id': batch[0]['dataset_id'][mask]}, 
                                      batch[1]))
    return filtered_batches

def print_dataset_info(train_dataloader, val_dataloader, test_dataloader, dataset_id=None, dataset_name=""):
    """
    Calculate and print the dataset information.
    """
    if dataset_id is not None:
        print(f"Information for dataset '{dataset_name}' (ID: {dataset_id}):")
        train_dataloader = filter_by_dataset_id(train_dataloader, dataset_id)
        val_dataloader = filter_by_dataset_id(val_dataloader, dataset_id)
        test_dataloader = filter_by_dataset_id(test_dataloader, dataset_id)
    else:
        print("Information for the complete dataset:")

    train_samples = count_samples(train_dataloader)
    val_samples = count_samples(val_dataloader)
    test_samples = count_samples(test_dataloader)

    train_cell_types = count_cell_types(train_dataloader)
    val_cell_types = count_cell_types(val_dataloader)
    test_cell_types = count_cell_types(test_dataloader)

    print(f"Train samples: {train_samples}")
    print(f"Validation samples: {val_samples}")
    print(f"Test samples: {test_samples}")

    print(f"Unique cell types in Train set: {train_cell_types}")
    print(f"Unique cell types in Validation set: {val_cell_types}")
    print(f"Unique cell types in Test set: {test_cell_types}")
    print("\n")

In [None]:
# Example usage:
print_dataset_info(estim.datamodule.train_dataloader(), 
                   estim.datamodule.val_dataloader(), 
                   estim.datamodule.test_dataloader())

Information for the complete dataset:


In [None]:
# For specific dataset_ids
dataset_ids = {'HLCA': 148, 'Tabula Sapiens': 87, 'PBMC': 41}
for name, dataset_id in dataset_ids.items():
    print_dataset_info(estim.datamodule.train_dataloader(), 
                       estim.datamodule.val_dataloader(), 
                       estim.datamodule.test_dataloader(), 
                       dataset_id, name)