In [None]:
import json
import os
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from scipy.stats import norm

In [None]:
def read_json_file(json_file):
    with open(json_file) as f:
        data = json.load(f)
    return data


def get_zeros(data):
    zeros = []
    for i in range(len(data["scores"])):
        if data["scores"][i]["label"] == 0:
            zeros.append(data["scores"][i]["score"])
    return zeros


def get_ones(data):
    ones = []
    for i in range(len(data["scores"])):
        if data["scores"][i]["label"] == 1:
            ones.append(data["scores"][i]["score"])
    return ones


def get_labels(data):
    labels = []
    for i in range(len(data["scores"])):
        labels.append(data["scores"][i]["label"])
    return labels

In [None]:
def plot_scores(stats_dir):
    stats_files = [str(p) for p in Path(stats_dir).rglob("*.json")]
    num_rows = int(np.ceil(len(stats_files) / 3))

    for i, filename in enumerate(stats_files):
        data = read_json_file(filename)
        zeros = get_zeros(data)
        ones = get_ones(data)
        title = filename.split(".")[0]

        mean_zeros = np.mean(zeros)
        mean_ones = np.mean(ones)
        std_zeros = np.std(zeros)
        std_ones = np.std(ones)

        step = len(zeros)
        x = np.linspace(0, 1, step)

        pdf_zeros = norm.pdf(x, mean_zeros, std_zeros)
        pdf_ones = norm.pdf(x, mean_ones, std_ones)
        idx = np.argwhere(np.diff(np.sign(pdf_zeros - pdf_ones))).flatten()

        sns.set_context("paper")
        sns.set_palette("Set2")

        plt.suptitle("Trained and Tested on Replay", fontsize=16)
        plt.rcParams["figure.figsize"] = (20, 10)
        plt.subplots_adjust(wspace=0.25, hspace=0.25)
        plt.subplot(num_rows, 3, i + 1, title=title, xlabel="scores")

        sns.distplot(zeros, hist=False, kde=True, kde_kws={"shade": True, "linewidth": 2})
        sns.distplot(ones, hist=False, kde=True, kde_kws={"shade": True, "linewidth": 2})

        plt.axvline(x=x[idx], color="black", linestyle="--", linewidth=1)
        plt.text(
            x[idx] + 0.01,
            2,
            f"th_eer={x[idx][0]:.3f}",
            rotation=90,
            fontsize=10,
        )

        plt.legend(["Spoof", "Real"], loc="upper right", fontsize=10)

        if "test" in filename:
            plt.title(
                f"test, eer={data['m_eer']:.3f}, acc={data['m_acc']:.3f}",
                fontsize=10,
            )
        else:
            epoch = os.path.basename(filename).split(".")[0]
            plt.title(
                f"train, {epoch}, eer={data['m_eer']:.3f}, acc={data['m_acc']:.3f}",
                fontsize=10,
            )

    plt.savefig(f"{stats_dir}/plots.png")

In [None]:
stats_dir = "../logs/stats"
plot_scores(stats_dir)