In [None]:
import numpy as np
import matplotlib.pyplot as plt

# Helper function to plot images and labels
def imshow(images, labels, pred=None, **kwargs):
    
    # Keyword arguments.
    smooth = kwargs.get('smooth', True)
    normalize = kwargs.get('normalize', 0.5)
    
    # Normalize image.
    images = images / 2 + normalize if normalize else images
    
    # Create figure with sub-plots.
    fig, axes = plt.subplots(4, 4)

    # Adjust vertical spacing if we need to print ensemble and best-net.
    wspace, hspace = 0.2, 0.8 if pred is not None else 0.4
    fig.subplots_adjust(hspace=hspace, wspace=wspace)

    for i, ax in enumerate(axes.flat):
        # Interpolation type.
        smooth = 'spline16' if smooth else 'nearest'

        # Plot image.

        img = images[i].view(28, 28).numpy()
        ax.imshow(img, interpolation=smooth, cmap='Greys')
            
        # Name of the true class.
        labels_name = classes[labels[i]]

        # Show true and predicted classes.
        if pred is None:
            xlabel = f'True: {labels_name}'
        else:
            # Name of the predicted class.
            pred_name = classes[pred[i]]
            
            xlabel = f'True: {labels_name}\nPred: {pred_name}'

        # Show the classes as the label on the x-axis.
        ax.set_xlabel(xlabel)
        
        # Remove ticks from the plot.
        ax.set_xticks([])
        ax.set_yticks([])
    
    # Ensure the plot is shown correctly with multiple plots
    # in a single Notebook cell.
    plt.show()


# Visualization function to visualize dataset.
def visualize(data, smooth=False):
    # Iterate over the data.
    data_iter = iter(data)
    
    # Unpack images and labels.
    images, labels = data_iter.next()
    
    # Free up memory
    del data_iter
    
    # Call to helper function for plotting images.
    imshow(images, labels=labels, smooth=smooth)



# Let's visualize some training set.
visualize(traindata)