In [1]:
from typing import List

import numpy as np
import pandas as pd
import plotly.express as px
from plotly.subplots import make_subplots
from sklearn.metrics import roc_curve, roc_auc_score
import torch

from prompt_playground.actionclip import (
    images_features_df,
    get_images_features,
    get_all_text_features,
    get_text_features,
    VARIATION_NAMES,
)
from prompt_playground.tensor_utils import normalize_features

Created a temporary directory at /var/folders/yc/slfs2kzs1d72wd8b20xw0nf80000gn/T/tmpo_puuigo
Writing /var/folders/yc/slfs2kzs1d72wd8b20xw0nf80000gn/T/tmpo_puuigo/_remote_module_non_scriptable.py


In [2]:
def get_similarities(
    variation: str,
    texts_positive: List[str],
    texts_negative: List[str],
):
    assert len(texts_positive) > 0
    assert len(texts_negative) > 0
    assert variation in VARIATION_NAMES

    images_features = normalize_features(get_images_features(variation))
    positive_texts_features = get_all_text_features(texts_positive, variation, True)
    negative_texts_features = get_all_text_features(texts_negative, variation, True)
    human_text_features = get_all_text_features(["human"], variation, True)

    pos_similarities = images_features @ positive_texts_features.T
    neg_similarities = images_features @ negative_texts_features.T
    human_similarities = images_features @ human_text_features.T

    # between -X and +X
    balanced_human_similarities = (
        human_similarities
        - human_similarities.max()
        + ((human_similarities.max() - human_similarities.min()) / 2)
    )

    return torch.hstack(
        (pos_similarities, neg_similarities - (0.37 * balanced_human_similarities))
    )


def infer(TEXTS: List[str], TEXT_CLASSIFICATIONS: List[bool], VARIATION: str):
    texts_positive = [t for i, t in enumerate(TEXTS) if TEXT_CLASSIFICATIONS[i]]
    texts_negative = [t for i, t in enumerate(TEXTS) if not TEXT_CLASSIFICATIONS[i]]
    similarities = get_similarities(
        VARIATION,
        texts_positive,
        texts_negative,
    )

    alarms_series = images_features_df(VARIATION)["Alarm"]

    df = (
        pd.DataFrame(
            [
                [clip, alarms_series[clip], text, y_predict, clip_similarity]
                for clip, clip_similarities in zip(alarms_series.index, similarities)
                for text, y_predict, clip_similarity in zip(
                    texts_positive + texts_negative,
                    [True] * len(texts_positive) + [False] * len(texts_negative),
                    clip_similarities,
                )
            ],
            columns=["clip", "y_true", "text", "y_predict", "similarity"],
        )
        .rename(columns={"classification": "y_predict"})
        .sort_values(["clip", "y_true", "y_predict", "text"])
        .reset_index(drop=True)
    )

    unique_y_true = df["y_true"].unique()

    fig = make_subplots(
        rows=1,
        cols=4,
        column_widths=[0.3, 0.2, 0.3, 0.2],
        shared_yaxes=True,
        y_title="similarity",
        subplot_titles=[
            f"y_true={y_true}" for y_true in unique_y_true for _ in range(2)
        ],
    )

    # group by
    #  1. y_true (facet)
    #  2. y_predict / text class (color)
    for i, y_true in enumerate(unique_y_true):
        facet_df = df[df["y_true"] == y_true]

        for y_predict, class_color in zip(
            sorted(facet_df["y_predict"].unique()), ["CornflowerBlue", "Tomato"]
        ):
            facet_color_df = facet_df[facet_df["y_predict"] == y_predict]

            violin_side = "positive" if y_predict else "negative"

            fig.add_scatter(
                x=facet_color_df["text"],
                y=facet_color_df["similarity"],
                marker=dict(color=class_color, size=3),
                hovertext=facet_color_df["clip"],
                mode="markers",
                name=f"y_predict={str(y_predict)}",
                legendgroup=f"y_true={str(y_true)}",
                legendgrouptitle=dict(text=f"y_true={str(y_true)}"),
                row=1,
                col=i * 2 + 1,
            )
            fig.update_layout(**{f"xaxis{i*2+1}": dict(title="text")})
            fig.add_violin(
                x=np.repeat(str(y_true), len(facet_color_df)),
                y=facet_color_df["similarity"],
                box=dict(visible=True),
                scalegroup=str(y_true),
                scalemode="count",
                width=1,
                meanline=dict(visible=True),
                side=violin_side,
                marker=dict(color=class_color),
                showlegend=False,
                row=1,
                col=i * 2 + 2,
            )

    fig.update_layout(height=900, violingap=0, violinmode="overlay")
    fig.show()

    groupby_classification = df.groupby(["clip", "y_predict"])["similarity"]
    weighted_similarity = groupby_classification.sum() / groupby_classification.count()

    ratio = weighted_similarity.groupby(level="clip").aggregate(
        lambda s: s.loc[:, True] / s.loc[:, False]
    )

    ratio_df = ratio.to_frame("ratio")
    ratio_df["y_true"] = alarms_series.loc[ratio_df.index]

    fig = px.scatter(
        ratio_df.sort_values(["y_true", "ratio"]),
        y="ratio",
        color="y_true",
        render_mode="line",
        marginal_y="violin",
        height=900,
    )
    fig.show()

    fpr, tpr, thresholds = roc_curve(ratio_df["y_true"], ratio_df["ratio"])
    auc_score = roc_auc_score(ratio_df["y_true"], ratio_df["ratio"])

    roc_df = pd.DataFrame(
        {
            "False Positive Rate": fpr,
            "True Positive Rate": tpr,
        },
        columns=pd.Index(["False Positive Rate", "True Positive Rate"], name="Rate"),
        index=pd.Index(thresholds, name="Thresholds"),
    )
    fig = px.line(
        roc_df,
        x="False Positive Rate",
        y="True Positive Rate",
        title=f"{VARIATION} - AUC: {auc_score:.5f}",
        color_discrete_sequence=["orange"],
        range_x=[0, 1],
        range_y=[0, 1],
        width=600,
        height=450,
    ).add_shape(type="line", line=dict(dash="dash"), x0=0, x1=1, y0=0, y1=1)

    fig.show()


# Reuse the best reference prompts
TEXTS_TRUE = [
    "human",
    "a video of a human approaching",
    "human going through",
    "human cutting",
]
TEXTS_FALSE = ["birds flying", "plastic bag laying on the floor", "rabbits", "animals"]
# assert len(TEXTS_TRUE) == len(TEXTS_FALSE)
TEXTS = TEXTS_TRUE + TEXTS_FALSE
TEXT_CLASSIFICATIONS = [True] * len(TEXTS_TRUE) + [False] * len(TEXTS_FALSE)

infer(TEXTS, TEXT_CLASSIFICATIONS, VARIATION_NAMES[1])

The bottom left looks much better, as it classifies better the background vs alarm clips.