In [1]:
import warnings

warnings.simplefilter(action="ignore", category=FutureWarning)
warnings.simplefilter(action="ignore", category=UserWarning)

In [2]:
import os

from gait import (
    FEL,
    Layers,
)
import numpy as np
import seaborn as sns

from rich.pretty import pprint
from sklearn.metrics import confusion_matrix, precision_score, recall_score, accuracy_score
from tqdm.notebook import trange

In [None]:
layers = Layers.load(os.path.expanduser(
    # "~/where.json",
    # "~/miami.json"
    # "~/XOM.json",
    "~/data/NorthSea.json",
))

In [None]:
# model = "ollama_chat/qwen2.5-coder:7b",
# model = "ollama_chat/deepseek-r1:latest"
# model = "ollama_chat/qwen2:7b-instruct-q8_0"
# model = "ollama_chat/qwen2.5-coder:7b"
# model = "ollama_chat/qwen2.5:latest"
# model = "ollama_chat/llama3.2:latest"
model = "azure/gpt-4o-mini"

fel = FEL(
    layers=layers.prune_layers(),
    model=model,
    # api_base=os.environ["AZURE_API_URL"] + "/gpt-4o",
)

In [None]:
fel0 = fel.create_line_0()
pprint(fel0, expand_all=True)

felX = fel(fel0.line)
pprint(felX, expand_all=True)

In [None]:
gt = []
pv = []

line_1 = 20
line_2 = 20

for _ in trange(100):
    line_fel = fel.create_line_0()
    gt.append(line_fel.fel.route)
    # pprint(prompt, expand_all=True)
    fel_route = fel.fel0(line_fel.line, line_1, line_2)
    # pprint(fel_route, expand_all=True)
    pv.append(fel_route.route)

In [None]:
accuracy = accuracy_score(gt, pv) * 100.0
precision = precision_score(gt, pv) * 100.0
recall = recall_score(gt, pv) * 100.0

# print(f"\n{model_name} {line_1} / {line_2} {top_k=}\n")
print(f"{model} {line_1} / {line_2}\n")
print(f"Accuracy:\t{accuracy:.1f}%")
print(f"Precision:\t{precision:.1f}%")
print(f"Recall:\t\t{recall:.1f}%")

categories = ["FEL1", "FEL2"]
cf_matrix = confusion_matrix(gt, pv)
ax = sns.heatmap(
    cf_matrix / np.sum(cf_matrix),
    fmt=".1%",
    cmap="Blues",
    annot=True,
    xticklabels=categories,
    yticklabels=categories,
)
ax.set(xlabel="Predicted", ylabel="Expected")
ax.xaxis.tick_top()