In [None]:
import io
import itertools
import os
import pickle

from datetime import datetime
from os import path

import numpy as np
import tensorflow as tf

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, PowerTransformer
from tensorflow.keras.layers import Input, Dense
from tensorflow.keras.losses import MeanSquaredError, MeanAbsoluteError
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import Callback, EarlyStopping
from tensorflow.keras.backend import mean, square

from spektral.datasets import qm9
from spektral.layers import EdgeConditionedConv, GINConv, GatedGraphConv
from spektral.layers import ops, GlobalSumPool, GlobalAttentionPool
from spektral.utils import batch_iterator, numpy_to_disjoint
from spektral.utils import label_to_one_hot

In [None]:
def load_data(amount=None, mode='batch'):
    if mode not in ['batch', 'disjoint']:
        raise ValueError(f"mode {mode} not recognized; "
                         "choose 'batch' or 'disjoint'")
    
    A_all, X_all, E_all, y_all = qm9.load_data(return_type='numpy',
                               nf_keys='atomic_num',
                               ef_keys='type',
                               self_loops=True,
                               amount=amount) # None for entire dataset
    # Preprocessing
    if mode == 'batch':
        X_uniq = np.unique(X_all)
        X_uniq = X_uniq[X_uniq != 0]
        E_uniq = np.unique(E_all)
        E_uniq = E_uniq[E_uniq != 0]

        X_all = label_to_one_hot(X_all, X_uniq)
        E_all = label_to_one_hot(E_all, E_uniq)
    elif mode == 'disjoint':
        X_uniq = np.unique([v for x in X_all for v in np.unique(x)])
        E_uniq = np.unique([v for e in E_all for v in np.unique(e)])
        X_uniq = X_uniq[X_uniq != 0]
        E_uniq = E_uniq[E_uniq != 0]

        X_all = [label_to_one_hot(x, labels=X_uniq) for x in X_all]
        E_all = [label_to_one_hot(e, labels=E_uniq) for e in E_all]
    
    return A_all, X_all, E_all, y_all

In [None]:
def sample_from_data(sample_size, A_all, X_all, E_all, y_all, mode='batch'):
    if mode not in ['batch', 'disjoint']:
        raise ValueError(f"mode {mode} not recognized; "
                         "choose 'batch' or 'disjoint'")
    if mode == 'batch':
        indices = np.random.choice(X_all.shape[0], sample_size, replace=False)
        A = A_all[indices, :, :]
        X = X_all[indices, :, :]
        E = E_all[indices, :, :, :]
        y = y_all.iloc[indices, :].copy()
        
    if mode == 'disjoint':
        indices = np.random.choice(len(X_all), sample_size, replace=False)
        A = [A_all[i] for i in indices]
        X = [X_all[i] for i in indices]
        E = [E_all[i] for i in indices]
        y = y_all.iloc[indices, :].copy()
    
    return A, X, E, y    

In [None]:
def standardize(y):
    task_to_scaler = dict()
    for task in list(y.columns)[1:]:
        scaler = PowerTransformer()
        y.loc[:, task] = scaler.fit_transform(y[[task]])
        task_to_scaler[task] = scaler
    return task_to_scaler

In [None]:
def get_shape_params(*, A, X, E, mode='batch'):
    if mode not in ['batch', 'disjoint']:
        raise ValueError(f"mode {mode} not recognized; "
                         "choose 'batch' or 'disjoint'")
    F = X[0].shape[-1]  # Dimension of node features
    S = E[0].shape[-1]  # Dimension of edge features
    if mode == 'batch':
        N = X.shape[-2]       # Number of nodes in the graphs
        return N, F, S
    if mode == 'disjoint':
        return F, S

In [None]:
def get_input_tensors(*, A, X, E, mode='batch'):
    if mode not in ['batch', 'disjoint']:
        raise ValueError(f"mode {mode} not recognized; "
                         "choose 'batch' or 'disjoint'")
    if mode == 'batch':
        N, F, S = get_shape_params(A=A, X=X, E=E, mode=mode)
        X_in = Input(shape=(N, F), name='X_in')
        A_in = Input(shape=(N, N), name='A_in')
        E_in = Input(shape=(N, N, S), name='E_in')

        return X_in, A_in, E_in
    
    if mode == 'disjoint':
        F, S = get_shape_params(A=A, X=X, E=E, mode=mode)
        X_in = Input(shape=(F,), name='X_in')
        A_in = Input(shape=(None,), sparse=True, name='A_in')
        E_in = Input(shape=(S,), name='E_in')
        I_in = Input(shape=(), name='segment_ids_in', dtype=tf.int32)
        
        return X_in, A_in, E_in, I_in

In [None]:
def build_single_task_model(*, A, X, E, learning_rate=1e-3, conv='ecc', mode='batch'):  
    if mode not in ['batch', 'disjoint']:
        raise ValueError(f"mode {mode} not recognized; "
                         "choose 'batch' or 'disjoint'")  
    if conv not in ['ecc', 'gin']:
        raise ValueError(f"convolution layer {conv} not recognized; "
                         "choose 'ecc' or 'gin'")
    
    if mode == 'batch':
        X_in, A_in, E_in = get_input_tensors(A=A, X=X, E=E, mode=mode)
    if mode == 'disjoint':
        X_in, A_in, E_in, I_in = get_input_tensors(A=A, X=X, E=E, mode=mode)

    if conv == 'ecc':    
        gc1 = EdgeConditionedConv(64, activation='relu')([X_in, A_in, E_in])
        gc2 = EdgeConditionedConv(128, activation='relu')([gc1, A_in, E_in])
    if conv == 'gin':
        assert mode == 'disjoint', 'cannot run GIN in batch mode'
        gc1 = GINConv(64, activation='relu')([X_in, A_in, E_in])
        gc2 = GINConv(128, activation='relu')([gc1, A_in, E_in])
    if mode == 'batch':
        pool = GlobalAttentionPool(256)(gc2)
    if mode == 'disjoint':
        pool = GlobalAttentionPool(256)([gc2, I_in])
    dense = Dense(256, activation='relu')(pool)
    output = Dense(1)(dense)

    optimizer = Adam(lr=learning_rate)
    loss_fn = MeanSquaredError()
    if mode == 'batch':
        model = Model(inputs=[X_in, A_in, E_in], outputs=output)
        model.compile(optimizer=optimizer, loss=loss_fn)
    if mode == 'disjoint':
        model = Model(inputs=[X_in, A_in, E_in, I_in], outputs=output)
    
    return model, loss_fn

In [None]:
def build_hard_sharing_model(*, A, X, E, num_tasks, 
                             learning_rate=1e-3, conv='ecc', mode='batch'):
    if mode not in ['batch', 'disjoint']:
        raise ValueError(f"mode {mode} not recognized; "
                         "choose 'batch' or 'disjoint'")  
    if conv not in ['ecc', 'gin']:
        raise ValueError(f"convolution layer {conv} not recognized; "
                         "choose 'ecc' or 'gin'")
    if mode == 'batch':
        X_in, A_in, E_in = get_input_tensors(A=A, X=X, E=E, mode=mode)
    if mode == 'disjoint':
        X_in, A_in, E_in, I_in = get_input_tensors(A=A, X=X, E=E, mode=mode)

    
    if conv == 'ecc':    
        gc1 = EdgeConditionedConv(64, activation='relu')([X_in, A_in, E_in])
        gc2 = EdgeConditionedConv(128, activation='relu')([gc1, A_in, E_in])
    if conv == 'gin':
        assert mode == 'disjoint', 'cannot run GIN in batch mode'
        gc1 = GINConv(64, activation='relu')([X_in, A_in, E_in])
        gc2 = GINConv(128, activation='relu')([gc1, A_in, E_in])
    if mode == 'batch':
        pool = GlobalAttentionPool(256)(gc2)
    if mode == 'disjoint':
        pool = GlobalAttentionPool(256)([gc2, I_in])
    dense_list = [Dense(256, activation='relu')(pool) 
                  for i in range(num_tasks)]
    output_list = [Dense(1)(dense_layer) for dense_layer in dense_list]

    optimizer = Adam(lr=learning_rate)
    loss_fn = MeanSquaredError()
    if mode == 'batch':
        model = Model(inputs=[X_in, A_in, E_in], outputs=output_list)
        model.compile(optimizer=optimizer, loss=loss_fn)
    if mode == 'disjoint':
        model = Model(inputs=[X_in, A_in, E_in, I_in], outputs=output_list)
    
    return model, loss_fn

In [None]:
def build_soft_sharing_model(*, A, X, E, num_tasks, share_param, 
                             learning_rate=1e-3, conv='ecc', mode='batch'):
    if mode not in ['batch', 'disjoint']:
        raise ValueError(f"mode {mode} not recognized; "
                         "choose 'batch' or 'disjoint'")  
    if conv not in ['ecc', 'gin']:
        raise ValueError(f"convolution layer {conv} not recognized; "
                         "choose 'ecc' or 'gin'")
    if mode == 'batch':
        X_in, A_in, E_in = get_input_tensors(A=A, X=X, E=E, mode=mode)
    if mode == 'disjoint':
        X_in, A_in, E_in, I_in = get_input_tensors(A=A, X=X, E=E, mode=mode)
        
    if conv == 'ecc':
        conv_layer = EdgeConditionedConv
    if conv == 'gin':
        conv_layer = GINConv

    gc1_list = [conv_layer(64, activation='relu')([X_in, A_in, E_in]) 
                for i in range(num_tasks)]
    gc2_list = [conv_layer(128, activation='relu')([gc1, A_in, E_in]) 
                for gc1 in gc1_list]
    if mode == 'batch':
        pool_list = [GlobalAttentionPool(256)(gc2) for gc2 in gc2_list]
    if mode == 'disjoint':
        pool_list = [GlobalAttentionPool(256)([gc2, I_in]) for gc2 in gc2_list]
    dense_list = [Dense(256, activation='relu')(pool) for pool in pool_list]
    output_list = [Dense(1)(dense) for dense in dense_list]

    def loss_fn(y_actual, y_pred):
        avg_layer_diff = 0
        for i, j in itertools.combinations(range(num_tasks), 2):
            for gc in [gc1_list, gc2_list]:
                diff = gc[i].trainable_weights - gc[j].trainable_weights
                avg_layer_diff += mean(square(diff))
        avg_layer_diff /= (num_tasks)*(num_tasks-1)/2  
        return mean(square(y_actual - y_pred)) + share_param*avg_layer_diff

    optimizer = Adam(lr=learning_rate)
    if mode == 'batch':
        model = Model(inputs=[X_in, A_in, E_in], outputs=output_list)
        model.compile(optimizer=optimizer, loss=loss_fn)
    if mode == 'disjoint':
        model = Model(inputs=[X_in, A_in, E_in, I_in], outputs=output_list)
    
    return model, loss_fn

In [None]:
def generate_model_filename(tasks, conv='ecc', mode='batch', folder_path='demo_models'):
    filename = "".join(sorted(tasks)) + '_' + conv + '_' + mode 
    return path.join(folder_path, f'{filename}.h5')

def generate_task_scaler_filename(task, folder_path='demo_models'):
    return path.join(folder_path, f'{task}_scaler.pkl')

In [None]:
def save_model(model, tasks, task_to_scaler, mode='batch', conv='ecc'):
    model.save_weights(generate_model_filename(tasks, conv=conv, mode=mode))
    for task in tasks:
        scaler_filename = generate_task_scaler_filename(task)
        with open(scaler_filename, 'wb') as f:
            scaler = task_to_scaler[task]
            pickle.dump(obj=scaler, file=f)

def load_hard_sharing_model(*, A, X, E, tasks, conv='ecc', 
                            mode='batch', task_to_scaler=dict()):
    model, _ = build_hard_sharing_model(A=A, X=X, E=E, conv=conv, mode=mode,
                                     num_tasks=len(tasks))
    model.load_weights(generate_model_filename(tasks, conv=conv, mode=mode))
    for task in tasks:
        if task not in task_to_scaler:
            with open(generate_task_scaler_filename(task), 'rb') as f:
                task_to_scaler[task] = pickle.load(f)
    return model, task_to_scaler

In [None]:
def train_multitask_disjoint(model, cluster, *, opt, loss_fn, batch_size, 
                             epochs, A_train, X_train, E_train, y_train, 
                             loss_logger=None):
    F, S = get_shape_params(A=A_train, X=X_train, E=E_train, mode='disjoint')
    @tf.function(
        input_signature=(tf.TensorSpec((None, F), dtype=tf.float64),
                         tf.SparseTensorSpec((None, None), dtype=tf.float64),
                         tf.TensorSpec((None, S), dtype=tf.float64),
                         tf.TensorSpec((None,), dtype=tf.int32),
                         tf.TensorSpec((None, len(cluster)), dtype=tf.float64)),
        experimental_relax_shapes=True)
    def train_step(X_, A_, E_, I_, y_):
        with tf.GradientTape() as tape:
            predictions = model([X_, A_, E_, I_], training=True)
            loss = loss_fn(y_, predictions)
            loss += sum(model.losses)
        gradients = tape.gradient(loss, model.trainable_variables)
        opt.apply_gradients(zip(gradients, model.trainable_variables))
        return loss
    
    current_batch = 0
    model_loss = 0
    batches_in_epoch = np.ceil(len(A_train) / batch_size)

    print('Fitting model')
    batches_train = batch_iterator([X_train, A_train, E_train, y_train[cluster].values],
                                   batch_size=batch_size, epochs=epochs)
    epoch_num = 1
    for b in batches_train:
        X_, A_, E_, I_ = numpy_to_disjoint(*b[:-1])
        A_ = ops.sp_matrix_to_sp_tensor(A_)
        y_ = b[-1]
        outs = train_step(X_, A_, E_, I_, y_)

        model_loss += outs.numpy()
        current_batch += 1
        if current_batch == batches_in_epoch:
            print('Loss: {}'.format(model_loss / batches_in_epoch))
            if loss_logger is not None:
                loss_logger.losses[epoch_num] = model_loss / batches_in_epoch
            model_loss = 0
            current_batch = 0
            epoch_num += 1

In [None]:
def test_multitask_disjoint(model, cluster, *, loss_fn, batch_size, A_test, X_test, E_test, y_test):
    print('Testing model')
    model_loss = 0
    batches_in_epoch = np.ceil(len(A_test) / batch_size)
    batches_test = batch_iterator([X_test, A_test, E_test, y_test[cluster].values], batch_size=batch_size)
    for b in batches_test:
        X_, A_, E_, I_ = numpy_to_disjoint(*b[:-1])
        A_ = ops.sp_matrix_to_sp_tensor(A_)
        y_ = b[3]

        predictions = model([X_, A_, E_, I_], training=False)
        model_loss += loss_fn(y_, predictions)
    model_loss /= batches_in_epoch
    print('Done. Test loss: {}'.format(model_loss))
    return model_loss

In [None]:
def predict_property(prop, mol_id, clusters, *, X_all, A_all, E_all,
                     mode='batch', conv='ecc', model=None, 
                     task_to_scaler=dict()):
    
    if mode not in ['batch', 'disjoint']:
        raise ValueError(f"mode {mode} not recognized; "
                         "choose 'batch' or 'disjoint'")  
    if conv not in ['ecc', 'gin']:
        raise ValueError(f"convolution layer {conv} not recognized; "
                         "choose 'ecc' or 'gin'")
    
    cluster = [c for c in clusters if prop in c][0]
    if model is None:
        model, task_to_scaler = load_hard_sharing_model(
            A=A_all, X=X_all, E=E_all, tasks=cluster, 
            mode=mode, conv=conv, task_to_scaler=task_to_scaler
        )
    i = mol_id - 1

    # convert shape for batch mode
    if mode == 'batch':
        def wrap(a):
            return a.reshape([1] + list(a.shape))
        x = list(map(wrap, [X_all[i], A_all[i], E_all[i]]))
        cluster_prediction = model.predict(x)       
    
    if mode == 'disjoint':
        X_, A_, E_, I_ = numpy_to_disjoint([X_all[i]], [A_all[i]], [E_all[i]])
        A_ = ops.sp_matrix_to_sp_tensor(A_)
        cluster_prediction = model([X_, A_, E_, I_], training=False)
    
    prediction = cluster_prediction[cluster.index(prop)]
    prediction = task_to_scaler[prop].inverse_transform(prediction)
    return prediction[0][0]

In [None]:
class LossLoggerCallback(Callback):
    def __init__(self):
        self.losses = dict()
        
    def on_epoch_end(self, epoch, logs=None):
        self.losses[epoch] = logs["loss"]

class ModelData:
    def __init__(self, params=dict()):
        self.timestamp = datetime.now()
        self.loss_logger = LossLoggerCallback()
        
        """
        Possible params keys:
        mode: 'batch' or 'disjoint'
        conv: 'ecc' or 'gin'
        single_task: true or false
        cluster: only if single_task is false
        hard_sharing: true or false, only if single_task is false 
        soft_weight: only if hard_sharing is false and single_task is false
        batch_size
        epochs
        num_sampled
        learning_rate
        model_summary
        loss_fn: string of loss function name
        optimizer: string of optimizer name
        """
        self.params = params
        
        # not using actual/pred dict in case values collide
        self.actual = list()
        self.pred = list()
        
    def get_losses(self):
        return self.loss_logger.losses
    
    def add_test(self, actual, pred):
        self.actual.append(actual)
        self.pred.append(pred)
    
    def _make_picklable(self):
        return {'params': params, 
                'actual': actual, 
                'pred': pred, 
                'losses': self.loss_logger.losses}
    
    def serialize(self, dirname='model_data', filename=''):
        if filename == '':
            dt_string = self.timestamp.strftime('%d-%m-%Y_%H-%M-%S')
            filename = path.join(dirname, dt_string + '.pkl')
        with open(filename, 'wb') as file:
            pickle.dump(self._make_picklable(), file)

In [None]:
if __name__ == '__main__' and '__file__' not in globals():
    mode = 'disjoint'
    conv = 'gin'
    batch_size = 32
    epochs = 40
    num_sampled = 100000
    learning_rate = 1e-3
    amount = None
    A_all, X_all, E_all, y_all = load_data(amount=amount, mode=mode)

In [None]:
if __name__ == '__main__' and '__file__' not in globals(): 
    A, X, E, y = sample_from_data(num_sampled, A_all, X_all, E_all, 
                                  y_all, mode=mode)
    task_to_scaler = standardize(y)

In [None]:
clusters = [['A', 'lumo', 'homo'], 
            ['B', 'r2', 'cv'], 
            ['alpha', 'zpve'], 
            ['C', 'u0', 'u298', 'mu'], 
            ['g298', 'h298']]

In [None]:
if __name__ == '__main__' and '__file__' not in globals():     
    A_train, A_test, \
        X_train, X_test, \
        E_train, E_test, \
        y_train, y_test = train_test_split(A, X, E, y, test_size=0.1)

In [None]:
if __name__ == '__main__' and '__file__' not in globals():
    print('begin training models')
   
    tasks = [[task] for cluster in clusters for task in cluster]
    tasks_and_clusters = itertools.chain(tasks, clusters)
    for cluster, conv in itertools.product(tasks_and_clusters,
                                           ['ecc', 'gin']):
        print(f'training {cluster} with {mode} mode on {conv} conv')
        
        model, loss_fn = build_hard_sharing_model(A=A_train, 
                                                  X=X_train, 
                                                  E=E_train, 
                                                  num_tasks=len(cluster),
                                                  mode=mode,
                                                  conv=conv)
        optimizer = Adam(lr=learning_rate)
        
        stream = io.StringIO()
        model.summary(print_fn=lambda x: stream.write(x + '\n'))
        summary = stream.getvalue()
        
        params = {'mode': mode, 
                  'conv': conv,
                  'batch_size': batch_size,
                  'epochs': epochs,
                  'num_sampled': num_sampled,
                  'learning_rate': learning_rate,
                  'cluster': cluster,
                  'hard_sharing': True,
                  'model_summary': summary,
                  'loss_fn': type(loss_fn).__name__,
                  'optimizer': type(optimizer).__name__}
        model_data = ModelData(params=params)
                
        if mode == 'batch':
            # training
            y_train_cluster = np.hsplit(y_train[cluster].values, len(cluster))
            model.compile(optimizer=optimizer, 
                          loss=loss_fn)
            model.fit(x=[X_train, A_train, E_train], 
                      y=y_train_cluster,
                      batch_size=batch_size,
                      validation_split=0.1,
                      epochs=epochs,
                      callbacks=[model_data.loss_logger])
            
            # testing
            y_test_cluster = np.hsplit(y_test[cluster].values, len(cluster))
            model_loss = model.evaluate(x=[X_test, A_test, E_test],
                                        y=y_test_cluster)
            print(f"Test loss on {cluster}: {model_loss}")
            cluster_pred = model.predict([X_test, A_test, E_test])

        if mode == 'disjoint':
            # training
            train_multitask_disjoint(model,
                                     cluster,
                                     opt=Adam(lr=1e-3),
                                     loss_fn=loss_fn,
                                     batch_size=batch_size,
                                     epochs=epochs,
                                     A_train=A_train,
                                     X_train=X_train,
                                     E_train=E_train,
                                     y_train=y_train, 
                                     loss_logger=model_data.loss_logger)
            # testing
            model_loss = test_multitask_disjoint(model,
                                                cluster,
                                                loss_fn=loss_fn,
                                                batch_size=batch_size,
                                                A_test=A_test,
                                                X_test=X_test,
                                                E_test=E_test,
                                                y_test=y_test)
            X_, A_, E_, I_ = numpy_to_disjoint(X_test, A_test, E_test)
            A_ = ops.sp_matrix_to_sp_tensor(A_)
            cluster_pred = model([X_, A_, E_, I_], training=False)
            
        if len(cluster) == 1:
            cluster_pred = [cluster_pred]

        for prop, batch_pred in zip(cluster, cluster_pred):
            batch_pred = task_to_scaler[prop].inverse_transform(batch_pred)
            errors = list()
            for index, pred in zip(y_test.index.values, batch_pred):
                actual = y_all.loc[index, prop]
                model_data.add_test(actual, pred[0])
        
        save_model(model, cluster, task_to_scaler, mode=mode, conv=conv)
        model_data.serialize()