# Using biotrainer autoeval for plm evaluation

This notebook shows an example how to use the biotrainer `autoeval` module for automatic plm evaluation. We use the [PBC](https://github.com/Rostlab/pbc) framework that includes curated datasets that are established for plm benchmarking.

In [None]:
# Install biotrainer if you haven't
# !pip install biotrainer

## Default Use Case: Model Download from Huggingface

The most convient option to use the biotrainer autoeval pipeline is to use the huggingface id of your plm. This will automatically download the model, calculate the embeddings and run the evaluation.

In [None]:
import torch

# Define variables
embedder_name = "Rostlab/prot_t5_xl_uniref50"  # Replace with your plm's huggingface id. For alternatives, see "Advanced options" below
framework = "pbc"
min_seq_length = 0  # Default
max_seq_length = 2000  # Default

In [None]:
# Run the pipeline
from biotrainer.autoeval import autoeval_pipeline

current_progress = None
for progress in autoeval_pipeline(embedder_name=embedder_name,
                                  framework=framework,
                                  min_seq_length=min_seq_length,
                                  max_seq_length=max_seq_length):
    print(progress)  # The pipeline is a generator function to inform the user about the current progress.
    current_progress = progress

In [None]:
# Let's look at the results
if current_progress is None or current_progress.final_report is None:
    print("No results found.")  # Something went wrong
else:
    final_report = current_progress.final_report
    scl_result = final_report["results"]["PBC-scl"]["test_results"]['test']['metrics']['accuracy']
    sec_struct_result_newpisces364 = final_report["results"]["PBC-secondary_structure"]["test_results"]['newpisces364']['metrics']['accuracy']
    sec_struct_result_casp12 = final_report["results"]["PBC-secondary_structure"]["test_results"]['casp12']['metrics']['accuracy']
    sec_struct_result_casp13 = final_report["results"]["PBC-secondary_structure"]["test_results"]['casp13']['metrics']['accuracy']
    sec_struct_result_casp14 = final_report["results"]["PBC-secondary_structure"]["test_results"]['casp13']['metrics']['accuracy']

    print(f"PBC-scl results: {scl_results} (accuracy on test)\n")
    print(f"PBC-secondary_structure results:")
    print(f"newpisces364: {sec_struct_result_newpisces364} (accuracy)")
    print(f"casp12: {sec_struct_result_casp12} (accuracy)")
    print(f"casp13: {sec_struct_result_casp13} (accuracy)")
    print(f"casp14: {sec_struct_result_casp14} (accuracy)")

The full report file can be found at `autoeval_output/{embedder_name}/autoeval_report_{embedder_name}.json`

## Advanced options 1: Using a custom embedding function

If you are running biotrainer-autoeval directly after training your model, the model will probably not be available on huggingface, but locally. Therefore, you can provide custom embedding functions both for per-sequence and per-residue embeddings to be independent of the biotrainer embedding module. The provided functions take a list of strings (sequences) as input and must return, for each sequence, the sequence and the respective embedding. This is to ensure that the sequence is always mapped to the correct embedding.

*What is a generator function?*:

A generator function returns a result as soon as it is available, and only continues to create new results after the previous one has been processed. In this case, this is useful because it allows to save the embeddings after computation, thus avoiding that the RAM runs full with the embeddings.

In [None]:
# Abstract Explanation
custom_embedding_function_per_sequence = lambda seq: (seq, torch.empty())  # Define your function as a generator here
custom_embedding_function_per_residue = lambda seq: (seq, torch.empty())  # Define your function as a generator here
for progress in autoeval_pipeline(embedder_name=embedder_name,
                                  framework=framework,
                                  custom_embedding_function_per_sequence=custom_embedding_function_per_sequence,
                                  custom_embedding_function_per_residue=custom_embedding_function_per_residue,
                                  min_seq_length=min_seq_length,
                                  max_seq_length=max_seq_length):
    print(progress)  # The pipeline is a generator function to inform the user about the current progress.

### Concrete Example with the ProtT5 QuickStart Tutorial

Now we use the [ProtT5 QuickStart Tutorial](https://github.com/agemagician/ProtTrans?tab=readme-ov-file#-quick-start) to show how to implement ProtT5 via the custom_embedding_functions into autoeval. Note that this example does not use batching efficiently, but it is a good starting point for your own implementation:

In [None]:
from transformers import T5Tokenizer, T5EncoderModel
import torch
import re

device = 'cpu' #torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

# Load the tokenizer
tokenizer = T5Tokenizer.from_pretrained('Rostlab/prot_t5_xl_half_uniref50-enc', do_lower_case=False)

# Load the model
model = T5EncoderModel.from_pretrained("Rostlab/prot_t5_xl_half_uniref50-enc").to(device)

# only GPUs support half-precision currently; if you want to run on CPU use full-precision (not recommended, much slower)
if device == torch.device("cpu"):
    model.to(torch.float32)

def embed_per_residue(sequences: list):
    for sequence in sequences:
        # replace all rare/ambiguous amino acids by X and introduce white-space between all amino acids
        sequence_cleaned = [" ".join(list(re.sub(r"[UZOB]", "X", sequence)))]
        ids = tokenizer(sequence_cleaned, add_special_tokens=True, padding="longest")

        input_ids = torch.tensor(ids['input_ids']).to(device)
        attention_mask = torch.tensor(ids['attention_mask']).to(device)
        # generate embeddings
        with torch.no_grad():
            embedding_repr = model(input_ids=input_ids, attention_mask=attention_mask)
        embedding = embedding_repr.last_hidden_state[0,:len(sequence)]
        yield sequence, embedding

def embed_per_sequence(sequences: str):
    for sequence in sequences:
        _, embedding = embed_per_residue(sequence)
        yield sequence, embedding.mean(dim=0) # shape (1024)

# Run Autoeval
for progress in autoeval_pipeline(embedder_name="ProtT5-custom",
                                  framework=framework,
                                  custom_embedding_function_per_sequence=embed_per_residue,
                                  custom_embedding_function_per_residue=embed_per_sequence,
                                  min_seq_length=min_seq_length,
                                  max_seq_length=max_seq_length):
    print(progress)  # The pipeline is a generator function to inform the user about the current progress.

## Advanced Options 2: Precomputed embeddings file

Another option is to use precomputed embeddings file, if you prefer that or have them already. Just make sure that the files include embeddings for all framework sequences and are stored by sequence hash, according to biotrainer standards.

In [None]:
from pathlib import Path
from biotrainer.autoeval import get_unique_framework_sequences

_, per_residue_seqs, per_sequence_seqs = get_unique_framework_sequences(framework=framework,
                                                                        min_seq_length=min_seq_length,
                                                                        max_seq_length=max_seq_length)
# per_residue_seqs and per_sequence_seqs are dictionaries mapping sequence hashes to BiotrainerSequenceRecord objects, use that hash as an id when storing your embeddings

per_residue_path = Path()  # TODO Your per-residue embeddings path
per_sequence_path = Path()  # TODO Your per-sequence embeddings path
for progress in autoeval_pipeline(embedder_name=embedder_name,
                                  framework=framework,
                                  precomputed_per_residue_embeddings=per_residue_path,
                                  precomputed_per_sequence_embeddings=per_sequence_path,
                                  min_seq_length=min_seq_length,
                                  max_seq_length=max_seq_length):
    print(progress)