In [None]:
import re
import os
import math
import imghdr
from PIL import Image
from textwrap import wrap
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import ImageGrid

%matplotlib inline
%load_ext autoreload
%autoreload 2

In [None]:
def get_label(filename, pattern):
    matches = re.findall(pattern, os.path.basename(filename))
    if matches:
        return matches[0]
    else:
        return None

In [None]:
def is_image(filename):
    return imghdr.what(filename) is not None

In [None]:
def get_labeled_images(path, pattern):
    full_file_paths = [os.path.join(path, file) for file in os.listdir(path)]
    image_files = [file for file in full_file_paths if is_image(file)]
    labels = [get_label(image, pattern) for image in image_files]
    images, labels = zip(*[(Image.open(file), label) for file, label in zip(image_files, labels) if label is not None])
    return images, labels

In [None]:
def make_grid_plot(images, labels, images_per_row=5, figsize=None, label_wrap_width=15,):
    cols = images_per_row
    rows = math.ceil(len(images) / cols)
    grid = plt.figure(figsize=figsize)
    for i, (image, label) in enumerate(zip(images, labels)):
        ax = grid.add_subplot(rows, cols, i+1)
        ax.imshow(image)
        ax.set_title('\n'.join(wrap(label, width=label_wrap_width)))
        ax.axis('off')
    return grid

In [None]:
# parameters
path = 'Report/Figures/Experiments/DI + HiddenSubnets'
regex = r'\d*_([^\.,]*).*'
figsize = (10, 6)
label_wrap_width = 15
output_file = os.path.join(os.path.dirname(path), 'DI + HiddenSubnets.png')
print(output_file)

In [None]:
images, labels = get_labeled_images(path, regex)
grid = make_grid_plot(images, labels, figsize=figsize, label_wrap_width=label_wrap_width)
grid.show()
grid.savefig(output_file, bbox_inches='tight')