<a href="https://colab.research.google.com/github/xingleiking/iodraw-files/blob/main/20250619esmGBRT.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [24]:
import pandas as pd
import numpy as np
import torch
from transformers import AutoTokenizer, AutoModel
from sklearn.ensemble import GradientBoostingRegressor
from sklearn.model_selection import train_test_split
from sklearn.metrics import r2_score, mean_squared_error
from google.colab import files

In [25]:
# 配置
ESM2_MODEL = "facebook/esm2_t33_650M_UR50D"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 8

def load_esm2(model_name=ESM2_MODEL):
    tokenizer = AutoTokenizer.from_pretrained(model_name, do_lower_case=False)
    model = AutoModel.from_pretrained(model_name, output_hidden_states=True)
    return tokenizer, model.to(DEVICE).eval()

def extract_esm_embeddings(sequences, tokenizer, model, batch_size=BATCH_SIZE):
    all_embs = []
    with torch.no_grad():
        for i in range(0, len(sequences), batch_size):
            batch = sequences[i:i+batch_size]
            enc = tokenizer(batch, return_tensors="pt", padding=True, truncation=True)
            input_ids = enc["input_ids"].to(DEVICE)
            attention_mask = enc["attention_mask"].to(DEVICE)
            out = model(input_ids, attention_mask=attention_mask)
            last_hidden = out.hidden_states[-1]
            mask = attention_mask.unsqueeze(-1)
            summed = (last_hidden * mask).sum(dim=1)
            lengths = mask.sum(dim=1)
            pooled = summed / lengths
            all_embs.append(pooled.cpu().numpy())
    return np.vstack(all_embs)



In [26]:
def process_uploaded_csv():
    # 弹出文件上传对话框
    uploaded = files.upload()
    # 取第一个文件
    fn = next(iter(uploaded))
    df = pd.read_csv(fn)
    if df.shape[1] < 2:
        print("CSV 至少要两列：sequence, label")
        return
    seqs = df.iloc[:,0].astype(str).tolist()
    labels = df.iloc[:,1].astype(float).values

    print("加载 ESM2 ...")
    tokenizer, model = load_esm2()
    print("提取特征 ...")
    X = extract_esm_embeddings(seqs, tokenizer, model)
    print("训练 GBRT ...")
    X_train, X_test, y_train, y_test = train_test_split(X, labels, test_size=0.2, random_state=42)
    gbrt = GradientBoostingRegressor()
    gbrt.fit(X_train, y_train)
    y_pred = gbrt.predict(X_test)

    print(f"R² = {r2_score(y_test, y_pred):.4f}")
    print(f"MSE = {mean_squared_error(y_test, y_pred):.4f}")

In [28]:
process_uploaded_csv()


Saving data.csv to data.csv
加载 ESM2 ...


Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t33_650M_UR50D and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


提取特征 ...
训练 GBRT ...
R² = 1.0000
MSE = 0.0000


# 新段落

In [None]:
main_from_csv_string(csv_text)

🔄 解析 CSV 数据...
📦 加载 ESM2 模型...


Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t33_650M_UR50D and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


🧬 提取特征...
✅ 特征形状: (10, 1280)
📈 拟合模型...
🎯 R²: -0.2457102440037513
📉 MSE: 0.7007120122521104
