In [6]:
import os, sys, glob, argparse
from PIL import Image
import cv2
import pandas as pd
import numpy as np
from tqdm import tqdm

from sklearn.neighbors import KNeighborsClassifier
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import cross_val_predict
from sklearn.metrics import classification_report


#2. 数据读取

# 读取数据集
train_path = glob.glob('./mydata/train/*')
test_path = glob.glob('./mydata/test/*')

train_path.sort()
test_path.sort()

train_df = pd.read_csv('mydata/train.csv')
train_df = train_df.sort_values(by='name')
train_label = train_df['label'].values

In [7]:
print(train_label)

[17  5 15 ...  6 11 13]


In [8]:
print(train_df)

                 name  label
0      001fgve41s.jpg     17
1      0042Hux54q.jpg      5
2      005pLHKsI3.jpg     15
3      00CJTMEaBU.jpg      1
4      00GKYy51Fb.jpg      0
...               ...    ...
22907  zzdkD1wHBr.jpg      3
22908  zzgDGWLysC.jpg      8
22909  zzppp9QGX4.jpg      6
22910  zzs9lNn9a0.jpg     11
22911  zzyy2lnneF.jpg     13

[22912 rows x 2 columns]


In [9]:
def image_feat(path):
    img = cv2.imread(path, 0)
    img = img.astype(np.float32)
    feat = [
        (img != 0).sum(),              # 非零像素的数量
        (img == 0).sum(),              # 零像素的数量
        img.mean(),                    # 平均值
        img.std(),                     # 标准差
        len(np.where(img.mean(0))[0]), # 在列方向上平均值不为零的数量
        len(np.where(img.mean(1))[0]), # 在行方向上平均值不为零的数量
        img.mean(0).max(),             # 列方向上的最大平均值
        img.mean(1).max()              # 行方向上的最大平均值
    ]
    return feat

In [10]:
train_feat = []
for path in tqdm(train_path):
    train_feat += [image_feat(path)]

100%|██████████████████████████████████████████████████████████████████████████| 22912/22912 [00:12<00:00, 1867.89it/s]


In [11]:
print(train_feat)

[[3302, 13082, 43.782776, 95.20307, 32, 128, 255.0, 54.03125], [3829, 12555, 47.5235, 98.05218, 64, 128, 249.11719, 90.0], [4501, 11883, 58.927612, 106.32816, 64, 128, 244.86719, 91.25], [3620, 12764, 47.700317, 98.374146, 48, 128, 254.59375, 71.828125], [5048, 11336, 65.56421, 109.83691, 80, 128, 211.54688, 107.03906], [4342, 12042, 55.119385, 103.83822, 80, 128, 210.125, 93.953125], [3976, 12408, 51.75537, 101.29321, 48, 128, 254.78125, 90.875], [3981, 12403, 51.131775, 100.895775, 48, 128, 240.97656, 73.82031], [3775, 12609, 47.960938, 98.67443, 48, 128, 254.77344, 77.671875], [3762, 12622, 47.723938, 97.80295, 48, 128, 248.11719, 86.78125], [3966, 12418, 51.023804, 100.653336, 48, 128, 252.73438, 70.671875], [4298, 12086, 56.110413, 104.06984, 56, 128, 247.89062, 84.69531], [3772, 12612, 46.34436, 97.07916, 80, 128, 213.41406, 92.359375], [4020, 12364, 50.516785, 100.265236, 56, 128, 252.67188, 78.328125], [4324, 12060, 59.446716, 107.13874, 56, 128, 252.6875, 81.64844], [2965, 134

In [12]:
test_feat = []
for path in tqdm(test_path):
    test_feat += [image_feat(path)]

# 训练集交叉验证
train_pred = cross_val_predict(
    KNeighborsClassifier(),
    np.array(train_feat),
    train_label
)
print(classification_report(train_label, train_pred))

100%|████████████████████████████████████████████████████████████████████████████| 6165/6165 [00:03<00:00, 1812.79it/s]


              precision    recall  f1-score   support

           0       0.08      0.23      0.12       894
           1       0.05      0.12      0.07       943
           2       0.05      0.12      0.07       750
           3       0.09      0.18      0.12      1031
           4       0.09      0.16      0.12      1153
           5       0.10      0.15      0.12       906
           6       0.18      0.23      0.20      1027
           7       0.06      0.05      0.06       672
           8       0.11      0.12      0.11       901
           9       0.08      0.07      0.07       989
          10       0.09      0.06      0.07       798
          11       0.12      0.08      0.10       756
          12       0.04      0.04      0.04      1146
          13       0.02      0.01      0.01      1146
          14       0.10      0.05      0.07      1097
          15       0.04      0.02      0.03       834
          16       0.18      0.10      0.13       837
          17       0.13    

In [13]:
print(test_path)

['./mydata/test\\01nWcziU3M.jpg', './mydata/test\\02zJddhgaT.jpg', './mydata/test\\04QrNpriH7.jpg', './mydata/test\\04Ted4Hpzs.jpg', './mydata/test\\05Alxs38hT.jpg', './mydata/test\\06HM4K4jsU.jpg', './mydata/test\\06WFkrCP3v.jpg', './mydata/test\\07FJX2pKUP.jpg', './mydata/test\\08EyPn5MZH.jpg', './mydata/test\\08S7Mt39wf.jpg', './mydata/test\\08UQYVH3mV.jpg', './mydata/test\\09GMpEhuOX.jpg', './mydata/test\\09HaCd60iw.jpg', './mydata/test\\09UfLYjAxq.jpg', './mydata/test\\0AbYicPTwH.jpg', './mydata/test\\0Asv4Akj6h.jpg', './mydata/test\\0BSexCGGws.jpg', './mydata/test\\0BsKmWL1fm.jpg', './mydata/test\\0CIURLM70s.jpg', './mydata/test\\0DRTxn1Gg4.jpg', './mydata/test\\0E3OuGGPgu.jpg', './mydata/test\\0FtpujBtva.jpg', './mydata/test\\0Fwtul4qbw.jpg', './mydata/test\\0GOgZOgpAJ.jpg', './mydata/test\\0GwRJlzlzg.jpg', './mydata/test\\0HOgpIhedx.jpg', './mydata/test\\0HjyjdRfBi.jpg', './mydata/test\\0IlT9ViGQi.jpg', './mydata/test\\0JDEJzsIS2.jpg', './mydata/test\\0JVcNkRXBr.jpg', './mydata