In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch_geometric
from torch_geometric.data import Data, DataLoader
import torch.nn.functional as F

import scipy.io
from tqdm import tqdm
import random
from load_data import load_graph_data

In [2]:
# CNN特征提取器
class CNNFeatureExtractor(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(CNNFeatureExtractor, self).__init__()
        self.conv1 = nn.Conv1d(input_dim, 64, kernel_size=3, padding=1)
        self.conv2 = nn.Conv1d(64, 128, kernel_size=3, padding=1)
        self.fc1 = nn.Linear(128, 256)
        self.fc2 = nn.Linear(256, output_dim)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.fc1(x.squeeze(0).T))
        return self.fc2(x)


In [None]:
data_dir = './data'
graph_seq = 6
signal_len = 1000
all_data = load_graph_data(data_dir, graph_seq, signal_len)

Constructing Graph Structure...
1723
Loading Nodes Features...


In [None]:
# 训练 M，确保 PSD
class GraphDenoisingModel(torch.nn.Module):
    def __init__(self, input_dim, feature_dim, num_nodes, beta):
        super().__init__()
        self.cnn = CNNFeatureExtractor(input_dim, feature_dim)  # CNN 提取特征
        self.Q = torch.nn.Parameter(torch.randn(feature_dim, feature_dim))  # 训练 Q 确保 M 是 PSD
        self.beta = beta
        self.num_nodes = num_nodes
        self.optimizer = torch.optim.Adam(self.parameters(), lr=0.001)

    def compute_M(self):
        """ M = Q Q^T 确保 PSD """
        return torch.matmul(self.Q, self.Q.T)

    def compute_distance(self, f):
        """ 计算 d_ij = (f_i - f_j)^T M (f_i - f_j) """
        M = self.compute_M()
        diff = f.unsqueeze(1) - f.unsqueeze(0)  # (N, N, d)
        dists = torch.einsum('bnd,dd,bnd->bn', diff, M, diff)  # 计算 (fi - fj)^T M (fi - fj)
        return dists

    def compute_weights(self, dists):
        """ 计算边权重 w_{i,j} = β_i β_j exp(-d_ij) """
        weights = torch.exp(-dists)  # 计算 exp(-d_ij)
        signs = self.beta.unsqueeze(1) * self.beta.unsqueeze(0)  # 计算 β_i * β_j
        return signs * weights  # 保证符号符合平衡图

    def forward(self, data):
        x = data.x.float()
        x = x + 0.1 * torch.randn_like(x)
        x = self.cnn(x.permute(1, 0).unsqueeze(0))  # CNN 提取特征 f_i
        dists = self.compute_distance(x)  # 计算距离 d_{i,j}
        weights = self.compute_weights(dists)  # 计算 w_{i,j}
        
        # 计算拉普拉斯矩阵 Lb
        degree = torch.diag(weights.sum(dim=1))
        Lb = degree - weights
        
        return Lb, x

    def optimize_beta(self, data):
        """ 迭代优化 β """
        num_nodes = self.num_nodes
        edge_index = data.edge_index
        x = data.x.float()
        Lb , _ = self.forward(data)
        origin_loss = torch.matmul(x.T, torch.matmul(Lb, x)).mean()
        
        for i in range(num_nodes):
            connected_nodes = edge_index[1][edge_index[0] == i]
            if len(connected_nodes) == 0:
                continue
            
            x = data.x.float()
            
            self.beta[i] *= -1
            updated_Lb, _ = self.forward(data)
            updated_loss = torch.matmul(x.T, torch.matmul(updated_Lb, x)).mean()
            if updated_loss > origin_loss:
                self.beta[i] *= -1
            else:
                origin_loss = updated_loss

    def train_step(self, data):
        """ 训练 CNN + M，优化 β """
        self.optimizer.zero_grad()
        Lb, x_pred = self.forward(data)

        # 计算损失
        x = data.x.float()
        loss_denoising = F.mse_loss(x_pred, x)
        # loss_GLR = torch.matmul(x.T, torch.matmul(Lb, x)).mean()
        
        loss = loss_denoising # + 0.1 * loss_GLR
        loss.backward()  # 反向传播
        self.optimizer.step()  # 更新 CNN 和 M

        return loss.item()


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

num_nodes = all_data[0].x.shape[0]
input_dim = all_data[0].x.shape[1]
feature_dim = signal_len
beta = scipy.io.loadmat('./results/cluster_labels.mat')['cluster_labels'][0]
beta = torch.tensor(list(beta) * graph_seq, dtype=torch.float32).to(device)
model = GraphDenoisingModel(input_dim, feature_dim, num_nodes, beta).to(device)
train_loader = DataLoader(all_data, batch_size=1, shuffle=True)
for epoch in range(100):
    loss = 0.0
    for data in tqdm(train_loader):
        data = data.to(device)
        loss += model.train_step(data) / len(all_data)
    print(f"Epoch {epoch}, Loss: {loss}")
    if epoch % 10 == 9:
        for data in tqdm(train_loader):
            if random.random() < 0.1:
                model.optimize_beta(data.to(device))

100%|██████████| 10356/10356 [01:33<00:00, 110.58it/s]


Epoch 0, Loss: 1178.5118651288785


100%|██████████| 10356/10356 [01:34<00:00, 110.02it/s]


Epoch 1, Loss: 1178.372565959952


100%|██████████| 10356/10356 [01:39<00:00, 103.99it/s]


Epoch 2, Loss: 1178.8112490776589


100%|██████████| 10356/10356 [01:37<00:00, 106.12it/s]


Epoch 3, Loss: 1178.3734847616106


100%|██████████| 10356/10356 [01:40<00:00, 103.24it/s]


Epoch 4, Loss: 1178.3734074368256


100%|██████████| 10356/10356 [01:37<00:00, 106.03it/s]


Epoch 5, Loss: 1178.3738368759314


100%|██████████| 10356/10356 [01:40<00:00, 102.96it/s]


Epoch 6, Loss: 1178.3705490850152


100%|██████████| 10356/10356 [01:37<00:00, 106.28it/s]


Epoch 7, Loss: 1178.3703957567236


100%|██████████| 10356/10356 [01:36<00:00, 107.69it/s]


Epoch 8, Loss: 1178.3723926905188


100%|██████████| 10356/10356 [01:36<00:00, 107.61it/s]


Epoch 9, Loss: 1178.3741416454154


 25%|██▌       | 2620/10356 [06:31<19:14,  6.70it/s]  


KeyboardInterrupt: 