In [1]:
import tensorflow as tf
import numpy as np
import os
import logging
import string
import random
import yaml
from datetime import datetime
from tqdm.notebook import tqdm

from dimenet.model.dimenet import DimeNet
from dimenet.model.activations import swish
from dimenet.training.trainer import Trainer
from dimenet.training.metrics import Metrics
from dimenet.training.data_container import DataContainer
from dimenet.training.data_provider import DataProvider

In [2]:
# Set up logger
logger = logging.getLogger()
logger.handlers = []
ch = logging.StreamHandler()
formatter = logging.Formatter(
        fmt='%(asctime)s (%(levelname)s): %(message)s',
        datefmt='%Y-%m-%d %H:%M:%S')
ch.setFormatter(formatter)
logger.addHandler(ch)
logger.setLevel('INFO')

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1'
tf.get_logger().setLevel('WARN')
tf.autograph.set_verbosity(2)

### Load config

In [3]:
with open('config.yaml', 'r') as c:
    config = yaml.safe_load(c)
print(config)

{'emb_size': 128, 'num_blocks': 6, 'num_bilinear': 8, 'num_spherical': 7, 'num_radial': 6, 'cutoff': 5.0, 'envelope_exponent': 5, 'num_before_skip': 1, 'num_after_skip': 2, 'num_dense_output': 3, 'num_train': 110000, 'num_valid': 10000, 'data_seed': 42, 'dataset': 'data/qm9_eV.npz', 'logdir': '.', 'num_steps': 3000000, 'ema_decay': 0.999, 'learning_rate': 0.001, 'warmup_steps': 3000, 'decay_rate': 0.01, 'decay_steps': 4000000, 'batch_size': 32, 'evaluation_interval': 10000, 'save_interval': 10000, 'restart': 'None', 'comment': 'final', 'targets': ['U0']}


In [50]:
emb_size = config['emb_size']
num_blocks = config['num_blocks']

num_bilinear = config['num_bilinear']
num_spherical = config['num_spherical']
num_radial = config['num_radial']

cutoff = config['cutoff']
envelope_exponent = config['envelope_exponent']

num_before_skip = config['num_before_skip']
num_after_skip = config['num_after_skip']
num_dense_output = config['num_dense_output']

num_train = config['num_train']
num_valid = config['num_valid']
data_seed = config['data_seed']
dataset_path = config['dataset']

batch_size = config['batch_size']

#####################################################################
targets = config['targets']  # Change this if you want to predict a different target, e.g. to ['U0']
#####################################################################
targets = ['lumo']

### Load dataset

In [51]:
data_container = DataContainer(dataset_path, cutoff=cutoff, target_keys=targets)

# Initialize DataProvider (splits dataset into training, validation and test set based on data_seed)
data_provider = DataProvider(data_container, num_train, num_valid, batch_size,
                             seed=data_seed, randomized=True)

# Initialize datasets
dataset = data_provider.get_dataset('test').prefetch(tf.data.experimental.AUTOTUNE)
dataset_iter = iter(dataset)

### Initialize model

In [52]:
model = DimeNet(emb_size=emb_size, num_blocks=num_blocks, num_bilinear=num_bilinear,
                num_spherical=num_spherical, num_radial=num_radial,
                cutoff=cutoff, envelope_exponent=envelope_exponent,
                num_before_skip=num_before_skip, num_after_skip=num_after_skip,
                num_dense_output=num_dense_output, num_targets=len(targets),
                activation=swish)

### Initialize trainer

In [53]:
trainer = Trainer(model)

### Load weights from model at best step

In [54]:
#####################################################################
# Load the trained model from your own training run
# directory = "/path/to/log/dir"  # Fill this in
# best_ckpt_file = os.path.join(directory, 'best', 'ckpt')
#####################################################################
# Uncomment this if you want to use a pretrained model
directory = f"pretrained/{targets[0]}"
best_ckpt_file = os.path.join(directory, 'ckpt')
#####################################################################

model.load_weights(best_ckpt_file)

<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7fa4804809b0>

### Extract loop

In [55]:
# Initialize aggregates
metrics = Metrics('val', targets)
pred_list, input_list, target_list, P_list, feature_list, sum_feature_list = [], [], [], [], [], []

In [56]:
steps_per_epoch = int(np.ceil(data_provider.nsamples['test'] / batch_size))

for step in tqdm(range(steps_per_epoch)):
    preds, inputs, target, P, features, sum_features = trainer.extract_on_batch(dataset_iter, metrics)
    pred_list.append(preds)
    input_list.append(inputs)
    target_list.append(target)
    P_list.append(P)
    feature_list.append(features)
    sum_feature_list.append(sum_features)

HBox(children=(FloatProgress(value=0.0, max=339.0), HTML(value='')))




In [57]:
def tensor2numpy(tensor):
    return tf.make_ndarray(tf.make_tensor_proto(tensor))

data_list = []
for preds, inputs, target, P, features, sum_features in tqdm(zip(pred_list, input_list, target_list, P_list, feature_list, sum_feature_list)):
    batch_seg = tensor2numpy(inputs['batch_seg'])
    idnb_i, idnb_j = tensor2numpy(inputs['idnb_i']), tensor2numpy(inputs['idnb_j'])
    preds, target, P = tensor2numpy(preds), tensor2numpy(target), tensor2numpy(P)
    features, sum_features = tensor2numpy(features), tensor2numpy(sum_features)
    sample_num = np.max(batch_seg)+1
    for si in range(sample_num):
        assign_vector = (batch_seg==si).astype('float')
        diff_vector = np.abs(assign_vector[:-1]-assign_vector[1:])
        assert(np.sum(diff_vector) <= 2)
    node_num_list = [np.sum(batch_seg==si) for si in range(sample_num)]
    split_index = [0,]+np.cumsum(node_num_list).tolist()
    for si in range(sample_num):
        edge_flag = np.logical_and(idnb_i < split_index[si+1], idnb_i >= split_index[si])
        assert(np.all(edge_flag==np.logical_and(idnb_j < split_index[si+1], idnb_j >= split_index[si])))
        data_list.append({
            'feat': features[batch_seg==si],
            'out_feat': sum_features[batch_seg==si],
            'final_feat': P[batch_seg==si],
            'pred': preds[si],
            'y': target[si],
            'edge_index': np.stack([
                idnb_j[edge_flag]-split_index[si],
                idnb_i[edge_flag]-split_index[si],
            ], axis=0)
        })

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))




In [58]:
import os, pickle as pkl

if not os.path.exists(f'extracted_features'):
    os.mkdir(f'extracted_features')
filename = f"extracted_features/{targets[0]}.pkl"
with open(filename, 'wb') as f:
    pkl.dump(data_list, f)