## Useful links

- DistilBertModel [[class]](https://github.com/huggingface/transformers/blob/06a6a4bd516f7d0ba7c4966a2d3d9c0bf07797ae/src/transformers/models/distilbert/modeling_distilbert.py#L459) [[forward]](https://github.com/huggingface/transformers/blob/06a6a4bd516f7d0ba7c4966a2d3d9c0bf07797ae/src/transformers/models/distilbert/modeling_distilbert.py#L538)
- [anndata x PyTorch](https://anndata-tutorials.readthedocs.io/en/latest/annloader.html)

In [2]:
%load_ext autoreload
%autoreload 2

In [15]:
import torch
LOCAL = not torch.cuda.is_available()
cache_dir=None if LOCAL else "/om2/user/rogerjin/.cache"
local_atac_path = '/home/rogerjin/Dropbox/Research/Kellis/masters/data/neurips2021/multiome_atac_processed_training_small.h5ad'
remote_atac_path = '/om2/user/rogerjin/data/NeurIPS2021/multiome/multiome_atac_processed_training_small.h5ad'
atac_path = local_atac_path if LOCAL else remote_atac_path
print("local" if LOCAL else "remote")

local


In [128]:
from transformers import DistilBertModel

model = DistilBertModel.from_pretrained('distilbert-base-uncased', cache_dir=cache_dir)
model.embeddings.word_embeddings = torch.nn.Embedding(116491, 768)

device = 'cpu'
# device = 'cuda:0'
_ = model.to(device)

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertModel: ['vocab_transform.bias', 'vocab_projector.weight', 'vocab_layer_norm.bias', 'vocab_projector.bias', 'vocab_transform.weight', 'vocab_layer_norm.weight']
- This IS expected if you are initializing DistilBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [17]:
from squish_indexing import squish_and_embed

atac = torch.tensor([[1, 10, 0, 0, 0, 100], [0, 0, 100, 0, 1, 0]])
# manually checked that the embedding looks correct
atac_embed = squish_and_embed(atac, model.embeddings.word_embeddings)
atac_embed

tensor([[[ 1.2857e+00, -9.5825e-01, -8.6744e-03,  ..., -1.3337e-01,
          -8.8122e-01, -9.9369e-01],
         [ 1.1567e+00, -1.0683e+01, -1.6499e+01,  ...,  1.2961e+01,
           1.1456e+01,  2.2055e+01],
         [ 2.2540e+01, -8.0246e+00, -2.9985e+01,  ...,  3.1569e+01,
          -1.1161e+02, -3.2046e+01]],

        [[ 1.3465e+02,  7.3956e+00,  3.8956e+01,  ..., -1.7088e+02,
           1.1801e+01,  7.1604e+01],
         [ 4.5243e-01,  6.9270e-01, -1.2919e+00,  ..., -1.8618e+00,
          -2.6889e-02, -9.6906e-01],
         [-0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., -0.0000e+00,
           0.0000e+00,  0.0000e+00]]], grad_fn=<MulBackward0>)

In [18]:
import pandas as pd
import scanpy as sc
atac = sc.read_h5ad(atac_path)

In [46]:
from anndata.experimental.pytorch import AnnLoader
from squish_indexing import squish_and_embed
from scipy.sparse import issparse

transform = {
    'layers': {
        # 'counts': lambda counts_vec: squish_and_embed(torch.tensor(counts_vec), model.embeddings.word_embeddings)
        'counts': lambda counts_vec: counts_vec
    }
}

dataloader = AnnLoader(atac, batch_size=8, shuffle=True, convert=transform, use_cuda=False, use_default_converter=False)

for batch in dataloader:
    display(batch)
    print(batch.layers['counts'].data)
    print(batch.layers['counts'].indices)
    print(batch.layers['counts'].indptr)
    print('coo')
    coo = batch.layers['counts'].tocsr().tocoo()
    print(batch.layers['counts'].tocoo().data)
    print(batch.layers['counts'].tocoo().row.shape)
    print(batch.layers['counts'].tocoo().col.shape)
    break

AnnCollectionView object with n_obs × n_vars = 8 × 116490
    obsm: 'gene_activity', 'lsi_full', 'lsi_red', 'umap'
    layers: 'counts'
    obs: 'nCount_peaks', 'atac_fragments', 'reads_in_peaks_frac', 'blacklist_fraction', 'nucleosome_signal', 'cell_type', 'pseudotime_order_ATAC', 'batch', 'pseudotime_order_GEX', 'is_train'

[4. 2. 2. ... 2. 1. 2.]
[5 7 3 ... 0 3 0]
[    0     0     0 ... 23013 23013 23013]
coo
[4. 2. 2. ... 2. 1. 2.]
(23013,)
(23013,)


In [129]:
import functorch

DEFAULT_ARANGE_LEN = 10
INDEX_PAD_TOKEN = 116490
COUNT_PAD_TOKEN = 0
BATCH_SIZE = 8

def cyclic_arange(lengths, arange=torch.arange(DEFAULT_ARANGE_LEN)):
    max_len = lengths.max().item()
    if max_len > DEFAULT_ARANGE_LEN:
        arange = torch.arange(max_len)
    arange_vmap = functorch.vmap(lambda length: (arange + 1) * (arange < length))
    vmapped = arange_vmap(lengths).flatten()
    return vmapped[vmapped.nonzero()].flatten() - 1

def index_and_pad(indices, data, pad_token):
    return torch.sparse_coo_tensor(indices, data - pad_token).to_dense() + pad_token

def squish_and_embed(batch_coo, embedding):
    rows = torch.tensor(batch_coo.row)
    _, num_nonzeros = torch.unique_consecutive(rows, return_counts=True)
    assert _.shape[0] == BATCH_SIZE # otherwise there's an all-0 sequence in the batch
    max_seq_len = num_nonzeros.max().item()
    sparse_indices = torch.stack([rows, cyclic_arange(num_nonzeros)])
    squish_indices = index_and_pad(sparse_indices, batch_coo.col, INDEX_PAD_TOKEN)
    counts = index_and_pad(sparse_indices, batch_coo.data, COUNT_PAD_TOKEN)
    display(squish_indices)
    display(counts)
    return {
        'indices': squish_indices,
        'counts': counts,
        'squish_embeddings': embedding(squish_indices) * counts.unsqueeze(-1)
    }

In [130]:
output, counts = torch.unique_consecutive(torch.tensor(coo.row), return_counts=True)
PAD_TOKEN = 0
seq_len = counts.max().item()
num_rows = output.shape[0]
indices = torch.stack([torch.tensor(coo.row), cyclic_arange(torch.tensor(counts))])
display(indices)
count_tensor = torch.sparse_coo_tensor(indices, coo.data - PAD_TOKEN).to_dense() + PAD_TOKEN
index_tensor = torch.sparse_coo_tensor(indices, coo.col - PAD_TOKEN).to_dense() + PAD_TOKEN
display(output)
display(counts)
display(seq_len)
display(count_tensor)
display(index_tensor)

  indices = torch.stack([torch.tensor(coo.row), cyclic_arange(torch.tensor(counts))])


tensor([[   0,    0,    0,  ...,    7,    7,    7],
        [   0,    1,    2,  ..., 3215, 3216, 3217]])

tensor([0, 1, 2, 3, 4, 5, 6, 7], dtype=torch.int32)

tensor([3596, 2013, 1309, 2753, 3675, 5674,  775, 3218])

5674

tensor([[ 6.,  2.,  2.,  ...,  0.,  0.,  0.],
        [ 2.,  2.,  2.,  ...,  0.,  0.,  0.],
        [ 2.,  2.,  2.,  ...,  0.,  0.,  0.],
        ...,
        [ 4., 18.,  2.,  ...,  2.,  2.,  2.],
        [ 2.,  2.,  2.,  ...,  0.,  0.,  0.],
        [ 2.,  4.,  4.,  ...,  0.,  0.,  0.]], dtype=torch.float64)

tensor([[     6,     21,     43,  ...,      0,      0,      0],
        [     6,     46,     52,  ...,      0,      0,      0],
        [   117,    191,    196,  ...,      0,      0,      0],
        ...,
        [     5,      6,     21,  ..., 116446, 116448, 116457],
        [     6,     10,     72,  ...,      0,      0,      0],
        [     5,      6,     86,  ...,      0,      0,      0]],
       dtype=torch.int32)

In [131]:
squish_and_embed(coo, model.embeddings.word_embeddings)

tensor([[     6,     21,     43,  ..., 116490, 116490, 116490],
        [     6,     46,     52,  ..., 116490, 116490, 116490],
        [   117,    191,    196,  ..., 116490, 116490, 116490],
        ...,
        [     5,      6,     21,  ..., 116446, 116448, 116457],
        [     6,     10,     72,  ..., 116490, 116490, 116490],
        [     5,      6,     86,  ..., 116490, 116490, 116490]],
       dtype=torch.int32)

tensor([[ 6.,  2.,  2.,  ...,  0.,  0.,  0.],
        [ 2.,  2.,  2.,  ...,  0.,  0.,  0.],
        [ 2.,  2.,  2.,  ...,  0.,  0.,  0.],
        ...,
        [ 4., 18.,  2.,  ...,  2.,  2.,  2.],
        [ 2.,  2.,  2.,  ...,  0.,  0.,  0.],
        [ 2.,  4.,  4.,  ...,  0.,  0.,  0.]], dtype=torch.float64)

{'indices': tensor([[     6,     21,     43,  ..., 116490, 116490, 116490],
         [     6,     46,     52,  ..., 116490, 116490, 116490],
         [   117,    191,    196,  ..., 116490, 116490, 116490],
         ...,
         [     5,      6,     21,  ..., 116446, 116448, 116457],
         [     6,     10,     72,  ..., 116490, 116490, 116490],
         [     5,      6,     86,  ..., 116490, 116490, 116490]],
        dtype=torch.int32),
 'counts': tensor([[ 6.,  2.,  2.,  ...,  0.,  0.,  0.],
         [ 2.,  2.,  2.,  ...,  0.,  0.,  0.],
         [ 2.,  2.,  2.,  ...,  0.,  0.,  0.],
         ...,
         [ 4., 18.,  2.,  ...,  2.,  2.,  2.],
         [ 2.,  2.,  2.,  ...,  0.,  0.,  0.],
         [ 2.,  4.,  4.,  ...,  0.,  0.,  0.]], dtype=torch.float64),
 'squish_embeddings': tensor([[[ 2.8917e+00, -8.3627e+00, -3.2458e+00,  ..., -3.7075e+00,
           -8.5471e+00, -3.6615e+00],
          [-3.3926e+00,  7.1447e-01,  2.6079e+00,  ..., -2.4539e-01,
           -2.2277e+00, -9.100