<a href="https://colab.research.google.com/github/waynelee9511cloud/my-colab-notebooks/blob/main/LLM_CDQC_v0_1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#@title 第一步：安裝環境與指定模型

!pip -q install -U "pandas==2.2.2"
!pip -q install -U "transformers>=4.50.0" accelerate huggingface_hub openpyxl "pillow<12.0" einops safetensors tqdm trl peft datasets

import os, json, re, time
import pandas as pd
import torch
from tqdm.auto import tqdm
from huggingface_hub import login
from transformers import AutoProcessor, AutoModelForCausalLM, AutoTokenizer

# ===========================
# 1) 基本設定
# ===========================
#MODEL_ID = "BioMistral/BioMistral-7B"  # BioMistral 7B model
#MODEL_ID = "WayneLee9511/WL-BMS-SFT-1400-1020-40epochs"  # Fine tune model: BioMistral_sft_out/final
MODEL_ID = "google/medgemma-4b-it"
#MODEL_ID = "WayneLee9511/medgemma-4b-it-sft-lora-crc100k" #medgemma fine tune
EXCEL_PATH = None  # 直接填入檔案路徑字串；若為 None 會開啟上傳視窗
OUTPUT_CSV = "validation_results.csv"
MAX_ROWS = None  # 例如 500，或設 None 全部處理
#DO_FINETUNE = True  # 若要微調改成 True（需自備/產生訓練資料）
DO_FINETUNE = False  # 若要微調改成 True（需自備/產生訓練資料）


PROMPT_CONFIG = {
    "MH": {
        "target_column": "[Diagnosis/ Conditions]",
        "error_type": "Incorrect diagnosis name",
        "example_error": "[Mypia] should be [Myopia].",
        "example_correct": "Hypertension"
    },
    "CM": {
        "target_column": "[Generic/Brand Name]",
        "error_type": "Incorrect drug name",
        "example_error": "[Lisnopril] should be [Lisinopril].",
        "example_correct": "Aspirin 100mg"
    },
    "AE": {
        "target_column": "[Event]",
        "error_type": "Incorrect event name",
        "example_error": "[Headchae] should be [Headache].",
        "example_correct": "Nausea"
    }
}



SYSTEM_PROMPT_ZH = (
    "You are a specialized tool for identifying spelling errors in specific medical data fields."
    "Your ONLY task is to check for spelling errors in the values of the following three fields: '[Diagnosis/ Conditions]', '[Event]', and '[Generic/Brand Name]'.",
    "You MUST IGNORE all other fields, such as IDs, dates, flags, sequence numbers, or any other columns. Do not report any issues for these other fields.",
    "If you find a spelling error, your reason MUST be specific and follow the format from the examples below. If there are no spelling errors in the specified fields, your summary must be 'No issues found.'.",
    "Your entire response MUST be a single, valid JSON object and nothing else.",

    # --- Example 1: Data with a diagnosis spelling error ---
    "Example Input Data:\n"
    "[Diagnosis/ Conditions]: Mypia\n",

    "Example JSON Output:\n"
    '{\n'
    '  "row_index": 0,\n'
    '  "issues": [\n'
    '    {\n'
    '      "field": "[Diagnosis/ Conditions]",\n'
    '      "reason": "Incorrect diagnosis name: [Mypia] should be [Myopia].",\n'
    '      "severity": "error"\n'
    '    }\n'
    '  ],\n'
    '  "summary": "Incorrect diagnosis name: [Mypia] should be [Myopia]."\n'
    '}',

    # --- Example 2: Data with a drug spelling error ---
    "Example Input Data:\n"
    "[Generic/Brand Name]: Lisnopril 10mg\n",

    "Example JSON Output:\n"
    '{\n'
    '  "row_index": 1,\n'
    '  "issues": [\n'
    '    {\n'
    '      "field": "[Generic/Brand Name]",\n'
    '      "reason": "Incorrect drug name: [Lisnopril] should be [Lisinopril].",\n'
    '      "severity": "error"\n'
    '    }\n'
    '  ],\n'
    '  "summary": "Incorrect drug name: [Lisnopril] should be [Lisinopril]."\n'
    '}',

    # --- Example 3: Data with no spelling issues ---
    "Example Input Data:\n"
    "[Diagnosis/ Conditions]: Hypertension\n",

    "Example JSON Output:\n"
    '{\n'
    '  "row_index": 2,\n'
    '  "issues": [],\n'
    '  "summary": "No issues found."\n'
    '}'
)


print("Packages installed successfully. Ready to load model.")

In [None]:
#@title  第二步：將模型掛載於google drive方便日後快速讀取

# ===========================
# 3) 載入模型與處理器（bf16 + 自動放到 GPU）
# - 使用 dtype 取代舊參數 torch_dtype
# ===========================
from google.colab import drive
drive.mount('/content/drive')

import os
os.environ['HF_HOME'] = '/content/drive/MyDrive/LLMcache6'
os.environ['TRANSFORMERS_CACHE'] = '/content/drive/MyDrive/LLMcache6'


#MODEL_ID = "BioMistral/BioMistral-7B"
#MODEL_ID = "WayneLee9511/WL-BMS-SFT-1400-1020-40epochs"
MODEL_ID = "google/medgemma-4b-it"
#MODEL_ID = "WayneLee9511/medgemma-4b-it-sft-lora-crc100k" #medgemma fine tune
torch_dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32

model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    torch_dtype=torch_dtype,
    device_map="auto"
)

# 載入 tokenizer 並修正 padding token 問題
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

processor = tokenizer
model.eval()

In [None]:
#@title  第三步：模型會跳出上傳按鍵，於讀Excel後產結果

from google.colab import files
print("請上傳 xlsx 檔案...")
uploaded = files.upload()
if not uploaded:
    raise RuntimeError("未上傳任何檔案")
EXCEL_PATH = next(iter(uploaded))
print("已選擇檔案：", EXCEL_PATH)

sheets = pd.read_excel(EXCEL_PATH, sheet_name=None, engine="openpyxl")
if MAX_ROWS is not None:
    for k in list(sheets.keys()):
        sheets[k] = sheets[k].head(MAX_ROWS)
total_rows = sum(len(df) for df in sheets.values())
print(f"已讀取工作表數：{len(sheets)}，總筆數：{total_rows}")


def row_to_yaml(row_dict):

    try:
        import yaml

        for k, v in row_dict.items():

            if "Timestamp" in str(type(v)):
                row_dict[k] = str(v)

    except Exception:
        return "\n".join([f"{k}: {v}" for k, v in row_dict.items()])
    return yaml.safe_dump(row_dict, allow_unicode=True, sort_keys=False)

def build_messages(row_index, row_dict, dynamic_system_prompt):
    filtered_dict = {key: value for key, value in row_dict.items() if key in dynamic_system_prompt}

    if not filtered_dict:
        return None

    yaml_text = row_to_yaml(filtered_dict)
    user_text = f"The following is a single data entry（YAML）：\\n{yaml_text}\\n Please review it according to the guidelines and return valid JSON."
    messages = [
        {"role": "user", "content": f"{dynamic_system_prompt}\\n{user_text}"}
    ]
    return messages

@torch.inference_mode()
def validate_one(row_index, row_dict):

    messages = build_messages(row_index, row_dict)


    prompt = processor.apply_chat_template(
        messages,
        add_generation_prompt=True,
        tokenize=False,
    )


    inputs = processor(
        text=prompt,
        return_tensors="pt",
        padding=True
    ).to(model.device, dtype=torch_dtype)


    input_len = inputs["input_ids"].shape[-1]
    gen_ids = model.generate(
        **inputs,
        max_new_tokens=512, # max_new_tokens 可從 128 增加到 512
        do_sample=False,
        temperature=None,
        repetition_penalty=1.05,
        use_cache=True,
    )
    text = processor.batch_decode(gen_ids[:, input_len:], skip_special_tokens=True)[0]


    jtxt = extract_json_str(text)
    parsed = None
    is_valid_json = False
    try:
        parsed = json.loads(jtxt) if jtxt else None
        is_valid_json = True if parsed else False
    except Exception:
        parsed = None
        is_valid_json = False

    return {
        "raw": text,
        "json_str": jtxt,
        "json_obj": parsed,
        "is_valid_json": is_valid_json
    }


def validate_many(start_index, rows, dynamic_system_prompt):
    prompts = []
    for row_index, row_dict in enumerate(rows):
        messages = build_messages(row_index, row_dict, dynamic_system_prompt)
        if messages is None:
            prompts.append(None)
            continue

        prompt = processor.apply_chat_template(
          messages,
          add_generation_prompt=True,
          tokenize=False
        )
        prompts.append(prompt)


    valid_prompts = [p for p in prompts if p is not None]
    if not valid_prompts:
        return [{"raw": "Skipped: No target column found.", "summary": "Skipped"} for _ in prompts]


    inputs = processor(text=prompts, return_tensors="pt", padding=True)
    for k in inputs.keys():
        if hasattr(inputs[k], "to"):
            inputs[k] = inputs[k].to(model.device)

    input_len = inputs["input_ids"].shape[-1]
    gen_ids = model.generate(**inputs, max_new_tokens=512, do_sample=False, repetition_penalty=1.05, use_cache=True)
    texts = processor.batch_decode(gen_ids[:, input_len:], skip_special_tokens=True)


    outs = []

    summary_regex = re.compile(r'"summary":\s*"(.*?)"', re.DOTALL)
    reason_regex = re.compile(r'"reason":\s*"(.*?)"', re.DOTALL)

    for t in texts:
        summary_text = ""

        summaries_found = summary_regex.findall(t)

        if summaries_found:

            summary_text = summaries_found[0]
        else:

            reasons_found = reason_regex.findall(t)
            if reasons_found:
                summary_text = "; ".join(reasons_found)
            else:
                summary_text = f"Skipped validation due to uncertain words."

        outs.append({
            "raw": t,
            "summary": summary_text
        })
    return outs


def _safe_name(name):
    return re.sub(r"[^A-Za-z0-9_.-]+", "_", str(name))

BATCH = 64
all_records = []
per_sheet_records = {}

for sheet_name, df in sheets.items():
    print(f"處理工作表：{sheet_name}，筆數：{len(df)}")


    config = PROMPT_CONFIG.get(sheet_name)
    if not config:
        print(f"警告：在 PROMPT_CONFIG 中找不到工作表 '{sheet_name}' 的設定，將略過此工作表。")
        continue


    dynamic_system_prompt = (
        f"You are a specialized tool for identifying spelling errors in the '{config['target_column']}' field.",
        f"Your ONLY task is to check for spelling errors in the value of the '{config['target_column']}' field.",
        "You MUST IGNORE all other fields.",
        f"If you find a spelling error, your reason MUST be '{config['error_type']}: {config['example_error']}'.",
        "If there are no spelling errors, your summary must be 'No issues found.'.",
        "Your entire response MUST be a single, valid JSON object and nothing else.",

        "\n--- Example 1: Data with a spelling error ---\n"
        f"Example Input Data:\n"
        f"{config['target_column']}: {config['example_error'].split(' should be ')[0][1:]}\n"
        "Example JSON Output:\n"
        '{\n'
        '  "row_index": 0,\n'
        '  "issues": [\n'
        '    {\n'
        f'      "field": "{config["target_column"]}",\n'
        f'      "reason": "{config["error_type"]}: {config["example_error"]}",\n'
        '      "severity": "error"\n'
        '    }\n'
        '  ],\n'
        f'  "summary": "{config["error_type"]}: {config["example_error"]}"\n'
        '}',

        "\n--- Example 2: Data with no spelling issues ---\n"
        f"Example Input Data:\n"
        f"{config['target_column']}: {config['example_correct']}\n"
        "Example JSON Output:\n"
        '{\n'
        '  "row_index": 1,\n'
        '  "issues": [],\n'
        '  "summary": "No issues found."\n'
        '}'
    )

    dynamic_system_prompt = "\n".join(dynamic_system_prompt)

    sheet_recs = []
    rows = df.to_dict('records')

    for s in range(0, len(rows), BATCH):
        batch_rows = rows[s:s+BATCH]
        outs = validate_many(s, batch_rows, dynamic_system_prompt)

        for j, out in enumerate(outs):
            row_idx = s + j
            summary = out.get("summary", "Error: Summary not found.")
            raw_output = out.get("raw", "Error: Raw output not found.")



            if "should be" in summary:
                match = re.search(r'\[(.*?)\]\s*should be\s*\[(.*?)\]', summary)
                if match:
                    original_word = match.group(1)
                    corrected_word = match.group(2)

                    if original_word.lower() == corrected_word.lower():
                        summary = "No issues found. "

            rec = {
                "sheet_name": sheet_name,
                "row_index": int(row_idx),
                "BioMistral_summary": summary,
                #"raw_model_output": raw_output
            }
            sheet_recs.append(rec)
            all_records.append(rec)

    per_sheet_records[sheet_name] = sheet_recs


overview_df = pd.DataFrame(all_records)
overview_df.to_csv("validation_results__overview.csv", index=False, encoding="utf-8-sig")
print("已輸出：validation_results__overview.csv")

for sheet_name, df in sheets.items():
    result_df = pd.DataFrame(per_sheet_records[sheet_name])

    df_with_index = df.reset_index().rename(columns={'index': 'row_index'})

    merged = pd.merge(df_with_index, result_df, on="row_index", how="left")

    out_name = f"validation_results__{_safe_name(sheet_name)}.csv"
    merged.to_csv(out_name, index=False, encoding="utf-8-sig")
    print(f"已輸出：{out_name}")




# ===========================
# （選用）微調：SFT（文字-only），預設關閉
# ===========================
if DO_FINETUNE:
    from datasets import Dataset
    from transformers import Trainer, TrainingArguments, DataCollatorForSeq2Seq
    from peft import LoraConfig, get_peft_model

    valid_records = [r for r in all_records if r.get("is_valid_json", False)]

    train_samples = []
    for r in valid_records:
        row_idx = r["row_index"]
        sheet_name = r["sheet_name"]

        original_df = sheets.get(sheet_name)
        if original_df is None or row_idx >= len(original_df):
            continue

        row_dict = {k: (None if pd.isna(v) else v) for k, v in original_df.iloc[row_idx].to_dict().items()}
        user_yaml = row_to_yaml(row_dict)
        user_text = f"Based on the following data（YAML）：\n{user_yaml}\n Please review according to the guidelines and return valid JSON."
        assistant_text = r["medgemma_json"]
        messages = [
            {"role": "user", "content": user_text},
            {"role": "assistant", "content": assistant_text},
        ]
        train_samples.append({"messages": messages})

    if len(train_samples) < 10:
        print(f"可用有效訓練樣本太少 ({len(train_samples)} < 10)，略過微調。請自行提供標註資料或使用能產出更多有效 JSON 的模型，並重試。")
    else:
        print(f"找到 {len(train_samples)} 個有效訓練樣本，開始微調。")
        raw_ds = Dataset.from_list(train_samples)

        def encode_fn(ex):
            msgs = ex["messages"]
            prompt = processor.apply_chat_template(
                msgs, add_generation_prompt=False, tokenize=False
            )
            tokenized = processor(text=prompt, return_tensors=None)
            return {"input_ids": tokenized["input_ids"], "labels": tokenized["input_ids"]}

        ds = raw_ds.map(encode_fn, remove_columns=raw_ds.column_names, batched=False)

        try:
            peft_config = LoraConfig(
                r=8, lora_alpha=16, lora_dropout=0.05,
                target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
                task_type="CAUSAL_LM",
            )
            model.enable_input_require_grads()
            model = get_peft_model(model, peft_config)
        except Exception as e:
            print(f"LoRA 套用失敗，改用全參數微調（小批次）：{e}")

        data_collator = DataCollatorForSeq2Seq(
            tokenizer=processor,
            model=model,
            padding=True,
            max_length=4096,
            label_pad_token_id=-100,
        )

        args = TrainingArguments(
            output_dir="medgemma_sft_out",
            per_device_train_batch_size=1,
            gradient_accumulation_steps=8,
            learning_rate=2e-5,
            num_train_epochs=1,
            logging_steps=10,
            save_steps=200,
            save_total_limit=2,
            bf16=torch.cuda.is_available(),
            gradient_checkpointing=True,
            optim="adamw_torch",
            lr_scheduler_type="cosine",
            warmup_ratio=0.03,
            report_to="none",
        )

        trainer = Trainer(
            model=model,
            args=args,
            train_dataset=ds,
            tokenizer=processor,
            data_collator=data_collator,
        )
        trainer.train()
        trainer.save_model("/content/drive/MyDrive/BioMistral_train_backup/BioMistral_sft_out/final")
        processor.save_pretrained("/content/drive/MyDrive/BioMistral_train_backup/BioMistral_sft_out/final")
        print("微調完成，輸出目錄：/content/drive/MyDrive/BioMistral_train_backup/BioMistral_sft_out/final")