In [26]:
import os
import mlx.core as mx
import math
import mlx.optimizers as optim
import scipy as sp
import numpy as np
import pandas as pd
from mlx_graphs.data import GraphData
from mlx_graphs.datasets.dataset import Dataset
from mlx_graphs.datasets.utils import download
from mlx_graphs.utils.transformations import to_sparse_adjacency_matrix
from typing import Tuple
from typing import Optional
from rdkit import Chem
from rdkit.Chem import Lipinski
from rdkit.Chem import rdMolDescriptors
import mlx.optimizers as optim
from mlx_graphs.loaders import Dataloader
import mlx.nn as nn
from mlx_graphs.nn import GINConv, global_mean_pool, global_max_pool, Linear
import time
from attfp import AttentiveFP
import matplotlib.pyplot as plt
import pickle
from AttentiveFP import save_smiles_dicts, get_smiles_dicts, get_smiles_array

import cProfile
import pstats
import io

random_seed = 108 
start_time = str(time.ctime()).replace(':','-').replace(' ','_')

batch_size = 32
epochs = 200

p_dropout= 0.2
fingerprint_dim = 10

weight_decay = 5 # also known as l2_regularization_lambda
learning_rate = 2.5
output_units_num = 1 # for regression model
radius = 2
T = 2

task_name = 'solubility'
tasks = ['measured log solubility in mols per litre']

raw_filename = "delaney-processed.csv"
feature_filename = raw_filename.replace('.csv','.pickle')
filename = raw_filename.replace('.csv','')
prefix_filename = raw_filename.split('/')[-1].replace('.csv','')
smiles_tasks_df = pd.read_csv(raw_filename)
smilesList = smiles_tasks_df.smiles.values
print("number of all smiles: ",len(smilesList))
atom_num_dist = []
remained_smiles = []
canonical_smiles_list = []
for smiles in smilesList:
    try:        
        mol = Chem.MolFromSmiles(smiles)
        atom_num_dist.append(len(mol.GetAtoms()))
        remained_smiles.append(smiles)
        canonical_smiles_list.append(Chem.MolToSmiles(Chem.MolFromSmiles(smiles), isomericSmiles=True))
    except:
        print(smiles)
        pass
print("number of successfully processed smiles: ", len(remained_smiles))
smiles_tasks_df = smiles_tasks_df[smiles_tasks_df["smiles"].isin(remained_smiles)]
# print(smiles_tasks_df)
smiles_tasks_df['cano_smiles'] =canonical_smiles_list



if os.path.isfile(feature_filename):
    feature_dicts = pickle.load(open(feature_filename, "rb" ))
else:
    feature_dicts = save_smiles_dicts(smilesList,filename)
# feature_dicts = get_smiles_dicts(smilesList)
remained_df = smiles_tasks_df[smiles_tasks_df["cano_smiles"].isin(feature_dicts['smiles_to_atom_mask'].keys())]
uncovered_df = smiles_tasks_df.drop(remained_df.index)
print("not processed items")
uncovered_df

remained_df = remained_df.reset_index(drop=True)
test_df = remained_df.sample(frac=1/10, random_state=random_seed) # test set
training_data = remained_df.drop(test_df.index) # training data

# training data is further divided into validation set and train set
valid_df = training_data.sample(frac=1/9, random_state=random_seed) # validation set
train_df = training_data.drop(valid_df.index) # train set
train_df = train_df.reset_index(drop=True)
valid_df = valid_df.reset_index(drop=True)
test_df = test_df.reset_index(drop=True)

# print(len(test_df),sorted(test_df.cano_smiles.values))



class GRUCell(nn.Module):
    """A GRU Cell that returns the final hidden state only."""
    def __init__(
        self,
        input_size: int,
        hidden_size: int,
        bias: bool = True,
    ):
        super().__init__()

        self.hidden_size = hidden_size
        scale = 1.0 / math.sqrt(hidden_size)
        self.Wx = mx.random.uniform(
            low=-scale, high=scale, shape=(3 * hidden_size, input_size)
        )
        self.Wh = mx.random.uniform(
            low=-scale, high=scale, shape=(3 * hidden_size, hidden_size)
        )
        self.b = (
            mx.random.uniform(low=-scale, high=scale, shape=(3 * hidden_size,))
            if bias
            else None
        )
        self.bhn = (
            mx.random.uniform(low=-scale, high=scale, shape=(hidden_size,))
            if bias
            else None
        )

    def __call__(self, x, hidden=None):
        if self.b is not None:
            x = mx.addmm(self.b, x, self.Wx.T)
        else:
            x = x @ self.Wx.T

        x_rz = x[..., : -self.hidden_size]
        x_n = x[..., -self.hidden_size :]

        for idx in range(x.shape[-2]):
            rz = x_rz[..., idx, :]
            if hidden is not None:
                h_proj = hidden @ self.Wh.T
                h_proj_rz = h_proj[..., : -self.hidden_size]
                h_proj_n = h_proj[..., -self.hidden_size :]

                if self.bhn is not None:
                    h_proj_n += self.bhn

                rz = rz + h_proj_rz

            rz = mx.sigmoid(rz)

            r, z = mx.split(rz, 2, axis=-1)

            n = x_n[..., idx, :]

            if hidden is not None:
                n = n + r * h_proj_n
            n = mx.tanh(n)

            if hidden is not None:
                hidden = (1 - z) * n + z * hidden
            else:
                hidden = (1 - z) * n

        return hidden

class AttFP(nn.Module):
    def __init__(self, radius, T, input_feature_dim, input_bond_dim, fingerprint_dim, output_units_num, p_dropout=0.1):
        super(AttFP, self).__init__()
        
        self.atom_fc =  nn.Linear(input_feature_dim, fingerprint_dim)
        self.neighbor_fc =  nn.Linear(input_feature_dim + input_bond_dim, fingerprint_dim)
        self.GRUCell = [GRUCell(fingerprint_dim, fingerprint_dim) for r in range(radius)]
        self.align = [nn.Linear(2 * fingerprint_dim, 1) for r in range(radius)]
        self.attend = [nn.Linear(fingerprint_dim, fingerprint_dim) for r in range(radius)]
        
        self.molGRU =  GRUCell(fingerprint_dim, fingerprint_dim)
        self.mol_align = nn.Linear(2 * fingerprint_dim, 1)
        self.mol_attend = nn.Linear(fingerprint_dim, fingerprint_dim)

        self.dropout = nn.Dropout(p=p_dropout)
        self.output = nn.Linear(fingerprint_dim, output_units_num)

        self.radius = radius
        self.T = T


    def __call__(self, atom_list, bond_list, atom_degree_list, bond_degree_list, atom_mask):
  
        atom_mask = atom_mask[:,:,None]
        batch_size, mol_length, num_atom_feat = atom_list.shape
        
        atom_feature = nn.leaky_relu(self.atom_fc(mx.array(atom_list)))
        
        bond_neighbor = [mx.array(bond_list[i][bond_degree_list[i]]) for i in range(batch_size)]
        bond_neighbor = mx.stack(bond_neighbor, axis=0)
                
        atom_neighbor = [mx.array(atom_list[i][atom_degree_list[i]]) for i in range(batch_size)]
        atom_neighbor = mx.stack(atom_neighbor, axis=0)
                
        neighbor_feature = mx.concatenate([atom_neighbor, bond_neighbor], axis=-1)
        neighbor_feature = nn.leaky_relu(self.neighbor_fc(neighbor_feature))
                
        # Generate mask to eliminate the influence of blank atoms
        attend_mask = mx.array(atom_degree_list)
        attend_mask = mx.where(attend_mask == mol_length - 1, mx.array(0.0), mx.array(1.0))
        attend_mask = attend_mask[:,:,:,None]
        
        softmax_mask = mx.array(atom_degree_list)
        softmax_mask = mx.where(softmax_mask == mol_length - 1,  mx.array(float("-inf")),mx.array(0.0))
        softmax_mask =  mx.expand_dims(softmax_mask, 3)
        
        batch_size, mol_length, max_neighbor_num, fingerprint_dim = neighbor_feature.shape
                
        atom_feature_expand = mx.expand_dims(atom_feature, 2) #[:,:,None,:]
        atom_feature_expand = mx.repeat(atom_feature_expand,  max_neighbor_num, axis=2)
        feature_align = mx.concatenate([atom_feature_expand, neighbor_feature], axis=-1)
        
        align_score = nn.leaky_relu(self.align[0](self.dropout(feature_align)))
        align_score = align_score + softmax_mask  # Ensure both tensors are on the same device
        attention_weight = nn.softmax(align_score, -2)
        attention_weight = attention_weight * attend_mask
        
        neighbor_feature_transform = self.attend[0](self.dropout(neighbor_feature))
        context = mx.sum(attention_weight * neighbor_feature_transform, axis=-2)
        context = nn.elu(context)
        
        context_reshape = mx.reshape(context,(batch_size * mol_length, fingerprint_dim))
        atom_feature_reshape = mx.reshape(atom_feature,(batch_size * mol_length, fingerprint_dim))
        
        atom_feature_reshape = self.GRUCell[0](context_reshape, atom_feature_reshape)
        atom_feature = mx.reshape(atom_feature_reshape,(batch_size, mol_length, fingerprint_dim))
        
        activated_features = nn.relu(atom_feature)
        
        for d in range(1,self.radius):
        
            atom_degree_list_mx = mx.array(atom_degree_list)
            neighbor_feature = [activated_features[i][atom_degree_list_mx[i]] for i in range(batch_size)]
            neighbor_feature = mx.stack(neighbor_feature, axis=0)
        
            atom_feature_expand = mx.expand_dims(activated_features, 2) #[:,:,None,:]
            atom_feature_expand = mx.repeat(atom_feature_expand,  max_neighbor_num, axis=2)
            feature_align = mx.concatenate([atom_feature_expand, neighbor_feature], axis=-1)
            
            align_score = nn.leaky_relu(self.align[d](self.dropout(feature_align)))
            align_score = align_score + softmax_mask  # Ensure both tensors are on the same device
            attention_weight = nn.softmax(align_score, -2)
            attention_weight = attention_weight * attend_mask
            
            neighbor_feature_transform = self.attend[d](self.dropout(neighbor_feature))
            context = mx.sum(attention_weight * neighbor_feature_transform, axis=-2)
            context = nn.elu(context)
            context_reshape = mx.reshape(context,(batch_size * mol_length, fingerprint_dim))
            
            atom_feature_reshape = self.GRUCell[d](context_reshape, atom_feature_reshape)
            atom_feature = mx.reshape(atom_feature_reshape,(batch_size, mol_length, fingerprint_dim))
            activated_features = nn.relu(atom_feature)
        
        mol_feature = mx.sum(activated_features * atom_mask, axis=-2)
        
        activated_features_mol = nn.relu(mol_feature)
        
        mol_softmax_mask = mx.array(atom_mask)
        mol_softmax_mask = mx.where(mol_softmax_mask == 1,  mx.array(0.0), mx.array(float("-inf")))  
        # this one is strange 
        #mol_softmax_mask[mol_softmax_mask == 0] = -9e8
        #mol_softmax_mask[mol_softmax_mask == 1] = 0
        
        for t in range(self.T):
            mol_prediction_expand = mx.expand_dims(activated_features_mol, 1) 
            mol_prediction_expand = mx.repeat(mol_prediction_expand,  mol_length, axis=1)
            
            mol_align = mx.concatenate([mol_prediction_expand, activated_features], axis=-1)
        
             
            mol_align_score = nn.leaky_relu(self.mol_align(mol_align))
            mol_align_score = mol_align_score + mol_softmax_mask  # Ensure both tensors are on the same device
            mol_attention_weight = nn.softmax(mol_align_score, -2)
            mol_attention_weight = mol_attention_weight * atom_mask
            
            activated_feature_transform = self.mol_attend(self.dropout(activated_features))
            mol_context = mx.sum(mol_attention_weight * activated_feature_transform, axis=-2)
            mol_context = nn.elu(mol_context)    
            mol_feature = self.molGRU(mol_context, mol_feature)
            activated_features_mol = nn.relu(mol_feature)
        
        mol_prediction = self.output(self.dropout(mol_feature))

        return atom_feature, mol_prediction


device = mx.gpu # or mx.cpu
mx.set_default_device(device)




x_atom, x_bonds, x_atom_index, x_bond_index, x_mask, smiles_to_rdkit_list = get_smiles_array([canonical_smiles_list[0]],feature_dicts)
num_atom_features = x_atom.shape[-1]
num_bond_features = x_bonds.shape[-1]
model = AttFP(radius, T, num_atom_features, num_bond_features,
            fingerprint_dim, output_units_num, p_dropout)

def loss_fn(y_hat, y, parameters=None):
    if len(y_hat.shape) != len(y.shape):
        y = mx.expand_dims(y,1)
    return mx.mean(nn.losses.mse_loss(y_hat, y))

def forward_fn(model, x_atom, x_bonds, x_atom_index, x_bond_index, x_mask, labels):
    _, y_hat = model(x_atom, x_bonds, x_atom_index, x_bond_index, x_mask)
    loss = loss_fn(y_hat, labels, model.parameters())
    return loss, y_hat
        
def train(dataset,e):
     # Profiler setup
    pr = cProfile.Profile()
    pr.enable()  # Start profiling

    loss_sum = 0.0
    np.random.seed(e)
    valList = np.arange(0,dataset.shape[0])
    #shuffle them
    np.random.shuffle(valList)
    batch_list = []
    for i in range(0, dataset.shape[0], batch_size):
        batch = valList[i:i+batch_size]
        batch_list.append(batch)   
    print(len(batch_list))
    pr = cProfile.Profile()
    pr.enable()  # Start profiling
    for counter, train_batch in enumerate(batch_list):
        batch_df = dataset.loc[train_batch,:]
        smiles_list = batch_df.cano_smiles.values
        y_val = mx.array(batch_df[tasks[0]].values)
        x_atom, x_bonds, x_atom_index, x_bond_index, x_mask, smiles_to_rdkit_list = get_smiles_array(smiles_list,feature_dicts)
        
        (loss, y_hat), grads = loss_and_grad_fn(
            model=model,
            x_atom=x_atom,
            x_bonds=x_bonds,
            x_atom_index=x_atom_index,
            x_bond_index=x_bond_index,
            x_mask=x_mask,
            labels=y_val,
        )
        
        optimizer.update(model, grads)
        mx.eval(model.parameters(), optimizer.state)
        loss_sum += loss.item()
        print(counter)
    pr.disable()

    # Print profiling results
    s = io.StringIO()
    ps = pstats.Stats(pr, stream=s).sort_stats(pstats.SortKey.CUMULATIVE)
    ps.print_stats(200)  # Print top 10 results
    print(s.getvalue())    

    return loss_sum / len(dataset)


def train1(dataset,e):

    loss_sum = 0.0
    np.random.seed(e)
    valList = np.arange(0,dataset.shape[0])
    #shuffle them
    np.random.shuffle(valList)
    batch_list = []
    for i in range(0, dataset.shape[0], batch_size):
        batch = valList[i:i+batch_size]
        batch_list.append(batch)   
    print(len(batch_list))
    train_batch =       batch_list[0]

    batch_df = dataset.loc[train_batch,:]
    smiles_list = batch_df.cano_smiles.values
    y_val = mx.array(batch_df[tasks[0]].values)
    x_atom, x_bonds, x_atom_index, x_bond_index, x_mask, smiles_to_rdkit_list = get_smiles_array(smiles_list,feature_dicts)
    # Profiler setup
    pr = cProfile.Profile()
    pr.enable()  # Start profiling

    (loss, y_hat), grads = loss_and_grad_fn(
        model=model,
        x_atom=x_atom,
        x_bonds=x_bonds,
        x_atom_index=x_atom_index,
        x_bond_index=x_bond_index,
        x_mask=x_mask,
        labels=y_val,
    )
    
    optimizer.update(model, grads)
    mx.eval(model.parameters(), optimizer.state)
    # Stop profiling
    pr.disable()

    # Print profiling results
    s = io.StringIO()
    ps = pstats.Stats(pr, stream=s).sort_stats(pstats.SortKey.CUMULATIVE)
    ps.print_stats(200)  # Print top 10 results
    print(s.getvalue())    
    return loss_sum / len(dataset)


def test(test_dataset):
    mse= 0.0
    valList = np.arange(0,test_dataset.shape[0])
    batch_list = []
    for i in range(0, test_dataset.shape[0], batch_size):
        batch = valList[i:i+batch_size]
        batch_list.append(batch) 
        
    for counter, test_batch in enumerate(batch_list):
        batch_df = test_dataset.loc[test_batch,:]
        smiles_list = batch_df.cano_smiles.values
        y_val = mx.array(batch_df[tasks[0]].values)
        
        x_atom, x_bonds, x_atom_index, x_bond_index, x_mask, _ = get_smiles_array(smiles_list,feature_dicts)
        _, y_hat = model(x_atom, x_bonds, x_atom_index, x_bond_index, x_mask)
       
        mse += mx.square(y_hat - y_val).sum().item()
        
    val =  mse / len(test_dataset)
    return val, np.sqrt(val)


def epoch(e):
    loss = train(train_df,e)
    train_mse, train_rmse = test(train_df)
    test_mse, test_rmse = test(test_df)
    return loss, train_mse, train_rmse, test_mse, test_rmse


def epoch1(e):
    loss = train1(train_df,e)

mx.eval(model.parameters())
optimizer = optim.AdamW(10**-learning_rate, weight_decay=10**-weight_decay)
loss_and_grad_fn = nn.value_and_grad(model, forward_fn)



epochs = 1
r = []
best_test_mse = 1e9
for e in range(epochs):
    loss, train_mse, train_rmse, test_mse, test_rmse = epoch(e)
    best_test_mse = min(best_test_mse, test_mse)
    r.append((train_rmse,test_rmse))
    print(
        " | ".join(
            [
                f"Epoch: {e:3d}",
                f"Train loss: {loss:.3f}",
                f"Train mse: {train_mse:.3f}",
                f"Train rmse: {train_rmse:.3f}",
                f"Test mse: {test_mse:.3f}",
                f"Test rmse: {test_rmse:.3f}",
                f"LR: {np.array(optimizer.learning_rate)}",

            ]
        )
    )
print(f"\n==> Best test mse: {best_test_mse:.3f},  rmse: {np.sqrt(best_test_mse):.3f}")

epoch1(e)

train(train_df,0)

Exception ignored When destroying _lsprof profiler:
Traceback (most recent call last):
  File "/var/folders/4w/xmf8nmhs51j4vjsttzcxmmm00000gn/T/ipykernel_44750/3573216986.py", line 331, in train
RuntimeError: Cannot install a profile function while another profile function is being installed


number of all smiles:  1128
number of successfully processed smiles:  1128
not processed items
29
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
         698131 function calls (691134 primitive calls) in 2.779 seconds

   Ordered by: cumulative time
   List reduced from 266 to 200 due to restriction <200>

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
       29    1.736    0.060    2.738    0.094 /Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/mlx/nn/utils.py:33(wrapped_value_grad_fn)
       29    0.000    0.000    0.999    0.034 /Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/mlx/nn/utils.py:27(inner_fn)
       29    0.000    0.000    0.999    0.034 /var/folders/4w/xmf8nmhs51j4vjsttzcxmmm00000gn/T/ipykernel_44750/3573216986.py:310(forward_fn)
       29    0.005    0.000    0.996    0.034 /var/folders/4w/xmf8nmhs51j4vjsttzcxmmm00000gn/T/ipykernel_44750/3573216986.py:188(__call__)

Exception ignored When destroying _lsprof profiler:
Traceback (most recent call last):
  File "/var/folders/4w/xmf8nmhs51j4vjsttzcxmmm00000gn/T/ipykernel_44750/3573216986.py", line 331, in train
RuntimeError: Cannot install a profile function while another profile function is being installed


         23894 function calls (23669 primitive calls) in 0.101 seconds

   Ordered by: cumulative time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1    0.064    0.064    0.101    0.101 /Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/mlx/nn/utils.py:33(wrapped_value_grad_fn)
        1    0.000    0.000    0.036    0.036 /Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/mlx/nn/utils.py:27(inner_fn)
        1    0.000    0.000    0.036    0.036 /var/folders/4w/xmf8nmhs51j4vjsttzcxmmm00000gn/T/ipykernel_44750/3573216986.py:310(forward_fn)
        1    0.000    0.000    0.036    0.036 /var/folders/4w/xmf8nmhs51j4vjsttzcxmmm00000gn/T/ipykernel_44750/3573216986.py:188(__call__)
        4    0.034    0.008    0.036    0.009 /var/folders/4w/xmf8nmhs51j4vjsttzcxmmm00000gn/T/ipykernel_44750/3573216986.py:129(__call__)
    10978    0.001    0.000    0.002    0.000 /Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib

nan