In [None]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset
from flow_matching.path import MixtureDiscreteProbPath
from flow_matching.path.scheduler import PolynomialConvexScheduler
from flow_matching.loss import MixturePathGeneralizedKL
from flow_matching.solver import MixtureDiscreteEulerSolver
from flow_matching.utils import ModelWrapper

2.9.0+cpu
None
False


In [None]:
ds = load_dataset("bandeiralab/Pep2Prob")
sequences = ds['train']['peptide'][:10000]

# Vocabulary
vocab = ['A','C','D','E','F','G','H','I','K','L','M','N','P','Q','R','S','T','V','W','Y','[PAD]','[MASK]']
vocab_size = len(vocab)
aa_to_id = {aa: i for i, aa in enumerate(vocab)}
id_to_aa = {i: aa for i, aa in enumerate(vocab)}
mask_id = aa_to_id['[MASK]']
pad_id = aa_to_id['[PAD]']
max_len = 40

In [None]:
class PeptideDataset(Dataset):
    def __init__(self, sequences, max_len):
        self.sequences = sequences
        self.max_len = max_len

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        seq = self.sequences[idx][:self.max_len]
        ids = [aa_to_id[aa] for aa in seq]
        ids += [pad_id] * (self.max_len - len(ids))
        return torch.tensor(ids)

# Model
class VelocityModel(nn.Module):
    def __init__(self, vocab_size, d_model=64, n_layers=6):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.transformer = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model, nhead=4), num_layers=n_layers
        )
        self.fc_out = nn.Linear(d_model, vocab_size)
        self.d_model = d_model

    def forward(self, x, t):
        emb = self.embedding(x) + t.unsqueeze(1).unsqueeze(2).repeat(1, x.size(1), self.d_model) 
        out = self.transformer(emb)
        return self.fc_out(out)

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
dataset = PeptideDataset(sequences, max_len)
dataloader = DataLoader(dataset, batch_size=16, shuffle=True)
model = VelocityModel(vocab_size)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

# FM components
scheduler = PolynomialConvexScheduler(n=1.0)
path = MixtureDiscreteProbPath(scheduler=scheduler)
loss_fn = MixturePathGeneralizedKL(path=path)

# Training (cháº¡y 5 epochs cho demo)
for epoch in range(100):
    for x_1 in dataloader:
        batch_size = x_1.size(0)
        t = torch.rand(batch_size) * (1.0 - 1e-3)
        x_0 = torch.full_like(x_1, mask_id)  # x_0 simple: all masked
        sample = path.sample(t=t, x_0=x_0, x_1=x_1)
        logits = model(sample.x_t, sample.t)
        loss = loss_fn(logits=logits, x_1=sample.x_1, x_t=sample.x_t, t=sample.t)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    print(f"Epoch {epoch+1}: Loss {loss.item()}")

# Sampling
class ProbabilityDenoiser(ModelWrapper):
    def forward(self, x, t, **extras):
        logits = model(x, t)
        return torch.nn.functional.softmax(logits.float(), dim=-1)

denoiser = ProbabilityDenoiser(model=model)
solver = MixtureDiscreteEulerSolver(model=denoiser, path=path, vocabulary_size=vocab_size)

In [None]:
x_init = torch.full((3, max_len), mask_id)  # Generate 3 sequences
x_generated = solver.sample(x_init=x_init, step_size=0.01)

# Decode
generated_peptides = [''.join(id_to_aa[id.item()] for id in seq if id != pad_id) for seq in x_generated]
print("Generated peptides:", generated_peptides)