# Contextual Bayesian Optimisation via Large Language Models

This notebook will demonstrates briefly the works of https://github.com/ur-whitelab/BO-LIFT, which focusses on few-shot/in-context learning (FS/ICL) for estimating the aqueous solubility (ESOL--Estimated SOLubility) of a compound and also yield calculations from chemical compound interactions. The advantages of ICL are demonstrates here: https://en.wikipedia.org/wiki/In-context_learning_(natural_language_processing). 
After, the notebook will show attempts of extending the works of https://arxiv.org/pdf/2304.05341.pdf via (non-exhaustive):

1. Implementation of advanced contextual prompting (not simply just compound+solubility or compound+yield).
2. Experimenting with chain-of-thought prompting variations (https://www.promptingguide.ai/techniques/cot).
3. Experimenting with tree-of-thought prompting (https://www.promptingguide.ai/techniques/tot).
4. Multi-task Bayesian optimization (for instance, we might want to optimize not just for solubility, but also for yield, or some other property), you could use a multi-task Bayesian optimization approach.

<DIV STYLE="background-color:#000000; height:10px; width:100%;">

# Import Libraries

In [1]:
# Standard Library
import json
import itertools
import os
import requests

# Third Party
import numpy as np
import pandas as pd
import openai
import matplotlib.pyplot as plt

# Private
import bolift
from bolift.llm_model import GaussDist, DiscreteDist
from langchain.prompts.prompt import PromptTemplate

SyntaxError: invalid syntax (asktell_new.py, line 154)

In [2]:
import re
from pydantic import create_model
from typing import Union, List, Type, Any

def sanitize_field_name(name: str) -> str:
    sanitized = name.strip().replace(' ', '_')
    sanitized = re.sub(r'\W|^(?=\d)', '', sanitized)
    if sanitized and sanitized[0].isdigit():
        raise ValueError(f"Invalid field name '{name}': After sanitization, field name cannot start with a digit.")
    return sanitized

def generate_model(keywords: List[str], datatypes: List[Type[Any]]) -> Type[Any]:
    sanitized_keywords = [sanitize_field_name(k) for k in keywords]
    fields = {k: (t, ...) for k, t in zip(sanitized_keywords, datatypes)}
    return create_model('DynamicModel', **fields)

# Test the function
keywords = ["molecule name", "molecule weight"]
datatypes = [str, Union[float, int]]

DynamicModel = generate_model(keywords, datatypes)

In [9]:
print(DynamicModel.properties())

AttributeError: type object 'DynamicModel' has no attribute 'properties'

In [2]:
# Seed results
np.random.seed(0)

In [3]:
# Default OpenAI API Key


# Data Preparation

In [4]:
# Establish path to solubility data
aqsol_df = pd.read_csv("paper/data/full_solubility.csv")
aqsol_df

Unnamed: 0,ID,Name,InChI,InChIKey,SMILES,Solubility,SD,Ocurrences,Group,MolWt,...,NumRotatableBonds,NumValenceElectrons,NumAromaticRings,NumSaturatedRings,NumAliphaticRings,RingCount,TPSA,LabuteASA,BalabanJ,BertzCT
0,A-3,"N,N,N-trimethyloctadecan-1-aminium bromide",InChI=1S/C21H46N.BrH/c1-5-6-7-8-9-10-11-12-13-...,SZEMGTQCPRNXEG-UHFFFAOYSA-M,[Br-].CCCCCCCCCCCCCCCCCC[N+](C)(C)C,-3.616127,0.000000,1,G1,392.510,...,17.0,142.0,0.0,0.0,0.0,0.0,0.00,158.520601,0.000000e+00,210.377334
1,A-4,Benzo[cd]indol-2(1H)-one,InChI=1S/C11H7NO/c13-11-8-5-1-3-7-4-2-6-9(12-1...,GPYLCFQEKPUWLD-UHFFFAOYSA-N,O=C1Nc2cccc3cccc1c23,-3.254767,0.000000,1,G1,169.183,...,0.0,62.0,2.0,0.0,1.0,3.0,29.10,75.183563,2.582996e+00,511.229248
2,A-5,4-chlorobenzaldehyde,InChI=1S/C7H5ClO/c8-7-3-1-6(5-9)2-4-7/h1-5H,AVPYQKSLYISFPO-UHFFFAOYSA-N,Clc1ccc(C=O)cc1,-2.177078,0.000000,1,G1,140.569,...,1.0,46.0,1.0,0.0,0.0,1.0,17.07,58.261134,3.009782e+00,202.661065
3,A-8,"zinc bis[2-hydroxy-3,5-bis(1-phenylethyl)benzo...",InChI=1S/2C23H22O3.Zn/c2*1-15(17-9-5-3-6-10-17...,XTUPUYCJWKHGSW-UHFFFAOYSA-L,[Zn++].CC(c1ccccc1)c2cc(C(C)c3ccccc3)c(O)c(c2)...,-3.924409,0.000000,1,G1,756.226,...,10.0,264.0,6.0,0.0,0.0,6.0,120.72,323.755434,2.322963e-07,1964.648666
4,A-9,4-({4-[bis(oxiran-2-ylmethyl)amino]phenyl}meth...,InChI=1S/C25H30N2O4/c1-5-20(26(10-22-14-28-22)...,FAUAZXVRLVIARB-UHFFFAOYSA-N,C1OC1CN(CC2CO2)c3ccc(Cc4ccc(cc4)N(CC5CO5)CC6CO...,-4.662065,0.000000,1,G1,422.525,...,12.0,164.0,2.0,4.0,4.0,6.0,56.60,183.183268,1.084427e+00,769.899934
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
9977,I-84,tetracaine,InChI=1S/C15H24N2O2/c1-4-5-10-16-14-8-6-13(7-9...,GKCBAIGFKIBETG-UHFFFAOYSA-N,C(c1ccc(cc1)NCCCC)(=O)OCCN(C)C,-3.010000,0.000000,1,G1,264.369,...,8.0,106.0,1.0,0.0,0.0,1.0,41.57,115.300645,2.394548e+00,374.236893
9978,I-85,tetracycline,InChI=1S/C22H24N2O8/c1-21(31)8-5-4-6-11(25)12(...,OFVLGDICTFRJMM-WESIUVDSSA-N,OC1=C(C(C2=C(O)[C@@](C(C(C(N)=O)=C(O)[C@H]3N(C...,-2.930000,0.000000,1,G1,444.440,...,2.0,170.0,1.0,0.0,3.0,4.0,181.62,182.429237,2.047922e+00,1148.584975
9979,I-86,thymol,InChI=1S/C10H14O/c1-7(2)9-5-4-8(3)6-10(9)11/h4...,MGSRCZKZVOBKFT-UHFFFAOYSA-N,c1(cc(ccc1C(C)C)C)O,-2.190000,0.019222,3,G5,150.221,...,1.0,60.0,1.0,0.0,0.0,1.0,20.23,67.685405,3.092720e+00,251.049732
9980,I-93,verapamil,"InChI=1S/C27H38N2O4/c1-20(2)27(19-28,22-10-12-...",SGTNSNPWRIOYBX-UHFFFAOYSA-N,COc1ccc(CCN(C)CCCC(C#N)(C(C)C)c2ccc(OC)c(OC)c2...,-3.980000,0.000000,1,G1,454.611,...,13.0,180.0,2.0,0.0,0.0,2.0,63.95,198.569223,2.023333e+00,938.203977


In [5]:
# Use only solubility
aqsol_df = aqsol_df.dropna()
aqsol_df = aqsol_df.drop_duplicates().reset_index(drop=True)
aqsol_df

Unnamed: 0,ID,Name,InChI,InChIKey,SMILES,Solubility,SD,Ocurrences,Group,MolWt,...,NumRotatableBonds,NumValenceElectrons,NumAromaticRings,NumSaturatedRings,NumAliphaticRings,RingCount,TPSA,LabuteASA,BalabanJ,BertzCT
0,A-3,"N,N,N-trimethyloctadecan-1-aminium bromide",InChI=1S/C21H46N.BrH/c1-5-6-7-8-9-10-11-12-13-...,SZEMGTQCPRNXEG-UHFFFAOYSA-M,[Br-].CCCCCCCCCCCCCCCCCC[N+](C)(C)C,-3.616127,0.000000,1,G1,392.510,...,17.0,142.0,0.0,0.0,0.0,0.0,0.00,158.520601,0.000000e+00,210.377334
1,A-4,Benzo[cd]indol-2(1H)-one,InChI=1S/C11H7NO/c13-11-8-5-1-3-7-4-2-6-9(12-1...,GPYLCFQEKPUWLD-UHFFFAOYSA-N,O=C1Nc2cccc3cccc1c23,-3.254767,0.000000,1,G1,169.183,...,0.0,62.0,2.0,0.0,1.0,3.0,29.10,75.183563,2.582996e+00,511.229248
2,A-5,4-chlorobenzaldehyde,InChI=1S/C7H5ClO/c8-7-3-1-6(5-9)2-4-7/h1-5H,AVPYQKSLYISFPO-UHFFFAOYSA-N,Clc1ccc(C=O)cc1,-2.177078,0.000000,1,G1,140.569,...,1.0,46.0,1.0,0.0,0.0,1.0,17.07,58.261134,3.009782e+00,202.661065
3,A-8,"zinc bis[2-hydroxy-3,5-bis(1-phenylethyl)benzo...",InChI=1S/2C23H22O3.Zn/c2*1-15(17-9-5-3-6-10-17...,XTUPUYCJWKHGSW-UHFFFAOYSA-L,[Zn++].CC(c1ccccc1)c2cc(C(C)c3ccccc3)c(O)c(c2)...,-3.924409,0.000000,1,G1,756.226,...,10.0,264.0,6.0,0.0,0.0,6.0,120.72,323.755434,2.322963e-07,1964.648666
4,A-9,4-({4-[bis(oxiran-2-ylmethyl)amino]phenyl}meth...,InChI=1S/C25H30N2O4/c1-5-20(26(10-22-14-28-22)...,FAUAZXVRLVIARB-UHFFFAOYSA-N,C1OC1CN(CC2CO2)c3ccc(Cc4ccc(cc4)N(CC5CO5)CC6CO...,-4.662065,0.000000,1,G1,422.525,...,12.0,164.0,2.0,4.0,4.0,6.0,56.60,183.183268,1.084427e+00,769.899934
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
9977,I-84,tetracaine,InChI=1S/C15H24N2O2/c1-4-5-10-16-14-8-6-13(7-9...,GKCBAIGFKIBETG-UHFFFAOYSA-N,C(c1ccc(cc1)NCCCC)(=O)OCCN(C)C,-3.010000,0.000000,1,G1,264.369,...,8.0,106.0,1.0,0.0,0.0,1.0,41.57,115.300645,2.394548e+00,374.236893
9978,I-85,tetracycline,InChI=1S/C22H24N2O8/c1-21(31)8-5-4-6-11(25)12(...,OFVLGDICTFRJMM-WESIUVDSSA-N,OC1=C(C(C2=C(O)[C@@](C(C(C(N)=O)=C(O)[C@H]3N(C...,-2.930000,0.000000,1,G1,444.440,...,2.0,170.0,1.0,0.0,3.0,4.0,181.62,182.429237,2.047922e+00,1148.584975
9979,I-86,thymol,InChI=1S/C10H14O/c1-7(2)9-5-4-8(3)6-10(9)11/h4...,MGSRCZKZVOBKFT-UHFFFAOYSA-N,c1(cc(ccc1C(C)C)C)O,-2.190000,0.019222,3,G5,150.221,...,1.0,60.0,1.0,0.0,0.0,1.0,20.23,67.685405,3.092720e+00,251.049732
9980,I-93,verapamil,"InChI=1S/C27H38N2O4/c1-20(2)27(19-28,22-10-12-...",SGTNSNPWRIOYBX-UHFFFAOYSA-N,COc1ccc(CCN(C)CCCC(C#N)(C(C)C)c2ccc(OC)c(OC)c2...,-3.980000,0.000000,1,G1,454.611,...,13.0,180.0,2.0,0.0,0.0,2.0,63.95,198.569223,2.023333e+00,938.203977


In [9]:
aqsol_df.columns

Index(['ID', 'Name', 'InChI', 'InChIKey', 'SMILES', 'Solubility', 'SD',
       'Ocurrences', 'Group', 'MolWt', 'MolLogP', 'MolMR', 'HeavyAtomCount',
       'NumHAcceptors', 'NumHDonors', 'NumHeteroatoms', 'NumRotatableBonds',
       'NumValenceElectrons', 'NumAromaticRings', 'NumSaturatedRings',
       'NumAliphaticRings', 'RingCount', 'TPSA', 'LabuteASA', 'BalabanJ',
       'BertzCT'],
      dtype='object')

In [7]:
aqsol_df[aqsol_df["Ocurrences"]>1]

Unnamed: 0,ID,Name,InChI,InChIKey,SMILES,Solubility,SD,Ocurrences,Group,MolWt,...,NumRotatableBonds,NumValenceElectrons,NumAromaticRings,NumSaturatedRings,NumAliphaticRings,RingCount,TPSA,LabuteASA,BalabanJ,BertzCT
8,A-14,bis(4-fluorophenyl)methanone,InChI=1S/C13H8F2O/c14-11-5-1-9(2-6-11)13(16)10...,LSQARZALBDFYQZ-UHFFFAOYSA-N,Fc1ccc(cc1)C(=O)c2ccc(F)cc2,-4.396652,0.431513,2,G3,218.202,...,2.0,80.0,2.0,0.0,0.0,2.0,17.07,91.346032,2.315628e+00,452.960733
9,A-15,1-[2-(benzoyloxy)propoxy]propan-2-yl benzoate ...,InChI=1S/C20H22O5/c21-19(17-9-3-1-4-10-17)24-1...,BYQDGAVOOHIJQS-UHFFFAOYSA-N,O=C(OCCCOCCCOC(=O)c1ccccc1)c2ccccc2,-4.595503,0.118551,2,G3,342.391,...,10.0,132.0,2.0,0.0,0.0,2.0,61.83,147.071714,1.447050e+00,582.150793
12,A-19,"2,3-dimethylphenol; 2,4-dimethylphenol; 2,5-di...",InChI=1S/6C8H10O/c1-6-3-7(2)5-8(9)4-6;1-6-3-4-...,YJZHZFOWHRKQHS-UHFFFAOYSA-N,Cc1ccc(O)c(C)c1.Cc2ccc(C)c(O)c2.Cc3cc(C)cc(O)c...,-1.980310,0.155859,4,G5,733.002,...,0.0,288.0,6.0,0.0,0.0,6.0,121.38,322.890738,3.240000e-07,1804.418547
15,A-23,"(2E)-3,7-dimethylocta-2,6-dien-1-ol","InChI=1S/C10H18O/c1-9(2)5-4-6-10(3)7-8-11/h5,7...",GLZPCOQZEFWAFX-YFHOEESVSA-N,CC(C)=CCC\C(C)=C/CO,-2.320601,0.071633,4,G5,154.253,...,4.0,64.0,0.0,0.0,0.0,0.0,20.23,69.438758,3.544387e+00,150.255712
16,A-24,2-(4-chloro-2-methylphenoxy)propanoic acid,InChI=1S/C10H11ClO3/c1-6-5-8(11)3-4-9(6)14-7(2...,WNTGYJSOUMFZEP-UHFFFAOYSA-N,CC(Oc1ccc(Cl)cc1C)C(O)=O,-2.466031,0.060621,4,G5,214.648,...,3.0,76.0,1.0,0.0,0.0,1.0,46.53,87.263739,2.817665e+00,349.220389
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
9972,I-76,sarafloxacin,InChI=1S/C20H17F2N3O3/c21-12-1-3-13(4-2-12)25-...,XBHBWNFJWIASRO-UHFFFAOYSA-N,C1CN(CCN1)c1c(cc2c(n(cc(C(=O)O)c2=O)c2ccc(cc2)...,-3.130000,0.005000,2,G3,385.370,...,3.0,144.0,3.0,1.0,1.0,4.0,74.57,158.176705,2.013165e+00,1120.229459
9974,I-79,sulfamethazine,InChI=1S/C12H14N4O2S/c1-8-7-9(2)15-12(14-8)16-...,ASWVTGNCAZCNNR-UHFFFAOYSA-N,S(=O)(=O)(Nc1nc(C)cc(n1)C)c1ccc(N)cc1,-2.730000,0.230750,2,G3,278.337,...,3.0,100.0,2.0,0.0,0.0,2.0,97.97,111.308206,2.330509e+00,675.919365
9976,I-83,sulindac_form_II,InChI=1S/C20H17FO3S/c1-12-17(9-13-3-6-15(7-4-1...,MLKXDPUZXIRXEP-RQZCQDPDSA-N,CC1=C(CC(O)=O)c2cc(F)ccc2C\1=C\c1ccc(cc1)S(C)=O,-4.500000,0.410000,2,G3,356.418,...,4.0,128.0,2.0,0.0,1.0,3.0,54.37,147.518542,2.101112e+00,939.913669
9979,I-86,thymol,InChI=1S/C10H14O/c1-7(2)9-5-4-8(3)6-10(9)11/h4...,MGSRCZKZVOBKFT-UHFFFAOYSA-N,c1(cc(ccc1C(C)C)C)O,-2.190000,0.019222,3,G5,150.221,...,1.0,60.0,1.0,0.0,0.0,1.0,20.23,67.685405,3.092720e+00,251.049732


# ICL

## Ask-Tell

In [None]:
# Instantiate LLM model through ask-tell interface
asktell_1 = bolift.AskTellFewShotTopk()
# Tell the model some points (few-shot/ICL)
icl_examples = ["1-bromopropane", "1-bromopentane", "1-bromooctane", "1-bromonaphthalene"]
for molecule in icl_examples:
    asktell_1.tell(molecule, 
                   esol_1[esol_1["IUPAC"]==molecule].values[0][1])
# Make a prediction for a molecule
yhat = asktell_1.predict("1-bromobutane")
print(f"Y_Hat for ICL (before BO): {yhat}")
print(f"Y_Hat Mean: {yhat.mean()}")
print(f"Y_Hat Standard Deviation: {yhat.std()}")

## LLM as BO

In [None]:
# Now treat LLM model as a BO protcol
pool_list_1 = [
    "1-bromoheptane",
    "1-bromohexane",
    "1-bromo-2-methylpropane",
    "butan-1-ol"
]
# Create the pool object
pool_1 = bolift.Pool(pool_list_1)
# Ask for the next most likely point (found through using UCB as the acquisition function on the previous points)
next_point_1 = asktell_1.ask(pool_1)
print(f"The next point for the optimiser to try is: {next_point_1}")

In [None]:
# Tell the LLM the "actual" solubility value
asktell_1.tell(next_point_1[0][0], esol_1[esol_1["IUPAC"]==next_point_1[0][0]].values[0][1])
yhat = asktell_1.predict("1-bromobutane")
print(f"Y_Hat for ICL+BO: {yhat}")
print(f"Y_Hat Mean: {yhat.mean()}")
print(f"Y_Hat Standard Deviation: {yhat.std()}")

In [None]:
# The actual yield
esol_1[esol_1["IUPAC"]=="1-bromobutane"]

<DIV STYLE="background-color:#000000; height:10px; width:100%;">

## Idea 1:

Here, we aim to improve the "LLM as BO" by adding more contextual information. Contextual information can be added in many different ways - the 3 ways we will look at are:

1. Change the prompt template:

`prompt_template = PromptTemplate(input_variables=["x", "Answer", "y_name"] + self._answer_choices,
                                   template="Q: Given {x}. What is {y_name}?\n"
                                   + "\n".join([f"{a}. {{{a}}}" for a in self._answer_choices])
                                   + "\nAnswer: {Answer}\n\n")`
                                  
2. Change the acquisition function to incorporate context (UCB -> C-UCB).
3. Use Policy Learning (e.g. policy learning with reinforcement learning can involve using a function approximator like a neural network to predict actions, and then updating the weights of the network based on the observed reward. Here, your "actions" would be your predictions of solubility, and your "reward" would be how close those predictions are to the true solubility).

Note that these can also be combined ideas.

In [None]:
# Use another modified dataset
esol_2 = esol_data[["IUPAC", "measured log(solubility:mol/L)", "SMILES"]]
esol_2 = esol_2.dropna()
esol_2 = esol_2.drop_duplicates().reset_index(drop=True)
esol_2

In [None]:
# Instantiate LLM model through ask-tell interface
asktell_2 = bolift.AskTellFewShotTopk_Template()
# Tell the model some points (few-shot/ICL)
for molecule in icl_examples:
    asktell_2.tell(esol_2[esol_2["IUPAC"]==molecule].values[0][2], 
                   esol_2[esol_2["IUPAC"]==molecule].values[0][1])

In [None]:
# Make a prediction for a molecule
yhat = asktell_2.predict("CCCCBr")
print(f"Y_Hat for ICL (before BO): {yhat}")
print(f"Y_Hat Mean: {yhat.mean()}")
print(f"Y_Hat Standard Deviation: {yhat.std()}")

In [None]:
# Now treat LLM model as a BO protcol
pool_examples = ["1-bromoheptane", "1-bromohexane", "1-bromo-2-methylpropane", "butan-1-ol"]
pool_list_2 = []
for molecule in pool_examples:
    pool_list_2.append(esol_2[esol_2["IUPAC"]==molecule].values[0][2])
# Create the pool object
pool_2 = bolift.Pool(pool_list_2)
# Ask for the next most likely point (found through using UCB as the acquisition function on the previous points)
next_point_2 = asktell_2.ask(pool_2)
print(f"The next point for the optimiser to try is: {next_point_2}")

In [None]:
# Tell the LLM the "actual" solubility value
asktell_2.tell(next_point_2[0][0], esol_2[esol_2["SMILES"]==next_point_2[0][0]].values[0][1])
yhat = asktell_2.predict("CCCCBr")
print(f"Y_Hat for ICL+BO: {yhat}")
print(f"Y_Hat Mean: {yhat.mean()}")
print(f"Y_Hat Standard Deviation: {yhat.std()}")

## Idea 2:

Here, we aim to improve the "LLM as BO" by noticing chemical compounds with similar SMILES structures will potentially have similar chemical characteristics. We can exploit this and then cleverly feed specific clusters to the LLM, to help improve prediction accuracy by:

1. Designing a function to encode SMILES strings into a numeric format (embedding) requires choosing an appropriate representation for the chemical structures. One particular representation is fingerprint. Then, applying a machine learning model to learn a mapping from the fingerprint (Morgan) to a scalar value that's predictive of the molecule's solubility. 
2. Applying chain-of-thought reasoning (and its variations) i.e. feeding examples in a particular order rather than random.

1. Try to find more context data and improve the prompt design (for this dataset).
2. Try on other contextual bo applications and see how it performs on the LLM.