In [1]:
%load_ext autoreload
%autoreload 2

import torch
import esm

from proteinttt.models.esm2 import ESM2TTT, DEFAULT_ESM2_35M_TTT_CFG
from proteinttt.models.esmfold import ESMFoldTTT, DEFAULT_ESMFOLD_TTT_CFG
from proteinttt.base import TTTConfig

device = "cuda" if torch.cuda.is_available() else "cpu"

## ESM2

Adaptation of an official [ESM2 example](https://github.com/facebookresearch/esm) to use TTT before predicting embeddings.

In [4]:
seq = "HRQALGERLYPRVQAMQPAFASKITGMLLELSPAQLLLLLASEDSLRARVDEAMELII"

# Load ESM-2 model and data
model, alphabet = esm.pretrained.esm2_t12_35M_UR50D()
batch_converter = alphabet.get_batch_converter()
model.eval().to(device)  # disables dropout for deterministic results
batch_labels, batch_strs, batch_tokens = batch_converter([(None, seq)])
batch_lens = (batch_tokens != alphabet.padding_idx).sum(1)
batch_tokens = batch_tokens.to(device)

# ================ TTT ================
ttt_cfg = DEFAULT_ESM2_35M_TTT_CFG
model = ESM2TTT.ttt_from_pretrained(model, ttt_cfg)
model.ttt(batch_tokens)
# =====================================

# Extract per-residue representations
with torch.no_grad():
    results = model(batch_tokens, repr_layers=[12])
token_representations = results["representations"][12]
sequence_representations = []
for i, tokens_len in enumerate(batch_lens):
    sequence_representations.append(token_representations[i, 1 : tokens_len - 1].mean(0))
print(sequence_representations[0].shape)

# Rest model to original state (after this model.ttt can be called again on another protein)
# ================ TTT ================
model.ttt_reset()
# =====================================

2025-01-18 17:28:46,040 | INFO | step: 0, accumulated_step: 0, loss: None, perplexity: None, ttt_step_time: 0.00000, score_seq_time: None, eval_step_time: 0.00000
2025-01-18 17:28:46,455 | INFO | step: 1, accumulated_step: 16, loss: 2.72167, perplexity: None, ttt_step_time: 0.41339, score_seq_time: None, eval_step_time: 0.00002
2025-01-18 17:28:46,849 | INFO | step: 2, accumulated_step: 32, loss: 2.27294, perplexity: None, ttt_step_time: 0.39393, score_seq_time: None, eval_step_time: 0.00002
2025-01-18 17:28:47,244 | INFO | step: 3, accumulated_step: 48, loss: 2.41207, perplexity: None, ttt_step_time: 0.39404, score_seq_time: None, eval_step_time: 0.00002
2025-01-18 17:28:47,638 | INFO | step: 4, accumulated_step: 64, loss: 2.37137, perplexity: None, ttt_step_time: 0.39358, score_seq_time: None, eval_step_time: 0.00002
2025-01-18 17:28:48,032 | INFO | step: 5, accumulated_step: 80, loss: 2.38664, perplexity: None, ttt_step_time: 0.39316, score_seq_time: None, eval_step_time: 0.00002
20

## ESMFold

Adaptation of an official [ESMFold example](https://github.com/facebookresearch/esm) to use TTT before predicting protein structure.

In [3]:
model = esm.pretrained.esmfold_v1()
model = model.eval().cuda()

# Optionally, uncomment to set a chunk size for axial attention. This can help reduce memory.
# Lower sizes will have lower memory requirements at the cost of increased speed.
# model.set_chunk_size(128)

sequence = "GIHLGELGLLPSTVLAIGYFENLVNIICESLNMLPKLEVSGKEYKKFKFTIVIPKDLDANIKKRAKIYFKQKSLIEIEIPTSSRNYPIHIQFDENSTDDILHLYDMPTTIGGIDKAIEMFMRKGHIGKTDQQKLLEERELRNFKTTLENLIATDAFAKEMVEVIIEE"

# ================ TTT ================
ttt_cfg = DEFAULT_ESMFOLD_TTT_CFG
ttt_cfg.seed = 0  # Trying TTT with several different seeds may enable finding structure with higher pLDDT
ttt_cfg.steps = 10
model = ESMFoldTTT.ttt_from_pretrained(model, esmfold_config=model.cfg)
model.ttt(sequence)
# =====================================

with torch.no_grad():
    output = model.infer_pdb(sequence)

with open("result.pdb", "w") as f:
    f.write(output)

import biotite.structure.io as bsio
struct = bsio.load_structure("result.pdb", extra_fields=["b_factor"])
print(struct.b_factor.mean())  # this will be the pLDDT

# Rest model to original state (after this model.ttt can be called again on another protein)
# ================ TTT ================
model.ttt_reset()
# =====================================

2025-01-18 18:11:08,062 | INFO | step: 0, accumulated_step: 0, loss: None, perplexity: None, ttt_step_time: 0.00000, score_seq_time: None, eval_step_time: 2.23341, plddt: 38.43025
2025-01-18 18:11:10,325 | INFO | step: 1, accumulated_step: 4, loss: 2.50000, perplexity: None, ttt_step_time: 0.51810, score_seq_time: None, eval_step_time: 1.74307, plddt: 34.06020
2025-01-18 18:11:12,543 | INFO | step: 2, accumulated_step: 8, loss: 2.66797, perplexity: None, ttt_step_time: 0.47896, score_seq_time: None, eval_step_time: 1.73836, plddt: 33.57555
2025-01-18 18:11:15,073 | INFO | step: 3, accumulated_step: 12, loss: 2.46094, perplexity: None, ttt_step_time: 0.47758, score_seq_time: None, eval_step_time: 1.73828, plddt: 76.67573
2025-01-18 18:11:17,298 | INFO | step: 4, accumulated_step: 16, loss: 2.37500, perplexity: None, ttt_step_time: 0.48372, score_seq_time: None, eval_step_time: 1.73908, plddt: 72.56865
2025-01-18 18:11:19,959 | INFO | step: 5, accumulated_step: 20, loss: 2.37891, perplex