In [63]:
import torch
from transformers import DistilBertModel
from typing import Dict, List, Optional, Set, Tuple, Union
from transformers.modeling_outputs import BaseModelOutput

class SquishTransformer(DistilBertModel):
    def forward(
    self,
    input_ids: Optional[torch.Tensor] = None,
    attention_mask: Optional[torch.Tensor] = None,
    head_mask: Optional[torch.Tensor] = None,
    inputs_embeds: Optional[torch.Tensor] = None,
    output_attentions: Optional[bool] = None,
    output_hidden_states: Optional[bool] = None,
    return_dict: Optional[bool] = None,
) -> Union[BaseModelOutput, Tuple[torch.Tensor, ...]]:
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        
        if input_ids is not None and inputs_embeds is not None:
            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
        elif input_ids is not None:
            input_shape = input_ids.size()
        elif inputs_embeds is not None:
            input_shape = inputs_embeds.size()[:-1]
        else:
            raise ValueError("You have to specify either input_ids or inputs_embeds")

        device = input_ids.device if input_ids is not None else inputs_embeds.device

        if attention_mask is None:
            attention_mask = torch.ones(input_shape, device=device)  # (bs, seq_length)

        # Prepare head mask if needed
        head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)

        if inputs_embeds is None:
            inputs_embeds = self.embeddings(input_ids)  # (bs, seq_length, dim)
            
        
        return self.transformer(
            x=inputs_embeds,
            attn_mask=attention_mask,
            head_mask=head_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
    

model = SquishTransformer.from_pretrained('distilbert-base-uncased', cache_dir="/om2/user/rogerjin/.cache")
model.embeddings.word_embeddings = torch.nn.Embedding(116490, 768)

device = 'cpu'
# device = 'cuda:0'
_ = model.to(device)

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing SquishTransformer: ['vocab_transform.weight', 'vocab_transform.bias', 'vocab_layer_norm.bias', 'vocab_layer_norm.weight', 'vocab_projector.weight', 'vocab_projector.bias']
- This IS expected if you are initializing SquishTransformer from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing SquishTransformer from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [64]:
model(torch.tensor([[1,3]]))

hi


BaseModelOutput(last_hidden_state=tensor([[[ 0.4954,  0.3014,  0.2405,  ..., -0.3716, -0.0344, -0.1476],
         [ 0.5129,  0.5270,  0.1829,  ..., -0.4150,  0.0398, -0.0783]]],
       grad_fn=<NativeLayerNormBackward0>), hidden_states=None, attentions=None)

In [60]:
import scanpy as sc
import anndata as ad

sc._settings.ScanpyConfig.n_jobs = 4

atac_path = '/om2/user/rogerjin/data/NeurIPS2021/multiome/multiome_atac_processed_training_small.h5ad'
atac = sc.read_h5ad(atac_path)
atac

AnnData object with n_obs × n_vars = 64 × 116490
    obs: 'nCount_peaks', 'atac_fragments', 'reads_in_peaks_frac', 'blacklist_fraction', 'nucleosome_signal', 'cell_type', 'pseudotime_order_ATAC', 'batch', 'pseudotime_order_GEX', 'is_train'
    var: 'feature_types'
    uns: 'dataset_id', 'gene_activity_var_names', 'organism', 'sample_pm_varnames'
    obsm: 'gene_activity', 'lsi_full', 'lsi_red', 'umap'
    layers: 'counts'

In [61]:
from anndata.experimental.pytorch import AnnLoader
from squish_indexing import squish_and_embed

dataloader = AnnLoader(atac, batch_size=8, shuffle=True, use_cuda=False)

for batch in dataloader:
    counts = batch.layers['counts']

TypeError: squish_and_embed() missing 1 required positional argument: 'embeddings'

In [None]:
from transformers import 