<a href="https://colab.research.google.com/github/t8101349/group-project-202503/blob/main/gradio_web_0326.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install pandas
!pip install numpy
!pip install gradio
!pip install rdkit
!pip install scikit-learn
!pip install xgboost
!pip install lightgbm
!pip install psutil

In [None]:
# 以上套件放進huggingface space requirement

In [None]:
import gradio as gr
import pandas as pd
import traceback
from gradio.themes import Base
from process import predict_process
import os
from rdkit import Chem
import psutil
import time

# 放進huggingface space app.py

# 記憶體資訊更新函數(測試用)
def update_memory_info():
    while True:
        memory_info = psutil.Process().memory_info()
        memory_mb = memory_info.rss / (1024 * 1024)  # 轉換為 MB
        yield f"RAM 使用量: {memory_mb:.2f} MB | CPU 使用率: {psutil.cpu_percent()}%"
        time.sleep(2)  # 每2秒更新一次


# 設定頁面元件
with gr.Blocks(
    delete_cache=(3600, 7200),
    theme=Base(primary_hue="cyan", secondary_hue="teal", neutral_hue="gray"),
    title="新藥預測工具",
    css="""
    .gradio-container {
        width: 100% !important;
        max-width: 800px !important;
        margin: 0 auto !important;
        padding: 0.5rem;
    }
    .gradient-title h1 {
        background: linear-gradient(45deg, #13A9E6, #3DD69E);
        -webkit-background-clip: text;
        background-clip: text;
        color: transparent;
        text-align: center;
        font-size: clamp(2rem, 5vw, 3.5rem);
        font-weight: bold;
    }
    footer {
        display: none !important;
    }
    .file-status {
        font-weight: bold;
    }
    """
) as demo:
    # 頂部的記憶體資訊框(測試用)
    memory_box = gr.Textbox(
        label="系統資源使用情況",
        value="RAM 使用量: 計算中... | CPU 使用率: 計算中...",
        elem_classes=["memory-info"],
        interactive=False
    )

    gr.Markdown("# 新藥預測工具", elem_classes=["gradient-title"])

    with gr.Accordion("點此查看工具詳細說明", open=False):
        gr.Markdown("""
            **功能介紹：**
            此工具可將 SMILES 字串的分子資料集根據機器學習模型預測出每個分子是否與指定的
            蛋白質標靶(sEH, BRD4, HSA 其中之一)結合，快速篩選出可能的藥物分子資料集。

            **操作說明：**
            1. 上傳分子數據集
            2. 確認檔案
            3. 執行預測
            4. 選擇下載格式
            5. 產生預測檔案
            6. 下載預測檔案

            **上傳資料限制：**
            1. 支援 CSV 與 Parquet 格式
            2. 包含必要欄位
            3. 檔案大小上限50MB
            4. 資料筆數上限50萬筆

            **必要欄位說明：**
            - molecule_smiles: 分子的 SMILES 字串
            - protein_name: 蛋白質名稱 (必須為 sEH, BRD4, HSA 其中之一)
            """)

    # 定義 state
    state = gr.State(value={
        "df": None,
        "result_df": None,
        "file_uploaded": False,
        "prediction_done": False,
        "cancel_prediction": False
    })

    with gr.Column():
        file_input = gr.File(
            label="上傳分子數據集，拖曳檔案至此或點擊上傳，上限50MB",
            file_types=[".csv", ".parquet"],
            type="filepath"
        )

        confirm_btn = gr.Button("確認檔案", variant="primary")
        file_status = gr.Textbox(
            label="檔案確認狀態",
            elem_classes=["file-status"],
            interactive=False
        )

        with gr.Row():
            predict_btn = gr.Button("執行預測", variant="primary", interactive=False)
            cancel_btn = gr.Button("取消預測", variant="primary", interactive=False)

        predict_status = gr.Textbox(
            label="預測狀態",
            elem_classes=["file-status"],
            interactive=False
        )

        with gr.Row():
            download_format = gr.Radio(
                choices=["CSV", "Parquet"],
                label="選擇下載格式",
                value="CSV"
            )

        generate_btn = gr.Button("產生下載檔案", variant="primary", interactive=False)
        generate_status = gr.Textbox(
            label="檔案生成狀態",
            elem_classes=["file-status"],
            interactive=False
        )

        download_btn = gr.DownloadButton(
            label="下載預測結果",
            variant="primary",
            interactive=False,
            visible=True
        )

        download_status = gr.Textbox(
            label="下載狀態",
            elem_classes=["file-status"],
            interactive=False
        )

    # 檔案確認
    def confirm_file(file, state):
        if file is None:
            return "請上傳分子數據集！", state
        try:
            if file.endswith('.csv'):
                try:
                    df = pd.read_csv(file, encoding="utf-8")
                except UnicodeDecodeError:
                    df = pd.read_csv(file, encoding="Windows-1252")
            elif file.endswith('.parquet'):
                df = pd.read_parquet(file)

            required_columns = ["molecule_smiles", "protein_name"]
            missing_columns = [col for col in required_columns if col not in df.columns]
            if missing_columns:
                return f"檔案缺少必要的欄位：{', '.join(missing_columns)}", state

            if len(df) > 500000:
                return "資料筆數超過 50 萬筆，請減少資料量！", state

            # 檢查特徵欄位的缺值
            for column in required_columns:
                if df[column].isnull().any():
                    return f"❌ 欄位 '{column}' 含有缺值，請檢查並補充完整數據！", state

            # 檢查 molecule_smiles 是否為有效 SMILES 格式
            invalid_smiles = []
            for idx, smiles in enumerate(df["molecule_smiles"]):
                try:
                    mol = Chem.MolFromSmiles(smiles)
                    if mol is None:
                        invalid_smiles.append((idx, smiles))
                except:
                    invalid_smiles.append((idx, smiles))

            if invalid_smiles:
                error_msg = "❌ 以下分子 SMILES 格式無效：\n"
                for idx, smiles in invalid_smiles[:5]:  # 只顯示前5個錯誤避免訊息太長
                    error_msg += f"行 {idx + 1}: {smiles}\n"
                if len(invalid_smiles) > 5:
                    error_msg += f"...還有 {len(invalid_smiles) - 5} 個無效 SMILES\n"
                error_msg += "請檢查並修正 SMILES 格式！"
                return error_msg, state

            # 檢查 protein_name 是否為 sEH, BRD4 或 HSA
            valid_proteins = {"sEH", "BRD4", "HSA"}
            invalid_proteins = []
            for idx, protein in enumerate(df["protein_name"]):
                if protein not in valid_proteins:
                    invalid_proteins.append((idx, protein))

            if invalid_proteins:
                error_msg = "❌ 以下蛋白質名稱無效（僅允許 sEH, BRD4 或 HSA）：\n"
                for idx, protein in invalid_proteins[:5]:
                    error_msg += f"行 {idx + 1}: {protein}\n"
                if len(invalid_proteins) > 5:
                    error_msg += f"...還有 {len(invalid_proteins) - 5} 個無效蛋白質名稱\n"
                error_msg += "請將蛋白質名稱修正為 sEH, BRD4 或 HSA！"
                return error_msg, state

            state["df"] = df
            state["file_uploaded"] = True
            state["prediction_done"] = False
            state["cancel_prediction"] = False  # 重置取消標誌

            filename = os.path.basename(file)
            return f"✅ 已成功上傳檔案：{filename}，共 {len(df)} 筆資料", state
        except Exception as e:
            error_details = traceback.format_exc()
            print(f"檔案處理錯誤：{str(e)}\n{error_details}")
            return f"❌ 檔案處理錯誤：{str(e)}", state

    # 執行預測
    def run_prediction(state):
        if not state["file_uploaded"] or state["df"] is None:
            return "❌ 請先上傳並確認檔案！", state
        try:
            # 重置取消標誌
            state["cancel_prediction"] = False

            # 檢查取消狀態（在 predict_process 中實現）
            result_df = predict_process(state["df"], cancel_flag=lambda: state["cancel_prediction"])
            if state["cancel_prediction"]:
                return "❌ 預測已取消！", state

            state["result_df"] = result_df
            state["prediction_done"] = True
            return f"✅ 預測完成！共處理 {len(result_df)} 筆資料", state
        except Exception as e:
            error_details = traceback.format_exc()
            print(f"預測錯誤：{str(e)}\n{error_details}")
            return f"❌ 預測錯誤：{str(e)}", state

    # 取消預測函數
    def cancel_prediction(state):
        state["cancel_prediction"] = True
        return "正在取消預測...", state

    # 產生下載檔案
    def generate_file(format_choice, state):
        import tempfile
        if not state["prediction_done"] or state["result_df"] is None:
            return None, "❌ 請先執行預測！", state
        try:
            temp_dir = tempfile.gettempdir()
            if format_choice == "CSV":
                filename = "prediction.csv"
                filepath = os.path.join(temp_dir, filename)
                state["result_df"].to_csv(filepath, index=False)
            else:
                filename = "prediction.parquet"
                filepath = os.path.join(temp_dir, filename)
                state["result_df"].to_parquet(filepath, index=False)
            return filepath, f"✅ 已生成 {filename}，點擊下方按鈕即可下載", state
        except Exception as e:
            error_details = traceback.format_exc()
            print(f"生成檔案錯誤：{str(e)}\n{error_details}")
            return None, f"❌ 生成檔案錯誤：{str(e)}", state

    # 下載完成後更新狀態
    def update_download_status(filepath):
        if filepath:
            filename = os.path.basename(filepath)
            return f"✅ 已下載 {filename}"
        return "❌ 下載失敗"

    # 啟動記憶體監控（測試用）
    demo.load(
        fn=update_memory_info,
        inputs=None,
        outputs=memory_box
    )

    # 事件綁定
    confirm_btn.click(
        fn=confirm_file,
        inputs=[file_input, state],
        outputs=[file_status, state]
    ).then(
        fn=lambda status_text: gr.update(interactive="✅" in status_text),
        inputs=file_status,
        outputs=predict_btn
    )

    predict_btn.click(
        fn=run_prediction,
        inputs=state,
        outputs=[predict_status, state]
    ).then(
        fn=lambda status_text: gr.update(interactive="✅" in status_text),
        inputs=predict_status,
        outputs=generate_btn
    ).then(  # 預測開始時啟用取消按鈕，完成時禁用
        fn=lambda status_text: gr.update(interactive="✅" not in status_text and "❌" not in status_text),
        inputs=predict_status,
        outputs=cancel_btn
    )

    generate_btn.click(
        fn=generate_file,
        inputs=[download_format, state],
        outputs=[download_btn, generate_status, state]
    ).then(
        fn=lambda filepath, status: gr.update(value=filepath, interactive="✅" in status),
        inputs=[download_btn, generate_status],
        outputs=download_btn
    )

    download_btn.click(
        fn=update_download_status,
        inputs=download_btn,
        outputs=download_status
    )

if __name__ == "__main__":
    demo.launch(max_file_size=50 * 1024 * 1024)

In [None]:
import pandas as pd
import numpy as np
import joblib
from rdkit import Chem
from rdkit.Chem import rdFingerprintGenerator
from huggingface_hub import hf_hub_download

# 放進huggingface space process.py

def smiles_to_morgan_fingerprint(smiles, n_bits=2048):
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        return np.zeros(n_bits, dtype=int)
    else:
        generator = rdFingerprintGenerator.GetMorganGenerator(radius=2, fpSize=n_bits)
        return np.array(generator.GetFingerprint(mol), dtype=int)

def predict_process(df, cancel_flag=None):
    # 載入模型
    model_path = hf_hub_download(repo_id='sinanju/model_voting', filename='voting_model.bin')
    model = joblib.load(model_path)

    # 保留分子式
    mol_smiles = df['molecule_smiles']

    # 對 'molecule_smiles' 欄位進行轉換，並檢查取消
    fingerprint_list = []
    for smiles in df['molecule_smiles']:
        if cancel_flag and cancel_flag():  # 檢查取消標誌
            return None
        fingerprint_list.append(smiles_to_morgan_fingerprint(smiles))
    df['molecule_smiles'] = fingerprint_list
    df.columns = df.columns.astype(str)

    # 轉成 int8 以節省記憶體
    int_cols = df.select_dtypes(include=['int64']).columns
    for col in int_cols:
        if cancel_flag and cancel_flag():  # 檢查取消標誌
            return None
        df[col] = df[col].astype(np.int8)

    # 處理指紋數據和蛋白質編碼
    fingerprints_df = pd.DataFrame(df['molecule_smiles'].to_list())
    protein_onehot = pd.get_dummies(df['protein_name'], prefix='protein').astype(int).reset_index(drop=True)
    X_test = pd.concat([fingerprints_df, protein_onehot], axis=1)
    X_test.columns = X_test.columns.astype(str)  # 統一欄位名稱為字串

    # 預測機率並轉為二元分類
    if cancel_flag and cancel_flag():  # 檢查取消標誌
        return None
    probabilities = model.predict_proba(X_test)[:, 1]
    threshold = 0.5
    predictions = (probabilities >= threshold).astype(int)

    # 產生新的 id
    df['id'] = range(1, 1 + len(df))

    # 建立結果 DataFrame
    result_df = pd.DataFrame({
        'id': df['id'],
        'molecule_smiles': mol_smiles,
        'protein_name': df['protein_name'],
        'binds': predictions
    })

    return result_df