In [None]:
import json
import os
import sys

import pandas as pd
import numpy as np
import seaborn as sns

from PIL import Image

from itertools import groupby
from operator import itemgetter

In [None]:
def read_partial_json(path):
    with open(path, "r") as f:
        return json.loads("[" + ",".join([f for f in f.readlines() if f.strip()]) + "]")

In [None]:
statistics = read_partial_json("../statistics_first")

In [None]:
def fixup_statistics(statistics):
    """Ensure that every second set of batches in epoch is 'validation' mode."""
    in_train = True
    running_batch_index = 0
    
    for s in statistics:
        if s['batch_index'] < running_batch_index:
            in_train = not in_train
        
        if not in_train:
            s['mode'] = 'validation'
        
        running_batch_index = s['batch_index']

In [None]:
fixup_statistics(statistics)

In [None]:
train_statistics = [
    (s['statistics']['mIoU'], s['statistics']['loss'])
    for s in statistics if s['mode'] == 'train'
]
val_statistics = [
    (s['statistics']['mIoU'], s['statistics']['loss'])
    for s in statistics if s['mode'] == 'validation'
]
train_statistics_by_epoch = [
    ((s['epoch'], s['statistics']['mIoU']), (s['epoch'], s['statistics']['loss']))
    for s in statistics if s['mode'] == 'train'
]
val_statistics_by_epoch = [
    ((s['epoch'], s['statistics']['mIoU']), (s['epoch'], s['statistics']['loss']))
    for s in statistics if s['mode'] == 'validation'
]

In [None]:
train_mious, train_losses = list(zip(*train_statistics))
val_mious, val_losses = list(zip(*val_statistics))
train_mious_by_epoch, train_losses_by_epoch = list(zip(*train_statistics_by_epoch))
val_mious_by_epoch, val_losses_by_epoch = list(zip(*val_statistics_by_epoch))

In [None]:
sns.lineplot(*list(zip(*train_mious_by_epoch))).set(
    xlabel='Epoch',
    ylabel='mIoU',
    title='Train mIoU (with error margins)'
)

In [None]:
sns.lineplot(*list(zip(*val_mious_by_epoch))).set(
    xlabel='Epoch',
    ylabel='mIoU',
    title='Validation mIoU (with error margins)'
)

In [None]:
sns.lineplot(*list(zip(*train_losses_by_epoch))).set(
    xlabel='Epoch',
    ylabel='Loss',
    title='Training Loss (with error margins)'
)

In [None]:
sns.lineplot(*list(zip(*val_losses_by_epoch))).set(
    xlabel='Epoch',
    ylabel='Loss',
    title='Validation Loss (with error margins)'
)

In [None]:
def visualize_change_in_segmentations(segmentations, image_id, epochs):
    """Visualize the change in segmentations over the specified epochs"""
    fig, ax = sns.mpl.pyplot.subplots(nrows=len(epochs), ncols=3, figsize=(10,10))
    source = Image.open(os.path.join(segmentations, 'image_{}.input.png'.format(image_id)))
    label = Image.open(os.path.join(segmentations, 'image_{}.label.png'.format(image_id)))
    for i, epoch in enumerate(epochs):
        epoch_output = Image.open(os.path.join(segmentations, 'image_{}.epoch.{:02d}.png'.format(image_id, epoch)))
        ax[i][0].imshow(source)
        ax[i][1].imshow(epoch_output)
        ax[i][2].imshow(label)
        
        ax[i][0].set_ylabel('Epoch {}'.format(epoch), rotation=0, size='large')
        
    ax[0][0].set_title('Source Image')
    ax[0][1].set_title('Network Output')
    ax[0][2].set_title('Label')
    
    fig.show()

In [None]:
visualize_change_in_segmentations('../logs/segmentations', 1, [1, 2, 3, 4, 5, 6, 7])

In [None]:
def show_best_and_worst(interesting, epochs):
    """Visualize the best and worst segementations over the specified epochs"""
    fig, ax = sns.mpl.pyplot.subplots(nrows=len(epochs) * 3, ncols=8, figsize=(15,10))
    for i, epoch in enumerate(epochs):
        ax[i * 3 + 0][0].imshow(Image.open(os.path.join(interesting, 'image.worst.0.epoch{:02d}.input.png'.format(epoch))))
        ax[i * 3 + 0][1].imshow(Image.open(os.path.join(interesting, 'image.worst.1.epoch{:02d}.input.png'.format(epoch))))
        ax[i * 3 + 0][2].imshow(Image.open(os.path.join(interesting, 'image.worst.2.epoch{:02d}.input.png'.format(epoch))))
        ax[i * 3 + 0][3].imshow(Image.open(os.path.join(interesting, 'image.middle.0.epoch{:02d}.input.png'.format(epoch))))
        ax[i * 3 + 0][4].imshow(Image.open(os.path.join(interesting, 'image.middle.1.epoch{:02d}.input.png'.format(epoch))))
        ax[i * 3 + 0][5].imshow(Image.open(os.path.join(interesting, 'image.best.0.epoch{:02d}.input.png'.format(epoch))))
        ax[i * 3 + 0][6].imshow(Image.open(os.path.join(interesting, 'image.best.1.epoch{:02d}.input.png'.format(epoch))))
        ax[i * 3 + 0][7].imshow(Image.open(os.path.join(interesting, 'image.best.2.epoch{:02d}.input.png'.format(epoch))))

        ax[i * 3 + 1][0].imshow(Image.open(os.path.join(interesting, 'image.worst.0.epoch{:02d}.segmentation.png'.format(epoch))))
        ax[i * 3 + 1][1].imshow(Image.open(os.path.join(interesting, 'image.worst.1.epoch{:02d}.segmentation.png'.format(epoch))))
        ax[i * 3 + 1][2].imshow(Image.open(os.path.join(interesting, 'image.worst.2.epoch{:02d}.segmentation.png'.format(epoch))))
        ax[i * 3 + 1][3].imshow(Image.open(os.path.join(interesting, 'image.middle.0.epoch{:02d}.segmentation.png'.format(epoch))))
        ax[i * 3 + 1][4].imshow(Image.open(os.path.join(interesting, 'image.middle.1.epoch{:02d}.segmentation.png'.format(epoch))))
        ax[i * 3 + 1][5].imshow(Image.open(os.path.join(interesting, 'image.best.0.epoch{:02d}.segmentation.png'.format(epoch))))
        ax[i * 3 + 1][6].imshow(Image.open(os.path.join(interesting, 'image.best.1.epoch{:02d}.segmentation.png'.format(epoch))))
        ax[i * 3 + 1][7].imshow(Image.open(os.path.join(interesting, 'image.best.2.epoch{:02d}.segmentation.png'.format(epoch))))

        ax[i * 3 + 2][0].imshow(Image.open(os.path.join(interesting, 'image.worst.0.epoch{:02d}.label.png'.format(epoch))))
        ax[i * 3 + 2][1].imshow(Image.open(os.path.join(interesting, 'image.worst.1.epoch{:02d}.label.png'.format(epoch))))
        ax[i * 3 + 2][2].imshow(Image.open(os.path.join(interesting, 'image.worst.2.epoch{:02d}.label.png'.format(epoch))))
        ax[i * 3 + 2][3].imshow(Image.open(os.path.join(interesting, 'image.middle.0.epoch{:02d}.label.png'.format(epoch))))
        ax[i * 3 + 2][4].imshow(Image.open(os.path.join(interesting, 'image.middle.1.epoch{:02d}.label.png'.format(epoch))))
        ax[i * 3 + 2][5].imshow(Image.open(os.path.join(interesting, 'image.best.0.epoch{:02d}.label.png'.format(epoch))))
        ax[i * 3 + 2][6].imshow(Image.open(os.path.join(interesting, 'image.best.1.epoch{:02d}.label.png'.format(epoch))))
        ax[i * 3 + 2][7].imshow(Image.open(os.path.join(interesting, 'image.best.2.epoch{:02d}.label.png'.format(epoch))))

        ax[i * 3 + 0][0].set_ylabel('E{} Input'.format(epoch), rotation=0, size='large')
        ax[i * 3 + 1][0].set_ylabel('E{} Seg'.format(epoch), rotation=0, size='large')
        ax[i * 3 + 2][0].set_ylabel('E{} Label'.format(epoch), rotation=0, size='large')

    ax[0][0].set_title('Worst 1')
    ax[0][1].set_title('Worst 2')
    ax[0][2].set_title('Worst 3')

    ax[0][3].set_title('Median 1')
    ax[0][4].set_title('Median 2')

    ax[0][5].set_title('Best 1')
    ax[0][6].set_title('Best 2')
    ax[0][7].set_title('Best 3')

    fig.show()

In [None]:
show_best_and_worst('../logs/interesting', [1, 4, 7])