In [1]:
print("Hello world")

Hello world


In [2]:
import os
import torch
from tqdm import tqdm
from torch_geometric.data import Data, InMemoryDataset
import pickle

class CustomDataset(InMemoryDataset):
    def __init__(self, root, data_list, transform=None, pre_transform=None):
        self.data_list = data_list
        super(CustomDataset, self).__init__(root, transform, pre_transform)
        self.data, self.slices = self.collate(data_list)

    @property
    def raw_file_names(self):
        return []

    @property
    def processed_file_names(self):
        return []

    def download(self):
        pass

    def process(self):
        pass

In [3]:
train_graph_folder = 'train_graph'
train_label_folder = 'train_label'

pt_files = [os.path.join(train_graph_folder, f) for f in os.listdir(train_graph_folder) if f.endswith('.pt')]
pkl_files = [os.path.join(train_label_folder, f) for f in os.listdir(train_label_folder) if f.endswith('.pkl')]

In [4]:
all_data = []
all_labels = []

for pt_file, pkl_file in tqdm(zip(pt_files, pkl_files), desc="Merging files", total=len(pt_files)):
    data = torch.load(pt_file)
    with open(pkl_file, 'rb') as f:
        label = pickle.load(f)

    all_data.append(data)
    all_labels.append(label)

Merging files: 100%|██████████| 441212/441212 [03:55<00:00, 1876.68it/s]


In [7]:
output_graph_file = 'train_graph.pt'
all_data_formal = []
for data in tqdm(all_data):
    if isinstance(data, dict):
        data = Data(**data)
    all_data_formal.append(data)
# 保存合并后的图数据到一个文件中
dataset = CustomDataset('', all_data_formal)
torch.save(dataset, output_graph_file)

100%|██████████| 441212/441212 [00:12<00:00, 34914.57it/s]
Processing...
Done!


In [8]:
output_label_file = 'train_label.pkl'
# Save merged labels to a file
with open(output_label_file, 'wb') as f:
    pickle.dump(all_labels, f)

: 

In [7]:
import os
import torch
from tqdm import tqdm
from torch_geometric.data import Data, InMemoryDataset
import pickle

class CustomDataset(InMemoryDataset):
    def __init__(self, root, data_list, transform=None, pre_transform=None):
        self.data_list = data_list
        super(CustomDataset, self).__init__(root, transform, pre_transform)
        self.data, self.slices = self.collate(data_list)

    @property
    def raw_file_names(self):
        return []

    @property
    def processed_file_names(self):
        return []

    def download(self):
        pass

    def process(self):
        pass

def convert_dict_to_data(data_dict):
    node_type = data_dict['node_type'].float().unsqueeze(1)  # Add feature dimension and convert to float
    num_inverted_predecessors = data_dict['num_inverted_predecessors'].float().unsqueeze(1)  # Add feature dimension and convert to float
    node_features = torch.cat([node_type, num_inverted_predecessors], dim=1)
    edge_index = data_dict['edge_index']
    data = Data(x=node_features, edge_index=edge_index)
    return data

def merge_datasets(train_graph_folder, train_label_folder, output_graph_file, output_label_file):
    pt_files = sorted([os.path.join(train_graph_folder, f) for f in os.listdir(train_graph_folder) if f.endswith('.pt')])
    pkl_files = sorted([os.path.join(train_label_folder, f) for f in os.listdir(train_label_folder) if f.endswith('.pkl')])

    all_data = []
    all_labels = []

    for pt_file, pkl_file in tqdm(zip(pt_files, pkl_files), desc="Merging files", total=len(pt_files)):
        data_dict = torch.load(pt_file)
        data = convert_dict_to_data(data_dict)
        all_data.append(data)

        with open(pkl_file, 'rb') as f:
            label = pickle.load(f)
        all_labels.append(label)

    dataset = CustomDataset('', all_data)
    torch.save(dataset, output_graph_file)

    with open(output_label_file, 'wb') as f:
        pickle.dump(all_labels, f)

    print(f'Merged graph data saved to {output_graph_file}')
    print(f'Merged labels saved to {output_label_file}')

if __name__ == "__main__":
    train_graph_folder = 'tmp_data/train_graph'
    train_label_folder = 'tmp_data/train_label'
    output_graph_file = 'data/merged_graph_data.pt'
    output_label_file = 'data/merged_labels.pkl'
    merge_datasets(train_graph_folder, train_label_folder, output_graph_file, output_label_file)


Merging files: 100%|██████████| 441212/441212 [05:40<00:00, 1294.85it/s]
Processing...
Done!


Merged graph data saved to data/merged_graph_data.pt
Merged labels saved to data/merged_labels.pkl


In [6]:
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GCNConv
import torch.nn.functional as F
# Load merged graph data and labels
dataset = torch.load('data/merged_graph_data.pt')
with open('data/merged_labels.pkl', 'rb') as f:
    labels = pickle.load(f)

# Create data loader
train_loader = DataLoader(dataset, batch_size=32, shuffle=True)

class GNN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super(GNN, self).__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, out_channels)
        self.fc = torch.nn.Linear(out_channels, 1)  # Output is a single number

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        x = F.relu(x)
        x = torch.mean(x, dim=0)  # Mean of node features
        x = self.fc(x)
        return x

model = GNN(in_channels=2, hidden_channels=64, out_channels=32)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = torch.nn.MSELoss()

model.train()
for epoch in range(2):
    total_loss = 0
    for i, data in enumerate(train_loader):
        optimizer.zero_grad()
        out = model(data)
        target = torch.tensor([labels[i]], dtype=torch.float32)
        loss = criterion(out, target)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f'Epoch {epoch+1}, Loss: {total_loss/len(train_loader)}')


Epoch 1, Loss: 0.033740330807631835
Epoch 2, Loss: 0.011013712850399315
