In [1]:
import glob
import os
import random
import joblib

import numpy as np
import pandas as pd
import ase.io
from ase.io import extxyz
from ase.visualize.plot import plot_atoms
from ase.constraints import FixAtoms

In [65]:
input_dir = '/mnt/data1/non-tar/Zn-Cr-O'
output_dir = '/mnt/data1/ocp_data/data'
test_dir_name = 'ZPL_Cr9O23Zn9'
data_list = []
sids = os.listdir(input_dir)
for sid in sids:
    if sid == test_dir_name:
        print('skipping test')
        continue
    traj_path = glob.glob(os.path.join(input_dir, sid, "*.xyz"))
    if len(traj_path) > 1:
        print(sid)
    traj_frames = ase.io.read(traj_path[0], ":")
    for fid in range(len(traj_frames)):
        data_list.append((sid, fid, 0.0))
traj_path = glob.glob(os.path.join(input_dir, test_dir_name, "*.xyz"))
if len(traj_path) > 1:
    print(sid)
traj_frames = ase.io.read(traj_path[0], ":")
test_list = []
for fid in range(len(traj_frames)):
    test_list.append((sid, fid, 0.0))

skipping test


In [66]:
len(data_list), len(test_list)

(107352, 115)

In [67]:
seed = 2021
np.random.seed(seed)
random.seed(seed)
train_size = 0.7
val_size = 0.2
test_size = 0.1
tot_structure_num = len(data_list[:105000]+test_list)
split_indexes = list(range(len(data_list[:105000])))
random.shuffle(split_indexes)
index2split = {}
for i in range(0, 80000):
    index2split[split_indexes[i]] = 'train'
for i in range(80000,80000+25000):
    index2split[split_indexes[i]] = 'val'

mapping_dict = {}
train_data, val_data = [], []
for i in range(len(data_list[:105000])):
    if index2split[i] == 'train':
        train_data.append(data_list[i])
    elif index2split[i] == 'val':
        val_data.append(data_list[i])
    mapping_dict[data_list[i][0]] = i
for i in range(len(test_list)):
    mapping_dict[test_list[i][0]] = i + 105000

test_data = test_list
del test_list

In [68]:
print(len(train_data), len(val_data), len(test_data))

80000 25000 115


In [89]:
df = pd.DataFrame(data_list)
df.columns = ['sid', 'fid', 'energy']

In [90]:
df.groupby('sid').agg({'fid': len}).sort_values('fid', ascending=False).head()

Unnamed: 0_level_0,fid
sid,Unnamed: 1_level_1
YLH_Cr2O47Zn46,11034
ZPL_Cr4O8Zn2,8604
ZPL_Cr4O7Zn2,7232
YLH_Cr2O42Zn46,5897
YLH_Cr2O46Zn46,4931


In [92]:
dir_name = 's2ef_ZnCrO_train_80K_uncompressed'
os.makedirs(os.path.join(output_dir, dir_name), exist_ok=True)
file_id = 0
traj_frames = []
log_frames = []
for i in range(len(train_data)):
    if i % 100 == 0:
        print(i)
    if (i % 5000 == 0) and (i > 0):
        with open(os.path.join(output_dir, dir_name, "%d.txt" % file_id),
                  "w") as ids_log:
            ids_log.writelines(log_frames)
        with open(os.path.join(output_dir, dir_name, "%d.extxyz" % file_id),
                  'w') as f:
            extxyz.write_xyz(f, traj_frames)
        file_id += 1
        traj_frames = []
        log_frames = []
    log_frame = train_data[i]
    sid, fid, energy = log_frame
    traj_path = glob.glob(os.path.join(input_dir, sid, "*.xyz"))
    traj_frame = ase.io.read(traj_path[0], fid)
    traj_frames.append(traj_frame)
    log_frames.append(','.join(['random{}'.format(mapping_dict[sid]),
                                'frame{}'.format(fid),
                                str(energy)]) + "\n")
with open(os.path.join(output_dir, dir_name, "%d.txt" % file_id),
          "w") as ids_log:
    ids_log.writelines(log_frames)
with open(os.path.join(output_dir, dir_name, "%d.extxyz" % file_id),
          'w') as f:
    extxyz.write_xyz(f, traj_frames)

0




In [93]:
dir_name = 's2ef_ZnCrO_val_25K_uncompressed'
os.makedirs(os.path.join(output_dir, dir_name), exist_ok=True)
file_id = 0
traj_frames = []
log_frames = []
for i in range(len(val_data)):
    if i % 100 == 0:
        print(i)
    if (i % 5000 == 0) and (i > 0):
        with open(os.path.join(output_dir, dir_name, "%d.txt" % file_id),
                  "w") as ids_log:
            ids_log.writelines(log_frames)
        with open(os.path.join(output_dir, dir_name, "%d.extxyz" % file_id),
                  'w') as f:
            extxyz.write_xyz(f, traj_frames)
        file_id += 1
        traj_frames = []
        log_frames = []
    log_frame = val_data[i]
    sid, fid, energy = log_frame
    traj_path = glob.glob(os.path.join(input_dir, sid, "*.xyz"))
    traj_frame = ase.io.read(traj_path[0], fid)
    traj_frames.append(traj_frame)
    log_frames.append(','.join(['random{}'.format(mapping_dict[sid]),
                                'frame{}'.format(fid),
                                str(energy)]) + "\n")
with open(os.path.join(output_dir, dir_name, "%d.txt" % file_id),
          "w") as ids_log:
    ids_log.writelines(log_frames)
with open(os.path.join(output_dir, dir_name, "%d.extxyz" % file_id),
          'w') as f:
    extxyz.write_xyz(f, traj_frames)

0




In [96]:
dir_name = 's2ef_ZnCrO_test_115_uncompressed'
os.makedirs(os.path.join(output_dir, dir_name), exist_ok=True)
file_id = 0
traj_frames = []
log_frames = []
for i in range(len(test_data)):
    if i % 100 == 0:
        print(i)
    if (i % 5000 == 0) and (i > 0):
        with open(os.path.join(output_dir, dir_name, "%d.txt" % file_id),
                  "w") as ids_log:
            ids_log.writelines(log_frames)
        with open(os.path.join(output_dir, dir_name, "%d.extxyz" % file_id),
                  'w') as f:
            extxyz.write_xyz(f, traj_frames)
        file_id += 1
        traj_frames = []
        log_frames = []
    log_frame = test_data[i]
    sid, fid, energy = log_frame
    traj_path = glob.glob(os.path.join(input_dir, sid, "*.xyz"))
    traj_frame = ase.io.read(traj_path[0], fid)
    traj_frames.append(traj_frame)
    log_frames.append(','.join(['random{}'.format(mapping_dict[sid]),
                                'frame{}'.format(fid),
                                str(energy)]) + "\n")
with open(os.path.join(output_dir, dir_name, "%d.txt" % file_id),
          "w") as ids_log:
    ids_log.writelines(log_frames)
with open(os.path.join(output_dir, dir_name, "%d.extxyz" % file_id),
          'w') as f:
    extxyz.write_xyz(f, traj_frames)

0
100




In [99]:
joblib.dump(mapping_dict, os.path.join(output_dir, "ZnCrO_data_mapping.pkl"))

['/mnt/data1/ocp_data/data/ZnCrO_data_mapping.pkl']