### Linear Classification Model Code 

Linear Classification model using molecular scent dataset from AI crowd (https://www.aicrowd.com/challenges/learning-to-smell)

Code below modified from example code given in the Classification section of "Deep Learning for Molecules and Materials" textbook (https://whitead.github.io/dmol-book/ml/classification.html)

In [None]:
print('Remember to update CUDA_VISIBLE_DEVICES')
#For GPU nodes, edit value below based on allocated GPU
import os
os.environ["CUDA_VISIBLE_DEVICES"]="0"

In [None]:
#Install packages & imports

#!pip install mordred[full]  matplotlib numpy pandas seaborn jax jaxlib wandb 

#Code uses Weights & Biases to log results
import wandb
#If running code in notebook & have not yet logged in w/it into W&B, uncomment lines below
#wandb.login()
#%env "WANDB_NOTEBOOK_NAME" "Linear Classification Model_Latest"

#Other imports
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import matplotlib as mpl
import rdkit, rdkit.Chem, rdkit.Chem.Draw
from rdkit.Chem.Draw import IPythonConsole
import numpy as np
import jax.numpy as jnp
import mordred, mordred.descriptors
import jax.experimental.optimizers as optimizers
import jax
import sklearn.metrics
import warnings
warnings.filterwarnings('ignore')
sns.set_context('notebook')
sns.set_style('dark',  {'xtick.bottom':True, 'ytick.left':True, 'xtick.color': '#666666', 'ytick.color': '#666666',
                        'axes.edgecolor': '#666666', 'axes.linewidth':     0.8 , 'figure.dpi': 300})
color_cycle = ['#1BBC9B', '#F06060', '#5C4B51', '#F3B562', '#6e5687']
mpl.rcParams['axes.prop_cycle'] = mpl.cycler(color=color_cycle) 
np.random.seed(0)

In [None]:
# 1. Start a W&B run
run = wandb.init(project='Linear_Model', entity='aseshad4')

In [None]:
# 2. Save model inputs and hyperparameters
config = wandb.config
config.learning_rate = 0.04
config.numEpochs = 1000
config.regularization = False
config.earlyStopping = False
config.updatedCode = True

In [None]:
#Load training & testing data --> file uploaded to jhub (locally stored)
scentdata = pd.read_csv('train.csv')

#Read in vocabulary text file --> this file gives the all of the scent classes used in dataset
file = open('vocabulary.txt')
#Create list that stores all scent classes
scentClasses = file.read().split('\n')

In [None]:
#Make object that can compute descriptors
calc = mordred.Calculator(mordred.descriptors, ignore_3D=True)
# make subsample from pandas df
allMolecules = [rdkit.Chem.MolFromSmiles(smi) for smi in scentdata.SMILES]

#View one molecule to make sure code works correctly (uncomment line below to test)
#allMolecules[0]

In [None]:
#Create vectors for each molecule in input data that corresponds to the classes it belongs to (scents associated with it)
numClasses = len(scentClasses)
numMolecules = len(allMolecules)
labels = jnp.zeros((numMolecules, numClasses))
for i in range(numMolecules):
    #Create array that contains all scents associated with molecule i
    tempScent = scentdata.SENTENCE[i].split(',')
    #Find class index in label vector that each scent corresponds to & update label for that molecule to 1
    for j in range(len(tempScent)):
        #Find class index
        classIndex = scentClasses.index(tempScent[j])
        #Update label
        labels = jax.ops.index_update(labels, (i,classIndex), 1)

In [None]:
##Test, uncomment code in this cell to test creation of label vectors

##Check that y vector was created correctly (compare label vector for molecule 11 to description of molecule from datset)
##Indices where labels has a 1 (molecule belongs to scent class with that index)

#indices = jnp.argwhere(labels[11]).ravel()
#for i in indices:
    #scentTemp = scentClasses[i]
    #print(scentTemp)
#print(scentdata.SENTENCE[11])

In [None]:
#Compute features
features = calc.pandas(allMolecules)

In [None]:
#Standardize features
features -= features.mean()
features /= features.std()

# we have some nans in features, likely because std was 0
features.dropna(inplace=True, axis=1)

print(f'We have {len(features.columns)} features per molecule')

In [None]:
#Generate testing, training, validation sets
train_N = int(numMolecules * 0.8)
valid_N = int(numMolecules * 0.1)
test_N = numMolecules - train_N - valid_N

batch_size = 32
batch_idx = range(0, train_N, batch_size)

train_data_labels = labels[:train_N]
train_data_features = features[:train_N].values.astype(np.float32)

valid_data_labels = labels[train_N:valid_N+train_N]
valid_data_features = features[train_N:valid_N+train_N].values.astype(np.float32)

test_data_labels = labels[valid_N+train_N:]
test_data_features = features[valid_N+train_N:].values.astype(np.float32)

print(f'Num Molecules: {numMolecules}, test_N: {test_N}, train_N: {train_N}, valid_N: {valid_N}')

In [None]:
def multiLabel_classifier(x,w,b):
    return jax.nn.sigmoid(jnp.dot(x,w) + b)

def cross_entropy(yhat, y):
    return -jnp.mean(y * jnp.log(yhat + 1e-10) + (1 - y) * jnp.log(1 - yhat + 1e-10))
    
def loss_wrapper(w,b,x,y):
    yhat = multiLabel_classifier(x,w,b)
    return jnp.mean(cross_entropy(yhat, y))
    
loss_grad = jax.grad(loss_wrapper, (0,1))

In [None]:
def lossFunc(train_x,train_y, valid_x, valid_y):
    loss_progress = np.zeros(config.numEpochs)
    valid_loss_progress = np.zeros(config.numEpochs)
    eta = config.learning_rate
    w = np.random.normal(scale = 0.01, size = (len(features.columns),109))
    b = np.ones(numClasses)
    
    for epoch in range(config.numEpochs):
        for i in range(len(batch_idx) - 1):    
            x = train_x[batch_idx[i]:batch_idx[i + 1]]
            
            y = train_y[batch_idx[i]:batch_idx[i + 1]]
            
            grad = loss_grad(w, b, x, y)
            w -= eta * grad[0]
            b -= eta * grad[1]
            
            loss_progress[epoch] += loss_wrapper(w, b, x, y)
            #print(f'Training loss for batch{i}: {loss_wrapper(w, b, x, y)}')
            currValidLoss = loss_wrapper(w,b,valid_x,valid_y)
            #print(f'Validation loss for same step: {currValidLoss}')
            valid_loss_progress[epoch] += currValidLoss
  
        numTimesLossComputed = (len(batch_idx) - 1)
        loss_progress[epoch] = loss_progress[epoch]/numTimesLossComputed
        valid_loss_progress[epoch] = valid_loss_progress[epoch]/numTimesLossComputed
        print(f'Training Loss, Epoch {epoch}: {loss_progress[epoch]}')
        print(f'Validation Loss, Epoch {epoch}: {valid_loss_progress[epoch]}')
        
        # 3. Log metrics over time to visualize performance (using Weights & Biases)
        wandb.log({'Training loss': loss_progress[epoch], 'Epoch': epoch})   
        wandb.log({"Validation loss": valid_loss_progress[epoch], 'Epoch': epoch})    
        
    resultsList = [y,w,b]
    return resultsList


In [None]:
#Train model
#Store w & b values from training
results = lossFunc(train_data_features, train_data_labels, valid_data_features , valid_data_labels)
wVals = results[1]
bVals = results[2]

In [None]:
#Classification Metrics - Accuracy (standard & competition) & AUROC (Same functions as what was used in GNN Model code)

#Accuracy function where accuracy is measured as |intersection of true and predicted labels|/|union of true and predicted labels|
def accuracy_fn(w,b, x, y): 
    yhat = multiLabel_classifier(x,w,b)
    true_scentIndices = jnp.nonzero(y)
    # convert from prob to hard class -> positive yhat -> yhat = 1, else 0
    hard_yhat = np.where(yhat > 0, np.ones_like(yhat), np.zeros_like(yhat))
    predicted_scentIndices = jnp.nonzero(hard_yhat)
    correctlyPredicted = len(np.intersect1d(predicted_scentIndices, true_scentIndices))
    numTrueLabels = np.size(true_scentIndices)

    #Total number of labels = number of labels in union of predicted & actual/true labels set
    ##The size of this set = Actual labels - those correctly predicted + all predicted labels
    numPredLabels = np.size(predicted_scentIndices)
    totalLabels = numTrueLabels - correctlyPredicted + numPredLabels
    return correctlyPredicted/totalLabels

#Competition accuracy is measured as |intersection of true and predicted labels for top 3 predictions|/|union of true and predicted labels for top 3 predictions|
def competitionAccuracy_fn(w,b, x, y): 
    yhat = multiLabel_classifier(x,w,b)
    numTrueLabels = jnp.count_nonzero(y) 
    true_scentIndices = jnp.nonzero(y)
    
    pred_sortedIndices = np.argsort(yhat)

    top15Pred = pred_sortedIndices[len(pred_sortedIndices)-15:]
    #Create array storing top 5 predictions
    predictions = np.zeros((5,3))
    #print(pred_sortedIndices)
    #print(f'yhat: {yhat}')
    for j in range(5):
        index = 15 - (j+1)*3
        predictions[j] = top15Pred[index:15-j*3]
        
    numCorrect = np.empty(5)
    for k in range(5):
        numCorrect[k] = len(np.intersect1d(predictions[k], true_scentIndices))
   
    topNumCorrect = np.amax(numCorrect)
    topPredictionSet = np.argmax(numCorrect)
    #Total number of labels is number of labels in union of predicted & actual/true labels set (keep only 3 true labels)
    ##The size of this set = Predicted labels - those accounted for in actual labels + all predicted labels
    totalLabels = 3 + (3-topNumCorrect)
    accuracyComp = topNumCorrect/totalLabels
    #print(f'For Molecule {i}, accuracy is: {accuracyComp}')
    return accuracyComp, topPredictionSet

#Compute accuracy
acc = np.zeros(test_N)
for i in range(test_N):
    yi = test_data_labels[i]
    xi = test_data_features[i]
    accuracy = accuracy_fn(wVals,bVals, xi,yi)
    #print(accuracy)
    acc[i] = accuracy
print(f'Overall Accuracy: {np.mean(acc)}')
wandb.run.summary["Test set accuracy (standard accuracy)"] = np.mean(acc)

#Compute competition accuracy
accCompetition = np.zeros(test_N)
topPredSet = np.empty(test_N)
for i in range(test_N):
    yi = test_data_labels[i]
    xi = test_data_features[i]
    accuracy, topPredSet[i] = competitionAccuracy_fn(wVals,bVals, xi,yi)
    #print(accuracy)
    accCompetition[i] = accuracy
print(f'Overall Accuracy (Competition): {np.mean(accCompetition)}')
wandb.run.summary["Test set accuracy (competition accuracy)"] = np.mean(accCompetition)

top3Pred_correctPercent = 100*((test_N - np.count_nonzero(topPredSet))/test_N)
print(f'Top 3 predictions were best on average (competition accuracy, regular params): {top3Pred_correctPercent}%')

#Compute AUROC using sklearn
test_yhat = np.empty((test_N, numClasses))#create empty array to store predictions on test set
for i in range(test_N):
    yi = test_data_labels[i]
    xi = test_data_features[i]
    test_yhat[i] = multiLabel_classifier(xi,wVals,bVals)

occurrences = np.zeros(numClasses)
for i in range(numClasses):
    occurrences[i] = np.sum(labels[:,i])

aurocs_scikit = []
aurocs_omitUncommonClasses = []
for c in range(numClasses):
    if(np.count_nonzero(test_data_labels[:,c]) == 0):
        print(f'Test set does not have any molecules with scent {scentClasses[c]}')
    else:
        #Uncomment lines below if want to plot/generate ROC curves for each label
        #fpr, tpr, thresholds = sklearn.metrics.roc_curve(test_data_labels[:,c], test_yhat[:,c])
        #plt.plot(fpr, tpr, '-o', label='Trained Model')
        #plt.plot([0,1], [0, 1], label='Naive Classifier')
        #plt.ylabel('True Positive Rate')
        #plt.xlabel('False Positive Rate')
        #plt.title(f'ROC Curve for {scentClasses[c]}')
        #plt.legend()
        #plt.show()
        #plt.savefig(f'GNN_ROC_Curve_{scentClasses[c]}_{runName}.jpg')
        #plt.close()

        auroc = sklearn.metrics.roc_auc_score(test_data_labels[:,c], test_yhat[:,c])
        aurocs_scikit.append(auroc)
        if(occurrences[c] >= 30):
            aurocs_omitUncommonClasses.append(auroc)
            print(f'Included {scentClasses[c]}')
        else:
            print(f'Omitted {scentClasses[c]}')
        print(f'AUROC for scent {scentClasses[c]}: {auroc}')

mean_AUROC = np.mean(aurocs_scikit)
mean_AUROC_omitUncommonScents = np.mean(aurocs_omitUncommonClasses)
print(f'Mean AUROC: {mean_AUROC}')
print(f'Mean AUROC (w/uncommon scent classes omitted): {mean_AUROC_omitUncommonScents}')
wandb.run.summary['Mean AUROC'] = mean_AUROC
wandb.run.summary['Mean AUROC w/uncommon scents omitted'] = mean_AUROC_omitUncommonScents

In [None]:
#Stop W&B run
run.finish()