In [1]:
import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
import seaborn as sns
sns.set(context='talk', style='ticks',
        color_codes=True, rc={'legend.frameon': False})

%matplotlib inline

In [2]:
# Load the input data
mol_train = pd.read_csv('data/mol_train_uff.csv.gz')
mol_valid = pd.read_csv('data/mol_valid_uff.csv.gz')
mol_test = pd.read_csv('data/mol_test_uff.csv.gz')

smiles_train = pd.read_csv('data/smiles_train.csv.gz', index_col=0)
smiles_valid = pd.read_csv('data/smiles_valid.csv.gz', index_col=0)
smiles_test = pd.read_csv('data/smiles_test.csv.gz', index_col=0)

In [3]:
from keras.models import load_model
from nfp import custom_layers
from nfp.preprocessing import GraphSequence
import warnings
import pickle
from tqdm import tqdm

Using TensorFlow backend.


In [4]:
props = ['gap', 'homo', 'lumo', 'spectral_overlap', 'homo_extrapolated',
         'lumo_extrapolated', 'gap_extrapolated', 'optical_lumo_extrapolated']

def mae(true, pred):
    return np.nanmean(np.abs(true - pred), 0)

In [27]:
with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    model = load_model('b3lyp_2D_mae_multitarget/best_model.hdf5',
                       custom_objects=custom_layers)
    
with open('b3lyp_2D_mae_multitarget/preprocessor.p', 'rb') as f:
    preprocessor = pickle.load(f)
    
with open('b3lyp_2D_mae_multitarget/y_scaler.p', 'rb') as f:
    y_scaler = pickle.load(f)

inputs_test = preprocessor.predict(smiles_test.smile)
test_inputs_seq = GraphSequence(inputs_test, shuffle=False, final_batch=True, batch_size=128)


  0%|          | 0/5000 [00:00<?, ?it/s][A
  1%|          | 51/5000 [00:00<00:09, 504.44it/s][A
  2%|▏         | 101/5000 [00:00<00:09, 499.73it/s][A
  3%|▎         | 149/5000 [00:00<00:09, 492.62it/s][A
  4%|▍         | 203/5000 [00:00<00:09, 503.39it/s][A
  5%|▌         | 264/5000 [00:00<00:08, 528.74it/s][A
  6%|▋         | 322/5000 [00:00<00:08, 543.11it/s][A
  8%|▊         | 379/5000 [00:00<00:08, 550.07it/s][A
  9%|▊         | 434/5000 [00:00<00:08, 549.87it/s][A
 10%|▉         | 488/5000 [00:00<00:08, 544.37it/s][A
 11%|█         | 541/5000 [00:01<00:08, 539.11it/s][A
 12%|█▏        | 601/5000 [00:01<00:07, 552.77it/s][A
 13%|█▎        | 657/5000 [00:01<00:07, 553.09it/s][A
 14%|█▍        | 712/5000 [00:01<00:08, 535.22it/s][A
 15%|█▌        | 766/5000 [00:01<00:07, 530.16it/s][A
 16%|█▋        | 821/5000 [00:01<00:07, 535.46it/s][A
 18%|█▊        | 876/5000 [00:01<00:07, 535.81it/s][A
 19%|█▊        | 933/5000 [00:01<00:07, 545.32it/s][A
 20%|█▉        | 991/

In [28]:
inputs_train = preprocessor.predict(smiles_train.smile)


  0%|          | 0/80823 [00:00<?, ?it/s][A
  0%|          | 22/80823 [00:00<06:21, 211.82it/s][A
  0%|          | 45/80823 [00:00<06:16, 214.71it/s][A
  0%|          | 68/80823 [00:00<06:10, 218.03it/s][A
  0%|          | 92/80823 [00:00<06:03, 222.01it/s][A
  0%|          | 116/80823 [00:00<05:56, 226.51it/s][A
  0%|          | 144/80823 [00:00<05:37, 238.82it/s][A
  0%|          | 172/80823 [00:00<05:24, 248.22it/s][A
  0%|          | 200/80823 [00:00<05:13, 256.87it/s][A
  0%|          | 232/80823 [00:00<04:58, 270.38it/s][A
  0%|          | 263/80823 [00:01<04:48, 279.51it/s][A
  0%|          | 293/80823 [00:01<04:42, 284.77it/s][A
  0%|          | 323/80823 [00:01<04:39, 287.75it/s][A
  0%|          | 357/80823 [00:01<04:27, 301.03it/s][A
  0%|          | 389/80823 [00:01<04:23, 305.42it/s][A
  1%|          | 420/80823 [00:01<04:22, 305.79it/s][A
  1%|          | 453/80823 [00:01<04:17, 312.23it/s][A
  1%|          | 485/80823 [00:01<04:17, 312.31it/s][A
  1%| 

  9%|▉         | 7495/80823 [00:14<01:46, 691.06it/s][A
  9%|▉         | 7565/80823 [00:14<01:48, 678.00it/s][A
  9%|▉         | 7635/80823 [00:14<01:47, 682.59it/s][A
 10%|▉         | 7707/80823 [00:14<01:45, 692.68it/s][A
 10%|▉         | 7777/80823 [00:15<01:45, 694.12it/s][A
 10%|▉         | 7847/80823 [00:15<01:45, 689.43it/s][A
 10%|▉         | 7918/80823 [00:15<01:45, 694.24it/s][A
 10%|▉         | 7989/80823 [00:15<01:44, 697.68it/s][A
 10%|▉         | 8059/80823 [00:15<01:44, 696.92it/s][A
 10%|█         | 8133/80823 [00:15<01:42, 707.35it/s][A
 10%|█         | 8210/80823 [00:15<01:40, 722.37it/s][A
 10%|█         | 8283/80823 [00:15<01:40, 720.13it/s][A
 10%|█         | 8357/80823 [00:15<01:39, 725.53it/s][A
 10%|█         | 8430/80823 [00:15<01:39, 724.50it/s][A
 11%|█         | 8503/80823 [00:16<01:41, 713.39it/s][A
 11%|█         | 8578/80823 [00:16<01:39, 723.79it/s][A
 11%|█         | 8653/80823 [00:16<01:38, 731.32it/s][A
 11%|█         | 8728/80823 [00

 19%|█▉        | 15600/80823 [00:29<02:26, 446.31it/s][A
 19%|█▉        | 15646/80823 [00:29<02:31, 428.93it/s][A
 19%|█▉        | 15692/80823 [00:29<02:30, 434.07it/s][A
 19%|█▉        | 15736/80823 [00:30<02:30, 431.60it/s][A
 20%|█▉        | 15780/80823 [00:30<02:31, 428.84it/s][A
 20%|█▉        | 15826/80823 [00:30<02:29, 434.69it/s][A
 20%|█▉        | 15873/80823 [00:30<02:26, 442.12it/s][A
 20%|█▉        | 15918/80823 [00:30<02:33, 423.29it/s][A
 20%|█▉        | 15964/80823 [00:30<02:29, 432.56it/s][A
 20%|█▉        | 16010/80823 [00:30<02:27, 439.98it/s][A
 20%|█▉        | 16055/80823 [00:30<02:27, 439.05it/s][A
 20%|█▉        | 16100/80823 [00:30<02:26, 440.97it/s][A
 20%|█▉        | 16145/80823 [00:31<02:28, 434.47it/s][A
 20%|██        | 16189/80823 [00:31<02:28, 434.63it/s][A
 20%|██        | 16236/80823 [00:31<02:25, 442.61it/s][A
 20%|██        | 16283/80823 [00:31<02:23, 448.74it/s][A
 20%|██        | 16328/80823 [00:31<02:24, 445.70it/s][A
 20%|██       

 27%|██▋       | 22049/80823 [00:44<02:06, 465.65it/s][A
 27%|██▋       | 22096/80823 [00:44<02:09, 452.37it/s][A
 27%|██▋       | 22142/80823 [00:44<02:09, 453.26it/s][A
 27%|██▋       | 22188/80823 [00:44<02:15, 433.80it/s][A
 28%|██▊       | 22236/80823 [00:44<02:11, 446.39it/s][A
 28%|██▊       | 22284/80823 [00:45<02:09, 453.44it/s][A
 28%|██▊       | 22330/80823 [00:45<02:12, 442.74it/s][A
 28%|██▊       | 22375/80823 [00:45<02:12, 439.66it/s][A
 28%|██▊       | 22427/80823 [00:45<02:07, 459.69it/s][A
 28%|██▊       | 22474/80823 [00:45<02:10, 447.63it/s][A
 28%|██▊       | 22520/80823 [00:45<02:10, 448.30it/s][A
 28%|██▊       | 22566/80823 [00:45<02:11, 443.98it/s][A
 28%|██▊       | 22611/80823 [00:45<02:13, 437.41it/s][A
 28%|██▊       | 22656/80823 [00:45<02:12, 438.32it/s][A
 28%|██▊       | 22703/80823 [00:45<02:10, 445.88it/s][A
 28%|██▊       | 22754/80823 [00:46<02:06, 460.48it/s][A
 28%|██▊       | 22801/80823 [00:46<02:05, 460.67it/s][A
 28%|██▊      

 35%|███▌      | 28419/80823 [00:59<02:01, 430.45it/s][A
 35%|███▌      | 28464/80823 [00:59<02:00, 434.64it/s][A
 35%|███▌      | 28508/80823 [00:59<02:10, 399.51it/s][A
 35%|███▌      | 28549/80823 [00:59<02:09, 402.27it/s][A
 35%|███▌      | 28592/80823 [00:59<02:07, 408.40it/s][A
 35%|███▌      | 28636/80823 [00:59<02:05, 416.46it/s][A
 35%|███▌      | 28678/80823 [00:59<02:06, 412.65it/s][A
 36%|███▌      | 28724/80823 [01:00<02:03, 422.93it/s][A
 36%|███▌      | 28767/80823 [01:00<02:03, 422.20it/s][A
 36%|███▌      | 28818/80823 [01:00<01:56, 444.70it/s][A
 36%|███▌      | 28863/80823 [01:00<02:00, 432.95it/s][A
 36%|███▌      | 28907/80823 [01:00<02:03, 418.76it/s][A
 36%|███▌      | 28956/80823 [01:00<01:58, 437.77it/s][A
 36%|███▌      | 29001/80823 [01:00<02:02, 424.52it/s][A
 36%|███▌      | 29045/80823 [01:00<02:00, 428.73it/s][A
 36%|███▌      | 29089/80823 [01:00<02:03, 420.23it/s][A
 36%|███▌      | 29132/80823 [01:00<02:05, 412.44it/s][A
 36%|███▌     

 43%|████▎     | 34688/80823 [01:14<01:52, 409.69it/s][A
 43%|████▎     | 34730/80823 [01:14<01:52, 410.87it/s][A
 43%|████▎     | 34772/80823 [01:14<01:51, 412.63it/s][A
 43%|████▎     | 34817/80823 [01:14<01:48, 422.28it/s][A
 43%|████▎     | 34860/80823 [01:14<01:49, 421.63it/s][A
 43%|████▎     | 34903/80823 [01:14<01:51, 410.14it/s][A
 43%|████▎     | 34945/80823 [01:14<01:54, 399.92it/s][A
 43%|████▎     | 34990/80823 [01:14<01:51, 412.30it/s][A
 43%|████▎     | 35032/80823 [01:14<01:52, 406.35it/s][A
 43%|████▎     | 35073/80823 [01:14<01:55, 394.64it/s][A
 43%|████▎     | 35122/80823 [01:15<01:49, 417.97it/s][A
 44%|████▎     | 35165/80823 [01:15<01:48, 420.45it/s][A
 44%|████▎     | 35208/80823 [01:15<01:50, 414.19it/s][A
 44%|████▎     | 35250/80823 [01:15<01:50, 411.07it/s][A
 44%|████▎     | 35292/80823 [01:15<01:50, 412.68it/s][A
 44%|████▎     | 35334/80823 [01:15<01:49, 413.61it/s][A
 44%|████▍     | 35379/80823 [01:15<01:47, 421.20it/s][A
 44%|████▍    

 51%|█████     | 40903/80823 [01:28<01:36, 411.76it/s][A
 51%|█████     | 40947/80823 [01:28<01:35, 418.27it/s][A
 51%|█████     | 40989/80823 [01:28<01:36, 411.86it/s][A
 51%|█████     | 41031/80823 [01:29<01:37, 407.59it/s][A
 51%|█████     | 41076/80823 [01:29<01:35, 415.61it/s][A
 51%|█████     | 41123/80823 [01:29<01:32, 428.30it/s][A
 51%|█████     | 41170/80823 [01:29<01:30, 438.02it/s][A
 51%|█████     | 41214/80823 [01:29<01:33, 425.21it/s][A
 51%|█████     | 41257/80823 [01:29<01:35, 415.50it/s][A
 51%|█████     | 41299/80823 [01:29<01:37, 405.62it/s][A
 51%|█████     | 41343/80823 [01:29<01:35, 413.05it/s][A
 51%|█████     | 41386/80823 [01:29<01:34, 417.44it/s][A
 51%|█████▏    | 41428/80823 [01:30<01:37, 404.09it/s][A
 51%|█████▏    | 41471/80823 [01:30<01:35, 410.37it/s][A
 51%|█████▏    | 41513/80823 [01:30<01:37, 403.85it/s][A
 51%|█████▏    | 41555/80823 [01:30<01:36, 406.17it/s][A
 51%|█████▏    | 41603/80823 [01:30<01:32, 425.04it/s][A
 52%|█████▏   

 58%|█████▊    | 47066/80823 [01:43<01:21, 413.90it/s][A
 58%|█████▊    | 47109/80823 [01:43<01:20, 418.58it/s][A
 58%|█████▊    | 47152/80823 [01:43<01:20, 418.24it/s][A
 58%|█████▊    | 47194/80823 [01:43<01:21, 413.56it/s][A
 58%|█████▊    | 47236/80823 [01:43<01:20, 415.19it/s][A
 58%|█████▊    | 47278/80823 [01:44<01:24, 397.00it/s][A
 59%|█████▊    | 47320/80823 [01:44<01:23, 402.27it/s][A
 59%|█████▊    | 47362/80823 [01:44<01:22, 405.63it/s][A
 59%|█████▊    | 47403/80823 [01:44<01:22, 404.68it/s][A
 59%|█████▊    | 47447/80823 [01:44<01:20, 413.19it/s][A
 59%|█████▉    | 47489/80823 [01:44<01:21, 411.51it/s][A
 59%|█████▉    | 47534/80823 [01:44<01:19, 421.18it/s][A
 59%|█████▉    | 47577/80823 [01:44<01:18, 423.60it/s][A
 59%|█████▉    | 47624/80823 [01:44<01:16, 436.19it/s][A
 59%|█████▉    | 47673/80823 [01:44<01:14, 446.00it/s][A
 59%|█████▉    | 47718/80823 [01:45<01:15, 438.99it/s][A
 59%|█████▉    | 47763/80823 [01:45<01:16, 429.51it/s][A
 59%|█████▉   

 66%|██████▌   | 53189/80823 [01:58<01:09, 398.50it/s][A
 66%|██████▌   | 53229/80823 [01:58<01:09, 398.71it/s][A
 66%|██████▌   | 53272/80823 [01:58<01:08, 404.30it/s][A
 66%|██████▌   | 53313/80823 [01:58<01:07, 405.50it/s][A
 66%|██████▌   | 53356/80823 [01:58<01:07, 408.34it/s][A
 66%|██████▌   | 53398/80823 [01:58<01:06, 411.31it/s][A
 66%|██████▌   | 53440/80823 [01:58<01:08, 398.29it/s][A
 66%|██████▌   | 53483/80823 [01:58<01:07, 407.22it/s][A
 66%|██████▌   | 53524/80823 [01:58<01:07, 402.53it/s][A
 66%|██████▋   | 53565/80823 [01:59<01:09, 392.00it/s][A
 66%|██████▋   | 53613/80823 [01:59<01:06, 411.38it/s][A
 66%|██████▋   | 53655/80823 [01:59<01:07, 402.72it/s][A
 66%|██████▋   | 53697/80823 [01:59<01:06, 407.01it/s][A
 66%|██████▋   | 53738/80823 [01:59<01:07, 400.94it/s][A
 67%|██████▋   | 53780/80823 [01:59<01:06, 406.43it/s][A
 67%|██████▋   | 53822/80823 [01:59<01:06, 408.55it/s][A
 67%|██████▋   | 53863/80823 [01:59<01:06, 406.18it/s][A
 67%|██████▋  

 73%|███████▎  | 59248/80823 [02:12<00:51, 416.46it/s][A
 73%|███████▎  | 59290/80823 [02:12<00:51, 414.51it/s][A
 73%|███████▎  | 59334/80823 [02:13<00:51, 419.86it/s][A
 73%|███████▎  | 59377/80823 [02:13<00:51, 420.16it/s][A
 74%|███████▎  | 59420/80823 [02:13<00:51, 418.63it/s][A
 74%|███████▎  | 59462/80823 [02:13<00:51, 415.88it/s][A
 74%|███████▎  | 59508/80823 [02:13<00:50, 425.14it/s][A
 74%|███████▎  | 59551/80823 [02:13<00:51, 409.86it/s][A
 74%|███████▎  | 59596/80823 [02:13<00:50, 419.84it/s][A
 74%|███████▍  | 59640/80823 [02:13<00:50, 423.59it/s][A
 74%|███████▍  | 59683/80823 [02:13<00:50, 418.78it/s][A
 74%|███████▍  | 59725/80823 [02:14<00:50, 419.01it/s][A
 74%|███████▍  | 59767/80823 [02:14<00:50, 414.33it/s][A
 74%|███████▍  | 59811/80823 [02:14<00:49, 420.90it/s][A
 74%|███████▍  | 59856/80823 [02:14<00:48, 428.31it/s][A
 74%|███████▍  | 59899/80823 [02:14<00:49, 425.49it/s][A
 74%|███████▍  | 59945/80823 [02:14<00:48, 431.66it/s][A
 74%|███████▍ 

 81%|████████  | 65408/80823 [02:27<00:37, 410.29it/s][A
 81%|████████  | 65450/80823 [02:27<00:37, 410.08it/s][A
 81%|████████  | 65494/80823 [02:27<00:36, 416.49it/s][A
 81%|████████  | 65536/80823 [02:27<00:36, 413.18it/s][A
 81%|████████  | 65578/80823 [02:28<00:37, 411.10it/s][A
 81%|████████  | 65620/80823 [02:28<00:37, 405.99it/s][A
 81%|████████  | 65661/80823 [02:28<00:38, 391.91it/s][A
 81%|████████▏ | 65701/80823 [02:28<00:38, 389.09it/s][A
 81%|████████▏ | 65747/80823 [02:28<00:36, 407.73it/s][A
 81%|████████▏ | 65790/80823 [02:28<00:36, 412.79it/s][A
 81%|████████▏ | 65836/80823 [02:28<00:35, 425.23it/s][A
 82%|████████▏ | 65879/80823 [02:28<00:35, 426.15it/s][A
 82%|████████▏ | 65926/80823 [02:28<00:34, 438.03it/s][A
 82%|████████▏ | 65971/80823 [02:28<00:35, 424.08it/s][A
 82%|████████▏ | 66014/80823 [02:29<00:34, 424.45it/s][A
 82%|████████▏ | 66057/80823 [02:29<00:35, 421.26it/s][A
 82%|████████▏ | 66100/80823 [02:29<00:35, 413.36it/s][A
 82%|████████▏

 88%|████████▊ | 71412/80823 [02:42<00:21, 430.39it/s][A
 88%|████████▊ | 71456/80823 [02:42<00:21, 427.99it/s][A
 88%|████████▊ | 71501/80823 [02:42<00:21, 434.07it/s][A
 89%|████████▊ | 71545/80823 [02:42<00:21, 429.32it/s][A
 89%|████████▊ | 71589/80823 [02:42<00:22, 416.68it/s][A
 89%|████████▊ | 71632/80823 [02:42<00:22, 416.09it/s][A
 89%|████████▊ | 71676/80823 [02:42<00:21, 418.31it/s][A
 89%|████████▊ | 71718/80823 [02:43<00:22, 411.88it/s][A
 89%|████████▉ | 71761/80823 [02:43<00:21, 415.81it/s][A
 89%|████████▉ | 71803/80823 [02:43<00:21, 413.02it/s][A
 89%|████████▉ | 71845/80823 [02:43<00:21, 414.54it/s][A
 89%|████████▉ | 71887/80823 [02:43<00:21, 409.02it/s][A
 89%|████████▉ | 71928/80823 [02:43<00:22, 400.08it/s][A
 89%|████████▉ | 71971/80823 [02:43<00:21, 407.19it/s][A
 89%|████████▉ | 72012/80823 [02:43<00:22, 397.63it/s][A
 89%|████████▉ | 72052/80823 [02:43<00:22, 398.05it/s][A
 89%|████████▉ | 72092/80823 [02:44<00:22, 384.37it/s][A
 89%|████████▉

 96%|█████████▌| 77376/80823 [02:57<00:08, 398.74it/s][A
 96%|█████████▌| 77416/80823 [02:57<00:08, 390.59it/s][A
 96%|█████████▌| 77457/80823 [02:57<00:08, 395.87it/s][A
 96%|█████████▌| 77498/80823 [02:57<00:08, 397.69it/s][A
 96%|█████████▌| 77540/80823 [02:57<00:08, 403.69it/s][A
 96%|█████████▌| 77584/80823 [02:57<00:07, 411.78it/s][A
 96%|█████████▌| 77626/80823 [02:57<00:07, 412.76it/s][A
 96%|█████████▌| 77669/80823 [02:57<00:07, 415.71it/s][A
 96%|█████████▌| 77711/80823 [02:57<00:07, 410.42it/s][A
 96%|█████████▌| 77754/80823 [02:57<00:07, 413.93it/s][A
 96%|█████████▋| 77796/80823 [02:58<00:07, 410.87it/s][A
 96%|█████████▋| 77838/80823 [02:58<00:07, 411.38it/s][A
 96%|█████████▋| 77880/80823 [02:58<00:07, 405.93it/s][A
 96%|█████████▋| 77927/80823 [02:58<00:06, 420.39it/s][A
 96%|█████████▋| 77970/80823 [02:58<00:06, 409.84it/s][A
 97%|█████████▋| 78014/80823 [02:58<00:06, 417.82it/s][A
 97%|█████████▋| 78056/80823 [02:58<00:06, 411.15it/s][A
 97%|█████████

In [29]:
max((item['n_bond'] for item in inputs_train))

424

In [30]:
y_pred_test = y_scaler.inverse_transform(
    model.predict_generator(test_inputs_seq, verbose=True))

y_true = smiles_test.set_index('smile').values
mae_2d = pd.Series({props[i]: mae(y_pred_test[:, i], y_true[:, i]) for i in range(8)},
                     name='b3lyp_2D_mae_multitarget')

mae_2d = mae_2d * 1000
mae_2d['spectral_overlap'] /= 1000
mae_2d.round(1)



gap                           35.4
homo                          29.4
lumo                          29.2
spectral_overlap             149.2
homo_extrapolated             47.4
lumo_extrapolated             46.8
gap_extrapolated              56.3
optical_lumo_extrapolated     43.9
Name: b3lyp_2D_mae_multitarget, dtype: float64

In [10]:
y_multitarget = pd.DataFrame(y_pred_test, columns=props)

In [11]:
st_predictions = []

for prop in props:
    print(prop)
    
    model_name = 'b3lyp_2d_{}'.format(prop)

    with open('{}/preprocessor.p'.format(model_name), 'rb') as f:
        preprocessor = pickle.load(f)
        
    with open('{}/y_scaler.p'.format(model_name), 'rb') as f:
        y_scaler = pickle.load(f)

    inputs_test = preprocessor.predict(smiles_test.smile)
    test_inputs_seq = GraphSequence(inputs_test, shuffle=False, final_batch=True, batch_size=128)

    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        model = load_model('{}/best_model.hdf5'.format(model_name),
                           custom_objects=custom_layers)

    y_pred_test = y_scaler.inverse_transform(
        model.predict_generator(test_inputs_seq, verbose=True))
    st_predictions += [y_pred_test]

  1%|          | 57/5000 [00:00<00:08, 559.86it/s]

gap


100%|██████████| 5000/5000 [00:08<00:00, 618.00it/s]




  0%|          | 0/5000 [00:00<?, ?it/s]

homo


100%|██████████| 5000/5000 [00:07<00:00, 635.54it/s]




  0%|          | 0/5000 [00:00<?, ?it/s]

lumo


100%|██████████| 5000/5000 [00:07<00:00, 676.61it/s]




  0%|          | 0/5000 [00:00<?, ?it/s]

spectral_overlap


100%|██████████| 5000/5000 [00:07<00:00, 678.37it/s]




  0%|          | 0/5000 [00:00<?, ?it/s]

homo_extrapolated


100%|██████████| 5000/5000 [00:07<00:00, 687.79it/s]




  0%|          | 0/5000 [00:00<?, ?it/s]

lumo_extrapolated


100%|██████████| 5000/5000 [00:07<00:00, 664.03it/s]




  0%|          | 0/5000 [00:00<?, ?it/s]

gap_extrapolated


100%|██████████| 5000/5000 [00:07<00:00, 664.58it/s]




  0%|          | 0/5000 [00:00<?, ?it/s]

optical_lumo_extrapolated


100%|██████████| 5000/5000 [00:07<00:00, 671.40it/s]




In [12]:
y_2d_st = pd.DataFrame([i.flatten() for i in st_predictions], index=props).T

mae_2d_st = pd.Series({props[i]: mae(smiles_test[props].values[:, i], y_2d_st.values[:, i]) for i in range(8)},
                         name='b3lyp_2d_st')

mae_2d_st = mae_2d_st * 1000
mae_2d_st['spectral_overlap'] /= 1000
mae_2d_st.round(1)

gap                           36.9
homo                          32.1
lumo                          27.9
spectral_overlap             149.3
homo_extrapolated             49.1
lumo_extrapolated             47.8
gap_extrapolated              57.1
optical_lumo_extrapolated     47.8
Name: b3lyp_2d_st, dtype: float64

In [13]:
def rbf_expansion(distances, mu=0, delta=0.2, kmax=150):
    k = np.arange(0, kmax)
    logits = -(np.atleast_2d(distances).T - (-mu + delta * k))**2 / delta
    return np.exp(logits)

def precalc_rbfs(inputs):

    for item in tqdm(inputs):

        item['distance_rbf'] = rbf_expansion(item['distance'])
        del item['distance']

    return inputs

In [14]:
from rdkit.Chem import MolFromMolBlock
from nfp.preprocessing import RobustNanScaler

schnet_predictions = []
for prop in props:
    print(prop)
    
    model_name = 'b3lyp_schnet2_{}'.format(prop)

    with open('{}/schnet_preprocessor.p'.format(model_name), 'rb') as f:
        schnet_preprocessor = pickle.load(f)

    inputs_test = schnet_preprocessor.predict((MolFromMolBlock(mol) for _, mol in mol_test.mol.iteritems()))
    rbf_inputs_test = precalc_rbfs(list(inputs_test))
    rbf_input_seq = GraphSequence(rbf_inputs_test, shuffle=False, batch_size=32)

    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        model = load_model('{}/best_model.hdf5'.format(model_name),
                           custom_objects=custom_layers)

    itrain = mol_train[mol_train[prop].notna()]

    # Rescale Y matrix
    y_train_raw = itrain[[prop]].values

    y_scaler = RobustNanScaler()
    y_scaler.fit(y_train_raw)

    y_pred = model.predict_generator(rbf_input_seq, verbose=True)
    schnet_predictions += [y_scaler.inverse_transform(y_pred)]

31it [00:00, 299.56it/s]

gap


5000it [00:16, 310.85it/s]
100%|██████████| 5000/5000 [00:16<00:00, 297.04it/s]


homo


5000it [00:16, 307.66it/s]
100%|██████████| 5000/5000 [00:17<00:00, 288.46it/s]


lumo


5000it [00:16, 301.84it/s]
100%|██████████| 5000/5000 [00:15<00:00, 322.52it/s]




0it [00:00, ?it/s]

spectral_overlap


5000it [00:16, 289.22it/s]
100%|██████████| 5000/5000 [00:15<00:00, 322.76it/s]




0it [00:00, ?it/s]

homo_extrapolated


5000it [00:16, 303.89it/s]
100%|██████████| 5000/5000 [00:16<00:00, 307.87it/s]




0it [00:00, ?it/s]

lumo_extrapolated


5000it [00:17, 293.21it/s]
100%|██████████| 5000/5000 [00:17<00:00, 289.13it/s]


gap_extrapolated


5000it [00:16, 297.04it/s]
100%|██████████| 5000/5000 [00:16<00:00, 304.49it/s]




0it [00:00, ?it/s]

optical_lumo_extrapolated


5000it [00:16, 306.87it/s]
100%|██████████| 5000/5000 [00:16<00:00, 301.61it/s]




In [15]:
y_schnet = pd.DataFrame([i.flatten() for i in schnet_predictions], index=props).T

In [16]:
mae_schnet = pd.Series({props[i]: mae(mol_test[props].values[:, i], y_schnet.values[:, i]) for i in range(8)},
                         name='b3lyp_schnet')
mae_schnet = mae_schnet * 1000
mae_schnet['spectral_overlap'] /= 1000
mae_schnet.round(1)

gap                          32.7
homo                         27.0
lumo                         24.8
spectral_overlap             96.6
homo_extrapolated            56.9
lumo_extrapolated            56.8
gap_extrapolated             69.8
optical_lumo_extrapolated    57.2
Name: b3lyp_schnet, dtype: float64

# Predictions using UFF reoptimized molecules

In [20]:
mol_test_uff = pd.read_csv('data/mol_test_uff.csv.gz')
mol_test_uff = mol_test_uff[~mol_test_uff['molUFF'].isna()]

In [21]:
schnet_predictions_uff = []
for prop in props:
    print(prop)
    
    model_name = 'b3lyp_schnet2_uff_{}'.format(prop)

    with open('{}/schnet_preprocessor.p'.format(model_name), 'rb') as f:
        schnet_preprocessor = pickle.load(f)

    inputs_test = schnet_preprocessor.predict((MolFromMolBlock(mol) for _, mol in mol_test_uff.molUFF.iteritems()))
    rbf_inputs_test = precalc_rbfs(list(inputs_test))
    rbf_input_seq = GraphSequence(rbf_inputs_test, shuffle=False, batch_size=32)

    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        model = load_model('{}/best_model.hdf5'.format(model_name),
                           custom_objects=custom_layers)

    itrain = mol_train[mol_train[prop].notna()]

    # Rescale Y matrix
    y_train_raw = itrain[[prop]].values

    y_scaler = RobustNanScaler()
    y_scaler.fit(y_train_raw)

    y_pred = model.predict_generator(rbf_input_seq, verbose=True)
    schnet_predictions_uff += [y_scaler.inverse_transform(y_pred)]


0it [00:00, ?it/s][A
30it [00:00, 293.77it/s][A

gap



59it [00:00, 292.34it/s][A
90it [00:00, 295.20it/s][A
119it [00:00, 291.51it/s][A
145it [00:00, 279.34it/s][A
177it [00:00, 288.98it/s][A
205it [00:00, 286.10it/s][A
232it [00:00, 273.22it/s][A
260it [00:00, 275.15it/s][A
289it [00:01, 279.08it/s][A
318it [00:01, 280.98it/s][A
349it [00:01, 287.52it/s][A
382it [00:01, 297.85it/s][A
412it [00:01, 295.58it/s][A
443it [00:01, 295.37it/s][A
473it [00:01, 285.31it/s][A
502it [00:01, 284.90it/s][A
531it [00:01, 283.25it/s][A
562it [00:01, 290.16it/s][A
592it [00:02, 290.27it/s][A
622it [00:02, 284.70it/s][A
651it [00:02, 282.44it/s][A
681it [00:02, 285.71it/s][A
711it [00:02, 288.50it/s][A
742it [00:02, 292.42it/s][A
772it [00:02, 291.10it/s][A
802it [00:02, 293.64it/s][A
833it [00:02, 296.26it/s][A
863it [00:02, 297.06it/s][A
893it [00:03, 293.82it/s][A
923it [00:03, 283.23it/s][A
952it [00:03, 278.50it/s][A
984it [00:03, 287.21it/s][A
1013it [00:03, 287.16it/s][A
1044it [00:03, 288.49it/s][A
1077it [00:03

 41%|████      | 2046/4996 [00:06<00:08, 336.46it/s][A
 42%|████▏     | 2083/4996 [00:06<00:08, 345.42it/s][A
 42%|████▏     | 2119/4996 [00:06<00:09, 297.99it/s][A
 43%|████▎     | 2151/4996 [00:06<00:09, 295.49it/s][A
 44%|████▎     | 2182/4996 [00:07<00:09, 294.76it/s][A
 44%|████▍     | 2213/4996 [00:07<00:09, 287.23it/s][A
 45%|████▌     | 2250/4996 [00:07<00:08, 306.09it/s][A
 46%|████▌     | 2287/4996 [00:07<00:08, 321.57it/s][A
 47%|████▋     | 2328/4996 [00:07<00:07, 342.20it/s][A
 47%|████▋     | 2364/4996 [00:07<00:08, 327.41it/s][A
 48%|████▊     | 2398/4996 [00:07<00:08, 308.98it/s][A
 49%|████▊     | 2430/4996 [00:07<00:08, 290.10it/s][A
 49%|████▉     | 2460/4996 [00:08<00:09, 262.15it/s][A
 50%|████▉     | 2492/4996 [00:08<00:09, 275.45it/s][A
 51%|█████     | 2526/4996 [00:08<00:08, 290.63it/s][A
 51%|█████▏    | 2565/4996 [00:08<00:07, 312.14it/s][A
 52%|█████▏    | 2598/4996 [00:08<00:07, 314.68it/s][A
 53%|█████▎    | 2631/4996 [00:08<00:07, 311.35i




0it [00:00, ?it/s][A
17it [00:00, 168.17it/s][A

homo



45it [00:00, 189.76it/s][A
76it [00:00, 213.82it/s][A
104it [00:00, 229.21it/s][A
134it [00:00, 243.10it/s][A
164it [00:00, 255.79it/s][A
194it [00:00, 267.14it/s][A
222it [00:00, 270.12it/s][A
250it [00:00, 269.41it/s][A
280it [00:01, 277.27it/s][A
310it [00:01, 283.55it/s][A
341it [00:01, 288.61it/s][A
376it [00:01, 302.25it/s][A
407it [00:01, 297.13it/s][A
437it [00:01, 294.99it/s][A
467it [00:01, 293.73it/s][A
497it [00:01, 292.82it/s][A
527it [00:01, 288.94it/s][A
560it [00:01, 297.67it/s][A
590it [00:02, 295.49it/s][A
620it [00:02, 293.38it/s][A
650it [00:02, 288.63it/s][A
680it [00:02, 290.74it/s][A
712it [00:02, 297.46it/s][A
742it [00:02, 296.72it/s][A
772it [00:02, 294.90it/s][A
802it [00:02, 292.26it/s][A
833it [00:02, 297.24it/s][A
863it [00:02, 294.37it/s][A
893it [00:03, 293.75it/s][A
923it [00:03, 285.44it/s][A
952it [00:03, 282.27it/s][A
984it [00:03, 291.93it/s][A
1014it [00:03, 291.13it/s][A
1045it [00:03, 294.95it/s][A
1080it [00:03

 42%|████▏     | 2082/4996 [00:06<00:07, 369.08it/s][A
 42%|████▏     | 2122/4996 [00:06<00:07, 376.74it/s][A
 43%|████▎     | 2161/4996 [00:06<00:07, 372.72it/s][A
 44%|████▍     | 2199/4996 [00:07<00:07, 373.45it/s][A
 45%|████▍     | 2237/4996 [00:07<00:07, 370.40it/s][A
 46%|████▌     | 2275/4996 [00:07<00:07, 372.16it/s][A
 46%|████▋     | 2316/4996 [00:07<00:07, 382.05it/s][A
 47%|████▋     | 2355/4996 [00:07<00:07, 363.84it/s][A
 48%|████▊     | 2392/4996 [00:07<00:07, 352.87it/s][A
 49%|████▊     | 2428/4996 [00:07<00:07, 346.89it/s][A
 49%|████▉     | 2463/4996 [00:07<00:07, 345.79it/s][A
 50%|█████     | 2498/4996 [00:07<00:07, 344.10it/s][A
 51%|█████     | 2534/4996 [00:07<00:07, 348.67it/s][A
 51%|█████▏    | 2569/4996 [00:08<00:08, 303.20it/s][A
 52%|█████▏    | 2602/4996 [00:08<00:07, 309.75it/s][A
 53%|█████▎    | 2634/4996 [00:08<00:08, 290.20it/s][A
 54%|█████▎    | 2673/4996 [00:08<00:07, 312.53it/s][A
 54%|█████▍    | 2711/4996 [00:08<00:06, 329.33i

lumo



0it [00:00, ?it/s][A
27it [00:00, 265.84it/s][A
54it [00:00, 266.65it/s][A
86it [00:00, 277.69it/s][A
114it [00:00, 276.56it/s][A
141it [00:00, 274.36it/s][A
172it [00:00, 281.89it/s][A
202it [00:00, 287.09it/s][A
229it [00:00, 271.78it/s][A
256it [00:00, 271.04it/s][A
287it [00:01, 281.00it/s][A
315it [00:01, 279.85it/s][A
347it [00:01, 289.28it/s][A
380it [00:01, 299.23it/s][A
410it [00:01, 298.11it/s][A
440it [00:01, 297.03it/s][A
470it [00:01, 288.01it/s][A
500it [00:01, 288.31it/s][A
529it [00:01, 284.30it/s][A
562it [00:01, 293.75it/s][A
592it [00:02, 293.13it/s][A
622it [00:02, 287.49it/s][A
651it [00:02, 285.55it/s][A
681it [00:02, 289.31it/s][A
712it [00:02, 293.58it/s][A
742it [00:02, 294.51it/s][A
772it [00:02, 291.74it/s][A
803it [00:02, 294.41it/s][A
834it [00:02, 296.24it/s][A
864it [00:02, 296.59it/s][A
894it [00:03, 293.98it/s][A
924it [00:03, 283.35it/s][A
953it [00:03, 279.96it/s][A
984it [00:03, 287.83it/s][A
1013it [00:03, 286.16it

 39%|███▉      | 1948/4996 [00:06<00:09, 314.10it/s][A
 40%|███▉      | 1987/4996 [00:06<00:09, 331.93it/s][A
 41%|████      | 2027/4996 [00:06<00:08, 348.80it/s][A
 41%|████▏     | 2068/4996 [00:06<00:08, 363.50it/s][A
 42%|████▏     | 2106/4996 [00:07<00:08, 337.35it/s][A
 43%|████▎     | 2141/4996 [00:07<00:08, 334.20it/s][A
 44%|████▎     | 2178/4996 [00:07<00:08, 343.28it/s][A
 44%|████▍     | 2213/4996 [00:07<00:08, 342.36it/s][A
 45%|████▌     | 2250/4996 [00:07<00:07, 349.71it/s][A
 46%|████▌     | 2288/4996 [00:07<00:07, 357.90it/s][A
 47%|████▋     | 2325/4996 [00:07<00:07, 348.97it/s][A
 47%|████▋     | 2361/4996 [00:07<00:08, 313.11it/s][A
 48%|████▊     | 2394/4996 [00:07<00:08, 305.23it/s][A
 49%|████▊     | 2427/4996 [00:08<00:08, 309.08it/s][A
 49%|████▉     | 2460/4996 [00:08<00:08, 313.83it/s][A
 50%|█████     | 2498/4996 [00:08<00:07, 329.15it/s][A
 51%|█████     | 2535/4996 [00:08<00:07, 339.96it/s][A
 51%|█████▏    | 2572/4996 [00:08<00:06, 347.32i




0it [00:00, ?it/s][A

spectral_overlap



29it [00:00, 282.55it/s][A
58it [00:00, 283.16it/s][A
90it [00:00, 290.56it/s][A
119it [00:00, 289.71it/s][A
146it [00:00, 281.84it/s][A
178it [00:00, 290.11it/s][A
208it [00:00, 290.01it/s][A
235it [00:00, 281.10it/s][A
264it [00:00, 282.25it/s][A
295it [00:01, 288.59it/s][A
324it [00:01, 288.40it/s][A
358it [00:01, 301.18it/s][A
389it [00:01, 302.77it/s][A
421it [00:01, 306.09it/s][A
452it [00:01, 296.97it/s][A
482it [00:01, 291.81it/s][A
512it [00:01, 288.84it/s][A
545it [00:01, 298.73it/s][A
575it [00:01, 298.48it/s][A
606it [00:02, 300.17it/s][A
637it [00:02, 291.20it/s][A
667it [00:02, 290.08it/s][A
697it [00:02, 292.06it/s][A
728it [00:02, 295.76it/s][A
759it [00:02, 299.77it/s][A
790it [00:02, 295.70it/s][A
820it [00:02, 295.51it/s][A
852it [00:02, 302.43it/s][A
883it [00:03, 294.42it/s][A
913it [00:03, 286.49it/s][A
942it [00:03, 286.44it/s][A
973it [00:03, 289.72it/s][A
1003it [00:03, 292.44it/s][A
1033it [00:03, 291.80it/s][A
1064it [00:03,

 40%|███▉      | 1990/4996 [00:06<00:09, 303.20it/s][A
 41%|████      | 2032/4996 [00:06<00:09, 328.96it/s][A
 41%|████▏     | 2072/4996 [00:07<00:08, 344.93it/s][A
 42%|████▏     | 2108/4996 [00:07<00:08, 346.80it/s][A
 43%|████▎     | 2145/4996 [00:07<00:08, 351.16it/s][A
 44%|████▎     | 2181/4996 [00:07<00:08, 346.11it/s][A
 44%|████▍     | 2220/4996 [00:07<00:07, 356.29it/s][A
 45%|████▌     | 2257/4996 [00:07<00:07, 360.11it/s][A
 46%|████▌     | 2297/4996 [00:07<00:07, 369.78it/s][A
 47%|████▋     | 2335/4996 [00:07<00:07, 359.40it/s][A
 47%|████▋     | 2372/4996 [00:07<00:07, 342.18it/s][A
 48%|████▊     | 2407/4996 [00:08<00:08, 293.45it/s][A
 49%|████▉     | 2438/4996 [00:08<00:08, 297.73it/s][A
 50%|████▉     | 2474/4996 [00:08<00:08, 312.27it/s][A
 50%|█████     | 2510/4996 [00:08<00:07, 324.57it/s][A
 51%|█████     | 2546/4996 [00:08<00:07, 332.38it/s][A
 52%|█████▏    | 2580/4996 [00:08<00:08, 289.72it/s][A
 52%|█████▏    | 2611/4996 [00:08<00:08, 280.02i




0it [00:00, ?it/s][A

homo_extrapolated



22it [00:00, 216.03it/s][A
49it [00:00, 228.51it/s][A
81it [00:00, 247.85it/s][A
109it [00:00, 255.59it/s][A
137it [00:00, 262.31it/s][A
166it [00:00, 268.75it/s][A
197it [00:00, 279.82it/s][A
224it [00:00, 272.48it/s][A
251it [00:00, 269.28it/s][A
282it [00:01, 278.37it/s][A
312it [00:01, 282.59it/s][A
342it [00:01, 286.70it/s][A
377it [00:01, 301.48it/s][A
408it [00:01, 298.17it/s][A
438it [00:01, 295.58it/s][A
468it [00:01, 291.92it/s][A
498it [00:01, 290.76it/s][A
528it [00:01, 287.09it/s][A
560it [00:01, 296.07it/s][A
590it [00:02, 293.32it/s][A
620it [00:02, 292.23it/s][A
650it [00:02, 287.78it/s][A
680it [00:02, 289.70it/s][A
711it [00:02, 294.27it/s][A
742it [00:02, 296.44it/s][A
772it [00:02, 294.28it/s][A
803it [00:02, 297.09it/s][A
834it [00:02, 299.08it/s][A
865it [00:02, 299.09it/s][A
895it [00:03, 297.40it/s][A
925it [00:03, 285.10it/s][A
954it [00:03, 283.14it/s][A
986it [00:03, 292.98it/s][A
1016it [00:03, 291.48it/s][A
1046it [00:03, 

 40%|███▉      | 1976/4996 [00:06<00:09, 335.13it/s][A
 40%|████      | 2016/4996 [00:06<00:08, 350.53it/s][A
 41%|████      | 2057/4996 [00:07<00:08, 364.98it/s][A
 42%|████▏     | 2095/4996 [00:07<00:08, 340.05it/s][A
 43%|████▎     | 2130/4996 [00:07<00:08, 320.56it/s][A
 43%|████▎     | 2163/4996 [00:07<00:09, 289.56it/s][A
 44%|████▍     | 2199/4996 [00:07<00:09, 307.56it/s][A
 45%|████▍     | 2235/4996 [00:07<00:08, 320.22it/s][A
 45%|████▌     | 2272/4996 [00:07<00:08, 332.76it/s][A
 46%|████▋     | 2312/4996 [00:07<00:07, 349.95it/s][A
 47%|████▋     | 2350/4996 [00:07<00:07, 357.32it/s][A
 48%|████▊     | 2387/4996 [00:08<00:07, 335.65it/s][A
 48%|████▊     | 2422/4996 [00:08<00:08, 316.50it/s][A
 49%|████▉     | 2456/4996 [00:08<00:07, 322.44it/s][A
 50%|████▉     | 2492/4996 [00:08<00:07, 332.05it/s][A
 51%|█████     | 2528/4996 [00:08<00:07, 337.80it/s][A
 51%|█████▏    | 2565/4996 [00:08<00:07, 345.33it/s][A
 52%|█████▏    | 2602/4996 [00:08<00:06, 351.71i

lumo_extrapolated



0it [00:00, ?it/s][A
26it [00:00, 258.98it/s][A
55it [00:00, 267.05it/s][A
86it [00:00, 277.38it/s][A
115it [00:00, 277.51it/s][A
143it [00:00, 277.76it/s][A
174it [00:00, 285.05it/s][A
204it [00:00, 287.87it/s][A
231it [00:00, 274.91it/s][A
261it [00:00, 278.83it/s][A
292it [00:01, 284.81it/s][A
321it [00:01, 285.27it/s][A
354it [00:01, 295.58it/s][A
386it [00:01, 302.07it/s][A
417it [00:01, 302.78it/s][A
448it [00:01, 297.85it/s][A
478it [00:01, 289.24it/s][A
508it [00:01, 289.55it/s][A
537it [00:01, 287.06it/s][A
571it [00:01, 298.23it/s][A
601it [00:02, 297.95it/s][A
631it [00:02, 295.18it/s][A
661it [00:02, 290.24it/s][A
692it [00:02, 294.54it/s][A
723it [00:02, 298.00it/s][A
754it [00:02, 301.24it/s][A
785it [00:02, 295.58it/s][A
816it [00:02, 297.66it/s][A
848it [00:02, 303.06it/s][A
879it [00:03, 297.14it/s][A
909it [00:03, 288.90it/s][A
938it [00:03, 287.79it/s][A
968it [00:03, 289.91it/s][A
999it [00:03, 294.30it/s][A
1029it [00:03, 292.72it

 37%|███▋      | 1842/4996 [00:06<00:09, 320.00it/s][A
 38%|███▊      | 1879/4996 [00:06<00:09, 332.23it/s][A
 38%|███▊      | 1919/4996 [00:06<00:08, 348.90it/s][A
 39%|███▉      | 1955/4996 [00:06<00:09, 304.63it/s][A
 40%|███▉      | 1987/4996 [00:06<00:10, 287.58it/s][A
 41%|████      | 2028/4996 [00:06<00:09, 314.99it/s][A
 41%|████▏     | 2068/4996 [00:07<00:08, 335.66it/s][A
 42%|████▏     | 2104/4996 [00:07<00:08, 342.56it/s][A
 43%|████▎     | 2141/4996 [00:07<00:08, 349.42it/s][A
 44%|████▎     | 2179/4996 [00:07<00:08, 347.94it/s][A
 44%|████▍     | 2215/4996 [00:07<00:08, 329.05it/s][A
 45%|████▌     | 2251/4996 [00:07<00:08, 335.57it/s][A
 46%|████▌     | 2290/4996 [00:07<00:07, 348.94it/s][A
 47%|████▋     | 2330/4996 [00:07<00:07, 361.39it/s][A
 47%|████▋     | 2370/4996 [00:07<00:07, 371.98it/s][A
 48%|████▊     | 2408/4996 [00:08<00:07, 330.00it/s][A
 49%|████▉     | 2443/4996 [00:08<00:08, 314.53it/s][A
 50%|████▉     | 2481/4996 [00:08<00:07, 330.50i




0it [00:00, ?it/s][A

gap_extrapolated



27it [00:00, 260.56it/s][A
57it [00:00, 270.47it/s][A
88it [00:00, 279.19it/s][A
117it [00:00, 281.67it/s][A
145it [00:00, 277.49it/s][A
178it [00:00, 287.64it/s][A
208it [00:00, 287.55it/s][A
235it [00:00, 277.92it/s][A
264it [00:00, 278.58it/s][A
295it [00:01, 285.01it/s][A
324it [00:01, 284.90it/s][A
357it [00:01, 295.46it/s][A
388it [00:01, 299.62it/s][A
419it [00:01, 302.19it/s][A
450it [00:01, 294.24it/s][A
480it [00:01, 281.34it/s][A
509it [00:01, 282.55it/s][A
538it [00:01, 283.87it/s][A
570it [00:01, 293.43it/s][A
600it [00:02, 293.33it/s][A
630it [00:02, 289.67it/s][A
660it [00:02, 284.73it/s][A
691it [00:02, 289.83it/s][A
722it [00:02, 292.88it/s][A
752it [00:02, 293.24it/s][A
782it [00:02, 292.93it/s][A
813it [00:02, 297.41it/s][A
844it [00:02, 297.09it/s][A
874it [00:03, 293.80it/s][A
904it [00:03, 287.48it/s][A
933it [00:03, 287.84it/s][A
962it [00:03, 280.92it/s][A
994it [00:03, 289.68it/s][A
1024it [00:03, 291.82it/s][A
1054it [00:03, 

 41%|████      | 2034/4996 [00:06<00:09, 302.75it/s][A
 41%|████▏     | 2066/4996 [00:07<00:15, 188.54it/s][A
 42%|████▏     | 2102/4996 [00:07<00:13, 219.70it/s][A
 43%|████▎     | 2131/4996 [00:07<00:12, 222.66it/s][A
 43%|████▎     | 2160/4996 [00:07<00:11, 239.26it/s][A
 44%|████▍     | 2194/4996 [00:07<00:10, 261.33it/s][A
 45%|████▍     | 2229/4996 [00:07<00:09, 282.27it/s][A
 45%|████▌     | 2264/4996 [00:07<00:09, 299.40it/s][A
 46%|████▌     | 2297/4996 [00:07<00:09, 281.64it/s][A
 47%|████▋     | 2331/4996 [00:07<00:08, 296.62it/s][A
 47%|████▋     | 2363/4996 [00:08<00:09, 271.50it/s][A
 48%|████▊     | 2399/4996 [00:08<00:08, 292.38it/s][A
 49%|████▉     | 2436/4996 [00:08<00:08, 311.67it/s][A
 49%|████▉     | 2469/4996 [00:08<00:08, 314.54it/s][A
 50%|█████     | 2502/4996 [00:08<00:08, 298.71it/s][A
 51%|█████     | 2533/4996 [00:08<00:08, 277.84it/s][A
 51%|█████▏    | 2562/4996 [00:08<00:09, 252.12it/s][A
 52%|█████▏    | 2596/4996 [00:08<00:08, 272.09i




0it [00:00, ?it/s][A

optical_lumo_extrapolated



25it [00:00, 244.80it/s][A
54it [00:00, 253.45it/s][A
86it [00:00, 267.96it/s][A
114it [00:00, 270.71it/s][A
142it [00:00, 271.04it/s][A
173it [00:00, 280.44it/s][A
202it [00:00, 283.18it/s][A
229it [00:00, 267.79it/s][A
256it [00:00, 267.23it/s][A
287it [00:01, 277.44it/s][A
315it [00:01, 276.27it/s][A
347it [00:01, 287.21it/s][A
380it [00:01, 297.41it/s][A
410it [00:01, 295.63it/s][A
440it [00:01, 294.71it/s][A
470it [00:01, 286.70it/s][A
499it [00:01, 287.07it/s][A
528it [00:01, 283.96it/s][A
560it [00:01, 292.07it/s][A
590it [00:02, 288.52it/s][A
619it [00:02, 286.75it/s][A
648it [00:02, 282.01it/s][A
677it [00:02, 282.50it/s][A
708it [00:02, 288.43it/s][A
737it [00:02, 287.91it/s][A
767it [00:02, 290.71it/s][A
797it [00:02, 289.12it/s][A
827it [00:02, 291.42it/s][A
857it [00:02, 293.37it/s][A
887it [00:03, 287.71it/s][A
916it [00:03, 277.03it/s][A
945it [00:03, 279.90it/s][A
975it [00:03, 285.16it/s][A
1005it [00:03, 286.17it/s][A
1034it [00:03, 

 39%|███▉      | 1943/4996 [00:06<00:09, 309.76it/s][A
 40%|███▉      | 1976/4996 [00:06<00:10, 297.88it/s][A
 40%|████      | 2018/4996 [00:06<00:09, 325.54it/s][A
 41%|████      | 2059/4996 [00:06<00:08, 335.93it/s][A
 42%|████▏     | 2095/4996 [00:07<00:08, 332.19it/s][A
 43%|████▎     | 2134/4996 [00:07<00:08, 347.18it/s][A
 43%|████▎     | 2173/4996 [00:07<00:07, 356.78it/s][A
 44%|████▍     | 2210/4996 [00:07<00:07, 351.45it/s][A
 45%|████▍     | 2246/4996 [00:07<00:07, 343.78it/s][A
 46%|████▌     | 2281/4996 [00:07<00:08, 318.36it/s][A
 46%|████▋     | 2321/4996 [00:07<00:07, 338.89it/s][A
 47%|████▋     | 2359/4996 [00:07<00:07, 350.20it/s][A
 48%|████▊     | 2398/4996 [00:07<00:07, 360.77it/s][A
 49%|████▊     | 2435/4996 [00:08<00:07, 326.26it/s][A
 49%|████▉     | 2469/4996 [00:08<00:08, 299.29it/s][A
 50%|█████     | 2501/4996 [00:08<00:08, 292.78it/s][A
 51%|█████     | 2532/4996 [00:08<00:08, 295.14it/s][A
 51%|█████▏    | 2568/4996 [00:08<00:07, 310.82i



In [22]:
y_schnet_uff = pd.DataFrame([i.flatten() for i in schnet_predictions_uff], index=props).T

In [31]:
mae_schnet_uff = pd.Series({props[i]: mae(
    mol_test_uff[props].values[:, i],
    schnet_predictions_uff[i].flatten()) for i in range(8)},
                         name='b3lyp_schnet_uff')

mae_schnet_uff = mae_schnet_uff * 1000
mae_schnet_uff['spectral_overlap'] /= 1000
mae_schnet_uff.round(1)

gap                           45.1
homo                          33.1
lumo                          31.9
spectral_overlap             170.0
homo_extrapolated             64.8
lumo_extrapolated             63.0
gap_extrapolated              74.3
optical_lumo_extrapolated     60.2
Name: b3lyp_schnet_uff, dtype: float64

In [25]:
model_predictions = {
    '2d_multitarget': y_multitarget,
    '2d_singletarget': y_2d_st,
    'schnet': y_schnet,
    'schnet_uff': y_schnet_uff,
}

with open('model_predictions.p', 'wb') as f:
    pickle.dump(model_predictions, f)