# Contextual Bayesian Optimisation via Large Language Models

This notebook will:
- Demonstrate the works of [BO-Lift](https://github.com/ur-whitelab/BO-LIFT), which focusses on few-shot/in-context learning (FS/ICL) for estimating the aqueous solubility (ESOL--Estimated SOLubility) of a compound and also yield calculations from chemical compound interactions. The advantages of ICL are demonstrated here: [ICL-Wiki](https://en.wikipedia.org/wiki/In-context_learning_(natural_language_processing)). 
- Show attempts of extending and evaluating the works of [BO-Lift](https://arxiv.org/pdf/2304.05341.pdf) via (non-exhaustive):
1. Implementation of advanced contextual prompting (not simply just compound+solubility or compound+yield) through automatic feature engineering (specifically selection).
2. Test comparison between BO-LIFt and new framework.
3. Contextual Bayesian Optimisation (EXTENSION).

<DIV STYLE="background-color:#000000; height:10px; width:100%;">

# Import Libraries

In [2]:
# Standard Library
import os
import itertools
import copy
import tqdm

# Third Party
import numpy as np
import pandas as pd

# Private
from model import CEBO, Pool

In [3]:
# Default OpenAI API Key
os.environ["OPENAI_API_KEY"] = "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"

# Data Preparation

The original paper used data corresponding to this paper [ESOL](https://www.researchgate.net/publication/8551133_ESOL_Estimating_Aqueous_Solubility_Directly_from_Molecular_Structure) - this corresponds to only 927 examples with 7 columns, with only 3 being important. This is not enough information for us to compare alternative techniques, hence we will use the larger dataset, provided from Kaggle, which incorporates more information about these molecules [AqSOL](https://www.kaggle.com/datasets/sorkun/aqsoldb-a-curated-aqueous-solubility-dataset?resource=download).

In [4]:
# Load AqSolDB data
aqsoldb_df = pd.read_csv("data/aqsoldb.csv")

In [5]:
# Clean data
aqsoldb_df = aqsoldb_df.dropna()
aqsoldb_df = aqsoldb_df.drop_duplicates().reset_index(drop=True)
aqsoldb_df.rename(columns={'Name': 'Compound ID'}, inplace=True)
aqsoldb_df = aqsoldb_df.drop(["ID"], axis=1)

Given the token length of the OpenAI language models, we will work with chemical compounds which have a length of less than 15.

In [6]:
# Keep compounds that "read" easily
aqsoldb_df = aqsoldb_df[aqsoldb_df["Compound ID"].str.len()<15].reset_index(drop=True)
aqsoldb_df

Unnamed: 0,Compound ID,InChI,InChIKey,SMILES,Solubility,SD,Ocurrences,Group,MolWt,MolLogP,...,NumRotatableBonds,NumValenceElectrons,NumAromaticRings,NumSaturatedRings,NumAliphaticRings,RingCount,TPSA,LabuteASA,BalabanJ,BertzCT
0,vinyltoluene,"InChI=1S/C9H10/c1-3-9-6-4-5-8(2)7-9/h3-7H,1H2,2H3",JZHGRUMIRATHIU-UHFFFAOYSA-N,Cc1cccc(C=C)c1,-3.123150,0.000000,1,G1,118.179,2.63802,...,1.0,46.0,1.0,0.0,0.0,1.0,0.00,55.836626,3.070761,211.033225
1,hydroxylamine,"InChI=1S/H3NO/c1-2/h2H,1H2",AVXURJPOCDRRFD-UHFFFAOYSA-N,NO,-0.763034,0.861298,7,G4,33.030,-0.66570,...,0.0,14.0,0.0,0.0,0.0,0.0,46.25,12.462472,1.000000,2.000000
2,molybdenum,InChI=1S/Mo,ZOKXTWBITQBERF-UHFFFAOYSA-N,[Mo],-4.203848,0.000000,1,G1,95.940,-0.00250,...,0.0,6.0,0.0,0.0,0.0,0.0,0.00,21.756566,0.000000,0.000000
3,Prednisolone,InChI=1S/C21H28O5/c1-19-7-5-13(23)9-12(19)3-4-...,OIGNJSKKLXVSLS-VWUMJDOOSA-N,C[C@]12C[C@H](O)[C@H]3[C@@H](CCC4=CC(=O)C=C[C@...,-3.178447,0.015047,2,G3,360.450,1.55760,...,2.0,142.0,0.0,3.0,4.0,4.0,94.83,153.341308,1.747281,723.913082
4,fluoromethane,InChI=1S/CH3F/c1-2/h1H3,NBVXSUQYWXRMNV-UHFFFAOYSA-N,CF,-0.175874,0.000000,1,G1,34.033,0.58570,...,0.0,14.0,0.0,0.0,0.0,0.0,0.00,12.904786,1.000000,2.000000
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1888,tetracaine,InChI=1S/C15H24N2O2/c1-4-5-10-16-14-8-6-13(7-9...,GKCBAIGFKIBETG-UHFFFAOYSA-N,C(c1ccc(cc1)NCCCC)(=O)OCCN(C)C,-3.010000,0.000000,1,G1,264.369,2.61700,...,8.0,106.0,1.0,0.0,0.0,1.0,41.57,115.300645,2.394548,374.236893
1889,tetracycline,InChI=1S/C22H24N2O8/c1-21(31)8-5-4-6-11(25)12(...,OFVLGDICTFRJMM-WESIUVDSSA-N,OC1=C(C(C2=C(O)[C@@](C(C(C(N)=O)=C(O)[C@H]3N(C...,-2.930000,0.000000,1,G1,444.440,-0.21440,...,2.0,170.0,1.0,0.0,3.0,4.0,181.62,182.429237,2.047922,1148.584975
1890,thymol,InChI=1S/C10H14O/c1-7(2)9-5-4-8(3)6-10(9)11/h4...,MGSRCZKZVOBKFT-UHFFFAOYSA-N,c1(cc(ccc1C(C)C)C)O,-2.190000,0.019222,3,G5,150.221,2.82402,...,1.0,60.0,1.0,0.0,0.0,1.0,20.23,67.685405,3.092720,251.049732
1891,verapamil,"InChI=1S/C27H38N2O4/c1-20(2)27(19-28,22-10-12-...",SGTNSNPWRIOYBX-UHFFFAOYSA-N,COc1ccc(CCN(C)CCCC(C#N)(C(C)C)c2ccc(OC)c(OC)c2...,-3.980000,0.000000,1,G1,454.611,5.09308,...,13.0,180.0,2.0,0.0,0.0,2.0,63.95,198.569223,2.023333,938.203977


In [7]:
# Use a smaller dataset framework
mini_df = aqsoldb_df.sample(n=1000, random_state=42).reset_index(drop=True)
mini_df

Unnamed: 0,Compound ID,InChI,InChIKey,SMILES,Solubility,SD,Ocurrences,Group,MolWt,MolLogP,...,NumRotatableBonds,NumValenceElectrons,NumAromaticRings,NumSaturatedRings,NumAliphaticRings,RingCount,TPSA,LabuteASA,BalabanJ,BertzCT
0,fluthiamide,InChI=1S/C14H13F4N3O2S/c1-8(2)21(10-5-3-9(15)4...,IANUJLZYFUDJIH-UHFFFAOYSA-N,CC(C)N(C(=O)COc1sc(nn1)C(F)(F)F)c2ccc(F)cc2,-3.812100,0.00000,1,G1,363.336,3.5164,...,5.0,130.0,2.0,0.0,0.0,2.0,55.32,138.423305,2.058571,700.790271
1,n-undecane,InChI=1S/C11H24/c1-3-5-7-9-11-10-8-6-4-2/h3-11...,RSJKGSCJYJTIGS-UHFFFAOYSA-N,CCCCCCCCCCC,-7.550500,0.00000,1,G1,156.313,4.5371,...,8.0,68.0,0.0,0.0,0.0,0.0,0.00,72.388672,2.690916,49.058650
2,hexaconazole,"InChI=1S/C14H17Cl2N3O/c1-2-3-6-14(20,8-19-10-1...",STMIIPIFODONDC-UHFFFAOYSA-N,CCCCC(O)(Cn1cncn1)c2ccc(Cl)cc2Cl,-4.266800,0.00000,1,G1,314.216,3.6629,...,6.0,108.0,2.0,0.0,0.0,2.0,50.94,127.905798,2.340948,559.551589
3,3-nitrophenol,"InChI=1S/C6H5NO3/c8-6-3-1-2-5(4-6)7(9)10/h1-4,8H",RTZZCYNQPHTPPL-UHFFFAOYSA-N,OC1=CC(=CC=C1)[N+]([O-])=O,-1.070000,0.02850,2,G3,139.110,1.3004,...,1.0,52.0,1.0,0.0,0.0,1.0,63.37,56.878613,3.069899,258.924771
4,resorufin,InChI=1S/C12H7NO3/c14-7-1-3-9-11(5-7)16-12-6-8...,HSSLDCABUXLXKM-UHFFFAOYSA-N,Oc1ccc2N=C3C=CC(=O)C=C3Oc2c1,-1.027700,0.00000,1,G1,213.192,1.9984,...,0.0,78.0,1.0,0.0,2.0,3.0,63.33,90.129407,2.456794,702.198677
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
995,3-penten-2-ol,"InChI=1S/C5H10O/c1-3-4-5(2)6/h3-6H,1-2H3/b4-3+",GJYMQFMQRRNLCY-ONEGZZNKSA-N,C/C=C/C(C)O,0.015200,0.00000,1,G1,86.134,0.9433,...,1.0,36.0,0.0,0.0,0.0,0.0,20.23,38.303650,3.032355,45.900135
996,molybdenum,InChI=1S/Mo,ZOKXTWBITQBERF-UHFFFAOYSA-N,[Mo],-4.203848,0.00000,1,G1,95.940,-0.0025,...,0.0,6.0,0.0,0.0,0.0,0.0,0.00,21.756566,0.000000,0.000000
997,pyrene,InChI=1S/C16H10/c1-3-11-7-9-13-5-2-6-14-10-8-1...,BBEAQIROQSPTKN-UHFFFAOYSA-N,c1cc2ccc3cccc4ccc(c1)c2c34,-6.178797,0.05882,3,G5,202.256,4.5840,...,0.0,74.0,4.0,0.0,0.0,4.0,0.00,93.455422,2.505956,666.619806
998,Benzoin,InChI=1S/C13H10O2/c14-13(11-7-3-1-4-8-11)15-12...,FCJSHPDYVMKCHI-UHFFFAOYSA-N,c1ccccc1C(=O)Oc2ccccc2,-2.850000,0.00000,1,G1,198.221,2.9058,...,2.0,74.0,2.0,0.0,0.0,2.0,26.30,88.128507,2.116422,434.607837


# Tell-Predict Experimentation

In [None]:
# Hyperparameters
T_list = [0.7]
k_list = [5]
train_num_list = [5, 15, 25, 35, 45]
test_num_list = [10]
models_list = ["curie", "davinci"]

In [None]:
def abalation_study(T_list, k_list, train_num_list, test_num_list, models_list, data):
    # Store results
    bo_lift_results = []
    cebo_lift_results_1 = []
    cebo_lift_results_2 = []
    cebo_lift_results_3 = []
    cebo_lift_results_4 = []
    # Loop
    for T, k, num_train, num_test, model in itertools.product(T_list, k_list, train_num_list, test_num_list,
                                                              models_list):
        bo_lift_result = []
        cebo_lift_result_1 = []
        cebo_lift_result_2 = []
        cebo_lift_result_3 = []
        cebo_lift_result_4 = []
        for _ in range(10):
            # Create data
            shuffled_df = data.sample(frac=1, random_state=42)
            train_df = shuffled_df.iloc[:num_train]
            test_df = shuffled_df.iloc[num_train:].head(num_test)
            # Create the model object
            bo_lift = AskTellFewShotTopk(x_formatter=lambda x: f"compound id {x}",
                                         y_name="solubility",
                                         y_formatter=lambda y: f"{y:.6f}",
                                         model=model,
                                         selector_k=k,
                                         temperature=0.7)
            cebo_lift_1 = AskTellFewShotTopk(x_formatter=lambda x: f"compound id {x}",
                                             y_name="solubility",
                                             y_formatter=lambda y: f"{y:.6f}",
                                             model=model,
                                             selector_k=k,
                                             temperature=T,
                                             prefix=(f"You are an expert chemist. "
                                                     "The following are correctly answered questions. "
                                                     "Each answer is numeric and ends with ###\n"))
            cebo_lift_2 = CEBO(y_name="solubility",
                               model=model,
                               selector_k=k,
                               temperature=T,
                               domain=None,
                               features=["MolLogP", "MolMR"])
            cebo_lift_3 = CEBO(y_name="solubility",
                               model=model,
                               selector_k=k,
                               temperature=T,
                               domain="chemist",
                               features=["MolLogP", "MolMR"])
            cebo_lift_4 = CEBO(y_name="solubility",
                               model=model,
                               selector_k=k,
                               temperature=T,
                               domain="chemist",
                               features=["Ocurrences", "SD"])
            # Tell some points to the model
            for _, row in train_df.iterrows():
                bo_info = row[["Compound ID"] + feature + ["Solubility"]]
                bo_lift.tell(bo_info[0], bo_info[-1])
                cebo_lift_1.tell(bo_info[0], bo_info[-1])
                cebo_lift_2.tell(bo_info.to_dict())
                cebo_lift_3.tell(bo_info.to_dict())
                cebo_lift_4.tell(bo_info.to_dict())
            # Predict remaining points
            bo_lift_y_pred = [bo_lift.predict(row["Compound ID"]) for _, row in test_df.iterrows()]
            cebo_lift_y_pred_1 = [cebo_lift_1.predict(row["Compound ID"]) for _, row in test_df.iterrows()]
            cebo_lift_y_pred_2 = [cebo_lift_2.predict(row[["Compound ID"] + feature].to_dict()) for _, row in test_df.iterrows()]
            cebo_lift_y_pred_3 = [cebo_lift_3.predict(row[["Compound ID"] + feature].to_dict()) for _, row in test_df.iterrows()]
            cebo_lift_y_pred_4 = [cebo_lift_4.predict(row[["Compound ID"] + feature].to_dict()) for _, row in test_df.iterrows()]
            # Modify results
            bo_lift_y_pred_modify = [sol.mean() if len(sol) >= 1 else np.nan for sol in bo_lift_y_pred]
            cebo_lift_y_pred_modify_1 = [sol.mean() if len(sol) >= 1 else np.nan for sol in cebo_lift_y_pred_1]
            cebo_lift_y_pred_modify_2 = [sol.mean() if len(sol) >= 1 else np.nan for sol in cebo_lift_y_pred_2]
            cebo_lift_y_pred_modify_3 = [sol.mean() if len(sol) >= 1 else np.nan for sol in cebo_lift_y_pred_3]
            cebo_lift_y_pred_modify_4 = [sol.mean() if len(sol) >= 1 else np.nan for sol in cebo_lift_y_pred_4]
            # Store values
            bo_lift_result.append({"T": T,
                                   "k": k,
                                   "Train": num_train,
                                   "Test": num_test,
                                   "Model": model,
                                   "Predictions": bo_lift_y_pred_modify
                                   })
            cebo_lift_result_1.append({"T": T,
                                       "k": k,
                                       "Train": num_train,
                                       "Test": num_test,
                                       "Model": model,
                                       "Predictions": cebo_lift_y_pred_modify_1
                                      })
            cebo_lift_result_2.append({"T": T,
                                       "k": k,
                                       "Train": num_train,
                                       "Test": num_test,
                                       "Model": model,
                                       "Predictions": cebo_lift_y_pred_modify_2
                                      })
            cebo_lift_result_3.append({"T": T,
                                       "k": k,
                                       "Train": num_train,
                                       "Test": num_test,
                                       "Model": model,
                                       "Predictions": cebo_lift_y_pred_modify_3
                                      })
            cebo_lift_result_4.append({"T": T,
                                       "k": k,
                                       "Train": num_train,
                                       "Test": num_test,
                                       "Model": model,
                                       "Predictions": cebo_lift_y_pred_modify_4
                                      })
        # Add to final results
        bo_lift_results.append(bo_lift_result)
        cebo_lift_results_1.append(cebo_lift_result_1)
        cebo_lift_results_2.append(cebo_lift_result_2)
        cebo_lift_results_3.append(cebo_lift_result_3)
        cebo_lift_results_4.append(cebo_lift_result_4)

## Results

<DIV STYLE="background-color:#000000; height:10px; width:100%;">