# Disentanglement Playground with Pretrained Protein LMs

This notebook shows how to:
1. Download & encode sequences via ESM and ProtTrans
2. Build small latents (β-VAE / FactorVAE) on top of frozen embeddings
3. Contrastive heads (SimCLR style)
4. Train & inspect disentanglement metrics


## 1. Setup

Install dependencies (run once):

In [None]:
!pip install biopython torch torchvision pytorch-lightning transformers fair-esm scikit-learn

Collecting biopython
  Downloading biopython-1.85-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (13 kB)
Downloading biopython-1.85-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.3 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.3/3.3 MB[0m [31m91.0 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: biopython
Successfully installed biopython-1.85


In [None]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import pytorch_lightning as pl
from transformers import AutoModel, AutoTokenizer
import esm  # if using FAIR’s esm repo

## Download FASTA-format Datasets

In [None]:
from Bio import Entrez, SeqIO
Entrez.email = "pompos002@gmail.com"
handle = Entrez.efetch(db="protein",
                       id=["P01308","P01308"],  # UniProt IDs
                       rettype="fasta", retmode="text")
records = list(SeqIO.parse(handle, "fasta"))


In [None]:
!pip install datasets

from datasets import load_dataset
from torch.utils.data import DataLoader

# 1) load FASTA as HF Dataset
hf_ds = load_dataset("fasta",
                     data_files={"train": "data/my_proteins.fasta"},
                     split="train")

# 2) wrap it so it returns just sequences
class SeqDataset(torch.utils.data.Dataset):
    def __init__(self, hf_dataset):
        self.ds = hf_dataset
    def __len__(self):
        return len(self.ds)
    def __getitem__(self, i):
        # hf_ds[i] has 'sequence' & 'description'
        return self.ds[i]["sequence"]

seq_ds = SeqDataset(hf_ds)

# 3) plug into your ProteinEmbeddingDataset
embed_ds = ProteinEmbeddingDataset(
    sequences=seq_ds,               # assume you adapt it to take a list or iterator of seqs
    model_name="esm2_t33_650M_UR50D",
    aggregation="cls",
    max_len=512,
    device="cuda"
)

loader = DataLoader(embed_ds, batch_size=32, shuffle=True)




The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


FileNotFoundError: Couldn't find a dataset script at /content/fasta/fasta.py or any data file in the same directory. Couldn't find 'fasta' on the Hugging Face Hub either: FileNotFoundError: Dataset 'fasta' doesn't exist on the Hub. If the repo is private or gated, make sure to log in with `huggingface-cli login`.

## 2. Data & Embedding Extraction

Define a simple FASTA dataset and functions to extract embeddings from ESM and ProtTrans.

In [None]:
class FastaDataset(Dataset):
    def __init__(self, fasta_path, tokenizer, max_len=512):
        # load sequences
        from Bio import SeqIO
        self.seqs = [str(rec.seq) for rec in SeqIO.parse(fasta_path, "fasta")]
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.seqs)

    def __getitem__(self, idx):
        seq = self.seqs[idx]
        enc = self.tokenizer(seq,
                             truncation=True,
                             padding='max_length',
                             max_length=self.max_len,
                             return_tensors='pt')
        return seq, enc

In [None]:
# Load ESM
esm_model, esm_alphabet = esm.pretrained.esm2_t33_650M_UR50D()
esm_batch_converter = esm_alphabet.get_batch_converter()
esm_model.eval()

# Load ProtTrans (ProtT5)
# pt_model = AutoModel.from_pretrained(
#     "Rostlab/prot_t5_xl_uniref50", trust_remote_code=True
# )
# pt_tokenizer = AutoTokenizer.from_pretrained(
#     "Rostlab/prot_t5_xl_uniref50", do_lower_case=False
# )

Downloading: "https://dl.fbaipublicfiles.com/fair-esm/models/esm2_t33_650M_UR50D.pt" to /root/.cache/torch/hub/checkpoints/esm2_t33_650M_UR50D.pt
Downloading: "https://dl.fbaipublicfiles.com/fair-esm/regression/esm2_t33_650M_UR50D-contact-regression.pt" to /root/.cache/torch/hub/checkpoints/esm2_t33_650M_UR50D-contact-regression.pt


ESM2(
  (embed_tokens): Embedding(33, 1280, padding_idx=1)
  (layers): ModuleList(
    (0-32): 33 x TransformerLayer(
      (self_attn): MultiheadAttention(
        (k_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (v_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (q_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (out_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (rot_emb): RotaryEmbedding()
      )
      (self_attn_layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
      (fc1): Linear(in_features=1280, out_features=5120, bias=True)
      (fc2): Linear(in_features=5120, out_features=1280, bias=True)
      (final_layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
    )
  )
  (contact_head): ContactPredictionHead(
    (regression): Linear(in_features=660, out_features=1, bias=True)
    (activation): Sigmoid()
  )
  (emb_layer_norm_after): LayerNorm((1280,), eps=1

### ProteinEmbedding dataset class (model agnostic)

In [None]:
import os
from torch.utils.data import Dataset
import torch

class ProteinEmbeddingDataset(Dataset):
    """
    A Dataset that:
     - reads sequences from a FASTA
     - tokenizes them for either ESM or ProtTrans
     - runs the frozen model to get per-residue embeddings
     - aggregates them into a fixed-size vector via 'mean', 'cls', or 'max'
    """
    def __init__(
        self,
        fasta_path: str,
        model_name: str = "esm2_t33_650M_UR50D",    # or "prot_t5_xl_uniref50"
        aggregation: str = "mean",                   # one of "mean", "cls", "max"
        max_len: int = 512,
        device: str = "cpu",
    ):
        from Bio import SeqIO
        self.seqs = [str(rec.seq) for rec in SeqIO.parse(fasta_path, "fasta")]
        self.aggregation = aggregation
        self.max_len = max_len
        self.device = device

        if model_name.startswith("esm"):
            import esm
            self.model, self.alphabet = getattr(esm.pretrained, model_name)()
            self.batch_converter = self.alphabet.get_batch_converter()
            self.model.eval().to(device)
            self.backend = "esm"
        else:
            from transformers import AutoModel, AutoTokenizer
            self.tokenizer = AutoTokenizer.from_pretrained(model_name, do_lower_case=False)
            self.model = AutoModel.from_pretrained(model_name, trust_remote_code=True)
            self.model.eval().to(device)
            self.backend = "prottrans"

    def __len__(self):
        return len(self.seqs)

    def __getitem__(self, idx):
        seq = self.seqs[idx][: self.max_len]  # truncate if too long
        if self.backend == "esm":
            # esm wants a batch of tuples: (name, seq)
            batch = [(str(idx), seq)]
            _, _, tokens = self.batch_converter(batch)
            tokens = tokens.to(self.device)
            with torch.no_grad():
                out = self.model(tokens, repr_layers=[self.model.num_layers])
            # repr_layers returns a dict: layer_index → (batch, L, C)
            reps = out["representations"][self.model.num_layers]  # shape (1, L, C)
            mask = (tokens != self.alphabet.padding_idx).unsqueeze(-1)  # (1, L, 1)
        else:
            enc = self.tokenizer(
                seq,
                truncation=True,
                padding="max_length",
                max_length=self.max_len,
                return_tensors="pt",
            ).to(self.device)
            with torch.no_grad():
                out = self.model(**enc, output_hidden_states=False)
            # last_hidden_state: (1, L, C)
            reps = out.last_hidden_state
            # ProtT5 uses tokenizer.pad_token_id for padding
            mask = (enc["attention_mask"].unsqueeze(-1).bool()).to(self.device)

        # reps: (1, L, C), mask: (1, L, 1)
        reps = reps.squeeze(0)   # → (L, C)
        mask = mask.squeeze(0)   # → (L, 1)

        if self.aggregation == "mean":
            summed = (reps * mask).sum(0)                # (C,)
            lengths = mask.sum(0).clamp(min=1)           # (C,) broadcastable
            emb = summed / lengths
        elif self.aggregation == "max":
            # mask out padded positions by very negative
            reps_masked = reps.masked_fill(~mask, -1e9)
            emb, _ = reps_masked.max(0)                  # (C,)
        elif self.aggregation == "cls":
            # For ESM, the first token is <cls> (index 0);
            # for ProtTrans (like T5), token 0 is <s> which you can treat like CLS
            emb = reps[0]
        else:
            raise ValueError(f"Unknown aggregation: {self.aggregation}")

        return emb.cpu()


import pytorch_lightning as pl
from torch.utils.data import DataLoader, Dataset

# --- 1. Wrap the embedding dataset so it yields {'emb': tensor} for Lightning ---
class EmbeddingWrapper(Dataset):
    def __init__(self, embed_ds):
        self.embed_ds = embed_ds

    def __len__(self):
        return len(self.embed_ds)

    def __getitem__(self, idx):
        emb = self.embed_ds[idx]    # emb: torch.Tensor of shape (C,)
        return {"emb": emb}

In [None]:
def get_esm_embeddings(batch_seqs):
    # batch_seqs: List of tuples (name, seq)
    _, _, toks = esm_batch_converter(batch_seqs)
    with torch.no_grad():
        out = esm_model(toks, repr_layers=[33])
    return out["representations"][33].mean(1)  # mean over residues

def get_pt_embeddings(encodings):
    input_ids = encodings["input_ids"].squeeze(0)
    attention_mask = encodings["attention_mask"].squeeze(0)
    with torch.no_grad():
        out = pt_model(
            input_ids[None], attention_mask=attention_mask[None]
        )
    mask = attention_mask[:, None].bool()
    emb = (out.last_hidden_state * mask).sum(1) / mask.sum(1)
    return emb

## 3. β-VAE / FactorVAE Modules

Define the VAE architecture and LightningModule wrapper.

In [None]:
class VAE(nn.Module):
    def __init__(self, input_dim, z_dim=32, hidden_dim=256):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc_mu = nn.Linear(hidden_dim, z_dim)
        self.fc_logvar = nn.Linear(hidden_dim, z_dim)
        self.fc_dec = nn.Linear(z_dim, hidden_dim)
        self.fc_out = nn.Linear(hidden_dim, input_dim)

    def encode(self, x):
        h = F.relu(self.fc1(x))
        return self.fc_mu(h), self.fc_logvar(h)

    def reparameterize(self, mu, logvar):
        std = (0.5 * logvar).exp()
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        h = F.relu(self.fc_dec(z))
        return self.fc_out(h)

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar

In [None]:
class VAETrainer(pl.LightningModule):
    def __init__(self, input_dim, z_dim=32, beta=4.0):
        super().__init__()
        self.model = VAE(input_dim, z_dim)
        self.beta = beta

    def training_step(self, batch, batch_idx):
        x = batch['emb']  # embedding vector
        recon, mu, logvar = self.model(x)
        recon_loss = F.mse_loss(recon, x, reduction='mean')
        kl = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp())
        loss = recon_loss + self.beta * kl
        self.log('train_loss', loss)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-3)

## 4. SimCLR Head

Define the projection head and NT-Xent loss.

In [None]:
class SimCLRHead(nn.Module):
    def __init__(self, emb_dim, proj_dim=64):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(emb_dim, emb_dim),
            nn.ReLU(),
            nn.Linear(emb_dim, proj_dim)
        )

    def forward(self, x):
        return self.net(x)

def nt_xent_loss(z_i, z_j, temperature=0.5):
    N = z_i.size(0)
    z = torch.cat([z_i, z_j], dim=0)
    sim = F.cosine_similarity(z.unsqueeze(1), z.unsqueeze(0), dim=2)
    sim_exp = torch.exp(sim / temperature)
    mask = ~torch.eye(2 * N, dtype=bool, device=z.device)
    sim_exp = sim_exp.masked_select(mask).view(2 * N, -1)
    positive = torch.exp(F.cosine_similarity(z_i, z_j) / temperature)
    positives = torch.cat([positive, positive], dim=0)
    loss = -torch.log(positives / sim_exp.sum(dim=1))
    return loss.mean()

## 5. Putting It Together

Combine embeddings with SimCLR training in a LightningModule.

In [None]:
class SimCLRTrainer(pl.LightningModule):
    def __init__(self, emb_dim, proj_dim=64, temp=0.5):
        super().__init__()
        self.encoder = nn.Identity()  # embeddings precomputed
        self.head = SimCLRHead(emb_dim, proj_dim)
        self.temp = temp

    def training_step(self, batch, batch_idx):
        z1 = self.head(batch['emb1'])
        z2 = self.head(batch['emb2'])
        loss = nt_xent_loss(z1, z2, self.temp)
        self.log('train_loss', loss)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=3e-4)

## 6. Workflow

Prepare datasets and train.


In [None]:
# Install dependencies
!pip install numpy==1.26 torch torchvision pytorch-lightning transformers fair-esm biopython datasets


Collecting numpy==1.26
  Downloading numpy-1.26.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (58 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/58.5 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m58.5/58.5 kB[0m [31m5.1 MB/s[0m eta [36m0:00:00[0m
Downloading numpy-1.26.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (18.2 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m18.2/18.2 MB[0m [31m60.6 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: numpy
  Attempting uninstall: numpy
    Found existing installation: numpy 2.0.2
    Uninstalling numpy-2.0.2:
      Successfully uninstalled numpy-2.0.2
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
thinc 8.3.6 requires numpy<3.0.0,>=2.0.0, but you have n

In [None]:
# Create example FASTA so the path exists
import os

os.makedirs("data", exist_ok=True)
with open("data/my_proteins.fasta", "w") as f:
    f.write(""">protein1
MKWVTFISLLFLFSSAYSRGVFRRDTHKSEIAHRFKDLGE
>protein2
GILGYTEAQVKILDGGSGFYTNLTMATPLKAPIK
>protein3
MTIQTGLDSTGTTMTVVESKDLKELLEAQQGIQAYSQVGR
""")


# Imports
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import pytorch_lightning as pl

from Bio import SeqIO
from datasets import Dataset

import esm   # FAIR’s ESM library
import numpy as np # Import numpy

# 1) Read your FASTA and build an HF Dataset
fasta_path = "data/my_proteins.fasta"
records = list(SeqIO.parse(fasta_path, "fasta"))
ids   = [rec.id  for rec in records]
seqs  = [str(rec.seq) for rec in records]
hf_ds = Dataset.from_dict({"id": ids, "sequence": seqs})

# 2) Load the frozen ESM model once
device = "cuda" if torch.cuda.is_available() else "cpu"
esm_model, esm_alphabet = esm.pretrained.esm2_t33_650M_UR50D()
esm_model = esm_model.eval().to(device)
batch_converter = esm_alphabet.get_batch_converter()
pad_idx   = esm_alphabet.padding_idx
num_layer = esm_model.num_layers

# 3) Compute embeddings via map()
def embed_batch(batch):
    names = batch["id"]
    seqs  = batch["sequence"]
    _, _, toks = batch_converter(list(zip(names, seqs)))
    toks = toks.to(device)
    with torch.no_grad():
        out = esm_model(toks, repr_layers=[num_layer])
    reps = out["representations"][num_layer]    # (B, L, C)
    mask = (toks != pad_idx).unsqueeze(-1)      # (B, L, 1)
    summed   = (reps * mask).sum(1)             # (B, C)
    lengths  = mask.sum(1).clamp(min=1)         # (B, 1)
    emb      = summed / lengths                # (B, C)
    return {"emb": emb.cpu().numpy().tolist()}

hf_emb = hf_ds.map(embed_batch, batched=True, batch_size=16, remove_columns=["id", "sequence"])
hf_emb.set_format(type="torch", columns=["emb"])

# 4) Prepare DataLoader
sample = hf_emb[0]["emb"]
input_dim = sample.shape[0]
print(f"Embedding dim = {input_dim}")

train_loader = DataLoader(hf_emb, batch_size=32, shuffle=True, num_workers=0)

# 5) Define VAE Trainer
class VAE(nn.Module):
    def __init__(self, input_dim, z_dim=32, hidden_dim=256):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc_mu = nn.Linear(hidden_dim, z_dim)
        self.fc_logvar = nn.Linear(hidden_dim, z_dim)
        self.fc_dec = nn.Linear(z_dim, hidden_dim)
        self.fc_out = nn.Linear(hidden_dim, input_dim)
    def encode(self, x):
        h = F.relu(self.fc1(x))
        return self.fc_mu(h), self.fc_logvar(h)
    def reparameterize(self, mu, logvar):
        std = (0.5*logvar).exp()
        eps = torch.randn_like(std)
        return mu + eps*std
    def decode(self, z):
        h = F.relu(self.fc_dec(z))
        return self.fc_out(h)
    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar

class VAETrainer(pl.LightningModule):
    def __init__(self, input_dim, z_dim=32, beta=4.0):
        super().__init__()
        self.model = VAE(input_dim, z_dim)
        self.beta = beta
    def training_step(self, batch, batch_idx):
        x = batch["emb"]
        recon, mu, logvar = self.model(x)
        recon_loss = F.mse_loss(recon, x, reduction='mean')
        kl = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp())
        loss = recon_loss + self.beta * kl
        self.log("train_loss", loss)
        return loss
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-3)

# 6) Train!
vae_module = VAETrainer(input_dim=input_dim, z_dim=32, beta=4.0)
trainer = pl.Trainer(max_epochs=5, accelerator="gpu" if torch.cuda.is_available() else "cpu") # Updated gpus to accelerator
trainer.fit(vae_module, train_loader)

Map:   0%|          | 0/3 [00:00<?, ? examples/s]

INFO:pytorch_lightning.utilities.rank_zero:💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.


Embedding dim = 1280


INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name  | Type | Params | Mode 
---------------------------------------
0 | model | VAE  | 681 K  | train
---------------------------------------
681 K     Trainable params
0         Non-trainable params
681 K     Total params
2.727     Total estimated model params size (MB)
6         Modules in train mode
0         Modules in eval mode
/usr/local/lib/python3.11/dist-packages/pytorch_lightning/loops/fit_loop.py:310: The number of training batches (1) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Training: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=5` reached.
