#### 加载数据

transformers==4.40.0

torch==1.11.0+cu113

torch-geometric==2.5.2

torch-scatter==2.0.9

torch-sparse==0.6.13

In [6]:
import pickle
with open("/root/autodl-tmp/graph_data/mix_train.pkl", "rb") as f:
    hc3_train = pickle.load(f)
with open("/root/autodl-tmp/graph_data/mix_val.pkl", "rb") as f:
    hc3_val = pickle.load(f)

#### 训练模型

In [7]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
from datetime import datetime
from torch_geometric.nn import GCNConv
from torch_geometric.data import Data
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime
from tqdm import tqdm

# 构建 GCN 模型
class GCN2(nn.Module):
    def __init__(self,  input_dim, hidden_dim, output_dim):
        super(GCN2, self).__init__()
        self.conv1 = GCNConv(input_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, output_dim)
        self.fc = nn.Linear(output_dim, 1) 
        self.dropout = nn.Dropout(0.5)
        
    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.dropout(x)
        x = self.conv2(x, edge_index)
        x = F.relu(x)
        x = self.dropout(x)
        x = self.fc(x)
        x = torch.mean(x, dim=0, keepdim=True)  
        return torch.sigmoid(x)  

class GCN4(nn.Module):
    def __init__(self,  input_dim, hidden_dim, hidden_dim2, hidden_dim3, output_dim):
        super(GCN4, self).__init__()
        self.conv1 = GCNConv(input_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, hidden_dim2)
        self.conv3 = GCNConv(hidden_dim2, hidden_dim3)
        self.conv4 = GCNConv(hidden_dim3, output_dim)
        self.fc = nn.Linear(output_dim, 1) 
        self.dropout = nn.Dropout(0.5)
        
    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.dropout(x)
        x = self.conv2(x, edge_index)
        x = F.relu(x)
        x = self.dropout(x)
        x = self.conv3(x, edge_index)
        x = F.relu(x)
        x = self.dropout(x)
        x = self.conv4(x, edge_index)
        x = F.relu(x)
        x = self.dropout(x)
        x = self.fc(x)
        x = torch.mean(x, dim=0, keepdim=True)  
        return torch.sigmoid(x)  


class GCNMulticlass(nn.Module):
    def __init__(self, input_dim, hidden_dim1, hidden_dim2, num_classes):
        super(GCNMulticlass, self).__init__()
        self.conv1 = GCNConv(input_dim, hidden_dim1)
        self.conv2 = GCNConv(hidden_dim1, hidden_dim2)  
        self.fc = nn.Linear(hidden_dim2, num_classes)  
        self.dropout = nn.Dropout(0.5)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.dropout(x)
        
        x = self.conv2(x, edge_index)
        x = F.relu(x)
        x = self.dropout(x)
        
        x = torch.mean(x, dim=0, keepdim=True)  
        x = self.fc(x)
        return F.log_softmax(x, dim=1) 

In [8]:
seed = 2024
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed) 
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [9]:
input_dim = 768  # 输入维度
hidden_dim = 512  # 隐藏层维度
hidden_dim2 = 256  # 隐藏层维度
hidden_dim3 = 128  # 隐藏层维度
output_dim = 64  # 输出类别数
# gcnmodel = GCN2(input_dim, hidden_dim2, output_dim)
# gcnmodel = GCN4(input_dim, hidden_dim, hidden_dim2, hidden_dim3, output_dim)
gcnmodel = GCNMulticlass(input_dim, hidden_dim2, output_dim, 3)
optimizer = optim.Adam(gcnmodel.parameters(), lr=0.0001)
criterion = nn.BCELoss()
CE_criterion = nn.CrossEntropyLoss()

In [10]:
train_len = len(hc3_train['y'])
val_len = len(hc3_val['y'])
epochs = 30
train_loss = []
val_loss = []
train_acc = []
val_acc = []
val_max_acc = -1
writer = SummaryWriter('logs/mix'+ datetime.now().strftime("%Y%m%d-%H%M%S"))
for epoch in range(epochs):
    # 训练集
    gcnmodel.train()
    epoch_loss = 0.0
    correct_predictions = 0
    for i in tqdm(range(train_len),  f"epoch: {epoch+1}, Training"):
        data = Data(x=hc3_train['all_token_embeddings'][i], edge_index=hc3_train['all_edge_index'][i], y=hc3_train['y'][i])
        optimizer.zero_grad()
        outputs = gcnmodel(data)
        # loss = criterion(outputs, data.y.float().view(-1, 1))
        loss = CE_criterion(outputs, data.y.long().view(-1, 1).squeeze(0))
        # print(loss)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
        _, predictions = torch.max(outputs, 1) 
        correct_predictions += (predictions == data.y).sum().item()
    epoch_loss /= train_len
    writer.add_scalar('Loss/train', epoch_loss, epoch)
    epoch_acc = correct_predictions / train_len
    writer.add_scalar('Acc/train', epoch_acc, epoch)
    print(f"epoch: {epoch+1}, train_loss: {epoch_loss}, train_acc: {epoch_acc}")
    train_loss.append(epoch_loss)
    train_acc.append(epoch_acc)
    
    # 验证集
    gcnmodel.eval()
    epoch_loss = 0.0
    correct_predictions = 0
    all_predictions = []
    with torch.no_grad():
        for i in tqdm(range(val_len),  f"epoch: {epoch+1}, Validation"):
            data = Data(x=hc3_val['all_token_embeddings'][i], edge_index=hc3_val['all_edge_index'][i], y=hc3_val['y'][i])
            outputs = gcnmodel(data)
            # loss = criterion(outputs, data.y.float().view(-1, 1))
            loss = CE_criterion(outputs, data.y.long().view(-1, 1).squeeze(0))
            epoch_loss += loss.item()
            _, predictions = torch.max(outputs, 1) 
            correct_predictions += (predictions == data.y).sum().item()
            all_predictions.append(predictions)
    epoch_loss /= val_len
    writer.add_scalar('Loss/val', epoch_loss, epoch)
    epoch_acc = correct_predictions / val_len
    writer.add_scalar('Acc/val', epoch_acc, epoch)
    print(f"epoch: {epoch+1}, val_loss: {epoch_loss}, val_acc: {epoch_acc}")
    val_loss.append(epoch_loss)
    val_acc.append(epoch_acc)

    if epoch_acc >= val_max_acc:
        val_max_acc = epoch_acc
        torch.save(gcnmodel.state_dict(), './model/mix_gcn_model.pth')

writer.close()

epoch: 1, Training: 100%|██████████| 4800/4800 [00:33<00:00, 143.48it/s]


epoch: 1, train_loss: 0.2994169792495111, train_acc: 0.8908333333333334


epoch: 1, Validation: 100%|██████████| 750/750 [00:01<00:00, 571.13it/s]


epoch: 1, val_loss: 0.21089619938431134, val_acc: 0.9413333333333334


epoch: 2, Training: 100%|██████████| 4800/4800 [00:33<00:00, 141.71it/s]


epoch: 2, train_loss: 0.14478376848672023, train_acc: 0.9485416666666666


epoch: 2, Validation: 100%|██████████| 750/750 [00:01<00:00, 560.47it/s]


epoch: 2, val_loss: 0.18525216186672272, val_acc: 0.9466666666666667


epoch: 3, Training: 100%|██████████| 4800/4800 [00:34<00:00, 139.96it/s]


epoch: 3, train_loss: 0.11689543746787791, train_acc: 0.958125


epoch: 3, Validation: 100%|██████████| 750/750 [00:01<00:00, 565.17it/s]


epoch: 3, val_loss: 0.16888778464727375, val_acc: 0.948


epoch: 4, Training: 100%|██████████| 4800/4800 [00:34<00:00, 139.76it/s]


epoch: 4, train_loss: 0.10146218700105641, train_acc: 0.9639583333333334


epoch: 4, Validation: 100%|██████████| 750/750 [00:01<00:00, 563.73it/s]


epoch: 4, val_loss: 0.16654337128551713, val_acc: 0.948


epoch: 5, Training: 100%|██████████| 4800/4800 [00:34<00:00, 140.45it/s]


epoch: 5, train_loss: 0.08698971800438018, train_acc: 0.9677083333333333


epoch: 5, Validation: 100%|██████████| 750/750 [00:01<00:00, 572.98it/s]


epoch: 5, val_loss: 0.14555561547953136, val_acc: 0.952


epoch: 6, Training: 100%|██████████| 4800/4800 [00:34<00:00, 140.14it/s]


epoch: 6, train_loss: 0.07793093202659755, train_acc: 0.9735416666666666


epoch: 6, Validation: 100%|██████████| 750/750 [00:01<00:00, 567.42it/s]


epoch: 6, val_loss: 0.14495961527808648, val_acc: 0.948


epoch: 7, Training: 100%|██████████| 4800/4800 [00:34<00:00, 140.65it/s]


epoch: 7, train_loss: 0.06884527367170334, train_acc: 0.97625


epoch: 7, Validation: 100%|██████████| 750/750 [00:01<00:00, 576.68it/s]


epoch: 7, val_loss: 0.12429745967857116, val_acc: 0.956


epoch: 8, Training: 100%|██████████| 4800/4800 [00:34<00:00, 140.19it/s]


epoch: 8, train_loss: 0.06275307984279067, train_acc: 0.979375


epoch: 8, Validation: 100%|██████████| 750/750 [00:01<00:00, 583.42it/s]


epoch: 8, val_loss: 0.116917938007285, val_acc: 0.956


epoch: 9, Training: 100%|██████████| 4800/4800 [00:34<00:00, 139.97it/s]


epoch: 9, train_loss: 0.055482873643700234, train_acc: 0.98125


epoch: 9, Validation: 100%|██████████| 750/750 [00:01<00:00, 567.85it/s]


epoch: 9, val_loss: 0.11710165786956918, val_acc: 0.9573333333333334


epoch: 10, Training: 100%|██████████| 4800/4800 [00:33<00:00, 141.46it/s]


epoch: 10, train_loss: 0.05035135756555907, train_acc: 0.9839583333333334


epoch: 10, Validation: 100%|██████████| 750/750 [00:01<00:00, 574.70it/s]


epoch: 10, val_loss: 0.1375429515156075, val_acc: 0.9533333333333334


epoch: 11, Training: 100%|██████████| 4800/4800 [00:33<00:00, 142.96it/s]


epoch: 11, train_loss: 0.04600701296759033, train_acc: 0.9852083333333334


epoch: 11, Validation: 100%|██████████| 750/750 [00:01<00:00, 570.16it/s]


epoch: 11, val_loss: 0.10374065902018265, val_acc: 0.9626666666666667


epoch: 12, Training: 100%|██████████| 4800/4800 [00:33<00:00, 142.26it/s]


epoch: 12, train_loss: 0.04007588800546101, train_acc: 0.9877083333333333


epoch: 12, Validation: 100%|██████████| 750/750 [00:01<00:00, 580.63it/s]


epoch: 12, val_loss: 0.11837128373757734, val_acc: 0.9586666666666667


epoch: 13, Training: 100%|██████████| 4800/4800 [00:33<00:00, 142.87it/s]


epoch: 13, train_loss: 0.03673444270114696, train_acc: 0.9889583333333334


epoch: 13, Validation: 100%|██████████| 750/750 [00:01<00:00, 581.14it/s]


epoch: 13, val_loss: 0.10449965068572074, val_acc: 0.9653333333333334


epoch: 14, Training: 100%|██████████| 4800/4800 [00:33<00:00, 142.12it/s]


epoch: 14, train_loss: 0.032505197579229775, train_acc: 0.9897916666666666


epoch: 14, Validation: 100%|██████████| 750/750 [00:01<00:00, 582.40it/s]


epoch: 14, val_loss: 0.13182477585787794, val_acc: 0.9586666666666667


epoch: 15, Training: 100%|██████████| 4800/4800 [00:34<00:00, 140.43it/s]


epoch: 15, train_loss: 0.030274932968942694, train_acc: 0.9910416666666667


epoch: 15, Validation: 100%|██████████| 750/750 [00:01<00:00, 577.45it/s]


epoch: 15, val_loss: 0.1396700254884724, val_acc: 0.9586666666666667


epoch: 16, Training: 100%|██████████| 4800/4800 [00:34<00:00, 137.65it/s]


epoch: 16, train_loss: 0.025538020256886167, train_acc: 0.9933333333333333


epoch: 16, Validation: 100%|██████████| 750/750 [00:01<00:00, 567.93it/s]


epoch: 16, val_loss: 0.13760220393896042, val_acc: 0.96


epoch: 17, Training: 100%|██████████| 4800/4800 [00:34<00:00, 139.83it/s]


epoch: 17, train_loss: 0.023228963741200072, train_acc: 0.9927083333333333


epoch: 17, Validation: 100%|██████████| 750/750 [00:01<00:00, 575.17it/s]


epoch: 17, val_loss: 0.14640885205289283, val_acc: 0.9573333333333334


epoch: 18, Training: 100%|██████████| 4800/4800 [00:33<00:00, 141.80it/s]


epoch: 18, train_loss: 0.021430755768538663, train_acc: 0.993125


epoch: 18, Validation: 100%|██████████| 750/750 [00:01<00:00, 581.16it/s]


epoch: 18, val_loss: 0.1596042228017097, val_acc: 0.956


epoch: 19, Training: 100%|██████████| 4800/4800 [00:33<00:00, 141.27it/s]


epoch: 19, train_loss: 0.01835657666195522, train_acc: 0.9945833333333334


epoch: 19, Validation: 100%|██████████| 750/750 [00:01<00:00, 573.00it/s]


epoch: 19, val_loss: 0.16002887532703008, val_acc: 0.9613333333333334


epoch: 20, Training: 100%|██████████| 4800/4800 [00:33<00:00, 141.44it/s]


epoch: 20, train_loss: 0.016151697234945896, train_acc: 0.9947916666666666


epoch: 20, Validation: 100%|██████████| 750/750 [00:01<00:00, 581.20it/s]


epoch: 20, val_loss: 0.1375656074130234, val_acc: 0.9653333333333334


epoch: 21, Training: 100%|██████████| 4800/4800 [00:34<00:00, 140.61it/s]


epoch: 21, train_loss: 0.016278241133570397, train_acc: 0.99375


epoch: 21, Validation: 100%|██████████| 750/750 [00:01<00:00, 548.44it/s]


epoch: 21, val_loss: 0.12109982117990553, val_acc: 0.968


epoch: 24, Training: 100%|██████████| 4800/4800 [00:33<00:00, 141.41it/s]


epoch: 24, train_loss: 0.012350959743515565, train_acc: 0.995625


epoch: 24, Validation: 100%|██████████| 750/750 [00:01<00:00, 576.37it/s]


epoch: 24, val_loss: 0.2090264196116416, val_acc: 0.952


epoch: 25, Training: 100%|██████████| 4800/4800 [00:34<00:00, 141.15it/s]


epoch: 25, train_loss: 0.009624322594257712, train_acc: 0.9975


epoch: 25, Validation: 100%|██████████| 750/750 [00:01<00:00, 568.67it/s]


epoch: 25, val_loss: 0.2581162045689492, val_acc: 0.9453333333333334


epoch: 26, Training: 100%|██████████| 4800/4800 [00:34<00:00, 141.10it/s]


epoch: 26, train_loss: 0.007993334995681823, train_acc: 0.998125


epoch: 26, Validation: 100%|██████████| 750/750 [00:01<00:00, 572.81it/s]


epoch: 26, val_loss: 0.13349657668306006, val_acc: 0.972


epoch: 27, Training: 100%|██████████| 4800/4800 [00:33<00:00, 142.18it/s]


epoch: 27, train_loss: 0.00984551222535738, train_acc: 0.996875


epoch: 27, Validation: 100%|██████████| 750/750 [00:01<00:00, 573.66it/s]


epoch: 27, val_loss: 0.15826526121184106, val_acc: 0.9586666666666667


epoch: 28, Training: 100%|██████████| 4800/4800 [00:34<00:00, 141.11it/s]


epoch: 28, train_loss: 0.007226035031706587, train_acc: 0.998125


epoch: 28, Validation: 100%|██████████| 750/750 [00:01<00:00, 577.09it/s]


epoch: 28, val_loss: 0.1729710191350542, val_acc: 0.9626666666666667


epoch: 29, Training: 100%|██████████| 4800/4800 [00:34<00:00, 140.03it/s]


epoch: 29, train_loss: 0.008916154261791051, train_acc: 0.9975


epoch: 29, Validation: 100%|██████████| 750/750 [00:01<00:00, 573.79it/s]


epoch: 29, val_loss: 0.25695445893220825, val_acc: 0.9493333333333334


epoch: 30, Training: 100%|██████████| 4800/4800 [00:34<00:00, 139.70it/s]


epoch: 30, train_loss: 0.006575552438553746, train_acc: 0.9979166666666667


epoch: 30, Validation: 100%|██████████| 750/750 [00:01<00:00, 580.38it/s]

epoch: 30, val_loss: 0.18147696985641032, val_acc: 0.9613333333333334





#### 测试

In [12]:
import pickle
test_file = "mix_test"
with open(f"/root/autodl-tmp/graph_data/{test_file}.pkl", "rb") as f:
    hc3_test = pickle.load(f)
test_len = len(hc3_test['y'])

In [24]:
from sklearn.metrics import roc_auc_score, f1_score

test_gcnmodel = GCNMulticlass(input_dim, hidden_dim2, output_dim, 3)
# test_gcnmodel = gcnmodel
test_gcnmodel.load_state_dict(torch.load('./model/mix_gcn_model.pth'))
test_gcnmodel.eval()
test_loss = 0.0
correct_predictions = 0
y_pred = list()
with torch.no_grad():
    for i in tqdm(range(test_len),  f"Test"):
        data = Data(x=hc3_test['all_token_embeddings'][i], edge_index=hc3_test['all_edge_index'][i], y=hc3_test['y'][i])
        outputs = test_gcnmodel(data)
        loss = CE_criterion(outputs, data.y.long().view(-1, 1).squeeze(0))
        test_loss += loss.item()
        _, predictions = torch.max(outputs, 1) 
        correct_predictions += (predictions == data.y).sum().item()
        y_pred.append(predictions.item())
y_true = hc3_test['y']
test_loss /= test_len
test_acc = correct_predictions / test_len
test_f1 = f1_score(y_true, y_pred, average='weighted')
print(f"test_loss: {test_loss}, test_acc: {test_acc}, test_f1: {test_f1}")

Test: 100%|██████████| 750/750 [00:01<00:00, 570.34it/s]

test_loss: 0.12448984051934489, test_acc: 0.9693333333333334, test_f1: 0.9693918394675536





In [25]:
with open(f"/root/autodl-tmp/result/test_result.txt", "a", encoding="utf-8") as w:
    w.write(f"{test_file}\tacc: {test_acc}\tf1: {test_f1}\t{datetime.now()}\n")