In [None]:
import pickle
import rdkit
from rdkit import Chem
from collections import Counter
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from __future__ import absolute_import, division, print_function, unicode_literals

# standard python
import numpy as np
import scipy
#import pathlib

import os
# plotting, especially for jupyter notebooks
import matplotlib
#matplotlib.rcParams['text.usetex'] = True # breaks for some endpoint labels
from matplotlib import pyplot as plt
from IPython.display import Image
import pandas as pd
# tensorflow
import tensorflow as tf
#tf.enable_eager_execution() # needed for tf version 1 or it stages operations but does not do them
from tensorflow import keras
from tensorflow.keras import layers, regularizers
tf.keras.backend.clear_session()  # For easy reset of notebook state.
# local routines
#from chemdataprep import load_PDBs,load_countsfromPDB,load_diametersfromPDB,find_chemnames
#from toxmathandler import load_tscores

#checkpoint_path = "/home2/ajgreen4/Read-Across_w_GAN/Models/cp.ckpt"
#checkpoint_dir = os.path.dirname(checkpoint_path)

from collections import defaultdict

print("tensorflow version",tf.__version__,". Executing eagerly?",tf.executing_eagerly())
print("Number of GPUs: ", len(tf.config.experimental.list_physical_devices('GPU')))

# Load Data

In [None]:
# Load training data, val and test data
file_path = 'train.pkl'
with open(file_path, 'rb') as file:
    train_df = pickle.load(file)
#print(train_df)
# good fit = lower negative number
# more negative docking score = better fit 


# Load validation data
file_path = 'validation_small_enantiomers_stable_full_screen_docking_MOL_margin3_49878_10368_5184.pkl'
with open(file_path, 'rb') as file:
    val_df = pickle.load(file)
#print(val_df)


# Load test data
file_path = 'test_small_enantiomers_stable_full_screen_docking_MOL_margin3_50571_10368_5184.pkl'
with open(file_path, 'rb') as file:
    test_df = pickle.load(file)
#print(test_df)

# Get atoms in list

In [None]:
train_df.head(20)

In [None]:

def atomic_number_to_symbol(atomic_number):
    periodic_table = {
        1: 'H', 6: 'C', 7: 'N', 8: 'O', 9: 'F',
        15: 'P', 16: 'S', 17: 'Cl', 35: 'Br', 53: 'I'
    }
    return periodic_table.get(atomic_number, 'Unknown')

# Define dataset paths
datasets = {
    'train': 'train.pkl',
    'val': 'validation_small_enantiomers_stable_full_screen_docking_MOL_margin3_49878_10368_5184.pkl',
    'test': 'test_small_enantiomers_stable_full_screen_docking_MOL_margin3_50571_10368_5184.pkl'
}

mollists = {}

for dataset_name, file_path in datasets.items():
    with open(file_path, 'rb') as file:
        df = pickle.load(file)
    
    mollist = []
    for mol in df['rdkit_mol_cistrans_stereo']:
        molecule = []
        if mol is not None and isinstance(mol, Chem.Mol):
            if mol.GetNumConformers() > 0:
                conf = mol.GetConformer(0)
                for atom in mol.GetAtoms():
                    atomic_number = atom.GetAtomicNum()
                    atom_type = atomic_number_to_symbol(atomic_number)
                    pos = conf.GetAtomPosition(atom.GetIdx())
                    coordinates = np.array([pos.x, pos.y, pos.z])
                    molecule.append((atom_type, coordinates))
        mollist.append(molecule)
    
    mollists[dataset_name] = mollist
    print(f"{dataset_name.capitalize()} Sample (first molecule):")
    print(mollist[0] if mollist else "No valid molecules found.\n")


In [None]:
train_mollist = mollists['train']
val_mollist = mollists['val']
test_mollist = mollists['test']

# Train test split

In [None]:
%%time
X_train = train_mollist    
y_train = train_df['top_score']

X_val = val_mollist             
y_val = val_df['top_score']    

X_test = test_mollist          
y_test = test_df['top_score']  

In [None]:
# Dictionary for availabel atoms
def speciesmap(atom_type):
    atom_to_number = {
        'H': 1,   # Hydrogen
        'C': 6,   # Carbon
        'N': 7,   # Nitrogen
        'O': 8,   # Oxygen
        'F': 9,   # Fluorine
        'P': 15,  # Phosphorus
        'S': 16,  # Sulfur
        'Cl': 17, # Chlorine
        'Br': 35, # Bromine
        'I': 53   # Iodine
    }
    return np.array([atom_to_number.get(atom_type, 0)])  # Returns 0 if atom type is not recognized

In [None]:
%%time

'''
Get weight and Views for training, validation and test dataset
'''
from qm7_weightedviews_train import load_qm7_data
from qm7_weightedviews_val import load_qm7_data
from qm7_weightedviews_test import load_qm7_data


ws_train, vs_train, Natoms_train, Nviews_train = load_qm7_data(train_mollist, speciesmap, setNatoms=None, setNviews=None, carbonbased=False, verbose=0)
chiral_train = [ws_train, vs_train]

# Featurize validation data
ws_val, vs_val, Natoms_val, Nviews_val = load_qm7_data(val_mollist, speciesmap, setNatoms=29, setNviews=29, carbonbased=False, verbose=0)
chiral_val = [ws_val, vs_val]

# Featurize test data
ws_test, vs_test, Natoms_test, Nviews_test = load_qm7_data(test_mollist, speciesmap, setNatoms=29, setNviews=29, carbonbased=False, verbose=0)
chiral_test = [ws_test, vs_test]

# Alternative way

In [None]:
# Load weights and views
ws_train = np.load('ws_train.npy')
vs_train = np.load('vs_train.npy')
ws_val = np.load('ws_val.npy')
vs_val = np.load('vs_val.npy')
ws_test = np.load('ws_test.npy')
vs_test = np.load('vs_test.npy')

# Load dimensions
Natoms_train = np.load('Natoms_train.npy')
Nviews_train = np.load('Nviews_train.npy')
Natoms_val = np.load('Natoms_val.npy')
Nviews_val = np.load('Nviews_val.npy')
Natoms_test = np.load('Natoms_test.npy')
Nviews_test = np.load('Nviews_test.npy')

chiral_train = [ws_train, vs_train]
chiral_val = [ws_val, vs_val]
chiral_test = [ws_test, vs_test]

In [None]:
# shape of weights and views

print(ws_train.shape)
print(vs_train.shape)
print(ws_val.shape)
print(vs_val.shape)
print(ws_test.shape)
print(vs_test.shape)

In [None]:

labelsG_train = y_train 
labelsG_val = y_val     
labelsG_test = y_test   

Ntoxicity = 3  

ws_train, vs_train = ws_train, vs_train 
ws_val, vs_val = ws_val, vs_val         
ws_test, vs_test = ws_test, vs_test     

dataG_train = [ws_train, vs_train]
dataG_val = [ws_val, vs_val]              
dataG_test = [ws_test, vs_test]

# Neural Network Code

In [None]:
# generic dense NN
def multiDense(Nin,Nout,Nhidden,widthhidden=None,kernel_regularizer=None):
    """Construct a basic NN with some dense layers.
    
    :parameter Nin: The number of inputs
    :type Nin: int
    :parameter Nout: The number of outputs
    :type Nout: int
    :parameter Nhidden: The number of hidden layers.
    :type Nhidden: int
    :parameter widthhidden: The width of each hidden layer.
        If left at None, Nin + Nout will be used.
    :parameter kernel_regularizer: the regularizer to use, such as regularizers.l2(0.001)
    :type kernel_regularizer: tensorflow.keras.regularizers.xxx
    :returns: The NN model
    :rtype: keras.Model
    
    """
    if widthhidden is None:
        widthhidden = Nin + Nout
    x = inputs = keras.Input(shape=(Nin,), name='multiDense_input')
    if kernel_regularizer is not None:
        print("Using regularization")
    for i in range(Nhidden):
        x = layers.Dense(widthhidden, activation='relu', kernel_regularizer=kernel_regularizer,name='dense'+str(i))(x)
#        x = layers.Dense(widthhidden, name='dense'+str(i))(x)
#        x = tf.nn.leaky_relu(x, alpha=0.05)
#    outputs = layers.Dense(Nout, activation='linear',name='multiDense_output')(x)
    outputs = layers.Dense(Nout,name='multiDense_output')(x)
    #outputs = tf.nn.leaky_relu(outputs, alpha=0.05)
    return keras.Model(inputs=inputs, outputs=outputs)#, name='multiDense')
if 1:
    # manual check of multiDense
    mmd = multiDense(116,3,4,256)
    mmd.summary()
    # used to do the weighted sum over views

In [None]:
def parallelwrapper(Nparallel,basemodel,insteadmax=False):
    """Construct a model that applies a basemodel multiple times and take a weighted sum (or max) of the result.
    
    :parameter Nparallel: The number of times to apply in parallel
    :type Nparallel: int
    :parameter basemodel: a keras.Model inferred to have Nin inputs and Nout outputs.
    :type basemodel: a keras.Model
    :parameter insteadmax: If True, take the max of the results of the basemodel instead of the weighted sum.
        For compatibility, the model is still constructed with weights as inputs, but it ignores them.
    :type insteadmax: Boolean
    :returns: model with inputs shape [(?,Nparallel),(?,Nin,Nparallel)] and outputs shape (?,Nout).
        The first input is the scalar weights in the sum.
    :rtype: keras.Model
    
    Note: We could do a max over the parallel applications instead of or in addition to the weighted sum.
    
    """
    # infer shape of basemodel inputs and outputs
    Nin =  basemodel.inputs[0].shape[1]
    Nout =  basemodel.outputs[0].shape[1]
    
    # Apply basemodel Nparallel times in parallel
    # create main input (?,Nparallel,Nin) 
    parallel_inputs = keras.Input(shape=(Nparallel,Nin), name='parallelwrapper_input0')
    # apply base NN to each parallel slice; outputs (?,Nparallel,Nout)
    if False:
        # original version, stopped working at some tensorflow update
        xb = basemodel(parallel_inputs) # worked in earlier tensorflow
        #xb = tf.map_fn(basemodel,parallel_inputs) # another version that fails
    else:
        # newer version, works but makes summary and graphing cumbersome
        # unstack in the Nparallel directio
        parallel_inputsunstacked = tf.keras.ops.unstack(parallel_inputs, Nparallel, 1)
        # apply base NN to each 
        xbunstacked = [basemodel(x) for x in parallel_inputsunstacked]
        # re-stack
        xb = tf.keras.ops.stack(xbunstacked,axis=1)
    
    # create input scalars for weighted sun (?,Nparallel)
    weight_inputs = keras.Input(shape=(Nparallel,), name='parallelScalars')
    if insteadmax:
        # take max over the Nparallel direction to get (?,1,Nout)
        out = layers.MaxPool1D(pool_size=Nparallel)(xb)
        # reshape to (?,Nout)
        out = layers.Reshape((Nout,))(out)
    else:
        # do a weighted sum over the Nparallel direction to get (?,Nout)
        out = layers.Dot((-2,-1))([xb,weight_inputs])
    
    return keras.Model(inputs=[weight_inputs,parallel_inputs], outputs=out, name='parallelwrapper')
if 1:
    # manual check
    mmd = multiDense(116,3,4,256)
    mpw = parallelwrapper(29,mmd,insteadmax=0)
    mpw.summary()
    # make models

In [None]:
def init_generator(data, labels, baselayers, Nfeatures, endlayers, Nviews, base_regularizer=None, end_regularizer=None):
    """Initialize the generator neural net.
    
    :parameter Nviews: Number of views/parallel applications
    :type Nviews: int
    :returns: return generator and descrimina NN.
    :rtype: keras.Model
    
    """
    ## Option changing how results of each view are aggregated
    insteadmax = False # Does weighted average; original design
    #insteadmax = True # Does max instead of weighted average (for both G and D)

    # G
    # base NN
    Gbase = multiDense(data[1].shape[2], Nfeatures, baselayers, kernel_regularizer=base_regularizer) 
    # parallel view wrapper
    Gpw = parallelwrapper(Nviews, Gbase, insteadmax)  # Now using the passed Nviews parameter
    # features to toxicity
    #Gft = multiDense(Nfeatures, labels.shape[1], endlayers, kernel_regularizer=end_regularizer)
    #we got an error when running original line since energies are one dimnsional array so we can interpret the single dimension as having only one dimension
    Gft = multiDense(Nfeatures, 1, endlayers, kernel_regularizer=end_regularizer) 
    # string together
    generator = keras.Model(inputs=Gpw.inputs, outputs=Gft(Gpw.outputs), name='generator')
    # make trainable
    generator.compile(optimizer='adam', loss='mse')
    #generator.summary()
    # previously did better with Nfeatures=Ntoxicity and no Gft

    if 0:
        # sanity checks that model is working
        print("Sanity check:")
        ws, vs = data
        gbv0call = Gbase(vs[:,0,:]).numpy()
        gbv0predict = Gbase.predict(vs[:,0,:])
        print("base: 0 ?==", np.linalg.norm(gbv0call-gbv0predict))
        gpwcall = Gpw([ws,vs]).numpy()
        gpwpredict = G
        pw.predict([ws,vs])
        print("wrapper: 0 ?==", np.linalg.norm(gpwcall-gpwpredict))
        gencall = generator([ws,vs]).numpy()
        genpredict = generator.predict([ws,vs])
        print("whole: 0 ?==", np.linalg.norm(gencall-genpredict))
        
    return generator

# Set your fixed Nviews value
Nviews = 29  # Your chosen fixed value

baselayers = 2  # hidden layers before weighted sum
base_reg = 0  # regularization for the base layers
Nfeatures = 3  # number of outputs of weighted sum
endlayers = 2   # hidden layers after weighted sum
end_reg = 0.1  # regularization for the end layers

if base_reg == 0:
    base_regularizer = None
else:
    base_regularizer = regularizers.l2(base_reg)

if end_reg == 0:
    end_regularizer = None
else:
    end_regularizer = regularizers.l2(end_reg)

print("(baselayers, base_reg, Nfeatures, endlayers, end_reg, Nviews) =",
      (baselayers, base_reg, Nfeatures, endlayers, end_reg, Nviews))

# compile model with options - NOW INCLUDING Nviews
generator = init_generator(dataG_train, labelsG_train, baselayers, Nfeatures, endlayers, Nviews,
                           base_regularizer=base_regularizer, end_regularizer=end_regularizer)
generator.compile(optimizer='adam', loss='mse')



In [None]:

from collections import defaultdict

def get_enantiomer_pairs(df):
    """
    Extract enantiomer pairs from dataframe based on SMILES_nostereo
    first create groups by same enantiomer where molecules with same enantiomer are in same group
    """
    
    pairs = []
    
    smiles_groups = df.groupby('SMILES_nostereo')
    
    for smiles, group in smiles_groups:
        if len(group) >= 2:
            score_groups = group.groupby('top_score')
            enantiomer_data = []
            
            for score, score_group in score_groups:
                enantiomer_data.append({
                    'conformers': score_group.index.tolist(),
                    'top_score': score,
                    'size': len(score_group)
                })
            
            if len(enantiomer_data) == 2:
                # randomly assign which is enantiomer1 and enantiomer2
                if np.random.random() > 0.5:
                    # Assign group 0 to enantiomer1, group 1 to enantiomer2
                    pairs.append({
                        'enantiomer1_conformers': enantiomer_data[0]['conformers'],
                        'enantiomer2_conformers': enantiomer_data[1]['conformers'],
                        'enantiomer1_score': enantiomer_data[0]['top_score'],
                        'enantiomer2_score': enantiomer_data[1]['top_score'],
                        'smiles_nostereo': smiles
                    })
                else:
                    # Assign group 1 to enantiomer1, group 0 to enantiomer2 
                    pairs.append({
                        'enantiomer1_conformers': enantiomer_data[1]['conformers'],
                        'enantiomer2_conformers': enantiomer_data[0]['conformers'],
                        'enantiomer1_score': enantiomer_data[1]['top_score'],
                        'enantiomer2_score': enantiomer_data[0]['top_score'],
                        'smiles_nostereo': smiles
                    })
    
    print(f"Found {len(pairs)} enantiomer pairs")
    return pairs

def create_enantiomer_batch(pairs, data, labels, batch_size=32):
    """
    pick a conformer from one bucket and other from second bucket
    """
    batch_ws = []
    batch_vs = []
    batch_scores = []
    
    n_pairs = len(pairs)
    if n_pairs < batch_size // 2:
        selected_pairs = np.random.choice(pairs, size=batch_size//2, replace=True)
    else:
        selected_pairs = np.random.choice(pairs, size=batch_size//2, replace=False) # with replacement 
    
    for pair in selected_pairs:
        conf1_idx = np.random.choice(pair['enantiomer1_conformers'])
        conf2_idx = np.random.choice(pair['enantiomer2_conformers'])
        
        batch_ws.extend([data[0][conf1_idx], data[0][conf2_idx]])
        batch_vs.extend([data[1][conf1_idx], data[1][conf2_idx]])
        batch_scores.extend([labels[conf1_idx], labels[conf2_idx]])
    
    return np.array(batch_ws), np.array(batch_vs), np.array(batch_scores).reshape(-1, 1)

def train_with_enantiomer_sampling(generator, dataG_train, labelsG_train, train_df, 
                                 dataG_val, labelsG_val, val_df, epochs=10, batch_size=32):
    """
    
    """
    enantiomer_pairs_train = get_enantiomer_pairs(train_df) #extract pairs from training data
    
    if not enantiomer_pairs_train:
        raise ValueError("No enantiomer pairs found in training data!")
    
    best_val_accuracy = 0
    best_weights = None
    patience = 1
    patience_counter = 0
    
    num_batches_per_epoch = max(1, len(enantiomer_pairs_train) // (batch_size // 2))
    
    for epoch in range(epochs):
        print(f"Epoch {epoch+1}/{epochs}")
        
        epoch_loss = 0.0
        
        for batch_idx in range(num_batches_per_epoch):
            batch_ws, batch_vs, batch_scores = create_enantiomer_batch(
                enantiomer_pairs_train, dataG_train, labelsG_train, batch_size
            )
            
            loss = generator.train_on_batch([batch_ws, batch_vs], batch_scores)
            
            if isinstance(loss, (list, tuple)):
                loss = float(loss[0])
            else:
                loss = float(loss)
            epoch_loss += loss
        
        avg_loss = epoch_loss / num_batches_per_epoch
        print(f"Training loss: {avg_loss:.4f}")
        
        val_accuracy = evaluate_ranking_accuracy(generator, dataG_val, labelsG_val, val_df)
        print(f"Validation ranking accuracy: {val_accuracy:.4f}")
        
        if val_accuracy > best_val_accuracy:
            best_val_accuracy = val_accuracy
            best_weights = generator.get_weights()
            patience_counter = 0
            print(f"New best model saved with accuracy: {best_val_accuracy:.4f}")
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print(f"Early stopping at epoch {epoch+1}")
                break
    
    if best_weights is not None:
        generator.set_weights(best_weights)
    return generator, best_val_accuracy

def evaluate_ranking_accuracy(model, data, labels, df):
    """
    Evaluate ranking accuracy 
    """
    print("Getting predictions for all conformers...")
    
    batch_size = 32
    all_predictions = []
    
    for i in range(0, len(data[0]), batch_size):
        batch_ws = data[0][i:i+batch_size]
        batch_vs = data[1][i:i+batch_size]
        batch_preds = model.predict([batch_ws, batch_vs], verbose=0)
        all_predictions.extend(batch_preds.flatten())
    
    all_predictions = np.array(all_predictions)
    
    pairs = get_enantiomer_pairs(df)
    
    correct_rankings = 0
    total_pairs = len(pairs)
    
    for pair in pairs:
        preds1 = [all_predictions[idx] for idx in pair['enantiomer1_conformers']]
        preds2 = [all_predictions[idx] for idx in pair['enantiomer2_conformers']]
        
        avg_pred1 = np.mean(preds1)
        avg_pred2 = np.mean(preds2)
        
        true_score1 = pair['enantiomer1_score']
        true_score2 = pair['enantiomer2_score']
        
        pred_ranking = avg_pred1 < avg_pred2
        true_ranking = true_score1 < true_score2
        
        if pred_ranking == true_ranking:
            correct_rankings += 1
    
    accuracy = correct_rankings / total_pairs if total_pairs > 0 else 0
    print(f"Evaluated {total_pairs} pairs, {correct_rankings} correct rankings")
    return accuracy



print("Preparing data for training...")
dataG_train = [np.array(dataG_train[0], dtype='float32'), np.array(dataG_train[1], dtype='float32')]
labelsG_train = np.array(labelsG_train, dtype='float32').reshape(-1, 1)

dataG_val = [np.array(dataG_val[0], dtype='float32'), np.array(dataG_val[1], dtype='float32')]
labelsG_val = np.array(labelsG_val, dtype='float32').reshape(-1, 1)

dataG_test = [np.array(dataG_test[0], dtype='float32'), np.array(dataG_test[1], dtype='float32')]
labelsG_test = np.array(labelsG_test, dtype='float32').reshape(-1, 1)

print("Starting training with enantiomer pair sampling...")
print(f"Training data: {len(train_df)} conformers")
print(f"Validation data: {len(val_df)} conformers") 
print(f"Test data: {len(test_df)} conformers")

trained_generator, best_val_accuracy = train_with_enantiomer_sampling(
    generator, dataG_train, labelsG_train, train_df,
    dataG_val, labelsG_val, val_df,
    epochs=10, batch_size=32
)

print(f"Best validation accuracy: {best_val_accuracy:.4f}")

print("\nEvaluating on test set...")
test_accuracy = evaluate_ranking_accuracy(trained_generator, dataG_test, labelsG_test, test_df)
print(f"Final test ranking accuracy: {test_accuracy:.4f}")


print("\nSaving the best model...")
model_path = "best_docking_model.keras" 
if os.path.exists(model_path):
    os.remove(model_path)
    print(f"Removed existing model file: {model_path}")

trained_generator.save(model_path)
print("Best model saved as 'best_docking_model.keras'")


