In [1]:
import os
import torch
import torch.nn.functional as F
from torch_geometric.loader import DataLoader
from torch_geometric.data import Data
from torch.nn import Linear
from torch_geometric.nn import GATConv, global_mean_pool
from torch_geometric.nn import SAGEConv
from torch_geometric.nn import GINConv
import logging
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

In [2]:
device = torch.device('cpu')

In [3]:
class GNN(torch.nn.Module):
    def __init__(self, num_features=15 , embedding_size=64, dropout_rate=0.3, num_classes=5):
        super(GNN, self).__init__()
        self.initial_conv = SAGEConv(num_features, embedding_size)
        self.conv1 = SAGEConv(embedding_size, embedding_size)
        self.conv2 = SAGEConv(embedding_size, embedding_size)  # 新增的一层
        self.dropout = torch.nn.Dropout(p=dropout_rate)
        self.out = Linear(embedding_size, num_classes)

    def forward(self, x, edge_index, batch):
        x = F.relu(self.initial_conv(x, edge_index))
        x = self.dropout(x)  # 加入正则化
        x = F.relu(self.conv1(x, edge_index))
        x = self.dropout(x)  # 加入正则化
        x = F.relu(self.conv2(x, edge_index))  # 新增的一层
        x = self.dropout(x)  # 加入正则化
        x = global_mean_pool(x, batch)  # 聚合节点特征到图级特征
        x = self.out(x)
        return x
    
# 加载完整的模型
complete_model_save_path = 'C_model/complete_model_depression.pth'
model = torch.load(complete_model_save_path, map_location=device)

# 将模型设置为评估模式
model.eval()    

GNN(
  (initial_conv): SAGEConv(15, 64, aggr=mean)
  (conv1): SAGEConv(64, 64, aggr=mean)
  (conv2): SAGEConv(64, 64, aggr=mean)
  (dropout): Dropout(p=0.3, inplace=False)
  (out): Linear(in_features=64, out_features=5, bias=True)
)

In [5]:
def predict_from_graph(graph_load_path, model):
    try:
        logging.info(f"Trying to load graph data from: {graph_load_path}")
        
        # 检查文件是否存在
        if not os.path.exists(graph_load_path):
            logging.error("Graph file does not exist.")
            return None

        # 加载图数据
        standardized_graph = torch.load(graph_load_path)
        logging.info(f"Standardized graph data loaded from {graph_load_path}")

        # 创建DataLoader
        data_list = [standardized_graph]
        data_loader = DataLoader(data_list, batch_size=1, shuffle=False)

        # 预测
        model.to(device)
        predictions = []
        for data in data_loader:
            data = data.to(device)
            with torch.no_grad():
                output = model(data.x, data.edge_index, data.batch)
                predicted_class = torch.argmax(output, dim=1).item()  # 提取预测的类别
            predictions.append(predicted_class)

        logging.info(f"Prediction results: {predictions}")
        return predictions

    except Exception as e:
        logging.error(f"Error in prediction: {e}")
        raise

if __name__ == "__main__":
    # 加载预训练模型
    model_path = 'C_model/complete_model_depression.pth'
    model = torch.load(model_path, map_location=device)
    model.eval()

    # 指定图数据的路径
    #graph_load_path = "tt/2.2_1.pt"
    #graph_load_path = "513d6902fdc9f03587004592.pt"

    # 进行预测
    predictions = predict_from_graph(graph_load_path, model)
    if predictions is not None:
        print(f"Predicted classes: {predictions}")

Predicted classes: [3]


In [None]:
#  beautiful, safe, depression, wealthy, boring,