In [1]:
import pandas as pd
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.preprocessing import LabelEncoder, StandardScaler
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_curve, precision_recall_curve, auc, roc_auc_score
import matplotlib.pyplot as plt
import numpy as np
import torch
from transformers import TabTransformerConfig, TabTransformerForSequenceClassification
from torch.utils.data import TensorDataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler

ImportError: cannot import name 'TabTransformerConfig' from 'transformers' (/root/miniconda3/lib/python3.8/site-packages/transformers/__init__.py)

In [None]:
# 读取数据
train = pd.read_csv('/root/autodl-fs/data/train_revise+45缩减到100特征 数量1000个 去掉三列和Name.csv')

# 分离特征和标签
X = train.drop(['senolytic'], axis=1)
y = train['senolytic']

# 标签编码
label_encoder = LabelEncoder()
y = label_encoder.fit_transform(y)

# 特征标准化
scaler = StandardScaler()
X = scaler.fit_transform(X)

# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

In [None]:
# 转换为Tensor
X_train_tensor = torch.tensor(X_train, dtype=torch.float32)
y_train_tensor = torch.tensor(y_train, dtype=torch.long)
X_test_tensor = torch.tensor(X_test, dtype=torch.float32)
y_test_tensor = torch.tensor(y_test, dtype=torch.long)

train_dataset = TensorDataset(X_train_tensor, y_train_tensor)
test_dataset = TensorDataset(X_test_tensor, y_test_tensor)
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False)

# 配置TabTransformer模型
config = TabTransformerConfig(
    num_labels=2,
    hidden_size=32,
    num_attention_heads=2,
    num_hidden_layers=2,
    intermediate_size=64,
)

# 初始化TabTransformer模型
model = TabTransformerForSequenceClassification(config)
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)

In [None]:
# 定义训练函数
def train(model, train_loader, optimizer, epochs=10):
    model.train()
    for epoch in range(epochs):
        for batch in train_loader:
            inputs, labels = batch
            outputs = model(inputs)
            loss = outputs.loss
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        print(f"Epoch {epoch+1}/{epochs} - Loss: {loss.item()}")

# 定义测试函数
def evaluate(model, test_loader):
    model.eval()
    y_pred, y_pred_proba, y_true = [], [], []
    with torch.no_grad():
        for batch in test_loader:
            inputs, labels = batch
            outputs = model(inputs)
            probs = outputs.logits.softmax(dim=1)[:, 1].cpu().numpy()
            predictions = outputs.logits.argmax(dim=1).cpu().numpy()
            y_pred.extend(predictions)
            y_pred_proba.extend(probs)
            y_true.extend(labels.cpu().numpy())
    return np.array(y_true), np.array(y_pred), np.array(y_pred_proba)

In [None]:
# 训练模型
train(model, train_loader, optimizer, epochs=10)

# 评估模型
y_true, y_pred, y_pred_proba = evaluate(model, test_loader)

# 计算评价指标
accuracy = accuracy_score(y_true, y_pred)
precision = precision_score(y_true, y_pred)
recall = recall_score(y_true, y_pred)
f1 = f1_score(y_true, y_pred)
roc_auc = roc_auc_score(y_true, y_pred_proba)

print("Accuracy:", accuracy)
print("Precision:", precision)
print("Recall:", recall)
print("F1 Score:", f1)
print("ROC AUC Score:", roc_auc)

In [None]:
# 绘制ROC曲线
fpr, tpr, _ = roc_curve(y_true, y_pred_proba)
plt.figure(figsize=(10, 5))
plt.plot(fpr, tpr, label=f'TabTransformer (AUC = {roc_auc:.2f})')
plt.plot([0, 1], [0, 1], 'k--')
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('ROC Curve')
plt.legend()
plt.show()

In [None]:
# 绘制PR曲线并计算AUC
precision_vals, recall_vals, _ = precision_recall_curve(y_true, y_pred_proba)
pr_auc = auc(recall_vals, precision_vals)
plt.figure(figsize=(10, 5))
plt.plot(recall_vals, precision_vals, label=f'TabTransformer (AUC = {pr_auc:.2f})')
plt.xlabel('Recall')
plt.ylabel('Precision')
plt.title('Precision-Recall Curve')
plt.legend()
plt.show()