In [1]:
from keras.datasets import mnist
import numpy as np
import pandas as pd
import random

Using TensorFlow backend.


In [13]:
NUM_CLASSES = 10
TRAIN_DATA_DIR = '../dataset/mnist-in-csv/mnist_train.csv'
TEST_DATA_DIR = '../dataset/mnist-in-csv/mnist_test.csv'
DATASET_ROOT_DIR = '../dataset/'

In [3]:
def create_pairs(x, digit_indices):
    '''Positive and negative pair creation.
    Alternates between positive and negative pairs.
    '''
    pairs = []
    labels = []
    n = min([len(digit_indices[d]) for d in range(NUM_CLASSES)]) - 1
    for d in range(NUM_CLASSES):
        for i in range(n):
            z1, z2 = digit_indices[d][i], digit_indices[d][i + 1]
            pairs += [[x[z1], x[z2]]]
            inc = random.randrange(1, NUM_CLASSES)
            dn = (d + inc) % NUM_CLASSES
            z1, z2 = digit_indices[d][i], digit_indices[dn][i]
            pairs += [[x[z1], x[z2]]]
            labels += [1, 0]
    return np.array(pairs), np.array(labels)

In [4]:
train_df = pd.read_csv(TRAIN_DATA_DIR)
test_df = pd.read_csv(TEST_DATA_DIR)

In [5]:
m_train = train_df.shape[0]
m_test = test_df.shape[0]

In [6]:
X_train = np.array(train_df.loc[:, train_df.columns != 'label'].values.reshape(m_train, 28, 28, 1), dtype=np.float64)
y_train = train_df['label'].values

X_test = np.array(test_df.loc[:, test_df.columns != 'label'].values.reshape(m_test, 28, 28, 1), dtype=np.float64)
y_test = test_df['label'].values

In [7]:
X_train /= 255
X_test /= 255
input_shape = X_train.shape[1:]

In [8]:
digit_indices = [np.where(y_train == i)[0] for i in range(NUM_CLASSES)]
tr_pairs, tr_y = create_pairs(X_train, digit_indices)

digit_indices = [np.where(y_test == i)[0] for i in range(NUM_CLASSES)]
ts_pairs, ts_y = create_pairs(X_test, digit_indices)

In [10]:
print('train X shape:', tr_pairs.shape)
print('train y shape:', tr_y.shape)

print('test X shape:', ts_pairs.shape)
print('test y shape:', ts_y.shape)

train X shape: (108400, 2, 28, 28, 1)
train y shape: (108400,)
test X shape: (17820, 2, 28, 28, 1)
test y shape: (17820,)


In [15]:
np.save(DATASET_ROOT_DIR + 'mnist-siamese-network-pair/tr_pairs.npy', tr_pairs)
np.save(DATASET_ROOT_DIR + 'mnist-siamese-network-pair/tr_y.npy', tr_y)

np.save(DATASET_ROOT_DIR + 'mnist-siamese-network-pair/ts_pairs.npy', ts_pairs)
np.save(DATASET_ROOT_DIR + 'mnist-siamese-network-pair/ts_y.npy', ts_y)