In [None]:
#convert train.csv to lmdb
import os
import logging
import numpy as np
import pandas as pd
import lmdb
import cv2
import caffe
from caffe.proto import caffe_pb2

DATA_ROOT = '/home/shihuai02/caffe/examples/mnist'
join = os.path.join
TRAIN = join(DATA_ROOT, 'train.csv')
train_file = join(DATA_ROOT, 'mnist_train_lmdb1')
test_file = join(DATA_ROOT, 'mnist_test_lmdb1')

# logger
logger = logging.getLogger()
logger.setLevel(logging.DEBUG)
sh = logging.StreamHandler()
sh.setLevel(logging.DEBUG)
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
sh.setFormatter(formatter)
logger.addHandler(sh)

# load data from train.csv
logger.info('Load data from %s', TRAIN)
df = pd.read_csv(TRAIN)
data = df.values

logger.info('Get %d Rows in dataset', len(data))

# random shuffle
np.random.shuffle(data)
np.fromstring(data, dtype=np.uint8)

# all dataset
labels = data[:, 0]
images = data[:, 1:]

# process data
images = images.reshape((len(images), 1, 28, 28))

# train dataset number
trainset = len(labels) * 3 / 4

# train dataset
labels_train = labels[:trainset]
images_train = images[:trainset]
# test dataset
labels_test = labels[trainset:]
images_test = images[trainset:]    

batch_size = 1000

# create the leveldb file
lmdb_env = lmdb.open(train_file, map_size=int(1e12))
lmdb_txn = lmdb_env.begin(write=True)
datum = caffe_pb2.Datum()

item_id = -1
logger.info('Write train dataset to %s', train_file)
for x in range(trainset):
    item_id += 1

    # save in datum
    datum = caffe.io.array_to_datum(images_train[x], labels_train[x])
    keystr = '{:0>8d}'.format(item_id)
    lmdb_txn.put( keystr, datum.SerializeToString() )

    # write batch
    if(item_id + 1) % batch_size == 0:
        lmdb_txn.commit()
        lmdb_txn = lmdb_env.begin(write=True)
        print (item_id + 1)
        
if (item_id+1) % batch_size != 0:
    lmdb_txn.commit()
    print 'last train batch'
    print (item_id + 1) 

lmdb_env.close()
lmdb_env = lmdb.open(test_file, map_size=int(1e12))
lmdb_txn = lmdb_env.begin(write=True)
datum = caffe_pb2.Datum()

item_id = -1
logger.info('Write test dataset to %s', test_file)
for x in range(len(labels) - trainset):
    item_id += 1
    datum = caffe.io.array_to_datum(images_test[x], labels_test[x])
    keystr = '{:0>8d}'.format(item_id)
    lmdb_txn.put( keystr, datum.SerializeToString() )

    # write batch
    if(item_id + 1) % batch_size == 0:
        lmdb_txn.commit()
        lmdb_txn = lmdb_env.begin(write=True)
        print (item_id + 1)
        
if (item_id+1) % batch_size != 0:
    lmdb_txn.commit()
    print 'last test batch'
    print (item_id + 1)

lmdb_env.close()

logger.info('Done')

In [None]:
#predict
import os
import logging
import matplotlib.pyplot as plt  
import numpy as np
import caffe  
import pandas as pd

DATA_ROOT = '/home/shihuai02/caffe/examples/mnist/'
TEST = DATA_ROOT+'test.csv'
OUTPUT = DATA_ROOT+'result.csv'
CAFFE_MODEL = DATA_ROOT+'lenet_iter_10000.caffemodel'
CAFFE_SOLVER = DATA_ROOT+'lenet_deploy.prototxt'

logger = logging.getLogger()
logger.setLevel(logging.DEBUG)
sh = logging.StreamHandler()
sh.setLevel(logging.DEBUG)
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
sh.setFormatter(formatter)
logger.addHandler(sh)

# load test dataset
logger.info('Load test dataset from %s', TEST)
df = pd.read_csv(TEST)
data = df.values

np.fromstring(data, dtype=np.uint8)
#data = data.astype(np.uint8)

testNum = len(data);
data = data.reshape((len(data), 28, 28, 1))
data = data / 255.
logger.info( 'finish load and reshape data')

net = caffe.Classifier(CAFFE_SOLVER, CAFFE_MODEL)
logger.info('loaded model')

caffe.set_mode_cpu()

# predict
logger.info('start predict')
iter_k = 0
labels = []
while True:
    result = net.predict([data[iter_k]])
    labels.append(result[0].argmax())
    iter_k = iter_k + 1
    if iter_k == testNum:
        break
logger.info('Prediction Done')

# write to file
logger.info('Save result to %s', OUTPUT)
if os.path.exists(OUTPUT):
    os.remove(OUTPUT)

with open(OUTPUT, 'w') as fd:
    fd.write('ImageId,Label\n')
    for idx, label in enumerate(labels):
        fd.write(str(idx+1))
        fd.write(',')
        fd.write(str(label))
        fd.write('\n')
        
logger.info('Finished write result file')