In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from torch.utils.data import Dataset, DataLoader
import numpy as np
import pandas as pd
import re
from collections import Counter
from sklearn.model_selection import train_test_split
import random

# 超参数配置
class Config:
    # 序列相关参数
    seq_vocab_size = 21  # 20种氨基酸 + 1个填充位置
    max_seq_len = 256
    seq_embed_dim = 128
    
    # 文本相关参数（单词级）
    text_vocab_size = 5000  # 单词词汇表大小
    max_text_len = 128     # 单词序列长度（更短）
    text_embed_dim = 128
    
    # 其他共享参数
    time_embed_dim = 64
    hidden_dim = 256
    num_heads = 8
    num_layers = 6
    dropout = 0.3
    beta_start = 1e-4
    beta_end = 0.02
    T = 1000
    batch_size = 32
    lr = 1e-4
    weight_decay = 1e-5
    val_split = 0.1
    # 数据增强参数
    augmentation_factor = 3  # 每条序列增强的次数
    mutation_rate = 0.01     # 每个氨基酸的变异概率

config = Config()

# 文本分词器（单词级）
class WordTokenizer:
    def __init__(self, word_counts=None):
        # 特殊标记
        self.special_tokens = ["<pad>", "<unk>", "<sos>", "<eos>"]
        self.word2idx = {}
        self.idx2word = {}
        
        # 构建词汇表
        self.vocab = self.special_tokens.copy()
        
        # 添加高频词
        if word_counts:
            sorted_words = sorted(word_counts.items(), key=lambda x: x[1], reverse=True)
            for word, _ in sorted_words:
                if len(self.vocab) < config.text_vocab_size:
                    self.vocab.append(word)
                else:
                    break
        
        # 创建映射
        for i, word in enumerate(self.vocab):
            self.word2idx[word] = i
            self.idx2word[i] = word
        
        self.vocab_size = len(self.vocab)
    
    def tokenize(self, text):
        """将文本分割为单词/标点序列"""
        # 匹配单词（含连字符）和标点
        return re.findall(r"\b[\w'-]+\b|[^\w\s]", text.lower())
    
    def encode(self, text, max_len=None):
        max_len = max_len or config.max_text_len
        tokens = self.tokenize(text)
        
        # 添加起始和结束标记
        tokens = ["<sos>"] + tokens[:max_len-2] + ["<eos>"]
        
        # 转换为索引
        encoded = []
        for token in tokens:
            if token in self.word2idx:
                encoded.append(self.word2idx[token])
            else:
                encoded.append(self.word2idx["<unk>"])
        
        # 填充或截断
        if len(encoded) < max_len:
            encoded += [self.word2idx["<pad>"]] * (max_len - len(encoded))
        else:
            encoded = encoded[:max_len-1] + [self.word2idx["<eos>"]]  # 确保结束标记
        
        return encoded
    
    def decode(self, indices):
        """将索引序列解码为文本"""
        words = []
        for idx in indices:
            word = self.idx2word.get(idx, "<unk>")
            if word in ["<pad>", "<sos>", "<eos>"]:
                continue
            words.append(word)
        return ' '.join(words)

# 蛋白质序列分词器
class SequenceTokenizer:
    def __init__(self):
        self.aa_list = "ACDEFGHIKLMNPQRSTVWY"  # 20种标准氨基酸
        self.pad_token = "<pad>"
        
        self.aa2idx = {self.pad_token: 0}
        for aa in self.aa_list:
            self.aa2idx[aa] = len(self.aa2idx)
        
        self.idx2aa = {v: k for k, v in self.aa2idx.items()}
        self.vocab_size = len(self.aa2idx)
    
    def encode(self, sequence, max_len=None):
        max_len = max_len or config.max_seq_len
        encoded = []
        for aa in sequence[:max_len]:
            encoded.append(self.aa2idx.get(aa, self.aa2idx[self.pad_token]))
        
        if len(encoded) < max_len:
            encoded += [self.aa2idx[self.pad_token]] * (max_len - len(encoded))
        return encoded
    
    def decode(self, indices):
        return ''.join([self.idx2aa.get(idx, 'X') for idx in indices if idx != self.aa2idx[self.pad_token]])

# 时间步嵌入
class SinusoidalPositionEmbeddings(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, t):
        device = t.device
        half_dim = self.dim // 2
        # 计算频率
        embeddings = math.log(10000) / (half_dim - 1)
        embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
        # 时间步与频率相乘
        embeddings = t[:, None] * embeddings[None, :]
        # 拼接正弦和余弦
        embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
        return embeddings

# 文本编码器
class TextEncoder(nn.Module):
    def __init__(self, config):
        super().__init__()
        # 单词嵌入层
        self.embedding = nn.Embedding(config.text_vocab_size, config.text_embed_dim, padding_idx=0)
        
        # Transformer编码器
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=config.text_embed_dim,
            nhead=config.num_heads,
            dim_feedforward=config.hidden_dim,
            dropout=config.dropout,
            batch_first=True
        )
        self.encoder = nn.TransformerEncoder(encoder_layer, config.num_layers)
        
    def forward(self, text):
        embedded = self.embedding(text)
        encoded = self.encoder(embedded)
        return encoded

# 蛋白质扩散模型
class DiffusionModel(nn.Module):
    def __init__(self, config, text_tokenizer, seq_tokenizer):
        super().__init__()
        self.config = config
        self.text_tokenizer = text_tokenizer
        self.seq_tokenizer = seq_tokenizer
        self.pad_idx = seq_tokenizer.aa2idx["<pad>"]  # 序列填充索引
        
        # 文本条件编码器
        self.text_encoder = TextEncoder(config)
        
        # 时间步嵌入
        self.time_embed = nn.Sequential(
            SinusoidalPositionEmbeddings(config.time_embed_dim),
            nn.Linear(config.time_embed_dim, config.time_embed_dim),
            nn.GELU()
        )
        
        # 序列嵌入层（使用蛋白质词汇表）
        self.seq_embed = nn.Embedding(
            seq_tokenizer.vocab_size, 
            config.seq_embed_dim,
            padding_idx=self.pad_idx
        )
        
        # 条件融合：融合文本编码和时间嵌入
        self.condition_fuse = nn.Sequential(
            nn.Linear(config.text_embed_dim + config.time_embed_dim, config.hidden_dim),
            nn.GELU(),
            nn.Dropout(config.dropout),
            nn.Linear(config.hidden_dim, config.seq_embed_dim)  # 映射到序列嵌入维度
        )
        
        # Transformer解码器（用于去噪）
        decoder_layer = nn.TransformerDecoderLayer(
            d_model=config.seq_embed_dim,
            nhead=config.num_heads,
            dim_feedforward=config.hidden_dim,
            dropout=config.dropout,
            batch_first=True
        )
        self.decoder = nn.TransformerDecoder(decoder_layer, config.num_layers)
        
        # 输出层：将嵌入映射到氨基酸词汇表
        self.output_layer = nn.Linear(config.seq_embed_dim, seq_tokenizer.vocab_size)
        
        # 初始化扩散参数（beta/alpha系列）
        self.register_buffer('betas', torch.linspace(config.beta_start, config.beta_end, config.T))
        self.register_buffer('alphas', 1. - self.betas)
        self.register_buffer('alphas_bar', torch.cumprod(self.alphas, dim=0))  # 累积乘积
        self.register_buffer('sqrt_alphas_bar', torch.sqrt(self.alphas_bar))
        self.register_buffer('sqrt_one_minus_alphas_bar', torch.sqrt(1. - self.alphas_bar))
    
    def forward_emb(self, xt_emb, t, text_condition):
        """处理加噪后的嵌入，预测原始嵌入（核心前向逻辑）"""
        # 1. 时间步嵌入
        t_emb = self.time_embed(t)  # (batch, time_embed_dim)
        
        # 2. 文本条件编码
        text_encoded = self.text_encoder(text_condition)  # (batch, text_len, text_embed_dim)
        
        # 3. 融合文本和时间条件
        # 文本编码取平均作为全局特征
        text_avg = text_encoded.mean(dim=1)  # (batch, text_embed_dim)
        # 拼接文本和时间嵌入
        cond = torch.cat([text_avg, t_emb], dim=-1)  # (batch, text_embed_dim + time_embed_dim)
        cond = self.condition_fuse(cond)  # (batch, seq_embed_dim)
        # 扩展到序列长度（与xt_emb对齐）
        cond = cond.unsqueeze(1).repeat(1, xt_emb.size(1), 1)  # (batch, seq_len, seq_embed_dim)
        
        # 4. 融合加噪嵌入与条件
        x_cond = xt_emb + cond  # (batch, seq_len, seq_embed_dim)
        
        # 5. Transformer解码（用文本编码作为memory增强条件）
        output_emb = self.decoder(tgt=x_cond, memory=text_encoded)  # (batch, seq_len, seq_embed_dim)
        
        return output_emb

    def p_loss(self, x0, text_condition):
        """计算扩散损失（基于嵌入空间的MSE）"""
        batch_size = x0.size(0)
        device = x0.device
        
        # 1. 随机采样时间步
        t = torch.randint(0, self.config.T, (batch_size,), device=device)
        
        # 2. 原始序列嵌入
        x0_emb = self.seq_embed(x0)  # (batch, seq_len, seq_embed_dim)
        
        # 3. 前向扩散：在嵌入空间添加噪声
        noise = torch.randn_like(x0_emb)  # 高斯噪声
        # 提取当前时间步的系数
        sqrt_alpha_bar = self.sqrt_alphas_bar[t][:, None, None]  # (batch, 1, 1)
        sqrt_one_minus_alpha_bar = self.sqrt_one_minus_alphas_bar[t][:, None, None]  # (batch, 1, 1)
        # 计算加噪嵌入xt_emb
        xt_emb = sqrt_alpha_bar * x0_emb + sqrt_one_minus_alpha_bar * noise  # (batch, seq_len, seq_embed_dim)
        
        # 4. 模型预测原始嵌入
        pred_x0_emb = self.forward_emb(xt_emb, t, text_condition)  # (batch, seq_len, seq_embed_dim)
        
        # 5. 计算MSE损失（忽略填充位置）
        mask = (x0 != self.pad_idx).float().unsqueeze(-1)  # (batch, seq_len, 1)：填充位置为0
        loss = F.mse_loss(pred_x0_emb * mask, x0_emb * mask, reduction='mean')
        
        return loss

    @torch.no_grad()
    def p_sample(self, text_condition):
        """反向扩散采样：从噪声嵌入生成蛋白质序列"""
        device = next(self.parameters()).device
        batch_size = text_condition.size(0)
        seq_len = self.config.max_seq_len
        
        # 1. 初始噪声嵌入（从标准正态分布采样）
        xt_emb = torch.randn((batch_size, seq_len, self.config.seq_embed_dim), device=device)
        
        # 2. 逐步去噪
        for t_step in reversed(range(self.config.T)):
            t = torch.full((batch_size,), t_step, device=device, dtype=torch.long)  # 当前时间步
            
            # 预测原始嵌入
            pred_x0_emb = self.forward_emb(xt_emb, t, text_condition)  # (batch, seq_len, seq_embed_dim)
            
            # 计算反向扩散系数
            alpha = self.alphas[t_step]
            alpha_bar = self.alphas_bar[t_step]
            beta = self.betas[t_step]
            
            # 采样噪声（最后一步用0噪声）
            if t_step > 0:
                noise = torch.randn_like(xt_emb)
            else:
                noise = torch.zeros_like(xt_emb)
            
            # 反向更新公式：xt-1 = (1/sqrt(alpha)) * (xt - (1-alpha)/sqrt(1-alpha_bar) * pred_x0) + sqrt(beta) * noise
            xt_emb = (1 / torch.sqrt(alpha)) * (
                xt_emb - ((1 - alpha) / torch.sqrt(1 - alpha_bar)) * pred_x0_emb
            ) + torch.sqrt(beta) * noise
        
        # 3. 将最终嵌入映射到氨基酸索引
        logits = self.output_layer(xt_emb)  # (batch, seq_len, vocab_size)
        return torch.argmax(logits, dim=-1)  # 取概率最大的索引

# 数据增强函数
def augment_protein_data(sequences, descriptions, config):
    """通过轻微变异增强蛋白质序列数据"""
    augmented_seqs = []
    augmented_descs = []
    
    amino_acids = "ACDEFGHIKLMNPQRSTVWY"
    
    for seq, desc in zip(sequences, descriptions):
        # 添加原始序列
        augmented_seqs.append(seq)
        augmented_descs.append(desc)
        
        # 创建变异序列
        for i in range(config.augmentation_factor):
            seq_list = list(seq)
            
            # 随机替换氨基酸
            for j in range(len(seq_list)):
                if random.random() < config.mutation_rate:
                    current_aa = seq_list[j]
                    possible_replacements = [aa for aa in amino_acids if aa != current_aa]
                    if possible_replacements:
                        seq_list[j] = random.choice(possible_replacements)
            
            augmented_seqs.append(''.join(seq_list))
            # 在描述中添加变体标记
            augmented_descs.append(f"{desc} (variant {i+1})")
    
    return augmented_seqs, augmented_descs

# 自定义数据集（包含数据增强）
class ProteinDataset(Dataset):
    def __init__(self, sequences, descriptions, seq_tokenizer, build_vocab=False, augment=False, config=None):
        self.sequences = sequences
        self.descriptions = descriptions 
        self.seq_tokenizer = seq_tokenizer
        self.augment = augment
        self.config = config
        
        # 如果需要数据增强
        if augment and config:
            print("Applying data augmentation...")
            self.sequences, self.descriptions = augment_protein_data(sequences, descriptions, config)
            print(f"Augmented dataset size: {len(self.sequences)}")
        
        # 构建单词词汇表
        if build_vocab:
            self.word_counts = self.build_word_counts(self.descriptions)
            self.text_tokenizer = WordTokenizer(self.word_counts)
        else:
            self.text_tokenizer = None
    
    def build_word_counts(self, descriptions):
        """从描述文本中统计词频"""
        word_counts = Counter()
        for desc in descriptions:
            # 简单分词（实际使用应与WordTokenizer相同的分词逻辑）
            words = re.findall(r"\b[\w'-]+\b|[^\w\s]", desc.lower())
            word_counts.update(words)
        return word_counts
    
    def __len__(self):
        return len(self.sequences)
    
    def __getitem__(self, idx):
        seq = self.sequences[idx]
        desc = self.descriptions[idx]
        
        seq_indices = self.seq_tokenizer.encode(seq)
        desc_indices = self.text_tokenizer.encode(desc)
        
        return {
            "sequence": torch.tensor(seq_indices, dtype=torch.long),
            "description": torch.tensor(desc_indices, dtype=torch.long)
        }

# 训练函数（带验证集早停）
def train(model, train_loader, val_loader, optimizer, device, epochs=50, patience=5):
    model.train()
    best_val_loss = float('inf')
    counter = 0  # 早停计数器
    
    for epoch in range(epochs):
        # 训练阶段
        train_loss = 0.0
        model.train()
        for batch in train_loader:
            sequences = batch["sequence"].to(device)
            descriptions = batch["description"].to(device)
            
            optimizer.zero_grad()
            loss = model.p_loss(sequences, descriptions)
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item() * sequences.size(0)  # 累计总损失
        avg_train_loss = train_loss / len(train_loader.dataset)
        
        # 验证阶段
        val_loss = 0.0
        model.eval()
        with torch.no_grad():
            for batch in val_loader:
                sequences = batch["sequence"].to(device)
                descriptions = batch["description"].to(device)
                
                loss = model.p_loss(sequences, descriptions)
                val_loss += loss.item() * sequences.size(0)
        avg_val_loss = val_loss / len(val_loader.dataset)
        
        # 打印日志
        print(f"Epoch {epoch+1}/{epochs}")
        print(f"Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f}")
        
        # 早停判断（基于验证损失）
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            counter = 0
            torch.save(model.state_dict(), "best_protein_model.pth")  # 保存最优模型
            print("Saved best model (val loss improved)")
        else:
            counter += 1
            if counter >= patience:
                print(f"Early stopping at epoch {epoch+1} (no improvement for {patience} epochs)")
                break
    
    # 加载最优模型
    model.load_state_dict(torch.load("best_protein_model.pth"))
    return model

# 生成函数（根据文本描述生成蛋白质序列）
def generate_sequence(model, description, device):
    model.eval()
    
    # 编码文本描述
    desc_indices = model.text_tokenizer.encode(description)
    desc_tensor = torch.tensor([desc_indices], dtype=torch.long).to(device)  # (1, max_text_len)
    
    # 采样生成序列索引
    with torch.no_grad():
        generated_indices = model.p_sample(desc_tensor)  # (1, max_seq_len)
    
    # 解码为氨基酸序列
    return model.seq_tokenizer.decode(generated_indices[0].cpu().numpy())

# 主程序
if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    # 初始化序列分词器
    seq_tokenizer = SequenceTokenizer()
    print(f"Sequence vocabulary size: {seq_tokenizer.vocab_size}")
    
    # 加载数据
    tsv_path = "uniprot-data.tsv"
    try:
        df = pd.read_csv(tsv_path, sep='\t')
        print(f"Loaded TSV with columns: {', '.join(df.columns)}")
    except Exception as e:
        print(f"Error loading TSV: {e}")
        # 创建示例数据
        sample_data = {
            "Sequence": [
                "MAGLRGLRVGAALAGLAVLGCAVALAVGGGQASFTSQQASALATPGGGQASFTSQQAT",
                "MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPR",
                "MALWMRLLPLLALLALWGPDPAAAFVNQHLCGSHLVEALYLVCGERGFFYTPKT",
                "MGLLCSWSRHCSLHGLGRSAGALRRGPGGPGPLLGLAVLGLSGPGAPGLQALQALRGG"
            ],
            "Protein names": [
                "Hypothetical protein OS=Homo sapiens",
                "Transcription factor AP-1 OS=Homo sapiens",
                "Insulin OS=Homo sapiens",
                "Collagen alpha-1(I) chain OS=Homo sapiens"
            ]
        }
        df = pd.DataFrame(sample_data)
        df.to_csv("uniprot-data.tsv", sep='\t', index=False)
        print("Created sample dataset: uniprot-data.tsv")
    
    # 提取序列和描述
    try:
        seq_col = "Sequence"
        desc_col = "Protein names"
        sequences = df[seq_col].dropna().tolist()
        descriptions = df[desc_col].dropna().tolist()
        
        min_len = min(len(sequences), len(descriptions))
        sequences = sequences[:min_len]
        descriptions = descriptions[:min_len]
        print(f"Extracted {len(sequences)} raw sequence-description pairs")
    except KeyError as e:
        print(f"Missing column in TSV: {e}")
        exit(1)
    
    # 过滤短序列
    filtered_sequences = []
    filtered_descriptions = []
    for seq, desc in zip(sequences, descriptions):
        if 20 < len(seq) <= config.max_seq_len:
            filtered_sequences.append(seq)
            filtered_descriptions.append(desc)
    print(f"Filtered to {len(filtered_sequences)} valid pairs (length 20~{config.max_seq_len})")
    
    # 拆分训练集和验证集
    train_seqs, val_seqs, train_descs, val_descs = train_test_split(
        filtered_sequences, 
        filtered_descriptions, 
        test_size=config.val_split,
        random_state=42
    )
    
    # 创建数据集（在训练集上构建词汇表并应用数据增强）
    train_dataset = ProteinDataset(
        train_seqs, train_descs, seq_tokenizer, 
        build_vocab=True, augment=True, config=config
    )
    text_tokenizer = train_dataset.text_tokenizer
    print(f"\nText vocabulary size: {text_tokenizer.vocab_size}")
    print(f"Top 10 words: {text_tokenizer.vocab[4:14]}")  # 跳过特殊标记
    
    # 验证集使用相同的词汇表，但不应用数据增强
    val_dataset = ProteinDataset(val_seqs, val_descs, seq_tokenizer)
    val_dataset.text_tokenizer = text_tokenizer
    
    train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, drop_last=True)
    val_loader = DataLoader(val_dataset, batch_size=config.batch_size, shuffle=False)
    
    # 初始化模型和优化器
    model = DiffusionModel(config, text_tokenizer, seq_tokenizer).to(device)
    optimizer = torch.optim.Adam(
        model.parameters(), 
        lr=config.lr,
        weight_decay=config.weight_decay
    )
    
    # 训练模型
    print("\nStarting training...")
    model = train(model, train_loader, val_loader, optimizer, device, epochs=50, patience=5)
    
    # 保存完整模型
    torch.save({
        'model_state_dict': model.state_dict(),
        'text_tokenizer': text_tokenizer,
        'seq_tokenizer': seq_tokenizer,
        'config': config
    }, "protein_diffusion_model_final.pth")
    print("\nModel saved to 'protein_diffusion_model_final.pth'")
    
    # 生成示例序列
    print("\nGenerating sample sequences...")
    sample_descriptions = [
        "DNA binding protein involved in transcription",
        "Enzyme with catalytic activity for hydrolysis",
        "Membrane transport protein for ions"
    ]
    
    for desc in sample_descriptions:
        generated_seq = generate_sequence(model, desc, device)
        print(f"\nDescription: {desc}")
        print(f"Generated sequence: {generated_seq[:100]}...")
        print(f"Sequence length: {len(generated_seq)}")

Using device: cpu
Sequence vocabulary size: 21
Loaded TSV with columns: Entry, Protein names, Sequence
Extracted 20420 raw sequence-description pairs
Filtered to 5260 valid pairs (length 20~256)
Applying data augmentation...
Augmented dataset size: 18936

Text vocabulary size: 5000
Top 10 words: ['(', ')', 'protein', 'variant', '1', '2', '.', '3', 'subunit', 'ec']

Starting training...
Epoch 1/50
Train Loss: 0.5674 | Val Loss: 0.3721
Saved best model (val loss improved)
Epoch 2/50
Train Loss: 0.3799 | Val Loss: 0.3217
Saved best model (val loss improved)
Epoch 3/50
Train Loss: 0.3206 | Val Loss: 0.2737
Saved best model (val loss improved)
Epoch 4/50
Train Loss: 0.2798 | Val Loss: 0.2486
Saved best model (val loss improved)
Epoch 5/50
Train Loss: 0.2481 | Val Loss: 0.2156
Saved best model (val loss improved)
Epoch 6/50
Train Loss: 0.2150 | Val Loss: 0.1833
Saved best model (val loss improved)
Epoch 7/50
Train Loss: 0.1933 | Val Loss: 0.1584
Saved best model (val loss improved)
Epoch 8/5

  model.load_state_dict(torch.load("best_protein_model.pth"))



Model saved to 'protein_diffusion_model_final.pth'

Generating sample sequences...

Description: DNA binding protein involved in transcription
Generated sequence: QWMHNPMWWAMYCYAIYMWAFRMHWWYHWMDCRFEILVYYNWMHAERHEIWREFNEFRFKPAMFTHANEMEVMERAPYDMMEGYCEGYFRCLQMRWQQLG...
Sequence length: 237

Description: Enzyme with catalytic activity for hydrolysis
Generated sequence: ASRSNYIFGEKRSQETDDYNHADHGEEHAMLHYYKSFYKTDNWGMEEWGPPHEMAYDAHLANCWHWWTYFGEIIYEFYAYYLDLQFPQMFWTQNAHMIHR...
Sequence length: 244

Description: Membrane transport protein for ions
Generated sequence: KCGPRLEHPSHKCQWWYEEINDCNPCIFHEERPYHMKMMQENYRRFMYVESDSWHQCLFGMGDGNYYFGYYCDSAGVWGEGMRVTRYKFWWRFYGHFRSK...
Sequence length: 244


In [None]:
# import torch
# import torch.nn as nn
# import torch.nn.functional as F
# import math
# from torch.utils.data import Dataset, DataLoader
# import numpy as np
# import pandas as pd
# import re
# from collections import Counter
# from sklearn.model_selection import train_test_split

In [None]:
# # 超参数配置 - 更新文本相关参数
# class Config:
#     # 序列相关参数
#     seq_vocab_size = 21  # 20种氨基酸 + 1个填充位置
#     max_seq_len = 256
#     seq_embed_dim = 128
    
#     # 文本相关参数（单词级）
#     text_vocab_size = 5000  # 单词词汇表大小
#     max_text_len = 128     # 单词序列长度（更短）
#     text_embed_dim = 128
    
#     # 其他共享参数
#     time_embed_dim = 64
#     hidden_dim = 256
#     num_heads = 8
#     num_layers = 6
#     dropout = 0.3
#     beta_start = 1e-4
#     beta_end = 0.02
#     T = 1000
#     batch_size = 32
#     lr = 1e-4
#     weight_decay = 1e-5
#     val_split = 0.1

# config = Config()

In [None]:
# # 文本分词器（单词级）
# class WordTokenizer:
#     def __init__(self, word_counts=None):
#         # 特殊标记
#         self.special_tokens = ["<pad>", "<unk>", "<sos>", "<eos>"]
#         self.word2idx = {}
#         self.idx2word = {}
        
#         # 构建词汇表
#         self.vocab = self.special_tokens.copy()
        
#         # 添加高频词
#         if word_counts:
#             sorted_words = sorted(word_counts.items(), key=lambda x: x[1], reverse=True)
#             for word, _ in sorted_words:
#                 if len(self.vocab) < config.text_vocab_size:
#                     self.vocab.append(word)
#                 else:
#                     break
        
#         # 创建映射
#         for i, word in enumerate(self.vocab):
#             self.word2idx[word] = i
#             self.idx2word[i] = word
        
#         self.vocab_size = len(self.vocab)
    
#     def tokenize(self, text):
#         """将文本分割为单词/标点序列"""
#         # 匹配单词（含连字符）和标点
#         return re.findall(r"\b[\w'-]+\b|[^\w\s]", text.lower())
    
#     def encode(self, text, max_len=None):
#         max_len = max_len or config.max_text_len
#         tokens = self.tokenize(text)
        
#         # 添加起始和结束标记
#         tokens = ["<sos>"] + tokens[:max_len-2] + ["<eos>"]
        
#         # 转换为索引
#         encoded = []
#         for token in tokens:
#             if token in self.word2idx:
#                 encoded.append(self.word2idx[token])
#             else:
#                 encoded.append(self.word2idx["<unk>"])
        
#         # 填充或截断
#         if len(encoded) < max_len:
#             encoded += [self.word2idx["<pad>"]] * (max_len - len(encoded))
#         else:
#             encoded = encoded[:max_len-1] + [self.word2idx["<eos>"]]  # 确保结束标记
        
#         return encoded
    
#     def decode(self, indices):
#         """将索引序列解码为文本"""
#         words = []
#         for idx in indices:
#             word = self.idx2word.get(idx, "<unk>")
#             if word in ["<pad>", "<sos>", "<eos>"]:
#                 continue
#             words.append(word)
#         return ' '.join(words)

In [None]:
# # 蛋白质序列分词器（保持不变）
# class SequenceTokenizer:
#     def __init__(self):
#         self.aa_list = "ACDEFGHIKLMNPQRSTVWY"  # 20种标准氨基酸
#         self.pad_token = "<pad>"
        
#         self.aa2idx = {self.pad_token: 0}
#         for aa in self.aa_list:
#             self.aa2idx[aa] = len(self.aa2idx)
        
#         self.idx2aa = {v: k for k, v in self.aa2idx.items()}
#         self.vocab_size = len(self.aa2idx)
    
#     def encode(self, sequence, max_len=None):
#         max_len = max_len or config.max_seq_len
#         encoded = []
#         for aa in sequence[:max_len]:
#             encoded.append(self.aa2idx.get(aa, self.aa2idx[self.pad_token]))
        
#         if len(encoded) < max_len:
#             encoded += [self.aa2idx[self.pad_token]] * (max_len - len(encoded))
#         return encoded
    
#     def decode(self, indices):
#         return ''.join([self.idx2aa.get(idx, 'X') for idx in indices if idx != self.aa2idx[self.pad_token]])

In [None]:
# # 时间步嵌入（正弦位置编码，用于扩散模型的时间步表示）
# class SinusoidalPositionEmbeddings(nn.Module):
#     def __init__(self, dim):
#         super().__init__()
#         self.dim = dim

#     def forward(self, t):
#         device = t.device
#         half_dim = self.dim // 2
#         # 计算频率
#         embeddings = math.log(10000) / (half_dim - 1)
#         embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
#         # 时间步与频率相乘
#         embeddings = t[:, None] * embeddings[None, :]
#         # 拼接正弦和余弦
#         embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
#         return embeddings

In [None]:
# # 文本编码器（更新为处理单词嵌入）
# class TextEncoder(nn.Module):
#     def __init__(self, config):
#         super().__init__()
#         # 单词嵌入层
#         self.embedding = nn.Embedding(config.text_vocab_size, config.text_embed_dim, padding_idx=0)
        
#         # Transformer编码器
#         encoder_layer = nn.TransformerEncoderLayer(
#             d_model=config.text_embed_dim,
#             nhead=config.num_heads,
#             dim_feedforward=config.hidden_dim,
#             dropout=config.dropout,
#             batch_first=True
#         )
#         self.encoder = nn.TransformerEncoder(encoder_layer, config.num_layers)
        
#     def forward(self, text):
#         embedded = self.embedding(text)
#         encoded = self.encoder(embedded)
#         return encoded

In [None]:
# # 蛋白质扩散模型（核心模型：基于文本条件生成蛋白质序列）
# class DiffusionModel(nn.Module):
#     def __init__(self, config, text_tokenizer, seq_tokenizer):
#         super().__init__()
#         self.config = config
#         self.text_tokenizer = text_tokenizer
#         self.seq_tokenizer = seq_tokenizer
#         self.pad_idx = seq_tokenizer.aa2idx["<pad>"]  # 序列填充索引
        
#         # 文本条件编码器
#         self.text_encoder = TextEncoder(config)
        
#         # 时间步嵌入
#         self.time_embed = nn.Sequential(
#             SinusoidalPositionEmbeddings(config.time_embed_dim),
#             nn.Linear(config.time_embed_dim, config.time_embed_dim),
#             nn.GELU()
#         )
        
#         # 序列嵌入层（使用蛋白质词汇表）
#         self.seq_embed = nn.Embedding(
#             seq_tokenizer.vocab_size, 
#             config.seq_embed_dim,
#             padding_idx=self.pad_idx
#         )
        
#         # 条件融合：融合文本编码和时间嵌入
#         self.condition_fuse = nn.Sequential(
#             nn.Linear(config.text_embed_dim + config.time_embed_dim, config.hidden_dim),
#             nn.GELU(),
#             nn.Dropout(config.dropout),
#             nn.Linear(config.hidden_dim, config.seq_embed_dim)  # 映射到序列嵌入维度
#         )
        
#         # Transformer解码器（用于去噪）
#         decoder_layer = nn.TransformerDecoderLayer(
#             d_model=config.seq_embed_dim,
#             nhead=config.num_heads,
#             dim_feedforward=config.hidden_dim,
#             dropout=config.dropout,
#             batch_first=True
#         )
#         self.decoder = nn.TransformerDecoder(decoder_layer, config.num_layers)
        
#         # 输出层：将嵌入映射到氨基酸词汇表
#         self.output_layer = nn.Linear(config.seq_embed_dim, seq_tokenizer.vocab_size)
        
#         # 初始化扩散参数（beta/alpha系列）
#         self.register_buffer('betas', torch.linspace(config.beta_start, config.beta_end, config.T))
#         self.register_buffer('alphas', 1. - self.betas)
#         self.register_buffer('alphas_bar', torch.cumprod(self.alphas, dim=0))  # 累积乘积
#         self.register_buffer('sqrt_alphas_bar', torch.sqrt(self.alphas_bar))
#         self.register_buffer('sqrt_one_minus_alphas_bar', torch.sqrt(1. - self.alphas_bar))
#     def forward_emb(self, xt_emb, t, text_condition):
#         """处理加噪后的嵌入，预测原始嵌入（核心前向逻辑）"""
#         # 1. 时间步嵌入
#         t_emb = self.time_embed(t)  # (batch, time_embed_dim)
        
#         # 2. 文本条件编码
#         text_encoded = self.text_encoder(text_condition)  # (batch, text_len, text_embed_dim)
        
#         # 3. 融合文本和时间条件
#         # 文本编码取平均作为全局特征
#         text_avg = text_encoded.mean(dim=1)  # (batch, text_embed_dim)
#         # 拼接文本和时间嵌入
#         cond = torch.cat([text_avg, t_emb], dim=-1)  # (batch, text_embed_dim + time_embed_dim)
#         cond = self.condition_fuse(cond)  # (batch, seq_embed_dim)
#         # 扩展到序列长度（与xt_emb对齐）
#         cond = cond.unsqueeze(1).repeat(1, xt_emb.size(1), 1)  # (batch, seq_len, seq_embed_dim)
        
#         # 4. 融合加噪嵌入与条件
#         x_cond = xt_emb + cond  # (batch, seq_len, seq_embed_dim)
        
#         # 5. Transformer解码（用文本编码作为memory增强条件）
#         output_emb = self.decoder(tgt=x_cond, memory=text_encoded)  # (batch, seq_len, seq_embed_dim)
        
#         return output_emb


#     def p_loss(self, x0, text_condition):
#         """计算扩散损失（基于嵌入空间的MSE）"""
#         batch_size = x0.size(0)
#         device = x0.device
        
#         # 1. 随机采样时间步
#         t = torch.randint(0, self.config.T, (batch_size,), device=device)
        
#         # 2. 原始序列嵌入
#         x0_emb = self.seq_embed(x0)  # (batch, seq_len, seq_embed_dim)
        
#         # 3. 前向扩散：在嵌入空间添加噪声
#         noise = torch.randn_like(x0_emb)  # 高斯噪声
#         # 提取当前时间步的系数
#         sqrt_alpha_bar = self.sqrt_alphas_bar[t][:, None, None]  # (batch, 1, 1)
#         sqrt_one_minus_alpha_bar = self.sqrt_one_minus_alphas_bar[t][:, None, None]  # (batch, 1, 1)
#         # 计算加噪嵌入xt_emb
#         xt_emb = sqrt_alpha_bar * x0_emb + sqrt_one_minus_alpha_bar * noise  # (batch, seq_len, seq_embed_dim)
        
#         # 4. 模型预测原始嵌入
#         pred_x0_emb = self.forward_emb(xt_emb, t, text_condition)  # (batch, seq_len, seq_embed_dim)
        
#         # 5. 计算MSE损失（忽略填充位置）
#         mask = (x0 != self.pad_idx).float().unsqueeze(-1)  # (batch, seq_len, 1)：填充位置为0
#         loss = F.mse_loss(pred_x0_emb * mask, x0_emb * mask, reduction='mean')
        
#         return loss


#     @torch.no_grad()
#     def p_sample(self, text_condition):
#         """反向扩散采样：从噪声嵌入生成蛋白质序列"""
#         device = next(self.parameters()).device
#         batch_size = text_condition.size(0)
#         seq_len = self.config.max_seq_len
        
#         # 1. 初始噪声嵌入（从标准正态分布采样）
#         xt_emb = torch.randn((batch_size, seq_len, self.config.seq_embed_dim), device=device)
        
#         # 2. 逐步去噪
#         for t_step in reversed(range(self.config.T)):
#             t = torch.full((batch_size,), t_step, device=device, dtype=torch.long)  # 当前时间步
            
#             # 预测原始嵌入
#             pred_x0_emb = self.forward_emb(xt_emb, t, text_condition)  # (batch, seq_len, seq_embed_dim)
            
#             # 计算反向扩散系数
#             alpha = self.alphas[t_step]
#             alpha_bar = self.alphas_bar[t_step]
#             beta = self.betas[t_step]
            
#             # 采样噪声（最后一步用0噪声）
#             if t_step > 0:
#                 noise = torch.randn_like(xt_emb)
#             else:
#                 noise = torch.zeros_like(xt_emb)
            
#             # 反向更新公式：xt-1 = (1/sqrt(alpha)) * (xt - (1-alpha)/sqrt(1-alpha_bar) * pred_x0) + sqrt(beta) * noise
#             xt_emb = (1 / torch.sqrt(alpha)) * (
#                 xt_emb - ((1 - alpha) / torch.sqrt(1 - alpha_bar)) * pred_x0_emb
#             ) + torch.sqrt(beta) * noise
        
#         # 3. 将最终嵌入映射到氨基酸索引
#         logits = self.output_layer(xt_emb)  # (batch, seq_len, vocab_size)
#         return torch.argmax(logits, dim=-1)  # 取概率最大的索引

In [None]:
# # 自定义数据集（添加词汇表构建）
# class ProteinDataset(Dataset):
#     def __init__(self, sequences, descriptions, seq_tokenizer, build_vocab=False):
#         self.sequences = sequences
#         self.descriptions = descriptions 
#         self.seq_tokenizer = seq_tokenizer
        
#         # 构建单词词汇表
#         if build_vocab:
#             self.word_counts = self.build_word_counts(descriptions)
#             self.text_tokenizer = WordTokenizer(self.word_counts)
#         else:
#             self.text_tokenizer = None
    
#     def build_word_counts(self, descriptions):
#         """从描述文本中统计词频"""
#         word_counts = Counter()
#         for desc in descriptions:
#             # 简单分词（实际使用应与WordTokenizer相同的分词逻辑）
#             words = re.findall(r"\b[\w'-]+\b|[^\w\s]", desc.lower())
#             word_counts.update(words)
#         return word_counts
    
#     def __len__(self):
#         return len(self.sequences)
    
#     def __getitem__(self, idx):
#         seq = self.sequences[idx]
#         desc = self.descriptions[idx]
        
#         seq_indices = self.seq_tokenizer.encode(seq)
#         desc_indices = self.text_tokenizer.encode(desc)
        
#         return {
#             "sequence": torch.tensor(seq_indices, dtype=torch.long),
#             "description": torch.tensor(desc_indices, dtype=torch.long)
#         }

In [None]:
# # # 5. 训练函数
# # def train(model, dataloader, optimizer, device, epochs=10):
# #     model.train()
# #     for epoch in range(epochs):
# #         total_loss = 0
# #         for batch in dataloader:
# #             sequences = batch["sequence"].to(device)
# #             descriptions = batch["description"].to(device)
            
# #             optimizer.zero_grad()
# #             loss = model.p_loss(sequences, descriptions)
# #             loss.backward()
# #             optimizer.step()
            
# #             total_loss += loss.item()
        
# #         print(f"Epoch {epoch+1}/{epochs} | Loss: {total_loss/len(dataloader):.4f}")
# # 训练函数（带验证集早停）
# def train(model, train_loader, val_loader, optimizer, device, epochs=50, patience=5):
#     model.train()
#     best_val_loss = float('inf')
#     counter = 0  # 早停计数器
    
#     for epoch in range(epochs):
#         # 训练阶段
#         train_loss = 0.0
#         model.train()
#         for batch in train_loader:
#             sequences = batch["sequence"].to(device)
#             descriptions = batch["description"].to(device)
            
#             optimizer.zero_grad()
#             loss = model.p_loss(sequences, descriptions)
#             loss.backward()
#             optimizer.step()
            
#             train_loss += loss.item() * sequences.size(0)  # 累计总损失
#         avg_train_loss = train_loss / len(train_loader.dataset)
        
#         # 验证阶段
#         val_loss = 0.0
#         model.eval()
#         with torch.no_grad():
#             for batch in val_loader:
#                 sequences = batch["sequence"].to(device)
#                 descriptions = batch["description"].to(device)
                
#                 loss = model.p_loss(sequences, descriptions)
#                 val_loss += loss.item() * sequences.size(0)
#         avg_val_loss = val_loss / len(val_loader.dataset)
        
#         # 打印日志
#         print(f"Epoch {epoch+1}/{epochs}")
#         print(f"Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f}")
        
#         # 早停判断（基于验证损失）
#         if avg_val_loss < best_val_loss:
#             best_val_loss = avg_val_loss
#             counter = 0
#             torch.save(model.state_dict(), "best_protein_model.pth")  # 保存最优模型
#             print("Saved best model (val loss improved)")
#         else:
#             counter += 1
#             if counter >= patience:
#                 print(f"Early stopping at epoch {epoch+1} (no improvement for {patience} epochs)")
#                 break
    
#     # 加载最优模型
#     model.load_state_dict(torch.load("best_protein_model.pth"))
#     return model

In [None]:
# # 生成函数（根据文本描述生成蛋白质序列）
# def generate_sequence(model, description, device):
#     model.eval()
    
#     # 编码文本描述
#     desc_indices = model.text_tokenizer.encode(description)
#     desc_tensor = torch.tensor([desc_indices], dtype=torch.long).to(device)  # (1, max_text_len)
    
#     # 采样生成序列索引
#     with torch.no_grad():
#         generated_indices = model.p_sample(desc_tensor)  # (1, max_seq_len)
    
#     # 解码为氨基酸序列
#     return model.seq_tokenizer.decode(generated_indices[0].cpu().numpy())

In [None]:
# # 主程序（更新数据加载逻辑）
# if __name__ == "__main__":
#     device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#     print(f"Using device: {device}")
    
#     # 初始化序列分词器
#     seq_tokenizer = SequenceTokenizer()
#     print(f"Sequence vocabulary size: {seq_tokenizer.vocab_size}")
    
#     # 加载数据
#     tsv_path = "uniprot-data.tsv"
#     try:
#         df = pd.read_csv(tsv_path, sep='\t')
#         print(f"Loaded TSV with columns: {', '.join(df.columns)}")
#     except Exception as e:
#         print(f"Error loading TSV: {e}")
#         exit(1)
    
#     # 提取序列和描述
#     try:
#         seq_col = "Sequence"
#         desc_col = "Protein names"
#         sequences = df[seq_col].dropna().tolist()
#         descriptions = df[desc_col].dropna().tolist()
        
#         min_len = min(len(sequences), len(descriptions))
#         sequences = sequences[:min_len]
#         descriptions = descriptions[:min_len]
#         print(f"Extracted {len(sequences)} raw sequence-description pairs")
#     except KeyError as e:
#         print(f"Missing column in TSV: {e}")
#         exit(1)
    
#     # 过滤短序列
#     filtered_sequences = []
#     filtered_descriptions = []
#     for seq, desc in zip(sequences, descriptions):
#         if 20 < len(seq) <= config.max_seq_len:
#             filtered_sequences.append(seq)
#             filtered_descriptions.append(desc)
#     print(f"Filtered to {len(filtered_sequences)} valid pairs (length 20~{config.max_seq_len})")
    
#     # 拆分训练集和验证集
#     train_seqs, val_seqs, train_descs, val_descs = train_test_split(
#         filtered_sequences, 
#         filtered_descriptions, 
#         test_size=config.val_split,
#         random_state=42
#     )
    
#     # 创建数据集（在训练集上构建词汇表）
#     train_dataset = ProteinDataset(train_seqs, train_descs, seq_tokenizer, build_vocab=True)
#     text_tokenizer = train_dataset.text_tokenizer
#     print(f"\nText vocabulary size: {text_tokenizer.vocab_size}")
#     print(f"Top 10 words: {text_tokenizer.vocab[4:14]}")  # 跳过特殊标记
    
#     # 验证集使用相同的词汇表
#     val_dataset = ProteinDataset(val_seqs, val_descs, seq_tokenizer)
#     val_dataset.text_tokenizer = text_tokenizer
    
#     train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, drop_last=True)
#     val_loader = DataLoader(val_dataset, batch_size=config.batch_size, shuffle=False)
    
#     # 初始化模型和优化器
#     model = DiffusionModel(config, text_tokenizer, seq_tokenizer).to(device)
#     optimizer = torch.optim.Adam(
#         model.parameters(), 
#         lr=config.lr,
#         weight_decay=config.weight_decay
#     )
    
#     # 训练模型
#     print("\nStarting training...")
#     model = train(model, train_loader, val_loader, optimizer, device, epochs=50, patience=5)
    
#     # 保存完整模型
#     torch.save({
#         'model_state_dict': model.state_dict(),
#         'text_tokenizer': text_tokenizer,
#         'seq_tokenizer': seq_tokenizer,
#         'config': config
#     }, "protein_diffusion_model_final.pth")
#     print("\nModel saved to 'protein_diffusion_model_final.pth'")
    
#     # 生成示例序列
#     print("\nGenerating sample sequences...")
#     sample_descriptions = [
#         "DNA binding protein involved in transcription",
#         "Enzyme with catalytic activity for hydrolysis",
#         "Membrane transport protein for ions"
#     ]
    
#     for desc in sample_descriptions:
#         generated_seq = generate_sequence(model, desc, device)
#         print(f"\nDescription: {desc}")
#         print(f"Generated sequence: {generated_seq[:100]}...")
#         print(f"Sequence length: {len(generated_seq)}")