In [1]:
import supervisely_lib as sly

from supervisely_lib.metric.iou_metric import IOU
from supervisely_lib.metric.map_metric import AP

import collections
import itertools
import os
from prettytable import PrettyTable

MetricInfo = collections.namedtuple('MetricInfo', 'metric_factory metric_id_with_title')

map_iou_threshold = 0.1

PREDICTION_SUFFIX = '_pred'

metric_infos = [
    MetricInfo(metric_factory=lambda _classes_mapping: sly.IoUMetric(_classes_mapping),
               metric_id_with_title={IOU: 'IoU'}),
    MetricInfo(metric_factory=lambda _classes_mapping: sly.MAPMetric(_classes_mapping,
                                                                     iou_threshold=map_iou_threshold,
                                                                     confidence_tag_name=(
                                                                             'confidence' + PREDICTION_SUFFIX)),
               metric_id_with_title={AP: 'mAP'}),
]

team_name = "jupyter_tutorials"
workspace_name = "metrics_tutorials"
agent_name = " Skinny Alpaca"

model_names = ['yolo_2_epochs', 'yolo_200_epochs']
project_names = ['lemons_annotated', 'lemons_annotated_copy']

inference_config = {
    'mode': {
        'model_classes': {
            'add_suffix': PREDICTION_SUFFIX
        }
    }
}

classes_mapping = {
    'kiwi': 'kiwi_bbox' + PREDICTION_SUFFIX,
    'lemon': 'lemon_bbox' + PREDICTION_SUFFIX
}

# ------------- End user settings -------------

address = os.environ['SERVER_ADDRESS']
token = os.environ['API_TOKEN']
api = sly.Api(address, token)

team = api.team.get_info_by_name(team_name)
if team is None:
    raise RuntimeError(f"Team {team_name!r} not found")

workspace = api.workspace.get_info_by_name(team.id, workspace_name)
if workspace is None:
    raise RuntimeError(f"Workspace {workspace_name!r} not found")

agent = api.agent.get_info_by_name(team.id, agent_name)
if agent is None:
    raise RuntimeError("Agent {!r} not found".format(agent_name))
if agent.status is api.agent.Status.WAITING:
    raise RuntimeError("Agent {!r} is not running".format(agent_name))

model_infos = [api.model.get_info_by_name(workspace.id, model_name) for model_name in model_names]
project_infos = [api.project.get_info_by_name(workspace.id, project_name) for project_name in project_names]

for model_name, model_info in zip(model_names, model_infos):
    if model_info is None:
        raise RuntimeError(f'Model {model_name!r} not found.')
for project_name, project_info in zip(project_names, project_infos):
    if project_info is None:
        raise RuntimeError(f'Model {project_name!r} not found.')

inference_project_names = {
    (project_idx, model_idx): '__'.join(['_model_eval', str(project_info.id), str(model_info.id)])
    for project_idx, project_info in enumerate(project_infos)
    for model_idx, model_info in enumerate(model_infos)}

for project_idx, project_info in enumerate(project_infos):
    for model_idx, model_info in enumerate(model_infos):
        inference_project_name = inference_project_names[project_idx, model_idx]
        inference_project_info = api.project.get_info_by_name(workspace.id, inference_project_name)
        if inference_project_info is not None:
            sly.logger.warn(f'Project {inference_project_name!r} already exists, reusing it and skipping inference.')
            # Inference project already exists. Assume that it has the correct inference data from a previous run.
            # Log a warning here.
            continue

        task_id = api.task.run_inference(
            agent.id, project_info.id, model_info.id, inference_project_name, inference_config=inference_config)
        print('Inference task (id={}) started'.format(task_id))
        api.task.wait(task_id, api.task.Status.FINISHED)
        print('Inference task (id={}) finished'.format(task_id))

def _make_metric_calculators(metric_infos, model_names, classes_mapping):
    return [[metric_info.metric_factory(_classes_mapping=classes_mapping)
             for metric_info in metric_infos] for _ in model_names]

overall_metric_calculators = _make_metric_calculators(metric_infos, model_names, classes_mapping)
per_project_metric_calculators = [
    _make_metric_calculators(metric_infos, model_names, classes_mapping) for _ in project_infos]

def _get_name(x):
    return x.name

for project_idx, project_info in enumerate(project_infos):
    meta_gt = sly.ProjectMeta.from_json(api.project.get_meta(project_info.id))
    for model_idx, model_info in enumerate(model_infos):
        inference_project_name = inference_project_names[project_idx, model_idx]
        inference_project_info = api.project.get_info_by_name(workspace.id, inference_project_name)
        meta_pred = sly.ProjectMeta.from_json(api.project.get_meta(inference_project_info.id))

        datasets_gt = sorted(api.dataset.get_list(project_info.id), key=_get_name)
        datasets_pred = sorted(api.dataset.get_list(inference_project_info.id), key=_get_name)

        for ds_gt, ds_pred in zip(datasets_gt, datasets_pred):
            images_gt = sorted(api.image.get_list(ds_gt.id), key=_get_name)
            images_pred = sorted(api.image.get_list(ds_pred.id), key=_get_name)

            for batch_gt, batch_pred in zip(sly.batched(images_gt), sly.batched(images_pred)):
                image_gt_ids = [image_info.id for image_info in batch_gt]
                image_pred_ids = [image_info.id for image_info in batch_pred]
                ann_gt_infos = api.annotation.download_batch(ds_gt.id, image_gt_ids)
                ann_pred_infos = api.annotation.download_batch(ds_pred.id, image_pred_ids)

                for ann_gt_info, ann_pred_info in zip(ann_gt_infos, ann_pred_infos):
                    ann_gt = sly.Annotation.from_json(ann_gt_info.annotation, meta_gt)
                    ann_pred = sly.Annotation.from_json(ann_pred_info.annotation, meta_pred)
                    for metric_idx in range(len(metric_infos)):
                        overall_metric_calculators[model_idx][metric_idx].add_pair(ann_gt, ann_pred)
                        per_project_metric_calculators[project_idx][model_idx][metric_idx].add_pair(ann_gt, ann_pred)


def _print_metrics(metric_calculators, title_line=None):
    total_metrics = [
        [m.get_total_metrics() for m in model_metric_calculators]
        for model_metric_calculators in metric_calculators]

    table = PrettyTable(field_names=['Models'] + [
        metric_title for metric_info in metric_infos for _, metric_title in metric_info.metric_id_with_title.items()])
    for model_name, all_model_totals in zip(model_names, total_metrics):
        model_values_nested = [[metric_totals[metric_id] for metric_id, _ in metric_info.metric_id_with_title.items()]
                               for metric_info, metric_totals in zip(metric_infos, all_model_totals)]
        model_values_flat = list(itertools.chain(*model_values_nested))
        table.add_row([model_name] + model_values_flat)
    if title_line is not None:
        print(title_line)
    print(table.get_string(), flush=True)

_print_metrics(overall_metric_calculators, title_line='Overall metrics:')

for project_info, project_metric_calculators in zip(project_infos, per_project_metric_calculators):
    _print_metrics(project_metric_calculators, title_line=f'Metrics for project {project_info.name!r}:')

Inference task (id=925) started


{"message": "Project '_model_eval__391__104' already exists, reusing it and skipping inference.", "timestamp": "2019-06-13T19:27:43.394Z", "level": "warn"}
{"message": "Project '_model_eval__393__103' already exists, reusing it and skipping inference.", "timestamp": "2019-06-13T19:27:43.437Z", "level": "warn"}
{"message": "Project '_model_eval__393__104' already exists, reusing it and skipping inference.", "timestamp": "2019-06-13T19:27:43.467Z", "level": "warn"}


Inference task (id=925) finished
Overall metrics:
+-----------------+---------------------+--------------------+
|      Models     |         IoU         |        mAP         |
+-----------------+---------------------+--------------------+
|  yolo_2_epochs  | 0.13149535709937687 | 0.2121212121212121 |
| yolo_200_epochs |  0.6669136976983602 |        1.0         |
+-----------------+---------------------+--------------------+
Metrics for project 'lemons_annotated':
+-----------------+---------------------+--------------------+
|      Models     |         IoU         |        mAP         |
+-----------------+---------------------+--------------------+
|  yolo_2_epochs  | 0.13149535709937687 | 0.2121212121212121 |
| yolo_200_epochs |  0.6669136976983602 |        1.0         |
+-----------------+---------------------+--------------------+
Metrics for project 'lemons_annotated_copy':
+-----------------+---------------------+--------------------+
|      Models     |         IoU         |     