In [1]:
import os
import cv2
import numpy as np
import matplotlib.pyplot as plt

from ipywidgets import Button, Output, HBox
from IPython.display import display, clear_output
from torch.utils.data import Dataset

In [2]:
root_path = "./data"
os.listdir(root_path)

['monet_jpg', 'monet_tfrec', 'photo_jpg', 'photo_tfrec']

In [3]:
class ImageExplorer:
    """ Simple tool for viewing images in notebooks

    Implements a n_rows by n_cols plt.figure with up and down buttons to change between rows

    """

    # --------------------------------------------------------------------------------
    def __init__(self,
                 dataset=None,
                 num_samples=None,
                 n_rows=3,
                 n_cols=3,
                 figsize=(12, 12)):
        """ Initialise the Image Explorer

        How to:
            explorer = ImageExplorer(dataset)
            explorer.show()

        Tips: Override any functionality for your own case, I left some comments in other functions
        that might be useful to provide a case specific behaviour

        ---
        Parameters
            dataset: Dataset class, can be anything (torch.Dataset recommended) as long as it implements len()
            num_samples: If dataset is None then num of samples can also be specified via this arg
            n_rows: Number of rows
            n_cols: Number of columns
            figsize: Figure size

        """
        if dataset is None and num_samples is None:
            raise ValueError("Either Dataset or num_samples must be given, both are None")

        self.dataset = dataset
        self.n_rows = n_rows
        self.n_cols = n_cols
        self.figsize = figsize
        self.total_images = n_rows * n_cols
        self.num_samples = num_samples if num_samples is not None else len(dataset)
        self.current_index = 0

        self.output_widget = Output()
        self.button_up = Button(description='▲')
        self.button_down = Button(description='▼')
        self.button_box = HBox([self.button_up, self.button_down], align_items='center')

        self._initialize_callbacks()

    # --------------------------------------------------------------------------------
    def _initialize_callbacks(self):
        """ Initialise callbacks for up and down buttons
        """

        self.button_up.on_click(self._on_up_click)
        self.button_down.on_click(self._on_down_click)

    # --------------------------------------------------------------------------------
    def _on_up_click(self, button):
        """ On up click
        """

        self.current_index -= self.n_cols
        self._update_display()

    # --------------------------------------------------------------------------------
    def _on_down_click(self, button):
        """ On down click
        """

        self.current_index += self.n_cols
        self._update_display()

    # --------------------------------------------------------------------------------
    def _update_display(self):
        """ Update figure
        """

        self.current_index = max(0, min(self.current_index, self.num_samples - self.total_images))
        with self.output_widget:
            clear_output(wait=True)
            self._display_images()

    # --------------------------------------------------------------------------------
    def _display_images(self):
        """ Show grid of images

        Tip: Override this function for your specific behaviour

        """

        fig, axs = plt.subplots(self.n_rows, self.n_cols, figsize=self.figsize)
        if self.n_rows == 1:
            axs = axs[np.newaxis, :]
        elif self.n_cols == 1:
            axs = axs[:, np.newaxis]

        for i in range(self.n_rows):
            for j in range(self.n_cols):
                idx = self.current_index + i * self.n_cols + j
                if idx < self.num_samples:
                    axs[i, j].imshow(self._get_image(idx=idx))
                    axs[i, j].axis('off')
        plt.show()

    # --------------------------------------------------------------------------------
    def _get_image(self, idx):
        """ Function that retrieves an image at specific idx

        This is just an example of a function that assumes that dataset 
        object implements __getitem__(idx), again, this is a simple PyTorch example

        Tip: Override this function for your specific behaviour. 
        This is called from _display_images and only requires to implement functionality
        that will return some object with which the plt.imshow() is happy with

        ---
        Parameters
            idx: Index of the sample

        ---
        Returns
            Anything that plt.imshow() is happy with

        """
        if self.dataset is None:
            raise ValueError("Dataset is None, can't retrieve a sample")

        return self.dataset[idx]

    # --------------------------------------------------------------------------------
    def show(self):
        """ Show the grid
        """

        display(self.output_widget)
        self._update_display()
        display(self.button_box)

In [4]:
class MonetDataset(Dataset):
    def __init__(self, 
                 root_path):
        
        # Paths
        self.root_path = root_path
        self.monet_paths = os.listdir(f"{root_path}/monet_jpg")
        
    def __len__(self):
        return len(self.monet_paths)
    
    def __getitem__(self, idx):
        """ Get one of the photos and a random monet
        """
        
        # Get a random monet path
        path_monet = f"{root_path}/monet_jpg/{self.monet_paths[idx]}"
        
        # Load Img
        x_monet = self.read_img(path_monet)
        
        return x_monet
        
    def read_img(self, path):
        """ Read img with cv2 and transform to RGB
        """
        return cv2.cvtColor(cv2.imread(path), cv2.COLOR_BGR2RGB)  

In [5]:
img_explorer = ImageExplorer(dataset=MonetDataset(root_path=root_path))


In [6]:
img_explorer.show()

Output()

HBox(children=(Button(description='▲', style=ButtonStyle()), Button(description='▼', style=ButtonStyle())))