#### Setup

In [None]:
import pickle
from sys import path
from pathlib import Path

import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.svm import SVC
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from IPython.display import clear_output

In [None]:
%cd "/content/drive/My Drive/archive/imecc/texture/data"

#### Visualization

In [None]:
class ModelVisualizer:
    def __init__(self):
        self.table = None
        self.metrics = {
            "database": [], 
            "feature": [], 
            "modifier": [],
            "classifier": [],
            "accuracy": []
        }

    def load(self, database="", method=""):
        get_modifier = lambda x: x[2] if len(x) == 4 else "raw"

        for filepath in self.query(database, method):
            with open(filepath, "rb") as file:
                print(filepath)
                model = pickle.load(file)
                tag = filepath.stem.split("+")
                self.metrics["database"].append(tag[0])
                self.metrics["feature"].append(tag[1])
                self.metrics["modifier"].append(get_modifier(tag))
                self.metrics["classifier"].append(tag[-1])
                self.metrics["accuracy"].append(model.best_score_)
                clear_output(wait=True)
        
        self.metrics = pd.DataFrame(self.metrics)
        return self

    def plot_db(self, database):
        plot_df = self.metrics[self.metrics.database == database]
        ax = sns.catplot(x="modifier", y="accuracy", hue="feature", 
                         kind="strip", data=plot_df)
        ax.set_xticklabels(rotation=30)
  
    # TODO: Fix pivot
    def plot_table(self, classifier):
        self.table = self.metrics[self.metrics.modifier == classifier]
        self.table = self.table.pivot("database", "modifier", "accuracy")
        return self.table
    
    @staticmethod
    def query(database="", method=""):
        query = f"*{database}+*+{method}*.pickle"
        filepaths = Path(".").glob(query) 
        return sorted(filepaths)

In [None]:
ModelVisualizer().load("brodatz", "svc").plot_db("brodatz")