In [None]:
%load_ext autoreload
%autoreload 2

import torch
import esm

from proteinttt.models.esm2 import ESM2TTT, DEFAULT_ESM2_35M_TTT_CFG
from proteinttt.models.esmfold import ESMFoldTTT, DEFAULT_ESMFOLD_TTT_CFG
from proteinttt.base import TTTConfig

device = "cuda" if torch.cuda.is_available() else "cpu"

## ESM2

Adaptation of an official [ESM2 example](https://github.com/facebookresearch/esm) to use ProteinTTT before predicting embeddings.

In [10]:
seq = "HRQALGERLYPRVQAMQPAFASKITGMLLELSPAQLLLLLASEDSLRARVDEAMELII"

# Load ESM-2 model and data
model, alphabet = esm.pretrained.esm2_t12_35M_UR50D()
batch_converter = alphabet.get_batch_converter()
model.eval().to(device)  # disables dropout for deterministic results
batch_labels, batch_strs, batch_tokens = batch_converter([(None, seq)])
batch_lens = (batch_tokens != alphabet.padding_idx).sum(1)
batch_tokens = batch_tokens.to(device)

# ================ TTT ================
ttt_cfg = DEFAULT_ESM2_35M_TTT_CFG
model = ESM2TTT.ttt_from_pretrained(model, ttt_cfg)
model.ttt(seq)
# =====================================

# Extract per-residue representations
with torch.no_grad():
    results = model(batch_tokens, repr_layers=[12])
token_representations = results["representations"][12]
sequence_representations = []
for i, tokens_len in enumerate(batch_lens):
    sequence_representations.append(token_representations[i, 1 : tokens_len - 1].mean(0))
print(sequence_representations[0].shape)

# Rest model to original state (after this model.ttt can be called again on another protein)
# ================ TTT ================
model.ttt_reset()
# =====================================

2025-10-25 19:15:37,554 | INFO | step: 0, accumulated_step: 0, loss: None, perplexity: None, ttt_step_time: 0.00000, score_seq_time: 0.00000, eval_step_time: 0.00000
2025-10-25 19:15:37,985 | INFO | step: 1, accumulated_step: 16, loss: 0.75260, perplexity: None, ttt_step_time: 0.43038, score_seq_time: 0.00000, eval_step_time: 0.00003
2025-10-25 19:15:38,411 | INFO | step: 2, accumulated_step: 32, loss: 0.73837, perplexity: None, ttt_step_time: 0.42509, score_seq_time: 0.00000, eval_step_time: 0.00003
2025-10-25 19:15:38,835 | INFO | step: 3, accumulated_step: 48, loss: 0.68426, perplexity: None, ttt_step_time: 0.42412, score_seq_time: 0.00000, eval_step_time: 0.00003
2025-10-25 19:15:39,261 | INFO | step: 4, accumulated_step: 64, loss: 0.69365, perplexity: None, ttt_step_time: 0.42529, score_seq_time: 0.00000, eval_step_time: 0.00003
2025-10-25 19:15:39,688 | INFO | step: 5, accumulated_step: 80, loss: 0.67766, perplexity: None, ttt_step_time: 0.42601, score_seq_time: 0.00000, eval_ste

## ESMFold

Adaptation of an official [ESMFold example](https://github.com/facebookresearch/esm) to use ProteinTTT before predicting protein structure.

In [3]:
model = esm.pretrained.esmfold_v1()
model = model.eval().cuda()

# Optionally, uncomment to set a chunk size for axial attention. This can help reduce memory.
# Lower sizes will have lower memory requirements at the cost of increased speed.
# model.set_chunk_size(128)

sequence = "GIHLGELGLLPSTVLAIGYFENLVNIICESLNMLPKLEVSGKEYKKFKFTIVIPKDLDANIKKRAKIYFKQKSLIEIEIPTSSRNYPIHIQFDENSTDDILHLYDMPTTIGGIDKAIEMFMRKGHIGKTDQQKLLEERELRNFKTTLENLIATDAFAKEMVEVIIEE"

# ================ TTT ================
ttt_cfg = DEFAULT_ESMFOLD_TTT_CFG
ttt_cfg.seed = 0  # Trying TTT with several different seeds may enable finding structure with higher pLDDT
ttt_cfg.steps = 10
model = ESMFoldTTT.ttt_from_pretrained(model, esmfold_config=model.cfg)
df = model.ttt(sequence)
# =====================================

with torch.no_grad():
    output = model.infer_pdb(sequence)

with open("result.pdb", "w") as f:
    f.write(output)

import biotite.structure.io as bsio
struct = bsio.load_structure("result.pdb", extra_fields=["b_factor"])
print(struct.b_factor.mean())  # this will be the pLDDT

# Rest model to original state (after this model.ttt can be called again on another protein)
# ================ TTT ================
model.ttt_reset()
# =====================================

2025-10-25 18:43:22,370 | INFO | step: 0, accumulated_step: 0, loss: None, perplexity: None, ttt_step_time: 0.00000, score_seq_time: 0.00000, eval_step_time: 2.49081, plddt: 38.43030
2025-10-25 18:43:24,946 | INFO | step: 1, accumulated_step: 4, loss: 2.54688, perplexity: None, ttt_step_time: 0.54727, score_seq_time: 0.00000, eval_step_time: 2.02741, plddt: 36.79425
2025-10-25 18:43:27,480 | INFO | step: 2, accumulated_step: 8, loss: 2.51953, perplexity: None, ttt_step_time: 0.51081, score_seq_time: 0.00000, eval_step_time: 2.02258, plddt: 38.01218
2025-10-25 18:43:30,014 | INFO | step: 3, accumulated_step: 12, loss: 2.48633, perplexity: None, ttt_step_time: 0.51053, score_seq_time: 0.00000, eval_step_time: 2.02268, plddt: 35.67231
2025-10-25 18:43:32,630 | INFO | step: 4, accumulated_step: 16, loss: 2.27734, perplexity: None, ttt_step_time: 0.51160, score_seq_time: 0.00000, eval_step_time: 2.02284, plddt: 76.67175
2025-10-25 18:43:35,166 | INFO | step: 5, accumulated_step: 20, loss: 2