In [1]:
import os
import mlx.core as mx
import math
import mlx.optimizers as optim
import mlx.nn as nn
import scipy as sp
import numpy as np
import pandas as pd
import time
import matplotlib.pyplot as plt
import pickle
from featurer import save_smiles_dicts, get_smiles_dicts, get_smiles_array
from attfp_mlx_utils import AttFP, cosineannealingwarmrestartfactor
import psutil
import cProfile
import pstats
import io
from rdkit import Chem

In [2]:
random_seed = 108 
start_time = str(time.ctime()).replace(':','-').replace(' ','_')

batch_size = 200
epochs = 200
p_dropout= 0.05
fingerprint_dim = 192

weight_decay = 5 # also known as l2_regularization_lambda
learning_rate = 2.5
output_units_num = 1 # for regression model
radius = 2
T = 2

task_name = 'solubility'
tasks = ['measured log solubility in mols per litre']
raw_filename = "delaney-processed.csv"

In [3]:
feature_filename = raw_filename.replace('.csv','.pickle')
filename = raw_filename.replace('.csv','')
prefix_filename = raw_filename.split('/')[-1].replace('.csv','')
smiles_tasks_df = pd.read_csv(raw_filename)
smilesList = smiles_tasks_df.smiles.values
print("number of all smiles: ",len(smilesList))
atom_num_dist = []
remained_smiles = []
canonical_smiles_list = []
for smiles in smilesList:
    try:        
        mol = Chem.MolFromSmiles(smiles)
        atom_num_dist.append(len(mol.GetAtoms()))
        remained_smiles.append(smiles)
        canonical_smiles_list.append(Chem.MolToSmiles(Chem.MolFromSmiles(smiles), isomericSmiles=True))
    except:
        print(smiles)
        pass
print("number of successfully processed smiles: ", len(remained_smiles))
smiles_tasks_df = smiles_tasks_df[smiles_tasks_df["smiles"].isin(remained_smiles)]
# print(smiles_tasks_df)
smiles_tasks_df['cano_smiles'] =canonical_smiles_list

if os.path.isfile(feature_filename):
    feature_dicts = pickle.load(open(feature_filename, "rb" ))
else:
    feature_dicts = save_smiles_dicts(smilesList,filename)
# feature_dicts = get_smiles_dicts(smilesList)
remained_df = smiles_tasks_df[smiles_tasks_df["cano_smiles"].isin(feature_dicts['smiles_to_atom_mask'].keys())]
uncovered_df = smiles_tasks_df.drop(remained_df.index)
print("not processed items")
uncovered_df

remained_df = remained_df.reset_index(drop=True)
test_df = remained_df.sample(frac=1/10, random_state=random_seed) # test set
training_data = remained_df.drop(test_df.index) # training data

# training data is further divided into validation set and train set
valid_df = training_data.sample(frac=1/9, random_state=random_seed) # validation set
train_df = training_data.drop(valid_df.index) # train set
train_df = train_df.reset_index(drop=True)
valid_df = valid_df.reset_index(drop=True)
test_df = test_df.reset_index(drop=True)

# print(len(test_df),sorted(test_df.cano_smiles.values))

number of all smiles:  1128
number of successfully processed smiles:  1128
not processed items


In [5]:
device = mx.gpu # or mx.cpu
mx.set_default_device(device)

x_atom, x_bonds, x_atom_index, x_bond_index, x_mask, smiles_to_rdkit_list = get_smiles_array([canonical_smiles_list[0]],feature_dicts)
num_atom_features = x_atom.shape[-1]
num_bond_features = x_bonds.shape[-1]
x_atom = mx.array(x_atom)
x_bonds = mx.array(x_bonds)
x_atom_index = mx.array(x_atom_index)
x_bond_index = mx.array(x_bond_index)
x_mask = mx.array(x_mask)

model = AttFP(radius, T, num_atom_features, num_bond_features,
            fingerprint_dim, output_units_num, p_dropout)

# Example: Dynamically create learning rate schedules based on milestones
initial_lr = 10**-learning_rate
restarts = 20
decay_step = 10*5  # Decay steps for each cosine and warmup phase
warmup_factor = 0.95  # Warmup reduction factors

lr_schedule = cosineannealingwarmrestartfactor(initial_lr, restarts, decay_step, warmup_factor)

optimizer = optim.AdamW(learning_rate=lr_schedule, weight_decay=1**-weight_decay)


from functools import partial

def loss_fn(y_hat, y):
    y = mx.reshape(y, y_hat.shape)
    return mx.mean(nn.losses.mse_loss(y_hat, y))


state = [model, optimizer.state, mx.random.state]

def forward_fn(x_atom, x_bonds, x_atom_index, x_bond_index, x_mask, labels):
    _, y_hat = model(x_atom, x_bonds, x_atom_index, x_bond_index, x_mask)
    loss = loss_fn(y_hat, labels)
    return loss, y_hat

@partial(mx.compile, inputs=state, outputs=state)
def step(x_atom, x_bonds, x_atom_index, x_bond_index, x_mask, labels):
    loss_and_grad_fn = nn.value_and_grad(model, forward_fn)
    (loss, y_hat), grads = loss_and_grad_fn(
        x_atom=x_atom,
        x_bonds=x_bonds,
        x_atom_index=x_atom_index,
        x_bond_index=x_bond_index,
        x_mask=x_mask,
        labels=labels,
    )
    optimizer.update(model, grads)
    return loss
        
def train(dataset,e, batch_size=64, doprofile=False):
    print('train')
    loss_sum = 0.0
    np.random.seed(e)
    valList = np.arange(0,dataset.shape[0])
    #shuffle them
    np.random.shuffle(valList)
    batch_list = []
    for i in range(0, dataset.shape[0], batch_size):
        batch = valList[i:i+batch_size]
        batch_list.append(batch)   
    #print('iter per batch:',len(batch_list))
    if doprofile:
        pr = cProfile.Profile()
        pr.enable()  # Start profiling

    for counter, train_batch in enumerate(batch_list):
        batch_df = dataset.loc[train_batch,:]
        smiles_list = batch_df.cano_smiles.values
        y_val = mx.array(batch_df[tasks[0]].values)
        x_atom, x_bonds, x_atom_index, x_bond_index, x_mask, smiles_to_rdkit_list = get_smiles_array(smiles_list,feature_dicts)
        x_atom = mx.array(x_atom)
        x_bonds = mx.array(x_bonds)
        x_atom_index = mx.array(x_atom_index)
        x_bond_index = mx.array(x_bond_index)
        x_mask = mx.array(x_mask)

        loss = step(x_atom, x_bonds, x_atom_index, x_bond_index, x_mask, y_val)
        mx.eval(state)
        
    if doprofile:
        pr.disable()
        # Print profiling results
        s = io.StringIO()
        ps = pstats.Stats(pr, stream=s).sort_stats(pstats.SortKey.CUMULATIVE)
        ps.print_stats(200)  # Print top 10 results
        print(s.getvalue())    
    return loss_sum / len(dataset)

def test(test_dataset, batch_size=64):
    mse= 0.0
    valList = np.arange(0,test_dataset.shape[0])
    batch_list = []
    for i in range(0, test_dataset.shape[0], batch_size):
        batch = valList[i:i+batch_size]
        batch_list.append(batch) 
    
    model.eval()
        
    for counter, test_batch in enumerate(batch_list):
        batch_df = test_dataset.loc[test_batch,:]
        smiles_list = batch_df.cano_smiles.values
        
        x_atom, x_bonds, x_atom_index, x_bond_index, x_mask, _ = get_smiles_array(smiles_list,feature_dicts)
        x_atom = mx.array(x_atom)
        x_bonds = mx.array(x_bonds)
        x_atom_index = mx.array(x_atom_index)
        x_bond_index = mx.array(x_bond_index)
        x_mask = mx.array(x_mask)
        y_val = mx.array(batch_df[tasks[0]].values)
        _, y_hat = model(x_atom, x_bonds, x_atom_index, x_bond_index, x_mask)
        y_val = mx.reshape(y_val, y_hat.shape)
        mse += mx.square(y_hat - y_val).sum().item()
        
    val =  mse / len(test_dataset)
    model.train()
    return val, np.sqrt(val)


def epoch(e, batch_size=64):
    loss = train(train_df,e, batch_size=batch_size)
    train_mse, train_rmse = test(train_df, batch_size=2*batch_size)
    valid_mse, valid_rmse = test(valid_df, batch_size=2*batch_size)

    test_mse, test_rmse = test(test_df, batch_size=batch_size)
    return loss, train_mse, train_rmse, valid_mse, valid_rmse,  test_mse, test_rmse


r = []
best_test_mse = 1e9
print(batch_size)
for e in range(epochs):
    starttime = time.time()
    
    loss, train_mse, train_rmse, valid_mse, valid_rmse, test_mse, test_rmse = epoch(e, batch_size=batch_size)
    stoptime = time.time()

    #print('RAM memory % used:', psutil.virtual_memory()[2],'RAM Used (GB):', psutil.virtual_memory()[3]/1000000000)

    best_test_mse = min(best_test_mse, test_mse)
    r.append((train_rmse,test_rmse))
    print(
        " | ".join(
            [
                f"Epoch: {e:3d}",
                f"loss: {loss:.5f}",
                f"rmse: {train_rmse:.3f}",
                f"rmse: {valid_rmse:.3f}",
                f"rmse: {test_rmse:.3f}",
                f"LR: {np.array(optimizer.learning_rate):.6f}",
                f"Time: {stoptime-starttime}",

            ]
        )
    )
print(f"\n==> Best test mse: {best_test_mse:.3f},  rmse: {np.sqrt(best_test_mse):.3f}")



[50, 100, 150, 200, 250, 300, 350, 400, 450, 500, 550, 600, 650, 700, 750, 800, 850, 900, 950]
200
train
Epoch:   0 | loss: 0.00000 | rmse: 2.347 | rmse: 2.392 | rmse: 1.966 | LR: 0.003113 | Time: 0.7618570327758789
train
Epoch:   1 | loss: 0.00000 | rmse: 2.042 | rmse: 2.033 | rmse: 1.742 | LR: 0.002916 | Time: 0.6949000358581543
train
Epoch:   2 | loss: 0.00000 | rmse: 1.855 | rmse: 1.845 | rmse: 1.685 | LR: 0.002589 | Time: 0.6866519451141357
train
Epoch:   3 | loss: 0.00000 | rmse: 1.845 | rmse: 1.794 | rmse: 1.551 | LR: 0.002163 | Time: 0.6845080852508545
train
Epoch:   4 | loss: 0.00000 | rmse: 1.952 | rmse: 1.963 | rmse: 1.722 | LR: 0.001680 | Time: 0.6807301044464111
train
Epoch:   5 | loss: 0.00000 | rmse: 1.615 | rmse: 1.631 | rmse: 1.455 | LR: 0.001188 | Time: 0.6799459457397461
train
Epoch:   6 | loss: 0.00000 | rmse: 1.472 | rmse: 1.531 | rmse: 1.429 | LR: 0.000734 | Time: 0.6814389228820801
train
Epoch:   7 | loss: 0.00000 | rmse: 1.424 | rmse: 1.471 | rmse: 1.336 | LR: 0