In [11]:
import pandas as pd
import torch
import os
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, global_mean_pool
from torch_geometric.data import Data, DataLoader
import matplotlib.pyplot as plt
import random


## dataloader

In [12]:
def resize_tensor(input_tensor, target_length):
    input_tensor = input_tensor.unsqueeze(1) 
    resized_tensor = F.interpolate(
        input_tensor, 
        size=target_length, 
        mode='linear', 
        align_corners=False
    )
    resized_tensor = resized_tensor.squeeze(1)
    
    return resized_tensor


In [13]:
## labels: day45 baseline 0, day 45 others 1, day90 baseline 2, day others 90 3, day 120 baseline 4, day 120 others 5
import pandas as pd
import torch
import os

## labels: day45 baseline 0, day 45 others 1, day90 baseline 2, day others 90 3, day 120 baseline 4, day 120 others 5
file_folder = "/home/featurize/work/xhh/MEA/data/overfitting"
sub_file_list = os.listdir(file_folder)
all_data = []
for classes in sub_file_list:
    class_path = os.path.join(file_folder, classes)
    graph_list = os.listdir(class_path)
    for file_name in graph_list:    
        file_path = os.path.join(class_path, file_name)
        data_sample = {}
        label = int(classes)
        df = pd.read_csv(file_path)
        data_np = df.values
        data_tensor = torch.tensor(data_np, dtype=torch.float32)
        target_length = 4500 
        data_tensor = resize_tensor(data_tensor, target_length)
        data_sample["data"] = data_tensor
        data_sample["label"] = label
        data_sample["data_name"] = file_name
        all_data.append(data_sample)


In [14]:
train_loader = DataLoader(all_data, batch_size=1, shuffle=True)



# Model

In [15]:
class Encoder(torch.nn.Module):
    def __init__(self, num_node_features, hidden_channels):
        super().__init__()
        self.conv1 = GCNConv(num_node_features, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)
      
    def forward(self, x, edge_index, batch):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        # x = F.dropout(x, training=self.training)
        x = self.conv2(x, edge_index)
        x = F.relu(x)
        x = global_mean_pool(x, batch)
        # x = F.dropout(x, training=self.training)
        max_values, _ = torch.max(x, dim=0)  
        return max_values.unsqueeze(0)

In [16]:
class Decoder(torch.nn.Module):
    def __init__(self, num_node_features, hidden_channels):
        super().__init__()
        self.linear = torch.nn.Linear(hidden_channels, num_node_features)
    def forward(self, x):
        x = F.leaky_relu(x,0.5)
        x = self.linear(x)
        return x
    

In [17]:
class Autoencoder(torch.nn.Module):  
    def __init__(self, num_node_features, hidden_channels):  
        super().__init__()  
        self.encoder = Encoder(num_node_features, hidden_channels)  
        self.decoder = Decoder(num_node_features, hidden_channels) 
    def forward(self, x, edge_index, batch):  
        # Encode the input  
        encoded = self.encoder(x, edge_index, batch)  
        decoded = self.decoder(encoded.unsqueeze(0), edge_index, batch)  
        return encoded, decoded  
  

In [18]:
latent_dim = 16
num_node_features = 4500
hidden_channels = 256
model = Autoencoder(num_node_features, hidden_channels) 

# Train


In [19]:
num_graphs = len(all_data) 
graphs = []
for i in range(num_graphs):
    graph = all_data[i]
    num_nodes = int(graph['data'].shape[0])  # 每个图的节点数目
    node_features = graph['data']            # 节点特征矩阵
    # 创建依次连接的边缘索引
    edge_index = []
    for i in range(num_nodes - 1):
        edge_index.append([i, i+1])
    # 转换为PyTorch张量
    edge_index = torch.tensor(edge_index).t().contiguous()
    # edge_index.append([i+1, i])                      # 如果图是无向的，添加反方向的边

    y = torch.tensor([graph['label']], dtype=torch.long)
    graph_data = Data(x=node_features, edge_index=edge_index, y=y)
    graphs.append(graph_data)
random.seed(42)
random.shuffle(graphs)
train_graphs = graphs 
# test_graphs = graphs[int(num_graphs*0.8):]    # 后20%作为测试集
train_loader = DataLoader(train_graphs, batch_size=128, shuffle=True)
# test_loader = DataLoader(test_graphs, batch_size=32, shuffle=False)

In [20]:
from tqdm import tqdm
epoch = 100   
lerning_rate = 0.001
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
encoder = Encoder(num_node_features, hidden_channels).to(device)
decoder = Decoder(num_node_features, hidden_channels).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=lerning_rate, capturable=True)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epoch, eta_min=0.000001)
mse_loss = torch.nn.MSELoss() 
loss_values = []
for epoch in tqdm(range(epoch)):
    total_loss = 0
    for data in train_loader:
        data = data.to(device)
        encode= encoder(data.x, data.edge_index, data.batch)
        encode_copy = encode.repeat(data.x.shape[0], 1) 
        decode = decoder(encode_copy)
        loss = mse_loss(decode, data.x)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    scheduler.step()
    avg_loss = total_loss / len(train_loader)
    loss_values.append(avg_loss)
    print(f'Epoch {epoch+1}, Loss: {avg_loss}')

  0%|          | 0/100 [00:00<?, ?it/s]

 27%|██▋       | 27/100 [00:00<00:00, 140.72it/s]

Epoch 1, Loss: 7.393680095672607
Epoch 2, Loss: 7.393679618835449
Epoch 3, Loss: 7.393680095672607
Epoch 4, Loss: 7.393679618835449
Epoch 5, Loss: 7.393679618835449
Epoch 6, Loss: 7.393679618835449
Epoch 7, Loss: 7.393679618835449
Epoch 8, Loss: 7.393679618835449
Epoch 9, Loss: 7.393679141998291
Epoch 10, Loss: 7.393680095672607
Epoch 11, Loss: 7.393679618835449
Epoch 12, Loss: 7.393679618835449
Epoch 13, Loss: 7.393679618835449
Epoch 14, Loss: 7.393679618835449
Epoch 15, Loss: 7.393679618835449
Epoch 16, Loss: 7.393679618835449
Epoch 17, Loss: 7.393680095672607
Epoch 18, Loss: 7.393679618835449
Epoch 19, Loss: 7.393679618835449
Epoch 20, Loss: 7.393679618835449
Epoch 21, Loss: 7.393679618835449
Epoch 22, Loss: 7.393679618835449
Epoch 23, Loss: 7.393679618835449
Epoch 24, Loss: 7.393679618835449
Epoch 25, Loss: 7.393679618835449
Epoch 26, Loss: 7.393679618835449
Epoch 27, Loss: 7.393679141998291
Epoch 28, Loss: 7.393679618835449
Epoch 29, Loss: 7.393679618835449
Epoch 30, Loss: 7.39367

 61%|██████    | 61/100 [00:00<00:00, 153.04it/s]

Epoch 32, Loss: 7.393679618835449
Epoch 33, Loss: 7.393679618835449
Epoch 34, Loss: 7.393679141998291
Epoch 35, Loss: 7.393679618835449
Epoch 36, Loss: 7.393679618835449
Epoch 37, Loss: 7.393679618835449
Epoch 38, Loss: 7.393679618835449
Epoch 39, Loss: 7.393679141998291
Epoch 40, Loss: 7.393679618835449
Epoch 41, Loss: 7.393679618835449
Epoch 42, Loss: 7.393680095672607
Epoch 43, Loss: 7.393679141998291
Epoch 44, Loss: 7.393679618835449
Epoch 45, Loss: 7.393679618835449
Epoch 46, Loss: 7.393679618835449
Epoch 47, Loss: 7.393679618835449
Epoch 48, Loss: 7.393679141998291
Epoch 49, Loss: 7.393679618835449
Epoch 50, Loss: 7.393679618835449
Epoch 51, Loss: 7.393679618835449
Epoch 52, Loss: 7.393679618835449
Epoch 53, Loss: 7.393679618835449
Epoch 54, Loss: 7.393680095672607
Epoch 55, Loss: 7.393679141998291
Epoch 56, Loss: 7.393679618835449
Epoch 57, Loss: 7.393679618835449
Epoch 58, Loss: 7.393679141998291
Epoch 59, Loss: 7.393680095672607
Epoch 60, Loss: 7.393679141998291
Epoch 61, Loss

100%|██████████| 100/100 [00:00<00:00, 157.24it/s]

Epoch 63, Loss: 7.393679618835449
Epoch 64, Loss: 7.393679618835449
Epoch 65, Loss: 7.393679618835449
Epoch 66, Loss: 7.393679618835449
Epoch 67, Loss: 7.393679618835449
Epoch 68, Loss: 7.393679618835449
Epoch 69, Loss: 7.393679618835449
Epoch 70, Loss: 7.393679618835449
Epoch 71, Loss: 7.393679618835449
Epoch 72, Loss: 7.393680095672607
Epoch 73, Loss: 7.393679141998291
Epoch 74, Loss: 7.393679618835449
Epoch 75, Loss: 7.393679618835449
Epoch 76, Loss: 7.393679618835449
Epoch 77, Loss: 7.393679618835449
Epoch 78, Loss: 7.393679141998291
Epoch 79, Loss: 7.393679618835449
Epoch 80, Loss: 7.393679618835449
Epoch 81, Loss: 7.393679618835449
Epoch 82, Loss: 7.393679618835449
Epoch 83, Loss: 7.393679618835449
Epoch 84, Loss: 7.393679618835449
Epoch 85, Loss: 7.393679618835449
Epoch 86, Loss: 7.393680095672607
Epoch 87, Loss: 7.393680095672607
Epoch 88, Loss: 7.393679618835449
Epoch 89, Loss: 7.393679618835449
Epoch 90, Loss: 7.393679618835449
Epoch 91, Loss: 7.393679618835449
Epoch 92, Loss


