# Embed Sequences through protT5MLM and evotuned versions


<br>

In [1]:
from datasets import Dataset

from transformers import (
    T5EncoderModel,
    T5Tokenizer,
    DataCollatorForLanguageModeling,
    TrainingArguments,
    Trainer,
)

import torch
from torch import nn

from typing import List, Optional, Tuple, Union
from transformers.modeling_outputs import BaseModelOutput, MaskedLMOutput
from torch.nn import CrossEntropyLoss
from Bio import SeqIO
from datasets import Dataset
import math
from sklearn.model_selection import train_test_split
import pandas as pd
import numpy as np
import scipy

import pandas as pd
from Bio import SeqIO, Seq
import time

import pickle

In [2]:
class T5LMHead(nn.Module):
    """Head for masked language modeling. Linear -> Gelu -> Norm -> Linear + Bias
    Outputs logits the size of the vocabulary (128)
    Adapted from ESMForMaskedLM"""

    def __init__(self, config):
        super().__init__()
        self.dense = nn.Linear(config.d_model, config.d_model)
        self.layer_norm = nn.LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
        self.decoder = nn.Linear(config.d_model, 128, bias=False)
        self.bias = nn.Parameter(torch.zeros(128))

    @staticmethod
    def gelu(x):
        """
        This is the gelu implementation from the original ESM repo. Using F.gelu yields subtly wrong results.
        """
        return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))

    def forward(self, features, **kwargs):
        x = self.dense(features)
        x = self.gelu(x)
        x = self.layer_norm(x)
        x = self.decoder(x) + self.bias
        return x

In [3]:
class T5EncoderMLM(T5EncoderModel):
    def __init__(self, config):
        super().__init__(config)
        self.custom_lm_head = T5LMHead(
            config
        ) 
        self.init_weights()
        print(config)

    def _init_weights(self, module):
        """Initialize the weights"""
        factor = (
            self.config.initializer_factor
        )  # Used for testing weights initialization
        if isinstance(module, nn.Linear):
            module.weight.data.normal_(
                mean=0.0, std=factor * ((self.config.d_model) ** -0.5)
            )
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)
        elif isinstance(module, T5LMHead):
            module.dense.weight.data.normal_(
                mean=0.0, std=factor * ((self.config.d_model) ** -0.5)
            )
            module.dense.bias.data.zero_()
            module.layer_norm.weight.data.fill_(1.0)
            module.layer_norm.bias.data.zero_()
            module.decoder.weight.data.normal_(
                mean=0.0, std=factor * ((self.config.d_model) ** -0.5)
            )
            module.bias.data.zero_()

    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        labels: Optional[torch.LongTensor] = None,
    ) -> Union[Tuple[torch.FloatTensor], MaskedLMOutput]:
        encoder_outputs = self.encoder(
            input_ids=input_ids,
            attention_mask=attention_mask,
            inputs_embeds=inputs_embeds,
            head_mask=head_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        lm_logits = self.custom_lm_head(encoder_outputs[0])

        loss = None
        if labels is not None:
            loss_fct = CrossEntropyLoss(ignore_index=-100)
            labels = labels.to(
                lm_logits.device
            )  # ensure logits and labels are on same device
            loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))

        if not return_dict:
            output = (lm_logits,) + encoder_outputs[1:]
            return ((loss,) + output) if loss is not None else output

        return MaskedLMOutput(
            loss=loss,
            logits=lm_logits,
            attentions=encoder_outputs.attentions,
            hidden_states=encoder_outputs.hidden_states,
        )

In [1]:
tokenizer = tokenizer = T5Tokenizer.from_pretrained(
    "Rostlab/prot_t5_xl_uniref50", do_lower_case=False
)

# Add masking token to the tokenizer for the datacollator to use:
tokenizer.add_special_tokens({"mask_token": "<mask>"})

data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer, return_tensors="pt", mlm_probability=0.15
)  # provide random masking and return tensors during training per-batch

In [5]:

def embedall_n_entropy_T5(modnam, indf, outf):
    
    #make df with all probabilities
    allentropies = []

    #store the hidden states
    allstates = {}

    #store full logit dfs 
    alllogitdfs = []

    #load model to GPU
    mod = T5EncoderMLM.from_pretrained(f"{modnam}", ignore_mismatched_sizes=True)
    device = torch.device("cuda:0")
    if torch.cuda.is_available():
        mod =  mod.to(device)
        print("%s transferred model to GPU"%modnam)  

    print(len(indf))
    
    maxslen = max(list(indf.seqlen))

    c=0
    stt = time.time()

    #parses sequences through the 'nodedetails' input df
    for seqid in list(indf.node):

        #time
        c=c+1
        if c%100==0:
            print(c)
            print((time.time() - stt)/60, 'min')

        #prepare sequence for embedding
        seq = str( list(indf[indf.node == seqid].seq)[0] ).replace('J', 'X')

        formseq = [" ".join(list(seq))]

        token_encoding = tokenizer(formseq, add_special_tokens=True, padding="longest")
        input_ids = torch.tensor(token_encoding['input_ids']).to(device)
        attention_mask = torch.tensor(token_encoding['attention_mask']).to(device)


        #embed
        if torch.cuda.is_available():
            m_results = mod(input_ids, attention_mask=attention_mask, return_dict=True, output_hidden_states=True, output_attentions=True)

        #softmax to normalise the logit values for this sequence
        m_logits = torch.nn.LogSoftmax(dim=1)((m_results["logits"][0]).to(device="cpu")).detach()

        #average out position states for each dimension/layer (for 24 layer model) -> one number per model dimension per layer
        m_hstates = [np.mean(m_results['hidden_states'][layernum][0].to(device="cpu").detach().numpy(), axis = 0) for layernum in range(0,25)]
        
        #make dataframe with logits
        df = pd.DataFrame(m_logits)
        #add all model tokens as columns
        df.columns = [x.replace("▁", "") for x in list(tokenizer.get_vocab())[:-1]]
        
        #make positions column
        df["pos"] = df.index.values
        df = df.melt(id_vars="pos").sort_values(["pos","variable"])
        #get probabilities by exp the softmaxed values
        df["probability"] = np.exp(df.value)
        real_amino_acids = ["A","R","N","D","C","Q","E","G","H","I","L","K","M","F","P","S","T","W","Y","V"]
        #only keep amino acid tokens
        df = df[df.variable.isin(real_amino_acids)]
        max_probs = [sum(df[df.pos == pos].probability) for pos in df.pos.sort_values().unique()]
        #normalise by maximum prob in each position to get adjusted probability (token probs within one position sum up to 1)
        df["token_adjusted_probability"] = [max_probs[pos] for pos in df.pos]
        df["token_adjusted_probability"] = df["probability"]/df["token_adjusted_probability"]
        #remove positions that match to start/end special tokens
        df = df[(df.pos>=1)]

        #list of base 2 entropies of token adjusted probs across aa tokens for each position
        embentropies = [scipy.stats.entropy(list(df[(df.pos == i+1)]['token_adjusted_probability']), base=2) for i in range(len(seq))]

        #add padding if the sequence is shorter than the longest of the alignment
        if len(embentropies) < maxslen:
            embentropies = embentropies + ['' for x in range(maxslen - len(embentropies))]

        allentropies.append([seqid] + embentropies)

        allstates.update({seqid:m_hstates})

        df['seqid'] = [seqid for i in range(len(df))]

        alllogitdfs.append(df)

    

    #turn entropy lists into df
    entrodf = pd.DataFrame(allentropies)
    # entrodf.columns = ['name'] + [i+1 for i in range(len(seq))]
    entrodf.columns = ['name'] + [i+1 for i in range(maxslen)]

    entrodf.to_csv(outf + '-site_entropy.csv', index=False)

    #export hidden states as pickle
    with open(outf + '.pickle', 'wb') as out:
        pickle.dump(allstates, out, pickle.HIGHEST_PROTOCOL)

    #export df with all logit probs
    fnlogitdf = pd.concat(alllogitdfs)
    fnlogitdf.to_csv(outf + '-all_logitprobs.csv', index=False)
    
    
    return entrodf

<br>
<br>
<br>
<br>
<br>


In [2]:
#example for H7 seqs

h7asrseqs = pd.read_csv('/media/spyros/HD-ADU3/spyros/flu_LLM_evol_data/gisaid_h7_270624_filt/gisaid_h7_270624_filt_nodedetails.csv')



#single serotypes
embedall_n_entropy_T5('/media/spyros/HD-ADU3/spyros/flu_LLM_evol_data/ncbiflu_protT5_221124/T5EncoderMLM_uniref_ncbi_e10_ncbiflu_HA_all_110424_noX_clu99_filt_H1/trainer/checkpoint-1000',
                      h7asrseqs,
                      '/media/spyros/HD-ADU3/spyros/flu_LLM_evol_data/gisaid_h7_270624_filt/gisaid_h7_270624_filt-ASR-T5_uniref_H1_221124')

embedall_n_entropy_T5('/media/spyros/HD-ADU3/spyros/flu_LLM_evol_data/ncbiflu_protT5_221124/T5EncoderMLM_uniref_ncbi_e10_ncbiflu_HA_all_110424_noX_clu99_filt_H5/trainer/checkpoint-344',
                      h7asrseqs,
                      '/media/spyros/HD-ADU3/spyros/flu_LLM_evol_data/gisaid_h7_270624_filt/gisaid_h7_270624_filt-ASR-T5_uniref_H5_221124')

embedall_n_entropy_T5('/media/spyros/HD-ADU3/spyros/flu_LLM_evol_data/ncbiflu_protT5_221124/T5EncoderMLM_uniref_ncbi_e10_ncbiflu_HA_all_110424_noX_clu99_filt_H7/trainer/checkpoint-112',
                      h7asrseqs,
                      '/media/spyros/HD-ADU3/spyros/flu_LLM_evol_data/gisaid_h7_270624_filt/gisaid_h7_270624_filt-ASR-T5_uniref_H7_221124')

embedall_n_entropy_T5('/media/spyros/HD-ADU3/spyros/flu_LLM_evol_data/ncbiflu_protT5_221124/T5EncoderMLM_uniref_ncbi_e10_ncbiflu_HA_all_110424_noX_clu99_filt_H3/trainer/checkpoint-1860/',
                      h7asrseqs,
                      '/media/spyros/HD-ADU3/spyros/flu_LLM_evol_data/gisaid_h7_270624_filt/gisaid_h7_270624_filt-ASR-T5_uniref_H3_221124')


#HA-80
embedall_n_entropy_T5('/media/spyros/HD-ADU3/spyros/flu_LLM_evol_data/ncbiflu_protT5_221124/T5EncoderMLM_uniref_ncbi_e10_ncbiflu_HA_all_110424_noX_clu99_filt_datesplit8020/trainer/checkpoint-3550/',
                      h7asrseqs,
                      '/media/spyros/HD-ADU3/spyros/flu_LLM_evol_data/gisaid_h7_270624_filt/gisaid_h7_270624_filt-ASR-T5_uniref_8020_221124')


#base
embedall_n_entropy_T5('/media/spyros/HD-ADU3/spyros/flu_LLM_evol_data/ncbiflu_protT5_221124/T5EncoderMLM_UniRef_2e_constant_20241106/trainer/checkpoint-127292',
                      h7asrseqs,
                      '/media/spyros/HD-ADU3/spyros/flu_LLM_evol_data/gisaid_h7_270624_filt/gisaid_h7_270624_filt-ASR-T5_UniRef_e2')


#HA-all
embedall_n_entropy_T5('/media/spyros/HD-ADU3/spyros/flu_LLM_evol_data/ncbiflu_protT5_221124/T5EncoderMLM_uniref_ncbi_e10_ncbiflu_HA_all_110424_noX_clu99_filt_all/trainer/checkpoint-3897',
                      h7asrseqs,
                      '/media/spyros/HD-ADU3/spyros/flu_LLM_evol_data/gisaid_h7_270624_filt/gisaid_h7_270624_filt-ASR-T5_uniref_HA_221124')

