# 前言

对数据集`chinese mnist`进行数据处理，以便能更好地进行模型训练。

In [None]:
# 载入包

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
from PIL import Image, ImageEnhance, ImageFilter
import scipy.ndimage


## 数据集预处理

包括：数据载入与增强。

In [None]:
# 基本参数定义
IMG_PATH = "../input/chinese-mnist/data/data/"  # 图片文件夹

In [None]:
# 载入csv文件
data_df=pd.read_csv('..//input//chinese-mnist//chinese_mnist.csv')
data_df.shape

In [None]:
data_df.head()

图片储存在`data/data`文件夹内，编号规则为`input_{suite_id}_{sample_id}_{code}`。

例如`input_1_1_10.jpg`：

In [None]:
# 查看一下图片
image = Image.open(IMG_PATH+"input_100_10_1.jpg")
image.size

In [None]:
plt.imshow(image,cmap="gray")

In [None]:
data_df.dtypes

In [None]:
# 将图片完整路径加入csv文件
data_df['suite_id'] = data_df['suite_id'].astype(str)
data_df['sample_id'] = data_df['sample_id'].astype(str)
data_df['code'] = data_df['code'].astype(str)
data_df["path"] = IMG_PATH+"input_"+data_df["suite_id"]+"_"+data_df["sample_id"]+"_"+data_df["code"]+".jpg"
data_df["path"].head()

In [None]:
# 图片处理
def img_map(image):
    # 将图片转为灰度图片
    image = image.convert('L')
    # 缩放图片大小到28*28
    #image = image.resize((28,28), Image.ANTIALIAS)
    # 滤镜
    image = image.filter(ImageFilter.SMOOTH)
    # 增强对比度
    enh_col = ImageEnhance.Contrast(image)
    factor = 5  # 因子
    image = enh_col.enhance(factor=factor)
    
    return image

# 将图片转为数组
def to_array(image):
    return np.array(image)

In [None]:
# 载入所有图片
images = list(map(Image.open,data_df["path"]))

image_original = np.array(list(map(to_array,images)))  # 原始图片

# 图片处理
train_images = map(img_map,images)

# 实例化并转为数组
x_train = np.array(list(map(to_array,train_images)))

# 去掉一些弱白点区域
x_train[x_train<120] = 0

print(x_train.shape,image_original.shape)

In [None]:
# 原始图片
cnt = 0
idx = list(range(1,len(data_df),1000))
r,c = 3,5  # r*c个图片
figure = np.zeros((64 * r, 64 * c))  # 预先定义图片矩阵
for i in range(r):
    for j in range(c):
        figure[i * 64: (i + 1) * 64,
               j * 64: (j + 1) * 64] = image_original[idx[cnt],:,:]
        cnt = cnt+1
plt.figure(figsize=(10,10))
plt.axis("off")
plt.imshow(figure, cmap='Greys_r')
plt.savefig("/kaggle/working/cn-original.pdf",bbox_inches='tight')
plt.show()

In [None]:
# 处理后图片
cnt = 0
idx = list(range(1,len(data_df),1000))
r,c = 3,5  # r*c个图片
figure = np.zeros((64 * r, 64 * c))  # 预先定义图片矩阵
for i in range(r):
    for j in range(c):
        figure[i * 64: (i + 1) * 64,
               j * 64: (j + 1) * 64] = x_train[idx[cnt],:,:]
        cnt = cnt+1
plt.figure(figsize=(10,10))
plt.axis("off")
plt.imshow(figure, cmap='Greys_r')
plt.savefig("/kaggle/working/cn-handled.pdf",bbox_inches='tight')
plt.show()

In [None]:
# 打乱数据集

y_train = data_df["value"]
data_idx = list(range(0,x_train.shape[0]))
np.random.shuffle(data_idx)
np.random.shuffle(data_idx)
x_train = x_train[data_idx]
y_train = y_train[data_idx]
image_original = image_original[data_idx]

# 保存数组
np.savez("/kaggle/working/chinese-mnist",
        x_train = x_train,
        y_train = y_train,
        image_original = image_original)