## **ReverTra - Reverse Translation with Transformers**

---

\
This is an inference script for using the models from the paper [ref]. \
The models are built on the BART architecture from the Huggingface platform, and the training procedure and algorithms are depicted in the paper. \
Github code with full training code and data can be found at: https://github.com/siditom-cs/ReverTra. \

This script accomodates the two inference types with BART models depicted in the paper: (a) only the target amino-acid sequence; and (b) alignment of the target amino-acid sequence with an additional codon sequence.

This script also offers an option to include accuracy and perplexity calculation given the original codon target sequence is present.  calculating accuracy and loss for a specific sequence generation, we offer an option to include the targe codon sequence and get the statistics.

**Available models:**

**[Mimic]**
1. Finetuned twice mimic model with fixed-win of size 10: "siditom/co-model_mimic-rexpr-10w_2ft" \
2. Finetuned twice mimic model with fixed-win of size 30: "siditom/co-model_mimic-rexpr-30w_2ft" \
3. Finetuned twice mimic model with fixed-win of size 50: "siditom/co-model_mimic-rexpr-50w_2ft" \
4. Finetuned twice mimic model with fixed-win of size 75: "siditom/co-model_mimic-rexpr-75w_2ft" \
5. Finetuned twice mimic model with fixed-win of size 100: "siditom/co-model_mimic-rexpr-100w_2ft" \
6. Finetuned twice mimic model with fixed-win of size 150: "siditom/co-model_mimic-rexpr-150w_2ft" \

**[Mask]** \
1. Finetuned once mask model with fixed-win of size 10: "siditom/co-model_mask-rexpr-10w_1ft" \
2. Finetuned once mask model with fixed-win of size 30: "siditom/co-model_mask-rexpr-30w_1ft" \
3. Finetuned once mask model with fixed-win of size 50: "siditom/co-model_mask-rexpr-50w_1ft" \
4. Finetuned once mask model with fixed-win of size 75: "siditom/co-model_mask-rexpr-75w_1ft" \
5. Finetuned once mask model with fixed-win of size 100: "siditom/co-model_mask-rexpr-100w_1ft" \
6. Finetuned once mask model with fixed-win of size 150: "siditom/co-model_mask-rexpr-150w_1ft" \

# New Section

Implementation details:
---

The arguemnts for the predict function below includes a dictionary configuration file with the paramenters needed for the inference:

- **sw_aa_size**: window size for generating subsets of predictions - should be the same size as the model was trained on. Options: [10,30,50,75,100,150], see available models.
- **inference_type**: 'mimic'/'mask'.
- **calc_stats**: False/True. Whether the input includes the target codon sequence for calculating accuracy and perplexity.

The input args are also a dictionary with the following keys:
- **qseq**: amino-acid sequence of the target sequence.
- **query_species**: the traget's host species.
- **expr**: the token of the expression level. we enable 6 tokens corresponding for the expression level percentails: 90%-100%, 75%-90%, 50%-75%, 25%-50%, lower than 25%, and unspecified with tokens: [expr_top10, expr_pre75_90, expr_pre50_75, expr_pre25_50, expr_low25, expr_unk], respectively.
- **subject_dna_seq**: [Optional] - space delimited codon sequence of the mimic protein. Required for inference_type='mimic'.
- **query_dna_seq**: [Optional] - space delimited codon sequence of the target protein. Required for calc_stats=True.
- **subject_species**:  [Optional] - the mimic's sequence origin species. Required for inference_type='mimic'.



In [None]:
!pip install transformers

Collecting transformers
  Downloading transformers-4.35.0-py3-none-any.whl (7.9 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.9/7.9 MB[0m [31m59.5 MB/s[0m eta [36m0:00:00[0m
Collecting huggingface-hub<1.0,>=0.16.4 (from transformers)
  Downloading huggingface_hub-0.18.0-py3-none-any.whl (301 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m302.0/302.0 kB[0m [31m38.2 MB/s[0m eta [36m0:00:00[0m
Collecting tokenizers<0.15,>=0.14 (from transformers)
  Downloading tokenizers-0.14.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.8/3.8 MB[0m [31m102.7 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting safetensors>=0.3.1 (from transformers)
  Downloading safetensors-0.4.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m60.9 MB/s[0m eta [36m0:00:00[0m
Co

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


Import and auxilary functions:
---

The next cells include the axuxilary and wrapper functions to prepare the raw amino-acid (and/or codon) sequence for the model, and aggregate the model outputs from windows to a single prediction.

- **calc_combined_gen_from_sliding_windows_logits**: aggregates the predicted logits for each window in the sequence to a single (N * V) matrix, where N is the size of the amino-acid protein sequence, and V is the size of the vocabulary of the tokenizer.

- **predict**: wrapper function that prepares the input sequence and configuration to be model compatible, applies the trained model to each window in the given input sequence, aggregates the predicted logits windows, and calcultes metrics.


At the last cell of this section, we download the chosen model and the tokenizer to be applied on the examples below.



In [None]:
import torch
from transformers import AutoTokenizer, BartForConditionalGeneration, LogitsWarper
from transformers import LogitsProcessor,LogitsProcessorList
import numpy as np
import json

In [None]:
class RestrictToAaLogitsWarper(LogitsWarper):
    def __init__(self, masked_input_ids: torch.LongTensor, restrict_dict: dict, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
        self.masked_input_ids = masked_input_ids
        self.restrict_dict = restrict_dict
        self.filter_value = filter_value
        self.min_tokens_to_keep = min_tokens_to_keep
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        cur_len = input_ids.shape[-1]-1
        vocab_size = scores.shape[-1]
        if self.masked_input_ids.shape[-1] <= cur_len:
            return scores
        for bid in range(input_ids.shape[0]):
            cur_mask_input = str(int(self.masked_input_ids[bid][cur_len].item()))
            if cur_mask_input in self.restrict_dict.keys():
                restricted_words = self.restrict_dict[cur_mask_input]
                banned_indices = set(range(vocab_size))-set(self.restrict_dict[cur_mask_input])
                banned = torch.tensor([i in banned_indices for i in range(vocab_size)])
                scores[bid][banned] = self.filter_value
        return scores



In [None]:
def get_expr_token(expr):
    if expr==None:
        expr = 'expr_unk'
    return "<"+expr+"> "

def prepare_inputs(example, tokenizer, config):
        #Rename
        qaaseq, query_species, sw_aa_size = example['qseq'], example['query_species'],  config['sw_aa_size']
        if config['inference_type']=='mimic':
          sseq, subject_species = example['subject_dna_seq'].split(" "), example['subject_species']
        else:
          sseq, subject_species = ['<mask_'+aa+'>' if aa!='-' else '<gap>' for aa in qaaseq], example['query_species']


        #Prepare fixed-sized windows
        query_aa_wins = [qaaseq[i:i+sw_aa_size] for i in range(0,max(1,len(qaaseq)-sw_aa_size+1))]
        subject_dna_wins = [" ".join(sseq[i:i+sw_aa_size]) for i in range(0,max(1,len(qaaseq)-sw_aa_size+1))]
        mask_aa_wins = ["<"+query_species+"> "+" ".join(['<mask_'+aa+'>' if aa!='-' else '<gap>' for aa in wseq]) for wseq in query_aa_wins]
        query_aa_wins = ["<"+query_species+"> "+get_expr_token(example['expr'])+' '.join(['<mask_'+aa+'>' if aa!='-' else '<gap>' for aa in wseq]) for wseq in query_aa_wins]
        subject_dna_wins = ["<"+subject_species+"> "+wseq for wseq in subject_dna_wins]

        #Encode windows
        input_ids = tokenizer(query_aa_wins, subject_dna_wins, return_tensors="pt", padding='max_length', max_length=sw_aa_size*2+3).input_ids
        masked_ids = tokenizer(mask_aa_wins, return_tensors="pt").input_ids[:,1:-1]
        return input_ids,masked_ids

def generate_outputs(input_ids, masked_ids, mask_restriction_dict, model, sw_aa_size):
        logits_processor = LogitsProcessorList(
                [RestrictToAaLogitsWarper(masked_ids, mask_restriction_dict)])

        outputs = model.generate(input_ids, do_sample=False, output_scores = True, return_dict_in_generate = True, renormalize_logits = True, logits_processor=logits_processor, max_length=min((sw_aa_size+3),masked_ids.shape[-1]+2))
        outputs = torch.stack(outputs['scores'][:sw_aa_size+1],1)
        return outputs

def calc_combined_gen_from_sliding_windows_logits(sw_logits, seqlen, sw_aa_size):
        sw_logits = sw_logits
        collect_logits = torch.zeros([seqlen, sw_logits.shape[-1]])
        counts = torch.zeros([1,seqlen])
        most_freq_pred = torch.zeros([seqlen,1])

        #Aggregating (sums) the logits of the different windows. Only the relevant codons (restricted by AA) are sumed.
        for i in range(sw_logits.shape[0]): # window num
            for j in range(min(sw_aa_size, seqlen)): # sequence len - codon index
                collect_logits[i+j, :] += torch.exp(sw_logits[i, 1+j, :])
                counts[0,i+j] += 1

        #Normalizing each position by the number of predictions (eg. first codon has only one prediction)
        for i in range(seqlen):
            collect_logits[i,:] /= counts[0,i]
        collect_logits = torch.log(collect_logits)
        collect_logits = collect_logits.log_softmax(dim=-1)

        for i in range(seqlen):
            most_freq_pred[i] = torch.argmax(collect_logits[i,:]).item()

        return collect_logits, most_freq_pred

def predict(config, example, mask_restriction_dict, tokenizer, model):



        input_ids, masked_ids = prepare_inputs(example, tokenizer, config)
        outputs = generate_outputs(input_ids, masked_ids, mask_restriction_dict, model, config['sw_aa_size'])
        logits, most_freq_pred = calc_combined_gen_from_sliding_windows_logits(outputs, len(example['qseq']), config['sw_aa_size'])

        ce = torch.nn.CrossEntropyLoss()
        most_freq_pred=most_freq_pred.clone().detach().reshape((1,-1))


        #print("decode: ", tokenizer.decode(most_freq_pred.numpy().astype(int)[0]))
        #print("truevals: ", tokenizer.decode(true_vals))
        res = dict()

        res['prot_len'] = len(example['qseq'])
        res['prot_AAs'] = example['qseq']
        res['pred_codons'] = tokenizer.decode(most_freq_pred.numpy().astype(int)[0])
        res['entropy'] = (-torch.nan_to_num(torch.exp(logits)*logits,nan=0.0).sum(dim=-1)).mean().item()

        assert(res['prot_len']==len(res['pred_codons'].split(" ")))

        if config['calc_stats'] and 'query_dna_seq' in example.keys():
          true_vals = tokenizer(example['query_dna_seq'], return_tensors="pt").input_ids[:,1:-1]
          mask = true_vals > 41 #special tokens threshold
          true_vals = true_vals.tolist()[0]
          masked_most_freq_pred = most_freq_pred.masked_select(mask).numpy().astype(int)
          masked_true_vals = torch.tensor(true_vals).masked_select(mask).numpy().astype(int)

          res['subject_codons'] = example['subject_dna_seq']
          res['num_of_correct_predicted_codons'] = sum([int(x==y) for x,y in zip(masked_true_vals, masked_most_freq_pred)])
          res['query_codons'] = example['query_dna_seq']
          res['cross_entropy_loss'] = ce(logits, torch.tensor(true_vals)).item()
          res['perplexity'] = np.exp(res['cross_entropy_loss'])
          res['accuracy'] = res['num_of_correct_predicted_codons'] / res['prot_len']
        #print(example['qseqid'], example['sseqid'],res['cross_entropy_loss'], res['entropy'],res['accuracy'])
        return res

In [None]:
tokenizer = AutoTokenizer.from_pretrained('siditom/tokenizer-codon_optimization-refined_expr')
!wget https://huggingface.co/siditom/tokenizer-codon_optimization-refined_expr/resolve/main/mask_restrict_dict.json
mask_restrict_dict = {}
with open('/content/mask_restrict_dict.json','r') as handle:
  mask_restrict_dict = json.load(handle)
model = BartForConditionalGeneration.from_pretrained("siditom/co-model_mimic-rexpr-50w_2ft")


--2023-11-08 13:59:37--  https://huggingface.co/siditom/tokenizer-codon_optimization-refined_expr/resolve/main/mask_restrict_dict.json
Resolving huggingface.co (huggingface.co)... 18.164.174.118, 18.164.174.17, 18.164.174.55, ...
Connecting to huggingface.co (huggingface.co)|18.164.174.118|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1457 (1.4K) [text/plain]
Saving to: ‘mask_restrict_dict.json.2’


2023-11-08 13:59:37 (720 MB/s) - ‘mask_restrict_dict.json.2’ saved [1457/1457]



## **Example 1 - mimic codon sequence generation**

In [None]:

config = {
    'sw_aa_size':50,
    'calc_stats':False,
    'inference_type':'mimic'
}

example = {
    'subject_dna_seq':'AAA GCG GCT GTA CTG GTC AAG AAA GTT CTG GAA TCT GCC ATT GCT AAC GCT GAA CAC AAC GAT GGC GCT GAC ATT GAC GAT CTG AAA GTT ACG AAA ATT TTC GTA GAC GAA GGC CCG AGC ATG AAG CGC ATT ATG CCG CGT GCA AAA GGT CGT GCA GAT CGC ATC CTG AAG CGC ACC AGC CAC ATC ACT GTG GTT GTG TCC GAT CGC',
    'qseq':'KSVKFVQGLLQNAAANAEA-KGLDATKLYVSHIQVNQAPKQRRRTYRAHGRINKYESSPSHIELVVTEK',
    'query_species':'S_cerevisiae',
    'subject_species':'E_coli',
    'expr':'expr_top10'
}

assert(len(example['qseq'])==len(example['subject_dna_seq'].split(" ")))
res = json.dumps(predict(config, example, mask_restrict_dict, tokenizer, model),indent=2)
print(res)


{
  "prot_len": 69,
  "prot_AAs": "KSVKFVQGLLQNAAANAEA-KGLDATKLYVSHIQVNQAPKQRRRTYRAHGRINKYESSPSHIELVVTEK",
  "pred_codons": "AAG TCT GTT AAG TTT GTT CAA GGT TTG TTG CAA AAC GCT GCT GCT AAC GCT GAA GCT <gap> AAG GGT TTG GAT GCT ACC AAG TTG TAC GTT TCT CAC ATT CAA GTC AAC CAA GCT CCA AAG CAA AGA AGA AGA ACT TAC AGA GCT CAC GGT AGA ATC AAC AAG TAC GAA TCT TCT CCA TCT CAC ATT GAA TTG GTT GTT ACT GAA AAG",
  "entropy": 0.7021335959434509
}


## **Example 2 - mimic codon sequence generation with statistics**

This inference enables you to calculate the accuracy and loss for a specific codon sequence. In addition with the amino-acid sequence of the protein target, you are required to insert the codon sequence of the translated protein.

In [None]:
config = {
    'sw_aa_size':50,
    'calc_stats':True,
    'inference_type':'mimic'
}

example = {
    'query_dna_seq':  'AAA TCT GTT AAG TTC GTT CAA GGT TTG TTG CAA AAC GCC GCT GCC AAT GCT GAA GCT <gap> AAG GGT CTA GAT GCT ACC AAG TTG TAC GTT TCT CAC ATC CAA GTT AAC CAA GCA CCA AAG CAA AGA AGA AGA ACT TAC AGA GCC CAC GGT AGA ATC AAC AAG TAC GAA TCT TCT CCA TCT CAC ATT GAA TTG GTT GTT ACC GAA AAG',
    'subject_dna_seq':'AAA GCG GCT GTA CTG GTC AAG AAA GTT CTG GAA TCT GCC ATT GCT AAC GCT GAA CAC AAC GAT GGC GCT GAC ATT GAC GAT CTG AAA GTT ACG AAA ATT TTC GTA GAC GAA GGC CCG AGC ATG AAG CGC ATT ATG CCG CGT GCA AAA GGT CGT GCA GAT CGC ATC CTG AAG CGC ACC AGC CAC ATC ACT GTG GTT GTG TCC GAT CGC',
    'qseq':'KSVKFVQGLLQNAAANAEA-KGLDATKLYVSHIQVNQAPKQRRRTYRAHGRINKYESSPSHIELVVTEK',
    'query_species':'S_cerevisiae',
    'subject_species':'E_coli',
    'expr':'expr_top10'
}

assert(len(example['qseq'])==len(example['query_dna_seq'].split(" ")))
assert(len(example['qseq'])==len(example['subject_dna_seq'].split(" ")))
res = json.dumps(predict(config, example, mask_restrict_dict, tokenizer, model),indent=2)
print(res)

{
  "prot_len": 69,
  "prot_AAs": "KSVKFVQGLLQNAAANAEA-KGLDATKLYVSHIQVNQAPKQRRRTYRAHGRINKYESSPSHIELVVTEK",
  "pred_codons": "AAG TCT GTT AAG TTT GTT CAA GGT TTG TTG CAA AAC GCT GCT GCT AAC GCT GAA GCT <gap> AAG GGT TTG GAT GCT ACC AAG TTG TAC GTT TCT CAC ATT CAA GTC AAC CAA GCT CCA AAG CAA AGA AGA AGA ACT TAC AGA GCT CAC GGT AGA ATC AAC AAG TAC GAA TCT TCT CCA TCT CAC ATT GAA TTG GTT GTT ACT GAA AAG",
  "entropy": 0.7021335959434509,
  "subject_codons": "AAA GCG GCT GTA CTG GTC AAG AAA GTT CTG GAA TCT GCC ATT GCT AAC GCT GAA CAC AAC GAT GGC GCT GAC ATT GAC GAT CTG AAA GTT ACG AAA ATT TTC GTA GAC GAA GGC CCG AGC ATG AAG CGC ATT ATG CCG CGT GCA AAA GGT CGT GCA GAT CGC ATC CTG AAG CGC ACC AGC CAC ATC ACT GTG GTT GTG TCC GAT CGC",
  "num_of_correct_predicted_codons": 57,
  "query_codons": "AAA TCT GTT AAG TTC GTT CAA GGT TTG TTG CAA AAC GCC GCT GCC AAT GCT GAA GCT <gap> AAG GGT CTA GAT GCT ACC AAG TTG TAC GTT TCT CAC ATC CAA GTT AAC CAA GCA CCA AAG CAA AGA AGA AGA ACT TAC AGA GCC CAC GGT A

## **Example 3 - mask codon sequence generation**

This inference enables you to calculate the accuracy and loss for a specific codon sequence. In addition with the amino-acid sequence of the protein target, you are required to insert the codon sequence of the translated protein.

In [None]:
config = {
    'sw_aa_size':50,
    'calc_stats':False,
    'inference_type':'mask'

}

example = {
    'qseq':'KSVKFVQGLLQNAAANAEA-KGLDATKLYVSHIQVNQAPKQRRRTYRAHGRINKYESSPSHIELVVTEK',
    'query_species':'S_cerevisiae',
    'expr':'expr_top10'
}

res = json.dumps(predict(config, example, mask_restrict_dict, tokenizer, model),indent=2)
print(res)


{
  "prot_len": 69,
  "prot_AAs": "KSVKFVQGLLQNAAANAEA-KGLDATKLYVSHIQVNQAPKQRRRTYRAHGRINKYESSPSHIELVVTEK",
  "pred_codons": "AAG TCT GTT AAG TTC GTT CAA GGT TTG TTG CAA AAC GCT GCT GCT AAC GCT GAA GCT <gap> AAG GGT TTG GAT GCT ACC AAG TTG TAC GTT TCT CAC ATT CAA GTC AAC CAA GCT CCA AAG CAA AGA AGA AGA ACT TAC AGA GCT CAC GGT AGA ATC AAC AAG TAC GAA TCT TCT CCA TCT CAC ATT GAA TTG GTT GTT ACT GAA AAG",
  "entropy": 0.66972815990448
}
