In [None]:
import numpy as np
import os
#download database of images from MNIST website
def load_dataset():
  # download specified file from Yann Lecun's website and store it on local disk
  def download(filename, source='http://yann.lecun.com/exdb/mnist/'):
    print("Downloading ",filename)
    import urllib
    urllib.urlretrieve(source + filename, filename)

  import gzip

  def load_mnist_images(filename):
    #check if specified file exists on local disk, if not download the file
    if not os.path.exists(filename):
      download(filename)
    
    #open zip file of images
    with gzip.open(filename, 'rb') as f:
      data = np.frombuffer(f.read(), np.uint8, offset =  16) #some boilerplate to extract data from the zip file

      #problem 1: data is in form of 1d array
      #need take this array and convert it into images
      #Each image has 28x28 pixels, its a monochrome image so only 1 channel
      #(full-colour would have 3/4 channels(rgb))

      #solution: data is currently a numpy array, so need reshape to array of 28x28 images
      data = data.reshape(-1,1,28,28)
      #dimension 1: # of images, -1 means infer from value of other parameters
      #dimension 2: # of channeels
      #dimension 3 and 4: size of the image (28x28) 

      #problem 2: numbers in form of bytes
      return data/np.float32(256) #converts byte value to a float32 in the range of [0,1]
    
  def load_mnist_labels(filename):
    #read labels which are in a binary file
    if not os.path.exists(filename):
      download(filename)
    with gzip.open(filename, 'rb') as f:
      data = np.frombuffer(f.read(), np.uint8, offset = 8) 
      #this gives a numpy array of integers, the digit corresponds
      #to the images obtained above

    return data

  X_train = load_mnist_images('train-images-idx3-ubyte.gz')
  Y_train = load_mnist_labels('train-labels-idx1-ubyte.gz')
  X_test = load_mnist_images('t10k-images-idx3-ubyte.gz')
  Y_test = load_mnist_labels('t10k-labels-idx1-ubyte.gz')

  return X_train, Y_train, X_test, Y_test

In [None]:
X_train, Y_train, X_test, Y_test = load_dataset()

In [None]:
import matplotlib
matplotlib.use('TkAgg')

import matplotlib.pyplot as plt
plt.show(block = plt.imshow(X_train[12][0]))