# Imports and Settings

In [1]:
# -*- coding: utf-8 -*-
import pandas as pd
import numpy as np
import os
from scipy.spatial.distance import squareform, cdist
import time
from numpy.testing import *

In [2]:
from Code.DataGeneration.printer import ProgressTimer
from Code.DataGeneration.saver import create_path
from Code.DataGeneration.transform import get_spherical, change_base, get_input_data

In [3]:
path_to_db = "./Dataset/iso17/reference.db"

In [4]:
from ase.db import connect

molecules = []
energies = []
with connect(path_to_db) as conn:
    for row in conn.select():
        molecules.append(np.hstack((row['numbers'].reshape((19, 1)), row['positions'])))
        energies.append(row['total_energy'])

In [5]:
np.array(molecules).shape

(404000, 19, 4)

# Run Calculations

In [6]:
start = time.time()
network_in = np.array(get_input_data(molecules))
print('time: {}'.format(time.time()-start))

0.0  % ---  70.7
0.0  % ---  73.2
0.1  % ---  79.9
0.1  % ---  76.8
0.1  % ---  74.3
0.1  % ---  70.9
0.2  % ---  69.4
0.2  % ---  67.5
0.2  % ---  66.1
0.2  % ---  64.9
0.3  % ---  64.2
0.3  % ---  64.7
0.3  % ---  63.8
0.3  % ---  63.6
0.4  % ---  62.9
0.4  % ---  63.2
0.4  % ---  62.7
0.4  % ---  62.0
0.5  % ---  61.7
0.5  % ---  61.5
0.5  % ---  61.7
0.5  % ---  61.2
0.6  % ---  61.0
0.6  % ---  60.8
0.6  % ---  60.6
0.6  % ---  60.8
0.7  % ---  60.8
0.7  % ---  60.4
0.7  % ---  60.4
0.7  % ---  60.6
0.8  % ---  60.5
0.8  % ---  60.5
0.8  % ---  60.5
0.8  % ---  60.5
0.9  % ---  60.6
0.9  % ---  60.3
0.9  % ---  60.1
0.9  % ---  60.0
1.0  % ---  60.0
1.0  % ---  60.2
1.0  % ---  60.1
1.0  % ---  59.9
1.1  % ---  59.6
1.1  % ---  59.6
1.1  % ---  59.4
1.1  % ---  59.4
1.2  % ---  59.3
1.2  % ---  59.4
1.2  % ---  59.5
1.2  % ---  59.7
1.3  % ---  60.1
1.3  % ---  60.2
1.3  % ---  60.1
1.3  % ---  60.0
1.4  % ---  59.9
1.4  % ---  59.8
1.4  % ---  59.8
1.4  % ---  59.9
1.5  % ---  59

11.9  % ---  50.9
11.9  % ---  50.9
11.9  % ---  50.9
11.9  % ---  50.9
12.0  % ---  50.8
12.0  % ---  50.8
12.0  % ---  50.8
12.0  % ---  50.7
12.1  % ---  50.7
12.1  % ---  50.7
12.1  % ---  50.7
12.1  % ---  50.7
12.2  % ---  50.7
12.2  % ---  50.6
12.2  % ---  50.6
12.2  % ---  50.7
12.3  % ---  50.6
12.3  % ---  50.6
12.3  % ---  50.6
12.3  % ---  50.6
12.4  % ---  50.5
12.4  % ---  50.5
12.4  % ---  50.5
12.4  % ---  50.5
12.5  % ---  50.4
12.5  % ---  50.4
12.5  % ---  50.4
12.5  % ---  50.3
12.5  % ---  50.3
12.6  % ---  50.3
12.6  % ---  50.3
12.6  % ---  50.3
12.6  % ---  50.2
12.7  % ---  50.2
12.7  % ---  50.2
12.7  % ---  50.2
12.7  % ---  50.2
12.8  % ---  50.2
12.8  % ---  50.2
12.8  % ---  50.1
12.8  % ---  50.1
12.9  % ---  50.1
12.9  % ---  50.1
12.9  % ---  50.0
12.9  % ---  50.0
13.0  % ---  50.0
13.0  % ---  50.0
13.0  % ---  49.9
13.0  % ---  49.9
13.1  % ---  49.9
13.1  % ---  49.9
13.1  % ---  49.8
13.1  % ---  49.8
13.2  % ---  49.8
13.2  % ---  49.7
13.2  % --

23.1  % ---  42.4
23.2  % ---  42.4
23.2  % ---  42.4
23.2  % ---  42.3
23.2  % ---  42.3
23.3  % ---  42.3
23.3  % ---  42.3
23.3  % ---  42.3
23.3  % ---  42.3
23.4  % ---  42.2
23.4  % ---  42.2
23.4  % ---  42.2
23.4  % ---  42.2
23.5  % ---  42.2
23.5  % ---  42.2
23.5  % ---  42.2
23.5  % ---  42.1
23.6  % ---  42.1
23.6  % ---  42.1
23.6  % ---  42.1
23.6  % ---  42.1
23.7  % ---  42.1
23.7  % ---  42.0
23.7  % ---  42.0
23.7  % ---  42.0
23.8  % ---  42.0
23.8  % ---  42.0
23.8  % ---  42.0
23.8  % ---  41.9
23.9  % ---  41.9
23.9  % ---  41.9
23.9  % ---  41.9
23.9  % ---  41.9
24.0  % ---  41.9
24.0  % ---  41.9
24.0  % ---  41.8
24.0  % ---  41.8
24.1  % ---  41.8
24.1  % ---  41.8
24.1  % ---  41.8
24.1  % ---  41.8
24.2  % ---  41.7
24.2  % ---  41.7
24.2  % ---  41.7
24.2  % ---  41.7
24.3  % ---  41.7
24.3  % ---  41.7
24.3  % ---  41.7
24.3  % ---  41.6
24.4  % ---  41.6
24.4  % ---  41.6
24.4  % ---  41.6
24.4  % ---  41.6
24.5  % ---  41.6
24.5  % ---  41.5
24.5  % --

34.4  % ---  35.3
34.5  % ---  35.3
34.5  % ---  35.3
34.5  % ---  35.2
34.5  % ---  35.2
34.6  % ---  35.2
34.6  % ---  35.2
34.6  % ---  35.2
34.6  % ---  35.2
34.7  % ---  35.1
34.7  % ---  35.1
34.7  % ---  35.1
34.7  % ---  35.1
34.8  % ---  35.1
34.8  % ---  35.1
34.8  % ---  35.1
34.8  % ---  35.0
34.9  % ---  35.0
34.9  % ---  35.0
34.9  % ---  35.0
34.9  % ---  35.0
35.0  % ---  35.0
35.0  % ---  34.9
35.0  % ---  34.9
35.0  % ---  34.9
35.0  % ---  34.9
35.1  % ---  34.9
35.1  % ---  34.9
35.1  % ---  34.9
35.1  % ---  34.8
35.2  % ---  34.8
35.2  % ---  34.8
35.2  % ---  34.8
35.2  % ---  34.8
35.3  % ---  34.8
35.3  % ---  34.8
35.3  % ---  34.7
35.3  % ---  34.7
35.4  % ---  34.7
35.4  % ---  34.7
35.4  % ---  34.7
35.4  % ---  34.7
35.5  % ---  34.6
35.5  % ---  34.6
35.5  % ---  34.6
35.5  % ---  34.6
35.6  % ---  34.6
35.6  % ---  34.6
35.6  % ---  34.6
35.6  % ---  34.5
35.7  % ---  34.5
35.7  % ---  34.5
35.7  % ---  34.5
35.7  % ---  34.5
35.8  % ---  34.5
35.8  % --

45.7  % ---  29.0
45.7  % ---  28.9
45.8  % ---  28.9
45.8  % ---  28.9
45.8  % ---  28.9
45.8  % ---  28.9
45.9  % ---  28.9
45.9  % ---  28.9
45.9  % ---  28.9
45.9  % ---  28.8
46.0  % ---  28.8
46.0  % ---  28.8
46.0  % ---  28.8
46.0  % ---  28.8
46.1  % ---  28.8
46.1  % ---  28.8
46.1  % ---  28.8
46.1  % ---  28.7
46.2  % ---  28.7
46.2  % ---  28.7
46.2  % ---  28.7
46.2  % ---  28.7
46.3  % ---  28.7
46.3  % ---  28.7
46.3  % ---  28.7
46.3  % ---  28.6
46.4  % ---  28.6
46.4  % ---  28.6
46.4  % ---  28.6
46.4  % ---  28.6
46.5  % ---  28.6
46.5  % ---  28.6
46.5  % ---  28.6
46.5  % ---  28.5
46.6  % ---  28.5
46.6  % ---  28.5
46.6  % ---  28.5
46.6  % ---  28.5
46.7  % ---  28.5
46.7  % ---  28.5
46.7  % ---  28.5
46.7  % ---  28.4
46.8  % ---  28.4
46.8  % ---  28.4
46.8  % ---  28.4
46.8  % ---  28.4
46.9  % ---  28.4
46.9  % ---  28.4
46.9  % ---  28.3
46.9  % ---  28.3
47.0  % ---  28.3
47.0  % ---  28.3
47.0  % ---  28.3
47.0  % ---  28.3
47.1  % ---  28.3
47.1  % --

57.0  % ---  23.1
57.0  % ---  23.1
57.1  % ---  23.0
57.1  % ---  23.0
57.1  % ---  23.0
57.1  % ---  23.0
57.2  % ---  23.0
57.2  % ---  23.0
57.2  % ---  23.0
57.2  % ---  22.9
57.3  % ---  22.9
57.3  % ---  22.9
57.3  % ---  22.9
57.3  % ---  22.9
57.4  % ---  23.0
57.4  % ---  23.0
57.4  % ---  22.9
57.4  % ---  22.9
57.5  % ---  22.9
57.5  % ---  22.9
57.5  % ---  22.9
57.5  % ---  22.9
57.5  % ---  22.9
57.6  % ---  22.9
57.6  % ---  22.8
57.6  % ---  22.8
57.6  % ---  22.8
57.7  % ---  22.8
57.7  % ---  22.8
57.7  % ---  22.8
57.7  % ---  22.8
57.8  % ---  22.8
57.8  % ---  22.7
57.8  % ---  22.7
57.8  % ---  22.7
57.9  % ---  22.7
57.9  % ---  22.7
57.9  % ---  22.7
57.9  % ---  22.7
58.0  % ---  22.6
58.0  % ---  22.6
58.0  % ---  22.6
58.0  % ---  22.6
58.1  % ---  22.6
58.1  % ---  22.6
58.1  % ---  22.6
58.1  % ---  22.5
58.2  % ---  22.5
58.2  % ---  22.5
58.2  % ---  22.5
58.2  % ---  22.5
58.3  % ---  22.5
58.3  % ---  22.5
58.3  % ---  22.5
58.3  % ---  22.4
58.4  % --

68.3  % ---  16.9
68.3  % ---  16.9
68.3  % ---  16.9
68.4  % ---  16.9
68.4  % ---  16.9
68.4  % ---  16.8
68.4  % ---  16.8
68.5  % ---  16.8
68.5  % ---  16.8
68.5  % ---  16.8
68.5  % ---  16.8
68.6  % ---  16.8
68.6  % ---  16.8
68.6  % ---  16.7
68.6  % ---  16.7
68.7  % ---  16.7
68.7  % ---  16.7
68.7  % ---  16.7
68.7  % ---  16.7
68.8  % ---  16.7
68.8  % ---  16.6
68.8  % ---  16.6
68.8  % ---  16.6
68.9  % ---  16.6
68.9  % ---  16.6
68.9  % ---  16.6
68.9  % ---  16.6
69.0  % ---  16.5
69.0  % ---  16.5
69.0  % ---  16.5
69.0  % ---  16.5
69.1  % ---  16.5
69.1  % ---  16.5
69.1  % ---  16.5
69.1  % ---  16.5
69.2  % ---  16.4
69.2  % ---  16.4
69.2  % ---  16.4
69.2  % ---  16.4
69.3  % ---  16.4
69.3  % ---  16.4
69.3  % ---  16.4
69.3  % ---  16.3
69.4  % ---  16.3
69.4  % ---  16.3
69.4  % ---  16.3
69.4  % ---  16.3
69.5  % ---  16.3
69.5  % ---  16.3
69.5  % ---  16.2
69.5  % ---  16.2
69.6  % ---  16.2
69.6  % ---  16.2
69.6  % ---  16.2
69.6  % ---  16.2
69.7  % --

79.6  % ---  10.9
79.6  % ---  10.9
79.6  % ---  10.9
79.7  % ---  10.8
79.7  % ---  10.8
79.7  % ---  10.8
79.7  % ---  10.8
79.8  % ---  10.8
79.8  % ---  10.8
79.8  % ---  10.8
79.8  % ---  10.7
79.9  % ---  10.7
79.9  % ---  10.7
79.9  % ---  10.7
79.9  % ---  10.7
80.0  % ---  10.7
80.0  % ---  10.7
80.0  % ---  10.7
80.0  % ---  10.6
80.0  % ---  10.6
80.1  % ---  10.6
80.1  % ---  10.6
80.1  % ---  10.6
80.1  % ---  10.6
80.2  % ---  10.6
80.2  % ---  10.5
80.2  % ---  10.5
80.2  % ---  10.5
80.3  % ---  10.5
80.3  % ---  10.5
80.3  % ---  10.5
80.3  % ---  10.5
80.4  % ---  10.5
80.4  % ---  10.4
80.4  % ---  10.4
80.4  % ---  10.4
80.5  % ---  10.4
80.5  % ---  10.4
80.5  % ---  10.4
80.5  % ---  10.4
80.6  % ---  10.3
80.6  % ---  10.3
80.6  % ---  10.3
80.6  % ---  10.3
80.7  % ---  10.3
80.7  % ---  10.3
80.7  % ---  10.3
80.7  % ---  10.3
80.8  % ---  10.2
80.8  % ---  10.2
80.8  % ---  10.2
80.8  % ---  10.2
80.9  % ---  10.2
80.9  % ---  10.2
80.9  % ---  10.2
80.9  % --

91.4  % ---  4.7
91.4  % ---  4.6
91.5  % ---  4.6
91.5  % ---  4.6
91.5  % ---  4.6
91.5  % ---  4.6
91.6  % ---  4.6
91.6  % ---  4.6
91.6  % ---  4.5
91.6  % ---  4.5
91.7  % ---  4.5
91.7  % ---  4.5
91.7  % ---  4.5
91.7  % ---  4.5
91.8  % ---  4.5
91.8  % ---  4.5
91.8  % ---  4.4
91.8  % ---  4.4
91.9  % ---  4.4
91.9  % ---  4.4
91.9  % ---  4.4
91.9  % ---  4.4
92.0  % ---  4.4
92.0  % ---  4.3
92.0  % ---  4.3
92.0  % ---  4.3
92.1  % ---  4.3
92.1  % ---  4.3
92.1  % ---  4.3
92.1  % ---  4.3
92.2  % ---  4.3
92.2  % ---  4.2
92.2  % ---  4.2
92.2  % ---  4.2
92.3  % ---  4.2
92.3  % ---  4.2
92.3  % ---  4.2
92.3  % ---  4.2
92.4  % ---  4.1
92.4  % ---  4.1
92.4  % ---  4.1
92.4  % ---  4.1
92.5  % ---  4.1
92.5  % ---  4.1
92.5  % ---  4.1
92.5  % ---  4.0
92.5  % ---  4.0
92.6  % ---  4.0
92.6  % ---  4.0
92.6  % ---  4.0
92.6  % ---  4.0
92.7  % ---  4.0
92.7  % ---  4.0
92.7  % ---  3.9
92.7  % ---  3.9
92.8  % ---  3.9
92.8  % ---  3.9
92.8  % ---  3.9
92.8  % ---  3

In [7]:
network_in.shape

(404000, 19, 72)

In [8]:
len(energies)

404000

## Save arrays to file

In [9]:
data_path = './Dataset/c7o2h10_X'
label_path = './Dataset/c7o2h10_Y'

In [10]:
np.save(data_path, network_in)
np.save(label_path, energies)

# Testing

## Test Functions

In [11]:
def test_get_spherical():
    test_positions = np.array([[0, 1, 2],
                               [1, 1, 1],
                               [-1, 2, 1]])
    val_result = np.array([[1/np.sqrt(5), np.cos(np.arccos(2/np.sqrt(5))),
                            np.cos(np.pi/2), np.sin(np.pi/2)],
                           [1/np.sqrt(3), np.cos(np.arccos(1/np.sqrt(3))),
                            np.cos(np.arctan(1)), np.sin(np.arctan(1))],
                           [1/np.sqrt(6), np.cos(np.arccos(1/np.sqrt(6))),
                            np.cos(np.arctan(-2) + np.pi), np.sin(np.arctan(-2) + np.pi)]])
    assert_array_almost_equal(val_result, get_spherical(test_positions)) 
    print('pass!')

In [12]:
def test_change_base():
    test_positions = np.array([[0, 1, 2],
                               [1, 1, 1],
                               [-1, 2, 1]])
    x = np.array([1, 1, 0])
    y = np.array([0, 0, 1])
    z = np.array([2, 1, 3])
    val_result = np.array([[-7.,-13.,4.],
                           [-8.,-17.,5.],
                           [-4., -8.,2.]])
    o = np.array([-1, 4, 3])
    assert_array_almost_equal(val_result, change_base(test_positions, x, y, z, o))
    print('pass!')

In [13]:
def test_input_data():
    # load molecule and sort
    test_mol = molecules[0]
    test_mol = test_mol[test_mol[:,0].argsort(kind='mergesort')]
    results = np.zeros((19, 72))
    positions = test_mol[:, 1:]
    # H-part
    for i in range(10):
        dists = cdist(positions[np.newaxis, i], positions[10:])
        one, two = dists.argsort().reshape(9)[:2] + 10
        zero = i
        x = positions[one] - positions[zero]
        z = np.cross(x, positions[two] - positions[zero])
        y = np.cross(z, x)
        x /= np.linalg.norm(x)
        y /= np.linalg.norm(y)
        z /= np.linalg.norm(z)
        others = np.vstack((test_mol[:i], test_mol[i+1:]))
        ch_b_others = change_base(others[:, 1:], x, y, z, positions[zero])
        sph_others = get_spherical(ch_b_others)
        ids = sph_others[:, 0].argsort(kind='mergesort')
        sph_others = sph_others[ids]
        others = others[ids]
        sph_others = sph_others[others[:, 0].argsort(kind='mergesort')]
        results[i] = sph_others.reshape(72)
    # C-part
    for i in range(10, 17):
        others = np.vstack((test_mol[:i], test_mol[i+1:]))
        neighs = others[10:]
        dists = cdist(positions[np.newaxis, i], neighs[:, 1:])
        one_id, two_id = dists.argsort().reshape(8)[:2]
        one, two = neighs[one_id, 1:], neighs[two_id, 1:]
        zero =test_mol[i, 1:]
        x = one - zero
        z = np.cross(x, two - zero)
        y = np.cross(z, x)
        x /= np.linalg.norm(x)
        y /= np.linalg.norm(y)
        z /= np.linalg.norm(z)
        ch_b_others = change_base(others[:, 1:], x, y, z, zero)
        sph_others = get_spherical(ch_b_others)
        ids = sph_others[:, 0].argsort(kind='mergesort')
        sph_others = sph_others[ids]
        others = others[ids]
        sph_others = sph_others[others[:, 0].argsort(kind='mergesort')]
        results[i] = sph_others.reshape(72)
    # O-part
    for i in range(17, 19):
        others = np.vstack((test_mol[:i], test_mol[i+1:]))
        neighs = others[10:]
        dists = cdist(positions[np.newaxis, i], neighs[:, 1:])
        one_id, two_id = dists.argsort().reshape(8)[:2]
        one, two = neighs[one_id, 1:], neighs[two_id, 1:]
        zero =test_mol[i, 1:]
        x = one - zero
        z = np.cross(x, two - zero)
        y = np.cross(z, x)
        x /= np.linalg.norm(x)
        y /= np.linalg.norm(y)
        z /= np.linalg.norm(z)
        ch_b_others = change_base(others[:, 1:], x, y, z, zero)
        sph_others = get_spherical(ch_b_others)
        ids = sph_others[:, 0].argsort(kind='mergesort')
        sph_others = sph_others[ids]
        others = others[ids]
        sph_others = sph_others[others[:, 0].argsort(kind='mergesort')]
        results[i] = sph_others.reshape(72)
    test_result = np.array(get_input_data(test_mol[np.newaxis])).reshape(19, 72)
    
    assert_array_almost_equal(test_result, results)
    print('pass!')

## Run Tests

In [14]:
test_get_spherical()

pass!


In [15]:
test_change_base()

pass!


In [16]:
test_input_data()

pass!
