In [13]:
from weight_watcher import WeightWatcherService, WeightWatcherResultService
from analysis import AnalysisService

In [14]:
weight_watcher_service = WeightWatcherService()
analysis_result_repository = WeightWatcherResultService('results')
analysis_service = AnalysisService()

In [15]:
results = analysis_result_repository.load_all()

In [16]:
from dataclasses import asdict
from typing import List, Dict, Any
from weight_watcher import WeightWatcherResult


def extract_summary_metrics(ww_results: List[WeightWatcherResult]) -> Dict[str, Any]:
    metrics = dict()
    for ww_result in ww_results:
        result_dict = asdict(ww_result.summary)
        for key in result_dict:
            if key not in metrics:
                metrics[key] = []
            metrics[key].append(result_dict[key])
    return metrics

In [17]:
def extract_data_from_results(ww_results):
    metrics = extract_summary_metrics(ww_results)
    accuracies = [ww_result.model_accuracy for ww_result in ww_results]
    identifications = [ww_result.model_identification for ww_result in ww_results]
    names = [f"{identification.architecture.name}/{identification.variant.name}" for identification in
             identifications]
    return metrics, accuracies, names

In [18]:
from plotly.subplots import make_subplots
import plotly.graph_objects as go

def add_plots(metrics, accuracies, names, figure_title):
    fig = make_subplots(rows=6, cols=1)
    for index, key in enumerate(metrics):
        row = (index % 6) + 1
        col = (index % 1) + 1
        fig.add_trace(go.Scatter(
            x=metrics[key],
            y=accuracies,
            text=names,
            mode="markers+text",
            marker=dict(
                size=10
            ),
        ), row=row, col=col)
        fig.update_xaxes(title_text=key, row=row, col=col)
    fig.update_yaxes(title_text="accuracy")
    fig.update_layout(title_text=figure_title, height=4000)
    fig.update_traces(textposition='top center')
    fig.show()

In [None]:
add_plots(*extract_data_from_results(results), "Correlation with accuracy [All models]")
architectures = {}
for result in results:
    architecture_name = result.model_identification.architecture.name
    if architecture_name not in architectures:
        architectures[architecture_name] = []
    architectures[architecture_name].append(result)

for architecture_name, architecture_results in architectures.items():
    add_plots(*extract_data_from_results(results), f"Correlation with accuracy [{architecture_name}]")

foo
