In [None]:
import os
import sys
import time
from pathos.multiprocessing import ProcessingPool as Pool
import torch

from get_data_from_siesta import get_data_from_siesta
from get_rotate_coord import get_rc
from rotate import get_rh
from data import HData

In [None]:
current_sys_path = os.path.abspath(".")
print("current sys path:", current_sys_path)

In [None]:
config = {
    "basic": {
        "raw_dir": "/fs2/home/ndsim10/DeepQT/DeepQTH/0_generate_dataset/expand_dataset/raw/",
        "processed_data_dir": "/fs2/home/ndsim10/DeepQT/DeepQTH/0_generate_dataset/expand_dataset/processed/",
        "graph_dir": "/fs2/home/ndsim10/DeepQT/DeepQTH/0_generate_dataset/expand_dataset/graph/",
        "target": "hamiltonian",
        "interface": "siesta",
        "data_format": "h5",
        "input_file": "input.fdf",
        "multiprocessing": 8,
        "local_coordinate": True,
        "material_dimension": 2,
    },
    "interpreter": {
        "python_interpreter": "~/miniconda3/envs/deeph-cpu/bin/python"
    },
    "graph": {
        "radius": 7.0, #graphene 7.0 Å, MoS2 8.0 Å, and silicon  9.0 Å
        "num_l": 4,
        "if_lcmp_graph": True,
        "shortest_path_length": 5,
    }
}

In [None]:
def main(config, current_sys_path):

    assert config['basic']['target'] in ['hamiltonian']
    assert config['basic']['input_file'] in ['input.fdf']
    
    raw_dir = os.path.abspath(config['basic']['raw_dir'])
    print("raw_dir:", raw_dir)
    processed_data_dir = os.path.abspath(config['basic']['processed_data_dir'])
    print("processed_data_dir:", processed_data_dir)

    target = config['basic']['target'] # hamiltonian
    interface = config['basic']['interface'] # siesta/transiesa
    input_file = config['basic']['input_file'] # *.fdf
    local_coordinate = config['basic']['local_coordinate'] # True
    multiprocessing = config['basic']['multiprocessing'] # 8
    radius = config['graph']['radius']
    
    os.chdir(raw_dir)
    relpath_list = [] # Obtain the relative path of the dataset
    abspath_list = [] # Obtain the absolute path of the dataset

    for root, dirs, files in os.walk('./'):
        if (any(['.HSX' in ifile for ifile in files])) or (any(['.TSHS' in ifile for ifile in files])):
            relpath_list.append(root)  # ./1-./600
            abspath_list.append(os.path.abspath(root))
    os.makedirs(processed_data_dir, exist_ok=True)
    os.chdir(processed_data_dir)
    print(f"Found {len(abspath_list)} directories to preprocess")
    # print(abspath_list)
    # print(relpath_list)
    
    def worker(index):
        # print(index)
        time_cost = time.time() - begin_time
        
        abspath = abspath_list[index]
        relpath = relpath_list[index]
        os.makedirs(relpath, exist_ok=True)
        struct_id = int(''.join(filter(str.isdigit, relpath)))
        
        get_data_from_siesta(abspath, os.path.abspath(relpath), interface, input_file)
        
        if local_coordinate:
            # Obtain the local coordinate system
            get_rc(os.path.abspath(relpath), os.path.abspath(relpath), radius=radius, neighbour_file='hamiltonians.h5')
            # Obtain the Hamiltonian after rotation
            get_rh(os.path.abspath(relpath), os.path.abspath(relpath), target)
        

    begin_time = time.time()
    if multiprocessing != 0:
        if multiprocessing > 0:
            pool_dict = {'nodes': multiprocessing}
        else:
            pool_dict = {}
    
        with Pool(**pool_dict) as pool:
            nodes = pool.nodes
            print(f'Use multiprocessing (nodes = {nodes})')
            pool.map(worker, range(len(abspath_list)))
    else:
        nodes = 1
        for index in range(len(abspath_list)):
            worker(index)
    print(f'\nFinished preprocess {len(abspath_list)} directories and cost {time.time() - begin_time:.2f} seconds')
    

    if current_sys_path not in sys.path:
        sys.path.insert(0, current_sys_path)
    dataset = HData(config, default_dtype_torch=torch.get_default_dtype())
    return dataset
    
if __name__ == '__main__':

    dataset = main(config, current_sys_path)
    

In [None]:
print(dataset.num_edge_features)
print(dataset.num_node_features)

In [None]:
print(dataset[0])

In [None]:
dataset[0].x

In [None]:
dataset[0].atom_num_orbital

In [None]:
dataset.info