In [1]:
import tensorflow as tf
import numpy as np
import os
import logging
import string
import random
import yaml
from datetime import datetime
from tqdm.notebook import tqdm

from dimenet.model.dimenet import DimeNet
from dimenet.model.dimenet_pp import DimeNetPP
from dimenet.model.activations import swish
from dimenet.training.trainer import Trainer
from dimenet.training.metrics import Metrics
from dimenet.training.data_container import DataContainer
from dimenet.training.data_provider import DataProvider

In [2]:
# Set up logger
logger = logging.getLogger()
logger.handlers = []
ch = logging.StreamHandler()
formatter = logging.Formatter(
        fmt='%(asctime)s (%(levelname)s): %(message)s',
        datefmt='%Y-%m-%d %H:%M:%S')
ch.setFormatter(formatter)
logger.addHandler(ch)
logger.setLevel('INFO')

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1'
tf.get_logger().setLevel('WARN')
tf.autograph.set_verbosity(2)

### Load config

In [3]:
# config.yaml for DimeNet, config_pp.yaml for DimeNet++
with open('config_dye.yaml', 'r') as c:
    config = yaml.safe_load(c)

In [4]:
model_name = config['model_name']

if model_name == "dimenet":
    num_bilinear = config['num_bilinear']
elif model_name == "dimenet++":
    out_emb_size = config['out_emb_size']
    int_emb_size = config['int_emb_size']
    basis_emb_size = config['basis_emb_size']
    extensive = config['extensive']
else:
    raise ValueError(f"Unknown model name: '{model_name}'")
    
emb_size = config['emb_size']
num_blocks = config['num_blocks']

num_spherical = config['num_spherical']
num_radial = config['num_radial']
output_init = config['output_init']

cutoff = config['cutoff']
envelope_exponent = config['envelope_exponent']

num_before_skip = config['num_before_skip']
num_after_skip = config['num_after_skip']
num_dense_output = config['num_dense_output']

num_train = config['num_train']
num_valid = config['num_valid']
data_seed = config['data_seed']
dataset_path = config['dataset']

batch_size = config['batch_size']

#####################################################################
# Change this if you want to predict a different target, e.g. to ['U0']
# (but don't forget to change output_init as well)
targets = config['targets']
#####################################################################

### Load dataset

In [5]:
data_container = DataContainer(dataset_path, cutoff=cutoff, target_keys=targets)

# Initialize DataProvider (splits dataset into training, validation and test set based on data_seed)
data_provider = DataProvider(data_container, num_train, num_valid, batch_size,
                             seed=data_seed, randomized=True)

# Initialize datasets
dataset = data_provider.get_dataset('test').prefetch(tf.data.experimental.AUTOTUNE)
dataset_iter = iter(dataset)

### Initialize model

In [6]:
if model_name == "dimenet":
    model = DimeNet(
            emb_size=emb_size, num_blocks=num_blocks, num_bilinear=num_bilinear,
            num_spherical=num_spherical, num_radial=num_radial,
            cutoff=cutoff, envelope_exponent=envelope_exponent,
            num_before_skip=num_before_skip, num_after_skip=num_after_skip,
            num_dense_output=num_dense_output, num_targets=len(targets),
            activation=swish, output_init=output_init)
elif model_name == "dimenet++":
    model = DimeNetPP(
            emb_size=emb_size, out_emb_size=out_emb_size,
            int_emb_size=int_emb_size, basis_emb_size=basis_emb_size,
            num_blocks=num_blocks, num_spherical=num_spherical, num_radial=num_radial,
            cutoff=cutoff, envelope_exponent=envelope_exponent,
            num_before_skip=num_before_skip, num_after_skip=num_after_skip,
            num_dense_output=num_dense_output, num_targets=len(targets),
            activation=swish, extensive=extensive, output_init=output_init)
else:
    raise ValueError(f"Unknown model name: '{model_name}'")

### Initialize trainer

In [7]:
trainer = Trainer(model)

### Load weights from model at best step

In [8]:
#####################################################################
# Load the trained model from your own training run
directory = "/path/to/log/dir"  # Fill this in
best_ckpt_file = os.path.join(directory, 'best', 'ckpt')
#####################################################################
# Uncomment this if you want to use a pretrained model
directory = f"pretrained/dimenet_pp/{targets[0]}"
best_ckpt_file = os.path.join(directory, 'ckpt')
#####################################################################

model.load_weights(best_ckpt_file)

<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7f85b8679310>

### Prediction loop

In [9]:
# Initialize aggregates
metrics = Metrics('val', targets)
preds_total = np.zeros([data_provider.nsamples['test'], len(targets)], dtype=np.float32)

In [10]:
steps_per_epoch = int(np.ceil(data_provider.nsamples['test'] / batch_size))

for step in tqdm(range(steps_per_epoch)):
    preds = trainer.predict_on_batch(dataset_iter, metrics)
    
    # Update predictions
    batch_start = step * batch_size
    batch_end = min((step + 1) * batch_size, data_provider.nsamples['test'])
    preds_total[batch_start:batch_end] = preds.numpy()

HBox(children=(FloatProgress(value=0.0, max=339.0), HTML(value='')))




In [11]:
print(f"{','.join(targets)} MAE: {metrics.mean_mae}")
print(f"{','.join(targets)} logMAE: {metrics.mean_log_mae}")

U0 MAE: 0.006319692824035883
U0 logMAE: -5.064084529876709
