# Embedding Novel Species

This notebook will create the files you need to embed a novel species that wasn't included in the training data.

To start, you will need to download the ESM2 protein embeddings and the reference proteome for the species.

You can find precalculated ESM2 protein embeddings for many species [here](https://drive.google.com/drive/folders/1_Dz7HS5N3GoOAG6MdhsXWY1nwLoN13DJ?usp=drive_link)

For reference proteomes, you can download them from [here](https://useast.ensembl.org/info/about/species.html).

If there is no protein embedding for the species you are interested in, you can request to have it made via Github or email, or you can create it yourself following instructions [here](https://github.com/snap-stanford/SATURN/tree/main/protein_embeddings).

In [1]:
import numpy as np
import pickle as pkl
import pandas as pd

In [30]:
SPECIES_NAME = "rabbit" # short hand name for this species, will be used in arguments and files

# Path to the species proteome
SPECIES_PROTEIN_FASTA_PATH = "/lfs/local/0/yanay/SATURN/protein_embeddings/data/Oryctolagus_cuniculus.OryCun2.0.pep.all.fa"

# Path to the ESM2 Embeddings
SPECIES_PROTEIN_EMBEDDINGS_PATH = "/lfs/local/0/yanay/SATURN/protein_embeddings//data/Oryctolagus_cuniculus.OryCun2.0.pep.all.gene_symbol_to_embedding_ESM2.pt"

# primary_assembly name, this needs to be matched to the FASTA file
ASSEMBLY_NAME = "Oryctolagus_cuniculus.OryCun2.0"
# NCBI Taxonomy ID, please set this so that if someone else also embeds the same species,
# randomly generated chromosome tokens will be the same
TAXONOMY_ID = 9986

You can view the FASTA format here, please confirm the primary_assembly name is correct.

In [31]:
!head {SPECIES_PROTEIN_FASTA_PATH}

>ENSOCUP00000018422.2 pep chromosome:OryCun2.0:17:42506545:42507786:1 gene:ENSOCUG00000023049.2 transcript:ENSOCUT00000027669.2 gene_biotype:TR_V_gene transcript_biotype:TR_V_gene
MAQTVTQTQPEMSVQEAETATLDCTYDTRDSDYYLFWYKQPPSGELVLIIRQEAYKPQNA
TQNRFSVNFQKVSKSFSLKISDSQLGDAGMYLCARMKEYFRMSF
>ENSOCUP00000036682.1 pep chromosome:OryCun2.0:10:21480758:21481093:-1 gene:ENSOCUG00000036137.1 transcript:ENSOCUT00000052566.1 gene_biotype:TR_C_gene transcript_biotype:TR_C_gene
DKKLDGDFSPKPTIFLPSIAETKLHKAGTYLCLLEKFFPDVIKVYWKEKNGDTILESQEG
NTMKINGTYMKFSWLTVPQKSFDKEHRCIVEHENTRESKQEILFPAINKGT
>ENSOCUP00000012669.3 pep chromosome:OryCun2.0:21:4870031:4880529:1 gene:ENSOCUG00000021096.2 transcript:ENSOCUT00000014734.3 gene_biotype:IG_C_gene transcript_biotype:IG_C_gene
GQPAVTPSVILFPPSSEELKDNKATLVCLINDFYPRTVKVNWKADGNSVTQGVDTTQPSK
QSNNKYAASSFLSLSANQWKSYQSVTCQVTHEGHTVEKSLAPAECQPAVTPSVILFPPSS
EELKDNKATLVCLINDFYPGTVKVNWKADGTPVTQGVDTTQPSKQSNSKYAASSFLSLSA


In [32]:
species_to_paths = {
    SPECIES_NAME: SPECIES_PROTEIN_FASTA_PATH,
}

species_to_ids = {
    SPECIES_NAME: ASSEMBLY_NAME,
}

In [49]:
all_pos_def = []

missing_genes = {}
for species in species_to_ids.keys():
    missing_genes[species] = []
    proteome_path = species_to_paths[species]
    species_id = species_to_ids[species]

    with open(proteome_path) as f:
        proteome_lines = f.readlines()

    gene_symbol_to_location = {}
    gene_symbol_to_chrom = {}

    for line in proteome_lines:
        if line.startswith(">"):
            split_line = line.split()
            gene_symbol = [token for token in split_line if token.startswith("gene")]
            if len(gene_symbol) > 0:
                gene_symbol = gene_symbol[0].split(":")
                
                if len(gene_symbol) == 2:
                    gene_symbol = gene_symbol[1]
                elif len(gene_symbol) > 2:
                    gene_symbol = ":".join(gene_symbol[1:]) # fix for annoying zebrafish gene names with colons in them
                else:
                    1/0 # something weird happening, throw an error
                
                
                chrom = None
                
                chrom_arr = [token for token in split_line if token.startswith("chromosome:")]
                if len(chrom_arr) > 0:
                    chrom = chrom_arr[0].replace("chromosome:", "")
                else:
                    chrom_arr = [token for token in split_line if token.startswith("primary_assembly:")]
                    if len(chrom_arr) > 0:
                        chrom = chrom_arr[0].replace("primary_assembly:", "")
                    else:
                        chrom_arr = [token for token in split_line if token.startswith("scaffold:")] 
                        if len(chrom_arr) > 0:
                            chrom = chrom_arr[0].replace("scaffold:", "")
                if chrom is not None:
                    gene_symbol_to_location[gene_symbol] = chrom.split(":")[2]
                    gene_symbol_to_chrom[gene_symbol] = chrom.split(":")[1]
                else:
                    missing_genes[species].append(gene_symbol)
                    

    positional_df = pd.DataFrame()
    positional_df["gene_symbol"] = [gn.upper() for gn in list(gene_symbol_to_chrom.keys())]
    positional_df["chromosome"] = list(gene_symbol_to_chrom.values())
    positional_df["start"] = list(gene_symbol_to_location.values())
    positional_df = positional_df.sort_values(["chromosome", "start"])
    #positional_df = positional_df.set_index("gene_symbol")
    positional_df["species"] = species
    all_pos_def.append(positional_df)

In [50]:
master_pos_def = pd.concat(all_pos_def)
master_pos_def

Unnamed: 0,gene_symbol,chromosome,start,species
13563,ENSOCUG00000030222.1,1,10033937,rabbit
13566,ENSOCUG00000009135.4,1,10079693,rabbit
11518,ENSOCUG00000000864.4,1,101192104,rabbit
12748,ENSOCUG00000022316.2,1,101724659,rabbit
12752,ENSOCUG00000010311.4,1,101780739,rabbit
...,...,...,...,...
6720,ENSOCUG00000005012.4,X,98832565,rabbit
7402,ENSOCUG00000005013.4,X,98862791,rabbit
9551,ENSOCUG00000006668.4,X,9919347,rabbit
9562,ENSOCUG00000006592.4,X,9983507,rabbit


In [51]:
master_pos_def["species"].value_counts() # double check how many genes are mapped

rabbit    20612
Name: species, dtype: int64

In [52]:
for k, v in missing_genes.items():
    print(f"{k}: {len(v)}") # are any genes missing?

rabbit: 0


In [53]:
# Count genes per chromosome
for species in species_to_ids.keys():
    print("*********")
    print(species)
    display(master_pos_def[master_pos_def["species"] == species]["chromosome"].value_counts().head(50))
    print("*********")

*********
rabbit


1           1621
13          1379
12          1038
2            983
7            884
3            836
19           830
4            803
14           796
17           765
9            754
X            685
8            522
18           517
15           517
16           498
5            367
11           348
10           269
6            238
20           204
21           179
GL018699     118
GL018704     106
GL018717      95
GL018725      65
GL018786      64
GL018706      63
GL018752      58
GL018763      57
GL018734      54
GL018714      52
GL018758      52
GL018789      49
GL018705      48
GL018767      48
GL018702      45
GL018738      42
GL018723      41
GL018776      40
GL018828      40
GL018701      40
GL018816      39
GL018760      39
GL018700      39
GL018823      37
GL018747      34
GL018765      32
GL018792      31
GL018730      31
Name: chromosome, dtype: int64

*********


In [54]:
master_pos_def.to_csv(f"{SPECIES_NAME}_to_chrom_pos.csv", index=False) # Save the DF

In [55]:
# The chromosome file path will be:
print(f"{SPECIES_NAME}_to_chrom_pos.csv")

rabbit_to_chrom_pos.csv


In [56]:
N_UNIQ_CHROM = len(master_pos_def[master_pos_def["species"] == species]["chromosome"].unique())
N_UNIQ_CHROM

1026

# Generate token file

In [57]:
import torch
import pickle
token_dim = 5120

This will create the token file. Please note the offset value.

In [58]:
species_to_offsets = {}

all_pe = torch.load("/dfs/project/cross-species/yanay/code/uce_code/UCE_public/model_files/all_tokens.torch")[0:4] # read in existing token file to make sure 
# that special vocab tokens are the same for different seeds

offset = len(all_pe) # special tokens at the top!

PE = torch.load(SPECIES_PROTEIN_EMBEDDINGS_PATH)

pe_stacked = torch.stack(list(PE.values()))
all_pe = torch.vstack((all_pe, pe_stacked))
species_to_offsets[species] = offset

print("CHROM_TOKEN_OFFSET:", all_pe.shape[0])
torch.manual_seed(TAXONOMY_ID)
CHROM_TENSORS = torch.normal(mean=0, std=1, size=(N_UNIQ_CHROM, 5120)) 
# N_UNIQ_CHROM is the total number of chromosome choices, it is hardcoded for now (for species in the training data)
all_pe = torch.vstack(
    (all_pe, CHROM_TENSORS))  # Add the chrom tensors to the end
all_pe.requires_grad = False


torch.save(all_pe, f"{SPECIES_NAME}_pe_tokens.torch")

with open(f"{SPECIES_NAME}_offsets.pkl", "wb+") as f:
    pickle.dump(species_to_offsets, f)
print("Saved PE, offsets file")

CHROM_TOKEN_OFFSET: 20616
Saved PE, offsets file


In [59]:
all_pe.shape

torch.Size([21642, 5120])

In [60]:
all_pe.shape

torch.Size([21642, 5120])

In [61]:
print(f"{SPECIES_NAME}_offsets.pkl")

rabbit_offsets.pkl


In [62]:
SPECIES_PROTEIN_EMBEDDINGS_PATH

'/lfs/local/0/yanay/SATURN/protein_embeddings//data/Oryctolagus_cuniculus.OryCun2.0.pep.all.gene_symbol_to_embedding_ESM2.pt'

# Example evaluation of new species

**Note: when you evaluate a new species, you need to change some arguments and modify some files:**

You will  need to modify the csv in `model_files/new_species_protein_embeddings.csv` to include the new protein embeddings file you downloaded.

In the file add a row for the new species with the format:
`species name,full path to protein embedding file`

Please also add this line to the dictionary created on line 247 in the file `data_proc/data_utils.py`.

When you want to embed this new species, you will need to specify these newly created files as arguments.
- `CHROM_TOKEN_OFFSET`: This tells UCE when the rows corresponding to chromosome tokens starts.
- `spec_chrom_csv_path`: This is a new csv, created by this script, which maps genes to chromosomes and genomic positions
- `token_file`: This is a new token file that will work just for this species. The embeddings generated will still be universal though!
- `offset_pkl_path`: This is another file that maps genes to tokens


```

accelerate launch eval_single_anndata.py chicken_heart.h5ad --species=chicken --CHROM_TOKEN_OFFSET=13275 --spec_chrom_csv_path=data_proc/chicken_to_chrom_pos.csv --token_file=data_proc/chicken_pe_tokens.torch --offset_pkl_path=data_proc/chicken_offsets.pkl --dir=... --multi_gpu=True

```