In [None]:
import anndata as ad

# PATH = '/home/ec2-user/cell_census/tabula_sapiens__sample_single_cell__label_cell_type__processed.h5ad'
PATH = '/home/ec2-user/cell_census/tabula_sapiens__sample_donor_id__label_cell_type.h5ad'

adata = ad.read_h5ad(PATH)

from enformer_pytorch import GenomeIntervalDataset

class MyGenomeIntervalDataset(GenomeIntervalDataset):
    def __init__(self, **kwargs):
        super(MyGenomeIntervalDataset, self).__init__(**kwargs)

    def __getitem__(self, ind):
        item = super().__getitem__(ind)
        label = self.df.row(ind)[4]
        return label, item


import torch
import polars as pl
import zarr
from enformer_pytorch import Enformer
from torch.utils.data import DataLoader
from tqdm import tqdm
import pyfaidx
from pathlib import Path

torch.multiprocessing.freeze_support()

BASE_PT = "/home/ec2-user/enformer"
DEVICE = "cuda:0"

# BASE_PT = "/Users/nsofroniew/Documents/data/multiomics/enformer"
# DEVICE = "cpu"

FASTA_PT = BASE_PT + "/Homo_sapiens.GRCh38.dna.toplevel.fa"
GENE_INTERVALS_PT = BASE_PT + "/Homo_sapiens.GRCh38.genes.bed"
EMBEDDING_PT = BASE_PT + "/Homo_sapiens.GRCh38.genes.enformer_embeddings.zarr"
EMBEDDING_PT_TSS = BASE_PT + "/Homo_sapiens.GRCh38.genes.enformer_embeddings_tss.zarr"
MODEL_PT = "EleutherAI/enformer-official-rough"

def filter_df_fn(df):
    return df.filter(pl.col("column_5").is_in(list(adata.var_names)))

print("Converting fasta file")
pyfaidx.Faidx(FASTA_PT)
print("Fasta file done")

model = Enformer.from_pretrained(MODEL_PT, output_heads=dict(human = 5313), use_checkpointing = False)
model.to(DEVICE)

In [None]:
ds = MyGenomeIntervalDataset(
    bed_file=GENE_INTERVALS_PT,  # bed file - columns 0, 1, 2 must be <chromosome>, <start position>, <end position>
    fasta_file=FASTA_PT,  # path to fasta file
    return_seq_indices=False,  # return nucleotide indices (ACGTN) or one hot encodings
    rc_aug=False,
    filter_df_fn=filter_df_fn,
)
dl = DataLoader(ds, batch_size=2, shuffle=False, num_workers=0) # type: DataLoader

# Create zarr files
SEQ_EMBED_DIM = 896
EMBED_DIM = 3072
NUM_GENES = len(ds)
TSS = int(SEQ_EMBED_DIM // 2)

paths = (Path(EMBEDDING_PT), Path(EMBEDDING_PT_TSS))

In [None]:
ds[181][0]

In [None]:
import pandas as pd
import numpy as np

# targets_txt = 'https://raw.githubusercontent.com/calico/basenji/0.5/manuscripts/cross2020/targets_human.txt'
# df_targets = pd.read_csv(targets_txt, sep='\t')
df_targets = pd.read_csv(BASE_PT + '/targets_human.txt')
df_targets.shape  # (5313, 8) With rows match output shape above.
cage_indices = np.where(df_targets['description'].str.startswith('CAGE:'))[0]

In [None]:
# z_embedding_full = zarr.open(
#     EMBEDDING_PT,
#     mode="w",
#     shape=(NUM_GENES, SEQ_EMBED_DIM, EMBED_DIM),
#     chunks=(1, SEQ_EMBED_DIM, EMBED_DIM),
#     dtype='float32',
# )

# z_embedding_tss = zarr.open(
#     EMBEDDING_PT_TSS,
#     mode="w",
#     shape=(NUM_GENES, EMBED_DIM),
#     chunks=(1, EMBED_DIM),
#     dtype='float32',
# )

# index = 0
# for labels, batch in tqdm(dl):
#     # calculate embedding
#     with torch.no_grad():
#         output, embeddings = model(batch.to(DEVICE), return_embeddings=True)
#         embeddings = embeddings.detach().cpu().numpy()

#     tss_embedding = embeddings[:, TSS]

#     # save full and reduced embeddings
#     batch_size = len(embeddings)
#     z_embedding_full[index : index + batch_size] = embeddings
#     z_embedding_tss[index : index + batch_size] = tss_embedding
#     index += batch_size

In [None]:
labels, batch = next(iter(dl))

In [None]:
batch[0, :10, :]

In [None]:
batch[0, :10, :]

In [None]:
with torch.no_grad():
    output, embeddings = model(batch.to(DEVICE), return_embeddings=True)
    cage_expression = output['human'][:, :, cage_indices].mean(dim=-1)
    max_inds = torch.argmax(cage_expression, dim=-1)
    batch_size = len(embeddings)
    tss_embedding = embeddings[torch.arange(batch_size), max_inds]

In [None]:
max_inds

In [None]:
tss_embedding.shape

In [None]:
result = embeddings[torch.arange(2), :, max_inds]


In [None]:
result.shape

In [None]:
cage_expression.shape

In [None]:
label, seq = ds[20]

In [None]:
seq

In [None]:
with torch.no_grad():
    output, embeddings = model(seq.to(DEVICE), return_embeddings=True)
    cage_expression = output['human'][:, cage_indices].mean(dim=1)
    max_ind = torch.argmax(cage_expression)
    tss_embedding = embeddings[max_ind]

In [None]:
embeddings.shape

In [None]:
list(ds.df['column_5'])

In [None]:
max_ind = torch.argmax(cage_expression)

In [None]:
tss_embedding.shape

In [None]:
embeddings.shape

In [None]:
output['human'].shape

In [None]:
import matplotlib.pyplot as plt

In [None]:
cage_indices

In [None]:
plt.plot([TSS, TSS], [0, 100]);
plt.plot(output['human'][:, cage_indices].mean(dim=1).detach().cpu());

In [None]:
# !pip install pyensembl

In [None]:
# from pyensembl import Genome

# gtf_file_path = BASE_PT + '/Homo_sapiens.GRCh38.77.gtf'

# genome = Genome(reference_name='GRCh38',
#     annotation_name='my_genome_features',
#     gtf_path_or_url=gtf_file_path)
# # genome.index()

In [None]:
label

In [None]:
from pyensembl import EnsemblRelease

genome = EnsemblRelease(77)

In [None]:
gene = genome.gene_by_id(label)

# Observed and predicted gene expression values were obtained by summing up the observed/predicted CAGE read counts
# at all unique TSS locations of the gene. For each TSS location, we used the 128-bp bin overlapping the TSS as well
# as the two neighboring bins (3 bins in total).

# For each gene, look through all transcipts - protein coding / not, and record offsets from gene start
gene.transcripts

In [None]:
from scipy.ndimage import gaussian_filter1d

def get_tss(gene_id, tss=TSS, length=SEQ_EMBED_DIM, sigma=8):
    gene = genome.gene_by_id(gene_id)
    starts = np.array([tss + np.round((ts.start - gene.start) / 128) for ts in gene.transcripts], dtype=int)
    starts = starts[starts>=0]
    starts = starts[starts<length]
    vector = np.zeros(length)
    vector[starts] = 1.0
    if sigma is not None:
        vector = gaussian_filter1d(vector, sigma)
    return vector

In [None]:
plt.plot(get_tss(label))

In [None]:
def get_tss_locations(genes, tss=TSS):
    locations = {}
    for gene_id in genes:
        gene = genome.gene_by_id(gene_id)
        start_diffs = [tss + np.round((ts.start - gene.start) / 128) for ts in gene.transcripts]
        locations[gene_id] = start_diffs
    return locations

In [None]:
locations = get_tss_locations(list(ds.df['column_5']))

In [None]:
locations

In [None]:
gene = genome.gene_by_id(label)
start_diffs = [TSS + np.round((ts.start - gene.start) / 128) for ts in gene.transcripts]

In [None]:
start_diffs

In [None]:
for st in start_diffs:
    plt.plot([st, st], [0, 100], color='k');
# plt.plot(output['human'][:, cage_indices].mean(dim=1).cpu());
plt.ylim([0, 20])

In [None]:
start_diffs

In [None]:
def extract_embeddings(embeddings, cage_expression, tss_tensors, sigmas, tss=TSS):
    # Embeddings include
    #   TSS
    #   sum over all
    #   argmax over all
    #   sum over TSS
    #   argmax over TSS
    #   sum over TSS sigma 3, 8, 16
    #   argmax over TSS sigma 3, 8, 16
    batch_size = embeddings.shape[0]
    scaled_cage_expression = cage_expression * tss_tensors

    tss_emb = embeddings[:, TSS]
    sum_emb = embeddings.sum(dim=1)
    max_inds = torch.argmax(cage_expression, dim=-1)
    amax_emb = embeddings[torch.arange(batch_size), max_inds]
    
    max_inds = torch.argmax(scaled_cage_expression, dim=-1)
    amax_tss_emb = embeddings[torch.arange(batch_size), max_inds]

    max_inds = torch.argmax(scaled_cage_expression, dim=-1)
    sum_tss_emb = (embeddings * scaled_cage_expression).sum(dim=1)

    all_emb = [tss_emb, sum_emb, amax_emb, amax_tss_emb, sum_tss_emb]
    for sigma in sigmas:
        ks = 2 * int(sigma / 2 * 3)
        tss_tensors_conv = gaussian_filter_1d(tss_tensors, kernel_size=ks, sigma=sigma)
        scaled_cage_expression = cage_expression * tss_tensors_conv
        max_inds = torch.argmax(scaled_cage_expression, dim=-1)
        amax_tss_emb = embeddings[torch.arange(batch_size), max_inds]

        max_inds = torch.argmax(scaled_cage_expression, dim=-1)
        sum_tss_emb = (embeddings * scaled_cage_expression).sum(dim=1)

        all_emb.append(amax_tss_emb)
        all_emb.append(sum_tss_emb)

    return torch.stack(all_emb, dim=0) # 5 + 2 * len(sigmas)

In [None]:
def extract_embeddings(embeddings, cage_expression, tss_tensors, sigmas, tss):
    # Embeddings include
    #   TSS -1, 0, 1
    #   argmax over TSS -1, 0, 1
    batch_size = embeddings.shape[0]
    len_seq = tss_tensors.shape[1] - 1
    scaled_cage_expression = cage_expression * tss_tensors

    tss_emb = embeddings[:, tss]
    tss_emb_m1 = embeddings[:, tss - 1]
    tss_emb_1 = embeddings[:, tss + 1]

    max_inds = torch.argmax(scaled_cage_expression, dim=-1)
    amax_emb = embeddings[torch.arange(batch_size), max_inds]
    amax_emb_m1 = embeddings[torch.arange(batch_size), torch.clip(max_inds - 1, 0, len_seq)]
    amax_emb_1 = embeddings[torch.arange(batch_size), torch.clip(max_inds + 1, 0, len_seq)]
    
    all_emb = [tss_emb, tss_emb_m1, tss_emb_1, amax_emb, amax_emb_m1, amax_emb_1]
    return torch.stack(all_emb, dim=0) # 6

In [None]:
def extract_embeddings(embeddings, cage_expression, tss_tensors, sigmas, tss):
    # Embeddings include
    #   argmax over TSS -1, 0, 1 for top 5
    batch_size = embeddings.shape[0]
    len_seq = tss_tensors.shape[1] - 1
    scaled_cage_expression = cage_expression * tss_tensors

    # max_inds = torch.argmax(scaled_cage_expression, dim=-1)
    topk_inds, tok_values = torch.topk(scaled_cage_expression, 5, dim=-1)
    topk_inds[tok_values == 0] = tss

    all_emb = []
    for i in range(topk_inds.shape[1]):
        max_inds = topk_inds[:, i]
        amax_emb = embeddings[torch.arange(batch_size), max_inds]
        amax_emb_m1 = embeddings[torch.arange(batch_size), torch.clip(max_inds - 1, 0, len_seq)]
        amax_emb_1 = embeddings[torch.arange(batch_size), torch.clip(max_inds + 1, 0, len_seq)]
        all_emb += [amax_emb, amax_emb_m1, amax_emb_1]
    return torch.stack(all_emb, dim=0) # 15

In [2]:
import torch

In [12]:
A = torch.rand((2, 89))

In [16]:
torch.topk(A, 5, dim=-1).indices[:, 0]

tensor([21, 49])

In [15]:
torch.argmax(A, dim=-1)

tensor([21, 49])

In [None]:
scaled_cage_expression = cage_expression * tss_tensors
max_inds = torch.argmax(cage_expression, dim=-1)
tss_embedding = embeddings[torch.arange(batch_size), max_inds]

In [1]:
import anndata as ad


PATH = '/home/ec2-user/cell_census/tabula_sapiens__tss_max__sample_donor_id__label_cell_type.h5ad'

adata = ad.read_h5ad(PATH)

In [2]:
adata.varm

AxisArrays with keys: embedding

In [3]:
del adata.varm['embedding']

In [4]:
import zarr

z = zarr.open('/home/ec2-user/enformer/Homo_sapiens.GRCh38.genes.enformer_embeddings_trio_top5_pc_0.zarr', 'r')

In [5]:
z.shape

(15, 19431, 3072)

In [6]:
# import numpy as np
#     # Embeddings include
#     #   TSS
#     #   sum over all
#     #   argmax over all
#     #   argmax over TSS
#     #   sum over TSS
#     #   argmax over TSS sigma 3, 8, 16
#     #   sum over TSS sigma 3, 8, 16

# names = ['', '_sum', '_amax', '_tss_amax', '_tss_sum']
# sigmas = [3, 8, 16, 32, 64]
# for s in sigmas:
#     names.append('_tss_amax_' + str(s))
#     names.append('_tss_sum_' + str(s))

# for i, name in enumerate(names):
#     adata.varm['embedding' + name] = np.asarray(z[i])

In [7]:
# import numpy as np

# names = ['', '_m1', '_1', '_amax', '_amax_m1', '_amax_1']

# for i, name in enumerate(names):
#     adata.varm['embedding' + name] = np.asarray(z[i])

In [6]:
import numpy as np

k = 0
for i in range(5):
    for j in ['', '_m1', '_1']:
        adata.varm['embedding_amax_' + str(i) + str(j)] = np.asarray(z[k])
        k += 1

In [8]:
PATH_3 = '/home/ec2-user/cell_census/tabula_sapiens__trio_top5_pc__sample_donor_id__label_cell_type.h5ad'

adata.write_h5ad(PATH_3)

In [9]:
adata.varm

AxisArrays with keys: embedding, embedding_m1, embedding_1, embedding_amax, embedding_amax_m1, embedding_amax_1

In [None]:
list(adata.varm.keys())

In [None]:
PATH_1 = '/home/ec2-user/cell_census/tabula_sapiens__all_pc__sample_donor_id__label_cell_type.h5ad'
PATH_2 = '/home/ec2-user/cell_census/tabula_sapiens__all__sample_donor_id__label_cell_type.h5ad'

adata_1 = ad.read_h5ad(PATH_1)
adata_2 = ad.read_h5ad(PATH_2)

In [None]:
adata_1.varm['embedding_tss_amax'][:5,:5]

In [None]:
adata_2.varm['embedding_tss_amax'][:5, :5]

In [None]:
import numpy as np

In [None]:
np.testing.assert_almost_equal(adata_1.varm['embedding_tss_amax'], adata_2.varm['embedding_tss_amax'])

In [None]:
(adata_1.varm['embedding'] - adata_2.varm['embedding']).max()