In [2]:
ls

11tokens_BGCs.txt  perplexity_analysis.ipynb


In [8]:
import csv
import torch
from tokenizers import Tokenizer
from transformers import RobertaConfig, RobertaForMaskedLM, PreTrainedTokenizerFast
import random
import numpy as np
from tqdm import tqdm  # 進捗表示のため

# GPUの設定
device = torch.device('mps' if torch.cuda.is_available() else 'cpu')

# ボキャブラリーの読み込み

fin = '/Users/yuikoshitak/Documents/GraduateSchool/Master_Research/prj-evo_design/model_parameter/vocab_pfam_product_hmmer33.json'

#fin = '/working.3/Kawano/vocab_pfam_product_hmmer33_revised.json'

# トークナイザーを定義
tokenizer_obj = Tokenizer.from_file(fin)
tokenizer_obj.add_special_tokens(["MASK"])
tokenizer_obj.add_special_tokens(["[PAD]"])
tokenizer = PreTrainedTokenizerFast(tokenizer_object=tokenizer_obj)
tokenizer.mask_token = "[MASK]"
tokenizer.pad_token = "[PAD]"
tokenizer.sep_token = "[SEP]"

# RoBERTaのモデルの設定
config = RobertaConfig(
    vocab_size=19723,
    max_position_embeddings=514,
    hidden_size=1024,
    num_attention_heads=16,
    num_hidden_layers=8,
    type_vocab_size=1
)
model = RobertaForMaskedLM(config).to(device)

# 学習済みパラメータのロード
model.load_state_dict(torch.load('/Users/yuikoshitak/Documents/GraduateSchool/Master_Research/prj-evo_design/model_parameter/RoBERTa_bacteria+fungi_20240918_epoch150.pth', map_location=torch.device('mps')))
#model.load_state_dict(torch.load('/working.3/Kawano/model-parameter/RoBERTa_bacteria+fungi20240918_epoch150.pth', map_location=torch.device('cuda')))

# 入力ファイル
input_file = "11tokens_BGCs.txt"

# トークナイザーの全語彙を取得
vocab = list(tokenizer.get_vocab().keys())

# 結果を保存するリスト
all_results = []

# テキストファイルを一行ずつ処理
with open(input_file, 'r', encoding='utf-8') as f:
    for line in tqdm(f, desc="Processing lines"):
        line = line.strip()
        if not line:
            continue

        # 行のトークンを分割し、3単語目以降を取得
        tokens = line.split()[2:]
        bgc_id = line.split()[0]

        # [SEP]を除いたトークンでターゲットを選定
        candidates = [token for token in tokens if token != "[SEP]"]
        max_probability = -1
        target_token = None
        token_count = len(candidates)

        for idx in range(len(tokens)):
            token = tokens[idx]
            # MASKして予測
            masked_tokens = tokens.copy()
            if masked_tokens[idx] == "[SEP]":
                continue
            #token_idx = masked_tokens.index(token)
            masked_tokens[idx] = "[MASK]"

            tokens_id = tokenizer.convert_tokens_to_ids(masked_tokens)
            tokens_tensor = torch.tensor([tokens_id]).to(device)

            with torch.no_grad():
                outputs = model(tokens_tensor)
                predictions = outputs[0]

            # MASK位置の予測スコアを取得
            softmax_scores = torch.softmax(predictions[0, idx], dim=-1)
            token_id = tokenizer.convert_tokens_to_ids([token])[0]

            if softmax_scores[token_id].item() > max_probability:
                max_probability = softmax_scores[token_id].item()
                target_token = token
                target_idx = idx

        # ターゲットトークンが見つからない場合はスキップ
        if target_token is None:
            continue

        # 行の結果を保存するリスト
        line_results = [bgc_id, target_token]

        for num_replacements in range(0, token_count):
            avg_perplexities = []
        
            for _ in range(10):
                # トークンリストをコピー
                modified_tokens = tokens.copy()
        
                # ランダムに置換
                if num_replacements > 0:
                    available_indices = [i for i in range(len(modified_tokens)) 
                         if i != target_idx and modified_tokens[i] != "[SEP]"]  
                    replace_indices = random.sample(available_indices, num_replacements)
                    for idx in replace_indices:
                        if modified_tokens[idx] != target_token:
                            sampled_token = random.choice(vocab)
                            modified_tokens[idx] = sampled_token

                # ターゲットトークンをMASK
                masked_tokens = modified_tokens.copy()
                masked_tokens[target_idx] = "[MASK]"
                # トークン列をテンソルに変換
                
                tokens_id = tokenizer.convert_tokens_to_ids(masked_tokens)
                tokens_tensor = torch.tensor([tokens_id]).to(device)
        
                # モデルに入力して予測
                with torch.no_grad():
                    outputs = model(tokens_tensor)
                    predictions = outputs[0]
        
                # 各トークンの負の対数確率を計算
                log_probs = []
                for idx, token in enumerate(modified_tokens):
                    token_id = tokenizer.convert_tokens_to_ids([target_token])[0]
                    softmax_scores = torch.softmax(predictions[0, target_idx], dim=-1)
                    token_prob = softmax_scores[token_id].item()
        
                    # 確率がゼロになることを防ぐための処理（log(0)を避ける）
                    token_prob = max(token_prob, 1e-9)
                    log_probs.append(-np.log(token_prob))  # 負の対数確率を保存
        
                # パープレキシティを計算
                perplexity = np.exp(np.mean(log_probs))
                avg_perplexities.append(perplexity)
        
            # 100回の実験の平均パープレキシティを取得
            mean_perplexity = np.mean(avg_perplexities)
            line_results.append(mean_perplexity)

        # 行ごとの結果を追加
        all_results.append(line_results)

# 結果をCSVに保存
output_file = "perplexity_11domains_BGCs.csv"
with open(output_file, mode='w', newline='', encoding='utf-8') as file:
    writer = csv.writer(file)

    # ヘッダー行を書き込む
    headers = ["BGC_ID", "masked_token"] + [f"Num_Replacements={i}" for i in range(11)]
    writer.writerow(headers)

    # 各行の結果を書き込む
    writer.writerows(all_results)

print("Done")

Processing lines: 3it [00:15,  5.14s/it]We strongly recommend passing in an `attention_mask` since your input_ids may be padded. See https://huggingface.co/docs/transformers/troubleshooting#incorrect-output-when-padding-tokens-arent-masked.
Processing lines: 71it [06:03,  5.11s/it]

Done





In [7]:
available_indices
num_replacements

11

In [10]:
# 入力ファイル
input_file = "11_domains_sets_genome.txt"

# トークナイザーの全語彙を取得
vocab = list(tokenizer.get_vocab().keys())

# 結果を保存するリスト
all_results = []

# テキストファイルを一行ずつ処理
with open(input_file, 'r', encoding='utf-8') as f:
    for line in tqdm(f, desc="Processing lines"):
        line = line.strip()
        if not line:
            continue

        # 行のトークンを分割し、3単語目以降を取得
        tokens = line.split()[0:]
        #bgc_id = line.split()[0]

        # [SEP]を除いたトークンでターゲットを選定
        candidates = [token for token in tokens if token != "[SEP]"]
        max_probability = -1
        target_token = None
        token_count = len(candidates)

        for idx in range(len(tokens)):
            token = tokens[idx]
            # MASKして予測
            masked_tokens = tokens.copy()
            if masked_tokens[idx] == "[SEP]":
                continue
            #token_idx = masked_tokens.index(token)
            masked_tokens[idx] = "[MASK]"

            tokens_id = tokenizer.convert_tokens_to_ids(masked_tokens)
            tokens_tensor = torch.tensor([tokens_id]).to(device)

            with torch.no_grad():
                outputs = model(tokens_tensor)
                predictions = outputs[0]

            # MASK位置の予測スコアを取得
            softmax_scores = torch.softmax(predictions[0, idx], dim=-1)
            token_id = tokenizer.convert_tokens_to_ids([token])[0]

            if softmax_scores[token_id].item() > max_probability:
                max_probability = softmax_scores[token_id].item()
                target_token = token
                target_idx = idx

        # ターゲットトークンが見つからない場合はスキップ
        if target_token is None:
            continue

        # 行の結果を保存するリスト
        line_results = [target_token]

        for num_replacements in range(0, token_count):
            avg_perplexities = []
        
            for _ in range(10):
                # トークンリストをコピー
                modified_tokens = tokens.copy()
        
                # ランダムに置換
                if num_replacements > 0:
                    available_indices = [i for i in range(len(modified_tokens)) 
                         if i != target_idx and modified_tokens[i] != "[SEP]"]  
                    replace_indices = random.sample(available_indices, num_replacements)
                    for idx in replace_indices:
                        if modified_tokens[idx] != target_token:
                            sampled_token = random.choice(vocab)
                            modified_tokens[idx] = sampled_token

                # ターゲットトークンをMASK
                masked_tokens = modified_tokens.copy()
                masked_tokens[target_idx] = "[MASK]"
                # トークン列をテンソルに変換
                
                tokens_id = tokenizer.convert_tokens_to_ids(masked_tokens)
                tokens_tensor = torch.tensor([tokens_id]).to(device)
        
                # モデルに入力して予測
                with torch.no_grad():
                    outputs = model(tokens_tensor)
                    predictions = outputs[0]
        
                # 各トークンの負の対数確率を計算
                log_probs = []
                for idx, token in enumerate(modified_tokens):
                    token_id = tokenizer.convert_tokens_to_ids([target_token])[0]
                    softmax_scores = torch.softmax(predictions[0, target_idx], dim=-1)
                    token_prob = softmax_scores[token_id].item()
        
                    # 確率がゼロになることを防ぐための処理（log(0)を避ける）
                    token_prob = max(token_prob, 1e-9)
                    log_probs.append(-np.log(token_prob))  # 負の対数確率を保存
        
                # パープレキシティを計算
                perplexity = np.exp(np.mean(log_probs))
                avg_perplexities.append(perplexity)
        
            # 100回の実験の平均パープレキシティを取得
            mean_perplexity = np.mean(avg_perplexities)
            line_results.append(mean_perplexity)

        # 行ごとの結果を追加
        all_results.append(line_results)


# 結果をCSVに保存
output_file = "perplexity_11domains_genome.csv"
with open(output_file, mode='w', newline='', encoding='utf-8') as file:
    writer = csv.writer(file)

    # ヘッダー行を書き込む
    headers = ["masked_token"] + [f"Num_Replacements={i}" for i in range(11)]
    writer.writerow(headers)

    # 各行の結果を書き込む
    writer.writerows(all_results)

print("Done")

Processing lines: 100it [15:57,  9.58s/it]

Done





In [5]:
# 入力ファイル
input_file = "synthetic_domain_sets-Copy1.txt"

# トークナイザーの全語彙を取得
vocab = list(tokenizer.get_vocab().keys())

# 結果を保存するリスト
all_results = []

# テキストファイルを一行ずつ処理
with open(input_file, 'r', encoding='utf-8') as f:
    for line in tqdm(f, desc="Processing lines"):
        line = line.strip()
        if not line:
            continue

        # 行のトークンを分割し、3単語目以降を取得
        tokens = line.split()[2:]
        bgc_id = line.split()[0]

        # [SEP]を除いたトークンでターゲットを選定
        candidates = [token for token in tokens if token != "[SEP]"]
        max_probability = -1
        target_token = None

        for token in candidates:
            # MASKして予測
            masked_tokens = tokens.copy()
            token_idx = masked_tokens.index(token)
            masked_tokens[token_idx] = "[MASK]"

            tokens_id = tokenizer.convert_tokens_to_ids(masked_tokens)
            tokens_tensor = torch.tensor([tokens_id]).to(device)

            with torch.no_grad():
                outputs = model(tokens_tensor)
                predictions = outputs[0]

            # MASK位置の予測スコアを取得
            softmax_scores = torch.softmax(predictions[0, token_idx], dim=-1)
            token_id = tokenizer.convert_tokens_to_ids([token])[0]

            if softmax_scores[token_id].item() > max_probability:
                max_probability = softmax_scores[token_id].item()
                target_token = token

        # ターゲットトークンが見つからない場合はスキップ
        if target_token is None:
            continue

        # 行の結果を保存するリスト
        line_results = [bgc_id, target_token]

        for num_replacements in range(0, 11):
            avg_perplexities = []
        
            for _ in range(100):
                # トークンリストをコピー
                modified_tokens = tokens.copy()
        
                # ランダムに置換
                if num_replacements > 0:
                    replace_indices = random.sample(range(len(modified_tokens)), num_replacements)
                    for idx in replace_indices:
                        if modified_tokens[idx] != target_token:
                            sampled_token = random.choice(vocab)
                            modified_tokens[idx] = sampled_token
        
                # トークン列をテンソルに変換
                tokens_id = tokenizer.convert_tokens_to_ids(modified_tokens)
                tokens_tensor = torch.tensor([tokens_id]).to(device)
        
                # モデルに入力して予測
                with torch.no_grad():
                    outputs = model(tokens_tensor)
                    predictions = outputs[0]
        
                # 各トークンの負の対数確率を計算
                log_probs = []
                for idx, token in enumerate(modified_tokens):
                    token_id = tokenizer.convert_tokens_to_ids([token])[0]
                    softmax_scores = torch.softmax(predictions[0, idx], dim=-1)
                    token_prob = softmax_scores[token_id].item()
        
                    # 確率がゼロになることを防ぐための処理（log(0)を避ける）
                    token_prob = max(token_prob, 1e-9)
                    log_probs.append(-np.log(token_prob))  # 負の対数確率を保存
        
                # パープレキシティを計算
                perplexity = np.exp(np.mean(log_probs))
                avg_perplexities.append(perplexity)
        
            # 100回の実験の平均パープレキシティを取得
            mean_perplexity = np.mean(avg_perplexities)
            line_results.append(mean_perplexity)

        # 行ごとの結果を追加
        all_results.append(line_results)

# 結果をCSVに保存
output_file = "1token_prediction_changing_prompt_10_synthetic_domains.csv"
with open(output_file, mode='w', newline='', encoding='utf-8') as file:
    writer = csv.writer(file)

    # ヘッダー行を書き込む
    headers = ["BGC_ID", "masked_token"] + [f"Num_Replacements={i}" for i in range(11)]
    writer.writerow(headers)

    # 各行の結果を書き込む
    writer.writerows(all_results)

print("Done")

Processing lines: 100it [08:37,  5.18s/it]

Done



