In [3]:
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, global_mean_pool
from torch_geometric.data import Data
import matplotlib.pyplot as plt

torch.__version__    #查看cpu版本
# torch.version.cuda     #查看gpu版本

'2.2.2+cu121'

In [9]:
import torch
import torch.nn.functional as F
from torch_geometric.nn import GAE, GCNConv, global_mean_pool
from torch_geometric.data import Data, DataLoader
import matplotlib.pyplot as plt
import pandas as pd
import os
import random
from sklearn.metrics import confusion_matrix
import seaborn as sns
import numpy as np

# 定义图自编码器模型
class GCN(torch.nn.Module):
    def __init__(self, num_node_features, hidden_channels):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(num_node_features, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        return x

# 定义图自编码器
class GAEModel(GAE):
    def __init__(self, encoder):
        super(GAEModel, self).__init__(encoder)

# 数据处理和加载
def resize_tensor(input_tensor, target_length):
    input_tensor = input_tensor.unsqueeze(1)
    resized_tensor = F.interpolate(input_tensor, size=target_length, mode='linear', align_corners=False)
    return resized_tensor.squeeze(1)

file_folder = "/home/featurize/work/ylx/MEA/data"
all_data = []
target_length = 4570

# 读取数据
for file_name in os.listdir(file_folder):
    data_sample = {}
    cls_mea = file_name.split('_')[-2]
    cls_lhh = 0 if cls_mea == 'baseline' else 1
    file_path = os.path.join(file_folder, file_name)
    df = pd.read_csv(file_path)
    data_np = df.values
    data_tensor = torch.tensor(data_np, dtype=torch.float32)
    data_tensor = resize_tensor(data_tensor, target_length)

    data_name = file_name.split('.')[0]
    data_sample['data'] = data_tensor
    data_sample['label'] = cls_lhh
    data_sample['data_name'] = data_name
    all_data.append(data_sample)

num_node_features = target_length
hidden_channels = 3000
num_graphs = len(all_data)
graphs = []

for i in range(num_graphs):
    graph = all_data[i]
    num_nodes = int(graph['data'].shape[0])
    node_features = graph['data']
    
    edge_index = []
    for i in range(num_nodes - 1):
        edge_index.append([i, i + 1])
    edge_index = torch.tensor(edge_index).t().contiguous()
    y = torch.tensor([graph['label']], dtype=torch.long)
    graph_data = Data(x=node_features, edge_index=edge_index, y=y)
    graphs.append(graph_data)

random.seed(42)
random.shuffle(graphs)
train_int = int(len(graphs) * 0.9)
train_graphs = graphs[:train_int]
test_graphs = graphs[train_int:]

train_loader = DataLoader(train_graphs, batch_size=32, shuffle=True)
test_loader = DataLoader(test_graphs, batch_size=1, shuffle=False)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
encoder = GCN(num_node_features=num_node_features, hidden_channels=hidden_channels)
model = GAEModel(encoder).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

# 训练过程
for epoch in range(600):
    total_loss = 0
    for data in train_loader:
        data = data.to(device)
        optimizer.zero_grad()
        z = model.encode(data.x, data.edge_index)
        loss = model.recon_loss(z, data.edge_index)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    avg_loss = total_loss / len(train_loader)
    if epoch % 10 == 0:
        print(f'Epoch {epoch + 1}, Loss: {avg_loss}')

# 测试模型
model.eval()
correct = 0
total = 0
all_preds = []
all_labels = []

with torch.no_grad():
    for data in test_loader:
        data = data.to(device)
        z = model.encode(data.x, data.edge_index)
        pred = model.decode(z, data.edge_index).argmax(dim=1)
        all_preds.extend(pred.cpu().tolist())
        all_labels.extend(data.y.cpu().tolist())
        correct += int((pred == data.y).sum())
        total += data.y.size(0)

accuracy = correct / total
print(f'Test Accuracy: {accuracy}')

# 计算混淆矩阵
conf_matrix = confusion_matrix(all_labels, all_preds)

# 提取混淆矩阵的值
TN, FP, FN, TP = conf_matrix.ravel()

# 计算准确率 (Accuracy)
accuracy = (TP + TN) / (TP + TN + FP + FN)
print(f'Accuracy: {accuracy:.4f}')

# 计算精确率 (Precision)
precision = TP / (TP + FP)
print(f'Precision: {precision:.4f}')

# 计算召回率 (Recall)
recall = TP / (TP + FN)
print(f'Recall: {recall:.4f}')

# 可视化混淆矩阵
fig, ax = plt.subplots(figsize=(4, 3))
sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues', ax=ax)

# 标注 TP, FP, TN, FN
labels = ['TN', 'FP', 'FN', 'TP']
for i in range(2):
    for j in range(2):
        ax.text(j, i, f"{labels[i * 2 + j]}={conf_matrix[i, j]}", ha='center', va='center', color='red')

ax.set_xlabel('Predicted Labels')
ax.set_ylabel('True Labels')
ax.set_title('Confusion Matrix with Annotations')
plt.show()




Epoch 1, Loss: 34.538777351379395
Epoch 11, Loss: 34.53877544403076
Epoch 21, Loss: 34.53877639770508
Epoch 31, Loss: 34.538777351379395
Epoch 41, Loss: 34.53877830505371
Epoch 51, Loss: 34.538777351379395
Epoch 61, Loss: 34.53877639770508
Epoch 71, Loss: 34.538777351379395
Epoch 81, Loss: 34.53877830505371
Epoch 91, Loss: 34.538777351379395
Epoch 101, Loss: 34.53877639770508
Epoch 111, Loss: 34.538777351379395
Epoch 121, Loss: 34.538777351379395
Epoch 131, Loss: 34.538777351379395
Epoch 141, Loss: 34.53877830505371
Epoch 151, Loss: 34.53877544403076
Epoch 161, Loss: 34.53877639770508
Epoch 171, Loss: 34.53877639770508
Epoch 181, Loss: 34.53877639770508
Epoch 191, Loss: 34.53877544403076
Epoch 201, Loss: 34.538777351379395
Epoch 211, Loss: 34.53877639770508
Epoch 221, Loss: 34.53877639770508
Epoch 231, Loss: 34.538777351379395
Epoch 241, Loss: 34.538777351379395
Epoch 251, Loss: 34.538777351379395
Epoch 261, Loss: 34.538777351379395
Epoch 271, Loss: 34.53877544403076
Epoch 281, Loss: 3

KeyboardInterrupt: 