# Embed Sequences through ESM-2 and evotuned versions

<br>

In [3]:
from datasets import Dataset
from transformers import AutoTokenizer, Trainer, TrainingArguments, EsmForMaskedLM, DataCollatorForLanguageModeling
import torch
import esm
from sklearn.model_selection import train_test_split
from torch.optim import AdamW
import plotly.express as px
import plotly.graph_objects as go
import random
import pandas as pd
from Bio import SeqIO, Seq
import numpy as np
import scipy
from peft import LoraConfig, get_peft_model#, prepare_model_for_int8_training
import math
from scipy.spatial.distance import euclidean, cityblock, cosine
import time
import itertools
from sklearn.manifold import TSNE
import ete3
import os
import datetime
import re
import statistics
import pickle

In [4]:
model, alphabet = esm.pretrained.load_model_and_alphabet('esm2_t6_8M_UR50D')
batch_converter = alphabet.get_batch_converter()

def embedall_n_entropy(modnam, indf, outf):
    
    #make df with all probabilities
    allentropies = []

    #store the hidden states
    allstates = {}

    #store full logit dfs 
    alllogitdfs = []

    #load model to GPU
    mod = EsmForMaskedLM.from_pretrained(modnam)
    device = torch.device("cuda:1")
    if torch.cuda.is_available():
        mod =  mod.to(device)
        print("%s transferred model to GPU"%modnam)  

    # seqdic = SeqIO.to_dict(SeqIO.parse(fasf, 'fasta')) 

    # print(len(seqdic))
    print(len(indf))
    
    maxslen = max(list(indf.seqlen))

    c=0
    stt = time.time()

    #parses sequences through the 'nodedetails' input df
    for seqid in list(indf.node):

        #time
        c=c+1
        if c%100==0:
            print(c)
            print(time.time() - stt)

        #prepare sequence for embedding
        seq = str( list(indf[indf.node == seqid].seq)[0] ).replace('J', 'X')
        batch_labels, batch_strs, batch_tokens = batch_converter([('base', seq)])
        batch_tokens = batch_tokens.to(device=device, non_blocking=True)
        batch_lens = (batch_tokens != alphabet.padding_idx).sum(1)

        #embed
        if torch.cuda.is_available():
            m_results = mod(batch_tokens, output_hidden_states=True)

        #softmax to normalise the logit values for this sequence
        m_logits = torch.nn.LogSoftmax(dim=1)((m_results["logits"][0]).to(device="cpu")).detach()

        #average out position states for each dimension/layer (for 33 layer model) -> one number per model dimension per layer
        m_hstates = [np.mean(m_results['hidden_states'][layernum][0].to(device="cpu").detach().numpy(), axis = 0) for layernum in range(0,34)]
        
        #make dataframe with logits
        df = pd.DataFrame(m_logits)
        #add all model tokens as columns
        df.columns = alphabet.all_toks
        df.drop(".",inplace=True,axis=1)
        #make positions column
        df["pos"] = df.index.values
        df = df.melt(id_vars="pos").sort_values(["pos","variable"])
        #get probabilities by exp the softmaxed values
        df["probability"] = np.exp(df.value)
        real_amino_acids = ["A","R","N","D","C","Q","E","G","H","I","L","K","M","F","P","S","T","W","Y","V"]
        #only keep amino acid tokens
        df = df[df.variable.isin(real_amino_acids)]
        max_probs = [sum(df[df.pos == pos].probability) for pos in df.pos.sort_values().unique()]
        #normalise by maximum prob in each position to get adjusted probability (token probs within one position sum up to 1)
        df["token_adjusted_probability"] = [max_probs[pos] for pos in df.pos]
        df["token_adjusted_probability"] = df["probability"]/df["token_adjusted_probability"]
        #remove positions that match to start/end special tokens
        df = df[(df.pos>=1) & (df.pos<=np.max(df.pos-1))]

        #list of base 2 entropies of token adjusted probs across aa tokens for each position
        embentropies = [scipy.stats.entropy(list(df[(df.pos == i+1)]['token_adjusted_probability']), base=2) for i in range(len(seq))]

        #add padding if the sequence is shorter than the longest of the alignment
        if len(embentropies) < maxslen:
            embentropies = embentropies + ['' for x in range(maxslen - len(embentropies))]

        allentropies.append([seqid] + embentropies)

        allstates.update({seqid:m_hstates})

        df['seqid'] = [seqid for i in range(len(df))]

        alllogitdfs.append(df)

    

    #turn entropy lists into df
    entrodf = pd.DataFrame(allentropies)
    # entrodf.columns = ['name'] + [i+1 for i in range(len(seq))]
    entrodf.columns = ['name'] + [i+1 for i in range(maxslen)]

    entrodf.to_csv(outf + '-site_entropy.csv', index=False)

    #export hidden states as pickle
    with open(outf + '.pickle', 'wb') as out:
        pickle.dump(allstates, out, pickle.HIGHEST_PROTOCOL)

    #export df with all logit probs
    fnlogitdf = pd.concat(alllogitdfs)
    fnlogitdf.to_csv(outf + '-all_logitprobs.csv', index=False)
    
    
    return entrodf

<br>
<br>
<br>
<br>
<br>


In [3]:
#H7 seqs
h7asrseqs = pd.read_csv('/data2/spyros/flu_LLM_evol_data/gisaid_h7_270624_filt/gisaid_h7_270624_filt_nodedetails.csv')
h7asrseqs

Unnamed: 0,node,branchlen,subs,date,seq,seqlen
0,NODE_0000079,0.13398,E100D,2001.73,MNIQILVAIACALIETKADKICLGHHAVANGTKVNTLTERGVEVVN...,570
1,NODE_0000080,0.31467,"E349K,I396V",2001.91,MNIQILVAIACALIETKADKICLGHHAVANGTKVNTLTERGVEVVN...,570
2,NODE_0000077,0.17278,"-340T,-341C,-342S,-343P,-344L,-345S,-346R,-347...",2001.59,MNIQILVAIACALIETKADKICLGHHAVANGTKVNTLTERGVEVVN...,570
3,NODE_0000076,3.81899,"G15E,T199N,I284T",2001.42,MNIQILVAIACALIETKADKICLGHHAVANGTKVNTLTERGVEVVN...,560
4,NODE_0000075,4.38225,"R182K,I278V,S310P,K321R",1997.60,MNIQILVAIACALIGTKADKICLGHHAVANGTKVNTLTERGVEVVN...,560
...,...,...,...,...,...,...
3848,EPI_ISL_285156,0.41557,R475G,2017.33,MNTQILVFALIAIIPTNADKICLGHHAVSNGTKVNTLTERGVEVVN...,560
3849,EPI_ISL_285037,0.43131,A170T,2017.35,MNTQILVFALIAIIPTNADKICLGHHAVSNGTKVNTLTERGVEVVN...,560
3850,EPI_ISL_266938,0.44590,A447V,2017.36,MNTQILVFALIAIIPTNADKICLGHHAVSNGTKVNTLTERGVEVVN...,560
3851,EPI_ISL_284988,0.48242,A286V,2017.40,MNTQILVFALIAIIPTNADKICLGHHAVSNGTKVNTLTERGVEVVN...,560


In [2]:
#embed H7
embedall_n_entropy('../models/esm2_t33-HA_all_110424_clu99_e10/',
                   h7asrseqs,
                   'gisaid_h7_270624_filt/gisaid_h7_270624_filt-ASR-esm2_t33_e10')

embedall_n_entropy('facebook/esm2_t33_650M_UR50D',
                   h7asrseqs,
                   '/data2/spyros/flu_LLM_evol_data/gisaid_h7_270624_filt/gisaid_h7_270624_filt-ASR-esm2_t33')

embedall_n_entropy('/data2/spyros/flu_LLM_evol_data/ncbiflu_models_071024/esm2_t33-H1_110424_clu99_e10_071024/best-checkpoint-1150/', 
                   h7asrseqs, 
                   '/data2/spyros/flu_LLM_evol_data/gisaid_h7_270624_filt/gisaid_h7_270624_filt-ASR-esm2_t33-H1_071024')

embedall_n_entropy('/data2/spyros/flu_LLM_evol_data/ncbiflu_models_071024/esm2_t33-H3_110424_clu99_e10_071024/best-checkpoint-2130/', 
                   h7asrseqs, 
                   '/data2/spyros/flu_LLM_evol_data/gisaid_h7_270624_filt/gisaid_h7_270624_filt-ASR-esm2_t33-H3_071024')

embedall_n_entropy('/data2/spyros/flu_LLM_evol_data/ncbiflu_models_071024/esm2_t33-H5_110424_clu99_e10_071024/best-checkpoint-441/', 
                   h7asrseqs, 
                   '/data2/spyros/flu_LLM_evol_data/gisaid_h7_270624_filt/gisaid_h7_270624_filt-ASR-esm2_t33-H5_071024')

embedall_n_entropy('/data2/spyros/flu_LLM_evol_data/ncbiflu_models_071024/esm2_t33-H7_110424_clu99_e10_071024/best-checkpoint-128/', 
                   h7asrseqs, 
                   '/data2/spyros/flu_LLM_evol_data/gisaid_h7_270624_filt/gisaid_h7_270624_filt-ASR-esm2_t33-H7_071024')

In [6]:
#H5 seqs
h5asrseqs = pd.read_csv('/data2/spyros/flu_LLM_evol_data/gisaid_h5_270624_filt/gisaid_h5_270624_filt_nodedetails.csv')
h5asrseqs

Unnamed: 0,node,branchlen,subs,date,seq,seqlen
0,NODE_0000022,0.29111,"I11T,K171Q",2004.11,MERIVIALAITSIVKADQICIGYHANNSTEQVDTIMEKNVTVTHAQ...,564
1,NODE_0000026,0.24961,T541I,2010.76,MERIVIALAIISIVKADQICIGYHANNSTEQVDTIMEKNVTVTHAQ...,564
2,NODE_0000028,0.20717,,2010.71,MERIVIALAIISIVKADQICIGYHANNSTEQVDTIMEKNVTVTHAQ...,564
3,NODE_0000025,1.18875,,2010.51,MERIVIALAIISIVKADQICIGYHANNSTEQVDTIMEKNVTVTHAQ...,564
4,NODE_0000024,3.33987,"V300I,R481K,N504D",2009.32,MERIVIALAIISIVKADQICIGYHANNSTEQVDTIMEKNVTVTHAQ...,564
...,...,...,...,...,...,...
15064,EPI_ISL_19159885,0.00000,,2023.79,MENIVLLLAIVSLVKSDQICIGYHANNSTEQVDTIMEKNVTVTHAQ...,567
15065,EPI_ISL_18691856,0.21169,"T143A,I180V",2024.00,MENIVLLLAIVSLVKSDQICIGYHANNSTEQVDTIMEKNVTVTHAQ...,567
15066,EPI_ISL_18697726,0.15022,V5A,2023.81,MENIALLLAIVSLVKSDQICIGYHANNSTEQVDTIMEKNVTVTHAQ...,567
15067,EPI_ISL_19064389,0.00000,,2023.82,MENIVLLLAIVSLVKSDQICIGYHANNSTEQVDTIMEKNVTVTHAQ...,567


In [1]:
#embed H5
embedall_n_entropy('../models/esm2_t33-HA_all_110424_clu99_e10/',
                   h5asrseqs,
                   'gisaid_h5_270624_filt/gisaid_h5_270624_filt-ASR-esm2_t33_e10')

embedall_n_entropy('facebook/esm2_t33_650M_UR50D',
                   h5asrseqs,
                   '/data2/spyros/flu_LLM_evol_data/gisaid_h5_270624_filt/gisaid_h5_270624_filt-ASR-esm2_t33')

embedall_n_entropy('/data2/spyros/flu_LLM_evol_data/ncbiflu_models_071024/esm2_t33-H1_110424_clu99_e10_071024/best-checkpoint-1150/', 
                   h5asrseqs, 
                   '/data2/spyros/flu_LLM_evol_data/gisaid_h5_270624_filt/gisaid_h5_270624_filt-ASR-esm2_t33-H1_071024')

embedall_n_entropy('/data2/spyros/flu_LLM_evol_data/ncbiflu_models_071024/esm2_t33-H3_110424_clu99_e10_071024/best-checkpoint-2130/', 
                   h5asrseqs, 
                   '/data2/spyros/flu_LLM_evol_data/gisaid_h5_270624_filt/gisaid_h5_270624_filt-ASR-esm2_t33-H3_071024')

embedall_n_entropy('/data2/spyros/flu_LLM_evol_data/ncbiflu_models_071024/esm2_t33-H5_110424_clu99_e10_071024/best-checkpoint-441/', 
                   h5asrseqs, 
                   '/data2/spyros/flu_LLM_evol_data/gisaid_h5_270624_filt/gisaid_h5_270624_filt-ASR-esm2_t33-H5_071024')

embedall_n_entropy('/data2/spyros/flu_LLM_evol_data/ncbiflu_models_071024/esm2_t33-H7_110424_clu99_e10_071024/best-checkpoint-128/', 
                   h5asrseqs, 
                   '/data2/spyros/flu_LLM_evol_data/gisaid_h5_270624_filt/gisaid_h5_270624_filt-ASR-esm2_t33-H7_071024')


In [3]:
#H1 seqs
h1asrseqs = pd.read_csv('/data2/spyros/flu_LLM_evol_data/ncbi_h1_110424_filt/ncbi_h1_110424_filt_nodedetails.csv')
h1asrseqs

Unnamed: 0,node,branchlen,subs,date,seq,seqlen
0,NODE_0002086,3.09304,"R77Q,N159D,S188L,G189N,T229I,I350V,E520G",2004.48,MKTIIALSYILCLVFAQKLPGNDNSMATLCLGHHAVPNGTLVKTIT...,567
1,NODE_0002085,15.11965,"T26M,F62S,E126D,S165Y,E227D,E531V",2001.39,MKTIIALSYILCLVFAQKLPGNDNSMATLCLGHHAVPNGTLVKTIT...,567
2,NODE_0002090,0.51807,I279V,2016.77,MKTVIALSYVFCLVFGQDFPGKGNNTATLCLGHHAVPNGTLVKTIT...,567
3,NODE_0002091,0.72017,I402T,2016.97,MKTVIALSYVFCLVFGQDFPGKGNNTATLCLGHHAVPNGTLVKTIT...,567
4,NODE_0002089,0.10261,,2016.25,MKTVIALSYVFCLVFGQDFPGKGNNTATLCLGHHAVPNGTLVKTIT...,567
...,...,...,...,...,...,...
13922,ARH17613,0.60056,P322L,2017.11,MKAILVVLLYTFTTANADTLCIGYHANNSTDTVDTVLEKNVTVTHS...,568
13923,APO20445,0.28914,P311S,2016.80,MKAILVVLLYTFTTANADTLCIGYHANNSTDTVDTVLEKNVTVTHS...,568
13924,AQS98486,0.55284,"I5T,E95D",2017.06,MKATLVVLLYTFTTANADTLCIGYHANNSTDTVDTVLEKNVTVTHS...,568
13925,APO20456,0.00000,,2016.85,MKAILVVLLYTFTTANADTLCIGYHANNSTDTVDTVLEKNVTVTHS...,568


In [3]:
#embed H1
embedall_n_entropy('../models/esm2_t33-HA_all_110424_clu99_e10/',
                   h1asrseqs,
                   'ncbi_h1_110424_filt/ncbi_h1_110424_filt-ASR-esm2_t33_e10')

embedall_n_entropy('facebook/esm2_t33_650M_UR50D',
                   h1asrseqs,
                   '/data2/spyros/flu_LLM_evol_data/ncbi_h1_110424_filt/ncbi_h1_110424_filt-ASR-esm2_t33')

embedall_n_entropy('/data2/spyros/flu_LLM_evol_data/ncbiflu_models_071024/esm2_t33-H1_110424_clu99_e10_071024/best-checkpoint-1150/', 
                   h1asrseqs, 
                   '/data2/spyros/flu_LLM_evol_data/ncbi_h1_110424_filt/ncbi_h1_110424_filt-ASR-esm2_t33-H1_071024')

embedall_n_entropy('/data2/spyros/flu_LLM_evol_data/ncbiflu_models_071024/esm2_t33-H3_110424_clu99_e10_071024/best-checkpoint-2130/', 
                   h1asrseqs, 
                   '/data2/spyros/flu_LLM_evol_data/ncbi_h1_110424_filt/ncbi_h1_110424_filt-ASR-esm2_t33-H3_071024')

embedall_n_entropy('/data2/spyros/flu_LLM_evol_data/ncbiflu_models_071024/esm2_t33-H5_110424_clu99_e10_071024/best-checkpoint-441/', 
                   h1asrseqs, 
                   '/data2/spyros/flu_LLM_evol_data/ncbi_h1_110424_filt/ncbi_h1_110424_filt-ASR-esm2_t33-H5_071024')

embedall_n_entropy('/data2/spyros/flu_LLM_evol_data/ncbiflu_models_071024/esm2_t33-H7_110424_clu99_e10_071024/best-checkpoint-128/', 
                   h1asrseqs, 
                   '/data2/spyros/flu_LLM_evol_data/ncbi_h1_110424_filt/ncbi_h1_110424_filt-ASR-esm2_t33-H7_071024')


In [7]:
#H3 seqs
h3asrseqs = pd.read_csv('/data2/spyros/flu_LLM_evol_data/ncbi_h3n2_110424_filt/ncbi_h3n2_110424_filt_nodedetails.csv')
h3asrseqs

Unnamed: 0,node,branchlen,subs,date,seq,seqlen
0,NODE_0000004,0.11299,L348I,1967.66,MKTIIALSYIFCLALGQDLPGNDNSTATLCLGHHAVPNGTLVKTIT...,566
1,NODE_0000006,0.10525,G235E,1967.65,MKTIIALSYIFCLALGQDLPGNDNSTATLCLGHHAVPNGTLVKTIT...,566
2,NODE_0000011,0.18555,D478E,1967.95,MKTIIALSYIFCLALGQDLPGNDNSTATLCLGHHAVPNGTLVKTIT...,566
3,NODE_0000010,0.21554,V199I,1967.76,MKTIIALSYIFCLALGQDLPGNDNSTATLCLGHHAVPNGTLVKTIT...,566
4,NODE_0000026,0.27832,,1970.90,MKTIIALSYIFYLALGQDLPGNDNSKATLCLGHHAVPNGTLVKTIT...,566
...,...,...,...,...,...,...
16952,WAJ06421,0.14689,R278Q,2022.26,MKTIIALSNILCLVFAQKIPGNDNSTATLCLGHHAVPNGTIVKTIT...,566
16953,USW63032,0.15237,N39D,2022.27,MKTIIALSNILCLVFAQKIPGNDNSTATLCLGHHAVPDGTIVKTIT...,566
16954,UUB81155,0.16607,L445I,2022.28,MKTIIALSNILCLVFAQKIPGNDNSTATLCLGHHAVPNGTIVKTIT...,566
16955,UUV81804,0.16607,Y528H,2022.28,MKTIIALSNILCLVFAQKIPGNDNSTATLCLGHHAVPNGTIVKTIT...,566


In [4]:
#embed H3
embedall_n_entropy('../models/esm2_t33-HA_all_110424_clu99_e10/',
                   h3asrseqs,
                   'ncbi_h3n2_110424_filt/ncbi_h3n2_110424_filt-ASR-esm2_t33_e10')

embedall_n_entropy('facebook/esm2_t33_650M_UR50D',
                   h3asrseqs,
                   '/data2/spyros/flu_LLM_evol_data/ncbi_h3n2_110424_filt/ncbi_h3n2_110424_filt-ASR-esm2_t33')

embedall_n_entropy('/data2/spyros/flu_LLM_evol_data/ncbiflu_models_071024/esm2_t33-H1_110424_clu99_e10_071024/best-checkpoint-1150/', 
                   h3asrseqs, 
                   '/data2/spyros/flu_LLM_evol_data/ncbi_h3n2_110424_filt/ncbi_h3n2_110424_filt-ASR-esm2_t33-H1_071024')

embedall_n_entropy('/data2/spyros/flu_LLM_evol_data/ncbiflu_models_071024/esm2_t33-H3_110424_clu99_e10_071024/best-checkpoint-2130/', 
                   h3asrseqs, 
                   '/data2/spyros/flu_LLM_evol_data/ncbi_h3n2_110424_filt/ncbi_h3n2_110424_filt-ASR-esm2_t33-H3_071024')

embedall_n_entropy('/data2/spyros/flu_LLM_evol_data/ncbiflu_models_071024/esm2_t33-H5_110424_clu99_e10_071024/best-checkpoint-441/', 
                   h3asrseqs, 
                   '/data2/spyros/flu_LLM_evol_data/ncbi_h3n2_110424_filt/ncbi_h3n2_110424_filt-ASR-esm2_t33-H5_071024')

embedall_n_entropy('/data2/spyros/flu_LLM_evol_data/ncbiflu_models_071024/esm2_t33-H7_110424_clu99_e10_071024/best-checkpoint-128/', 
                   h3asrseqs, 
                   '/data2/spyros/flu_LLM_evol_data/ncbi_h3n2_110424_filt/ncbi_h3n2_110424_filt-ASR-esm2_t33-H7_071024')

<br>

### ESM-2 HA-80

In [None]:
h7asrseqs = pd.read_csv('/media/spyros/HD-ADU3/spyros/flu_LLM_evol_data/gisaid_h7_270624_filt/gisaid_h7_270624_filt_nodedetails.csv')
h5asrseqs = pd.read_csv('/media/spyros/HD-ADU3/spyros/flu_LLM_evol_data/gisaid_h5_270624_filt/gisaid_h5_270624_filt_nodedetails.csv')
h1asrseqs = pd.read_csv('/media/spyros/HD-ADU3/spyros/flu_LLM_evol_data/ncbi_h1_110424_filt/ncbi_h1_110424_filt_nodedetails.csv')
h3asrseqs = pd.read_csv('/media/spyros/HD-ADU3/spyros/flu_LLM_evol_data/ncbi_h3n2_110424_filt/ncbi_h3n2_110424_filt_nodedetails.csv')


embedall_n_entropy('/media/spyros/HD-ADU3/spyros/flu_LLM_evol_data/ncbiflu_models_datesplit241024/esm2_t33-HA_all_110424_datesplit8020_e10_241024/checkpoint-3645/', 
                   h7asrseqs, 
                   '/media/spyros/HD-ADU3/spyros/flu_LLM_evol_data/gisaid_h7_270624_filt/gisaid_h7_270624_filt-ASR-esm2_t33-8020_e10_241024')

embedall_n_entropy('/media/spyros/HD-ADU3/spyros/flu_LLM_evol_data/ncbiflu_models_datesplit241024/esm2_t33-HA_all_110424_datesplit8020_e10_241024/checkpoint-3645/', 
                   h5asrseqs, 
                   '/media/spyros/HD-ADU3/spyros/flu_LLM_evol_data/gisaid_h5_270624_filt/gisaid_h5_270624_filt-ASR-esm2_t33-8020_e10_241024')

embedall_n_entropy('/media/spyros/HD-ADU3/spyros/flu_LLM_evol_data/ncbiflu_models_datesplit241024/esm2_t33-HA_all_110424_datesplit8020_e10_241024/checkpoint-3645/', 
                   h1asrseqs, 
                   '/media/spyros/HD-ADU3/spyros/flu_LLM_evol_data/ncbi_h1_110424_filt/ncbi_h1_110424_filt-ASR-esm2_t33-8020_e10_241024')

embedall_n_entropy('/media/spyros/HD-ADU3/spyros/flu_LLM_evol_data/ncbiflu_models_datesplit241024/esm2_t33-HA_all_110424_datesplit8020_e10_241024/checkpoint-3645/', 
                   h3asrseqs, 
                   '/media/spyros/HD-ADU3/spyros/flu_LLM_evol_data/ncbi_h3n2_110424_filt/ncbi_h3n2_110424_filt-ASR-esm2_t33-8020_e10_241024')
