### imports

In [1]:
from omegaconf import OmegaConf as om
from mosaicfm.model import ComposerSCGPTModel
from mosaicfm.tasks import get_batch_embeddings
from mosaicfm.tokenizer import GeneVocab
import torch
import pandas as pd
import anndata as ad
import scanpy as sc
import numpy as np
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


### generate PCA embeddings

In [9]:
adata = sc.read_h5ad("/vevo/umair/data/sens-pred/embs/ccle.h5ad")
sc.pp.normalize_total(adata)
sc.pp.log1p(adata)
adata.obsm["pca"] = sc.tl.pca(adata.X, n_comps=15, return_info=False)

In [10]:
adata.write_h5ad("/vevo/umair/data/sens-pred/embs/ccle.h5ad")

### generate model embeddings

In [2]:
base_path = "/vevo/umair/data/scgpt-models"

model_paths = {
    "mosaicfm-70m-pretrained": f"{base_path}/scgpt-70m-1024-fix-norm-apr24-data/best-model.pt",
    "mosaicfm-70m-adapted": f"{base_path}/mosaicfm-adapt-rif-constant-lr/ep50-ba354200-rank0.pt",
    "mosaicfm-70m-from-scratch": f"{base_path}/mosaicfm-adapt-rif-no-load/ep50-ba354200-rank0.pt",
    "mosaicfm-70m-tahoe": f"{base_path}/mosaicfm-70m-tahoe/best-model.pt",
    "mosaicfm-70m-merged": f"{base_path}/mosaicfm-70m-merged/best-model.pt",
    "mosaicfm-v2-1_3b-merged": f"{base_path}/mosaicfm-v2-1_3b-merged/best-model.pt"
}

vocab_paths = {
    "mosaicfm-70m-pretrained": f"{base_path}/scgpt-70m-1024-fix-norm-apr24-data/vocab.json",
    "mosaicfm-70m-adapted": f"{base_path}/scgpt-70m-1024-fix-norm-apr24-data/vocab.json",
    "mosaicfm-70m-from-scratch": f"{base_path}/scgpt-70m-1024-fix-norm-apr24-data/vocab.json",
    "mosaicfm-70m-tahoe": f"{base_path}/mosaicfm-70m-tahoe/vocab.json",
    "mosaicfm-70m-merged": f"{base_path}/mosaicfm-70m-merged/vocab.json",
    "mosaicfm-v2-1_3b-merged": f"{base_path}/mosaicfm-v2-1_3b-merged/vocab.json"
}

model_config_paths = {
    "mosaicfm-70m-pretrained": f"{base_path}/scgpt-70m-1024-fix-norm-apr24-data/model_config.yml",
    "mosaicfm-70m-adapted": f"{base_path}/scgpt-70m-1024-fix-norm-apr24-data/model_config.yml",
    "mosaicfm-70m-from-scratch": f"{base_path}/scgpt-70m-1024-fix-norm-apr24-data/model_config.yml",
    "mosaicfm-70m-tahoe": f"{base_path}/mosaicfm-70m-tahoe/model_config.yml",
    "mosaicfm-70m-merged": f"{base_path}/mosaicfm-70m-merged/model_config.yml",
    "mosaicfm-v2-1_3b-merged": f"{base_path}/mosaicfm-v2-1_3b-merged/model_config.yml"
}

collator_config_paths = {
    "mosaicfm-70m-pretrained": f"{base_path}/scgpt-70m-1024-fix-norm-apr24-data/collator_config.yml",
    "mosaicfm-70m-adapted": f"{base_path}/scgpt-70m-1024-fix-norm-apr24-data/collator_config.yml",
    "mosaicfm-70m-from-scratch": f"{base_path}/scgpt-70m-1024-fix-norm-apr24-data/collator_config.yml",
    "mosaicfm-70m-tahoe": f"{base_path}/mosaicfm-70m-tahoe/collator_config.yml",
    "mosaicfm-70m-merged": f"{base_path}/mosaicfm-70m-merged/collator_config.yml",
    "mosaicfm-v2-1_3b-merged": f"{base_path}/mosaicfm-v2-1_3b-merged/collator_config.yml"
}

feature_columns = {
    "mosaicfm-70m-pretrained": "feature_name",
    "mosaicfm-70m-adapted": "feature_name",
    "mosaicfm-70m-from-scratch": "feature_name",
    "mosaicfm-70m-tahoe": "gene_id",
    "mosaicfm-70m-merged": "gene_id",
    "mosaicfm-v2-1_3b-merged": "gene_id"
}

In [11]:
model_name = "mosaicfm-70m-merged"

In [12]:
# load vocabulary
vocab_path = vocab_paths[model_name]
vocab = GeneVocab.from_file(vocab_path)
vocab.set_default_index(vocab["<pad>"])

# set data variables
# data_path = "/vevo/umair/data/sens-pred/embs/ccle.h5ad"
data_path = "/vevo/umair/data/sens-pred/embs/rif-dmso-t0-all.h5ad"
gene_col = feature_columns[model_name]

# load data
adata = sc.read_h5ad(data_path)
adata.var["id_in_vocab"] = [vocab[gene] if gene in vocab else -1 for gene in adata.var[gene_col]]
gene_ids_in_vocab = np.array(adata.var["id_in_vocab"])
print(f"match {np.sum(gene_ids_in_vocab >= 0)}/{len(gene_ids_in_vocab)} genes in vocabulary of size {len(vocab)}.")
adata = adata[:, adata.var["id_in_vocab"] >= 0]
genes = adata.var[gene_col].tolist()
gene_ids = np.array(vocab(genes), dtype=int)

match 37392/37392 genes in vocabulary of size 62720.


In [13]:
# only get cell embeddings if they don't already exist
if model_name not in adata.obsm_keys():

    # load model
    model_config_path = model_config_paths[model_name]
    collator_config_path = collator_config_paths[model_name]
    model_file = model_paths[model_name]
    model_config = om.load(model_config_path)
    collator_config = om.load(collator_config_path)
    model = ComposerSCGPTModel(model_config = model_config, collator_config = collator_config)
    model.load_state_dict(torch.load(model_file)["state"]["model"], strict=True)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    model.eval()

    # get cell embeddings
    cell_embeddings = get_batch_embeddings(
        adata=adata,
        model=model.model,
        vocab=vocab,
        gene_ids=gene_ids,
        model_cfg=model_config,
        collator_cfg=collator_config,
        batch_size=64,
        max_length=2048,
        return_gene_embeddings=False
    )

    # add to AnnData and save
    adata.obsm[model_name] = cell_embeddings
    adata.write_h5ad(data_path)

Embedding cells: 100%|██████████| 30245/30245 [01:02<00:00, 481.11it/s]
  adata.obsm[model_name] = cell_embeddings


### compute mean embeddings across cell line

In [9]:
adata = sc.read_h5ad("/vevo/umair/data/sens-pred/dmso-t0.h5ad")
adata

AnnData object with n_obs × n_vars = 30245 × 37476
    obs: 'sample', 'species', 'gene_count', 'tscp_count', 'mread_count', 'bc1_well', 'bc2_well', 'bc3_well', 'bc1_wind', 'bc2_wind', 'bc3_wind', 'id', 'drugname_drugconc', 'drug', 'INT_ID', 'NUM.SNPS', 'NUM.READS', 'demuxlet_call', 'BEST.GUESS', 'BEST.LLK', 'NEXT.GUESS', 'NEXT.LLK', 'DIFF.LLK.BEST.NEXT', 'BEST.POSTERIOR', 'SNG.POSTERIOR', 'cell_line', 'SNG.BEST.LLK', 'SNG.NEXT.GUESS', 'SNG.NEXT.LLK', 'SNG.ONLY.POSTERIOR', 'DBL.BEST.GUESS', 'DBL.BEST.LLK', 'DIFF.LLK.SNG.DBL', 'sublibrary', 'BARCODE', 'batch', 'pcnt_mito', 'S_score', 'G2M_score', 'phase', 'mlp_label', 'mlp_confidence', 'cell_line_orig', 'demuxlet_call_orig', 'pass_filter'
    var: 'gene_id', 'genome', 'SUB_LIB_ID-0-0', 'SUB_LIB_ID-1-0', 'SUB_LIB_ID-10-0', 'SUB_LIB_ID-11-0', 'SUB_LIB_ID-12-0', 'SUB_LIB_ID-13-0', 'SUB_LIB_ID-14-0', 'SUB_LIB_ID-15-0', 'SUB_LIB_ID-2-0', 'SUB_LIB_ID-3-0', 'SUB_LIB_ID-4-0', 'SUB_LIB_ID-5-0', 'SUB_LIB_ID-6-0', 'SUB_LIB_ID-7-0', 'SUB_LIB_ID-8-0'

In [15]:
adata.uns["pca"]["variance_ratio"].sum()

0.13700681465891257

In [3]:
# initialize arrays
cell_lines = []
emb_arrays = {
    "pca": [],
    "mosaicfm-70m-pretrained": [],
    "mosaicfm-70m-adapted": [],
    "mosaicfm-70m-from-scratch": [],
}

# compute means
for c in tqdm(adata.obs["cell_line_orig"].unique().tolist()):
    cell_lines.append(c)
    emb_arrays["pca"].append(adata[adata.obs["cell_line_orig"] == c].obsm["X_pca"].mean(axis=0))
    emb_arrays["mosaicfm-70m-pretrained"].append(adata[adata.obs["cell_line_orig"] == c].obsm["mosaicfm-70m-pretrained"].mean(axis=0))
    emb_arrays["mosaicfm-70m-adapted"].append(adata[adata.obs["cell_line_orig"] == c].obsm["mosaicfm-70m-adapted"].mean(axis=0))
    emb_arrays["mosaicfm-70m-from-scratch"].append(adata[adata.obs["cell_line_orig"] == c].obsm["mosaicfm-70m-from-scratch"].mean(axis=0))

# convert to NumPy arrays
for k in emb_arrays.keys():
    emb_arrays[k] = np.array(emb_arrays[k])

100%|██████████| 52/52 [00:05<00:00,  9.39it/s]


In [5]:
# build initial AnnData
mean_adata = ad.AnnData(
    X=emb_arrays["pca"],
    obs=pd.DataFrame({"cell-line": cell_lines}),
    var=pd.DataFrame({"dim": [i for i in range(emb_arrays["pca"].shape[1])]})
)

# add other embeddings
mean_adata.obsm["pca"] = mean_adata.X
for suffix in ["pretrained", "adapted", "from-scratch"]:
    model = f"mosaicfm-70m-{suffix}"
    mean_adata.obsm[model] = emb_arrays[model]

# inspect
mean_adata



AnnData object with n_obs × n_vars = 52 × 15
    obs: 'cell-line'
    var: 'dim'
    obsm: 'pca', 'mosaicfm-70m-pretrained', 'mosaicfm-70m-adapted', 'mosaicfm-70m-from-scratch'

In [6]:
mean_adata.write_h5ad("/vevo/umair/data/sens-pred/dmso-mean-embs.h5ad")

### add embeddings from new models

In [16]:
adata = sc.read_h5ad("/vevo/umair/data/sens-pred/embs/rif-dmso-t0-all.h5ad")
adata

AnnData object with n_obs × n_vars = 30245 × 37392
    obs: 'sample', 'species', 'gene_count', 'tscp_count', 'mread_count', 'bc1_well', 'bc2_well', 'bc3_well', 'bc1_wind', 'bc2_wind', 'bc3_wind', 'id', 'drugname_drugconc', 'drug', 'INT_ID', 'NUM.SNPS', 'NUM.READS', 'demuxlet_call', 'BEST.GUESS', 'BEST.LLK', 'NEXT.GUESS', 'NEXT.LLK', 'DIFF.LLK.BEST.NEXT', 'BEST.POSTERIOR', 'SNG.POSTERIOR', 'cell_line', 'SNG.BEST.LLK', 'SNG.NEXT.GUESS', 'SNG.NEXT.LLK', 'SNG.ONLY.POSTERIOR', 'DBL.BEST.GUESS', 'DBL.BEST.LLK', 'DIFF.LLK.SNG.DBL', 'sublibrary', 'BARCODE', 'batch', 'pcnt_mito', 'S_score', 'G2M_score', 'phase', 'mlp_label', 'mlp_confidence', 'cell_line_orig', 'demuxlet_call_orig', 'pass_filter'
    var: 'gene_id', 'genome', 'SUB_LIB_ID-0-0', 'SUB_LIB_ID-1-0', 'SUB_LIB_ID-10-0', 'SUB_LIB_ID-11-0', 'SUB_LIB_ID-12-0', 'SUB_LIB_ID-13-0', 'SUB_LIB_ID-14-0', 'SUB_LIB_ID-15-0', 'SUB_LIB_ID-2-0', 'SUB_LIB_ID-3-0', 'SUB_LIB_ID-4-0', 'SUB_LIB_ID-5-0', 'SUB_LIB_ID-6-0', 'SUB_LIB_ID-7-0', 'SUB_LIB_ID-8-0'

In [17]:
mean_adata = sc.read_h5ad("/vevo/umair/data/sens-pred/embs/rif-dmso-t0-mean.h5ad")
mean_adata

AnnData object with n_obs × n_vars = 52 × 15
    obs: 'cell-line'
    var: 'dim'
    obsm: 'mosaicfm-70m-adapted', 'mosaicfm-70m-from-scratch', 'mosaicfm-70m-pretrained', 'mosaicfm-70m-tahoe', 'pca'

In [18]:
# initialize arrays
cell_lines = []
emb_arrays = {"mosaicfm-70m-merged": [], "mosaicfm-v2-1_3b-merged": []}

# compute means
for c in tqdm(adata.obs["cell_line_orig"].unique().tolist()):
    cell_lines.append(c)
    emb_arrays["mosaicfm-70m-merged"].append(adata[adata.obs["cell_line_orig"] == c].obsm["mosaicfm-70m-merged"].mean(axis=0))
    emb_arrays["mosaicfm-v2-1_3b-merged"].append(adata[adata.obs["cell_line_orig"] == c].obsm["mosaicfm-v2-1_3b-merged"].mean(axis=0))

# convert to NumPy arrays
for k in emb_arrays.keys():
    emb_arrays[k] = np.array(emb_arrays[k])

100%|██████████| 52/52 [00:02<00:00, 18.94it/s]


In [19]:
mean_adata.obsm["mosaicfm-70m-merged"] = emb_arrays["mosaicfm-70m-merged"]
mean_adata.obsm["mosaicfm-v2-1_3b-merged"] = emb_arrays["mosaicfm-v2-1_3b-merged"]
mean_adata

AnnData object with n_obs × n_vars = 52 × 15
    obs: 'cell-line'
    var: 'dim'
    obsm: 'mosaicfm-70m-adapted', 'mosaicfm-70m-from-scratch', 'mosaicfm-70m-pretrained', 'mosaicfm-70m-tahoe', 'pca', 'mosaicfm-70m-merged', 'mosaicfm-v2-1_3b-merged'

In [20]:
mean_adata.write_h5ad("/vevo/umair/data/sens-pred/embs/rif-dmso-t0-mean.h5ad")