diff --git a/examples/utils/dataset.py b/examples/utils/dataset.py index 1c48eaf..3d9d553 100644 --- a/examples/utils/dataset.py +++ b/examples/utils/dataset.py @@ -70,7 +70,7 @@ def load_mnist_realval(path, one_hot=True, dequantify=False): '/mnist.pkl.gz', path) f = gzip.open(path, 'rb') - train_set, valid_set, test_set = pickle.load(f) + train_set, valid_set, test_set = pickle.load(f, encoding='latin1') f.close() x_train, t_train = train_set[0], train_set[1] x_valid, t_valid = valid_set[0], valid_set[1] @@ -172,16 +172,16 @@ def load_cifar10(path, normalize=True, dequantify=False, one_hot=True): train_x, train_y = [], [] for i in range(1, 6): batch_file = os.path.join(batch_dir, 'data_batch_' + str(i)) - with open(batch_file, 'r') as f: - data = pickle.load(f) + with open(batch_file, 'rb') as f: + data = pickle.load(f, encoding='latin1') train_x.append(data['data']) train_y.append(data['labels']) train_x = np.vstack(train_x) train_y = np.hstack(train_y) test_batch_file = os.path.join(batch_dir, 'test_batch') - with open(test_batch_file, 'r') as f: - data = pickle.load(f) + with open(test_batch_file, 'rb') as f: + data = pickle.load(f, encoding='latin1') test_x = data['data'] test_y = np.asarray(data['labels'])