In [1]:
import os
import re
import pandas as pd
import numpy as np

import torch
import networkx as nx
from torch_geometric.data import Data

from tqdm import tqdm
import time
import matplotlib.pyplot as plt

### step1:extract_internal_coordinates_with_atomic_info from dft_logs

In [2]:
'''处理多个文件的函数'''
def extract_internal_coordinates_with_atomic_info(path, atomic_mass_to_element, atomic_mass_to_atomic_number):
    """
    Processes all log files in the specified directory, extracting internal coordinates,
    Cartesian coordinates, and energy.
    Parameters:
    - path: str, the directory path containing log files
    - atomic_mass_to_element: dict, mapping from atomic mass to element symbol
    - atomic_mass_to_atomic_number: dict, mapping from atomic mass to atomic number
    Returns:
    - data_dicts: list of dictionaries containing data for each molecule
    """
    data_dicts = []
    
    # Get the list of .log files in the directory
    files = [f for f in os.listdir(path) if f.endswith('.log')]
    file_paths = [os.path.join(path, file) for file in files]
    
    # Add tqdm for progress bar
    for file_path in tqdm(file_paths, desc="Processing log files", unit="file"):
        with open(file_path, "r") as file:
            log_content = file.read()
            filename = os.path.basename(file_path).split('.')[0]

            # Step 1: Extract atomic weights to create atom-to-mass mapping
            atom_pattern = re.compile(r"IAtWgt=\s+([\d\s]+)")
            atom_matches = atom_pattern.findall(log_content)
            # Flatten the list and create atom-to-mass mapping (atom indices start from 0)
            atomic_weights = [int(weight) for line in atom_matches for weight in line.split()]
            atom_to_mass = {i: atomic_weights[i] for i in range(len(atomic_weights))}

            # Step 2: Find the last "Optimized Parameters" section
            optimized_parameters_section = log_content.split("Optimized Parameters")[-1]
            # Regex patterns for R, A, D types
            pattern_r = re.compile(r"R\(([\d,]+)\)\s+([\d.]+)")
            pattern_a = re.compile(r"A\(([\d,]+)\)\s+([\d.]+)")
            pattern_d = re.compile(r"D\(([\d,]+)\)\s+(-?[\d.]+)")

            # Extract R, A, D values with atomic mass information
            data = []
            for definition, value in pattern_r.findall(optimized_parameters_section):
                atoms = definition.split(",")
                masses = [atom_to_mass.get(int(atom) - 1, "Unknown") for atom in atoms]
                data.append([f"R({definition})", float(value), masses, filename])
            for definition, value in pattern_a.findall(optimized_parameters_section):
                atoms = definition.split(",")
                masses = [atom_to_mass.get(int(atom) - 1, "Unknown") for atom in atoms]
                data.append([f"A({definition})", float(value), masses, filename])
            for definition, value in pattern_d.findall(optimized_parameters_section):
                atoms = definition.split(",")
                masses = [atom_to_mass.get(int(atom) - 1, "Unknown") for atom in atoms]
                data.append([f"D({definition})", float(value), masses, filename])
            # Create DataFrame for internal coordinates
            df = pd.DataFrame(data, columns=["Definition", "Value", "Element Types", "File"])

            # Step 3: Extract Cartesian coordinates from "Standard orientation"
            sections = log_content.split("Standard orientation")
            if len(sections) > 1:
                last_section = sections[-1]
                lines = last_section.strip().split("\n")
                # Find the indices of the coordinate tabel
                dash_line_indices = [i for i, line in enumerate(lines) if re.match(r'\s*-+\s*', line)]
                if len(dash_line_indices) >= 3:
                    start = dash_line_indices[1] + 1  # Start after second dashed line
                    end = dash_line_indices[2]        # End at third dashed line
                    # parse the coordinate lines
                    coordinate_data = []
                    for line in lines[start:end]:
                        tokens = line.strip().split()
                        if len(tokens) == 6:
                            center_number = int(tokens[0]) - 1  # Adjust index to start from 0 
                            atomic_number = int(tokens[1])
                            x_coord = float(tokens[3])
                            y_coord = float(tokens[4])
                            z_coord = float(tokens[5])
                            coordinate_data.append([center_number, atomic_number, x_coord, y_coord, z_coord])
                    # Creat DataFrame of coordinates
                    coord_df = pd.DataFrame(coordinate_data, columns=["Atom Index", "Atomic Number", "X", "Y", "Z"])
                else:
                    print(f"Warning: Coordinate table not found in {filename}.")
                    coord_df = pd.DataFrame()
            else:
                print(f"Warning: 'Standard orientation' not found in {filename}.")
                coord_df = pd.DataFrame()
            # Map mass to coord_df
            coord_df['Mass'] = coord_df['Atom Index'].map(atom_to_mass)

            # Step 4: Extract final energy from the log file
            energy_pattern = re.compile(r'SCF Done:\s+E\([^\)]+\)\s+=\s+(-?\d+\.\d+(?:[DE][-+]\d+)?)')
            energy_matches = energy_pattern.findall(log_content)
            if energy_matches:
                # Handle 'D' notation by replacing it with 'E' for float conversion
                energy_values = [e.replace('D', 'E') for e in energy_matches]
                final_energy = float(energy_values[-1])
            else:
                final_energy = None
                print(f"Warning: Final energy not found in {filename}.")

            # Store all data in a dictionary
            data_dict = {
                'Filename': filename,
                'final_energy': final_energy,
                'df': df,
                'coord_df': coord_df            
            }
            data_dicts.append(data_dict)
    return data_dicts

In [4]:
import os

# Your dictionaries
atomic_mass_to_element = {
    1: "H", 11: "B", 12: "C", 14: "N", 16: "O", 19: "F",
    28: "Si", 31: "P", 32: "S", 35: "Cl", 79: "Br", 80: "Se", 127: "I",  
}

atomic_mass_to_atomic_number = {
    1: 1, 11: 5, 12: 6, 14: 7, 16: 8, 19: 9,
    28: 14, 31: 15, 32: 16, 35: 17, 79: 35, 80: 34, 127: 53
}

# Path to the directory containing log files
log_dir = './data/normal-log-files'

# Call the function
data_dicts = extract_internal_coordinates_with_atomic_info(
    log_dir,
    atomic_mass_to_element,
    atomic_mass_to_atomic_number
)

print(f"Processed {len(data_dicts)} molecules.")
# data_dicts

#### 在data_dicts中加入HOMO、LUMO、Gap数据以及光谱数据

In [3]:
# path = r'C:\Users\xiaoyu\Desktop\click\4-raman_info\Extract_info_from_initial_logs\Homo_Lumo_scaled.csv'
path = './generate_new_molecules/data/Homo_Lumo_scaled.csv'
homo_lumo_df = pd.read_csv(path)
homo_lumo_df.head()  # 365 rows × 4 columns

Unnamed: 0,HOMO,LUMO,Gap,Filename
0,1.471878,0.228315,-0.804732,1
1,1.199832,-0.002504,-0.818722,10
2,1.37828,-0.297496,-1.194667,100
3,0.0163,0.302149,0.249581,1000
4,-0.099167,-0.019409,0.050744,1001


In [4]:
homo_lumo_df[homo_lumo_df['Filename']==7562]

Unnamed: 0,HOMO,LUMO,Gap,Filename
7147,1.14516,-0.54936,-1.253306,7562


In [2]:
# # path2 = r'C:\Users\xiaoyu\Desktop\click\4-raman_info\3-all_raman_logs_No.2'
# # ir_df = pd.read_csv(os.path.join(path2, 'ir_expanded_scaled.csv'), index_col=False)
# # ir_df.head()  # 365 rows × 3601 columns

# # ir_path = r'C:\Users\xiaoyu\Desktop\click\4-raman_info\Extract_info_from_initial_logs'
# # ir_df_2 = pd.read_csv(os.path.join(ir_path, 'ir_expanded_scaled_2.csv'), index_col=False)

ir_path = 'generate_new_molecules/data/IrSpecInfo_Scaled.csv'
extent_ir_df = pd.read_csv(ir_path, index_col=False)
extent_ir_df.head()

In [7]:
# raman_df = pd.read_csv(os.path.join(path2, 'raman_expanded_scaled.csv'), index_col=False)
# raman_df.head()  # 365 rows × 3601 columns

In [8]:
len(data_dicts), 
# data_dicts[0]

(9277,)

In [9]:
# final_rdkit_descriptors_df = pd.read_csv(r'C:\Users\xiaoyu\Desktop\click\jupyter\final_rdkit_descriptors.csv')
# final_rdkit_descriptors_df.head()  # (363, 17)

In [10]:
data_dicts[0]

{'Filename': '1',
 'final_energy': -1656.87246616,
 'df':          Definition     Value     Element Types File
 0            R(1,2)    1.5082          [12, 12]    1
 1           R(1,44)    1.0957           [12, 1]    1
 2           R(1,45)    1.0983           [12, 1]    1
 3           R(1,46)    1.0947           [12, 1]    1
 4            R(2,3)    1.3982          [12, 12]    1
 ..              ...       ...               ...  ...
 409  D(70,39,40,71)    0.3324    [1, 12, 12, 1]    1
 410    D(5,42,43,2)    0.9751  [12, 12, 12, 12]    1
 411   D(5,42,43,73) -178.1881   [12, 12, 12, 1]    1
 412   D(72,42,43,2)  178.8797   [1, 12, 12, 12]    1
 413  D(72,42,43,73)   -0.2835    [1, 12, 12, 1]    1
 
 [414 rows x 4 columns],
 'coord_df':     Atom Index  Atomic Number         X         Y         Z  Mass
 0            0              6 -6.821811 -0.984771 -1.781673    12
 1            1              6 -5.376806 -1.083723 -1.360979    12
 2            2              6 -4.345296 -0.658895 -2.2

In [16]:
# 1. 将 data_dicts 转换为以 Filename 为键的字典
data_dicts_dict = {data_dict['Filename']: data_dict for data_dict in data_dicts}
# 2. 获取在所有数据框中都存在的 filename 列表
valid_filenames = set(homo_lumo_df['Filename'])
# 3. 更新 data_dicts_dict，仅保留有效的分子
data_dicts_dict = {filename: data_dicts_dict[str(filename)] for filename in valid_filenames}
# 4. 为 data_dicts 添加 HOMO, LUMO 和 Gap 信息
for idx, row in homo_lumo_df.iterrows():
    filename = row['Filename']
    homo = row['HOMO']
    lumo = row['LUMO']
    gap = row['Gap']
    
    # 仅更新 valid_filenames 中存在的分子
    if filename in data_dicts_dict:
        data_dict = data_dicts_dict[filename]
        data_dict['HOMO'] = homo
        data_dict['LUMO'] = lumo
        data_dict['Gap'] = gap

# # 5. 为 data_dicts 添加 IR 光谱
# for idx, row in ir_df_2.iterrows():
#     filename = row['Filename']
#     ir_spectrum = row.drop('Filename').values  # 获取 IR 光谱数据（去除 Filename 列）
#     ir_spectrum = np.array(ir_spectrum, dtype=np.float64)
    
# #     仅更新 valid_filenames 中存在的分子
#     if filename in data_dicts_dict:
#         data_dict = data_dicts_dict[filename]
#         data_dict['IR_Spectrum'] = ir_spectrum

# # 6. 为 data_dicts 添加 Raman 光谱
# for idx, row in raman_df.iterrows():
#     filename = row['Filename']
#     raman_spectrum = row.drop('Filename').values  # 获取 Raman 光谱数据（去除 Filename 列）
#     raman_spectrum = np.array(raman_spectrum, dtype=np.float64)
    
#     # 仅更新 valid_filenames 中存在的分子
#     if filename in data_dicts_dict:
#         data_dict = data_dicts_dict[filename]
#         data_dict['Raman_Spectrum'] = raman_spectrum

# 7. 为 data_dicts 添加 rdkit描述符
# for idx, row in final_rdkit_descriptors_df.iterrows():
#     filename = row['Filename']
#     rdkit_desc = final_rdkit_descriptors_df.iloc[idx, 2:].values
    
#     # 仅更新 valid_filenames 中存在的分子
#     if filename in data_dicts_dict:
#         data_dict = data_dicts_dict[filename]
#         data_dict['rdkit_desc'] = rdkit_desc
        
# 8. 检查最终合并结果（仅输出更新后的有效分子数据）
updated_data_dicts = list(data_dicts_dict.values())  # 获取更新后的数据列表
i = 0
for data_dict in updated_data_dicts:
    i+=1
#     print(f"Filename: {data_dict['Filename']}")
#     print(f"HOMO: {data_dict['HOMO']}, LUMO: {data_dict['LUMO']}, Gap: {data_dict['Gap']}")
#     print(f"IR Spectrum Length: {len(data_dict['IR_Spectrum'])}")
#     print(f"Raman Spectrum Length: {len(data_dict['Raman_Spectrum'])}")
#     print(f"rdkit_desc Length: {len(data_dict['rdkit_desc'])}")
#     print("------------------------------------------------------------")
print(f'{i}')

9277


In [17]:
len(updated_data_dicts)

9277

# updated_data_dicts[0]

### step2: Construct GNN

In [18]:
data_dicts[0]

{'Filename': '1',
 'final_energy': -1656.87246616,
 'df':          Definition     Value     Element Types File
 0            R(1,2)    1.5082          [12, 12]    1
 1           R(1,44)    1.0957           [12, 1]    1
 2           R(1,45)    1.0983           [12, 1]    1
 3           R(1,46)    1.0947           [12, 1]    1
 4            R(2,3)    1.3982          [12, 12]    1
 ..              ...       ...               ...  ...
 409  D(70,39,40,71)    0.3324    [1, 12, 12, 1]    1
 410    D(5,42,43,2)    0.9751  [12, 12, 12, 12]    1
 411   D(5,42,43,73) -178.1881   [12, 12, 12, 1]    1
 412   D(72,42,43,2)  178.8797   [1, 12, 12, 12]    1
 413  D(72,42,43,73)   -0.2835    [1, 12, 12, 1]    1
 
 [414 rows x 4 columns],
 'coord_df':     Atom Index  Atomic Number         X         Y         Z  Mass
 0            0              6 -6.821811 -0.984771 -1.781673    12
 1            1              6 -5.376806 -1.083723 -1.360979    12
 2            2              6 -4.345296 -0.658895 -2.2

In [19]:
def build_and_debug_molecule_graph(data_dicts, atomic_mass_to_element, atomic_mass_to_atomic_number):
    """
    为多个分子构建分子图,加入HOMO、LUMO、Gap以及红外光谱作为全局信息.

    参数：
    - data_dicts: 包含每个分子数据的字典列表
    - atomic_mass_to_element: dict，原子质量到元素符号的映射
    - atomic_mass_to_atomic_number: dict，原子质量到原子序数的映射

    返回：
    - data_list: 包含每个分子图的 Data 对象列表
    """
    data_list = []
    start_time = time.time()  # Start time for performance tracking
    
    # Add tqdm for progress bar
    for data_dict in tqdm(data_dicts, desc="Building molecular graphs", unit="molecule"):
        file_name = data_dict['Filename']
        final_energy = data_dict['final_energy']
        HOMO = data_dict['HOMO']
        LUMO = data_dict['LUMO']
        gap = data_dict['Gap']
#         ir_spectrum = data_dict['IR_Spectrum']  # 3600维红外光谱数据
#         raman_spectrum = data_dict['Raman_Spectrum']  # 2000维拉曼光谱数据        
        # 尝试获取 rdkit_desc，若没有则跳过该分子
#         rdkit_desc = data_dict.get('rdkit_desc', None)
#         if rdkit_desc is None:
#             print(f"警告：文件 {file_name} 中没有 rdkit_desc 字段，跳过此分子。")
#             continue  # 跳过该分子，进入下一个分子的处理
        
#         print(final_energy)
        df = data_dict['df']
        coord_df = data_dict['coord_df']
        molecule_df = df[df['File'] == file_name]
        nodes = {}  # 用于存储节点特征的字典，以原子索引为键
        edges = []  # 用于存储带有特征的边的列表
    
        #  Build nodes using coord_df
        for idx, row in coord_df.iterrows():
            atom_index = int(row['Atom Index'])
            mass = row['Mass']
            x_coord = float(row['X'])
            y_coord = float(row['Y'])
            z_coord = float(row['Z'])
            atomic_number = atomic_mass_to_atomic_number.get(mass, None) # 原子序数
            element = atomic_mass_to_element.get(mass, 'Unknown')  # 原子种类
            if element == 'Unknown':
                print(f"警告：在文件 {file_name} 中未找到原子质量 {mass} 对应的元素。")
            if atomic_number is None:
                print(f"警告：在文件 {file_name} 中未找到原子质量 {mass} 对应的原子序数。")
            nodes[atom_index] = [atom_index, mass, atomic_number, x_coord, y_coord, z_coord]
        
        # 从内部坐标构建边
        for _, row in molecule_df.iterrows():
            definition = row['Definition']
            value = row['Value']
            # 提取原子索引
            atoms = [int(a) - 1 for a in definition[2:-1].split(",")]
            #print(f"Definition: {definition}, Atom Indices (zero-based): {atoms}")
            # 根据内部坐标类型，添加边信息
            if definition.startswith("R"):  # 键长
                edges.append((atoms[0], atoms[1], value))
    
        # 创建 NetworkX 图
        G = nx.Graph()
        for atom_index, features in nodes.items():
            G.add_node(atom_index, features=features)
        # 将边添加到图中
        existing_edges = set()
        for u, v, value in edges:
            if (u, v) not in existing_edges and (v, u) not in existing_edges:
                G.add_edge(u, v, feature=value)
                existing_edges.add((u, v))
    
        # 从图中提取 edge_index 和 edge_attr
        edge_index = []
        edge_attr = []
    
        for u, v, data in G.edges(data=True):
            edge_index.append([u, v])
            edge_attr.append([data['feature']])
    
        # 构建节点特征矩阵 x
        x = torch.tensor(
            [
                [atom_index, mass, atomic_number, x_coord, y_coord, z_coord]
                for atom_index, (atom_index, mass, atomic_number, x_coord, y_coord, z_coord) in sorted(nodes.items())
            ],
            dtype=torch.float
        )
    
        # 构建边张量
        if edge_index:
            edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
            edge_attr = torch.tensor(edge_attr, dtype=torch.float)
        else:
            edge_index = torch.empty((2, 0), dtype=torch.long)
            edge_attr = torch.empty((0,), dtype=torch.float)
    
        # 创建 PyTorch Geometric Data 对象
        data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr)
        
#         rdkit_desc = np.array(rdkit_desc, dtype=np.float32)
        # Add HOMO, LUMO, Gap, IR spectrum, Raman spectrum, and Energy to global features
        data.global_features = torch.cat([
            torch.tensor([HOMO, LUMO, gap], dtype=torch.float),  # HOMO, LUMO, Gap
#             torch.tensor(ir_spectrum, dtype=torch.float),       # IR spectrum (3600 dimensions)
#             torch.tensor(raman_spectrum, dtype=torch.float),    # Raman spectrum (3600 dimensions)
            torch.tensor([final_energy], dtype=torch.float),     # Energy (if it's part of global information)
#             torch.tensor(rdkit_desc, dtype=torch.float),     # rdkit_desc(15)
        ], dim=0)

        data.file_name = file_name  # 在数据对象中存储文件名
#         print(data)
        # 添加到列表
        data_list.append(data)
    
    return data_list

传入未加入信息的data_dicts

In [104]:
# data_list = build_and_debug_molecule_graph(data_dicts, atomic_mass_to_element, atomic_mass_to_atomic_number)

传加入各种信息的updated_data_dicts

In [20]:
data_list = build_and_debug_molecule_graph(updated_data_dicts, atomic_mass_to_element, atomic_mass_to_atomic_number)

Building molecular graphs: 100%|█████████████████████████████████████████████| 9277/9277 [02:18<00:00, 67.13molecule/s]


In [21]:
data_list[0], data_list[0].file_name, data_list[0].global_features, len(data_list)

(Data(x=[73, 6], edge_index=[2, 80], edge_attr=[80, 1], global_features=[4], file_name='1'),
 '1',
 tensor([ 1.4719e+00,  2.2831e-01, -8.0473e-01, -1.6569e+03]),
 9277)

In [108]:
# log_path = r'C:\Users\xiaoyu\Desktop\click\5-提取分子结构向量\logs\acceptor45.log'
# base_name = os.path.basename(log_path).split('.')[0]
# base_name

### prediction4-副本中用于预测的分子描述符表格，将其与以下信息合并，这样数据排序是一样的，用同样的切割种子切割出同样的训练集和测试集

In [23]:
final_rdkit_descriptors_df = pd.read_csv('./final_extend_data_rdkit_descriptors.csv')
final_rdkit_descriptors_df.head() # (363, 17)

Unnamed: 0,Filename,SMILES,MaxEStateIndex,MinEStateIndex,MinAbsEStateIndex,MaxPartialCharge,BCUT2D_MWHI,BCUT2D_MRLOW
0,1,Cc1ccc(-c2c3ccccc3c(-c3ccccc3)c3c(-c4ccccc4)c4...,2.308093,1.227435,1.227435,-0.000139,14.27395,1.474828
1,10,Brc1ccc2c(-c3ccccc3)c3c(-c4ccccc4)c4ccccc4c(-c...,3.838429,1.070255,1.070255,0.018137,79.918731,1.623463
2,100,Cc1ccc(C#Cc2c3ccccc3c(C#Cc3ccccc3)c3cc4ccccc4c...,3.533942,1.008111,1.008111,0.040654,14.143554,1.459498
3,1000,N#Cc1ccc(-c2c3ccccc3c(-c3ccc(I)cc3)c3ccccc23)cc1,9.173973,0.678583,0.678583,0.09911,126.912704,1.485139
4,1001,N#Cc1ccc(-c2c3ccccc3c(-c3ccccc3)c3ccccc23)cc1C#N,9.539119,0.398424,0.398424,0.100532,14.279392,1.431312


In [42]:
# 3. 为每个分子图添加lifetime值
i = 0
for data in data_list:
    filename = data.file_name  # 获取分子图的文件名
#     print('filename in data_list:', filename)
    data.lifetime = 0  # 为分子图添加lifetime属性
    i+=1
print(f'共为data_list添加了{i}个Filename')

共为data_list添加了9277个Filename


In [44]:
data_list_with_lifetime = [data for data in data_list if hasattr(data, 'lifetime')]
len(data_list_with_lifetime), data_list_with_lifetime[0]

(9277,
 Data(x=[73, 6], edge_index=[2, 80], edge_attr=[80, 1], global_features=[4], file_name='1', lifetime=0))

In [45]:
'''准备数据（图数据和目标属性）'''
def prepare_data_for_training(data_list, target_list, file_names):
    prepared_data = []
    for i in range(len(data_list)):
        data = data_list[i]
        global_features = data.global_features  # 获取全局特征 
        # 创建 PyTorch Geometric 图数据对象
        data_obj = Data(
            x=data.x, 
            edge_index=data.edge_index, 
            edge_attr=data.edge_attr, 
            y=torch.tensor([target_list[i]], dtype=torch.float32), 
            global_features=global_features  # 将全局特征传递给模型
        )
        data_obj.file_name = file_names[i]
        prepared_data.append(data_obj)
    return prepared_data

In [46]:
'''数据准备：切分数据集'''
import torch
from sklearn.model_selection import train_test_split

In [47]:
# 提取图数据和对应的lifetime值
data_list = [data for data in data_list_with_lifetime]  # 图的全部数据
target_list = [data.lifetime for data in data_list_with_lifetime]  # 目标lifetime值
file_names = [data.file_name for data in data_list_with_lifetime]  # 文件名列表

In [48]:
all_data = prepare_data_for_training(data_list, target_list, file_names)

In [50]:
all_data[:10]

[Data(x=[73, 6], edge_index=[2, 80], edge_attr=[80, 1], y=[1], global_features=[4], file_name='1'),
 Data(x=[73, 6], edge_index=[2, 80], edge_attr=[80, 1], y=[1], global_features=[4], file_name='2'),
 Data(x=[73, 6], edge_index=[2, 80], edge_attr=[80, 1], y=[1], global_features=[4], file_name='3'),
 Data(x=[73, 6], edge_index=[2, 80], edge_attr=[80, 1], y=[1], global_features=[4], file_name='4'),
 Data(x=[73, 6], edge_index=[2, 80], edge_attr=[80, 1], y=[1], global_features=[4], file_name='5'),
 Data(x=[70, 6], edge_index=[2, 77], edge_attr=[77, 1], y=[1], global_features=[4], file_name='6'),
 Data(x=[70, 6], edge_index=[2, 77], edge_attr=[77, 1], y=[1], global_features=[4], file_name='7'),
 Data(x=[70, 6], edge_index=[2, 77], edge_attr=[77, 1], y=[1], global_features=[4], file_name='8'),
 Data(x=[70, 6], edge_index=[2, 77], edge_attr=[77, 1], y=[1], global_features=[4], file_name='9'),
 Data(x=[70, 6], edge_index=[2, 77], edge_attr=[77, 1], y=[1], global_features=[4], file_name='10')]

In [52]:
data_list[0], all_data[0]

(Data(x=[73, 6], edge_index=[2, 80], edge_attr=[80, 1], global_features=[4], file_name='1', lifetime=0),
 Data(x=[73, 6], edge_index=[2, 80], edge_attr=[80, 1], y=[1], global_features=[4], file_name='1'))

In [121]:
data_list[0], all_data[0]

(Data(x=[36, 6], edge_index=[2, 39], edge_attr=[39, 1], global_features=[4], file_name='E1', lifetime=16.35322976),
 Data(x=[36, 6], edge_index=[2, 39], edge_attr=[39, 1], y=[1], global_features=[4], file_name='E1'))

In [60]:
'''构建GNN模型'''
import torch
import torch.nn as nn
import torch.optim as optim  # 添加这行导入
import torch.nn.functional as F
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GCNConv
import torch.nn.init as init

class GNNModel(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, global_dim, p=0.2):
        super(GNNModel, self).__init__()
        self.global_dim = global_dim  # 将传入的global_dim存储为实例属性
        self.conv1 = GCNConv(input_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, hidden_dim)
        
#         '''不含光谱高维信息'''
        self.fc1 = nn.Linear(hidden_dim + self.global_dim, 32)   # 输入是 7268，输出 2048
        self.bn1 = nn.BatchNorm1d(32)                        # Batch Normalization
        self.dropout1 = nn.Dropout(p=p)                      # Dropout 防止过拟合
        self.fc2 = nn.Linear(32,output_dim)                  # 最终输出层
        
        # 权重初始化 (He initialization, Kaiming Normal)
        self._initialize_weights()

    def _initialize_weights(self):
        # 使用 He initialization 对每一层的权重进行初始化
        init.kaiming_normal_(self.fc1.weight, mode='fan_in', nonlinearity='relu')
        init.kaiming_normal_(self.fc2.weight, mode='fan_in', nonlinearity='relu')

    def forward(self, x, edge_index, edge_attr, batch, global_features):
#         print("x shape:", x.shape) 
        x = F.relu(self.conv1(x, edge_index, edge_attr))
        x = F.dropout(x, training=self.training)
        x = F.relu(self.conv2(x, edge_index, edge_attr))   # 加了relu函数
#         print("\nx shape:", x.shape)   # torch.Size([1351, 64])
#         print('global_features shape:', global_features.shape)   # torch.Size([230528])
        num_nodes = x.size(0)
#         print('num_nodes:', num_nodes)  # 1351
        # 初始化全局特征广播的列表
        global_features_broadcasted = []
        unique_batch = torch.unique(batch)
        for i,b in enumerate(unique_batch):
            nodes_in_batch = (batch == b).sum().item()  # 计算当前图b的节点数，并存储在nodes_in_batch中。
#             print('nodes_in_batch:\t', nodes_in_batch)
            start_idx = i*self.global_dim
            end_idx = (i+1)*self.global_dim
            global_features_b = global_features[start_idx:end_idx]  # 从global_features中提取图b的全局特征
#             print('global_features_b:\t', global_features_b.shape)
            global_features_broadcasted.append(global_features_b.unsqueeze(0).repeat(nodes_in_batch, 1))
        global_features_broadcasted = torch.cat(global_features_broadcasted, dim=0)     
#         print('global_features_broadcasted shape:', global_features_broadcasted.shape)  #torch.Size([1351, 7204]        
        # 拼接节点特征和全局特征
        x = torch.cat([x, global_features_broadcasted], dim=-1)
#         print('x after cat shape:', x.shape)    # torch.Size([1351, 7268])  7204+64
        # 我们需要将每个图的节点特征进行平均，然后传递给全连接层
        graph_features = []
        start_idx = 0
        for b in unique_batch:
            end_idx = start_idx + (batch == b).sum().item()
            graph_features.append(x[start_idx:end_idx, :].mean(dim=0))
            start_idx = end_idx    
        # 将图特征表示堆叠起来
        graph_features = torch.stack(graph_features)
#         print('graph_features shape:\t', graph_features.shape)  #  torch.Size([32, 7268])
        
        # 通过全连接层进行预测
        graph_features = self.fc1(graph_features)  # 第一层全连接
        graph_features = self.bn1(graph_features)  # Batch Normalization
        graph_features = F.relu(graph_features)    # ReLU 激活
        graph_features = self.dropout1(graph_features)  # Dropout 防止过拟合
        
        '''含有光谱时，注意使用上面的高维全连接'''
        out = self.fc2(graph_features)  # 输出层   

        return out, graph_features

In [61]:
# len(all_predictions), len(all_true_values)

#### 从保存的训练模型中提取向量

In [62]:
def get_global_features(data_loader):
    # 初始化模型
    model = GNNModel(6, 64, 1, global_dim=4, p=0)

    # 加载模型权重
    model_path = 'data/epoch-8-MSE-0.8245-R2-0.7286.pth'  # 替换为你的模型保存路径
    model.load_state_dict(torch.load(model_path))  # 加载保存的权重
    model.eval()  # 切换到评估模式

    # 提取所有训练样本的 global_features
    global_features_list = []
    labels = []  # 这里假设标签（lifetime）是 `batch_data.y`
    filenames = []  # 确保初始化为空列表

    with torch.no_grad():  # 禁用梯度计算，节省内存
        for batch_data in data_loader:
            # 获取数据
            batch = batch_data.batch
            x = batch_data.x
            edge_index = batch_data.edge_index
            edge_attr = batch_data.edge_attr
            global_features = batch_data.global_features
            lifetime = batch_data.y   
            lifetime = lifetime.unsqueeze(1) if lifetime.dim() == 1 else lifetime
            
            filenames_batch = batch_data.file_name  # 假设文件名存储在 `Filename` 中
            # 模型前向传播
            _, global_features_out = model(x, edge_index, edge_attr, batch, global_features)

            # 将每个批次的 global_features 合并到 global_features 列表中
            global_features_list.append(global_features_out.cpu().numpy())  # 需要转为 numpy 数组以便后续处理
            labels.append(lifetime.cpu().numpy())  # 目标值
            filenames.extend(filenames_batch)  # 追加文件名

    # 将所有批次的 global_features 和标签合并为一个矩阵
    global_features = np.concatenate(global_features_list, axis=0)  # 合并所有的 global_features
    labels = np.concatenate(labels, axis=0)  # 合并所有的标签
    
    return global_features, labels, filenames

In [75]:
batch_size=32
data_loader = DataLoader(all_data, batch_size=batch_size, shuffle=False, drop_last=False)

global_features, labels, filenames = get_global_features(data_loader)
labels = labels.reshape(-1, 1)
features = np.concatenate((global_features, labels), axis=1) 
# 创建列名: 1, 2, ..., 128 对应每个特征，最后一列为lifetime
column_names = [str(i+1) for i in range(global_features.shape[1])] + ["Lifetime"]
df = pd.DataFrame(features, columns=column_names)
df['Filename'] = file_names
df

Unnamed: 0,1,2,3,4,5,6,7,8,9,10,...,25,26,27,28,29,30,31,32,Lifetime,Filename
0,0.0,0.173797,0.0,0.0,0.118692,0.096499,0.0,0.774783,0.057839,0.0,...,0.700632,0.584249,0.0,0.0,0.698546,0.585029,1.198794,0.000000,0.0,1
1,0.0,0.167409,0.0,0.0,0.130081,0.095855,0.0,0.701808,0.054247,0.0,...,0.590315,0.583735,0.0,0.0,0.729140,0.584781,1.244438,0.000000,0.0,2
2,0.0,0.176095,0.0,0.0,0.162581,0.107649,0.0,0.759916,0.056983,0.0,...,0.414087,0.584021,0.0,0.0,0.730253,0.584782,1.195135,0.000000,0.0,3
3,0.0,0.172101,0.0,0.0,0.088470,0.093114,0.0,0.639463,0.068216,0.0,...,0.638360,0.579750,0.0,0.0,0.540187,0.581928,0.919073,0.000000,0.0,4
4,0.0,0.169747,0.0,0.0,0.072914,0.091454,0.0,0.624370,0.068020,0.0,...,0.672073,0.579613,0.0,0.0,0.528409,0.581919,0.911856,0.000000,0.0,5
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
9272,0.0,0.030859,0.0,0.0,0.030664,0.010164,0.0,0.423441,0.000000,0.0,...,0.000000,0.495008,0.0,0.0,0.192335,0.506369,0.606767,0.013631,0.0,9443
9273,0.0,0.023474,0.0,0.0,0.000000,0.000000,0.0,0.369223,0.000000,0.0,...,0.393608,0.493881,0.0,0.0,0.199756,0.505564,0.716898,0.008310,0.0,9444
9274,0.0,0.024624,0.0,0.0,0.014264,0.001416,0.0,0.437334,0.000000,0.0,...,0.164071,0.496707,0.0,0.0,0.290335,0.507588,0.822531,0.011557,0.0,9445
9275,0.0,0.032188,0.0,0.0,0.042863,0.006269,0.0,0.394387,0.000000,0.0,...,0.000000,0.493374,0.0,0.0,0.156206,0.505097,0.548009,0.014121,0.0,9446


In [76]:
df = df.set_index('Filename').reindex(file_names).reset_index()
df

Unnamed: 0,Filename,1,2,3,4,5,6,7,8,9,...,24,25,26,27,28,29,30,31,32,Lifetime
0,1,0.0,0.173797,0.0,0.0,0.118692,0.096499,0.0,0.774783,0.057839,...,0.106762,0.700632,0.584249,0.0,0.0,0.698546,0.585029,1.198794,0.000000,0.0
1,2,0.0,0.167409,0.0,0.0,0.130081,0.095855,0.0,0.701808,0.054247,...,0.000000,0.590315,0.583735,0.0,0.0,0.729140,0.584781,1.244438,0.000000,0.0
2,3,0.0,0.176095,0.0,0.0,0.162581,0.107649,0.0,0.759916,0.056983,...,0.004420,0.414087,0.584021,0.0,0.0,0.730253,0.584782,1.195135,0.000000,0.0
3,4,0.0,0.172101,0.0,0.0,0.088470,0.093114,0.0,0.639463,0.068216,...,0.000000,0.638360,0.579750,0.0,0.0,0.540187,0.581928,0.919073,0.000000,0.0
4,5,0.0,0.169747,0.0,0.0,0.072914,0.091454,0.0,0.624370,0.068020,...,0.000000,0.672073,0.579613,0.0,0.0,0.528409,0.581919,0.911856,0.000000,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
9272,9443,0.0,0.030859,0.0,0.0,0.030664,0.010164,0.0,0.423441,0.000000,...,0.920237,0.000000,0.495008,0.0,0.0,0.192335,0.506369,0.606767,0.013631,0.0
9273,9444,0.0,0.023474,0.0,0.0,0.000000,0.000000,0.0,0.369223,0.000000,...,0.747993,0.393608,0.493881,0.0,0.0,0.199756,0.505564,0.716898,0.008310,0.0
9274,9445,0.0,0.024624,0.0,0.0,0.014264,0.001416,0.0,0.437334,0.000000,...,1.063094,0.164071,0.496707,0.0,0.0,0.290335,0.507588,0.822531,0.011557,0.0
9275,9446,0.0,0.032188,0.0,0.0,0.042863,0.006269,0.0,0.394387,0.000000,...,0.686672,0.000000,0.493374,0.0,0.0,0.156206,0.505097,0.548009,0.014121,0.0


In [77]:
extend_df.to_csv('generate_new_molecules/data/global_features_4_32dim.csv', index=False)