In [1]:
%load_ext autoreload
%autoreload 2

import torch
import esm
import biotite.structure.io as bsio

from proteinttt.models.esm2 import ESM2TTT, DEFAULT_ESM2_35M_TTT_CFG
from proteinttt.models.esmfold import ESMFoldTTT, DEFAULT_ESMFOLD_TTT_CFG
from proteinttt.base import TTTConfig

device = "cuda" if torch.cuda.is_available() else "cpu"

  from .autonotebook import tqdm as notebook_tqdm
  __import__("pkg_resources").declare_namespace(__name__)


## ESM2

Adaptation of an official [ESM2 example](https://github.com/facebookresearch/esm) to use ProteinTTT before predicting embeddings.

In [2]:
seq = "HRQALGERLYPRVQAMQPAFASKITGMLLELSPAQLLLLLASEDSLRARVDEAMELII"

# Load ESM-2 model and data
model, alphabet = esm.pretrained.esm2_t12_35M_UR50D()
batch_converter = alphabet.get_batch_converter()
model.eval().to(device)  # disables dropout for deterministic results
batch_labels, batch_strs, batch_tokens = batch_converter([(None, seq)])
batch_lens = (batch_tokens != alphabet.padding_idx).sum(1)
batch_tokens = batch_tokens.to(device)

# ================ TTT ================
ttt_cfg = DEFAULT_ESM2_35M_TTT_CFG
model = ESM2TTT.ttt_from_pretrained(model, ttt_cfg)
model.ttt(seq)
# =====================================

# Extract per-residue representations
with torch.no_grad():
    results = model(batch_tokens, repr_layers=[12])
token_representations = results["representations"][12]
sequence_representations = []
for i, tokens_len in enumerate(batch_lens):
    sequence_representations.append(token_representations[i, 1 : tokens_len - 1].mean(0))
print(sequence_representations[0].shape)

# Reset model to original state (after this model.ttt can be called again on another protein)
# ================ TTT ================
model.ttt_reset()
# =====================================



2025-11-04 15:33:14,661 | INFO | step: 0, accumulated_step: 0, loss: None, perplexity: None, ttt_step_time: 0.00000, score_seq_time: 0.00000, eval_step_time: 0.00001
2025-11-04 15:33:15,400 | INFO | step: 1, accumulated_step: 16, loss: 0.75260, perplexity: None, ttt_step_time: 0.73787, score_seq_time: 0.00000, eval_step_time: 0.00003
2025-11-04 15:33:15,831 | INFO | step: 2, accumulated_step: 32, loss: 0.73837, perplexity: None, ttt_step_time: 0.43010, score_seq_time: 0.00000, eval_step_time: 0.00003
2025-11-04 15:33:16,262 | INFO | step: 3, accumulated_step: 48, loss: 0.68426, perplexity: None, ttt_step_time: 0.43029, score_seq_time: 0.00000, eval_step_time: 0.00003
2025-11-04 15:33:16,692 | INFO | step: 4, accumulated_step: 64, loss: 0.69365, perplexity: None, ttt_step_time: 0.42923, score_seq_time: 0.00000, eval_step_time: 0.00003
2025-11-04 15:33:17,117 | INFO | step: 5, accumulated_step: 80, loss: 0.67766, perplexity: None, ttt_step_time: 0.42461, score_seq_time: 0.00000, eval_ste

## ESMFold

Adaptation of an official [ESMFold example](https://github.com/facebookresearch/esm) to use ProteinTTT before predicting protein structure. Please note that rerunning the customization multiple times or with different random seeds (`ttt_cfg.seed=<seed>`) may lead to slightly different results. So, running several times can yield a better (i.e., higher-pLDDT) solution.

In [3]:
# Set your sequence
sequence = "GIHLGELGLLPSTVLAIGYFENLVNIICESLNMLPKLEVSGKEYKKFKFTIVIPKDLDANIKKRAKIYFKQKSLIEIEIPTSSRNYPIHIQFDENSTDDILHLYDMPTTIGGIDKAIEMFMRKGHIGKTDQQKLLEERELRNFKTTLENLIATDAFAKEMVEVIIEE"

# Load model
model = esm.pretrained.esmfold_v1()
model = model.eval().cuda()

def predict_structure(model, sequence):
    with torch.no_grad():
        output = model.infer_pdb(sequence)

    with open("result.pdb", "w") as f:
        f.write(output)

    struct = bsio.load_structure("result.pdb", extra_fields=["b_factor"])
    print('pLDDT:', struct.b_factor.mean())

predict_structure(model, sequence)

# ============ ProteinTTT =============
ttt_cfg = DEFAULT_ESMFOLD_TTT_CFG
ttt_cfg.steps = 10  # This is how you can modify the config
ttt_cfg.seed = 5
model = ESMFoldTTT.ttt_from_pretrained(model, ttt_cfg=ttt_cfg, esmfold_config=model.cfg)
model.ttt(sequence)
# =====================================

predict_structure(model, sequence)

# Reset model to original state (after this model.ttt can be called again on another protein)
# ============== ProteinTTT ===========
model.ttt_reset()
# =====================================

pLDDT: 37.87921248142645
2025-11-04 15:34:28,076 | INFO | step: 0, accumulated_step: 0, loss: None, perplexity: None, ttt_step_time: 0.00000, score_seq_time: 0.00000, eval_step_time: 1.75401, plddt: 37.87921
2025-11-04 15:34:30,745 | INFO | step: 1, accumulated_step: 4, loss: 2.51367, perplexity: None, ttt_step_time: 0.59175, score_seq_time: 0.00000, eval_step_time: 1.75558, plddt: 54.04266
2025-11-04 15:34:33,507 | INFO | step: 2, accumulated_step: 8, loss: 2.60938, perplexity: None, ttt_step_time: 0.55520, score_seq_time: 0.00000, eval_step_time: 1.74990, plddt: 77.62975
2025-11-04 15:34:35,813 | INFO | step: 3, accumulated_step: 12, loss: 2.50000, perplexity: None, ttt_step_time: 0.55554, score_seq_time: 0.00000, eval_step_time: 1.74906, plddt: 72.34952
2025-11-04 15:34:38,393 | INFO | step: 4, accumulated_step: 16, loss: 2.12109, perplexity: None, ttt_step_time: 0.55375, score_seq_time: 0.00000, eval_step_time: 1.74888, plddt: 78.22857
2025-11-04 15:34:40,699 | INFO | step: 5, accu

## ProGen2
Example for an autoregressive model

https://www.cell.com/cell-systems/fulltext/S2405-4712(23)00272-7

Conda environment should be installed following https://github.com/salesforce/progen/blob/main/progen2/requirements.txt

In [None]:
# Download ProGen2 code and weights
!git clone https://github.com/salesforce/progen ../proteinttt/models && \
    cd ../proteinttt/models/progen/progen2 && \
    set model progen2-small && \
    wget -P checkpoints/$model https://storage.googleapis.com/sfr-progen-research/checkpoints/$model.tar.gz && \
    tar -xvf checkpoints/$model/$model.tar.gz -C checkpoints/$model/

c[3J--2025-11-28 17:08:26--  https://storage.googleapis.com/sfr-progen-research/checkpoints/progen2-small.tar.gz
Resolving storage.googleapis.com (storage.googleapis.com)... 142.251.36.123, 142.251.36.91, 142.251.38.155, ...
Connecting to storage.googleapis.com (storage.googleapis.com)|142.251.36.123|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 561673660 (536M) [application/x-tar]
Saving to: â€˜checkpoints/progen2-small/progen2-small.tar.gz.2â€™


2025-11-28 17:09:11 (12.2 MB/s) - â€˜checkpoints/progen2-small/progen2-small.tar.gz.2â€™ saved [561673660/561673660]

x ./
x ./config.json
x ./pytorch_model.bin


In [2]:
from proteinttt.models.progen2 import ProGen2TTT
from tokenizers import Tokenizer
from proteinttt.models.progen.progen2.models.progen.modeling_progen import ProGenForCausalLM

sequence = "GIHLGELGLLPSTVLAIGYFENLVNIICESLNMLPKLEVSGKEYKKFKFTIVIPKDLDANIKKRAKIYFKQKSLIEIEIPTSSRNYPIHIQFDENSTDDILHLYDMPTTIGGIDKAIEMFMRKGHIGKTDQQKLLEERELRNFKTTLENLIATDAFAKEMVEVIIEE"
ckpts_dir = "../proteinttt/models/progen/progen2/checkpoints/"
model_name = "progen2-small"

with open(ckpts_dir + "tokenizer.json", "r") as f:
    tokenizer = Tokenizer.from_str(f.read())

model = ProGenForCausalLM.from_pretrained(ckpts_dir + model_name)
model = ProGen2TTT.ttt_from_pretrained(model=model, tokenizer=tokenizer, config=model.config)

model.ttt(sequence)

ProGenForCausalLM has generative capabilities, as `prepare_inputs_for_generation` is explicitly defined. However, it doesn't directly inherit from `GenerationMixin`. From ðŸ‘‰v4.50ðŸ‘ˆ onwards, `PreTrainedModel` will NOT inherit from `GenerationMixin`, and this model will lose the ability to call `generate` and other related functions.
  - If you are the owner of the model architecture code, please modify your model class such that it inherits from `GenerationMixin` (after `PreTrainedModel`, otherwise you'll get an exception).
  - If you are not the owner of the model architecture class, please contact the model code owner to update it.
ProGenForCausalLM has generative capabilities, as `prepare_inputs_for_generation` is explicitly defined. However, it doesn't directly inherit from `GenerationMixin`. From ðŸ‘‰v4.50ðŸ‘ˆ onwards, `PreTrainedModel` will NOT inherit from `GenerationMixin`, and this model will lose the ability to call `generate` and other related functions.
  - If you are th

2025-11-28 17:13:55,565 | INFO | step: 0, accumulated_step: 0, loss: None, perplexity: None, ttt_step_time: 0.00000, score_seq_time: 0.00000, eval_step_time: 0.00000
2025-11-28 17:13:56,941 | INFO | step: 1, accumulated_step: 4, loss: 2.81965, perplexity: None, ttt_step_time: 1.37571, score_seq_time: 0.00000, eval_step_time: 0.00000
2025-11-28 17:13:58,399 | INFO | step: 2, accumulated_step: 8, loss: 2.82224, perplexity: None, ttt_step_time: 1.45787, score_seq_time: 0.00000, eval_step_time: 0.00000
2025-11-28 17:13:59,907 | INFO | step: 3, accumulated_step: 12, loss: 2.98479, perplexity: None, ttt_step_time: 1.50796, score_seq_time: 0.00000, eval_step_time: 0.00000
2025-11-28 17:14:01,341 | INFO | step: 4, accumulated_step: 16, loss: 2.97238, perplexity: None, ttt_step_time: 1.43282, score_seq_time: 0.00000, eval_step_time: 0.00001
2025-11-28 17:14:02,829 | INFO | step: 5, accumulated_step: 20, loss: 3.36476, perplexity: None, ttt_step_time: 1.48732, score_seq_time: 0.00000, eval_step_

{'ttt_step_data': defaultdict(dict,
             {0: {'eval_step_preds': {}},
              1: {'eval_step_preds': {}},
              2: {'eval_step_preds': {}},
              3: {'eval_step_preds': {}},
              4: {'eval_step_preds': {}},
              5: {'eval_step_preds': {}},
              6: {'eval_step_preds': {}},
              7: {'eval_step_preds': {}},
              8: {'eval_step_preds': {}},
              9: {'eval_step_preds': {}},
              10: {'eval_step_preds': {}},
              11: {'eval_step_preds': {}},
              12: {'eval_step_preds': {}},
              13: {'eval_step_preds': {}},
              14: {'eval_step_preds': {}},
              15: {'eval_step_preds': {}}}),
 'df':     step  accumulated_step      loss perplexity  ttt_step_time  \
 0      0                 0       NaN       None       0.000000   
 1      1                 4  2.819646       None       1.375708   
 2      2                 8  2.822244       None       1.457869   
 3      3 