In [1]:
import ijson
import numpy as np
import pandas as pd

from berkeley_pes.source.utils.data import *


In [2]:
def parse_json(json_filename, mode="normal", verbose=False):
    energies_raw, energies = [], []
    xyz_unformated, grad_unformated = [], []
    charge_list = []
    atom_count, element_count, element_list = [], [], []
    element_list_single_structure = []

    with open(json_filename, "rb") as input_file:
        # load json iteratively
        parser = ijson.parse(input_file)
        ind_track, atom_count_temp = 0, 0
        trigger_count, check_ind = 0, 0
        ind_mode, ind_current = -1, -1
        line_ind = 0
        frame_count_total = 0
        for prefix, event, value in parser:
            if mode == "flat":
                if prefix[0:13] == "item.molecule" or prefix[0:8] == "molecule":
                    # print('prefix={}, event={}, value={}'.format(prefix, event, value))

                    if ind_mode == -1 and event == "start_array":
                        try:
                            ind_current = int(prefix.split(".")[1])
                            ind_mode = 0
                            print("entering ind mode 0")
                        except:
                            trigger = False
                            ind_mode = 1
                            print("entering ind mode 1")

                    # print('prefix={}, event={}, value={}'.format(prefix, event, value))

                    if (
                        value == None
                        and event == "null"
                        and prefix == "item.molecule.@version"
                    ):
                        trigger = True
                        trigger_count += 1

                    elif event == "number":
                        if "xyz" in prefix.split("."):
                            xyz_unformated.append(float(value))
                        if "charge" in prefix.split("."):
                            charge_list.append(float(value))

                    elif event == "string":
                        if "element" in prefix.split("."):
                            element_list.append(str(value))
                            element_list_single_structure.append(str(value))
                            if trigger:
                                atom_count.append(atom_count_temp)
                                element_count.append(
                                    len(element_list_single_structure[0])
                                )
                                atom_count_temp = 1
                                trigger = False
                                element_list_single_structure = []

                            else:
                                atom_count_temp += 1

                if value != None:
                    if "gradient" in prefix.split("."):
                        if event == "number":
                            grad_unformated.append(float(value))

                    if prefix == "item.energy":
                        if event == "number":
                            energies_raw.append(float(value))

            else:
                if value is not None:
                    if event == "string":
                        if "formula_alphabetical" in prefix.split("."):
                            # sum all integers in string
                            elements = value.split()
                            # strip elements of alphabetical characters
                            elements = [
                                element.strip("ABCDEFGHIJKLMNOPQRSTUVWXYZ")
                                for element in elements
                            ]
                            # strip lowercase letters
                            elements = [
                                element.strip("abcdefghijklmnopqrstuvwxyz")
                                for element in elements
                            ]
                            # print(elements)
                            num_atoms = sum(map(int, elements))
                            element_count.append(num_atoms)

                    if event == "number":
                        if "energy_trajectory" in prefix.split("."):
                            energies_raw.append(float(value))

                        if "gradient_trajectory" in prefix.split("."):
                            # print('prefix={}, event={}, value={}'.format(prefix, event, value))
                            grad_unformated.append(float(value))

                if "molecule_trajectory" in prefix.split("."):
                    # if prefix[0:13] == "item.molecule":
                    if ind_mode == -1 and event == "start_array":
                        try:
                            ind_current = int(prefix.split(".")[1])
                            ind_mode = 0
                            print("entering ind mode 0")
                        except:
                            trigger = False
                            ind_mode = 1
                            print("entering ind mode 1")

                    if ind_mode == 0:
                        if event == "number":
                            if "xyz" in prefix.split("."):
                                xyz_unformated.append(float(value))
                            if "charge" in prefix.split("."):
                                charge_list.append(float(value))

                        if event == "string":
                            # print('prefix={}, event={}, value={}'.format(prefix, event, value))
                            if "name" in prefix.split("."):
                                element_list.append(str(value))
                                ind_current = int(prefix.split(".")[1])
                                # print(ind_current, ind_track)
                                if ind_current != ind_track:
                                    ind_track = int(ind_current)
                                    atom_count.append(atom_count_temp)
                                    atom_count_temp = 1
                                else:
                                    atom_count_temp += 1
                        if (
                            value == None
                            and event == "end_array"
                            and "molecule_trajectory" in prefix.split(".")
                            and "sites" in prefix.split(".")
                            and "species" not in prefix.split(".")
                            and "xyz" not in prefix.split(".")
                        ):
                            frame_count_total += 1

                    if ind_mode == 1:
                        if (
                            value == None
                            and event == "end_array"
                            and prefix == "item.molecule_trajectory"
                        ):
                            trigger = True
                            trigger_count += 1

                        if event == "number":
                            if "xyz" in prefix.split("."):
                                xyz_unformated.append(float(value))
                            # print('prefix={}, event={}, value={}'.format(prefix, event, value))
                            if "charge" in prefix.split("."):
                                charge_list.append(float(value))

                        if event == "string":
                            if "name" in prefix.split("."):
                                element_list.append(str(value))
                                if trigger:
                                    atom_count.append(atom_count_temp)
                                    atom_count_temp = 1
                                    trigger = False
                                else:
                                    atom_count_temp += 1

    if ind_mode == 0:
        atom_count.append(atom_count_temp)
        if mode == "flat":
            element_count.append(len(element_list_single_structure))
    else:
        atom_count.append(atom_count_temp)
        if mode == "flat":
            element_count.append(len(element_list_single_structure))

    if mode == "flat":
        grad_unformated = np.array(grad_unformated)
        grad_formated = grad_unformated.reshape(-1, 3)
        xyz_unformated = np.array(xyz_unformated)
        xyz_formated = xyz_unformated.reshape(-1, 3)
        atom_count = np.array(atom_count)
        frames_per_mol = atom_count / element_count
        grad_format = np.split(grad_formated, np.cumsum(atom_count)[:-1])
        xyz_format = np.split(xyz_formated, np.cumsum(atom_count)[:-1])
        element_list = np.split(element_list, np.cumsum(atom_count)[:-1])
        energies = energies_raw

    else:
        grad_unformated = np.array(grad_unformated)
        grad_formated = grad_unformated.reshape(-1, 3)
        xyz_unformated = np.array(xyz_unformated)
        xyz_formated = xyz_unformated.reshape(-1, 3)
        atom_count = np.array(atom_count)
        frames_per_mol = atom_count / element_count

        grad_format = np.split(grad_formated, np.cumsum(atom_count)[:-1])
        xyz_format = np.split(xyz_formated, np.cumsum(atom_count)[:-1])
        element_list = np.split(element_list, np.cumsum(atom_count)[:-1])  #
        # charge_list = np.split(charge_list, np.cumsum(frames_per_mol))
        # split charges into one charges per frame per molecule

        # split energies into frames per molecule
        running_start = 0
        charges = []
        for i in range(len(frames_per_mol)):
            energies.append(
                energies_raw[running_start : running_start + int(frames_per_mol[i])]
            )
            charges.append(
                charge_list[running_start : running_start + int(frames_per_mol[i])]
            )
            running_start += int(frames_per_mol[i])

        xyz_format = [
            array.reshape(int(frames_per_mol[ind_frame]), element_count[ind_frame], 3)
            for ind_frame, array in enumerate(xyz_format)
        ]
        grad_format = [
            array.reshape(int(frames_per_mol[ind_frame]), element_count[ind_frame], 3)
            for ind_frame, array in enumerate(grad_format)
        ]
        element_list = [
            array.reshape(int(frames_per_mol[ind_frame]), element_count[ind_frame])
            for ind_frame, array in enumerate(element_list)
        ]

    composition_list = []
    for elements_repeated in element_list:
        composition = {}
        for element in elements_repeated[0]:
            if element in composition:
                composition[element] += 1
            else:
                composition[element] = 1
        # sort composition by key
        composition = dict(sorted(composition.items()))
        composition_list.append(composition)

    if verbose:
        print("element_list len:     {}".format(len(element_list)))
        print("element_count len:    {}".format(len(element_count)))
        print("atom_count len:       {}".format(len(atom_count)))
        print("xyz_format len:       {}".format(len(xyz_format)))
        print("energies len:         {}".format(len(energies)))
        print("grad_format len:      {}".format(len(grad_format)))
        print("composition_list len: {}".format(len(composition_list)))
        print("charge_list len:      {}".format(len(charges)))
        print("frames total:         {}".format(frame_count_total))
        print("frames per mol:       {}".format(frames_per_mol))
        print("sum frames per mol:   {}".format(np.sum(frames_per_mol)))

    data = {
        "energies": energies,
        "grads": grad_format,
        "xyz": xyz_format,
        "elements": element_list,
        "frames_per_mol": frames_per_mol,
        "atom_count": atom_count,
        "element_composition": composition_list,
        "charges": charges,
    }
    return data


In [3]:
# file_rapter = "../../data/20230414_rapter_tracks_initial.json"
# file_libe = "../../data/tasks_opt_trajectories_partial.json"
# file_test = "../../data/test_rapter.json"
# df_rapter = pd.read_json(file_test)
# print(df_rapter.shape)
# df_libe = dd.read_json(file_libe)
# pd.read_json(file_libe)

file_test = "../../../data/test_libe.json"
file_test_rapter = "../../../data/test_rapter.json"


In [45]:
def parse_json(json_filename, mode="normal", verbose=False):
    energies_raw, energies = [], []
    xyz_unformated, grad_unformated = [], []
    spin_list, charge_list = [], []
    atom_count, element_count, element_list = [], [], []
    element_list_single_structure = []

    with open(json_filename, "rb") as input_file:
        # load json iteratively
        parser = ijson.parse(input_file)
        ind_track, atom_count_temp = 0, 0
        trigger_count, check_ind = 0, 0
        ind_mode, ind_current = -1, -1
        line_ind = 0
        frame_count_total = 0
        for prefix, event, value in parser:
            # if "spin_multiplicity" in prefix and event == "number":
            #    print("prefix={}, event={}, value={}".format(prefix, event, value))

            if mode == "flat":
                if prefix[0:13] == "item.molecule" or prefix[0:8] == "molecule":
                    # print('prefix={}, event={}, value={}'.format(prefix, event, value))

                    if ind_mode == -1 and event == "start_array":
                        try:
                            ind_current = int(prefix.split(".")[1])
                            ind_mode = 0
                            print("entering ind mode 0")
                        except:
                            trigger = False
                            ind_mode = 1
                            print("entering ind mode 1")

                    # print('prefix={}, event={}, value={}'.format(prefix, event, value))

                    if (
                        value == None
                        and event == "null"
                        and prefix == "item.molecule.@version"
                    ):
                        trigger = True
                        trigger_count += 1

                    elif event == "number":
                        if "xyz" in prefix.split("."):
                            xyz_unformated.append(float(value))
                        if "charge" in prefix.split("."):
                            charge_list.append(float(value))

                    elif event == "string":
                        if "element" in prefix.split("."):
                            element_list.append(str(value))
                            element_list_single_structure.append(str(value))
                            if trigger:
                                atom_count.append(atom_count_temp)
                                element_count.append(
                                    len(element_list_single_structure[0])
                                )
                                atom_count_temp = 1
                                trigger = False
                                element_list_single_structure = []

                            else:
                                atom_count_temp += 1

                if value != None:
                    if "gradient" in prefix.split("."):
                        if event == "number":
                            grad_unformated.append(float(value))

                    if prefix == "item.energy":
                        if event == "number":
                            energies_raw.append(float(value))

            else:
                if value is not None:
                    if event == "string":
                        if "formula_alphabetical" in prefix.split("."):
                            # sum all integers in string
                            elements = value.split()
                            # strip elements of alphabetical characters
                            elements = [
                                element.strip("ABCDEFGHIJKLMNOPQRSTUVWXYZ")
                                for element in elements
                            ]
                            # strip lowercase letters
                            elements = [
                                element.strip("abcdefghijklmnopqrstuvwxyz")
                                for element in elements
                            ]
                            # print(elements)
                            num_atoms = sum(map(int, elements))
                            element_count.append(num_atoms)

                    if event == "number":
                        if "energy_trajectory" in prefix.split("."):
                            energies_raw.append(float(value))

                        if "gradient_trajectory" in prefix.split("."):
                            # print('prefix={}, event={}, value={}'.format(prefix, event, value))
                            grad_unformated.append(float(value))

                if "molecule_trajectory" in prefix.split("."):
                    # if prefix[0:13] == "item.molecule":
                    if ind_mode == -1 and event == "start_array":
                        try:
                            ind_current = int(prefix.split(".")[1])
                            ind_mode = 0
                            print("entering ind mode 0")
                        except:
                            trigger = False
                            ind_mode = 1
                            print("entering ind mode 1")

                    if ind_mode == 0:
                        if event == "number":
                            if "xyz" in prefix.split("."):
                                xyz_unformated.append(float(value))
                            if "charge" in prefix.split("."):
                                charge_list.append(float(value))
                            if "spin_multiplicity" in prefix.split("."):
                                # print(
                                #    "prefix={}, event={}, value={}".format(
                                #        prefix, event, value
                                #    )
                                # )
                                spin_list.append(float(value))

                        if event == "string":
                            # print('prefix={}, event={}, value={}'.format(prefix, event, value))
                            if "name" in prefix.split("."):
                                element_list.append(str(value))
                                ind_current = int(prefix.split(".")[1])
                                # print(ind_current, ind_track)
                                if ind_current != ind_track:
                                    ind_track = int(ind_current)
                                    atom_count.append(atom_count_temp)
                                    atom_count_temp = 1
                                else:
                                    atom_count_temp += 1
                        if (
                            value == None
                            and event == "end_array"
                            and "molecule_trajectory" in prefix.split(".")
                            and "sites" in prefix.split(".")
                            and "species" not in prefix.split(".")
                            and "xyz" not in prefix.split(".")
                        ):
                            frame_count_total += 1

                    if ind_mode == 1:
                        if (
                            value == None
                            and event == "end_array"
                            and prefix == "item.molecule_trajectory"
                        ):
                            trigger = True
                            trigger_count += 1

                        if event == "number":
                            if "xyz" in prefix.split("."):
                                xyz_unformated.append(float(value))
                            # print('prefix={}, event={}, value={}'.format(prefix, event, value))
                            if "charge" in prefix.split("."):
                                charge_list.append(float(value))
                            if "spin_multiplicity" in prefix.split("."):
                                """print(
                                    "prefix={}, event={}, value={}".format(
                                        prefix, event, value
                                    )
                                )"""
                                spin_list.append(float(value))

                        if event == "string":
                            if "name" in prefix.split("."):
                                element_list.append(str(value))
                                if trigger:
                                    atom_count.append(atom_count_temp)
                                    atom_count_temp = 1
                                    trigger = False
                                else:
                                    atom_count_temp += 1

    if ind_mode == 0:
        atom_count.append(atom_count_temp)
        if mode == "flat":
            element_count.append(len(element_list_single_structure))
    else:
        atom_count.append(atom_count_temp)
        if mode == "flat":
            element_count.append(len(element_list_single_structure))

    if mode == "flat":
        grad_unformated = np.array(grad_unformated)
        grad_formated = grad_unformated.reshape(-1, 3)
        xyz_unformated = np.array(xyz_unformated)
        xyz_formated = xyz_unformated.reshape(-1, 3)
        atom_count = np.array(atom_count)
        frames_per_mol = atom_count / element_count
        grad_format = np.split(grad_formated, np.cumsum(atom_count)[:-1])
        xyz_format = np.split(xyz_formated, np.cumsum(atom_count)[:-1])
        element_list = np.split(element_list, np.cumsum(atom_count)[:-1])
        energies = energies_raw

    else:
        grad_unformated = np.array(grad_unformated)
        grad_formated = grad_unformated.reshape(-1, 3)
        xyz_unformated = np.array(xyz_unformated)
        xyz_formated = xyz_unformated.reshape(-1, 3)
        atom_count = np.array(atom_count)
        frames_per_mol = atom_count / element_count

        grad_format = np.split(grad_formated, np.cumsum(atom_count)[:-1])
        xyz_format = np.split(xyz_formated, np.cumsum(atom_count)[:-1])
        element_list = np.split(element_list, np.cumsum(atom_count)[:-1])  #
        # charge_list = np.split(charge_list, np.cumsum(frames_per_mol))
        # split charges into one charges per frame per molecule

        # split energies into frames per molecule
        running_start = 0
        charges = []
        spins = []
        for i in range(len(frames_per_mol)):
            energies.append(
                energies_raw[running_start : running_start + int(frames_per_mol[i])]
            )
            charges.append(
                charge_list[running_start : running_start + int(frames_per_mol[i])]
            )
            spins.append(
                spin_list[running_start : running_start + int(frames_per_mol[i])]
            )
            running_start += int(frames_per_mol[i])

        xyz_format = [
            array.reshape(int(frames_per_mol[ind_frame]), element_count[ind_frame], 3)
            for ind_frame, array in enumerate(xyz_format)
        ]
        grad_format = [
            array.reshape(int(frames_per_mol[ind_frame]), element_count[ind_frame], 3)
            for ind_frame, array in enumerate(grad_format)
        ]
        element_list = [
            array.reshape(int(frames_per_mol[ind_frame]), element_count[ind_frame])
            for ind_frame, array in enumerate(element_list)
        ]

    composition_list = []
    for elements_repeated in element_list:
        composition = {}
        for element in elements_repeated[0]:
            if element in composition:
                composition[element] += 1
            else:
                composition[element] = 1
        # sort composition by key
        composition = dict(sorted(composition.items()))
        composition_list.append(composition)

    if verbose:
        print("element_list len:     {}".format(len(element_list)))
        print("element_count len:    {}".format(len(element_count)))
        print("atom_count len:       {}".format(len(atom_count)))
        print("spins len:            {}".format(len(spins)))
        print("xyz_format len:       {}".format(len(xyz_format)))
        print("energies len:         {}".format(len(energies)))
        print("grad_format len:      {}".format(len(grad_format)))
        print("composition_list len: {}".format(len(composition_list)))
        print("charge_list len:      {}".format(len(charges)))
        print("frames total:         {}".format(frame_count_total))
        print("frames per mol:       {}".format(frames_per_mol))
        print("sum frames per mol:   {}".format(np.sum(frames_per_mol)))

    data = {
        "energies": energies,
        "grads": grad_format,
        "xyz": xyz_format,
        "elements": element_list,
        "frames_per_mol": frames_per_mol,
        "atom_count": atom_count,
        "element_composition": composition_list,
        "charges": charges,
        "spin": spins,
    }
    return data

In [46]:
data_test = parse_json(file_test_rapter, mode="normal", verbose=True)


entering ind mode 0
element_list len:     100
element_count len:    100
atom_count len:       100
spins len:            100
xyz_format len:       100
energies len:         100
grad_format len:      100
composition_list len: 100
charge_list len:      100
frames total:         2755
frames per mol:       [  5.   5.  82.  16. 132.   2.  33.  19.  27.  24.  11. 117.  44.  59.
  39.  50.  69.  15.  26.  10.   2.   4.  28.  26.   7.  54.  21.   2.
  87.  26.  13.  27.  17.   8.  20.  21.  80.  59.  14.  32.  27.   9.
  34.   5.  23.  14.   1.   3.   1.   3.   3.   4.  30.  19.  14.  52.
   3.   2.  55.  80.  79.   2.  26.   2.  24.   2.   9.   3.   2.   4.
  13.   2.   2.   2.  47.  33.  25.   4.  24.  19.  19.  27.  13.  26.
  23.  46.   9.   4. 145.  20.  49.   9.  16.   5.  51.  63.  94.  24.
   4.  69.]
sum frames per mol:   2755.0


In [47]:
data_test_rapter = parse_json(file_test_rapter, mode="normal", verbose=True)


entering ind mode 0
element_list len:     100
element_count len:    100
atom_count len:       100
spins len:            100
xyz_format len:       100
energies len:         100
grad_format len:      100
composition_list len: 100
charge_list len:      100
frames total:         2755
frames per mol:       [  5.   5.  82.  16. 132.   2.  33.  19.  27.  24.  11. 117.  44.  59.
  39.  50.  69.  15.  26.  10.   2.   4.  28.  26.   7.  54.  21.   2.
  87.  26.  13.  27.  17.   8.  20.  21.  80.  59.  14.  32.  27.   9.
  34.   5.  23.  14.   1.   3.   1.   3.   3.   4.  30.  19.  14.  52.
   3.   2.  55.  80.  79.   2.  26.   2.  24.   2.   9.   3.   2.   4.
  13.   2.   2.   2.  47.  33.  25.   4.  24.  19.  19.  27.  13.  26.
  23.  46.   9.   4. 145.  20.  49.   9.  16.   5.  51.  63.  94.  24.
   4.  69.]
sum frames per mol:   2755.0


In [72]:
data_test_rapter.keys()


dict_keys(['energies', 'grads', 'xyz', 'elements', 'frames_per_mol', 'atom_count', 'element_composition', 'charges', 'spin'])

In [108]:
rapter_file = (
    "/home/santiagovargas/dev/berkeley_pes/data/20230414_rapter_tracks_initial.json"
)
data_rapter = parse_json(rapter_file, mode="normal", verbose=True)


entering ind mode 1
element_list len:     14414
element_count len:    14414
atom_count len:       14414
spins len:            14414
xyz_format len:       14414
energies len:         14414
grad_format len:      14414
composition_list len: 14414
charge_list len:      14414
frames total:         0
frames per mol:       [ 5.  5. 82. ... 28. 14.  6.]
sum frames per mol:   579636.0


In [None]:
data = parse_json(
    "../../../data/tasks_opt_trajectories_partial.json", mode="alt", verbose=True
)


In [87]:
data_rapter.keys()


dict_keys(['energies', 'grads', 'xyz', 'elements', 'frames_per_mol', 'atom_count', 'element_composition', 'charge_list'])

In [89]:
charge_list = data_rapter["charge_list"]
list_charge_unique = np.unique(charge_list)
list_charge_unique


array([-3., -2., -1.,  0.,  1.,  2.,  3.])

In [104]:
data_rapter["element_composition"]
comp_string_list = []
for i in data_rapter["element_composition"]:
    comp_string = ""
    for key, value in sorted(i.items()):
        comp_string = comp_string + key + str(value) + "_"
    comp_string = comp_string[:-1]
    comp_string_list.append(comp_string)
print(len(comp_string_list))
print(len(list(set(comp_string_list))))


14414
1465


In [105]:
import os


def separate_into_charge_and_comp_and_spin(dict_info):
    energies = dict_info["energies"]
    grads_list = dict_info["grads"]
    spin_list = dict_info["spin"]
    charge_list = dict_info["charges"]
    xyzs_list = dict_info["xyz"]
    elements_list = dict_info["elements"]
    composition_list = dict_info["element_composition"]
    charge_list_single = [i[0] for i in charge_list]
    spin_list_single = [i[0] for i in spin_list]
    list_charge_unique = np.unique(charge_list_single)
    list_spin_unique = np.unique(spin_list_single)

    separate_charge_comp_dict = (
        {}
    )  # dict with charges and info for each frame in composition

    for charge in list_charge_unique:  # instantiate
        separate_charge_comp_dict[charge] = {}
        for spin in list_spin_unique:
            separate_charge_comp_dict[charge][spin] = {}

    comp_string_list = []
    print(separate_charge_comp_dict)
    for ind, i in enumerate(composition_list):
        comp_string = comp_dict_to_string(i)
        comp_string_list.append(comp_string)
        charge = charge_list[ind][0]
        spin = spin_list[ind][0]

        if comp_string not in separate_charge_comp_dict[charge][spin].keys():
            separate_charge_comp_dict[charge][spin][comp_string] = {
                "energies": [],
                "grads": [],
                "xyzs": [],
                "elements": [],
                "charge": [],
                "spin": [],
            }

        separate_charge_comp_dict[charge][spin][comp_string]["energies"].append(
            energies[ind]
        )
        separate_charge_comp_dict[charge][spin][comp_string]["grads"].append(
            grads_list[ind]
        )
        separate_charge_comp_dict[charge][spin][comp_string]["xyzs"].append(
            xyzs_list[ind]
        )
        separate_charge_comp_dict[charge][spin][comp_string]["elements"].append(
            elements_list[ind]
        )
        separate_charge_comp_dict[charge][spin][comp_string]["charge"].append(
            charge_list[ind]
        )
        separate_charge_comp_dict[charge][spin][comp_string]["spin"].append(
            spin_list[ind]
        )

    return separate_charge_comp_dict


def write_dict_to_ase_trajectory(
    dict_info,
    root_out,
    separate_charges=False,
    separate_composition=False,
    separate_spin=False,
):
    """
    Takes a dictionary organized by trajectories and writes it to an ase file
    """

    energies = dict_info["energies"]
    grads_list = dict_info["grads"]
    xyzs_list = dict_info["xyz"]
    elements_list = dict_info["elements"]
    composition_list = dict_info["element_composition"]
    charge_list = dict_info["charges"]
    charge_list_single = [i[0] for i in charge_list]
    list_charge_unique = np.unique(charge_list_single)

    if separate_spin:
        spin_list = dict_info["spin"]
        spin_list_single = [i[0] for i in spin_list]
        list_spin_unique = np.unique(spin_list_single)

    if separate_spin and separate_composition and separate_charges:
        separate_charge_comp_dict = separate_into_charge_and_comp_and_spin(dict_info)
    elif separate_composition and separate_charges:
        separate_charge_comp_dict = separate_into_charge_and_comp(dict_info)
    elif separate_composition:
        separate_charge_comp_dict = separate_into_comp(dict_info)
    else:
        separate_charge_comp_dict = None

    # create a folder for each charge
    frame_count_global = 0

    if separate_charges and separate_spin:
        for charge in list_charge_unique:
            frame_count_charge = 0
            if not os.path.exists(
                root_out + "/charge_" + str(charge)
            ) and not os.path.exists(root_out + "/charge_" + "neg_" + str(charge)):
                charge_temp = int(charge)
                if charge < 0:
                    charge_temp = "neg_" + str(abs(charge_temp))
                os.makedirs(root_out + "/charge_" + str(charge_temp), exist_ok=True)

            for spin in list_spin_unique:
                if not os.path.exists(
                    root_out + "/charge_" + str(charge_temp) + "/spin_" + str(int(spin))
                ):
                    os.makedirs(
                        root_out
                        + "/charge_"
                        + str(charge_temp)
                        + "/spin_"
                        + str(int(spin)),
                        exist_ok=True,
                    )

            for spin_key, dict_info_temp in separate_charge_comp_dict[charge].items():
                for comp_key, dict_info in dict_info_temp.items():
                    comp_count = 0
                    file = (
                        root_out
                        + "/charge_"
                        + str(charge_temp)
                        + "/spin_"
                        + str(int(spin_key))
                        + "/{}.xyz".format(comp_key)
                    )
                    with open(file, "w") as f:
                        for ind_frame, (
                            energies_frame,
                            grads_frame,
                            xyzs_frame,
                            elements_frame,
                        ) in enumerate(
                            zip(
                                dict_info["energies"],
                                dict_info["grads"],
                                dict_info["xyzs"],
                                dict_info["elements"],
                            )
                        ):
                            for ind_mol, (energy, grad, xyz, elements) in enumerate(
                                zip(
                                    energies_frame,
                                    grads_frame,
                                    xyzs_frame,
                                    elements_frame,
                                )
                            ):
                                frame_count_global += 1
                                write_ase(f, elements, energy, grad, xyz)
                                comp_count += 1
                                frame_count_charge += 1
                                frame_count_global += 1

                    print(
                        "frames to {}: \t\t {}".format(
                            file.split("/")[-1].split(".")[0], comp_count
                        )
                    )
            print(
                "frames charge {} folder:\t {}".format(charge_temp, frame_count_charge)
            )
        print("frames total: \t\t {}".format(frame_count_global))

    elif separate_charges:
        for charge in list_charge_unique:
            frame_count_charge = 0
            if not os.path.exists(root_out + str(charge)):
                charge_temp = int(charge)
                if charge < 0:
                    charge_temp = "neg_" + str(abs(charge_temp))
                os.makedirs(root_out + str(charge_temp), exist_ok=True)

            for comp_key, dict_info in separate_charge_comp_dict[charge].items():
                comp_count = 0
                file = root_out + str(charge_temp) + "/{}.xyz".format(comp_key)
                with open(file, "w") as f:
                    for ind_frame, (
                        energies_frame,
                        grads_frame,
                        xyzs_frame,
                        elements_frame,
                    ) in enumerate(
                        zip(
                            dict_info["energies"],
                            dict_info["grads"],
                            dict_info["xyzs"],
                            dict_info["elements"],
                        )
                    ):
                        for ind_mol, (energy, grad, xyz, elements) in enumerate(
                            zip(
                                energies_frame,
                                grads_frame,
                                xyzs_frame,
                                elements_frame,
                            )
                        ):
                            frame_count_global += 1
                            write_ase(f, elements, energy, grad, xyz)
                            comp_count += 1
                            frame_count_charge += 1
                            frame_count_global += 1

                print(
                    "frames to {}: \t\t {}".format(
                        file.split("/")[-1].split(".")[0], comp_count
                    )
                )

            print(
                "frames charge {} folder:\t {}".format(charge_temp, frame_count_charge)
            )
        print("frames total: \t\t {}".format(frame_count_global))

    else:
        print("writing as a single file")
        if not os.path.exists(root_out):
            # make it if it doesn't
            os.makedirs(root_out, exist_ok=True)

        with open(root_out + "combined.xyz", "w") as f:
            for ind_frame, (
                energies_frame,
                grads_frame,
                xyzs_frame,
                elements_frame,
            ) in enumerate(zip(energies, grads_list, xyzs_list, elements_list)):
                # print(ind_frame, len(energies_frame))
                for ind_mol, (energy, grad, xyz, elements) in enumerate(
                    zip(energies_frame, grads_frame, xyzs_frame, elements_frame)
                ):
                    frame_count_global += 1
                    n_atoms = len(elements)
                    # print(elements)
                    f.write(str(n_atoms) + "\n")
                    f.write(
                        'Properties=species:S:1:pos:R:3:forces:R:3 energy={} free_energy={} pbc="F F F"\n'.format(
                            energy, energy
                        )
                    )
                    for ind_atom, (xyz, grad) in enumerate(zip(xyz, grad)):
                        # print(elements[0])
                        f.write(
                            "{:2} {:15.8} {:15.8} {:15.8} {:15.8} {:15.8} {:15.8}\n".format(
                                elements[ind_atom],
                                xyz[0],
                                xyz[1],
                                xyz[2],
                                grad[0],
                                grad[1],
                                grad[2],
                            )
                        )
        print("Wrote {} frames to {}".format(frame_count_global, root_out))

In [109]:
write_dict_to_ase_trajectory(
    data_rapter,
    "../../../data/ase/rapter_full/",
    separate_composition=True,
    separate_spin=True,
    separate_charges=True,
)


{-3.0: {1.0: {}, 2.0: {}, 3.0: {}}, -2.0: {1.0: {}, 2.0: {}, 3.0: {}}, -1.0: {1.0: {}, 2.0: {}, 3.0: {}}, 0.0: {1.0: {}, 2.0: {}, 3.0: {}}, 1.0: {1.0: {}, 2.0: {}, 3.0: {}}, 2.0: {1.0: {}, 2.0: {}, 3.0: {}}, 3.0: {1.0: {}, 2.0: {}, 3.0: {}}}
frames to C2_H3_Li1_O3: 		 8
frames charge neg_3 folder:	 8
frames to C3_H6_O5: 		 247
frames to C6_H8_O6: 		 8
frames to C4_H6: 		 67
frames to C5_H4_O4: 		 52
frames to C4_H4_O2: 		 28
frames to C4_H6_O1: 		 374
frames to C2_H2_O2: 		 56
frames to C5_H2_O1: 		 29
frames to C4_H8_O1: 		 73
frames to C5_H8_O4: 		 19
frames to C6_H8_O3: 		 4
frames to C4_H6_O2: 		 229
frames to C3_H4_O2: 		 8
frames to C5_H8_O3: 		 6
frames to C6_H1_Li1_O2: 		 339
frames to C3_H1_Li1_O3: 		 55
frames to C2_F1_H3_O2: 		 2
frames to C2_F1_O1_P1: 		 9
frames to C3_F1_H2_Li1_O1: 		 5
frames to F1_H1_Li1_O2_P1: 		 37
frames to F2_H1_O1_P1: 		 70
frames to C2_F1_H3_O1: 		 50
frames to C3_F1_H3_O2: 		 339
frames to C2_F1_H1_O1: 		 15
frames to C5_H6_O2: 		 58
frames to C4_

In [81]:
write_dict_to_ase_trajectory(data_test, "../../data/ase/test_parser.xyz")
write_dict_to_ase_trajectory(data, "../../data/ase/20230414_rapter_tracks_initial.xyz")


Wrote 2755 frames to ../../data/npz/test_parser.xyz
Wrote 579636 frames to ../../data/npz/20230414_rapter_tracks_initial.xyz


In [16]:
# read with ase
from ase.io import read

test_out = read("../../../data/ase/libe_full.xyz", index=":")


In [17]:
# libe full - Wrote 1419271 frames to ../../../data/ase/libe_full.xyz
len(test_out)  # whoo!!


1419271

In [5]:
libe_test_file = "../../../data/test_libe.json"
import pandas as pd

libe_test_df = pd.read_json(libe_test_file)
libe_test_df.molecule.iloc[0]


{'@module': 'pymatgen.core.structure',
 '@class': 'Molecule',
 'charge': 0,
 'spin_multiplicity': 1,
 'sites': [{'name': 'O',
   'species': [{'element': 'O', 'occu': 1}],
   'xyz': [-2.6925826572, 0.7768596788000001, -0.6232921263],
   'properties': {}},
  {'name': 'C',
   'species': [{'element': 'C', 'occu': 1}],
   'xyz': [-1.8891435037000002, -0.0008620098, -1.1063488904],
   'properties': {}},
  {'name': 'C',
   'species': [{'element': 'C', 'occu': 1}],
   'xyz': [-1.0414994726, -0.9508093493, -0.3183232657],
   'properties': {}},
  {'name': 'C',
   'species': [{'element': 'C', 'occu': 1}],
   'xyz': [0.42070744720000003, -0.6954733305, -0.5925367549],
   'properties': {}},
  {'name': 'O',
   'species': [{'element': 'O', 'occu': 1}],
   'xyz': [-3.7580214127, 1.7756652839, 2.6235669689],
   'properties': {}},
  {'name': 'H',
   'species': [{'element': 'H', 'occu': 1}],
   'xyz': [-3.966847318, 2.0676673205, 3.5044355521],
   'properties': {}},
  {'name': 'H',
   'species': [{'eleme

In [11]:
def write_dict_to_ase_single_mol(dict_info, file_out):
    """
    Takes a dictionary with a single molecule/gradient/energy and writes it to an ase file
    """
    energies = dict_info["energies"]
    grads_list = dict_info["grads"]
    xyzs_list = dict_info["xyz"]
    elements_list = dict_info["elements"]
    frame_count_global = 0
    with open(file_out, "w") as f:
        for ind_frame, (energy, grad, xyz, elements) in enumerate(
            zip(energies, grads_list, xyzs_list, elements_list)
        ):
            frame_count_global += 1
            n_atoms = len(elements)
            # print(elements)
            f.write(str(n_atoms) + "\n")
            f.write(
                'Properties=species:S:1:pos:R:3:forces:R:3 energy={} free_energy={} pbc="F F F"\n'.format(
                    energy, energy
                )
            )
            for ind_atom, (xyz, grad) in enumerate(zip(xyz, grad)):
                # print(elements[0])
                f.write(
                    "{:2} {:15.8} {:15.8} {:15.8} {:15.8} {:15.8} {:15.8}\n".format(
                        elements[ind_atom],
                        xyz[0],
                        xyz[1],
                        xyz[2],
                        grad[0],
                        grad[1],
                        grad[2],
                    )
                )

    print("Wrote {} frames to {}".format(frame_count_global, file_out))


def write_dict_to_ase_trajectory(dict_info, file_out):
    """
    Takes a dictionary organized by trajectories and writes it to an ase file
    """
    energies = dict_info["energies"]
    grads_list = dict_info["grads"]
    xyzs_list = dict_info["xyz"]
    elements_list = dict_info["elements"]

    frame_count_global = 0
    with open(file_out, "w") as f:
        for ind_frame, (
            energies_frame,
            grads_frame,
            xyzs_frame,
            elements_frame,
        ) in enumerate(zip(energies, grads_list, xyzs_list, elements_list)):
            # print(ind_frame, len(energies_frame))
            for ind_mol, (energy, grad, xyz, elements) in enumerate(
                zip(energies_frame, grads_frame, xyzs_frame, elements_frame)
            ):
                frame_count_global += 1
                n_atoms = len(elements)
                # print(elements)
                f.write(str(n_atoms) + "\n")
                f.write(
                    'Properties=species:S:1:pos:R:3:forces:R:3 energy={} free_energy={} pbc="F F F"\n'.format(
                        energy, energy
                    )
                )
                for ind_atom, (xyz, grad) in enumerate(zip(xyz, grad)):
                    # print(elements[0])
                    f.write(
                        "{:2} {:15.8} {:15.8} {:15.8} {:15.8} {:15.8} {:15.8}\n".format(
                            elements[ind_atom],
                            xyz[0],
                            xyz[1],
                            xyz[2],
                            grad[0],
                            grad[1],
                            grad[2],
                        )
                    )

    print("Wrote {} frames to {}".format(frame_count_global, file_out))


write_dict_to_ase_single_mol(data, "../../../data/ase/libe_full.xyz")


Wrote 1419271 frames to ../../../data/ase/libe_full.xyz
