In [1]:
import os
import torch
import torch.autograd as autograd
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as Data
torch.manual_seed(8)

import time
import numpy as np
import gc
import sys
sys.setrecursionlimit(50000)
import pickle
torch.backends.cudnn.benchmark = False
torch.nn.Module.dump_patches = True
import copy
import pandas as pd
#then import my own modules
from AttentiveFP import Fingerprint,  save_smiles_dicts, get_smiles_dicts, get_smiles_array

In [2]:
from rdkit import Chem
# from rdkit.Chem import AllChem
from rdkit.Chem import QED
from rdkit.Chem import rdMolDescriptors, MolSurf
from rdkit.Chem.Draw import SimilarityMaps
from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit.Chem import rdDepictor
from rdkit.Chem.Draw import rdMolDraw2D
%matplotlib inline
from numpy.polynomial.polynomial import polyfit
import matplotlib.pyplot as plt
from matplotlib import gridspec
import matplotlib.cm as cm
import matplotlib
import seaborn as sns; sns.set_style("darkgrid")
from IPython.display import SVG, display
import itertools
from sklearn.metrics import r2_score
import scipy

In [3]:
random_seed = 108 
start_time = str(time.ctime()).replace(':','-').replace(' ','_')

batch_size = 200
epochs = 200

p_dropout= 0.2
fingerprint_dim = 200

weight_decay = 5 # also known as l2_regularization_lambda
learning_rate = 2.5
output_units_num = 1 # for regression model
radius = 2
T = 2

In [4]:
task_name = 'solubility'
tasks = ['measured log solubility in mols per litre']

raw_filename = "delaney-processed.csv"
feature_filename = raw_filename.replace('.csv','.pickle')
filename = raw_filename.replace('.csv','')
prefix_filename = raw_filename.split('/')[-1].replace('.csv','')
smiles_tasks_df = pd.read_csv(raw_filename)
smilesList = smiles_tasks_df.smiles.values
print("number of all smiles: ",len(smilesList))
atom_num_dist = []
remained_smiles = []
canonical_smiles_list = []
for smiles in smilesList:
    try:        
        mol = Chem.MolFromSmiles(smiles)
        atom_num_dist.append(len(mol.GetAtoms()))
        remained_smiles.append(smiles)
        canonical_smiles_list.append(Chem.MolToSmiles(Chem.MolFromSmiles(smiles), isomericSmiles=True))
    except:
        print(smiles)
        pass
print("number of successfully processed smiles: ", len(remained_smiles))
smiles_tasks_df = smiles_tasks_df[smiles_tasks_df["smiles"].isin(remained_smiles)]
# print(smiles_tasks_df)
smiles_tasks_df['cano_smiles'] =canonical_smiles_list



number of all smiles:  1128
number of successfully processed smiles:  1128


In [5]:
if os.path.isfile(feature_filename):
    feature_dicts = pickle.load(open(feature_filename, "rb" ))
else:
    feature_dicts = save_smiles_dicts(smilesList,filename)
# feature_dicts = get_smiles_dicts(smilesList)
remained_df = smiles_tasks_df[smiles_tasks_df["cano_smiles"].isin(feature_dicts['smiles_to_atom_mask'].keys())]
uncovered_df = smiles_tasks_df.drop(remained_df.index)
print("not processed items")
uncovered_df

not processed items


Unnamed: 0,Compound ID,ESOL predicted log solubility in mols per litre,Minimum Degree,Molecular Weight,Number of H-Bond Donors,Number of Rings,Number of Rotatable Bonds,Polar Surface Area,measured log solubility in mols per litre,smiles,cano_smiles
934,Methane,-0.636,0,16.043,0,0,0,0.0,-0.9,C,C


In [6]:
remained_df = remained_df.reset_index(drop=True)
test_df = remained_df.sample(frac=1/10, random_state=random_seed) # test set
training_data = remained_df.drop(test_df.index) # training data

# training data is further divided into validation set and train set
valid_df = training_data.sample(frac=1/9, random_state=random_seed) # validation set
train_df = training_data.drop(valid_df.index) # train set
train_df = train_df.reset_index(drop=True)
valid_df = valid_df.reset_index(drop=True)
test_df = test_df.reset_index(drop=True)

# print(len(test_df),sorted(test_df.cano_smiles.values))


In [7]:
test_df

Unnamed: 0,Compound ID,ESOL predicted log solubility in mols per litre,Minimum Degree,Molecular Weight,Number of H-Bond Donors,Number of Rings,Number of Rotatable Bonds,Polar Surface Area,measured log solubility in mols per litre,smiles,cano_smiles
0,1-Chlorobutane,-1.940,1,92.569,0,0,2,0.00,-2.03,CCCCCl,CCCCCl
1,"2,6-Dimethylphenol",-2.589,1,122.167,1,1,0,20.23,-1.29,Cc1cccc(C)c1O,Cc1cccc(C)c1O
2,RTI 24,-4.423,1,273.723,1,3,1,45.23,-5.36,CCN2c1cc(Cl)ccc1NC(=O)c3cccnc23,CCN1c2cc(Cl)ccc2NC(=O)c2cccnc21
3,1-Dodecanol,-3.523,1,186.339,1,0,10,20.23,-4.80,CCCCCCCCCCCCO,CCCCCCCCCCCCO
4,3-Pentanol,-0.970,1,88.150,1,0,2,20.23,-0.24,CCC(O)CC,CCC(O)CC
...,...,...,...,...,...,...,...,...,...,...,...
108,p-Hydroxybenzaldehyde,-2.003,1,122.123,1,1,1,37.30,-0.96,Oc1ccc(C=O)cc1,O=Cc1ccc(O)cc1
109,Propyl propanoate,-1.545,1,116.160,0,0,3,26.30,-1.34,CCCCC(=O)OC,CCCCC(=O)OC
110,Methylcyclopentane,-2.452,1,84.162,0,1,0,0.00,-3.30,CC1CCCC1,CC1CCCC1
111,Dimethyldisulfide,-1.524,1,94.204,0,0,1,0.00,-1.44,CSSC,CSSC


In [8]:
device = torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")
device

device(type='mps')

In [9]:
x_atom, x_bonds, x_atom_index, x_bond_index, x_mask, smiles_to_rdkit_list = get_smiles_array([canonical_smiles_list[0]],feature_dicts)
num_atom_features = x_atom.shape[-1]
num_bond_features = x_bonds.shape[-1]
loss_function = nn.MSELoss()
model = Fingerprint(radius, T, num_atom_features, num_bond_features,
            fingerprint_dim, output_units_num, p_dropout)

# optimizer = optim.Adam(model.parameters(), learning_rate, weight_decay=weight_decay)
optimizer = optim.Adam(model.parameters(), 10**-learning_rate, weight_decay=10**-weight_decay)
# optimizer = optim.SGD(model.parameters(), 10**-learning_rate, weight_decay=10**-weight_decay)

# tensorboard = SummaryWriter(log_dir="runs/"+start_time+"_"+prefix_filename+"_"+str(fingerprint_dim)+"_"+str(p_dropout))
"""
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
params = sum([np.prod(p.size()) for p in model_parameters])
print(params)
for name, param in model.named_parameters():
    if param.requires_grad:
        print(name, param.data.shape)
"""        

def train(model, dataset, optimizer, loss_function):
    model.train()
    np.random.seed(epoch)
    valList = np.arange(0,dataset.shape[0])
    #shuffle them
    np.random.shuffle(valList)
    batch_list = []
    for i in range(0, dataset.shape[0], batch_size):
        batch = valList[i:i+batch_size]
        batch_list.append(batch)   
    for counter, train_batch in enumerate(batch_list):
        batch_df = dataset.loc[train_batch,:]
        smiles_list = batch_df.cano_smiles.values
        y_val = batch_df[tasks[0]].values
        
        x_atom, x_bonds, x_atom_index, x_bond_index, x_mask, smiles_to_rdkit_list = get_smiles_array(smiles_list,feature_dicts)
        atoms_prediction, mol_prediction = model(torch.Tensor(x_atom),torch.Tensor(x_bonds),torch.LongTensor(x_atom_index),torch.LongTensor(x_bond_index),torch.Tensor(x_mask))
        
        model.zero_grad()
        loss = loss_function(mol_prediction, torch.Tensor(y_val).view(-1,1))     
        loss.backward()
        optimizer.step()


def eval(model, dataset):
    model.eval()
    test_MAE_list = []
    test_MSE_list = []
    valList = np.arange(0,dataset.shape[0])
    batch_list = []
    for i in range(0, dataset.shape[0], batch_size):
        batch = valList[i:i+batch_size]
        batch_list.append(batch) 
    for counter, test_batch in enumerate(batch_list):
        batch_df = dataset.loc[test_batch,:]
        smiles_list = batch_df.cano_smiles.values
#         print(batch_df)
        y_val = batch_df[tasks[0]].values
        
        x_atom, x_bonds, x_atom_index, x_bond_index, x_mask, smiles_to_rdkit_list = get_smiles_array(smiles_list,feature_dicts)
        atoms_prediction, mol_prediction = model(torch.Tensor(x_atom),torch.Tensor(x_bonds),torch.LongTensor(x_atom_index),torch.LongTensor(x_bond_index),torch.Tensor(x_mask))
        MAE = F.l1_loss(mol_prediction, torch.Tensor(y_val).view(-1,1), reduction='none')        
        MSE = F.mse_loss(mol_prediction, torch.Tensor(y_val).view(-1,1), reduction='none')
#         print(x_mask[:2],atoms_prediction.shape, mol_prediction,MSE)
        
        test_MAE_list.extend(MAE.data.squeeze().cpu().numpy())
        test_MSE_list.extend(MSE.data.squeeze().cpu().numpy())
    return np.array(test_MAE_list).mean(), np.array(test_MSE_list).mean()


best_param ={}
best_param["train_epoch"] = 0
best_param["valid_epoch"] = 0
best_param["train_MSE"] = 9e8
best_param["valid_MSE"] = 9e8

st = time.time()
for epoch in range(800):
    train_MAE, train_MSE = eval(model, train_df)
    valid_MAE, valid_MSE = eval(model, valid_df)
#     tensorboard.add_scalars('MAE',{'train_MAE':valid_MAE, 'test_MAE':valid_MSE}, epoch)
#     tensorboard.add_scalars('MSE',{'train_MSE':valid_MAE, 'test_MSE':valid_MSE}, epoch)
    if train_MSE < best_param["train_MSE"]:
        best_param["train_epoch"] = epoch
        best_param["train_MSE"] = train_MSE
    if valid_MSE < best_param["valid_MSE"]:
        best_param["valid_epoch"] = epoch
        best_param["valid_MSE"] = valid_MSE
        if valid_MSE < 0.35:
             torch.save(model, 'saved_models/model_'+prefix_filename+'_'+start_time+'_'+str(epoch)+'.pt')
    if (epoch - best_param["train_epoch"] >8) and (epoch - best_param["valid_epoch"] >10):        
        break
    
    tt = time.time()-st
    print(epoch, np.sqrt(train_MSE), np.sqrt(valid_MSE), tt)
    
    train(model, train_df, optimizer, loss_function)
    st = time.time()

0 3.737257 3.8358169 0.6889142990112305
1 2.3454325 2.3386853 0.44771814346313477
2 1.6861997 1.6299764 0.5241658687591553
3 1.7689035 1.7173629 0.45954203605651855
4 1.6851 1.6441001 0.5215370655059814


KeyboardInterrupt: 

In [64]:
# evaluate model
best_model = torch.load('saved_models/model_'+prefix_filename+'_'+start_time+'_'+str(best_param["valid_epoch"])+'.pt')     

best_model_dict = best_model.state_dict()
best_model_wts = copy.deepcopy(best_model_dict)

model.load_state_dict(best_model_wts)
(best_model.align[0].weight == model.align[0].weight).all()
test_MAE, test_MSE = eval(model, test_df)
print("best epoch:",best_param["valid_epoch"],"\n","test RMSE:",np.sqrt(test_MSE))

FileNotFoundError: [Errno 2] No such file or directory: 'saved_models/model_delaney-processed_Thu_Oct_17_20-43-40_2024_5.pt'

In [11]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np

# Set the device to MPS (Metal Performance Shaders) if available, otherwise use CPU
device = torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")

# Assume canonical_smiles_list and feature_dicts are already defined
x_atom, x_bonds, x_atom_index, x_bond_index, x_mask, smiles_to_rdkit_list = get_smiles_array([canonical_smiles_list[0]], feature_dicts)
num_atom_features = x_atom.shape[-1]
num_bond_features = x_bonds.shape[-1]

# Move your model to MPS
loss_function = nn.MSELoss()
model = Fingerprint(radius, T, num_atom_features, num_bond_features, fingerprint_dim, output_units_num, p_dropout).to(device)

# Optimizer
optimizer = optim.Adam(model.parameters(), 10**-learning_rate, weight_decay=10**-weight_decay)

"""
# Print model parameters
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
params = sum([np.prod(p.size()) for p in model_parameters])
print(params)
for name, param in model.named_parameters():
    if param.requires_grad:
        print(name, param.data.shape)
"""

def train(model, dataset, optimizer, loss_function):
    model.train()
    np.random.seed(epoch)
    valList = np.arange(0, dataset.shape[0])
    np.random.shuffle(valList)
    batch_list = []
    for i in range(0, dataset.shape[0], batch_size):
        batch = valList[i:i + batch_size]
        batch_list.append(batch)

    for counter, train_batch in enumerate(batch_list):
        batch_df = dataset.loc[train_batch, :]
        smiles_list = batch_df.cano_smiles.values
        y_val = batch_df[tasks[0]].values

        x_atom, x_bonds, x_atom_index, x_bond_index, x_mask, smiles_to_rdkit_list = get_smiles_array(smiles_list, feature_dicts)

        # Move data to the MPS device
        x_atom = torch.Tensor(x_atom).to(device)
        x_bonds = torch.Tensor(x_bonds).to(device)
        x_atom_index = torch.LongTensor(x_atom_index).to(device)
        x_bond_index = torch.LongTensor(x_bond_index).to(device)
        x_mask = torch.Tensor(x_mask).to(device)
        y_val = torch.Tensor(y_val).to(device).view(-1, 1)

        # Forward pass
        atoms_prediction, mol_prediction = model(x_atom, x_bonds, x_atom_index, x_bond_index, x_mask)

        # Backpropagation
        model.zero_grad()
        loss = loss_function(mol_prediction, y_val)
        loss.backward()
        optimizer.step()


def eval(model, dataset):
    model.eval()
    test_MAE_list = []
    test_MSE_list = []
    valList = np.arange(0, dataset.shape[0])
    batch_list = []
    for i in range(0, dataset.shape[0], batch_size):
        batch = valList[i:i + batch_size]
        batch_list.append(batch)

    for counter, test_batch in enumerate(batch_list):
        batch_df = dataset.loc[test_batch, :]
        smiles_list = batch_df.cano_smiles.values
        y_val = batch_df[tasks[0]].values

        x_atom, x_bonds, x_atom_index, x_bond_index, x_mask, smiles_to_rdkit_list = get_smiles_array(smiles_list, feature_dicts)

        # Move data to the MPS device
        x_atom = torch.Tensor(x_atom).to(device)
        x_bonds = torch.Tensor(x_bonds).to(device)
        x_atom_index = torch.LongTensor(x_atom_index).to(device)
        x_bond_index = torch.LongTensor(x_bond_index).to(device)
        x_mask = torch.Tensor(x_mask).to(device)
        y_val = torch.Tensor(y_val).to(device).view(-1, 1)

        # Forward pass
        atoms_prediction, mol_prediction = model(x_atom, x_bonds, x_atom_index, x_bond_index, x_mask)
        
        # Compute losses
        MAE = F.l1_loss(mol_prediction, y_val, reduction='none')
        MSE = F.mse_loss(mol_prediction, y_val, reduction='none')

        test_MAE_list.extend(MAE.data.squeeze().cpu().numpy())
        test_MSE_list.extend(MSE.data.squeeze().cpu().numpy())

    return np.array(test_MAE_list).mean(), np.array(test_MSE_list).mean()




# Training loop
best_param = {}
best_param["train_epoch"] = 0
best_param["valid_epoch"] = 0
best_param["train_MSE"] = 9e8
best_param["valid_MSE"] = 9e8

for epoch in range(800):
    st = time.time()

    train_MAE, train_MSE = eval(model, train_df)
    valid_MAE, valid_MSE = eval(model, valid_df)

    if train_MSE < best_param["train_MSE"]:
        best_param["train_epoch"] = epoch
        best_param["train_MSE"] = train_MSE
    if valid_MSE < best_param["valid_MSE"]:
        best_param["valid_epoch"] = epoch
        best_param["valid_MSE"] = valid_MSE
        if valid_MSE < 0.35:
            torch.save(model, 'saved_models/model_' + prefix_filename + '_' + start_time + '_' + str(epoch) + '.pt')
    if (epoch - best_param["train_epoch"] > 15) and (epoch - best_param["valid_epoch"] > 15):
        break
        
    tt = time.time()-st

    print(epoch, np.sqrt(train_MSE), np.sqrt(valid_MSE),tt)
    train(model, train_df, optimizer, loss_function)



0 3.615654 3.724313 0.2547619342803955
1 2.3772168 2.377616 0.20005011558532715
2 1.7480872 1.6968665 0.1998450756072998
3 1.8745983 1.8041772 0.19939208030700684
4 1.7044008 1.6572366 0.20171689987182617
5 1.583944 1.5409473 0.19936513900756836
6 1.491966 1.4652581 0.19896984100341797
7 1.3530041 1.3028105 0.19799494743347168
8 1.1328714 1.1743983 0.19801783561706543
9 1.0703236 1.2446014 0.1975080966949463
10 1.0124611 1.1530569 0.19768595695495605
11 0.9683418 1.0671117 0.19868111610412598
12 0.92986304 0.97338265 0.20084786415100098
13 0.8738954 0.97434 0.19850897789001465
14 0.87408006 0.9557488 0.19936895370483398
15 0.8477525 0.9257929 0.20505213737487793
16 0.82137585 0.90787846 0.19728493690490723
17 0.82134336 0.90912473 0.1980431079864502
18 0.75300795 0.8486903 0.1991558074951172
19 0.7268298 0.821751 0.19780588150024414
20 0.70267737 0.8030773 0.1992189884185791
21 0.720822 0.7841787 0.1996290683746338
22 0.7065823 0.7837175 0.20296597480773926
23 0.7530048 0.80178267 0.19

KeyboardInterrupt: 

In [48]:
def pred(model, dataset):
    model.eval()
  
    valList = np.arange(0, dataset.shape[0])
    batch_list = []
    y_pred = []
    y_true = []

    for i in range(0, dataset.shape[0], batch_size):
        batch = valList[i:i + batch_size]
        batch_list.append(batch)

    for counter, test_batch in enumerate(batch_list):
        batch_df = dataset.loc[test_batch, :]
        smiles_list = batch_df.cano_smiles.values
        y_val = batch_df[tasks[0]].values

        x_atom, x_bonds, x_atom_index, x_bond_index, x_mask, smiles_to_rdkit_list = get_smiles_array(smiles_list, feature_dicts)

        # Move data to the MPS device
        x_atom = torch.Tensor(x_atom).to(device)
        x_bonds = torch.Tensor(x_bonds).to(device)
        x_atom_index = torch.LongTensor(x_atom_index).to(device)
        x_bond_index = torch.LongTensor(x_bond_index).to(device)
        x_mask = torch.Tensor(x_mask).to(device)
        y_val = torch.Tensor(y_val).to(device).view(-1, 1)

        # Forward pass
        atoms_prediction, mol_prediction = model(x_atom, x_bonds, x_atom_index, x_bond_index, x_mask)
        
        # Compute losses
        
        y_pred.append(mol_prediction)
        y_true.append(y_val)

       
    return y_true, y_pred
y_pred, y_true = pred(model,valid_df)


In [53]:
# evaluate model
best_model = torch.load('saved_models/model_'+prefix_filename+'_'+start_time+'_'+str(best_param["valid_epoch"])+'.pt')     

best_model_dict = best_model.state_dict()
best_model_wts = copy.deepcopy(best_model_dict)

model.load_state_dict(best_model_wts)
(best_model.align[0].weight == model.align[0].weight).all()
test_MAE, test_MSE = eval(model, test_df)
print("best epoch:",best_param["valid_epoch"],"\n","test RMSE:",np.sqrt(test_MSE))

best epoch: 72 
 test RMSE: 0.5546058


In [28]:
import torch.profiler

def profile_one_batch(model, dataset, optimizer, loss_function, feature_dicts, batch_size, tasks, device):
    model.train()  # Set the model to training mode
    model = model.to(device)  # Move the model to MPS or CPU

    # Select one batch
    valList = np.arange(dataset.shape[0])
    np.random.shuffle(valList)
    train_batch = valList[:batch_size]

    # Get the batch data
    batch_df = dataset.iloc[train_batch]
    smiles_list = batch_df.cano_smiles.values
    y_val = batch_df[tasks[0]].values

    # Get input arrays from SMILES strings and pre-computed feature dicts
    x_atom, x_bonds, x_atom_index, x_bond_index, x_mask, _ = get_smiles_array(smiles_list, feature_dicts)

    # Move input data to the MPS device
    x_atom = torch.Tensor(x_atom).to(device)
    x_bonds = torch.Tensor(x_bonds).to(device)
    x_atom_index = torch.LongTensor(x_atom_index).to(device)
    x_bond_index = torch.LongTensor(x_bond_index).to(device)
    x_mask = torch.Tensor(x_mask).to(device)
    y_val = torch.Tensor(y_val).view(-1, 1).to(device)

    # Start profiling for just one batch
    with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CPU]) as profiler:
        
        # Forward pass through the model
        _, mol_prediction = model(x_atom, x_bonds, x_atom_index, x_bond_index, x_mask)

        # Zero gradients, backpropagation, and optimizer step
        optimizer.zero_grad()
        loss = loss_function(mol_prediction, y_val)  # Calculate the loss
        loss.backward()  # Backpropagation
        optimizer.step()  # Update the model weights

        # Step the profiler after this batch
        profiler.step()

    # Print the profiling results
    print(profiler.key_averages().table(sort_by="cpu_time_total", row_limit=50))


# Example usage
device =  torch.device('cpu')
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
loss_function = torch.nn.MSELoss()
batch_size = 32
tasks = ['measured log solubility in mols per litre']

profile_one_batch(model, train_df, optimizer, loss_function, feature_dicts, batch_size, tasks, device)


STAGE:2024-10-17 11:38:42 87880:56463086 ActivityProfilerController.cpp:314] Completed Stage: Warm Up
STAGE:2024-10-17 11:38:42 87880:56463086 ActivityProfilerController.cpp:320] Completed Stage: Collection
STAGE:2024-10-17 11:38:42 87880:56463086 ActivityProfilerController.cpp:324] Completed Stage: Post Processing


-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                          aten::dropout         0.05%      56.000us        35.83%      44.111ms       6.302ms             7  
                                       aten::bernoulli_        33.98%      41.829ms        33.98%      41.829ms       5.976ms             7  
    autograd::engine::evaluate_function: IndexBackward0         0.11%     140.000us        14.23%      17.515ms     547.344us            32  
                                         IndexBackward0         0.08%     101.000us        14.11%      17.375ms     542.969us            32  
      

In [36]:
import torch
import cProfile
import pstats
import io

# Profiling function using cProfile
def profile_one_batch_cProfile(model, dataset, optimizer, loss_function, feature_dicts, batch_size, tasks, device):
    model.train()  # Set the model to training mode
    model = model.to(device)  # Move the model to MPS or CPU

    # Select one batch
    valList = np.arange(dataset.shape[0])
    np.random.shuffle(valList)
    train_batch = valList[:batch_size]

    # Get the batch data
    batch_df = dataset.iloc[train_batch]
    smiles_list = batch_df.cano_smiles.values
    y_val = batch_df[tasks[0]].values

    # Get input arrays from SMILES strings and pre-computed feature dicts
    x_atom, x_bonds, x_atom_index, x_bond_index, x_mask, _ = get_smiles_array(smiles_list, feature_dicts)

    # Move input data to the MPS device
    x_atom = torch.Tensor(x_atom).to(device)
    x_bonds = torch.Tensor(x_bonds).to(device)
    x_atom_index = torch.LongTensor(x_atom_index).to(device)
    x_bond_index = torch.LongTensor(x_bond_index).to(device)
    x_mask = torch.Tensor(x_mask).to(device)
    y_val = torch.Tensor(y_val).view(-1, 1).to(device)

    # Profiler setup
    pr = cProfile.Profile()
    pr.enable()  # Start profiling

    # Forward pass through the model
    _, mol_prediction = model(x_atom, x_bonds, x_atom_index, x_bond_index, x_mask)

    # Zero gradients, backpropagation, and optimizer step
    optimizer.zero_grad()
    loss = loss_function(mol_prediction, y_val)  # Calculate the loss
    loss.backward()  # Backpropagation
    optimizer.step()  # Update the model weights

    # Stop profiling
    pr.disable()

    # Print profiling results
    s = io.StringIO()
    ps = pstats.Stats(pr, stream=s).sort_stats(pstats.SortKey.CUMULATIVE)
    ps.print_stats(100)  # Print top 10 results
    print(s.getvalue())

# Example usage
device = torch.device('cpu')  # Use 'mps' if on Apple Silicon, 'cpu' for CPU
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
loss_function = torch.nn.MSELoss()
batch_size = 32
tasks = ['measured log solubility in mols per litre']

profile_one_batch_cProfile(model, train_df, optimizer, loss_function, feature_dicts, batch_size, tasks, device)


         1823 function calls (1767 primitive calls) in 0.117 seconds

   Ordered by: cumulative time
   List reduced from 168 to 100 due to restriction <100>

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
     24/2    0.000    0.000    0.069    0.034 /Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/torch/nn/modules/module.py:1528(_wrapped_call_impl)
     24/2    0.000    0.000    0.069    0.034 /Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/torch/nn/modules/module.py:1534(_call_impl)
        1    0.004    0.004    0.069    0.069 /Users/tgg/Github/mlx-graphs-last/AttentiveFP.py:432(forward)
        1    0.000    0.000    0.045    0.045 /Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/torch/_tensor.py:466(backward)
        1    0.000    0.000    0.045    0.045 /Users/tgg/miniforge3/envs/mlxgraphenv-py311/lib/python3.11/site-packages/torch/autograd/__init__.py:165(backward)
        1    

In [37]:
torch.__version__

'2.3.1'

In [None]:
torch.mps.profiler.profile

In [39]:
pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
pytorch_total_params

863604