In [None]:
import bolift
import numpy as np
import matplotlib.pyplot as plt
import json
import pandas as pd
from langchain.prompts.prompt import PromptTemplate


data_path = "paper/data/C2_yield_meth_oxy_short.csv"
np.random.seed(0)

In [None]:
raw_data = pd.read_csv(data_path)
raw_data.head()

In [None]:
asktell = bolift.AskTellFewShotMulti(
    x_formatter=lambda x: f'synthesis procedure:"{x}"',
    y_name="C2 yield",
    y_formatter=lambda y: f"{y:.2f}",
    model="text-curie-001",
    selector_k=8,
)
train = np.random.choice(raw_data.shape[0], 450, replace=False)
test = np.setdiff1d(np.arange(raw_data.shape[0]), train)
# shuffle test
np.random.shuffle(test)
print(len(train), len(test))

In [None]:
for i in train:
    asktell.tell(raw_data.iloc[i, 0], float(raw_data.iloc[i, 1]))
y = []
yhat = []
for j in test[:10]:
    y.append(float(raw_data.iloc[j, 1]))
    yhat.append(asktell.predict(raw_data.iloc[j, 0]))

In [None]:
# filter out failed
y_filter = [yi for yi, yhi in zip(y, yhat) if len(yhi.values) > 0]
yhat_filter = [yhi for yi, yhi in zip(y, yhat) if len(yhi.values) > 0]

In [None]:
modes = [yhat_filter[i].mode() for i in range(len(yhat_filter))]
std = [yhat_filter[i].std() for i in range(len(yhat_filter))]
plt.errorbar(y_filter, modes, yerr=std, fmt="o")
# let's add best fit line
m, b = np.polyfit(y_filter, modes, 1)
plt.plot(y_filter, m * np.array(y_filter) + b, color="gray")
# put R^2 on plot
r2 = np.corrcoef(y_filter, modes)[0, 1] ** 2
plt.text(0.1, 0.9, f"$R^2$ = {r2:.2f}", transform=plt.gca().transAxes)
plt.title("Predicted vs. Actual C2 Yield")
plt.xlabel("Actual Yield")
plt.ylabel("Predicted Yield")
plt.show()

In [None]:
for yh, yi in zip(yhat, y):
    print(yh, yi)

In [None]:
asktell = bolift.AskTellFewShotMulti(
    x_formatter=lambda x: f'synthesis procedure:"{x}"',
    y_name="C2 yield",
    y_formatter=lambda y: f"{y:.2f}",
    model="text-curie-001",
    selector_k=8,
)
# tell it 1 example

for i in train[:1]:
    asktell.tell(raw_data.iloc[i, 0], float(raw_data.iloc[i, 1]))

In [None]:
N = 10
aq = "expected_improvement"
pool = bolift.Pool(list(raw_data.prompt[test]), asktell.format_x)
point = []
true_y = []
for i in range(N):
    if i == N - 1:
        px, _, py = asktell.ask(pool, k=1, aq_fxn="greedy", inv_filter=10)
    else:
        px, _, py = asktell.ask(pool, k=1, aq_fxn=aq, inv_filter=10)
    xc = px[0]
    # remove from pool
    pool.choose(xc)
    y = float(raw_data[raw_data["prompt"] == xc]["completion"])
    asktell.tell(xc, y)
    point.append((xc, y))
    print(y)

In [None]:
plt.axhline(
    y=raw_data["completion"][test].min(), color="C0", linestyle="--", label="min"
)
plt.axhline(
    y=raw_data["completion"][test].mean(), color="C1", linestyle="--", label="mean"
)
plt.axhline(
    y=raw_data["completion"][test].max(), color="C2", linestyle="--", label="max"
)
# give 5% quantiles
plt.axhline(
    y=raw_data["completion"][test].quantile(0.05),
    color="C3",
    linestyle="--",
    label="5%",
)
plt.axhline(
    y=raw_data["completion"][test].quantile(0.95),
    color="C4",
    linestyle="--",
    label="95%",
)
maxes = [max([y for x, y in point][:i]) for i in range(1, N + 1)]
# plt.plot(range(10, N + 10), maxes)
plt.plot(range(1, N + 1), [y for x, y in point])
plt.title(f"{aq} with {asktell._model}")
plt.legend()