In [11]:
import chart_studio.plotly as py
import plotly.graph_objects as go

In [12]:
import os

In [13]:
from weight_watcher import WeightWatcherService, WeightWatcherResultService
from analysis import AnalysisService

In [14]:
weight_watcher_service = WeightWatcherService()
analysis_result_repository = WeightWatcherResultService('results')
analysis_service = AnalysisService()

In [15]:
results = analysis_result_repository.load_all()

In [16]:
from dataclasses import asdict
from typing import List, Dict, Any
from weight_watcher import WeightWatcherResult


def extract_summary_metrics(ww_results: List[WeightWatcherResult]) -> Dict[str, Any]:
    metrics = dict()
    for ww_result in ww_results:
        result_dict = asdict(ww_result.summary)
        for key in result_dict:
            if key not in metrics:
                metrics[key] = []
            metrics[key].append(result_dict[key])
    return metrics

In [17]:
def extract_data_from_results(results):
    metrics = extract_summary_metrics(results)
    accuracies = [result.model_accuracy for result in results]
    identifications = [result.model_identification for result in results]
    names = [f"{identification.architecture.name}/{identification.variant.name}" for identification in
             identifications]
    return metrics, accuracies, names

In [18]:
from plotly.subplots import make_subplots

def add_plots(data, metrics, accuracies, names, figure_title):
    for metric in metrics:
        scatter_plot = go.Scatter(
            x=metrics[metric],
            y=accuracies,
            text=names,
            mode="markers+text",
            marker=dict(
                size=10
            )
        )
        scatter_plot.update(textposition='top center')
        data.append(scatter_plot)

In [None]:
data = []
add_plots(data, *extract_data_from_results(results), "Correlation with accuracy")
architectures = {}
for result in results:
    architecture_name = result.model_identification.architecture.name
    if architecture_name not in architectures:
        architectures[architecture_name] = []
    architectures[architecture_name].append(result)

for architecture_name, architecture_results in architectures.items():
    add_plots(data, *extract_data_from_results(results), f"Correlation with accuracy [{architecture_name}]")

layout = go.Layout(title="Average Earnings for Graduates",
                xaxis=dict(title='School'),
                yaxis=dict(title='Salary (in thousands)'))

fig = go.Figure(data=data, layout=layout)

print("foo")

py.iplot(fig, sharing='private', filename='jupyter-styled_bar')


foo


In [None]:
print("foo")