#### 加载数据

transformers==4.40.0

torch==1.11.0+cu113

torch-geometric==2.5.2

torch-scatter==2.0.9

torch-sparse==0.6.13

#### 训练模型

In [1]:
import pickle
with open("/root/autodl-tmp/graph_data/gpt3.5_mixed_train_split.pkl", "rb") as f:
    hc3_train = pickle.load(f)
with open("/root/autodl-tmp/graph_data/gpt3.5_mixed_val_split.pkl", "rb") as f:
    hc3_val = pickle.load(f)

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
from datetime import datetime
from torch_geometric.nn import GCNConv
from torch_geometric.data import Data
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime
from tqdm import tqdm
import time

# 构建 GCN 模型
class GCN2(nn.Module):
    def __init__(self,  input_dim, hidden_dim, output_dim):
        super(GCN2, self).__init__()
        self.conv1 = GCNConv(input_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, output_dim)
        self.fc = nn.Linear(output_dim, 1) 
        self.dropout = nn.Dropout(0.5)
        
    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.dropout(x)
        x = self.conv2(x, edge_index)
        x = F.relu(x)
        x = self.dropout(x)
        x = self.fc(x)
        x = torch.mean(x, dim=0, keepdim=True)  
        return torch.sigmoid(x)  

class GCN4(nn.Module):
    def __init__(self,  input_dim, hidden_dim, hidden_dim2, hidden_dim3, output_dim):
        super(GCN4, self).__init__()
        self.conv1 = GCNConv(input_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, hidden_dim2)
        self.conv3 = GCNConv(hidden_dim2, hidden_dim3)
        self.conv4 = GCNConv(hidden_dim3, output_dim)
        self.fc = nn.Linear(output_dim, 1) 
        self.dropout = nn.Dropout(0.5)
        
    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.dropout(x)
        x = self.conv2(x, edge_index)
        x = F.relu(x)
        x = self.dropout(x)
        x = self.conv3(x, edge_index)
        x = F.relu(x)
        x = self.dropout(x)
        x = self.conv4(x, edge_index)
        x = F.relu(x)
        x = self.dropout(x)
        x = self.fc(x)
        x = torch.mean(x, dim=0, keepdim=True)  
        return torch.sigmoid(x)  

In [3]:
seed = 2024
dataset_name = 'gpt3.5'
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed) 
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
input_dim = 768  # 输入维度
hidden_dim = 512  # 隐藏层维度
hidden_dim2 = 256  # 隐藏层维度
hidden_dim3 = 128  # 隐藏层维度
output_dim = 64  # 输出类别数
gcnmodel = GCN2(input_dim, hidden_dim2, output_dim).to(device)
# gcnmodel = GCN4(input_dim, hidden_dim, hidden_dim2, hidden_dim3, output_dim)
optimizer = optim.Adam(gcnmodel.parameters(), lr=0.0001)
criterion = nn.BCELoss()

In [5]:
train_len = len(hc3_train['y'])
val_len = len(hc3_val['y'])
epochs = 40
train_loss = []
val_loss = []
train_acc = []
val_acc = []
val_max_acc = -1
writer = SummaryWriter(f'logs/{dataset_name}_{seed}'+ datetime.now().strftime("%Y%m%d-%H%M%S"))
start_time = time.time()
for epoch in range(epochs):
    # 训练集
    gcnmodel.train()
    epoch_loss = 0.0
    correct_predictions = 0
    for i in tqdm(range(train_len),  f"epoch: {epoch+1}, Training"):
        data = Data(x=hc3_train['all_token_embeddings'][i], edge_index=hc3_train['all_edge_index'][i], y=hc3_train['y'][i]).to(device)
        optimizer.zero_grad()
        outputs = gcnmodel(data)
        loss = criterion(outputs, data.y.float().view(-1, 1))
        # print(loss)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
        predictions = (outputs >= 0.5).long()  
        correct_predictions += (predictions == data.y.view(-1, 1)).sum().item()
    epoch_loss /= train_len
    writer.add_scalar('Loss/train', epoch_loss, epoch)
    epoch_acc = correct_predictions / train_len
    writer.add_scalar('Acc/train', epoch_acc, epoch)
    print(f"epoch: {epoch+1}, train_loss: {epoch_loss}, train_acc: {epoch_acc}")
    train_loss.append(epoch_loss)
    train_acc.append(epoch_acc)
    
    # 验证集
    gcnmodel.eval()
    epoch_loss = 0.0
    correct_predictions = 0
    all_predictions = []
    with torch.no_grad():
        for i in tqdm(range(val_len),  f"epoch: {epoch+1}, Validation"):
            data = Data(x=hc3_val['all_token_embeddings'][i], edge_index=hc3_val['all_edge_index'][i], y=hc3_val['y'][i]).to(device)
            outputs = gcnmodel(data)
            loss = criterion(outputs, data.y.float().view(-1, 1))
            epoch_loss += loss.item()
            predictions = (outputs >= 0.5).long()
            all_predictions.append(predictions)
            correct_predictions += (predictions == data.y.view(-1, 1)).sum().item()
    epoch_loss /= val_len
    writer.add_scalar('Loss/val', epoch_loss, epoch)
    epoch_acc = correct_predictions / val_len
    writer.add_scalar('Acc/val', epoch_acc, epoch)
    print(f"epoch: {epoch+1}, val_loss: {epoch_loss}, val_acc: {epoch_acc}")
    val_loss.append(epoch_loss)
    val_acc.append(epoch_acc)

    if epoch_acc >= val_max_acc:
        val_max_acc = epoch_acc
        torch.save(gcnmodel.state_dict(), f'./model/{dataset_name}_gcn_model_{seed}.pth')
end_time = time.time()
elapsed_time = end_time - start_time
print(f"运行时间: {elapsed_time} 秒")
writer.close()

epoch: 1, Training: 100%|██████████| 6000/6000 [00:14<00:00, 416.94it/s]


epoch: 1, train_loss: 0.3317128727087111, train_acc: 0.8598333333333333


epoch: 1, Validation: 100%|██████████| 1000/1000 [00:01<00:00, 924.74it/s]


epoch: 1, val_loss: 0.2067017246020332, val_acc: 0.921


epoch: 2, Training: 100%|██████████| 6000/6000 [00:14<00:00, 412.27it/s]


epoch: 2, train_loss: 0.1843417204907245, train_acc: 0.9253333333333333


epoch: 2, Validation: 100%|██████████| 1000/1000 [00:01<00:00, 906.81it/s]


epoch: 2, val_loss: 0.18241078009221018, val_acc: 0.928


epoch: 3, Training: 100%|██████████| 6000/6000 [00:13<00:00, 429.12it/s]


epoch: 3, train_loss: 0.15727009968686612, train_acc: 0.9363333333333334


epoch: 3, Validation: 100%|██████████| 1000/1000 [00:01<00:00, 921.13it/s]


epoch: 3, val_loss: 0.16768582736365767, val_acc: 0.934


epoch: 4, Training: 100%|██████████| 6000/6000 [00:14<00:00, 405.53it/s]


epoch: 4, train_loss: 0.1384437370926044, train_acc: 0.9451666666666667


epoch: 4, Validation: 100%|██████████| 1000/1000 [00:01<00:00, 887.14it/s]


epoch: 4, val_loss: 0.16682239029288914, val_acc: 0.933


epoch: 5, Training: 100%|██████████| 6000/6000 [00:14<00:00, 422.71it/s]


epoch: 5, train_loss: 0.12509411526314906, train_acc: 0.952


epoch: 5, Validation: 100%|██████████| 1000/1000 [00:01<00:00, 916.87it/s]


epoch: 5, val_loss: 0.1523254279697597, val_acc: 0.942


epoch: 6, Training: 100%|██████████| 6000/6000 [00:14<00:00, 402.81it/s]


epoch: 6, train_loss: 0.11428218914103852, train_acc: 0.9558333333333333


epoch: 6, Validation: 100%|██████████| 1000/1000 [00:01<00:00, 890.36it/s]


epoch: 6, val_loss: 0.15758468784888546, val_acc: 0.939


epoch: 7, Training: 100%|██████████| 6000/6000 [00:14<00:00, 410.59it/s]


epoch: 7, train_loss: 0.10549680555692294, train_acc: 0.9611666666666666


epoch: 7, Validation: 100%|██████████| 1000/1000 [00:01<00:00, 924.43it/s]


epoch: 7, val_loss: 0.15254263276795882, val_acc: 0.939


epoch: 8, Training: 100%|██████████| 6000/6000 [00:13<00:00, 429.74it/s]


epoch: 8, train_loss: 0.09750379501096498, train_acc: 0.965


epoch: 8, Validation: 100%|██████████| 1000/1000 [00:01<00:00, 922.86it/s]


epoch: 8, val_loss: 0.15094361955850377, val_acc: 0.941


epoch: 9, Training: 100%|██████████| 6000/6000 [00:15<00:00, 398.37it/s]


epoch: 9, train_loss: 0.09012403258761757, train_acc: 0.9665


epoch: 9, Validation: 100%|██████████| 1000/1000 [00:01<00:00, 920.07it/s]


epoch: 9, val_loss: 0.15831954561636563, val_acc: 0.938


epoch: 10, Training: 100%|██████████| 6000/6000 [00:15<00:00, 393.93it/s]


epoch: 10, train_loss: 0.08307746245707948, train_acc: 0.9688333333333333


epoch: 10, Validation: 100%|██████████| 1000/1000 [00:01<00:00, 872.98it/s]


epoch: 10, val_loss: 0.14659493730972095, val_acc: 0.942


epoch: 11, Training: 100%|██████████| 6000/6000 [00:14<00:00, 410.90it/s]


epoch: 11, train_loss: 0.07748806906158943, train_acc: 0.9721666666666666


epoch: 11, Validation: 100%|██████████| 1000/1000 [00:01<00:00, 923.53it/s]


epoch: 11, val_loss: 0.17975476404859864, val_acc: 0.931


epoch: 12, Training: 100%|██████████| 6000/6000 [00:15<00:00, 392.00it/s]


epoch: 12, train_loss: 0.07281944209764982, train_acc: 0.9726666666666667


epoch: 12, Validation: 100%|██████████| 1000/1000 [00:01<00:00, 912.95it/s]


epoch: 12, val_loss: 0.15889206210903575, val_acc: 0.937


epoch: 13, Training: 100%|██████████| 6000/6000 [00:15<00:00, 398.94it/s]


epoch: 13, train_loss: 0.0664739165856706, train_acc: 0.9763333333333334


epoch: 13, Validation: 100%|██████████| 1000/1000 [00:01<00:00, 921.37it/s]


epoch: 13, val_loss: 0.14615323005172387, val_acc: 0.945


epoch: 14, Training: 100%|██████████| 6000/6000 [00:14<00:00, 412.75it/s]


epoch: 14, train_loss: 0.06352481853373967, train_acc: 0.9783333333333334


epoch: 14, Validation: 100%|██████████| 1000/1000 [00:01<00:00, 919.08it/s]


epoch: 14, val_loss: 0.16286972736397218, val_acc: 0.938


epoch: 15, Training: 100%|██████████| 6000/6000 [00:14<00:00, 419.02it/s]


epoch: 15, train_loss: 0.05805639920717336, train_acc: 0.9815


epoch: 15, Validation: 100%|██████████| 1000/1000 [00:01<00:00, 898.99it/s]


epoch: 15, val_loss: 0.149984117932514, val_acc: 0.945


epoch: 16, Training: 100%|██████████| 6000/6000 [00:14<00:00, 414.87it/s]


epoch: 16, train_loss: 0.053043368824022524, train_acc: 0.9818333333333333


epoch: 16, Validation: 100%|██████████| 1000/1000 [00:01<00:00, 915.91it/s]


epoch: 16, val_loss: 0.16111757571320692, val_acc: 0.941


epoch: 17, Training: 100%|██████████| 6000/6000 [00:14<00:00, 403.62it/s]


epoch: 17, train_loss: 0.051166591188321356, train_acc: 0.9833333333333333


epoch: 17, Validation: 100%|██████████| 1000/1000 [00:01<00:00, 888.92it/s]


epoch: 17, val_loss: 0.1575117795466568, val_acc: 0.945


epoch: 18, Training: 100%|██████████| 6000/6000 [00:14<00:00, 410.16it/s]


epoch: 18, train_loss: 0.047487306599591834, train_acc: 0.9833333333333333


epoch: 18, Validation: 100%|██████████| 1000/1000 [00:01<00:00, 933.75it/s]


epoch: 18, val_loss: 0.16883645765450095, val_acc: 0.941


epoch: 19, Training: 100%|██████████| 6000/6000 [00:14<00:00, 428.08it/s]


epoch: 19, train_loss: 0.045046522942487704, train_acc: 0.9841666666666666


epoch: 19, Validation: 100%|██████████| 1000/1000 [00:01<00:00, 832.13it/s]


epoch: 19, val_loss: 0.1766342666039497, val_acc: 0.939


epoch: 20, Training: 100%|██████████| 6000/6000 [00:15<00:00, 394.36it/s]


epoch: 20, train_loss: 0.0416606475286322, train_acc: 0.9855


epoch: 20, Validation: 100%|██████████| 1000/1000 [00:01<00:00, 925.71it/s]


epoch: 20, val_loss: 0.13751791900050062, val_acc: 0.954


epoch: 21, Training: 100%|██████████| 6000/6000 [00:15<00:00, 396.00it/s]


epoch: 21, train_loss: 0.03881185760529716, train_acc: 0.987


epoch: 21, Validation: 100%|██████████| 1000/1000 [00:01<00:00, 923.81it/s]


epoch: 21, val_loss: 0.15233492611979188, val_acc: 0.949


epoch: 22, Training: 100%|██████████| 6000/6000 [00:13<00:00, 432.84it/s]


epoch: 22, train_loss: 0.03503070250885712, train_acc: 0.9883333333333333


epoch: 22, Validation: 100%|██████████| 1000/1000 [00:01<00:00, 922.88it/s]


epoch: 22, val_loss: 0.13542438574990007, val_acc: 0.954


epoch: 23, Training: 100%|██████████| 6000/6000 [00:14<00:00, 403.39it/s]


epoch: 23, train_loss: 0.033535517029938344, train_acc: 0.9903333333333333


epoch: 23, Validation: 100%|██████████| 1000/1000 [00:01<00:00, 920.38it/s]


epoch: 23, val_loss: 0.0950905996567799, val_acc: 0.969


epoch: 24, Training: 100%|██████████| 6000/6000 [00:13<00:00, 453.99it/s]


epoch: 24, train_loss: 0.030157196167234238, train_acc: 0.9903333333333333


epoch: 24, Validation: 100%|██████████| 1000/1000 [00:01<00:00, 921.91it/s]


epoch: 24, val_loss: 0.15087211898116715, val_acc: 0.952


epoch: 25, Training: 100%|██████████| 6000/6000 [00:14<00:00, 401.61it/s]


epoch: 25, train_loss: 0.02874738203812067, train_acc: 0.9905


epoch: 25, Validation: 100%|██████████| 1000/1000 [00:01<00:00, 906.12it/s]


epoch: 25, val_loss: 0.0842055003118576, val_acc: 0.969


epoch: 26, Training: 100%|██████████| 6000/6000 [00:15<00:00, 397.65it/s]


epoch: 26, train_loss: 0.02861915794432377, train_acc: 0.9913333333333333


epoch: 26, Validation: 100%|██████████| 1000/1000 [00:01<00:00, 879.49it/s]


epoch: 26, val_loss: 0.0782917860403326, val_acc: 0.973


epoch: 27, Training: 100%|██████████| 6000/6000 [00:14<00:00, 414.95it/s]


epoch: 27, train_loss: 0.025093607591062394, train_acc: 0.9913333333333333


epoch: 27, Validation: 100%|██████████| 1000/1000 [00:01<00:00, 927.40it/s]


epoch: 27, val_loss: 0.1523489919681401, val_acc: 0.953


epoch: 28, Training: 100%|██████████| 6000/6000 [00:14<00:00, 416.39it/s]


epoch: 28, train_loss: 0.025584121729745044, train_acc: 0.9918333333333333


epoch: 28, Validation: 100%|██████████| 1000/1000 [00:01<00:00, 919.54it/s]


epoch: 28, val_loss: 0.07992366907148862, val_acc: 0.972


epoch: 29, Training: 100%|██████████| 6000/6000 [00:14<00:00, 405.20it/s]


epoch: 29, train_loss: 0.02097310537332482, train_acc: 0.9938333333333333


epoch: 29, Validation: 100%|██████████| 1000/1000 [00:01<00:00, 920.05it/s]


epoch: 29, val_loss: 0.07927464390489207, val_acc: 0.975


epoch: 30, Training: 100%|██████████| 6000/6000 [00:14<00:00, 411.05it/s]


epoch: 30, train_loss: 0.023814305320411136, train_acc: 0.9928333333333333


epoch: 30, Validation: 100%|██████████| 1000/1000 [00:01<00:00, 887.14it/s]


epoch: 30, val_loss: 0.1588398212659016, val_acc: 0.958


epoch: 31, Training: 100%|██████████| 6000/6000 [00:13<00:00, 433.53it/s]


epoch: 31, train_loss: 0.020727348430683263, train_acc: 0.9935


epoch: 31, Validation: 100%|██████████| 1000/1000 [00:01<00:00, 885.09it/s]


epoch: 31, val_loss: 0.3932828428197518, val_acc: 0.923


epoch: 32, Training: 100%|██████████| 6000/6000 [00:14<00:00, 411.42it/s]


epoch: 32, train_loss: 0.022001310543557886, train_acc: 0.9936666666666667


epoch: 32, Validation: 100%|██████████| 1000/1000 [00:01<00:00, 914.89it/s]


epoch: 32, val_loss: 0.07414237111204434, val_acc: 0.978


epoch: 33, Training: 100%|██████████| 6000/6000 [00:14<00:00, 409.88it/s]


epoch: 33, train_loss: 0.02004309225221597, train_acc: 0.9935


epoch: 33, Validation: 100%|██████████| 1000/1000 [00:01<00:00, 920.15it/s]


epoch: 33, val_loss: 0.3481888787246085, val_acc: 0.929


epoch: 34, Training: 100%|██████████| 6000/6000 [00:14<00:00, 420.28it/s]


epoch: 34, train_loss: 0.019030178527009223, train_acc: 0.995


epoch: 34, Validation: 100%|██████████| 1000/1000 [00:01<00:00, 917.21it/s]


epoch: 34, val_loss: 0.1063790812383599, val_acc: 0.969


epoch: 35, Training: 100%|██████████| 6000/6000 [00:15<00:00, 395.61it/s]


epoch: 35, train_loss: 0.018243067838416307, train_acc: 0.994


epoch: 35, Validation: 100%|██████████| 1000/1000 [00:01<00:00, 921.53it/s]


epoch: 35, val_loss: 0.07613602158440726, val_acc: 0.98


epoch: 36, Training: 100%|██████████| 6000/6000 [00:15<00:00, 400.00it/s]


epoch: 36, train_loss: 0.014444600749182373, train_acc: 0.996


epoch: 36, Validation: 100%|██████████| 1000/1000 [00:01<00:00, 924.42it/s]


epoch: 36, val_loss: 0.18659135045284236, val_acc: 0.972


epoch: 37, Training: 100%|██████████| 6000/6000 [00:15<00:00, 397.19it/s]


epoch: 37, train_loss: 0.014151426656140859, train_acc: 0.996


epoch: 37, Validation: 100%|██████████| 1000/1000 [00:01<00:00, 916.58it/s]


epoch: 37, val_loss: 0.09375198996335539, val_acc: 0.974


epoch: 38, Training: 100%|██████████| 6000/6000 [00:13<00:00, 430.36it/s]


epoch: 38, train_loss: 0.015302775388887954, train_acc: 0.9956666666666667


epoch: 38, Validation: 100%|██████████| 1000/1000 [00:01<00:00, 918.98it/s]


epoch: 38, val_loss: 0.0847155271675315, val_acc: 0.975


epoch: 39, Training: 100%|██████████| 6000/6000 [00:14<00:00, 412.56it/s]


epoch: 39, train_loss: 0.015895210690359783, train_acc: 0.9948333333333333


epoch: 39, Validation: 100%|██████████| 1000/1000 [00:01<00:00, 901.34it/s]


epoch: 39, val_loss: 0.1876065002002825, val_acc: 0.974


epoch: 40, Training: 100%|██████████| 6000/6000 [00:14<00:00, 426.84it/s]


epoch: 40, train_loss: 0.014893286380229147, train_acc: 0.996


epoch: 40, Validation: 100%|██████████| 1000/1000 [00:01<00:00, 931.47it/s]

epoch: 40, val_loss: 0.11020838578334968, val_acc: 0.973
运行时间: 627.5200390815735 秒





#### 测试

In [13]:
import pickle
test_file = "matched_gpt3.5_mixed_test_split"
with open(f"/root/autodl-tmp/graph_data/{test_file}.pkl", "rb") as f:
    hc3_test = pickle.load(f)
test_len = len(hc3_test['y'])

In [14]:
from sklearn.metrics import roc_auc_score, f1_score

# test_gcnmodel = gcnmodel
test_gcnmodel = GCN2(input_dim, hidden_dim2, output_dim).to(device)
test_gcnmodel.load_state_dict(torch.load(f'./model/{dataset_name}_gcn_model_{seed}.pth'))
test_gcnmodel.eval()
test_loss = 0.0
correct_predictions = 0
test_pres = list()
start_time = time.time()
with torch.no_grad():
    for i in tqdm(range(test_len),  f"Test"):
        data = Data(x=hc3_test['all_token_embeddings'][i], edge_index=hc3_test['all_edge_index'][i], y=hc3_test['y'][i]).to(device)
        outputs = test_gcnmodel(data)
        test_pres.append(outputs.item())
        loss = criterion(outputs, data.y.float().view(-1, 1))
        test_loss += loss.item()
        predictions = (outputs >= 0.5).long()
        correct_predictions += (predictions == data.y.view(-1, 1)).sum().item()
end_time = time.time()
elapsed_time = end_time - start_time
print(f"运行时间: {elapsed_time} 秒")
y_pred = [1 if prob >= 0.5 else 0 for prob in test_pres]
y_true = hc3_test['y'].view(-1, 1)
test_loss /= test_len
test_acc = correct_predictions / test_len
test_f1 = f1_score(y_true, y_pred)
print(f"test_loss: {test_loss}, test_acc: {test_acc}, test_f1: {test_f1}")

Test: 100%|██████████| 1000/1000 [00:01<00:00, 898.84it/s]

运行时间: 1.115060567855835 秒
test_loss: 0.24888627747510203, test_acc: 0.954, test_f1: 0.9546351084812622





In [15]:
auc = roc_auc_score(hc3_test['y'], test_pres)
auc

0.988476

In [16]:
with open(f"/root/autodl-tmp/result/test_result.txt", "a", encoding="utf-8") as w:
    w.write(f"{test_file}\t acc: {test_acc}\t auc: {auc}\t f1: {test_f1}\t seed: {seed}\t{datetime.now()}\n")