In [None]:
# ==============================
# PubMedBERT CPT (MLM only) + RAdam + loss logging
# + Validation Loss & Early Stopping
# ==============================
from transformers import (
    BertTokenizerFast,
    BertForMaskedLM,
    DataCollatorForLanguageModeling,
    Trainer,
    TrainingArguments,
    EarlyStoppingCallback,  # <--- 追加
)
from datasets import Dataset
import torch
import os
import pandas as pd
import itertools
import sqlite3

# ---- パラメータ探索設定 ----
param_grid = {
    "learning_rate": [2e-05],
    "weight_decay": [0.0001],
    "warmup_ratio": [0.1],
}

# ---- 学習ループ ----
for lr, wd, warmup in itertools.product(
    param_grid["learning_rate"],
    param_grid["weight_decay"],
    param_grid["warmup_ratio"]
):

    # ---- トークナイザの準備 ----
    model_name = "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext"
    tokenizer = BertTokenizerFast.from_pretrained(model_name)
    tokenizer.add_tokens(["[CELL0]", "[CELL1]"])

    cpt_path = "../../00_pubmed_ext_2cells/data_for_annotation.csv"
    db_file_path = "../../00_pubmed_ext_2cells/hit_cells.db"

    conn = sqlite3.connect(db_file_path)
    query = "SELECT pmid, sent_id, ann_sentence as sentence FROM hit_cells" 
    db_df = pd.read_sql_query(query, conn)
    conn.close()

    db_df['sentence'] = db_df['sentence'].str.replace(r"[【】]", "", regex=True)
    text_lookup_dict = dict(zip(zip(db_df['pmid'], db_df['sent_id']), db_df['sentence']))

    # --- データ読み込み ---
    cpt_df = pd.read_csv(cpt_path)

    indices_to_modify = cpt_df.sample(frac=0.5, random_state=42).index

    def get_db_text(row):
        key = (row['pmid'], row['sent_id'])
        # DBに存在すればそのテキスト（【 or】削除済み）を返す。なければ元のテキストを維持
        return text_lookup_dict.get(key, row['Cleaned_Sentence'])

    new_sentences = cpt_df.loc[indices_to_modify].apply(get_db_text, axis=1)
    cpt_df.loc[indices_to_modify, 'Cleaned_Sentence'] = new_sentences

    print(cpt_df.head())

    dataset = Dataset.from_pandas(cpt_df)

    # --- トークナイズ ---
    def tokenize_function(examples):
        return tokenizer(
            examples["Cleaned_Sentence"],
            truncation=True,
            padding="max_length",
            max_length=128
        )

    tokenized_dataset_full = dataset.map(  # <--- 変更 (変数名)
        tokenize_function,
        batched=True,
        remove_columns=["pmid","sent_id","ID_A","ID_B","Selected_Term_A","Selected_Term_B","Cleaned_Sentence"]
    )

    # --- 検証用データセットの分割 (95% train, 5% eval) --- # <--- 追加
    split_dataset = tokenized_dataset_full.train_test_split(test_size=0.05, seed=42)
    train_dataset = split_dataset["train"]
    eval_dataset = split_dataset["test"]
    print(f"Train dataset size: {len(train_dataset)}, Eval dataset size: {len(eval_dataset)}")
    # --- ここまで追加 ---

    # ---- データコラトラ ----
    data_collator = DataCollatorForLanguageModeling(
        tokenizer=tokenizer,
        mlm=True,
        mlm_probability=0.15
    )

    exp_name = f"lr{lr}_wd{wd}_warmup{warmup}"
    output_dir = f"./pretrain_half/{exp_name}"
    os.makedirs(output_dir, exist_ok=True)

    # ---- モデル読み込み ----
    model = BertForMaskedLM.from_pretrained(model_name)
    model.resize_token_embeddings(len(tokenizer))

    # ---- 早期停止コールバック ---- # <--- 追加
    early_stopping_callback = EarlyStoppingCallback(
        early_stopping_patience=5,  # 5回連続でeval_lossが改善しなかったら停止
        early_stopping_threshold=0.001 # 0.001未満の改善は「改善なし」とみなす
    )
    # --- ここまで追加 ---

    # ---- TrainingArguments ----
    training_args = TrainingArguments(
        output_dir=output_dir,
        overwrite_output_dir=True,
        num_train_epochs=2,                 # 早期停止するため、ここは最大エポック数を指定
        per_device_train_batch_size=128,
        gradient_accumulation_steps=4,
        learning_rate=lr,
        weight_decay=wd,
        warmup_ratio=warmup,
        lr_scheduler_type="linear",
        
        logging_steps=200,                  # 200ステップごとにログ記録
        eval_strategy="steps",        # <--- 追加: ステップごとに評価
        eval_steps=200,                     # <--- 追加: 200ステップごとに評価（logging_stepsと合わせる）
        save_strategy="steps",              # <--- 変更: ステップごとに保存（評価と合わせる）
        save_steps=200,                     # <--- 変更: 200ステップごとに保存
        
        load_best_model_at_end=True,        # <--- 追加: 学習終了時に最良モデルをロード
        metric_for_best_model="eval_loss",  # <--- 追加: 最良モデルの指標はeval_loss
        greater_is_better=False,            # <--- 追加: lossは低い方が良い
        save_total_limit=2,                 # チェックポイントは2つまで
        
        dataloader_num_workers=8,
        fp16=True,
        report_to="none",
        optim="schedule_free_radam",
        logging_dir=f"{output_dir}/logs",
        log_level="info",
    )

    # ---- Trainer ----
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,   # <--- 変更: 訓練用データ
        eval_dataset=eval_dataset,     # <--- 追加: 検証用データを追加
        data_collator=data_collator,
        callbacks=[early_stopping_callback] # <--- 追加: 早期停止コールバック
    )

    # ---- 学習 ----
    print(f"\n=== Training {exp_name} ===")
    train_result = trainer.train()
    
    # ---- 最終メトリクス保存 ---- # <--- 変更
    metrics = train_result.metrics
    metrics["final_step"] = trainer.state.global_step # 早期停止した場合のステップ数を記録
    pd.DataFrame([metrics]).to_csv(f"{output_dir}/final_training_metrics.csv", index=False)

    # ---- ステップごとの全ログ履歴を保存 ---- # <--- 追加
    log_history_df = pd.DataFrame(trainer.state.log_history)
    log_history_df.to_csv(f"{output_dir}/full_log_history.csv", index=False)

    # ---- 保存 ----
    # load_best_model_at_end=True のため、trainer.modelは既に検証ロス最小のモデル
    trainer.save_model(output_dir)
    tokenizer.save_pretrained(output_dir)
    print(f"Saved best model (based on eval_loss) and tokenizer to {output_dir}") # <--- 変更