# Embedding Novel Species

This notebook will create the files you need to embed a novel species that wasn't included in the training data.

To start, you will need to download the ESM2 protein embeddings and the reference proteome for the species.

You can find precalculated ESM2 protein embeddings for many species [here](https://drive.google.com/drive/folders/1_Dz7HS5N3GoOAG6MdhsXWY1nwLoN13DJ?usp=drive_link)

For reference proteomes, you can download them from [here](https://useast.ensembl.org/info/about/species.html).

If there is no protein embedding for the species you are interested in, you can request to have it made via Github or email, or you can create it yourself following instructions [here](https://github.com/snap-stanford/SATURN/tree/main/protein_embeddings).

In [99]:
import numpy as np
import pickle as pkl
import pandas as pd

In [117]:
SPECIES_NAME = "gorilla" # short hand name for this species, will be used in arguments and files

# Path to the species proteome
SPECIES_PROTEIN_FASTA_PATH = "/dfs/project/cross-species/yanay/data/proteome/Gorilla_gorilla.gorGor4.pep.all.fa"

# Path to the ESM2 Embeddings
SPECIES_PROTEIN_EMBEDDINGS_PATH = "/dfs/project/cross-species/yanay/data/proteome/embeddings/Gorilla_gorilla.gorGor4.pep.all.gene_symbol_to_embedding_ESM2.pt"

# primary_assembly name, this needs to be matched to the FASTA file
ASSEMBLY_NAME = "Gorilla_gorilla.gorGor4"
# NCBI Taxonomy ID, please set this so that if someone else also embeds the same species,
# randomly generated chromosome tokens will be the same
TAXONOMY_ID = 9595

You can view the FASTA format here, please confirm the primary_assembly name is correct.

In [118]:
!head {SPECIES_PROTEIN_FASTA_PATH}

>ENSGGOP00000050442.1 pep chromosome:gorGor4:7:142109302:142109637:1 gene:ENSGGOG00000038962.1 transcript:ENSGGOT00000048834.1 gene_biotype:TR_V_gene transcript_biotype:TR_V_gene
DAGVTQSPTHLIKTRGQQVTLRCSPQSGHNTVSWYQQALGQGPQFIFEYYEKEERGRGNF
PDRFSARQFPNYSSELNVNALLLGDSALYLCASSLAQPSRVTDVLYINFLP
>ENSGGOP00000047879.1 pep chromosome:gorGor4:14:23497737:23498111:-1 gene:ENSGGOG00000040725.1 transcript:ENSGGOT00000067180.1 gene_biotype:TR_V_gene transcript_biotype:TR_V_gene
GENVEQHPSTLSVQEGDSAVIKCTYSDSASNYFPWYKQELGKGPQLLIDIRSNVGEKKDQ
GITVTLNKTAKHFSLHITETQPEDSAVYFCAASTHCFPGTCHPCTNLRLELKLHPLSFVI
DRQL
>ENSGGOP00000045042.1 pep chromosome:gorGor4:14:23364087:23364422:-1 gene:ENSGGOG00000044166.1 transcript:ENSGGOT00000059848.1 gene_biotype:TR_V_gene transcript_biotype:TR_V_gene gene_symbol:TRAV17 description:T cell receptor alpha variable 17 [Source:HGNC Symbol;Acc:HGNC:12113]
SQQGEEDPQALSIQEGENATMNCSYKTSINNLQWYRQNSGRGLVHLILIRSNEREKHSGR
LRVTLDTSKKSSSLLITASRAADTASYFCATDAQCSPGTCSLYANPAKATS


In [119]:
species_to_paths = {
    SPECIES_NAME: SPECIES_PROTEIN_FASTA_PATH,
}

species_to_ids = {
    SPECIES_NAME: ASSEMBLY_NAME,
}

In [120]:
all_pos_def = []

missing_genes = {}
for species in species_to_ids.keys():
    missing_genes[species] = []
    proteome_path = species_to_paths[species]
    species_id = species_to_ids[species]

    with open(proteome_path) as f:
        proteome_lines = f.readlines()

    gene_symbol_to_location = {}
    gene_symbol_to_chrom = {}

    for line in proteome_lines:
        if line.startswith(">"):
            split_line = line.split()
            gene_symbol = [token for token in split_line if token.startswith("gene_symbol")]
            if len(gene_symbol) > 0:
                gene_symbol = gene_symbol[0].split(":")
                
                if len(gene_symbol) == 2:
                    gene_symbol = gene_symbol[1]
                elif len(gene_symbol) > 2:
                    gene_symbol = ":".join(gene_symbol[1:]) # fix for annoying zebrafish gene names with colons in them
                else:
                    1/0 # something weird happening, throw an error
                
                
                chrom = None
                
                chrom_arr = [token for token in split_line if token.startswith("chromosome:")]
                if len(chrom_arr) > 0:
                    chrom = chrom_arr[0].replace("chromosome:", "")
                else:
                    chrom_arr = [token for token in split_line if token.startswith("primary_assembly:")]
                    if len(chrom_arr) > 0:
                        chrom = chrom_arr[0].replace("primary_assembly:", "")
                    else:
                        chrom_arr = [token for token in split_line if token.startswith("scaffold:")] 
                        if len(chrom_arr) > 0:
                            chrom = chrom_arr[0].replace("scaffold:", "")
                if chrom is not None:
                    gene_symbol_to_location[gene_symbol] = chrom.split(":")[2]
                    gene_symbol_to_chrom[gene_symbol] = chrom.split(":")[1]
                else:
                    missing_genes[species].append(gene_symbol)
                    

    positional_df = pd.DataFrame()
    positional_df["gene_symbol"] = [gn.upper() for gn in list(gene_symbol_to_chrom.keys())]
    positional_df["chromosome"] = list(gene_symbol_to_chrom.values())
    positional_df["start"] = list(gene_symbol_to_location.values())
    positional_df = positional_df.sort_values(["chromosome", "start"])
    #positional_df = positional_df.set_index("gene_symbol")
    positional_df["species"] = species
    all_pos_def.append(positional_df)

In [121]:
master_pos_def = pd.concat(all_pos_def)
master_pos_def

Unnamed: 0,gene_symbol,chromosome,start,species
3967,PLPPR5,1,100279777,gorilla
8070,PLPPR4,1,100649238,gorilla
2133,PALMD,1,101051168,gorilla
4858,AGL,1,101242303,gorilla
15395,SASS6,1,101474949,gorilla
...,...,...,...,...
1802,GPR143,X,9441264,gorilla
6848,SHROOM2,X,9629546,gorilla
6712,DIAPH2,X,96542787,gorilla
384,WWC3,X,9733923,gorilla


In [122]:
master_pos_def["species"].value_counts() # double check how many genes are mapped

species
gorilla    15926
Name: count, dtype: int64

In [123]:
for k, v in missing_genes.items():
    print(f"{k}: {len(v)}") # are any genes missing?

gorilla: 0


In [124]:
# Count genes per chromosome
for species in species_to_ids.keys():
    print("*********")
    print(species)
    display(master_pos_def[master_pos_def["species"] == species]["chromosome"].value_counts().head(50))
    print("*********")

*********
gorilla


chromosome
1                  1644
5                  1186
19                 1078
11                 1043
3                   929
12                  867
6                   826
7                   728
16                  653
4                   635
9                   608
10                  598
X                   590
8                   548
2B                  545
14                  518
2A                  487
17                  469
15                  453
20                  448
22                  338
13                  271
18                  223
21                  160
MT                   13
CABD030168792.1       4
CABD030153847.1       2
CABD030167992.1       2
CABD030168825.1       2
CABD030168793.1       2
CABD030166645.1       1
CABD030166187.1       1
CABD030166558.1       1
CABD030165914.1       1
CABD030165956.1       1
CABD030165925.1       1
CABD030167128.1       1
CABD030163968.1       1
CABD030163720.1       1
CABD030163716.1       1
CABD030166995.1       1
CABD0

*********


In [125]:
master_pos_def.to_csv(f"{SPECIES_NAME}_to_chrom_pos.csv", index=False) # Save the DF

In [126]:
# The chromosome file path will be:
print(f"{SPECIES_NAME}_to_chrom_pos.csv")

gorilla_to_chrom_pos.csv


In [127]:
N_UNIQ_CHROM = len(master_pos_def[master_pos_def["species"] == species]["chromosome"].unique())
N_UNIQ_CHROM

86

# Generate token file

In [128]:
import torch
import pickle
token_dim = 5120

This will create the token file. Please note the offset value.

In [129]:
species_to_offsets = {}

all_pe = torch.load("/dfs/project/cross-species/yanay/code/uce_code/UCE_public/model_files/all_tokens.torch")[0:4] # read in existing token file to make sure 
# that special vocab tokens are the same for different seeds

offset = len(all_pe) # special tokens at the top!

PE = torch.load(SPECIES_PROTEIN_EMBEDDINGS_PATH)

pe_stacked = torch.stack(list(PE.values()))
all_pe = torch.vstack((all_pe, pe_stacked))
species_to_offsets[species] = offset

print("CHROM_TOKEN_OFFSET:", all_pe.shape[0])
torch.manual_seed(TAXONOMY_ID)
CHROM_TENSORS = torch.normal(mean=0, std=1, size=(N_UNIQ_CHROM, 5120)) 
# N_UNIQ_CHROM is the total number of chromosome choices, it is hardcoded for now (for species in the training data)
all_pe = torch.vstack(
    (all_pe, CHROM_TENSORS))  # Add the chrom tensors to the end
all_pe.requires_grad = False


torch.save(all_pe, f"{SPECIES_NAME}_pe_tokens.torch")

with open(f"{SPECIES_NAME}_offsets.pkl", "wb+") as f:
    pickle.dump(species_to_offsets, f)
print("Saved PE, offsets file")

CHROM_TOKEN_OFFSET: 15930
Saved PE, offsets file


In [130]:
all_pe.shape

torch.Size([16016, 5120])

In [131]:
print(f"{SPECIES_NAME}_offsets.pkl")

gorilla_offsets.pkl


In [132]:
SPECIES_PROTEIN_EMBEDDINGS_PATH

'/dfs/project/cross-species/yanay/data/proteome/embeddings/Gorilla_gorilla.gorGor4.pep.all.gene_symbol_to_embedding_ESM2.pt'

# Example evaluation of new species

**Note: when you evaluate a new species, you need to change some arguments and modify some files:**

You will  need to modify the csv in `model_files/new_species_protein_embeddings.csv` to include the new protein embeddings file you downloaded.

In the file add a row for the new species with the format:
`species name,full path to protein embedding file`

Please also add this line to the dictionary created on line 247 in the file `data_proc/data_utils.py`.

When you want to embed this new species, you will need to specify these newly created files as arguments.
- `CHROM_TOKEN_OFFSET`: This tells UCE when the rows corresponding to chromosome tokens starts.
- `spec_chrom_csv_path`: This is a new csv, created by this script, which maps genes to chromosomes and genomic positions
- `token_file`: This is a new token file that will work just for this species. The embeddings generated will still be universal though!
- `offset_pkl_path`: This is another file that maps genes to tokens


```

accelerate launch eval_single_anndata.py chicken_heart.h5ad --species=chicken --CHROM_TOKEN_OFFSET=13275 --spec_chrom_csv_path=data_proc/chicken_to_chrom_pos.csv --token_file=data_proc/chicken_pe_tokens.torch --offset_pkl_path=data_proc/chicken_offsets.pkl --dir=... --multi_gpu=True

```