In [2]:
import hydra
from omegaconf import OmegaConf
from glob import  glob
# from rosa import  predict
from rosa.data import create_io_paths, RosaDataModule
from rosa.modeling.modules import RosaLightningModule
from pytorch_lightning import Trainer


BASE_DIR = "/Users/nsofroniew/Documents/data/rosa/outputs/2023-02-12/21-58-13"
config_dir = BASE_DIR + "/.hydra"

with hydra.initialize_config_dir(config_dir=config_dir):
    config = hydra.compose(config_name="config", overrides=OmegaConf.load(config_dir + "/overrides.yaml"))

    chkpts = BASE_DIR + "/checkpoints/*.ckpt"
    chkpt = glob(chkpts)[1]

    _, output_path = create_io_paths(config.paths)

    # Create Data Module
    rdm = RosaDataModule(
        output_path,
        config=config.data_module,
    )
    rdm.setup()

    # Load model from checkpoint
    rlm = RosaLightningModule.load_from_checkpoint(
        chkpt,
        in_dim=rdm.len_input,
        out_dim=rdm.len_target,
        config=config.module,
    )
    print(rlm)

    trainer = Trainer()
    predictions = trainer.predict(rlm, rdm)
    rdm.predict_dataset.predict(predictions)
    adata = rdm.predict_dataset.adata
    # adata = predict(config, chkpt)

display(adata)

Global seed set to 0
  new_rank_zero_deprecation(
  return new_rank_zero_deprecation(*args, **kwargs)
The version_base parameter is not specified.
Please specify a compatability version level, or None.
Will assume defaults for version 1.1
  with hydra.initialize_config_dir(config_dir=config_dir):


RosaLightningModule(
  (model): RosaSingleModel(
    (main): Sequential(
      (layer_norm): LayerNorm((3072,), eps=1e-05, elementwise_affine=True)
      (input_embed): Identity()
      (feed_forward): Identity()
      (dropout): Dropout(p=0.5, inplace=False)
      (expression_head): ProjectionExpressionHead(
        (model): Sequential(
          (projection): Linear(in_features=3072, out_features=369, bias=True)
          (activation): Softplus(beta=1, threshold=20)
        )
      )
    )
  )
)


GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_warn(
Missing logger folder: /Users/nsofroniew/Documents/GitHub/rosa/notebooks/lightning_logs
  rank_zero_warn(


Predicting: 0it [00:00, ?it/s]

AnnData object with n_obs × n_vars = 369 × 19431
    obs: 'dataset_id', 'cell_type', 'cell_type_ontology_term_id', 'development_stage', 'development_stage_ontology_term_id', 'disease', 'disease_ontology_term_id', 'donor_id', 'is_primary_data', 'self_reported_ethnicity', 'self_reported_ethnicity_ontology_term_id', 'sex', 'sex_ontology_term_id', 'suspension_type', 'label', 'sample', 'n_genes', 'train', 'marker_gene', 'marker_feature_name'
    var: 'soma_joinid', 'feature_name', 'feature_length', 'column_1', 'column_2', 'column_3', 'column_4', 'external_gene_name', 'gene_biotype', 'train', 'highly_variable', 'means', 'dispersions', 'dispersions_norm'
    uns: 'dendrogram_label', 'hvg', 'log1p', 'obs_embedding_pca', 'preprocessing', 'rank_genes_groups', 'var_embedding_pca'
    obsm: 'bin_edges', 'embedding'
    varm: 'embedding', 'embedding_pca'
    layers: 'binned', 'counts', 'log1p', 'normalized_counts', 'prediction'

In [3]:
train_cells = adata.obs['train']
train_genes = adata.var['train']

In [4]:
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm

In [5]:
# imports from captum library
from captum.attr import IntegratedGradients, InputXGradient

In [6]:
# ig = IntegratedGradients(rlm)
ixg = InputXGradient(rlm)

In [7]:
test_input_tensor = rdm.predict_dataset[:][0]
target_tensor = rdm.predict_dataset[:][1]

In [8]:
rlm.forward(test_input_tensor).shape

torch.Size([19431, 369])

In [9]:
import zarr

results_shape = (len(rdm.predict_dataset), 50, rdm.len_input) # rdm.len_target
# results_shape = rdm.predict_dataset.adata.shape + (rdm.len_input,)
z = zarr.zeros(results_shape, chunks=(1, None, None), dtype=np.float32)

ind = 0
for x, y in tqdm(iter(rdm.predict_dataloader(batch_size=1)), leave=False):
    for target in tqdm(range(2)):
        x.requires_grad_()
        attribution = ixg.attribute(x, target=target)
        attribution = attribution.detach().numpy()
        # z[target, ind:ind+len(x), :] = attribution # for gene dataset
        z[ind:ind+len(x), target, :] = attribution # for cell dataset
        #### for joint have to do something more clever ...... maybe swap iterators
    ind += len(x)

100%|██████████| 2/2 [00:04<00:00,  2.29s/it]
100%|██████████| 2/2 [00:04<00:00,  2.20s/it]
100%|██████████| 2/2 [00:04<00:00,  2.26s/it]
100%|██████████| 2/2 [00:04<00:00,  2.24s/it]
100%|██████████| 2/2 [00:03<00:00,  1.67s/it]
                                             

In [12]:
z

<zarr.core.Array (19431, 50, 3072) float32>

In [None]:
full_attrs = []
n_targets = 20 #target_tensor.shape[1]
for targ in tqdm(range(n_targets)):
    test_input_tensor.requires_grad_()
    attr = ig.attribute(test_input_tensor, target=targ)
    attr = attr.detach().numpy()
    full_attrs.append(attr)
full_attrs = np.stack(full_attrs, axis=1)

In [None]:
plt.bar(np.arange(full_attrs.shape[2]), full_attrs[0, 0]);

In [None]:
full_attrs.shape

In [None]:
plt.bar(np.arange(full_attrs.shape[2]), np.mean(np.abs(full_attrs), axis=(0, 1)));
plt.xlabel('n emb');
plt.ylabel('mean attribution score');

In [None]:
plt.bar(np.arange(full_attrs.shape[2]), np.sort(np.mean(np.abs(full_attrs), axis=(0, 1)))[::-1]);
plt.xlabel('n emb');
plt.ylabel('mean attribution score');

In [None]:
plt.imshow(np.mean(np.abs(full_attrs), axis=(0,))[:, :85]);
plt.xlabel('n emb');
plt.ylabel('target gene');

In [None]:
import seaborn as sns

In [None]:
sns.clustermap(np.mean(np.abs(full_attrs), axis=0)[:, :])

## Switch to full enformer based model

In [None]:
from enformer_pytorch import Enformer


MODEL_PT = "EleutherAI/enformer-official-rough"
SEQ_EMBED_DIM = 896
EMBED_DIM = 3072
TSS = int(SEQ_EMBED_DIM // 2)

In [None]:
import torch.nn as nn
import torch

class FullModel(nn.Module):
    def __init__(self, rlm):
        super(FullModel, self).__init__()
        self.enformer = Enformer.from_pretrained(MODEL_PT, output_heads=dict(), use_checkpointing=False)
        self.model = rlm.model
        self.tss = TSS

    def forward(self, seq):
        _, embeddings = self.enformer(seq, return_embeddings=True)
        x = embeddings[:, self.tss]
        return self.model(x)

In [None]:
model = FullModel(rlm)

In [None]:
from enformer_pytorch import seq_indices_to_one_hot

test_input_tensor = torch.randint(0, 5, (1, 196_608)) # for ACGTN, in that order (-1 for padding)
test_input_tensor = seq_indices_to_one_hot(test_input_tensor)

In [None]:
test_input_tensor.shape

In [None]:
model(test_input_tensor).shape

In [None]:
# ig = IntegratedGradients(model)
ig = InputXGradient(model)

full_attrs_G = []
for targ in tqdm(range(3)):
    test_input_tensor.requires_grad_()
    attr = ig.attribute(test_input_tensor, target=targ)
    # sum across all nucleotides (note value 0 for non used value)
    attr = attr.sum(dim=-1)
    # pool across sliding window
    # attr = torch.nn.functional.avg_pool1d(attr, 128)
    attr = attr.detach().numpy()
    full_attrs_G.append(attr)
full_attrs_G = np.stack(full_attrs_G, axis=1)

In [None]:
from scipy.signal import convolve

full_attrs_G_ds = convolve(full_attrs_G, (np.ones(128)/128)[np.newaxis, np.newaxis, :])[:, :, ::128]

In [None]:
full_attrs_G_ds.shape

In [None]:
plt.bar(np.arange(full_attrs_G_ds.shape[2]), np.mean(np.abs(full_attrs_G_ds), axis=(0, 1)));

In [None]:
plt.bar(np.arange(full_attrs_G_ds.shape[2]), np.sort(np.mean(np.abs(full_attrs_G_ds), axis=(0, 1)))[::-1]);

In [None]:
from enformer_pytorch import GenomeIntervalDataset
from torch.utils.data import DataLoader


BASE_PT = "/Users/nsofroniew/Documents/data/multiomics/enformer"

FASTA_PT = BASE_PT + "/Homo_sapiens.GRCh38.dna.toplevel.fa"
GENE_INTERVALS_PT = BASE_PT + "/Homo_sapiens.GRCh38.genes.bed"
EMBEDDING_PT = BASE_PT + "/Homo_sapiens.GRCh38.genes.enformer_embeddings.zarr"
EMBEDDING_PT_TSS = BASE_PT + "/Homo_sapiens.GRCh38.genes.enformer_embeddings_tss.zarr"
MODEL_PT = "EleutherAI/enformer-official-rough"

# print("Converting fasta file")
# pyfaidx.Faidx(FASTA_PT)
# print("Fasta file done")

# model = Enformer.from_pretrained(MODEL_PT, output_heads=dict(), use_checkpointing = False)
# model.to(DEVICE)

class MyGenomeIntervalDataset(GenomeIntervalDataset):
    def __init__(self, **kwargs):
        super(MyGenomeIntervalDataset, self).__init__(**kwargs)

    def __getitem__(self, ind):
        item = super().__getitem__(ind)
        label = self.df.row(ind)[4]
        return label, item


ds = MyGenomeIntervalDataset(
    bed_file=GENE_INTERVALS_PT,  # bed file - columns 0, 1, 2 must be <chromosome>, <start position>, <end position>
    fasta_file=FASTA_PT,  # path to fasta file
    return_seq_indices=False,  # return nucleotide indices (ACGTN) or one hot encodings
    rc_aug=False,
)
dl = DataLoader(ds, batch_size=1, shuffle=False, num_workers=0) # type: DataLoader

In [None]:
ds[0][1]

In [None]:
from enformer_pytorch import seq_indices_to_one_hot

test_input_tensor = torch.stack([ds[i][1] for i in range(2)])

In [None]:
test_input_tensor.shape

In [None]:
# ig = IntegratedGradients(model)
ig = InputXGradient(model)

full_attrs_G = []
for targ in tqdm(range(3)):
    test_input_tensor.requires_grad_()
    attr = ig.attribute(test_input_tensor, target=targ)
    # sum across all nucleotides (note value 0 for non used value)
    attr = attr.sum(dim=-1)
    # pool across sliding window
    # attr = torch.nn.functional.avg_pool1d(attr, 128)
    attr = attr.detach().numpy()
    full_attrs_G.append(attr)
full_attrs_G = np.stack(full_attrs_G, axis=1)

In [None]:
plt.plot(np.arange(full_attrs_G.shape[2]), np.mean(np.abs(full_attrs_G), axis=(0, 1)), lw=0.05);

In [None]:
plt.plot(np.arange(full_attrs_G.shape[2]), np.mean(np.abs(full_attrs_G), axis=(1,))[0] / np.mean(np.abs(full_attrs_G), axis=(1,))[0].max(), lw=0.05);
plt.plot(np.arange(full_attrs_G.shape[2]), np.mean(np.abs(full_attrs_G), axis=(1,))[1] / np.mean(np.abs(full_attrs_G), axis=(1,))[1].max(), lw=0.05, alpha=0.5);

In [None]:
plt.plot(np.arange(full_attrs_G_ds.shape[2]), np.mean(np.abs(full_attrs_G_ds), axis=(0, 1)), lw=0.5);

In [None]:
plt.plot(np.arange(full_attrs_G.shape[2]), np.sort(np.mean(np.abs(full_attrs_G), axis=(0, 1)))[::-1], lw=1);

In [None]:
keep = slice(66_300, 66_800)
plt.plot(np.arange(full_attrs_G.shape[2])[keep], np.abs(full_attrs_G)[0, 0][keep], lw=0.5);
plt.plot(np.arange(full_attrs_G.shape[2])[keep], np.abs(full_attrs_G)[0, 1][keep], lw=0.5, alpha=0.5);
plt.plot(np.arange(full_attrs_G.shape[2])[keep], np.abs(full_attrs_G)[0, 2][keep], lw=0.5, alpha=0.25);

In [None]:
full_attrs_G.sum(axis=-1)

In [None]:
full_attrs_G.shape

In [None]:
adata.var.iloc[3]

In [None]:
ds[1][0]

In [None]:
196608 * 20_000 / 1e9

In [None]:
from modisco.visualization import viz_sequence

In [None]:
viz_sequence.plot_weights((full_attrs_G[0, 0, :, np.newaxis] * test_input_tensor[0].detach().numpy())[98304-100:98304+100])

In [None]:
196608 // 2

In [None]:
plt.bar(np.arange(full_attrs_G.shape[2]), np.sort(np.mean(np.abs(full_attrs_G), axis=(0, 1)))[::-1]);
plt.xlabel('gene');
plt.ylabel('mean attribution score');

In [None]:
full_attrs_G_train = full_attrs_G[:, adata.var['train'][:200]]

In [None]:
full_attrs_G_test = full_attrs_G[:, np.logical_not(adata.var['train'][:200])]

In [None]:
plt.bar(np.arange(full_attrs_G_train.shape[2]), np.sort(np.mean(np.abs(full_attrs_G_train), axis=(0, 1)))[::-1]);
plt.xlabel('gene');
plt.ylabel('mean attribution score');

In [None]:
plt.bar(np.arange(140), np.sort(np.mean(np.abs(full_attrs_G_train), axis=(0, 1))[:140])[::-1]);
plt.bar(np.arange(140), np.sort(np.mean(np.abs(full_attrs_G_train), axis=(0, 1))[140:280])[::-1], alpha=0.5);
plt.xlabel('gene');
plt.ylabel('mean attribution score');

In [None]:
plt.bar(np.arange(280), np.mean(np.abs(full_attrs_G_train), axis=(0, 1))[:280]);

In [None]:
# plt.imshow(np.mean(np.abs(full_attrs_G_train[np.logical_not(adata.obs['train'])]), axis=0)[:, :280]);
plt.imshow(np.mean(np.abs(full_attrs_G_train[np.logical_not(adata.obs['train'])]), axis=0)[:, :280]);

In [None]:
from scipy.cluster.hierarchy import linkage
keep = np.logical_not(adata.obs['train'])
D = np.mean(np.abs(full_attrs_G_train[keep]), axis=0)[:, :140]

link = linkage(D) # D being the measurement
sns.clustermap(D, row_linkage=link, col_linkage=link)

In [None]:
diag = np.diagonal(np.mean(np.abs(full_attrs_G_train[keep]), axis=0)[:, :140])
vals = np.mean(adata[:, train_genes].X, axis=0)[:140]

In [None]:
plt.plot(vals, diag, '.');
plt.xlabel('mean expression')
plt.ylabel('mean self attribution score')

In [None]:
total_attr = np.mean(np.sum(full_attrs_G_train, axis=-1), axis=0)

In [None]:
plt.plot(total_attr, diag, '.');
plt.xlabel('total attribution');
plt.ylabel('mean self attribution score');

In [None]:
plt.plot(vals, total_attr, '.');
plt.xlabel('mean expression');
plt.ylabel('total attribution');

In [None]:
keep = np.logical_not(adata.obs['train'])
D = np.mean(np.abs(full_attrs_G_test[keep]), axis=0)
sns.clustermap(D)

In [None]:
total_attr

In [None]:
full_attrs_G_norm = full_attrs_G / np.expand_dims(np.sum(full_attrs_G, axis=-1), -1)

In [None]:
full_attrs_G_train_norm = full_attrs_G_norm[:, adata.var['train'][:200]]

In [None]:
plt.bar(np.arange(full_attrs_G_train_norm.shape[2]), np.sort(np.mean(np.abs(full_attrs_G_train_norm), axis=(0, 1)))[::-1]);
plt.xlabel('gene');
plt.ylabel('mean attribution score');

In [None]:
plt.bar(np.arange(140), np.sort(np.mean(np.abs(full_attrs_G_train_norm), axis=(0, 1))[:140])[::-1]);
plt.bar(np.arange(140), np.sort(np.mean(np.abs(full_attrs_G_train_norm), axis=(0, 1))[140:280])[::-1], alpha=0.5);
plt.xlabel('gene');
plt.ylabel('mean attribution score');

In [None]:
plt.imshow(np.mean(np.abs(full_attrs_G_train_norm[np.logical_not(adata.obs['train'])]), axis=0)[:, :280]);

In [None]:
from scipy.cluster.hierarchy import linkage
keep = np.logical_not(adata.obs['train'])
D = np.mean(np.abs(full_attrs_G_train_norm[keep]), axis=0)[:, :140]

link = linkage(D) # D being the measurement
sns.clustermap(D, row_linkage=link, col_linkage=link)

In [None]:
diag = np.diagonal(np.mean(np.abs(full_attrs_G_train_norm[keep]), axis=0)[:, :140])
vals = np.mean(adata[:, train_genes].X, axis=0)[:140]

plt.plot(vals, diag, '.');
plt.xlabel('mean expression')
plt.ylabel('mean self attribution score')

In [None]:
D = full_attrs_G_r[10, :, :140]
sns.clustermap(D, row_linkage=link, col_linkage=link)

In [None]:
plt.imshow(np.mean(np.abs(full_attrs_G_t), axis=0)[:, :140]);

In [None]:
D = np.mean(np.abs(full_attrs_G_t), axis=0)[:, :140]
sns.clustermap(D)

In [None]:
np.sum(full_attrs_G, axis=-1).shape

In [None]:
pca.components_.shape

In [None]:
X = model(test_input_tensor) - model(torch.zeros_like(test_input_tensor))

In [None]:
abs(X[:, :100].detach().numpy() - np.sum(full_attrs_G, axis=-1)).max()

In [None]:
model(torch.zeros_like(test_input_tensor)).max()

In [None]:
plt.hist(model(torch.zeros_like(test_input_tensor))[0].detach().numpy(), bins=200);

In [None]:
pca.components_.shape

In [None]:
fa = np.cumsum(np.mean(np.abs(full_attrs), axis=(0, 2)))
fa = fa / fa[-1]

plt.plot(np.arange(256), fa);
plt.xlabel('n PC');
plt.ylabel('cumulative mean attribution score');

In [None]:
pca.components_[:256, :].shape

In [None]:
full_attrs.shape

In [None]:
rlm.fo

In [None]:
from rosa.utils import score_predictions, plot_expression_and_correlation, plot_marker_gene_heatmap


adata_test, results = score_predictions(adata)

In [None]:
plot_expression_and_correlation(adata_test, results)

In [None]:
import numpy as np


marker_genes = adata_test.var[adata_test.var['highly_variable']]['feature_name'].values
np.random.seed(42)
marker_genes = np.random.choice(marker_genes, 50)

plot_marker_gene_heatmap(adata_test, marker_genes)

In [None]:
marker_genes_dict = adata_test.obs.set_index('label').to_dict()['marker_feature_name']
plot_marker_gene_heatmap(adata_test, marker_genes_dict)