In [20]:
import os
import json
import pprint

checkpoint_dir = '/om2/user/rogerjin/checkpoints/'
checkpoint_path = f'{checkpoint_dir}/super-wind-18/epoch=1-val_loss=12.54945.pt'
cache_dir = '/om2/user/rogerjin/.cache'

config_dir = '/om2/user/rogerjin/GANOLI/configs'
print(os.listdir(config_dir))
config_path = f'{config_dir}/default.json'
print(config_path)
config = json.load(open(config_path))
pprint.pprint(config)

['binary.json', 'binary_highlr.json', 'binary_relu.json', 'cls.json', 'cls_highlr.json', 'cls_veryhighlr.json', 'default.json', 'num_workers.json', 'relu.json', 'test.json', 'test_config.py']
/om2/user/rogerjin/GANOLI/configs/default.json
{'batch_size': 32,
 'epochs': 1000,
 'lr': 0.0005,
 'max_seq_len': 1600,
 'run_name': 'default'}


In [5]:
import torch

checkpoint = torch.load(checkpoint_path)
checkpoint.keys()

dict_keys(['epoch', 'model_state_dict', 'optimizer_state_dict', 'train/loss', 'train/best_loss', 'train/best_loss_epoch', 'val/loss', 'val/best_loss', 'val/best_loss_epoch'])

In [9]:
from transformers import DistilBertModel, DistilBertConfig
from torch.nn import Linear

class SquishTransformer(torch.nn.Module):
    
    def __init__(self, output_dim=13431):
        super().__init__()
        self.output_dim = output_dim
        self.distilbert = DistilBertModel.from_pretrained('distilbert-base-uncased', cache_dir=cache_dir)
        self.distilbert.embeddings.word_embeddings = torch.nn.Embedding(116491, 768) # todo: magic numbers
        self.pre_classifier = Linear(self.distilbert.config.dim, self.distilbert.config.dim)
        self.classifier = Linear(self.distilbert.config.dim, output_dim)
        
    def forward(self, **kwargs):
        out = self.distilbert(**kwargs).last_hidden_state[:, 0] # embedding of cls
        out = self.pre_classifier(out)
        out = self.classifier(out)
        return out

model = SquishTransformer()
model.load_state_dict(checkpoint['model_state_dict'])
# device = 'cpu'
device = 'cuda:0'
_ = model.to(device)

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertModel: ['vocab_layer_norm.bias', 'vocab_projector.bias', 'vocab_projector.weight', 'vocab_transform.weight', 'vocab_layer_norm.weight', 'vocab_transform.bias']
- This IS expected if you are initializing DistilBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [11]:
import scanpy as sc

remote_atac_dir = '/om2/user/rogerjin/data/NeurIPS2021/multiome/atac'
remote_rna_dir = '/om2/user/rogerjin/data/NeurIPS2021/multiome/rna'
remote_atac_path = '/om2/user/rogerjin/data/NeurIPS2021/multiome/multiome_atac_processed_training.h5ad'
remote_rna_path = '/om2/user/rogerjin/data/NeurIPS2021/multiome/multiome_gex_processed_training.h5ad'

atac = {
#     'train': sc.read_h5ad(f'{remote_atac_dir}/atac_train_sorted_decreasing_variance.h5ad'),
    'val': sc.read_h5ad(f'{remote_atac_dir}/atac_val_sorted_decreasing_variance.h5ad'),
    'test': sc.read_h5ad(f'{remote_atac_dir}/atac_test_sorted_decreasing_variance.h5ad')
}

rna = {
#     'train': sc.read_h5ad(f'{remote_rna_dir}/rna_train.h5ad'),
    'val': sc.read_h5ad(f'{remote_rna_dir}/rna_val.h5ad'),
    'test': sc.read_h5ad(f'{remote_rna_dir}/rna_test.h5ad')
}

In [12]:
from ganoli.GanoliDataset import GanoliMultimodalDataset
from muon import MuData

class MuDataWithLen(MuData):
    
    def __len__(self):
        try:
            return self._len
        except:
            self._len = min(len(mod) for mod in self.mod.values())
            return self._len

datasets = {
    partition: MuDataWithLen({'atac': atac[partition], 'rna': rna[partition]}) for partition in atac.keys()
}
datasets

{'val': MuData object with n_obs × n_vars = 4249 × 129921
   var:	'feature_types'
   2 modalities
     atac:	4249 x 116490
       obs:	'nCount_peaks', 'atac_fragments', 'reads_in_peaks_frac', 'blacklist_fraction', 'nucleosome_signal', 'cell_type', 'pseudotime_order_ATAC', 'batch', 'pseudotime_order_GEX', 'is_train'
       var:	'feature_types'
       uns:	'dataset_id', 'gene_activity_var_names', 'organism', 'sample_pm_varnames'
       obsm:	'gene_activity', 'lsi_full', 'lsi_red', 'umap'
       layers:	'counts'
     rna:	4249 x 13431
       obs:	'pct_counts_mt', 'n_counts', 'n_genes', 'size_factors', 'phase', 'cell_type', 'pseudotime_order_GEX', 'batch', 'pseudotime_order_ATAC', 'is_train'
       var:	'gene_ids', 'feature_types', 'genome'
       uns:	'dataset_id', 'organism'
       obsm:	'X_pca', 'X_umap'
       layers:	'counts',
 'test': MuData object with n_obs × n_vars = 4250 × 129921
   var:	'feature_types'
   2 modalities
     atac:	4250 x 116490
       obs:	'nCount_peaks', 'atac_fr

In [21]:
from muon import MuData as md
from torch.utils.data import DataLoader, BatchSampler, SequentialSampler, RandomSampler
torch.manual_seed(42)

samplers = {
    'train': RandomSampler,
    'val': SequentialSampler,
    'test': SequentialSampler
}

# todo: increase val/test batch size

loaders = {
    partition: DataLoader(dataset, sampler=BatchSampler(samplers[partition](dataset), batch_size=config['batch_size'], drop_last=False ), collate_fn=lambda x: x[0]) for partition, dataset in datasets.items()
}

In [24]:
from squish_indexing import squish_and_embed

def forward_pass(batch, use_binary=config.get('use_binary', False)):
    if use_binary:
        atac = batch.mod['atac'].X.tocsr().tocoo()
    else:
        atac = batch.mod['atac'].layers['counts'].tocsr().tocoo()
    squished = squish_and_embed(atac, model.distilbert.embeddings.word_embeddings, max_seq_len=config['max_seq_len'])
    out = model(inputs_embeds=squished['embeddings'], attention_mask=squished['attention_mask'])
    return out

In [26]:
for batch in loaders['val']:
    with torch.no_grad():
        display(forward_pass(batch))
    break

tensor([[-0.0049,  0.0184,  0.0211,  ...,  0.7366,  0.0163,  0.0040],
        [-0.0043,  0.0189,  0.0208,  ...,  0.7265,  0.0169,  0.0040],
        [ 0.0062,  0.0396,  0.0326,  ...,  0.4454,  0.0350,  0.0051],
        ...,
        [-0.0050,  0.0167,  0.0216,  ...,  0.7553,  0.0147,  0.0043],
        [-0.0050,  0.0147,  0.0219,  ...,  0.7745,  0.0129,  0.0034],
        [-0.0048,  0.0170,  0.0216,  ...,  0.7509,  0.0148,  0.0037]],
       device='cuda:0')