## Settings...

In [1]:
from subprocess import call
from urllib import urlretrieve
import os
import gzip
import binascii
import struct
import numpy
import sys
from matplotlib import pyplot

%matplotlib inline

save_dir = '/notebooks/mnist'

filenames = [
    'train-images-idx3-ubyte.gz', #:  training set images (9912422 bytes) 
    'train-labels-idx1-ubyte.gz', #:  training set labels (28881 bytes) 
    't10k-images-idx3-ubyte.gz',  #:   test set images (1648877 bytes) 
    't10k-labels-idx1-ubyte.gz'   #:   test set labels (4542 bytes)
]




## Grab the MNist files...

In [2]:
if not os.path.exists(save_dir):
    os.makedirs(save_dir)
for f in filenames:
    fullpath = os.path.join(save_dir, f)
    downloadurl = 'http://yann.lecun.com/exdb/mnist/' + f
    print "Path", fullpath
    if os.path.exists(fullpath):
        print "Nothing to do."
    else:
        print "Downloading", downloadurl, "..."
        urlretrieve(downloadurl, fullpath)
        print "... done"

Path /notebooks/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz ...
... done
Path /notebooks/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz ...
... done
Path /notebooks/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz ...
... done
Path /notebooks/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz ...
... done


## Load data file and parse

In [3]:
%matplotlib inline

def progress_dot():
    sys.stdout.write('.')
    sys.stdout.flush()

def unpack_int(data_file):
    data = data_file.read(4)
    data = struct.unpack('>i', data)[0]
    return data

def unpack_image(data_file):
    img = data_file.read(28 * 28)
    img = struct.unpack('B' * (28 * 28), img)
    img = numpy.array(img)
    img = (img - 128) / 255.0
    return img

def pyplot_images(image_list):
    histogram = len(image_list) == 1
    plot_count = 2 if histogram else len(image_list)
    _, plots = pyplot.subplots(1, plot_count)
    for i in range(len(image_list)):
        img = image_list[i]
        plot = plots[i]
        plot.imshow(img.reshape(28, 28), cmap=pyplot.cm.Greys)
        if histogram:
            histplot = plots[i + 1]
            histplot.hist(img, bins=20, range=[-1.0,1.0])

def unpack_image_data(filename):
    print "Unpacking image", filename, "..."
    with gzip.open(os.path.join(save_dir, filename)) as data_file:
        fields = {
            'magic_number': unpack_int(data_file),
            'image_count': unpack_int(data_file),
            'rows': unpack_int(data_file),
            'columns': unpack_int(data_file),
        }
        image_count = fields['image_count']
        image_list = []
        print "Reading", image_count, "images:"
        while(image_count > len(image_list)):
            if(len(image_list) % 1000 == 0):
                progress_dot()
            img = unpack_image(data_file)
            image_list.append(img)
        print "done"
        return fields, image_list

def unpack_labels(filename):
    print "Unpacking labels", filename, "..."
    with gzip.open(os.path.join(save_dir, filename)) as data_file:
        fields = {
            'magic_number': unpack_int(data_file),
            'label_count': unpack_int(data_file)
        }
        print "Reading", fields['label_count'], 'labels:'
        label_data = data_file.read(fields['label_count'])
        label_list = struct.unpack('B' * fields['label_count'], label_data)
    print "... done"
    return fields, label_list

image_fields, image_list = unpack_image_data('train-images-idx3-ubyte.gz')
label_fields, label_list = unpack_labels('train-labels-idx1-ubyte.gz')
training = {
    'image_fields': image_fields,
    'image': image_list,
    'label_fields': label_fields,
    'label': label_list
}

image_fields, image_list = unpack_image_data('t10k-images-idx3-ubyte.gz')
label_fields, label_list = unpack_labels('t10k-labels-idx1-ubyte.gz')
testing = {
    'image_fields': image_fields,
    'image': image_list,
    'label_fields': label_fields,
    'label': label_list
}

print "##Training", training['image_fields']['count'], "images:"
pyplot_images([training['image'][0], training['image'][1], training['image'][2]])
print "##Testing", testing['image_fields']['count'], "images:"
pyplot_images([testing['image'][0], testing['image'][1], testing['image'][2]])


Unpacking image train-images-idx3-ubyte.gz ...
Reading 60000 images:
............................................................done
Unpacking labels train-labels-idx1-ubyte.gz ...
Reading 60000 labels:
... done
Unpacking image t10k-images-idx3-ubyte.gz ...
Reading 10000 images:
..........done
Unpacking labels t10k-labels-idx1-ubyte.gz ...
Reading 10000 labels:
... done
##Training

KeyError: 'count'

In [None]:
import tensorflow

x = tensorflow.placeholder(tensorflow.float, [None, 28 * 28])
w = tensorflow.Variable(tensorflow.zeros([784, 10]))
b = tensorflow.Variable(tensorflow.zeros([10]))
y = tensorflow.nn.softmax(tensorflow.matmul(x, w) + b)
y_ = tensorflow.placeholder(tensorflow.float, [None, 10])
cross_entropy = tensorflow.reduce_mean(-tensorflow.reduce_sum(y_ * tensorflow.log(y), reduction_indices=[1]))

train_step = tensorflow.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)
init = tensorflow.init_all_variables();
sess = tensorflow.Session()
sess.run(init)

for i in range(1000):
    batch_xs, batch_ys = mnist.train.next_batch(100)

In [None]:
## FOOBAR