## 数据预处理
#### 说明
1. 本预处理脚本负责将数据集分成上身、下身两部分，并剔除缺失了归一化（Normalization）点的数据
2. 本脚本不进行数据增补。数据增补在训练时在线完成，详见datagen.py
3. 代码分为训练集和测试集两部分。每部分可以单独运行。

In [None]:
import csv
import numpy as np
import cv2

from matplotlib import pyplot as plt
%matplotlib inline

%load_ext autoreload
%autoreload 2

In [None]:
points_list = ['neckline_left', 'neckline_right', 'center_front', 'shoulder_left', 'shoulder_right',
                   'armpit_left', 'armpit_right', 'waistline_left', 'waistline_right', 'cuff_left_in',
                   'cuff_left_out', 'cuff_right_in', 'cuff_right_out', 'top_hem_left', 'top_hem_right',
                   'waistband_left', 'waistband_right', 'hemline_left', 'hemline_right', 'crotch',
                   'bottom_left_in', 'bottom_left_out', 'bottom_right_in', 'bottom_right_out']

categories = {
    'all': {'blouse', 'dress', 'outwear', 'skirt', 'trousers'},
    'top': {'blouse', 'dress', 'outwear'},
    'bottom': {'skirt', 'trousers'}
}

weights = {
    'blouse': np.array([1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0,
        0, 0]),
    'dress': np.array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0,
        0, 0]),
    'outwear': np.array([1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0,
        0, 0]),
    'skirt': np.array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0,
        0, 0]),
    'trousers': np.array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 1, 1,
        1, 1])
}

valid_points = {
    'all': np.ones(24, dtype=np.int32),
    'top': np.logical_or(np.logical_or(weights['blouse'], weights['dress']),weights['outwear']),
    'bottom': np.logical_or(weights['skirt'], weights['trousers'])
}

### 训练集
1. 遍历训练集数据，剔除并统计非法数据（缺失归一化点）
2. 统计训练集关键点数量 
3. 通过general_category选择上装(top)和下装(bottom)

In [None]:
def generate_train_set(source, target, general_category = 'all'):
    with open(source, newline='') as infile:
        spamreader = csv.reader(infile)
        with open(target, 'w', newline='') as outfile:
            spamwriter = csv.writer(outfile, delimiter=' ')
            
            # keep track of some variables
            head = True
            image_count = {'blouse':[0, 0], 'dress':[0, 0], 'outwear':[0, 0], 'skirt':[0, 0], 'trousers':[0, 0]}
            points_to_save = valid_points[general_category]
            category_to_save = categories[general_category]
            
            # process every row
            for row in spamreader:
                # skip the header row
                if head:
                    head = False
                    continue
                    
                bad_data = False
                category = row[1]
                
                if category not in category_to_save:
                    continue
                
                # parse data
                this_row = [row[0].split('/')[-1], category]
                for i in range(2, 26):
                    if points_to_save[i-2] == False:
                        continue
                        
                    x, y, v = row[i].split('_')

                    # pick out data with invalid points
                    if v == '1' and weights[category][i-2] == 0:
                        bad_data = True
                        break
                    
                    if v == '1':
                        point_count[i-2] += 1

                    this_row += [x, y, v]

                # pick out data without valid normalization points
                if category in categories['top']:
                    if row[7].split('_')[2] == "-1" or row[8].split('_')[2] == "-1":
                        bad_data = True
                else:
                    if row[17].split('_')[2] == "-1" or row[18].split('_')[2] == "-1":
                        bad_data = True
                        
                if bad_data:
                    image_count[category][1] += 1    
                else:
                    image_count[category][0] += 1
                    spamwriter.writerow(this_row)
                    
            return image_count

In [None]:
point_count = np.zeros(24)

for general_category in ['top', 'bottom']:
    print('dataset:', general_category)
    image_count = generate_train_set("data/train/Annotations/train.csv", 
                                     'data/train/dataset_%s.csv' % general_category, general_category)
    print(image_count)

In [None]:
for general_category in ['top', 'bottom']:
    print(general_category)
    for i in range(len(point_count)):
        if valid_points[general_category][i]:
            print(points_list[i], ':', point_count[i])
    print('')

### 测试集
1. 遍历并统计测试集数据 
2. 通过general_category选择上装(top)和下装(bottom)

In [None]:
def generate_test_set(source, target, general_category = 'all'):
    with open(source, newline='') as infile:
        spamreader = csv.reader(infile)
        with open(target, 'w', newline='') as outfile:
            spamwriter = csv.writer(outfile)
            
            category_to_save = categories[general_category]
            
            head = True
            
            # process every row
            for row in spamreader:
                if head:
                    spamwriter.writerow(row)
                    head = False
                    continue
                    
                category = row[1]
                
                if category not in category_to_save:
                    continue
                    
                num_test_set[general_category] += 1
                    
                spamwriter.writerow(row)

In [None]:
num_test_set = {'top': 0, 'bottom': 0}

for general_category in ['top', 'bottom']:
    generate_test_set("data/test_b/test.csv", 
                      'data/test_b/test_%s.csv' % general_category, general_category)
    
print(num_test_set)