# Contextual Bayesian Optimisation via Large Language Models

This notebook will demonstrates briefly the works of https://github.com/ur-whitelab/BO-LIFT, which focusses on few-shot/in-context learning (FS/ICL) for estimating the aqueous solubility (ESOL--Estimated SOLubility) of a compound. The advantages of ICL are demonstrates here: https://en.wikipedia.org/wiki/In-context_learning_(natural_language_processing). 
After, the notebook will show attempts of extending the works of https://arxiv.org/pdf/2304.05341.pdf via (non-exhaustive):

1. Implementation of advanced contextual prompting (not simply just compound+solubility or compound+yield).
2. Experimenting with chain-of-thought prompting variations (https://www.promptingguide.ai/techniques/cot).
3. Experimenting with tree-of-thought prompting (https://www.promptingguide.ai/techniques/tot).
4. Multi-task Bayesian optimization (for instance, we might want to optimize not just for solubility, but also for yield, or some other property), you could use a multi-task Bayesian optimization approach.

<DIV STYLE="background-color:#000000; height:10px; width:100%;">

# Import Libraries

In [1]:
# Standard Library
import json
import itertools
import os
import requests

# Third Party
import numpy as np
import pandas as pd
import openai

# Private
import bolift
from bolift.llm_model import GaussDist, DiscreteDist
from langchain.prompts.prompt import PromptTemplate

In [2]:
# Seed results
np.random.seed(0)
# Default OpenAI API Key
os.environ["OPENAI_API_KEY"] = "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"

# Data Preparation

In [3]:
# Establish path to solubility data
esol_data = pd.read_csv("../paper/data/esol_iupac.csv")
esol_data

Unnamed: 0,Compound ID,measured log(solubility:mol/L),ESOL predicted log(solubility:mol/L),SMILES,SELFIES,InChI,IUPAC
0,"1,1,1,2-Tetrachloroethane",-2.180,-2.794,ClCC(Cl)(Cl)Cl,[Cl][C][C][Branch1][C][Cl][Branch1][C][Cl][Cl],"InChI=1S/C2H2Cl4/c3-1-2(4,5)6/h1H2","1,1,1,2-tetrachloroethane"
1,"1,1,1-Trichloroethane",-2.000,-2.232,CC(Cl)(Cl)Cl,[C][C][Branch1][C][Cl][Branch1][C][Cl][Cl],"InChI=1S/C2H3Cl3/c1-2(3,4)5/h1H3","1,1,1-trichloroethane"
2,"1,1,2,2-Tetrachloroethane",-1.740,-2.549,ClC(Cl)C(Cl)Cl,[Cl][C][Branch1][C][Cl][C][Branch1][C][Cl][Cl],InChI=1S/C2H2Cl4/c3-1(4)2(5)6/h1-2H,"1,1,2,2-tetrachloroethane"
3,"1,1,2-Trichloroethane",-1.480,-1.961,ClCC(Cl)Cl,[Cl][C][C][Branch1][C][Cl][Cl],"InChI=1S/C2H3Cl3/c3-1-2(4)5/h2H,1H2","1,1,2-trichloroethane"
4,"1,1,2-Trichlorotrifluoroethane",-3.040,-3.077,FC(F)(Cl)C(F)(Cl)Cl,[F][C][Branch1][C][F][Branch1][C][Cl][C][Branc...,"InChI=1S/C2Cl3F3/c3-1(4,6)2(5,7)8","1,1,2-trichloro-1,2,2-trifluoroethane"
...,...,...,...,...,...,...,...
922,Valeraldehyde,-0.850,-1.103,CCCCC=O,[C][C][C][C][C][=O],"InChI=1S/C5H10O/c1-2-3-4-5-6/h5H,2-4H2,1H3",pentanal
923,vamidothion,1.144,-1.446,CNC(=O)C(C)SCCSP(=O)(OC)(OC),[C][N][C][=Branch1][C][=O][C][Branch1][C][C][S...,InChI=1S/C8H18NO4PS2/c1-7(8(10)9-2)15-5-6-16-1...,2-(2-dimethoxyphosphorylsulfanylethylsulfanyl)...
924,Vinclozolin,-4.925,-4.377,CC1(OC(=O)N(C1=O)c2cc(Cl)cc(Cl)c2)C=C,[C][C][Branch2][Ring1][O][O][C][=Branch1][C][=...,InChI=1S/C12H9Cl2NO3/c1-3-12(2)10(16)15(11(17)...,"3-(3,5-dichlorophenyl)-5-ethenyl-5-methyl-1,3-..."
925,Xipamide,-3.790,-3.642,Cc1cccc(C)c1NC(=O)c2cc(c(Cl)cc2O)S(N)(=O)=O,[C][C][=C][C][=C][C][Branch1][C][C][=C][Ring1]...,InChI=1S/C15H15ClN2O4S/c1-8-4-3-5-9(2)14(8)18-...,"4-chloro-N-(2,6-dimethylphenyl)-2-hydroxy-5-su..."


In [4]:
# Use only solubility
esol_df = esol_data[["IUPAC", "measured log(solubility:mol/L)"]]
esol_df = esol_df.dropna()
esol_df

Unnamed: 0,IUPAC,measured log(solubility:mol/L)
0,"1,1,1,2-tetrachloroethane",-2.180
1,"1,1,1-trichloroethane",-2.000
2,"1,1,2,2-tetrachloroethane",-1.740
3,"1,1,2-trichloroethane",-1.480
4,"1,1,2-trichloro-1,2,2-trifluoroethane",-3.040
...,...,...
922,pentanal,-0.850
923,2-(2-dimethoxyphosphorylsulfanylethylsulfanyl)...,1.144
924,"3-(3,5-dichlorophenyl)-5-ethenyl-5-methyl-1,3-...",-4.925
925,"4-chloro-N-(2,6-dimethylphenyl)-2-hydroxy-5-su...",-3.790


# ICL

## Ask-Tell

In [5]:
# Instantiate LLM model through ask-tell interface
asktell = bolift.AskTellFewShotTopk()
# Tell the model some points (few-shot/ICL)
asktell.tell("1-bromopropane", -1.730)
asktell.tell("1-bromopentane", -3.080)
asktell.tell("1-bromooctane", -5.060)
asktell.tell("1-bromonaphthalene", -4.35)
# Make a prediction for a molecule
yhat = asktell.predict("1-bromobutane")
print(f"Y_Hat for ICL (before BO): {yhat}")
print(f"Y_Hat Mean: {yhat.mean()}")
print(f"Y_Hat Standard Deviation: {yhat.std()}")

Y_Hat for ICL (before BO): DiscreteDist([-2.92 -2.68], [0.6 0.4])
Y_Hat Mean: -2.824
Y_Hat Standard Deviation: 0.11757550765359243


## LLM as BO

In [6]:
# Now treat LLM model as a BO protcol
pool_list = [
    "1-bromoheptane",
    "1-bromohexane",
    "1-bromo-2-methylpropane",
    "butan-1-ol"
]
# Create the pool object
pool = bolift.Pool(pool_list)
# Ask for the next most likely point (found through using UCB as the acquisition function on the previous points)
next_point = asktell.ask(pool)
print(f"The next point for the optimiser to try is: {next_point}")

The next point for the optimiser to try is: (['1-bromo-2-methylpropane'], [-1.284916344093158], [-1.9199999999999997])


In [7]:
# Tell the LLM the "actual" solubility value
asktell.tell(next_point[0][0], esol_df[esol_df["IUPAC"]==next_point[0][0]].values[0][1])
yhat = asktell.predict("1-bromobutane")
print(f"Y_Hat for ICL+BO: {yhat}")
print(f"Y_Hat Mean: {yhat.mean()}")
print(f"Y_Hat Standard Deviation: {yhat.std()}")

Y_Hat for ICL+BO: DiscreteDist([-1.89 -1.86], [0.6 0.4])
Y_Hat Mean: -1.878
Y_Hat Standard Deviation: 0.014696938456698972


Here, we aim to improve the "LLM as BO" by by adding more contextual information. Contextual information can be added in many different ways - the 3 ways we will look at are:

1. Change the prompt template:

`prompt_template = PromptTemplate(input_variables=["x", "Answer", "y_name"] + self._answer_choices,
                                   template="Q: Given {x}. What is {y_name}?\n"
                                   + "\n".join([f"{a}. {{{a}}}" for a in self._answer_choices])
                                   + "\nAnswer: {Answer}\n\n")`
                                  
2. Change the acquisition function to incorporate context (UCB -> C-UCB).
3. Use Policy Learning (e.g. policy learning with reinforcement learning can involve using a function approximator like a neural network to predict actions, and then updating the weights of the network based on the observed reward. Here, your "actions" would be your predictions of solubility, and your "reward" would be how close those predictions are to the true solubility).



Note that these can also be combined ideas.

<DIV STYLE="background-color:#000000; height:10px; width:100%;">

## Idea 1: