In [1]:
import tensorflow as tf
from tensorflow.keras import layers
from rdkit import Chem
import os
import pickle
import re
import numpy as np
import pandas as pd
from tqdm import tqdm
tqdm.pandas()
import time
from featurer import  save_smiles_dicts, get_smiles_dicts, get_smiles_array
from AttFP_tf_utils import Fingerprint, CosineAnnealingLR_with_Restart

In [2]:
# Example instantiation and testing
model = Fingerprint(2, 2, 39, 10, 192, 1,0.05)
x_atom = tf.random.normal((200, 56, 39))  # atom features
x_bonds = tf.random.normal((200, 63, 10))  # bond features
x_atom_index = tf.random.uniform((200, 56, 6), minval=0, maxval=56, dtype=tf.int32)  # atom neighbor indices
x_bond_index = tf.random.uniform((200, 56, 6), minval=0, maxval=63, dtype=tf.int32)  # bond neighbor indices
x_mask = tf.ones((200, 56))  # atom mask

target = model([x_atom, x_bonds, x_atom_index, x_bond_index, x_mask], doprint=True)
target.shape

2024-10-27 15:36:08.519830: I metal_plugin/src/device/metal_device.cc:1154] Metal device set to: Apple M3 Max
2024-10-27 15:36:08.519858: I metal_plugin/src/device/metal_device.cc:296] systemMemory: 128.00 GB
2024-10-27 15:36:08.519863: I metal_plugin/src/device/metal_device.cc:313] maxCacheSize: 48.00 GB
2024-10-27 15:36:08.519891: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:306] Could not identify NUMA node of platform GPU ID 0, defaulting to 0. Your kernel may not have been built with NUMA support.
2024-10-27 15:36:08.519908: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:272] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 0 MB memory) -> physical PluggableDevice (device: 0, name: METAL, pci bus id: <undefined>)


atom_feature shape: (200, 56, 192)
atom_neighbor shape: (200, 56, 6, 39)
bond_neighbor shape: (200, 56, 6, 10)
Concatenated neighbor_feature shape: (200, 56, 6, 49)
Processed neighbor_feature shape: (200, 56, 6, 192)
feature_align shape: (200, 56, 6, 384)
align_score shape: (200, 56, 6, 1)
attention_weight shape: (200, 56, 6, 1)
context shape after reduce_sum: (200, 56, 192)
atom_features_reshape shape after GRU: (11200, 192)
neighbor_feature shape at radius 0: (200, 56, 6, 192)
feature_align shape at radius 0: (200, 56, 6, 384)
align_score shape at radius 0: (200, 56, 6, 1)
attention_weight shape at radius 0: (200, 56, 6, 1)
context shape after reduce_sum at radius 0: (200, 56, 192)


TensorShape([200, 1])

In [3]:
%timeit model([x_atom, x_bonds, x_atom_index, x_bond_index, x_mask])

48 ms ± 266 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [4]:
random_seed = 108 
start_time = str(time.ctime()).replace(':','-').replace(' ','-')
batch_size = 200
epochs = 200
p_dropout= 0.05
fingerprint_dim = 192

weight_decay = 5 # also known as l2_regularization_lambda
learning_rate = 2.5
output_units_num = 1 # for regression model
radius = 2
T = 2

In [5]:

task_name = 'solubility'
tasks = ['measured log solubility in mols per litre']

raw_filename = "delaney-processed.csv"
feature_filename = raw_filename.replace('.csv','.pickle')
filename = raw_filename.replace('.csv','')
prefix_filename = raw_filename.split('/')[-1].replace('.csv','')
smiles_tasks_df = pd.read_csv(raw_filename)
smilesList = smiles_tasks_df.smiles.values
print("number of all smiles: ",len(smilesList))
atom_num_dist = []
remained_smiles = []
canonical_smiles_list = []
for smiles in smilesList:
    try:        
        mol = Chem.MolFromSmiles(smiles)
        atom_num_dist.append(len(mol.GetAtoms()))
        remained_smiles.append(smiles)
        canonical_smiles_list.append(Chem.MolToSmiles(Chem.MolFromSmiles(smiles), isomericSmiles=True))
    except:
        print(smiles)
        pass
print("number of successfully processed smiles: ", len(remained_smiles))
smiles_tasks_df = smiles_tasks_df[smiles_tasks_df["smiles"].isin(remained_smiles)]
# print(smiles_tasks_df)
smiles_tasks_df['cano_smiles'] =canonical_smiles_list



number of all smiles:  1128
number of successfully processed smiles:  1128


In [6]:

if os.path.isfile(feature_filename):
    feature_dicts = pickle.load(open(feature_filename, "rb" ))
else:
    feature_dicts = save_smiles_dicts(smilesList,filename)
# feature_dicts = get_smiles_dicts(smilesList)
remained_df = smiles_tasks_df[smiles_tasks_df["cano_smiles"].isin(feature_dicts['smiles_to_atom_mask'].keys())]
uncovered_df = smiles_tasks_df.drop(remained_df.index)
print("not processed items")
uncovered_df

not processed items


Unnamed: 0,Compound ID,ESOL predicted log solubility in mols per litre,Minimum Degree,Molecular Weight,Number of H-Bond Donors,Number of Rings,Number of Rotatable Bonds,Polar Surface Area,measured log solubility in mols per litre,smiles,cano_smiles
934,Methane,-0.636,0,16.043,0,0,0,0.0,-0.9,C,C


In [17]:
remained_df = remained_df.reset_index(drop=True)
test_df = remained_df.sample(frac=1/10, random_state=random_seed) # test set
training_data = remained_df.drop(test_df.index) # training data

# training data is further divided into validation set and train set
valid_df = training_data.sample(frac=1/9, random_state=random_seed) # validation set
train_df = training_data.drop(valid_df.index) # train set
train_df = train_df.reset_index(drop=True)
valid_df = valid_df.reset_index(drop=True)
test_df = test_df.reset_index(drop=True)

# Assuming canonical_smiles_list and feature_dicts are available
x_atom_train, x_bonds_train, x_atom_index_train, x_bond_index_train, x_mask_train, smiles_to_rdkit_list_train = get_smiles_array(train_df['cano_smiles'], feature_dicts)
x_atom_val, x_bonds_val, x_atom_index_val, x_bond_index_val, x_mask_val, smiles_to_rdkit_list_val = get_smiles_array(valid_df['cano_smiles'], feature_dicts)
x_atom_test, x_bonds_test, x_atom_index_test, x_bond_index_test, x_mask_test, smiles_to_rdkit_list_test = get_smiles_array(test_df['cano_smiles'], feature_dicts)

x_atom_index_train = tf.cast(x_atom_index_train, tf.int32)
x_atom_index_val = tf.cast(x_atom_index_val, tf.int32)
x_atom_index_test = tf.cast(x_atom_index_test, tf.int32)

x_bond_index_train = tf.cast(x_bond_index_train, tf.int32)
x_bond_index_val = tf.cast(x_bond_index_val, tf.int32)
x_bond_index_test = tf.cast(x_bond_index_test, tf.int32)


# Convert targets into NumPy arrays
y_train = train_df[tasks[0]].values
y_val = valid_df[tasks[0]].values
y_test = test_df[tasks[0]].values

y_train = y_train.reshape(-1, 1)
y_val = y_val.reshape(-1, 1)
y_test = y_test.reshape(-1, 1)


num_atom_features = x_atom_train.shape[-1]
num_bond_features = x_bonds_train.shape[-1]
print(num_atom_features,num_bond_features)

# Move your model to MPS
model = Fingerprint(radius, T, num_atom_features, num_bond_features, fingerprint_dim, output_units_num, p_dropout)

mol_prediction = model([x_atom_train, x_bonds_train, x_atom_index_train, x_bond_index_train, x_mask_train])

model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=10**-2.5), 
              loss='mse', 
              metrics=['mae'])
model.summary()

39 10




Model: "fingerprint_2"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 atom_fc (Dense)             multiple                  7680      
                                                                 
 neighbor_fc (Dense)         multiple                  9600      
                                                                 
 atomgru_0 (GRUCell)         multiple                  222336    
                                                                 
 atomgru_1 (GRUCell)         multiple                  222336    
                                                                 
 align_0 (Dense)             multiple                  385       
                                                                 
 align_1 (Dense)             multiple                  385       
                                                                 
 attend_0 (Dense)            multiple                

In [18]:
# Prepare an optimizer
optimizer = tf.keras.optimizers.AdamW(learning_rate=10**-learning_rate, weight_decay=10**-weight_decay)

# Define a loss function (mean squared error in this case)
loss_fn = tf.keras.losses.MeanSquaredError()
train_loss = tf.keras.metrics.Mean(name="train_loss")  # Track training loss
val_loss = tf.keras.metrics.Mean(name="val_loss")      # Track validation loss
train_mse = tf.keras.metrics.MeanSquaredError(name="train_mse")  # Track RMSE
val_mse = tf.keras.metrics.MeanSquaredError(name="val_mse")      # Track RMSE
test_mse = tf.keras.metrics.MeanSquaredError(name="test_mse")    # Track RMSE

# A single training step function
@tf.function
def train_step(x_atom, x_bonds, x_atom_index, x_bond_index, x_mask, y_true):
    with tf.GradientTape() as tape:
        # Forward pass
        mol_prediction = model([x_atom, x_bonds, x_atom_index, x_bond_index, x_mask], training=True, doprint=False)
        # Compute loss
        loss = loss_fn(y_true, mol_prediction)
    
    # Compute gradients
    gradients = tape.gradient(loss, model.trainable_variables)
    # Apply gradients to update the model's weights
    # Print out gradients and variables to debug
    #for var, grad in zip(model.trainable_variables, gradients):
    #    print(f"Variable: {var.name}, Gradient: {'None' if grad is None else 'Computed'}")
    
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))
    
    # Update metrics
    train_loss.update_state(loss)
    train_mse.update_state(y_true, mol_prediction)

# A single validation step function
@tf.function
def val_step(x_atom, x_bonds, x_atom_index, x_bond_index, x_mask, y_true):
    # Forward pass (without gradient computation)
    print('val')
    mol_prediction = model([x_atom, x_bonds, x_atom_index, x_bond_index, x_mask], training=False, doprint=False)
    # Compute loss
    loss = loss_fn(y_true, mol_prediction)
    
    # Update metrics
    val_loss.update_state(loss)
    val_mse.update_state(y_true, mol_prediction)

def test_model(test_data):
    test_mse.reset_states()  # Reset RMSE metric before evaluating

    # Loop over batches in the test dataset
    for (x_atom, x_bonds, x_atom_index, x_bond_index, x_mask), y_true in test_data:
        # Forward pass (inference)
        mol_prediction = model([x_atom, x_bonds, x_atom_index, x_bond_index, x_mask], training=False)

        # Update metrics
        test_mse.update_state(y_true, mol_prediction)

    print(f'Test RMSE: {np.sqrt(test_mse.result().numpy())}')

# Now you can call the training function
# train_model(train_dataset, val_dataset, epochs=80, callback=cosine_annealing_callback)




In [19]:
# Training loop
def train_model(train_data, val_data, epochs, callbacks=None):
    best_val_mse = float('inf')  # Keep track of the best val MAE

    # Prepare each callback if provided
    if callbacks is not None:
        for callback in callbacks:
            callback.model = model
            callback.on_train_begin()

    for epoch in range(epochs):
        start = time.time()

        # Reset metrics at the start of each epoch
        train_loss.reset_states()
        train_mse.reset_states()
        val_loss.reset_states()
        val_mse.reset_states()

        # On epoch begin callback
        if callbacks is not None:
            for callback in callbacks:
                callback.on_epoch_begin(epoch)

        # Training loop over batches
        for (x_atom, x_bonds, x_atom_index, x_bond_index, x_mask), y_true in train_data:
            train_step(x_atom, x_bonds, x_atom_index, x_bond_index, x_mask, y_true)

        # Validation loop over batches
        for (x_atom, x_bonds, x_atom_index, x_bond_index, x_mask), y_true in val_data:
            val_step(x_atom, x_bonds, x_atom_index, x_bond_index, x_mask, y_true)

        # On epoch end callback
        current_val_mse = val_mse.result().numpy()
        if callbacks is not None:
            for callback in callbacks:
                callback.on_epoch_end(epoch)

        # Print epoch results
        print(f'Epoch {epoch+1}, '
              f'Train Loss: {train_loss.result()}, '
              f'Train RMSE: {np.sqrt(train_mse.result().numpy())}, '
              f'Val RMSE: {np.sqrt(val_mse.result().numpy())}',
              f'time: {time.time()-start}')

        # Check if the current validation MAE is the best and save the model
        if current_val_mse < best_val_mse:
            print(f"Validation MSE improved from {best_val_mse} to {current_val_mse}. Saving model.")
            best_val_mse = current_val_mse
            model.save_weights(checkpoint_filepath)

    # End of training callback
    if callbacks is not None:
        for callback in callbacks:
            callback.on_train_end()

In [20]:
# Convert your data into batched datasets

train_data = tf.data.Dataset.from_tensor_slices(((x_atom_train, x_bonds_train, x_atom_index_train, x_bond_index_train, x_mask_train), y_train)).batch(batch_size)
val_data = tf.data.Dataset.from_tensor_slices(((x_atom_val, x_bonds_val, x_atom_index_val, x_bond_index_val, x_mask_val), y_val)).batch(batch_size)
test_data = tf.data.Dataset.from_tensor_slices(((x_atom_test, x_bonds_test, x_atom_index_test, x_bond_index_test, x_mask_test), y_test)).batch(batch_size)

checkpoint_filepath = './best/best_model.h5'

# Define the ModelCheckpoint callback to save the model with the best val_mae
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_filepath,
    monitor='val_rmse',  # Monitor the validation MAE
    save_best_only=True,  # Save only when the validation MAE improves
    mode='min',  # "min" mode because we want to minimize MAE
    verbose=1
)

# Initialize the cosine annealing learning rate callback
cosine_annealing_callback = CosineAnnealingLR_with_Restart(
    T_max=10,  # Adjust T_max based on your needs
    T_mult=1,  # Adjust T_mult based on your needs
    eta_min=1e-6,  # Minimum learning rate
    verbose=1,
    lr_reduction_factor=0.95,  # Reduce learning rate at restarts
    out_dir="./model_snapshots"  # Directory to save model snapshots
)

callbacks = [
    #checkpoint_callback,
    cosine_annealing_callback
]

In [None]:
train_model(train_data, val_data, epochs=epochs, callbacks=callbacks)



Epoch 00001: CosineAnnealing lr to 0.003162277629598975.
























val
Epoch 1, Train Loss: 9.505282402038574, Train RMSE: 3.135585069656372, Val RMSE: 2.7220170497894287 time: 4.114591836929321
Validation MSE improved from inf to 7.409376621246338. Saving model.

Epoch 00002: CosineAnnealing lr to 0.0030849156595235887.
Epoch 2, Train Loss: 4.54384708404541, Train RMSE: 2.1706650257110596, Val RMSE: 1.8193855285644531 time: 0.519885778427124
Validation MSE improved from 7.409376621246338 to 3.3101634979248047. Saving model.

Epoch 00003: CosineAnnealing lr to 0.0028604024779409483.
Epoch 3, Train Loss: 2.7483673095703125, Train RMSE: 1.6658304929733276, Val RMSE: 1.6147785186767578 time: 0.4812319278717041
Validation MSE improved from 3.3101634979248047 to 2.6075098514556885. Saving model.

Epoch 00004: CosineAnnealing lr to 0.0025107149993396803.
Epoch 4, Train Loss: 2.1879830360412598, Train RMSE: 1.4847526550292969, Val RMSE: 1.3772107362747192 time: 0.473297119140625
Validation MSE improved from 2.6075098514556885 to 1.8967094421386719. Saving mo

In [None]:
test_model(test_data)
