這是為了測試對於真實訓練數據我們該採用對於scedule的參數以及diffusion step的一個測試

In [None]:
from polydiff.diffusion.forward import ExcitedStateDiffusion
from polydiff.diffusion.schedules import ExponentialSchedule
from polydiff.model import ChemBERTTokenizer

In [None]:
data_file = "/Users/wang-work/polydiffusion/data/smiles.txt"
data = []
with open(data_file, "r") as f:
    for line in f:
        line = line.strip()
        if line:  # Check if the line is not empty
            data.append(line)

data

In [None]:
def initialize_diffusion(beta_start=1e-4, beta_end=0.02, num_timesteps=1000, vocab_size=100, mask_token_id=99, pad_token_id=0):
    """
    初始化diffusion過程
    
    Args:
        beta_start: beta的起始值
        beta_end: beta的結束值  
        num_timesteps: diffusion步數
        vocab_size: 詞彙表大小
        mask_token_id: mask token ID
        pad_token_id: pad token ID
    
    Returns:
        ExcitedStateDiffusion: 初始化好的diffusion實例
    """
    # 創建指數schedule
    schedule = ExponentialSchedule(
        num_timesteps=num_timesteps,
        beta_start=beta_start,
        beta_end=beta_end
    )
    
    # 初始化ExcitedStateDiffusion
    diffusion = ExcitedStateDiffusion(
        schedule=schedule,
        vocab_size=vocab_size,
        mask_token_id=mask_token_id,
        pad_token_id=pad_token_id
    )
    
    print(f"ExcitedStateDiffusion已初始化！")
    print(f"- Beta範圍: {beta_start} -> {beta_end}")
    print(f"- 時間步數: {num_timesteps}")
    print(f"- 詞彙表大小: {vocab_size}")
    
    return diffusion


In [None]:
def select_sequences_by_length(data, shortest_count=20, middle_count=50, longest_count=20, random_seed=42):
    """
    從數據中選擇不同長度的序列
    
    Args:
        data: 序列數據列表
        shortest_count: 最短序列數量
        middle_count: 中間序列數量  
        longest_count: 最長序列數量
        random_seed: 隨機種子
    
    Returns:
        list: 選中的序列列表
    """
    import random
    
    # 計算每個序列的長度
    data_lengths = [(i, len(seq)) for i, seq in enumerate(data)]
    data_lengths.sort(key=lambda x: x[1])
    
    # 找最短的序列
    shortest_sequences = data_lengths[:shortest_count]
    
    # 找最長的序列
    longest_sequences = data_lengths[-longest_count:]
    
    # 中間隨機選擇序列
    middle_range = data_lengths[shortest_count:-longest_count]
    random.seed(random_seed)
    middle_sequences = random.sample(middle_range, min(middle_count, len(middle_range)))
    middle_sequences.sort(key=lambda x: x[1])
    
    # 收集所有選中的序列
    selected_sequences = []
    selected_sequences.extend([data[idx] for idx, _ in shortest_sequences])
    selected_sequences.extend([data[idx] for idx, _ in middle_sequences])
    selected_sequences.extend([data[idx] for idx, _ in longest_sequences])
    
    return selected_sequences

# 運行函數獲取選中的序列
selected_sequences = select_sequences_by_length(data)  # Ensure 'data' is defined in a previous cell
print(f"總共選取了 {len(selected_sequences)} 條序列")
selected_sequences


In [None]:
from transformers import AutoTokenizer
import torch
tokenizer = AutoTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
def initialize_and_test_diffusion(sequences, 
                                  tokenizer=tokenizer,
                                  beta_start=1e-4, 
                                  beta_end=0.02, 
                                  num_timesteps=1000, 
                                  vocab_size=tokenizer.vocab_size, 
                                  mask_token_id=tokenizer.mask_token_id, 
                                  pad_token_id=tokenizer.pad_token_id):
    """
    初始化diffusion過程並測試每個序列在多少步驟後會被完全mask
    
    Args:
        sequences: 要測試的序列列表
        beta_start: beta的起始值
        beta_end: beta的結束值  
        num_timesteps: diffusion步數
        vocab_size: 詞彙表大小
        mask_token_id: mask token ID
        pad_token_id: pad token ID
        max_steps: 最大測試步數
    
    Returns:
        tuple: 初始化好的diffusion實例和每個序列的完全mask步數
    """
    # 創建指數schedule
    schedule = ExponentialSchedule(
        num_timesteps=num_timesteps,
        beta_start=beta_start,
        beta_end=beta_end
    )
    
    # 初始化ExcitedStateDiffusion
    diffusion = ExcitedStateDiffusion(
        schedule=schedule,
        vocab_size=vocab_size,
        mask_token_id=mask_token_id,
        pad_token_id=pad_token_id
    )
    
    print(f"ExcitedStateDiffusion已初始化！")
    print(f"- Beta範圍: {beta_start} -> {beta_end}")
    print(f"- 時間步數: {num_timesteps}")
    print(f"- 詞彙表大小: {vocab_size}")
    
    results = {}
    
    for i, seq in enumerate(sequences):
        # tokenize序列
        tokens = tokenizer.encode(seq, add_special_tokens=True, return_tensors="pt")
        tokens = tokens[:, 1:-1]  # Remove the first and last tokens (0 and 2)
        # 測試不同步數
        for t in range(1, num_timesteps ):
            # 創建時間步張量
            timesteps = torch.tensor([t])
            
            # 應用forward diffusion
            noisy_tokens,_ = diffusion.forward_mask_process(tokens, timesteps)
            
            # 檢查是否完全被mask
            mask_token_id = tokenizer.mask_token_id
            if (noisy_tokens == mask_token_id).all() and (tokens != tokenizer.pad_token_id).any():
                results[seq] = t
                break
        else:
            results[seq] = -99 # 如果沒有完全mask，則記錄最大步數

    return diffusion, results

beta_start = [0.01, 0.05, 0.1]
beta_end = [0.1, 0.2, 0.3]
num_timesteps = [20, 30, 40, 50]
for b_start in beta_start:
    for b_end in beta_end:
        for n_steps in num_timesteps:
            print(f"Testing with beta_start={b_start}, beta_end={b_end}, num_timesteps={n_steps}")
            diffusion, results = initialize_and_test_diffusion(
                selected_sequences,
                tokenizer=tokenizer,
                beta_start=b_start,
                beta_end=b_end,
                num_timesteps=n_steps
            )
            # Create a scatter plot for token lengths vs diffusion steps
            import matplotlib.pyplot as plt
            token_lengths = [len(tokenizer.encode(seq, add_special_tokens=True)) - 2 for seq in selected_sequences]

            plt.figure(figsize=(10, 6))
            plt.scatter(token_lengths, list(results.values()), alpha=0.6)
            plt.title('Token Length vs Diffusion Steps')
            plt.xlabel('Token Length')
            plt.ylabel('Diffusion Steps')
            plt.grid()
            plt.show()