In [1]:
import hydra
from omegaconf import OmegaConf
from glob import  glob
# from rosa import  predict
from rosa.data import create_io_paths, RosaDataModule
from rosa.modeling.modules import RosaLightningModule
from pytorch_lightning import Trainer


BASE_DIR = "/Users/nsofroniew/Documents/data/rosa/outputs/2023-02-12/21-57-00"
config_dir = BASE_DIR + "/.hydra"

with hydra.initialize_config_dir(config_dir=config_dir):
    config = hydra.compose(config_name="config", overrides=OmegaConf.load(config_dir + "/overrides.yaml"))

    chkpts = BASE_DIR + "/checkpoints/*.ckpt"
    chkpt = glob(chkpts)[1]

    _, output_path = create_io_paths(config.paths)

    # Create Data Module
    rdm = RosaDataModule(
        output_path,
        config=config.data_module,
    )
    rdm.setup()

    # Load model from checkpoint
    rlm = RosaLightningModule.load_from_checkpoint(
        chkpt,
        in_dim=rdm.len_input,
        out_dim=rdm.len_target,
        config=config.module,
    )
    print(rlm)

    trainer = Trainer()
    predictions = trainer.predict(rlm, rdm)
    rdm.predict_dataset.predict(predictions)
    adata = rdm.predict_dataset.adata
    # adata = predict(config, chkpt)

display(adata)

Global seed set to 0
  new_rank_zero_deprecation(
  return new_rank_zero_deprecation(*args, **kwargs)
The version_base parameter is not specified.
Please specify a compatability version level, or None.
Will assume defaults for version 1.1
  with hydra.initialize_config_dir(config_dir=config_dir):


RosaLightningModule(
  (model): RosaSingleModel(
    (main): Sequential(
      (layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
      (input_embed): Identity()
      (feed_forward): Identity()
      (dropout): Dropout(p=0.5, inplace=False)
      (expression_head): ProjectionExpressionHead(
        (model): Sequential(
          (projection): Linear(in_features=256, out_features=19431, bias=True)
          (activation): Softplus(beta=1, threshold=20)
        )
      )
    )
  )
)


GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_warn(
Missing logger folder: /Users/nsofroniew/Documents/GitHub/rosa/notebooks/lightning_logs
  rank_zero_warn(


Predicting: 0it [00:00, ?it/s]

AnnData object with n_obs × n_vars = 369 × 19431
    obs: 'dataset_id', 'cell_type', 'cell_type_ontology_term_id', 'development_stage', 'development_stage_ontology_term_id', 'disease', 'disease_ontology_term_id', 'donor_id', 'is_primary_data', 'self_reported_ethnicity', 'self_reported_ethnicity_ontology_term_id', 'sex', 'sex_ontology_term_id', 'suspension_type', 'label', 'sample', 'n_genes', 'train', 'marker_gene', 'marker_feature_name'
    var: 'soma_joinid', 'feature_name', 'feature_length', 'column_1', 'column_2', 'column_3', 'column_4', 'external_gene_name', 'gene_biotype', 'train', 'highly_variable', 'means', 'dispersions', 'dispersions_norm'
    uns: 'dendrogram_label', 'hvg', 'log1p', 'obs_embedding_pca', 'preprocessing', 'rank_genes_groups', 'var_embedding_pca'
    obsm: 'bin_edges', 'embedding'
    varm: 'embedding', 'embedding_pca'
    layers: 'binned', 'counts', 'log1p', 'normalized_counts', 'prediction'

In [2]:
from sklearn.decomposition import PCA


train_cells = adata.obs['train']
train_genes = adata.var['train']
adata_split = adata[train_cells, train_genes]

# fit pca on training data
pca = PCA()
pca.fit(adata_split.X)

# compute scores for all cells
pca_expression = pca.transform(adata[:, train_genes].X)

# # add cell embeddings to obsm
# n_pcs = config.pcs
# n_pcs = min(n_pcs, pca_expression.shape[1] - 1)
# adata.obsm["embedding"] = pca_expression[:, :n_pcs]
# adata.uns["obs_embedding_pca"] = {
#     "explained_variance": np.cumsum(pca.explained_variance_ratio_)[n_pcs]
# }

In [3]:
# trans = np.einsum('ij, kj -> ik', adata[:, train_genes].X - pca.mean_, pca.components_)

In [4]:
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm

In [5]:
# imports from captum library
from captum.attr import IntegratedGradients, InputXGradient

In [6]:
# ig = IntegratedGradients(rlm)
ixg = InputXGradient(rlm)

In [7]:
test_input_tensor = rdm.predict_dataset[:][0]
target_tensor = rdm.predict_dataset[:][1]

In [8]:
x, y = next(iter(rdm.predict_dataloader()))

In [26]:
y.shape[-1]

19431

In [23]:
import torch


def make_explain_iter(rdm, explainer, batch_size=1):
    if rdm.len_target == 1:
        for x, y in iter(rdm.predict_dataloader(batch_size=batch_size)):
            x = tuple(x_ind.reshape(-1, x_ind.shape[-1]).requires_grad_() for x_ind in x)
            attribution = explainer.attribute(x)
            yield tuple(a.reshape(y.shape[0], y.shape[1], -1) for a in attribution)
    else:
        for x, y in iter(rdm.predict_dataloader(batch_size=batch_size)):
            attribution = []
            for target in range(rdm.len_target):
                x.requires_grad_()
                attribution.append(explainer.attribute(x, target=target))
            yield torch.stack(attribution, dim=1)


In [24]:
explain_iter = make_explain_iter(rdm, ixg)

In [21]:
attr = next(iter(explain_iter))

In [22]:
attr.shape

torch.Size([1, 19431, 256])

In [None]:
import zarr

results_shape = (len(rdm.predict_dataset), rdm.len_target, rdm.len_input) # rdm.len_target
# results_shape = rdm.predict_dataset.adata.shape + (rdm.len_input,)
z = zarr.zeros(results_shape, chunks=(1, None, None), dtype=np.float32)

ind = 0
for attr in tqdm(iter(explain_iter)):
    z[ind:ind+len(attr), :, :] = attr # for cell dataset
    # z[:, ind:ind+len(attr), :] = attr # for gene dataset
    ind += len(attr)

In [None]:
import zarr

results_shape = (len(rdm.predict_dataset), 50, rdm.len_input) # rdm.len_target
# results_shape = rdm.predict_dataset.adata.shape + (rdm.len_input,)
z = zarr.zeros(results_shape, chunks=(1, None, None), dtype=np.float32)

ind = 0
for x, y in tqdm(iter(rdm.predict_dataloader(batch_size=1)), leave=False):
    for target in tqdm(range(50)):
        x.requires_grad_()
        attribution = ixg.attribute(x, target=target)
        attribution = attribution.detach().numpy()
        # z[target, ind:ind+len(x), :] = attribution # for gene dataset
        z[ind:ind+len(x), target, :] = attribution # for cell dataset
        #### for joint have to do something more clever ...... maybe swap iterators
    ind += len(x)

In [None]:
import zarr

results_shape = (len(rdm.predict_dataset), 50, rdm.len_input) # rdm.len_target
# results_shape = rdm.predict_dataset.adata.shape + (rdm.len_input,)
z = zarr.zeros(results_shape, chunks=(1, None, None), dtype=np.float32)

results = []
for x, y in tqdm(iter(rdm.predict_dataloader()), leave=False):
    results_batch = []
    for target in tqdm(range(50)):
        x.requires_grad_()
        attribution = ixg.attribute(x, target=target)
        attribution = attribution.detach().numpy()
        results_batch.append(attribution)
    results_batch = np.stack(results_batch, axis=1)
    results.append(results_batch)
results = np.concatenate(results)

# results = results.swapaxes(0, 1)
# results = results.swapaxes(0, 1)

In [None]:
results.swapaxes(0, 1).shape

In [None]:
# cells x genes x cell_embedding - for cell model
# cells x genes x gene_embedding - for gene model (right now would be genes x cells - need a transpose)
# BOTH for joint model .....

# Allow for normal model + modified model.
# For modified model for cells include pca .... add pca components to adata `uns` ....?????
# For modified model for gene include enformer
# For modified model for joint include both

# Once have attribution working get TFMolDisco working ....
# Explore ground truth / databases ....

In [None]:
rdm.predict_dataset.adata.shape

In [None]:
full_attrs = np.asarray(z)

In [None]:
plt.bar(np.arange(full_attrs.shape[2]), full_attrs[0, 0])

In [None]:
plt.bar(np.arange(full_attrs.shape[2]), np.mean(np.abs(full_attrs), axis=(0, 1)));
plt.xlabel('n PC');
plt.ylabel('mean attribution score');

In [None]:
plt.plot(pca.explained_variance_ratio_[:full_attrs.shape[2]], np.exp(np.mean(np.abs(full_attrs), axis=(0, 1))), '.');
plt.xlabel('singular values');
plt.ylabel('mean attribution score');


In [None]:
plt.bar(np.arange(full_attrs.shape[2]), pca.explained_variance_ratio_[:full_attrs.shape[2]]);
plt.xlabel('n PC');
plt.ylabel('singular values');

In [None]:
plt.imshow(np.mean(np.abs(full_attrs), axis=(0,))[:, :85]);
plt.xlabel('n PC');
plt.ylabel('target gene');

In [None]:
print(full_attrs.shape)
print(pca.components_.shape)

In [None]:
output = np.einsum('ijk, kl -> ijl', full_attrs, pca.components_[:256, :])

In [None]:
print(output.shape)

In [None]:
plt.bar(np.arange(output.shape[2]), np.mean(np.abs(output), axis=(0, 1)));
plt.xlabel('gene');
plt.ylabel('mean attribution score');

In [None]:
plt.bar(np.arange(output.shape[2]), np.sort(np.mean(np.abs(output), axis=(0, 1)))[::-1]);
plt.xlabel('gene');
plt.ylabel('mean attribution score');

In [None]:
output.shape

In [None]:
output_r = output[:, adata.var['train'][:100]]

In [None]:
plt.imshow(np.mean(np.abs(output_r), axis=0)[:, :70]);

In [None]:
import seaborn as sns

In [None]:
sns.clustermap(test_input_tensor.detach().numpy()[:, :25])

In [None]:
sns.clustermap(np.mean(np.abs(full_attrs), axis=0)[:, :])

In [None]:
from scipy.cluster.hierarchy import linkage
D = np.mean(np.abs(output_r), axis=0)[:, :70]

link = linkage(D) # D being the measurement
sns.clustermap(D, row_linkage=link, col_linkage=link)

In [None]:
import torch.nn as nn
import torch

class FullModel(nn.Module):
    def __init__(self, rlm, pca):
        super(FullModel, self).__init__()
        self.input_mean = torch.from_numpy(pca.mean_)
        self.input_weights = torch.from_numpy(pca.components_[:256])
        self.model = rlm.model

    def forward(self, x):
        x = x - self.input_mean
        x = torch.einsum('ij, kj -> ik', x, self.input_weights)
        return self.model(x)

In [None]:
model = FullModel(rlm, pca)

In [None]:
test_input_tensor = torch.from_numpy(adata[:, train_genes].X)

In [None]:
test_input_tensor.shape

In [None]:
model(test_input_tensor).shape

In [None]:
# ig = IntegratedGradients(model)
ig = InputXGradient(model)

full_attrs_G = []
for targ in tqdm(range(200)):
    test_input_tensor.requires_grad_()
    attr = ig.attribute(test_input_tensor, target=targ)
    attr = attr.detach().numpy()
    full_attrs_G.append(attr)
full_attrs_G = np.stack(full_attrs_G, axis=1)

In [None]:
test_input_tensor.shape

In [None]:
plt.bar(np.arange(full_attrs_G.shape[2]), np.sort(np.mean(np.abs(full_attrs_G), axis=(0, 1)))[::-1]);
plt.xlabel('gene');
plt.ylabel('mean attribution score');

In [None]:
full_attrs_G_train = full_attrs_G[:, adata.var['train'][:200]]

In [None]:
full_attrs_G_test = full_attrs_G[:, np.logical_not(adata.var['train'][:200])]

In [None]:
plt.bar(np.arange(full_attrs_G_train.shape[2]), np.sort(np.mean(np.abs(full_attrs_G_train), axis=(0, 1)))[::-1]);
plt.xlabel('gene');
plt.ylabel('mean attribution score');

In [None]:
plt.bar(np.arange(140), np.sort(np.mean(np.abs(full_attrs_G_train), axis=(0, 1))[:140])[::-1]);
plt.bar(np.arange(140), np.sort(np.mean(np.abs(full_attrs_G_train), axis=(0, 1))[140:280])[::-1], alpha=0.5);
plt.xlabel('gene');
plt.ylabel('mean attribution score');

In [None]:
plt.bar(np.arange(280), np.mean(np.abs(full_attrs_G_train), axis=(0, 1))[:280]);

In [None]:
# plt.imshow(np.mean(np.abs(full_attrs_G_train[np.logical_not(adata.obs['train'])]), axis=0)[:, :280]);
plt.imshow(np.mean(np.abs(full_attrs_G_train[np.logical_not(adata.obs['train'])]), axis=0)[:, :280]);

In [None]:
from scipy.cluster.hierarchy import linkage
keep = np.logical_not(adata.obs['train'])
D = np.mean(np.abs(full_attrs_G_train[keep]), axis=0)[:, :140]

link = linkage(D) # D being the measurement
sns.clustermap(D, row_linkage=link, col_linkage=link)

In [None]:
diag = np.diagonal(np.mean(np.abs(full_attrs_G_train[keep]), axis=0)[:, :140])
vals = np.mean(adata[:, train_genes].X, axis=0)[:140]

In [None]:
plt.plot(vals, diag, '.');
plt.xlabel('mean expression')
plt.ylabel('mean self attribution score')

In [None]:
total_attr = np.mean(np.sum(full_attrs_G_train, axis=-1), axis=0)

In [None]:
plt.plot(total_attr, diag, '.');
plt.xlabel('total attribution');
plt.ylabel('mean self attribution score');

In [None]:
plt.plot(vals, total_attr, '.');
plt.xlabel('mean expression');
plt.ylabel('total attribution');

In [None]:
keep = np.logical_not(adata.obs['train'])
D = np.mean(np.abs(full_attrs_G_test[keep]), axis=0)
sns.clustermap(D)

In [None]:
D.shape

In [None]:
full_attrs_G_norm = full_attrs_G / np.expand_dims(np.sum(full_attrs_G, axis=-1), -1)

In [None]:
full_attrs_G_train_norm = full_attrs_G_norm[:, adata.var['train'][:200]]

In [None]:
plt.bar(np.arange(full_attrs_G_train_norm.shape[2]), np.sort(np.mean(np.abs(full_attrs_G_train_norm), axis=(0, 1)))[::-1]);
plt.xlabel('gene');
plt.ylabel('mean attribution score');

In [None]:
plt.bar(np.arange(140), np.sort(np.mean(np.abs(full_attrs_G_train_norm), axis=(0, 1))[:140])[::-1]);
plt.bar(np.arange(140), np.sort(np.mean(np.abs(full_attrs_G_train_norm), axis=(0, 1))[140:280])[::-1], alpha=0.5);
plt.xlabel('gene');
plt.ylabel('mean attribution score');

In [None]:
plt.imshow(np.mean(np.abs(full_attrs_G_train_norm[np.logical_not(adata.obs['train'])]), axis=0)[:, :280]);

In [None]:
from scipy.cluster.hierarchy import linkage
keep = np.logical_not(adata.obs['train'])
D = np.mean(np.abs(full_attrs_G_train_norm[keep]), axis=0)[:, :140]

link = linkage(D) # D being the measurement
sns.clustermap(D, row_linkage=link, col_linkage=link)

In [None]:
diag = np.diagonal(np.mean(np.abs(full_attrs_G_train_norm[keep]), axis=0)[:, :140])
vals = np.mean(adata[:, train_genes].X, axis=0)[:140]

plt.plot(vals, diag, '.');
plt.xlabel('mean expression')
plt.ylabel('mean self attribution score')

In [None]:
D = full_attrs_G_r[10, :, :140]
sns.clustermap(D, row_linkage=link, col_linkage=link)

In [None]:
plt.imshow(np.mean(np.abs(full_attrs_G_t), axis=0)[:, :140]);

In [None]:
D = np.mean(np.abs(full_attrs_G_t), axis=0)[:, :140]
sns.clustermap(D)

In [None]:
np.sum(full_attrs_G, axis=-1).shape

In [None]:
pca.components_.shape

In [None]:
X = model(test_input_tensor) - model(torch.zeros_like(test_input_tensor))

In [None]:
abs(X[:, :100].detach().numpy() - np.sum(full_attrs_G, axis=-1)).max()

In [None]:
model(torch.zeros_like(test_input_tensor)).max()

In [None]:
plt.hist(model(torch.zeros_like(test_input_tensor))[0].detach().numpy(), bins=200);

In [None]:
pca.components_.shape

In [None]:
fa = np.cumsum(np.mean(np.abs(full_attrs), axis=(0, 2)))
fa = fa / fa[-1]

plt.plot(np.arange(256), fa);
plt.xlabel('n PC');
plt.ylabel('cumulative mean attribution score');

In [None]:
pca.components_[:256, :].shape

In [None]:
full_attrs.shape

In [None]:
rlm.fo

In [None]:
from rosa.utils import score_predictions, plot_expression_and_correlation, plot_marker_gene_heatmap


adata_test, results = score_predictions(adata)

In [None]:
plot_expression_and_correlation(adata_test, results)

In [None]:
import numpy as np


marker_genes = adata_test.var[adata_test.var['highly_variable']]['feature_name'].values
np.random.seed(42)
marker_genes = np.random.choice(marker_genes, 50)

plot_marker_gene_heatmap(adata_test, marker_genes)

In [None]:
marker_genes_dict = adata_test.obs.set_index('label').to_dict()['marker_feature_name']
plot_marker_gene_heatmap(adata_test, marker_genes_dict)