# Imports and Settings

In [1]:
# -*- coding: utf-8 -*-
import pandas as pd
import numpy as np
import os
from scipy.spatial.distance import squareform, cdist
import time
from numpy.testing import *
from ase.db import connect

In [2]:
from Code.Tools.Data.printer import ProgressTimer
from Code.Tools.Data.saver import create_path
from Code.Tools.Data.transform import get_spherical, change_base, get_input_data

In [3]:
reference_db = "./Dataset/iso17/reference.db"
within_db = "./Dataset/iso17/test_within.db"
reference_npy = "./Dataset/DeepPotential/reference_"
within_npy = "./Dataset/DeepPotential/within_"
db_paths = [reference_db, within_db]
npy_paths = [reference_npy, within_npy]

In [4]:
for db_path, npy_path in zip(db_paths, npy_paths):
    
    molecules = []
    energies = []
    with connect(db_path) as conn:
        for row in conn.select():
            molecules.append(np.hstack((row['numbers'].reshape((19, 1)), row['positions'])))
            energies.append(row['total_energy'])
    print(np.array(molecules).shape)
    start = time.time()
    network_in = np.array(get_input_data(molecules))
    print('time: {}'.format(time.time() - start))
    np.save(npy_path + 'X', network_in)
    np.save(npy_path + 'Y', energies)

(404000, 19, 4)
0.0  % ---  69.8
0.0  % ---  60.9
0.1  % ---  58.3
0.1  % ---  57.2
0.1  % ---  57.3
0.1  % ---  56.8
0.2  % ---  56.2
0.2  % ---  55.8
0.2  % ---  55.7
0.2  % ---  55.7
0.3  % ---  55.6
0.3  % ---  55.3
0.3  % ---  55.1
0.3  % ---  55.0
0.4  % ---  54.9
0.4  % ---  54.8
0.4  % ---  55.2
0.4  % ---  55.2
0.5  % ---  55.3
0.5  % ---  55.2
0.5  % ---  55.2
0.5  % ---  55.2
0.6  % ---  55.2
0.6  % ---  55.0
0.6  % ---  54.9
0.6  % ---  54.8
0.7  % ---  54.7
0.7  % ---  54.8
0.7  % ---  54.7
0.7  % ---  54.6
0.8  % ---  54.6
0.8  % ---  54.5
0.8  % ---  54.5
0.8  % ---  54.5
0.9  % ---  54.5
0.9  % ---  54.4
0.9  % ---  54.4
0.9  % ---  54.5
1.0  % ---  54.5
1.0  % ---  54.6
1.0  % ---  54.6
1.0  % ---  54.6
1.1  % ---  54.6
1.1  % ---  54.5
1.1  % ---  54.5
1.1  % ---  54.6
1.2  % ---  54.5
1.2  % ---  54.6
1.2  % ---  54.6
1.2  % ---  54.7
1.3  % ---  54.7
1.3  % ---  54.8
1.3  % ---  54.7
1.3  % ---  54.7
1.4  % ---  54.6
1.4  % ---  55.8
1.4  % ---  55.8
1.4  % ---  55.

11.8  % ---  55.8
11.9  % ---  55.7
11.9  % ---  55.7
11.9  % ---  55.7
11.9  % ---  55.6
12.0  % ---  55.6
12.0  % ---  55.6
12.0  % ---  55.5
12.0  % ---  55.5
12.1  % ---  55.5
12.1  % ---  55.5
12.1  % ---  55.4
12.1  % ---  55.4
12.2  % ---  55.4
12.2  % ---  55.3
12.2  % ---  55.3
12.2  % ---  55.3
12.3  % ---  55.2
12.3  % ---  55.2
12.3  % ---  55.2
12.3  % ---  55.1
12.4  % ---  55.1
12.4  % ---  55.1
12.4  % ---  55.0
12.4  % ---  55.0
12.5  % ---  55.0
12.5  % ---  55.0
12.5  % ---  54.9
12.5  % ---  54.9
12.5  % ---  54.9
12.6  % ---  54.8
12.6  % ---  54.8
12.6  % ---  54.8
12.6  % ---  54.7
12.7  % ---  54.7
12.7  % ---  54.7
12.7  % ---  54.6
12.7  % ---  54.6
12.8  % ---  54.6
12.8  % ---  54.6
12.8  % ---  54.5
12.8  % ---  54.5
12.9  % ---  54.5
12.9  % ---  54.4
12.9  % ---  54.4
12.9  % ---  54.4
13.0  % ---  54.3
13.0  % ---  54.3
13.0  % ---  54.3
13.0  % ---  54.3
13.1  % ---  54.2
13.1  % ---  54.2
13.1  % ---  54.2
13.1  % ---  54.1
13.2  % ---  54.1
13.2  % --

23.1  % ---  45.7
23.1  % ---  45.7
23.2  % ---  45.6
23.2  % ---  45.6
23.2  % ---  45.6
23.2  % ---  45.6
23.3  % ---  45.6
23.3  % ---  45.5
23.3  % ---  45.5
23.3  % ---  45.5
23.4  % ---  45.5
23.4  % ---  45.5
23.4  % ---  45.4
23.4  % ---  45.4
23.5  % ---  45.4
23.5  % ---  45.4
23.5  % ---  45.4
23.5  % ---  45.4
23.6  % ---  45.3
23.6  % ---  45.3
23.6  % ---  45.3
23.6  % ---  45.3
23.7  % ---  45.3
23.7  % ---  45.2
23.7  % ---  45.2
23.7  % ---  45.2
23.8  % ---  45.2
23.8  % ---  45.2
23.8  % ---  45.1
23.8  % ---  45.1
23.9  % ---  45.1
23.9  % ---  45.1
23.9  % ---  45.2
23.9  % ---  45.1
24.0  % ---  45.1
24.0  % ---  45.1
24.0  % ---  45.1
24.0  % ---  45.1
24.1  % ---  45.1
24.1  % ---  45.0
24.1  % ---  45.0
24.1  % ---  45.0
24.2  % ---  45.0
24.2  % ---  45.0
24.2  % ---  44.9
24.2  % ---  44.9
24.3  % ---  44.9
24.3  % ---  44.9
24.3  % ---  44.9
24.3  % ---  44.9
24.4  % ---  44.8
24.4  % ---  44.8
24.4  % ---  44.8
24.4  % ---  44.8
24.5  % ---  44.8
24.5  % --

34.4  % ---  37.9
34.4  % ---  37.9
34.5  % ---  37.9
34.5  % ---  37.9
34.5  % ---  37.9
34.5  % ---  37.8
34.6  % ---  37.8
34.6  % ---  37.8
34.6  % ---  37.8
34.6  % ---  37.8
34.7  % ---  37.8
34.7  % ---  37.7
34.7  % ---  37.7
34.7  % ---  37.7
34.8  % ---  37.7
34.8  % ---  37.7
34.8  % ---  37.7
34.8  % ---  37.7
34.9  % ---  37.6
34.9  % ---  37.6
34.9  % ---  37.6
34.9  % ---  37.6
35.0  % ---  37.6
35.0  % ---  37.6
35.0  % ---  37.5
35.0  % ---  37.5
35.0  % ---  37.5
35.1  % ---  37.5
35.1  % ---  37.5
35.1  % ---  37.5
35.1  % ---  37.4
35.2  % ---  37.4
35.2  % ---  37.4
35.2  % ---  37.4
35.2  % ---  37.4
35.3  % ---  37.4
35.3  % ---  37.3
35.3  % ---  37.3
35.3  % ---  37.3
35.4  % ---  37.3
35.4  % ---  37.3
35.4  % ---  37.3
35.4  % ---  37.3
35.5  % ---  37.2
35.5  % ---  37.2
35.5  % ---  37.2
35.5  % ---  37.2
35.6  % ---  37.2
35.6  % ---  37.2
35.6  % ---  37.1
35.6  % ---  37.1
35.7  % ---  37.1
35.7  % ---  37.1
35.7  % ---  37.1
35.7  % ---  37.1
35.8  % --

45.7  % ---  31.0
45.7  % ---  31.0
45.7  % ---  31.0
45.8  % ---  31.0
45.8  % ---  31.0
45.8  % ---  30.9
45.8  % ---  30.9
45.9  % ---  30.9
45.9  % ---  30.9
45.9  % ---  30.9
45.9  % ---  30.9
46.0  % ---  30.9
46.0  % ---  30.8
46.0  % ---  30.8
46.0  % ---  30.8
46.1  % ---  30.8
46.1  % ---  30.8
46.1  % ---  30.8
46.1  % ---  30.7
46.2  % ---  30.7
46.2  % ---  30.7
46.2  % ---  30.7
46.2  % ---  30.7
46.3  % ---  30.7
46.3  % ---  30.7
46.3  % ---  30.6
46.3  % ---  30.6
46.4  % ---  30.6
46.4  % ---  30.6
46.4  % ---  30.6
46.4  % ---  30.6
46.5  % ---  30.6
46.5  % ---  30.5
46.5  % ---  30.5
46.5  % ---  30.5
46.6  % ---  30.5
46.6  % ---  30.5
46.6  % ---  30.5
46.6  % ---  30.5
46.7  % ---  30.4
46.7  % ---  30.4
46.7  % ---  30.4
46.7  % ---  30.4
46.8  % ---  30.4
46.8  % ---  30.4
46.8  % ---  30.4
46.8  % ---  30.3
46.9  % ---  30.3
46.9  % ---  30.3
46.9  % ---  30.3
46.9  % ---  30.3
47.0  % ---  30.3
47.0  % ---  30.3
47.0  % ---  30.2
47.0  % ---  30.2
47.1  % --

57.0  % ---  24.7
57.0  % ---  24.7
57.0  % ---  24.7
57.1  % ---  24.7
57.1  % ---  24.6
57.1  % ---  24.6
57.1  % ---  24.6
57.2  % ---  24.6
57.2  % ---  24.6
57.2  % ---  24.6
57.2  % ---  24.5
57.3  % ---  24.5
57.3  % ---  24.5
57.3  % ---  24.5
57.3  % ---  24.5
57.4  % ---  24.5
57.4  % ---  24.5
57.4  % ---  24.5
57.4  % ---  24.4
57.5  % ---  24.4
57.5  % ---  24.4
57.5  % ---  24.4
57.5  % ---  24.4
57.5  % ---  24.4
57.6  % ---  24.4
57.6  % ---  24.3
57.6  % ---  24.3
57.6  % ---  24.3
57.7  % ---  24.3
57.7  % ---  24.3
57.7  % ---  24.3
57.7  % ---  24.3
57.8  % ---  24.2
57.8  % ---  24.2
57.8  % ---  24.2
57.8  % ---  24.2
57.9  % ---  24.2
57.9  % ---  24.2
57.9  % ---  24.2
57.9  % ---  24.1
58.0  % ---  24.1
58.0  % ---  24.1
58.0  % ---  24.1
58.0  % ---  24.1
58.1  % ---  24.1
58.1  % ---  24.1
58.1  % ---  24.0
58.1  % ---  24.0
58.2  % ---  24.0
58.2  % ---  24.0
58.2  % ---  24.0
58.2  % ---  24.0
58.3  % ---  24.0
58.3  % ---  23.9
58.3  % ---  23.9
58.3  % --

68.3  % ---  18.4
68.3  % ---  18.4
68.3  % ---  18.4
68.3  % ---  18.4
68.4  % ---  18.4
68.4  % ---  18.3
68.4  % ---  18.3
68.4  % ---  18.3
68.5  % ---  18.3
68.5  % ---  18.3
68.5  % ---  18.3
68.5  % ---  18.2
68.6  % ---  18.2
68.6  % ---  18.2
68.6  % ---  18.2
68.6  % ---  18.2
68.7  % ---  18.2
68.7  % ---  18.2
68.7  % ---  18.1
68.7  % ---  18.1
68.8  % ---  18.1
68.8  % ---  18.1
68.8  % ---  18.1
68.8  % ---  18.1
68.9  % ---  18.1
68.9  % ---  18.0
68.9  % ---  18.0
68.9  % ---  18.0
69.0  % ---  18.0
69.0  % ---  18.0
69.0  % ---  18.0
69.0  % ---  18.0
69.1  % ---  17.9
69.1  % ---  17.9
69.1  % ---  17.9
69.1  % ---  17.9
69.2  % ---  17.9
69.2  % ---  17.9
69.2  % ---  17.9
69.2  % ---  17.8
69.3  % ---  17.8
69.3  % ---  17.8
69.3  % ---  17.8
69.3  % ---  17.8
69.4  % ---  17.8
69.4  % ---  17.8
69.4  % ---  17.7
69.4  % ---  17.7
69.5  % ---  17.7
69.5  % ---  17.7
69.5  % ---  17.7
69.5  % ---  17.7
69.6  % ---  17.6
69.6  % ---  17.6
69.6  % ---  17.6
69.6  % --

79.6  % ---  12.0
79.6  % ---  12.0
79.6  % ---  12.0
79.6  % ---  11.9
79.7  % ---  11.9
79.7  % ---  11.9
79.7  % ---  11.9
79.7  % ---  11.9
79.8  % ---  11.9
79.8  % ---  11.8
79.8  % ---  11.8
79.8  % ---  11.8
79.9  % ---  11.8
79.9  % ---  11.8
79.9  % ---  11.8
79.9  % ---  11.8
80.0  % ---  11.7
80.0  % ---  11.7
80.0  % ---  11.7
80.0  % ---  11.7
80.0  % ---  11.7
80.1  % ---  11.7
80.1  % ---  11.7
80.1  % ---  11.6
80.1  % ---  11.6
80.2  % ---  11.6
80.2  % ---  11.6
80.2  % ---  11.6
80.2  % ---  11.6
80.3  % ---  11.5
80.3  % ---  11.5
80.3  % ---  11.5
80.3  % ---  11.5
80.4  % ---  11.5
80.4  % ---  11.5
80.4  % ---  11.5
80.4  % ---  11.4
80.5  % ---  11.4
80.5  % ---  11.4
80.5  % ---  11.4
80.5  % ---  11.4
80.6  % ---  11.4
80.6  % ---  11.4
80.6  % ---  11.3
80.6  % ---  11.3
80.7  % ---  11.3
80.7  % ---  11.3
80.7  % ---  11.3
80.7  % ---  11.3
80.8  % ---  11.3
80.8  % ---  11.2
80.8  % ---  11.2
80.8  % ---  11.2
80.9  % ---  11.2
80.9  % ---  11.2
80.9  % --

91.3  % ---  5.0
91.3  % ---  5.0
91.3  % ---  5.0
91.4  % ---  5.0
91.4  % ---  5.0
91.4  % ---  5.0
91.4  % ---  5.0
91.5  % ---  4.9
91.5  % ---  4.9
91.5  % ---  4.9
91.5  % ---  4.9
91.6  % ---  4.9
91.6  % ---  4.9
91.6  % ---  4.9
91.6  % ---  4.8
91.7  % ---  4.8
91.7  % ---  4.8
91.7  % ---  4.8
91.7  % ---  4.8
91.8  % ---  4.8
91.8  % ---  4.7
91.8  % ---  4.7
91.8  % ---  4.7
91.9  % ---  4.7
91.9  % ---  4.7
91.9  % ---  4.7
91.9  % ---  4.7
92.0  % ---  4.6
92.0  % ---  4.6
92.0  % ---  4.6
92.0  % ---  4.6
92.1  % ---  4.6
92.1  % ---  4.6
92.1  % ---  4.6
92.1  % ---  4.5
92.2  % ---  4.5
92.2  % ---  4.5
92.2  % ---  4.5
92.2  % ---  4.5
92.3  % ---  4.5
92.3  % ---  4.5
92.3  % ---  4.4
92.3  % ---  4.4
92.4  % ---  4.4
92.4  % ---  4.4
92.4  % ---  4.4
92.4  % ---  4.4
92.5  % ---  4.4
92.5  % ---  4.3
92.5  % ---  4.3
92.5  % ---  4.3
92.5  % ---  4.3
92.6  % ---  4.3
92.6  % ---  4.3
92.6  % ---  4.3
92.6  % ---  4.2
92.7  % ---  4.2
92.7  % ---  4.2
92.7  % ---  4

IndexError: list index out of range

In [7]:
np.save(npy_path + 'X', network_in)
np.save(npy_path + 'Y', energies)

In [10]:
energies

[]

# Testing

## Test Functions

In [None]:
def test_get_spherical():
    test_positions = np.array([[0, 1, 2],
                               [1, 1, 1],
                               [-1, 2, 1]])
    val_result = np.array([[1/np.sqrt(5), np.cos(np.arccos(2/np.sqrt(5))),
                            np.cos(np.pi/2), np.sin(np.pi/2)],
                           [1/np.sqrt(3), np.cos(np.arccos(1/np.sqrt(3))),
                            np.cos(np.arctan(1)), np.sin(np.arctan(1))],
                           [1/np.sqrt(6), np.cos(np.arccos(1/np.sqrt(6))),
                            np.cos(np.arctan(-2) + np.pi), np.sin(np.arctan(-2) + np.pi)]])
    assert_array_almost_equal(val_result, get_spherical(test_positions)) 
    print('pass!')

In [None]:
def test_change_base():
    test_positions = np.array([[0, 1, 2],
                               [1, 1, 1],
                               [-1, 2, 1]])
    x = np.array([1, 1, 0])
    y = np.array([0, 0, 1])
    z = np.array([2, 1, 3])
    val_result = np.array([[-7.,-13.,4.],
                           [-8.,-17.,5.],
                           [-4., -8.,2.]])
    o = np.array([-1, 4, 3])
    assert_array_almost_equal(val_result, change_base(test_positions, x, y, z, o))
    print('pass!')

In [None]:
def test_input_data():
    # load molecule and sort
    test_mol = molecules[0]
    test_mol = test_mol[test_mol[:,0].argsort(kind='mergesort')]
    results = np.zeros((19, 72))
    positions = test_mol[:, 1:]
    # H-part
    for i in range(10):
        dists = cdist(positions[np.newaxis, i], positions[10:])
        one, two = dists.argsort().reshape(9)[:2] + 10
        zero = i
        x = positions[one] - positions[zero]
        z = np.cross(x, positions[two] - positions[zero])
        y = np.cross(z, x)
        x /= np.linalg.norm(x)
        y /= np.linalg.norm(y)
        z /= np.linalg.norm(z)
        others = np.vstack((test_mol[:i], test_mol[i+1:]))
        ch_b_others = change_base(others[:, 1:], x, y, z, positions[zero])
        sph_others = get_spherical(ch_b_others)
        ids = sph_others[:, 0].argsort(kind='mergesort')
        sph_others = sph_others[ids]
        others = others[ids]
        sph_others = sph_others[others[:, 0].argsort(kind='mergesort')]
        results[i] = sph_others.reshape(72)
    # C-part
    for i in range(10, 17):
        others = np.vstack((test_mol[:i], test_mol[i+1:]))
        neighs = others[10:]
        dists = cdist(positions[np.newaxis, i], neighs[:, 1:])
        one_id, two_id = dists.argsort().reshape(8)[:2]
        one, two = neighs[one_id, 1:], neighs[two_id, 1:]
        zero =test_mol[i, 1:]
        x = one - zero
        z = np.cross(x, two - zero)
        y = np.cross(z, x)
        x /= np.linalg.norm(x)
        y /= np.linalg.norm(y)
        z /= np.linalg.norm(z)
        ch_b_others = change_base(others[:, 1:], x, y, z, zero)
        sph_others = get_spherical(ch_b_others)
        ids = sph_others[:, 0].argsort(kind='mergesort')
        sph_others = sph_others[ids]
        others = others[ids]
        sph_others = sph_others[others[:, 0].argsort(kind='mergesort')]
        results[i] = sph_others.reshape(72)
    # O-part
    for i in range(17, 19):
        others = np.vstack((test_mol[:i], test_mol[i+1:]))
        neighs = others[10:]
        dists = cdist(positions[np.newaxis, i], neighs[:, 1:])
        one_id, two_id = dists.argsort().reshape(8)[:2]
        one, two = neighs[one_id, 1:], neighs[two_id, 1:]
        zero =test_mol[i, 1:]
        x = one - zero
        z = np.cross(x, two - zero)
        y = np.cross(z, x)
        x /= np.linalg.norm(x)
        y /= np.linalg.norm(y)
        z /= np.linalg.norm(z)
        ch_b_others = change_base(others[:, 1:], x, y, z, zero)
        sph_others = get_spherical(ch_b_others)
        ids = sph_others[:, 0].argsort(kind='mergesort')
        sph_others = sph_others[ids]
        others = others[ids]
        sph_others = sph_others[others[:, 0].argsort(kind='mergesort')]
        results[i] = sph_others.reshape(72)
    test_result = np.array(get_input_data(test_mol[np.newaxis])).reshape(19, 72)
    
    assert_array_almost_equal(test_result, results)
    print('pass!')

## Run Tests

In [None]:
test_get_spherical()

In [None]:
test_change_base()

In [None]:
test_input_data()