In [1]:
#@title **Setup**

#@markdown ### Identification
huggingface_token = "" #@param {type:"string"}
github_token = "" #@param {type:"string"}
#@markdown ---

github_clone_path = f"https://{github_token}@github.com/Synthyra/TranslatorInference.git"

print("\nCloning the github repository...\n")
# !git clone {github_clone_path}
# %cd TranslatorInference
# !pip install -r requirements.txt --quiet

if huggingface_token:
    print("\nLogging into HuggingFace...\n")
    from huggingface_hub import login
    login(huggingface_token)

print("\nLoading the Annotation Vocabulary...\n")
import pickle
with open('id2label.pkl', 'rb') as f:
    id2label = pickle.load(f)

with open('label2id.pkl', 'rb') as f:
    label2id = pickle.load(f)


from annotation_mapping import name_ec, ec_name, name_go, go_name, name_ip, ip_name, name_gene, gene_name
annotation_vocab_dict = {
    'ec': (name_ec, ec_name),
    'go': (name_go, go_name),
    'ip': (name_ip, ip_name),
    '3d': (name_gene, gene_name)
}

print("\nImporting dependencies...\n")
import torch
import ipywidgets as widgets
import pandas as pd
from IPython.display import display, HTML
from utils import describe_prompt, return_preds, get_probs
from model import SeqToAnnTranslator, TranslatorConfig
from tqdm.auto import tqdm


print("\nLoading the model...\n")
model_path = 'lhallee/translator_seq_to_ann_final'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
config = TranslatorConfig.from_pretrained(model_path)
model = SeqToAnnTranslator(config).from_pretrained(model_path).eval().to(device)
tokenizer = model.esm.tokenizer


aspect_dict = {
    'ec': 'Enzyme Comission Number',
    'bp': 'GO Biological Process',
    'cc': 'GO Cellular Component',
    'mf': 'GO Molecular Function',
    'ip': 'InterPro',
    'threed': 'Gene3D',
    'keywords': 'Uniprot Keywords'
}
#@markdown *Press play to setup the environment*

#@markdown ---


Cloning the github repository...


Loading the Annotation Vocabulary...


Importing dependencies...


Loading the model...



modeling_fastesm.py:   0%|          | 0.00/41.4k [00:00<?, ?B/s]

A new version of the following files was downloaded from https://huggingface.co/Synthyra/ESM2-650M:
- modeling_fastesm.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.
Some weights of FastEsmModel were not initialized from the model checkpoint at Synthyra/ESM2-650M and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


model.safetensors:   0%|          | 0.00/1.84G [00:00<?, ?B/s]

Some weights of FastEsmModel were not initialized from the model checkpoint at Synthyra/ESM2-650M and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [10]:
#@title **Inference**
#@markdown Enter a sequence to annotate
seq = "MDEMILLRRVLLAGFICALLVPSGLSCGPGRGIGTRKRFKKLTPLAYKQFTPNVPEKTLGASGRYEGKITRNSERFKELTPNYNPDIIFKDEENTGADRLMTQRCKDKLNALAISVMNQWPGVKLRVTEGWDEDGHHFEESLHYEGRAVDITTSDRDRSKYGMLARLAAEAGFDWVYFESKAHIHCSVKAENSVAAKSGGCFPGSATVALEQGVRIPVKDLRPGDRVLAADGLGKLVYSDFLLFMDKEETVRKVFYVIETSRERVRLTAAHLLFVGQAHPGNDSGGDFRSVFGSAGFRSMFASSVRAGHRVLTVDREGRGLREATVERVYLEEATGAYAPVTAHGTVVIDRVLASCYAVIEEHSWAHWAFAPLRVGLGILSFFSPQDYSSHSPPAPSQSEGVHWYSEILYRIGTWVLQEDTIHPLGMAAKSS" #@param {type:"string"}

num_annotations = 32
seqs = [seq]

probs = get_probs(
    model=model,
    tokenizer=tokenizer,
    seqs=seqs,
    num_annotations=num_annotations,
    device=device
)

#@markdown Press play to annotate
#@markdown ---

In [3]:
#@title A guide to topk and confidence

#@markdown ---
#@markdown **Translator** predicts a fixed number of protein annotations from the [**Annotation Vocabulary**](https://www.biorxiv.org/content/10.1101/2024.07.30.605924v1) from an input protein sequence.

#@markdown The `topk` parameter controls the number of annotations retrieved per "token."

#@markdown The `confidence` parameter controls the minimum predicted confidence score for an annotation to be included in the output.

#@markdown Shown below is a figure showcasing the trade-off between topk and confidence. A higher topk value will result in more annotations being retrieved, but at the cost of lower confidence and precision.

#@markdown Lower topk values will result in higher precision, meaning each annotation shown is more likely to be correct.

#@markdown Higher topk values will result in higher recall, meaning that within the set of annotations, more are likely to be retrieved.

#@markdown A very high topk value is way to explore possible annotations but is less likely to be accurate.

#@markdown The optimal topk value is a trade-off between precision and recall, often measured by their harmonic mean (F1 score), which is at topk=3 for our evaluation sets.

#@markdown The figure also showcasing the minimum confidence score at each topk value such that **every annotation above that confidence was correctly predicted**.

#@markdown Therefore, you can adjust topk and confidence to be more "sure" about the output or more "exploratory."

#@markdown ---

In [11]:
#@title **Press play to view results**

#@markdown ---
#@markdown Annotations are colored by confidence, with blue being the lowest confidence and red being the highest.
#@markdown In general, anything above 0.1 is fairly high confidence (as they sum to 1 over 88000 options!).

topk_text = widgets.BoundedIntText(
    value=3,
    min=1,
    max=100,
    step=1,
    description='TopK',
    continuous_update=False
)

confidence_text = widgets.BoundedFloatText(
    value=0.04,
    min=0.0,
    max=1.0,
    step=0.01,
    description='Min Confidence',
    continuous_update=False
)

# Create the Output widget and enforce white background + scrolling.
output = widgets.Output()
# Set a large height, enable scrolling, and explicitly set background to white:
output.layout = widgets.Layout(
    height='1000px',
    overflow='auto',         # or 'scroll'
    border='1px solid #ddd'
)

def color_text_by_conf(text, conf):
    """
    Returns an HTML <span> with color going from blue (conf=0) to red (conf=1).
    """
    red_val   = int(conf * 255)
    blue_val  = 255 - red_val
    return f"<span style='color:rgb({red_val},0,{blue_val}); font-weight:bold;'>{text} ({conf:.4f})</span>"

def update_output(*args):
    with output:
        output.clear_output()
        
        # Get the widget values:
        topk_val = topk_text.value
        confidence_val = confidence_text.value
        
        # Get filtered predictions
        final_ids, confidences = return_preds(
            probs=probs,
            topk=topk_val, 
            minimum_confidence=confidence_val
        )
        
        described, track_ids = describe_prompt(final_ids, id2label, annotation_vocab_dict)
        conf_map = dict(zip(final_ids, confidences))
        
        html_output = ""
        for aspect, entries in described.items():
            aspect_name = aspect_dict.get(aspect, aspect)
            html_output += f"<h4 style='color:black; font-weight:bold;'>{aspect_name}</h4>"
            for entry in entries:
                id_ = track_ids[entry]
                conf = conf_map.get(id_, 0.0)
                html_output += color_text_by_conf(entry, conf) + "<br>"
            html_output += "<hr>"
        
        # Display the HTML, wrapped in a white background div just in case:
        html_output = f"<div style='background-color:white;'>{html_output}</div>"
        display(HTML(html_output))
        
        # Force scroll to the top with a tiny JS snippet.
        # This attempts to find the container of the current output and reset scrollTop.
        scroll_to_top_js = """
        <script>
        // Grab the closest output area for this cell (the ipywidget .output_scroll or .jupyter-widgets-output-area)
        let out_area = this.closest('.jupyter-widgets-output-area') || this.closest('.output_scroll');
        if (out_area) {
            out_area.scrollTop = 0;
        }
        </script>
        """
        display(HTML(scroll_to_top_js))

# Observe changes
topk_text.observe(update_output, names='value')
confidence_text.observe(update_output, names='value')

# Display widgets and output
display(widgets.VBox([topk_text, confidence_text]), output)

# Initialize the display
update_output()

VBox(children=(BoundedIntText(value=3, description='TopK', min=1), BoundedFloatText(value=0.04, description='M…

Output(layout=Layout(border_bottom='1px solid #ddd', border_left='1px solid #ddd', border_right='1px solid #dd…

In [5]:
#@title **High throughput annotation**
topk = "3" #@param {type:"string"}
min_confidence = "0.04" #@param {type:"string"}
#@markdown Number of annotations refers to how many annotation `tokens` are predicted.
#@markdown The model is trained to predict between 1 and 62 annotations.
#@markdown 32 is a good default.
num_annotations = 32 #@param {type:"slider", min:1, max:62, step:1}

#@markdown ---
#@markdown Give local csv with "seqs" column or huggingface dataset path with one train split and "seqs" column.
#@markdown Outputs to a local csv
data_path = "" #@param {type:"string"}
local = False #@param {type:"boolean"}
output_path = "results.csv" #@param {type:"string"}

topk = int(topk)
confidence = float(min_confidence)
num_annotations = int(num_annotations)

if local:
    dataset = pd.read_csv(data_path)
    seqs = dataset["seqs"].tolist()
else:
    from datasets import load_dataset
    dataset = load_dataset(data_path)
    seqs = dataset["train"]["seqs"]

results = []
for seq in tqdm(seqs, desc="Annotating sequences"):
    probs = get_probs(
        model=model,
        tokenizer=tokenizer,
        seqs=[seq],
        num_annotations=num_annotations,
        device=device
    )

    final_ids, confidences = return_preds(
        probs=probs, 
        topk=topk, 
        minimum_confidence=confidence
    )

    described, track_ids = describe_prompt(final_ids, id2label, annotation_vocab_dict)
    conf_map = dict(zip(final_ids, confidences))

    result_dict = {"seqs": seq}
    
    for aspect, entries in described.items():
        aspect_name = aspect_dict.get(aspect, aspect)
        entries_with_conf = []
        for entry in entries:
            id_ = track_ids[entry]
            conf = conf_map.get(id_, 0.0)
            entries_with_conf.append((entry, conf))
        # Sort entries by confidence in descending order
        entries_with_conf.sort(key=lambda x: x[1], reverse=True)
        # Format entries after sorting
        formatted_entries = [f"{entry} ({conf:.4f})" for entry, conf in entries_with_conf]
        result_dict[aspect_name] = "; ".join(formatted_entries)
    
    results.append(result_dict)

# Create and save DataFrame
results_df = pd.DataFrame(results)
results_df.to_csv(output_path, index=False)
print(f"Results saved to {output_path}")

#@markdown Press play to annotate your dataset
#@markdown ---


IndexError: list index out of range