In [None]:
import json
import os
import sys

import pandas as pd
import numpy as np
import seaborn as sns

from PIL import Image

from itertools import groupby
from operator import itemgetter

In [None]:
pd.set_option('precision', 3)

In [None]:
def read_partial_json(path):
    with open(path, "r") as f:
        return json.loads("[" + ",".join([f for f in f.readlines() if f.strip()]) + "]")

In [None]:
EXPERIMENT = os.environ.get("EXPERIMENT", "None")
INPUT_DIR = os.path.expanduser("~/Downloads/experiments/{}").format(EXPERIMENT)
OUTPUT_DIR = "reports/{}".format(EXPERIMENT)

In [None]:
os.makedirs(OUTPUT_DIR, exist_ok=True)

In [None]:
statistics = read_partial_json(os.path.join(INPUT_DIR, "logs", "statistics"))

In [None]:
def fixup_statistics_restarts(statistics):
    """Fix errors caused by restarts in training process."""
    in_train = True
    running_batch_index = 0
    running_epoch_index = -1
    running_epoch = 0

    epoch_statistics = {
        'train': [],
        'validation': []
    }
    mode = None
    
    for s in statistics:
        if s['epoch'] > running_epoch and s['mode'] == 'train':
            yield from epoch_statistics['train']
            yield from epoch_statistics['validation']
            epoch_statistics = {
                'train': [],
                'validation': []
            }
            running_batch_index = 0

        # We reset the batch index without changing modes.
        # That's bad. Reset what we had on this mode
        if s['batch_index'] < running_batch_index and s['mode'] == mode:
            epoch_statistics[mode] = []

        mode = s['mode']
        running_batch_index = s['batch_index']
        epoch_statistics[mode].append(s)

In [None]:
def fixup_statistics(statistics):
    """Ensure that every second set of batches in epoch is 'validation' mode."""
    in_train = True
    running_batch_index = 0
    
    for s in statistics:
        if s['batch_index'] < running_batch_index:
            in_train = not in_train
        
        if not in_train:
            s['mode'] = 'validation'
        
        running_batch_index = s['batch_index']

In [None]:
# fixup_statistics(statistics)
statistics = list(fixup_statistics_restarts(statistics))

In [None]:
train_statistics = [
    (s['statistics']['mIoU'], s['statistics']['loss'])
    for s in statistics if s['mode'] == 'train'
]
val_statistics = [
    (s['statistics']['mIoU'], s['statistics']['loss'])
    for s in statistics if s['mode'] == 'validation'
]
train_statistics_by_epoch = [
    ((s['epoch'], s['statistics']['mIoU']), (s['epoch'], s['statistics']['loss']))
    for s in statistics if s['mode'] == 'train'
]
val_statistics_by_epoch = [
    ((s['epoch'], s['statistics']['mIoU']), (s['epoch'], s['statistics']['loss']))
    for s in statistics if s['mode'] == 'validation'
]

In [None]:
train_mious, train_losses = list(zip(*train_statistics))
val_mious, val_losses = list(zip(*val_statistics))
train_mious_by_epoch, train_losses_by_epoch = list(zip(*train_statistics_by_epoch))
val_mious_by_epoch, val_losses_by_epoch = list(zip(*val_statistics_by_epoch))

In [None]:
def generate_lineplot(dataset, bbox=None, **kwargs):
    transposed = list(zip(*dataset))
    grouped_by_epoch = [
        {str(i): r for i, r in enumerate(result)}
        for result in [
            [g[1] for g in group]
            for e, group in groupby(dataset, key=itemgetter(0))
        ]
    ]
    df = pd.DataFrame(grouped_by_epoch)
    summary = pd.DataFrame(df.describe().max(axis=1))
    summary.columns = ["Summary"]
    plot = sns.lineplot(*transposed)
    plot.set(
        **kwargs
    )
    plot.table(cellText=[['{:.2f}'.format(d[0])] for d in summary.values],
               rowLabels=summary.index,
               colLabels=summary.columns,
               cellLoc='right',
               rowLoc='center',
               loc='right',
               bbox=bbox)
    return plot

In [None]:
train_mious_by_epoch_plot = generate_lineplot(train_mious_by_epoch,
                                              xlabel='Epoch',
                                              ylabel='mIoU',
                                              title='Train mIoU (with error margins)',
                                              bbox=[.65,.05,.3,.45])

In [None]:
val_mious_by_epoch_plot = generate_lineplot(val_mious_by_epoch,
                                            xlabel='Epoch',
                                            ylabel='mIoU',
                                            title='Validation mIoU (with error margins)',
                                            bbox=[.65,.05,.3,.45])

In [None]:
train_loss_by_epoch_plot = generate_lineplot(train_losses_by_epoch,
                                             xlabel='Epoch',
                                             ylabel='Loss',
                                             title='Training Loss (with error margins)',
                                             bbox=[.65,.50,.3,.45])

In [None]:
val_loss_by_epoch_plot = generate_lineplot(val_losses_by_epoch,
                                           xlabel='Epoch',
                                           ylabel='Loss',
                                           title='Validation Loss (with error margins)',
                                           bbox=[.65,.50,.3,.45])

In [None]:
def visualize_change_in_segmentations(segmentations, image_id, epochs):
    """Visualize the change in segmentations over the specified epochs"""
    fig, ax = sns.mpl.pyplot.subplots(nrows=len(epochs), ncols=3, figsize=(12,20))
    source = Image.open(os.path.join(segmentations, 'image_{}.input.png'.format(image_id)))
    label = Image.open(os.path.join(segmentations, 'image_{}.label.png'.format(image_id)))
    for i, epoch in enumerate(epochs):
        epoch_output = Image.open(os.path.join(segmentations, 'image_{}.epoch.{:02d}.png'.format(image_id, epoch)))
        ax[i][0].imshow(source)
        ax[i][1].imshow(epoch_output)
        ax[i][2].imshow(label)
        
        ax[i][0].set_ylabel('Epoch {}'.format(epoch), rotation=0, size='large')
        
    ax[0][0].set_title('Source Image')
    ax[0][1].set_title('Network Output')
    ax[0][2].set_title('Label')
    
    fig.show()
    return fig

In [None]:
def int_all(array):
    return [int(np.round(a)) for a in array]

In [None]:
SEGMENTATION_DIR = os.path.join(INPUT_DIR, 'logs', 'interesting', 'segmentations')

In [None]:
validation_segmentation_0 = visualize_change_in_segmentations(SEGMENTATION_DIR,
                                                              0,
                                                              int_all(np.geomspace(1, 499, 6)))

In [None]:
validation_segmentation_1 = visualize_change_in_segmentations(SEGMENTATION_DIR,
                                                              1,
                                                              int_all(np.geomspace(1, 499, 6)))

In [None]:
validation_segmentation_2 = visualize_change_in_segmentations(SEGMENTATION_DIR,
                                                              2,
                                                              int_all(np.geomspace(1, 499, 6)))

In [None]:
def show_best_and_worst(interesting, epochs):
    """Visualize the best and worst segementations over the specified epochs"""
    fig, ax = sns.mpl.pyplot.subplots(nrows=len(epochs) * 3, ncols=8, figsize=(15,20))
    for i, epoch in enumerate(epochs):
        ax[i * 3 + 0][0].imshow(Image.open(os.path.join(interesting, 'image.worst.0.epoch{:02d}.input.png'.format(epoch))))
        ax[i * 3 + 0][1].imshow(Image.open(os.path.join(interesting, 'image.worst.1.epoch{:02d}.input.png'.format(epoch))))
        ax[i * 3 + 0][2].imshow(Image.open(os.path.join(interesting, 'image.worst.2.epoch{:02d}.input.png'.format(epoch))))
        ax[i * 3 + 0][3].imshow(Image.open(os.path.join(interesting, 'image.middle.0.epoch{:02d}.input.png'.format(epoch))))
        ax[i * 3 + 0][4].imshow(Image.open(os.path.join(interesting, 'image.middle.1.epoch{:02d}.input.png'.format(epoch))))
        ax[i * 3 + 0][5].imshow(Image.open(os.path.join(interesting, 'image.best.0.epoch{:02d}.input.png'.format(epoch))))
        ax[i * 3 + 0][6].imshow(Image.open(os.path.join(interesting, 'image.best.1.epoch{:02d}.input.png'.format(epoch))))
        ax[i * 3 + 0][7].imshow(Image.open(os.path.join(interesting, 'image.best.2.epoch{:02d}.input.png'.format(epoch))))

        ax[i * 3 + 1][0].imshow(Image.open(os.path.join(interesting, 'image.worst.0.epoch{:02d}.segmentation.png'.format(epoch))))
        ax[i * 3 + 1][1].imshow(Image.open(os.path.join(interesting, 'image.worst.1.epoch{:02d}.segmentation.png'.format(epoch))))
        ax[i * 3 + 1][2].imshow(Image.open(os.path.join(interesting, 'image.worst.2.epoch{:02d}.segmentation.png'.format(epoch))))
        ax[i * 3 + 1][3].imshow(Image.open(os.path.join(interesting, 'image.middle.0.epoch{:02d}.segmentation.png'.format(epoch))))
        ax[i * 3 + 1][4].imshow(Image.open(os.path.join(interesting, 'image.middle.1.epoch{:02d}.segmentation.png'.format(epoch))))
        ax[i * 3 + 1][5].imshow(Image.open(os.path.join(interesting, 'image.best.0.epoch{:02d}.segmentation.png'.format(epoch))))
        ax[i * 3 + 1][6].imshow(Image.open(os.path.join(interesting, 'image.best.1.epoch{:02d}.segmentation.png'.format(epoch))))
        ax[i * 3 + 1][7].imshow(Image.open(os.path.join(interesting, 'image.best.2.epoch{:02d}.segmentation.png'.format(epoch))))

        ax[i * 3 + 2][0].imshow(Image.open(os.path.join(interesting, 'image.worst.0.epoch{:02d}.label.png'.format(epoch))))
        ax[i * 3 + 2][1].imshow(Image.open(os.path.join(interesting, 'image.worst.1.epoch{:02d}.label.png'.format(epoch))))
        ax[i * 3 + 2][2].imshow(Image.open(os.path.join(interesting, 'image.worst.2.epoch{:02d}.label.png'.format(epoch))))
        ax[i * 3 + 2][3].imshow(Image.open(os.path.join(interesting, 'image.middle.0.epoch{:02d}.label.png'.format(epoch))))
        ax[i * 3 + 2][4].imshow(Image.open(os.path.join(interesting, 'image.middle.1.epoch{:02d}.label.png'.format(epoch))))
        ax[i * 3 + 2][5].imshow(Image.open(os.path.join(interesting, 'image.best.0.epoch{:02d}.label.png'.format(epoch))))
        ax[i * 3 + 2][6].imshow(Image.open(os.path.join(interesting, 'image.best.1.epoch{:02d}.label.png'.format(epoch))))
        ax[i * 3 + 2][7].imshow(Image.open(os.path.join(interesting, 'image.best.2.epoch{:02d}.label.png'.format(epoch))))

        ax[i * 3 + 0][0].set_ylabel('E{} Input'.format(epoch), rotation=0, size='large')
        ax[i * 3 + 1][0].set_ylabel('E{} Seg'.format(epoch), rotation=0, size='large')
        ax[i * 3 + 2][0].set_ylabel('E{} Label'.format(epoch), rotation=0, size='large')

    ax[0][0].set_title('Worst 1')
    ax[0][1].set_title('Worst 2')
    ax[0][2].set_title('Worst 3')

    ax[0][3].set_title('Median 1')
    ax[0][4].set_title('Median 2')

    ax[0][5].set_title('Best 1')
    ax[0][6].set_title('Best 2')
    ax[0][7].set_title('Best 3')

    fig.show()
    return fig

In [None]:
INTERESTING_DIR = os.path.join(INPUT_DIR, 'logs', 'interesting', 'interesting')

In [None]:
best_and_worst = show_best_and_worst(INTERESTING_DIR, int_all(np.geomspace(1, 499, 3)))

In [None]:
def write_experiment_results(output_dir, plots):
    for name, plot in plots.items():
        fig = plot.get_figure() if plot.get_figure() else plot
        fig.savefig(os.path.join(output_dir, name))

In [None]:
write_experiment_results(OUTPUT_DIR, {
    "train_mious.png": train_mious_by_epoch_plot,
    "val_mious.png": val_mious_by_epoch_plot,
    "train_loss.png": train_loss_by_epoch_plot,
    "val_loss.png": val_loss_by_epoch_plot,
    "best_images.png": best_and_worst,
    "validation_seg0.png": validation_segmentation_0,
    "validation_seg1.png": validation_segmentation_1,
    "validation_seg2.png": validation_segmentation_2
})