## Imports and Settings

In [1]:
# -*- coding: utf-8 -*-
import pandas as pd
import numpy as np
import os
from scipy.spatial.distance import squareform, cdist
import time

In [2]:
path_to_files = "./../Dataset/dsC7O2H10nsd.xyz/"
path_to_db = "./../Dataset/iso17/reference.db"
filenames = os.listdir(path_to_files)
print('number of files: {}'.format(len(filenames)))

number of files: 6095


In [3]:
from ase.db import connect

molecules = []
energies = []
with connect(path_to_db) as conn:
    for row in conn.select():
        molecules.append(np.hstack((row['numbers'].reshape((19, 1)), row['positions'])))
        energies.append(row['total_energy'])

In [4]:
len(molecules)

404000

# Helper Functions

In [5]:
class ProgressTimer:
    
    def __init__(self, iterations, print_every=1):
        self.start_time = time.time()
        self.iterations = iterations
        self.iteration = 0
        self.print_every = print_every
    
    def how_long(self):
        self.iteration += 1
        if self.iteration % self.print_every == 0:
            print(np.round(100*self.iteration/self.iterations, 1), ' % --- ', 
                  np.round((time.time()-self.start_time)/
                           self.iteration*(self.iterations-self.iteration)/60,
                           1))

In [6]:
100%100

0

## Create a Dataframe from Input files

In [7]:
#list_ = []
#for file in filenames:
#    filepath = os.path.join(path_to_files, file)
#    try:
#        df_single = pd.read_csv(filepath, skiprows=2,
#                               skipfooter=3, delimiter='\t',
#                               names=['atomtype', 'x', 'y', 'z', 'charge'], 
#                               dtype=dict(atomtype=str, x=float, y=float, z=float, charge=float))
#    except:
#        print(file)
#    df_single['file'] = file
#    list_.append(df_single)
#df_all = pd.concat(list_)
#df_all.head(5)

## Prepare raw Data for Transformation

In [8]:
#n_atoms = 19
#h_atoms = 10
#mask_H = dict(H='ZZZ_H')
#df_all = df_all.replace(dict(atomtype=mask_H))
## sort by file and atomtype
#df_all = df_all.sort_values(['file', 'atomtype']).reset_index(drop=True)
## create file id column
#df_all['file_id'] = (df_all.index) // n_atoms + 1

In [9]:
#df_all.head(25)

## Transform Dataframe to Numpy Array for faster Calculations

In [10]:
#raw_matrix = df_all[['file_id', 'atomtype', 'x', 'y', 'z', 'charge']].values

## Transformation Functions

In [11]:
def get_spherical(positions):
    """
    Transform 3D cartesian coordinates to spherical coordinates that can
    be used for the nural network input.
    
    Parameters
    ----------
    positions : ndarray
        Array of shape (N, 3) where N is the number of coordinates that
        needs to be transformed.
    
    Returns
    -------
    ndarray
        Array of shape (N, 4) where N is the number of transformed coordinates.
        Transformed coordinates for one position: (1/r, cos(theta), cos(phi), sin(phi)).
        
    """            
    positions = positions.astype(float)
    r = np.linalg.norm(positions, axis=1)
    theta = np.arccos(positions[:, 2]/r)
    phi = np.arctan2(positions[:, 1], positions[:, 0])
    return np.array([1/r, np.cos(theta), np.cos(phi), np.sin(phi)]).T

In [12]:
def change_base(positions, x, y, z, o):
    """
    Calculate the base transformation from the standard basis to the new axes x, y, z.
    
    Parameters
    ----------
    positions : np.array
        3D atom position in the standard basis
    x : np.array
        new x-axis
    y : np.array
        new y-axis
    z : np.array
        new z-axis    
    o : np.array
        new origin
    
    Returns
    -------
    new_positions : np.array
        3D atom position in the new basis.
        Same shape as positions.

    """
    positions -= o
    basis = np.vstack((x, y, z)).T
    basis_inv = np.linalg.inv(basis)
    new_positions = basis_inv.dot(positions.T).T
    return new_positions


In [13]:
#raw_matrix = raw_matrix[:, 1:5]

In [14]:
def get_input_data(raw_matrix):
    """
    Calculate the training input for the sub-networks from a given molecular configuration.

    Parameters
    ----------
    raw_matrix : np.array
        Matrix of the raw input data for all files and all atoms
    raw_matrix_cols : list
        Column names for the raw_matrix

    Returns
    -------
    X : np.array
        Training data with 'atomtype' and 'relative position-vector' for all other atoms.
    Y : np.array
        Training labels (Mullikan Charge)

    """
    timer = ProgressTimer(len(raw_matrix), print_every=100)
    n_atoms = raw_matrix[0].shape[0]
    h_atoms = np.sum(raw_matrix[0][:, 0] == 1)
    not_H_atoms = n_atoms - h_atoms
    # make a copy
    network_inputs = []
    # create a column for the pos vector
    # loop over all configurations
    for molecule in raw_matrix:
        molecule = molecule[molecule[:, 0].argsort()]
        timer.how_long()
        mol_input = []
        for atom in range(len(molecule)):
            others = np.delete(molecule, atom, axis=0)
            focus_atom = molecule[atom]
            # get distances from focus atom to other atoms
            distances = cdist(focus_atom[1:].reshape(1, 3), molecule[:, 1:])[0]
            zero = focus_atom[1:].astype(float)
            # get nearest atoms that are not H
            nearest = distances.argsort()
            one_id, two_id = nearest[nearest >= h_atoms][1:3]
            one = molecule[one_id, 1:].astype(float)
            two = molecule[two_id, 1:].astype(float)
            # get new basis vectors
            new_x = one - zero
            new_z = np.cross(new_x, two - zero)
            new_y = np.cross(new_x, new_z)
            # normalize basis vectors
            new_x /= np.linalg.norm(new_x)
            new_y /= np.linalg.norm(new_y)
            new_z /= np.linalg.norm(new_z)
            # sort by distance to origin
            labels = others[:, 0]
            cart_coords = others[:, 1:].astype(float)
            trans_coords = change_base(cart_coords, new_x, new_y,
                                       new_z, zero)
            spherical_coords = get_spherical(trans_coords)
            spherical_coords = spherical_coords[spherical_coords[:, 1].argsort()]
            spherical_coords = spherical_coords[labels.argsort()]
            net_in_coords = spherical_coords.reshape((n_atoms - 1) * 4).tolist()
            mol_input.append(net_in_coords)
        network_inputs.append(mol_input)
    return network_inputs

## Run Calculations

In [15]:
start = time.time()
network_in = np.array(get_input_data(molecules))
print('time: {}'.format(time.time()-start))

0.0  % ---  57.4
0.0  % ---  49.9
0.1  % ---  51.6
0.1  % ---  53.5
0.1  % ---  52.8
0.1  % ---  51.9
0.2  % ---  51.1
0.2  % ---  50.5
0.2  % ---  50.1
0.2  % ---  50.0
0.3  % ---  49.5
0.3  % ---  49.2
0.3  % ---  49.1
0.3  % ---  49.5
0.4  % ---  49.1
0.4  % ---  49.1
0.4  % ---  49.1
0.4  % ---  49.0
0.5  % ---  49.1
0.5  % ---  49.9
0.5  % ---  49.6
0.5  % ---  49.6
0.6  % ---  49.6
0.6  % ---  49.3
0.6  % ---  49.1
0.6  % ---  48.8
0.7  % ---  48.5
0.7  % ---  48.3
0.7  % ---  48.0
0.7  % ---  47.8
0.8  % ---  47.6
0.8  % ---  47.3
0.8  % ---  47.1
0.8  % ---  46.9
0.9  % ---  46.7
0.9  % ---  46.5
0.9  % ---  46.3
0.9  % ---  46.1
1.0  % ---  46.0
1.0  % ---  45.8
1.0  % ---  45.7
1.0  % ---  45.5
1.1  % ---  45.4
1.1  % ---  45.3
1.1  % ---  45.2
1.1  % ---  45.0
1.2  % ---  44.9
1.2  % ---  44.8
1.2  % ---  44.7
1.2  % ---  44.6
1.3  % ---  44.5
1.3  % ---  44.4
1.3  % ---  44.3
1.3  % ---  44.2
1.4  % ---  44.2
1.4  % ---  44.1
1.4  % ---  44.0
1.4  % ---  43.9
1.5  % ---  43

11.9  % ---  37.3
11.9  % ---  37.3
11.9  % ---  37.3
11.9  % ---  37.3
12.0  % ---  37.3
12.0  % ---  37.3
12.0  % ---  37.3
12.0  % ---  37.3
12.1  % ---  37.3
12.1  % ---  37.2
12.1  % ---  37.2
12.1  % ---  37.2
12.2  % ---  37.2
12.2  % ---  37.2
12.2  % ---  37.2
12.2  % ---  37.2
12.3  % ---  37.2
12.3  % ---  37.2
12.3  % ---  37.2
12.3  % ---  37.1
12.4  % ---  37.1
12.4  % ---  37.1
12.4  % ---  37.1
12.4  % ---  37.1
12.5  % ---  37.1
12.5  % ---  37.1
12.5  % ---  37.1
12.5  % ---  37.1
12.5  % ---  37.0
12.6  % ---  37.0
12.6  % ---  37.0
12.6  % ---  37.0
12.6  % ---  37.0
12.7  % ---  37.0
12.7  % ---  37.0
12.7  % ---  37.0
12.7  % ---  37.0
12.8  % ---  37.0
12.8  % ---  36.9
12.8  % ---  36.9
12.8  % ---  36.9
12.9  % ---  36.9
12.9  % ---  36.9
12.9  % ---  36.9
12.9  % ---  36.9
13.0  % ---  36.9
13.0  % ---  36.9
13.0  % ---  36.8
13.0  % ---  36.8
13.1  % ---  36.8
13.1  % ---  36.8
13.1  % ---  36.8
13.1  % ---  36.8
13.2  % ---  36.8
13.2  % ---  36.8
13.2  % --

23.1  % ---  32.8
23.2  % ---  32.7
23.2  % ---  32.7
23.2  % ---  32.7
23.2  % ---  32.7
23.3  % ---  32.7
23.3  % ---  32.7
23.3  % ---  32.7
23.3  % ---  32.7
23.4  % ---  32.7
23.4  % ---  32.7
23.4  % ---  32.6
23.4  % ---  32.6
23.5  % ---  32.6
23.5  % ---  32.6
23.5  % ---  32.6
23.5  % ---  32.6
23.6  % ---  32.6
23.6  % ---  32.6
23.6  % ---  32.6
23.6  % ---  32.6
23.7  % ---  32.5
23.7  % ---  32.5
23.7  % ---  32.5
23.7  % ---  32.5
23.8  % ---  32.5
23.8  % ---  32.5
23.8  % ---  32.5
23.8  % ---  32.5
23.9  % ---  32.5
23.9  % ---  32.5
23.9  % ---  32.4
23.9  % ---  32.4
24.0  % ---  32.4
24.0  % ---  32.4
24.0  % ---  32.4
24.0  % ---  32.4
24.1  % ---  32.4
24.1  % ---  32.4
24.1  % ---  32.4
24.1  % ---  32.4
24.2  % ---  32.3
24.2  % ---  32.3
24.2  % ---  32.3
24.2  % ---  32.3
24.3  % ---  32.3
24.3  % ---  32.3
24.3  % ---  32.3
24.3  % ---  32.3
24.4  % ---  32.3
24.4  % ---  32.3
24.4  % ---  32.2
24.4  % ---  32.2
24.5  % ---  32.2
24.5  % ---  32.2
24.5  % --

34.4  % ---  27.9
34.5  % ---  27.9
34.5  % ---  27.9
34.5  % ---  27.9
34.5  % ---  27.9
34.6  % ---  27.9
34.6  % ---  27.9
34.6  % ---  27.9
34.6  % ---  27.9
34.7  % ---  27.8
34.7  % ---  27.8
34.7  % ---  27.8
34.7  % ---  27.8
34.8  % ---  27.8
34.8  % ---  27.8
34.8  % ---  27.8
34.8  % ---  27.8
34.9  % ---  27.8
34.9  % ---  27.7
34.9  % ---  27.7
34.9  % ---  27.7
35.0  % ---  27.7
35.0  % ---  27.7
35.0  % ---  27.7
35.0  % ---  27.7
35.0  % ---  27.7
35.1  % ---  27.7
35.1  % ---  27.6
35.1  % ---  27.6
35.1  % ---  27.6
35.2  % ---  27.6
35.2  % ---  27.6
35.2  % ---  27.6
35.2  % ---  27.6
35.3  % ---  27.6
35.3  % ---  27.6
35.3  % ---  27.5
35.3  % ---  27.5
35.4  % ---  27.5
35.4  % ---  27.5
35.4  % ---  27.5
35.4  % ---  27.5
35.5  % ---  27.5
35.5  % ---  27.5
35.5  % ---  27.5
35.5  % ---  27.4
35.6  % ---  27.4
35.6  % ---  27.4
35.6  % ---  27.4
35.6  % ---  27.4
35.7  % ---  27.4
35.7  % ---  27.4
35.7  % ---  27.4
35.7  % ---  27.4
35.8  % ---  27.3
35.8  % --

45.7  % ---  23.3
45.7  % ---  23.3
45.8  % ---  23.3
45.8  % ---  23.3
45.8  % ---  23.3
45.8  % ---  23.3
45.9  % ---  23.3
45.9  % ---  23.2
45.9  % ---  23.2
45.9  % ---  23.2
46.0  % ---  23.2
46.0  % ---  23.2
46.0  % ---  23.2
46.0  % ---  23.2
46.1  % ---  23.2
46.1  % ---  23.2
46.1  % ---  23.2
46.1  % ---  23.1
46.2  % ---  23.1
46.2  % ---  23.1
46.2  % ---  23.1
46.2  % ---  23.1
46.3  % ---  23.1
46.3  % ---  23.1
46.3  % ---  23.1
46.3  % ---  23.1
46.4  % ---  23.0
46.4  % ---  23.0
46.4  % ---  23.0
46.4  % ---  23.0
46.5  % ---  23.0
46.5  % ---  23.0
46.5  % ---  23.0
46.5  % ---  23.0
46.6  % ---  23.0
46.6  % ---  22.9
46.6  % ---  22.9
46.6  % ---  22.9
46.7  % ---  22.9
46.7  % ---  22.9
46.7  % ---  22.9
46.7  % ---  22.9
46.8  % ---  22.9
46.8  % ---  22.9
46.8  % ---  22.8
46.8  % ---  22.8
46.9  % ---  22.8
46.9  % ---  22.8
46.9  % ---  22.8
46.9  % ---  22.8
47.0  % ---  22.8
47.0  % ---  22.8
47.0  % ---  22.8
47.0  % ---  22.7
47.1  % ---  22.7
47.1  % --

57.0  % ---  18.4
57.0  % ---  18.4
57.1  % ---  18.4
57.1  % ---  18.4
57.1  % ---  18.3
57.1  % ---  18.3
57.2  % ---  18.3
57.2  % ---  18.3
57.2  % ---  18.3
57.2  % ---  18.3
57.3  % ---  18.3
57.3  % ---  18.3
57.3  % ---  18.3
57.3  % ---  18.2
57.4  % ---  18.6
57.4  % ---  18.6
57.4  % ---  18.6
57.4  % ---  18.6
57.5  % ---  18.6
57.5  % ---  18.5
57.5  % ---  18.5
57.5  % ---  18.5
57.5  % ---  18.5
57.6  % ---  18.5
57.6  % ---  18.5
57.6  % ---  18.5
57.6  % ---  18.5
57.7  % ---  18.5
57.7  % ---  18.4
57.7  % ---  18.4
57.7  % ---  18.4
57.8  % ---  18.4
57.8  % ---  18.4
57.8  % ---  18.4
57.8  % ---  18.4
57.9  % ---  18.4
57.9  % ---  18.4
57.9  % ---  18.3
57.9  % ---  18.3
58.0  % ---  18.3
58.0  % ---  18.3
58.0  % ---  18.3
58.0  % ---  18.3
58.1  % ---  18.3
58.1  % ---  18.3
58.1  % ---  18.3
58.1  % ---  18.2
58.2  % ---  18.2
58.2  % ---  18.2
58.2  % ---  18.2
58.2  % ---  18.2
58.3  % ---  18.2
58.3  % ---  18.2
58.3  % ---  18.2
58.3  % ---  18.2
58.4  % --

68.3  % ---  13.8
68.3  % ---  13.7
68.3  % ---  13.7
68.4  % ---  13.7
68.4  % ---  13.7
68.4  % ---  13.7
68.4  % ---  13.7
68.5  % ---  13.7
68.5  % ---  13.7
68.5  % ---  13.7
68.5  % ---  13.6
68.6  % ---  13.6
68.6  % ---  13.6
68.6  % ---  13.6
68.6  % ---  13.6
68.7  % ---  13.6
68.7  % ---  13.6
68.7  % ---  13.6
68.7  % ---  13.6
68.8  % ---  13.6
68.8  % ---  13.5
68.8  % ---  13.5
68.8  % ---  13.5
68.9  % ---  13.5
68.9  % ---  13.5
68.9  % ---  13.5
68.9  % ---  13.5
69.0  % ---  13.5
69.0  % ---  13.5
69.0  % ---  13.4
69.0  % ---  13.4
69.1  % ---  13.4
69.1  % ---  13.4
69.1  % ---  13.4
69.1  % ---  13.4
69.2  % ---  13.4
69.2  % ---  13.4
69.2  % ---  13.4
69.2  % ---  13.3
69.3  % ---  13.3
69.3  % ---  13.3
69.3  % ---  13.3
69.3  % ---  13.3
69.4  % ---  13.3
69.4  % ---  13.3
69.4  % ---  13.3
69.4  % ---  13.3
69.5  % ---  13.2
69.5  % ---  13.2
69.5  % ---  13.2
69.5  % ---  13.2
69.6  % ---  13.2
69.6  % ---  13.2
69.6  % ---  13.2
69.6  % ---  13.2
69.7  % --

79.7  % ---  9.0
79.7  % ---  8.9
79.8  % ---  8.9
79.8  % ---  8.9
79.8  % ---  8.9
79.8  % ---  8.9
79.9  % ---  8.9
79.9  % ---  8.9
79.9  % ---  8.9
79.9  % ---  8.9
80.0  % ---  8.8
80.0  % ---  8.8
80.0  % ---  8.8
80.0  % ---  8.8
80.0  % ---  8.8
80.1  % ---  8.8
80.1  % ---  8.8
80.1  % ---  8.8
80.1  % ---  8.8
80.2  % ---  8.7
80.2  % ---  8.7
80.2  % ---  8.7
80.2  % ---  8.7
80.3  % ---  8.7
80.3  % ---  8.7
80.3  % ---  8.7
80.3  % ---  8.7
80.4  % ---  8.7
80.4  % ---  8.6
80.4  % ---  8.6
80.4  % ---  8.6
80.5  % ---  8.6
80.5  % ---  8.6
80.5  % ---  8.6
80.5  % ---  8.6
80.6  % ---  8.6
80.6  % ---  8.6
80.6  % ---  8.6
80.6  % ---  8.5
80.7  % ---  8.5
80.7  % ---  8.5
80.7  % ---  8.5
80.7  % ---  8.5
80.8  % ---  8.5
80.8  % ---  8.5
80.8  % ---  8.5
80.8  % ---  8.5
80.9  % ---  8.4
80.9  % ---  8.4
80.9  % ---  8.4
80.9  % ---  8.4
81.0  % ---  8.4
81.0  % ---  8.4
81.0  % ---  8.4
81.0  % ---  8.4
81.1  % ---  8.4
81.1  % ---  8.3
81.1  % ---  8.3
81.1  % ---  8

91.6  % ---  3.8
91.7  % ---  3.8
91.7  % ---  3.7
91.7  % ---  3.7
91.7  % ---  3.7
91.8  % ---  3.7
91.8  % ---  3.7
91.8  % ---  3.7
91.8  % ---  3.7
91.9  % ---  3.7
91.9  % ---  3.7
91.9  % ---  3.6
91.9  % ---  3.6
92.0  % ---  3.6
92.0  % ---  3.6
92.0  % ---  3.6
92.0  % ---  3.6
92.1  % ---  3.6
92.1  % ---  3.6
92.1  % ---  3.6
92.1  % ---  3.5
92.2  % ---  3.5
92.2  % ---  3.5
92.2  % ---  3.5
92.2  % ---  3.5
92.3  % ---  3.5
92.3  % ---  3.5
92.3  % ---  3.5
92.3  % ---  3.5
92.4  % ---  3.4
92.4  % ---  3.4
92.4  % ---  3.4
92.4  % ---  3.4
92.5  % ---  3.4
92.5  % ---  3.4
92.5  % ---  3.4
92.5  % ---  3.4
92.5  % ---  3.4
92.6  % ---  3.3
92.6  % ---  3.3
92.6  % ---  3.3
92.6  % ---  3.3
92.7  % ---  3.3
92.7  % ---  3.3
92.7  % ---  3.3
92.7  % ---  3.3
92.8  % ---  3.3
92.8  % ---  3.2
92.8  % ---  3.2
92.8  % ---  3.2
92.9  % ---  3.2
92.9  % ---  3.2
92.9  % ---  3.2
92.9  % ---  3.2
93.0  % ---  3.2
93.0  % ---  3.2
93.0  % ---  3.1
93.0  % ---  3.1
93.1  % ---  3

In [16]:
network_in.shape

(404000, 19, 72)

## Get Y-labels

In [17]:
len(energies)

404000

## Save arrays to file

In [18]:
data_path = '../Dataset/network_inputs'
label_path = '../Dataset/network_labels'

In [None]:
np.save(data_path, network_in)
np.save(label_path, energies)

# Testing

## Test Functions

In [None]:
from numpy.testing import *

In [None]:
def test_get_spherical():
    test_positions = np.array([[0, 1, 2],
                               [1, 1, 1],
                               [-1, 2, 1]])
    val_result = np.array([[1/np.sqrt(5), np.cos(np.arccos(2/np.sqrt(5))),
                            np.cos(np.pi/2), np.sin(np.pi/2)],
                           [1/np.sqrt(3), np.cos(np.arccos(1/np.sqrt(3))),
                            np.cos(np.arctan(1)), np.sin(np.arctan(1))],
                           [1/np.sqrt(6), np.cos(np.arccos(1/np.sqrt(6))),
                            np.cos(np.arctan(-2) + np.pi), np.sin(np.arctan(-2) + np.pi)]])
    assert_array_almost_equal(val_result, get_spherical(test_positions)) 

In [None]:
def test_change_base():
    test_positions = np.array([[0, 1, 2],
                               [1, 1, 1],
                               [-1, 2, 1]])
    x = np.array([1, 1, 0])
    y = np.array([0, 0, 1])
    z = np.array([2, 1, 3])
    val_result = np.array([[-7.,-13.,4.],
                           [-8.,-17.,5.],
                           [-4., -8.,2.]])
    o = np.array([-1, 4, 3])
    assert_array_almost_equal(val_result, change_base(test_positions, x, y, z, o))

In [None]:
def test_get_input_data():
    test_mol = np.array([['C', 1, 1, 1],
                         ['O', 1, 0, 0],
                         ['O', 0, 3, 0],
                         ['ZZZ_H', 0, 2, 0]])
    return get_input_data(test_mol, 4)

## Run Tests

In [None]:
test_get_spherical()

In [None]:
test_change_base()

In [None]:
test_get_input_data()

In [None]:
np.cross(np.array([-1, 2, -1]), np.array([0, -1, -1]))

In [None]:
x = np.array([[1, 1, 1], [1, 0, 0], [0, 3, 0], [0, 2, 0]])

In [None]:
cdist(x, x)#.argsort()