In [1]:
%load_ext autoreload

In [2]:
# Import necessary libraries and modules
import torch
from lightning.pytorch.callbacks import ModelCheckpoint, TQDMProgressBar, LearningRateMonitor
from lightning.pytorch.loggers import TensorBoardLogger
import os
import dask.dataframe as dd
import pickle
import numpy as np
import pandas as pd
from pathlib import Path

In [3]:
%autoreload 2
from self_supervision.estimator.cellnet import EstimatorAutoEncoder
from self_supervision.paths import DATA_DIR, TRAINING_FOLDER

In [4]:
# Define your large set of parameters
LARGE_PARAMS = {
    "decoder": False,
    "model": "VAE",
    "mask_rate": 0.5,
    "masking_strategy": "random",
    "gp_file": "C5",
    "weight_decay": 0.0,
    "dropout": 0.0,
    "batch_size": 8,
    "mask_type": "sparsemax",
    "version": "",
    "lr": 0.1,
    "hidden_units": [512, 512, 256, 256, 64],
    "checkpoint_interval": 1,
    "hvg": False,
    "num_hvgs": 2000,
    "missing_tolerance": 0,
    "data_path": os.path.join(DATA_DIR, "merlin_cxg_2023_05_15_sf-log1p"),
    "model_path": TRAINING_FOLDER,
}


In [5]:
root = os.path.dirname(os.path.abspath(os.curdir))
CHECKPOINT_PATH = "."

In [6]:
# get estimator
estim = EstimatorAutoEncoder(data_path=LARGE_PARAMS["data_path"], hvg=LARGE_PARAMS["hvg"])

# set up datamodule
estim.init_datamodule(batch_size=LARGE_PARAMS["batch_size"])

estim.init_trainer(
    trainer_kwargs={
        'max_epochs': 1000,
        'gradient_clip_val': 1.,
        'gradient_clip_algorithm': 'norm',
        'default_root_dir': CHECKPOINT_PATH,
        'accelerator': 'gpu',
        'devices': 1,
        'num_sanity_val_steps': 0,
        'check_val_every_n_epoch': 1,
        'logger': [TensorBoardLogger(CHECKPOINT_PATH, name='default')],
        'log_every_n_steps': 100,
        'detect_anomaly': False,
        'enable_progress_bar': True,
        'enable_model_summary': False,
        'enable_checkpointing': True,
        'callbacks': [
            TQDMProgressBar(refresh_rate=300),
            LearningRateMonitor(logging_interval='step'),
            ModelCheckpoint(filename='best_checkpoint_train', monitor='train_loss_epoch', mode='min',
                            every_n_epochs=LARGE_PARAMS["checkpoint_interval"], save_top_k=1),
            ModelCheckpoint(filename='best_checkpoint_val', monitor='val_loss', mode='min',
                            every_n_epochs=LARGE_PARAMS["checkpoint_interval"], save_top_k=1),
            ModelCheckpoint(filename='last_checkpoint', monitor=None),
        ],
    }
)

# init model
estim.init_model(
    model_type='mlp_ae' if LARGE_PARAMS["model"] == 'MLP' else 'mlp_vae',
    model_kwargs={
        'learning_rate': LARGE_PARAMS["lr"],
        'weight_decay': LARGE_PARAMS["weight_decay"],
        'dropout': LARGE_PARAMS["dropout"],
        'lr_scheduler': torch.optim.lr_scheduler.StepLR,
        'lr_scheduler_kwargs': {
            'step_size': 2,
            'gamma': 0.9,
            'verbose': True
        },
        'masking_strategy': LARGE_PARAMS["masking_strategy"],
        'masking_rate': LARGE_PARAMS["mask_rate"],
        # 'encoded_gene_program': encoded_gene_program,
        'units_encoder': LARGE_PARAMS["hidden_units"],
        'units_decoder': LARGE_PARAMS["hidden_units"][::-1][1:] if LARGE_PARAMS["decoder"] else [],
    },
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [7]:
dataloader = estim.datamodule.train_dataloader()

In [8]:
for i, batch in dataloader:
    print('batch: ', i)
    break

batch:  {'X': tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:0'), 'cell_type': tensor([ 12,  50, 124,  10, 162, 129, 151, 122], device='cuda:0'), 'dataset_id': tensor([ 35, 220, 168,   3, 124, 147, 160,  41], device='cuda:0')}


In [9]:
def get_count_matrix_and_obs(ddf):
    x = (
        ddf['X']
        .map_partitions(
            lambda xx: pd.DataFrame(np.vstack(xx.tolist())), 
            meta={col: 'f4' for col in range(19331)}
        )
        .to_dask_array(lengths=[1024] * ddf.npartitions)
    )
    obs = ddf[['cell_type', 'dataset_id']].compute()
    
    return x, obs

In [None]:
PATH = os.path.join(DATA_DIR, 'merlin_cxg_2023_05_15_sf-log1p')

In [11]:
ddf = dd.read_parquet(os.path.join(PATH, 'train'), split_row_groups=True)
x, obs = get_count_matrix_and_obs(ddf)
var = pd.read_parquet(os.path.join(PATH, 'var.parquet'))

In [30]:
target_id = "9f222629-9e39-47d0-b83f-e08d610c7479"  # HLCA
target_id = "53d208b0-2cfd-4366-9866-c3c6114081bc"  # Tabula Sapiens
target_id = "2a498ace-872a-4935-984b-1afa70fd9886"  # PBMC

In [31]:
dataset_id_mapping = pd.read_parquet(os.path.join(PATH, 'categorical_lookup/dataset_id.parquet'))

In [32]:
result = dataset_id_mapping[dataset_id_mapping['label'] == target_id]

if not result.empty:
    corresponding_int = result.index[0]
    print('corresponding int is: ', corresponding_int)
else:
    corresponding_int = None  # or some default value
    print('doesnt work')


corresponding int is:  41


Dataset_ID of the HLCA is encoded as 148

Dataset_ID of Tabula Sapiens is encoded as 87

Dataset_ID of PBMC is encoded as 41