In [None]:
import bolift
import numpy as np
import matplotlib.pyplot as plt
import json
import pandas as pd
from langchain.prompts.prompt import PromptTemplate


data_path = "paper/data/yield_strength.csv"
np.random.seed(0)

In [None]:
raw_data = pd.read_csv(data_path)
raw_data.columns

In [None]:
asktell = bolift.AskTellFewShotMulti(
    x_formatter=lambda x: f"alloy composition of {x}",
    y_name="yield strength",
    y_formatter=lambda y: f"{y:.2f}",
    model="text-curie-001",
    selector_k=10,
)

In [None]:
N = raw_data.shape[0]
train = np.random.choice(raw_data.shape[0], int(N * 0.8), replace=False)
test = np.setdiff1d(np.arange(raw_data.shape[0]), train)
# shuffle test
np.random.shuffle(test)
print(N, len(train), len(test))

In [None]:
for i in train:
    asktell.tell(raw_data.iloc[i, 0], float(raw_data.iloc[i, 1]))
y = []
yhat = []
for j in test[:10]:
    y.append(float(raw_data.iloc[j, 1]))
    yhat.append(asktell.predict(raw_data.iloc[j, 0]))

In [None]:
# filter out failed
y_filter = [yi for yi, yhi in zip(y, yhat) if len(yhi.values) > 0]
yhat_filter = [yhi for yi, yhi in zip(y, yhat) if len(yhi.values) > 0]

In [None]:
modes = [yhat_filter[i].mode() for i in range(len(yhat_filter))]
std = [yhat_filter[i].std() for i in range(len(yhat_filter))]
plt.errorbar(y_filter, modes, yerr=std, fmt="o")
plt.plot(y, y)
plt.xlabel("Actual Yield")
plt.ylabel("Predicted Yield")
plt.show()

In [None]:
for yh, yi in zip(yhat, y):
    print(yh, yi)

In [None]:
asktell = bolift.AskTellFewShotMulti(
    x_formatter=lambda x: f"alloy composition of {x}",
    y_name="yield strength",
    y_formatter=lambda y: f"{y:.2f}",
    model="text-curie-001",
    selector_k=10,
)
# tell it select_k examples

for i in train[:10]:
    asktell.tell(raw_data.iloc[i, 0], float(raw_data.iloc[i, 1]))

In [None]:
N = 10
pool = test[:50]
pool_str = [raw_data.iloc[i, 0] for i in pool]
point = []
pred_y = []
true_y = []
for i in range(N):
    aq = "expected_improvement"
    if i == N - 1:
        aq = "greedy"
    px, _, py = asktell.ask(pool_str, k=len(pool), aq_fxn=aq)
    xc = px[np.argmax(py)]
    y = float(raw_data[raw_data["composition"] == xc]["yield strength"])
    asktell.tell(xc, y)
    point.append((xc, y))
    pred_y.append(py)

In [None]:
plt.axhline(y=raw_data["yield strength"].min(), color="C0", linestyle="--", label="min")
plt.axhline(
    y=raw_data["yield strength"].mean(), color="C1", linestyle="--", label="mean"
)
plt.axhline(y=raw_data["yield strength"].max(), color="C2", linestyle="--", label="max")
# give 5% quantiles
plt.axhline(
    y=raw_data["yield strength"].quantile(0.05), color="C3", linestyle="--", label="5%"
)
plt.axhline(
    y=raw_data["yield strength"].quantile(0.95), color="C4", linestyle="--", label="95%"
)
maxes = [max([y for x, y in point][:i]) for i in range(1, N + 1)]
# plt.plot(range(10, N + 10), maxes)
plt.plot(range(1, N + 1), [y for x, y in point])
plt.legend()