In [None]:
import os
import sys
import time 

sys.path.append('..')
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf

from AlphaGNN import Alpha, ALPHA0

In [None]:
DATA = np.load('data/DATA_ALPHA.npy', allow_pickle=True).item()
DATA_SHOWCASE = np.load('data/DATA_ALPHA_SHOWCASE.npy', allow_pickle=True).item()
DATA_EXP = np.load('data/DATA_ALPHA_EXP.npy', allow_pickle=True).item()
KEYS = list(DATA.keys())

In [None]:
N_SAMPLES = 512 
N_EPOCHS = 512
N_STEPS = 2

lr_fn = tf.optimizers.schedules.ExponentialDecay(5e-4, int(N_SAMPLES *  N_EPOCHS), 2e-2) # 4e-4, 1e-2
optimizer = tf.keras.optimizers.Adam(lr_fn)
mae_alpha = tf.keras.metrics.MeanAbsoluteError()
alpha_scaler = Alpha(n_steps=N_STEPS)

In [None]:
for epoch in range(N_EPOCHS):    
    start = time.time()
    sampled_keys = np.random.choice(KEYS, N_SAMPLES, replace=False)
    for key in sampled_keys:   
        target = DATA[key]['pol_ccsd']
        graph = DATA[key]['graph']
        elements = DATA[key]['elements']
        ratios = DATA[key]['ratios'][None]
        with tf.GradientTape() as tape:
            prediction = tf.reduce_sum(alpha_scaler(graph, ratios, elements))
            loss = tf.math.squared_difference(target, prediction)            
        gradients = tape.gradient(loss, alpha_scaler.trainable_variables)        
        optimizer.apply_gradients(zip(gradients, alpha_scaler.trainable_variables))
        #optimizer.apply_gradients((grad, var) for (grad, var) in zip(gradients, model.trainable_variables) if grad is not None and not tf.math.is_nan(grad).numpy().any())
        mae_alpha.update_state([target], [prediction])
    print('Epoch {}'.format(epoch))
    print(time.time() - start)
    print('MAE [A3]: {}'.format(mae_alpha.result()))
    print('LR ', optimizer.lr.numpy())    
    mae_alpha.reset_states()      

In [None]:
alpha_scaler.save_weights(f'weights/ALPHAR{N_STEPS}')

In [None]:
targets, predictions = [], []
for key in DATA:
    target = DATA[key]['pol_ccsd']
    graph = DATA[key]['graph']
    elements = DATA[key]['elements']
    ratios = DATA[key]['ratios'][None]
    prediction = tf.reduce_sum(alpha_scaler(graph, ratios, elements))
    targets.append(target)
    predictions.append(prediction)

In [None]:
targets_showcase, predictions_showcase = [], []
for key in DATA_SHOWCASE:
    target = DATA_SHOWCASE[key]['pol_ccsd']
    graph = DATA_SHOWCASE[key]['graph']
    elements = DATA_SHOWCASE[key]['elements']
    ratios = DATA_SHOWCASE[key]['ratios'][None]
    prediction = tf.reduce_sum(alpha_scaler(graph, ratios, elements))
    targets_showcase.append(target)
    predictions_showcase.append(prediction)

In [None]:
targets_exp, predictions_exp = [], []
for key in DATA_EXP:#
    target = DATA_EXP[key]['pol']
    prediction = tf.reduce_sum(alpha_scaler(DATA_EXP[key]['graph'], DATA_EXP[key]['ratios'][None], DATA_EXP[key]['elements']))
    targets_exp.append(target)
    predictions_exp.append(prediction)

In [None]:
targets, predictions = np.array(targets), np.array(predictions)
targets_showcase, predictions_showcase = np.array(targets_showcase), np.array(predictions_showcase)
targets_exp, predictions_exp = np.array(targets_exp), np.array(predictions_exp)

In [None]:
np.mean(np.abs(targets - predictions)), np.mean(np.abs(targets_showcase - predictions_showcase)), np.mean(np.abs(targets_exp - predictions_exp)) # N2: (0.09475004, 0.4558718)

In [None]:
plt.figure(0, figsize=(8, 8), dpi=200)

plt.scatter(predictions_exp, targets_exp, s=0.5, label='Train')
plt.plot(range(2, 25), range(2, 25), color='red')
ax = plt.gca()
ax.set_ylabel('Polarizability (Reference) [A^3]')
ax.set_xlabel('Polarizability (ML) [A^3]')

In [None]:
plt.figure(0, figsize=(8, 8), dpi=200)

plt.scatter(predictions, targets, s=0.5, label='Train')
plt.scatter(predictions_showcase, targets_showcase, s=0.5, label='Showcase')
plt.plot(range(2, 25), range(2, 25), color='red')
ax = plt.gca()
ax.set_ylabel('Polarizability (Reference) [A^3]')
ax.set_xlabel('Polarizability (ML) [A^3]')
plt.legend(frameon=False)

In [None]:
#N2
plt.figure(0, figsize=(8, 8), dpi=200)

plt.scatter(predictions, targets, s=0.5, label='Train')
plt.scatter(predictions_showcase, targets_showcase, s=0.5, label='Showcase')
plt.plot(range(2, 25), range(2, 25), color='red')
ax = plt.gca()
ax.set_ylabel('Polarizability (Reference) [A^3]')
ax.set_xlabel('Polarizability (ML) [A^3]')
plt.legend(frameon=False)