In [None]:
import panel
import param
import pandas as pd

panel.extension()

In [None]:
import sys

sys.path.insert(0, "../../dags/")
from jobs.common.bigquery import get_table_as_records
from google.cloud import bigquery
from google.oauth2 import service_account

key_path = "/Users/yco/.config/dbt-user-creds.json"
credentials = service_account.Credentials.from_service_account_file(
    key_path  # , scopes=["https://www.googleapis.com/auth/cloud-platform"],
)

client = bigquery.Client(
    credentials=credentials,
    project=credentials.project_id,
)
perfs = pd.DataFrame(
    get_table_as_records(
        client,
        "model_perfs",
        "textcat",
    )
)

In [None]:
import holoviews as hv

hv.extension("bokeh")


class ModelPerfs(param.Parameterized):
    analysis = param.Selector(
        objects=["AUC", "score", "micro_prf", "macro_prf"]
    )  # TODO: PRF per label
    perf_type = param.Selector(objects=["valid", "train"])
    model = param.Selector(objects=perfs.model.unique().tolist())

    def view(self):
        if self.analysis == "AUC":
            variables = ["cats_macro_auc"]
        elif self.analysis == "score":
            variables = ["cats_score"]
        elif self.analysis == "micro_prf":
            variables = ["cats_micro_p", "cats_micro_r", "cats_micro_f"]
        elif self.analysis == "macro_prf":
            variables = ["cats_macro_p", "cats_macro_r", "cats_macro_f"]
        cols = ["date", "model", *variables]
        df = (
            perfs.loc[perfs.type == self.perf_type, cols]
            .set_index(["date", "model"])
            .stack()
        )
        df.name = "performance"
        df.index.names = df.index.names[:2] + ["variables"]
        df = df.reset_index().sort_values("date")
        ds = hv.Dataset(
            df,
            kdims=["date", "model", "variables"],
            vdims=hv.Dimension("performance", range=(0, 1)),
        )
        return (
            ds.to(hv.Curve, "date")
            .overlay("variables")
            .opts(legend_position="right", toolbar="above", width=600)
        )


obj = ModelPerfs()
panel.Row(obj.param, obj.view)