# 划分训练集和验证集



## 导入工具包

In [2]:
import os
import math
import shutil
import collections
import pandas as pd

In [9]:

class MyImageClassifier:
    """
    划分训练集和验证集
    """
    def __init__(self, data_dir, target_dir, valid_ratio,train_folder,test_folder):
        self.data_dir = data_dir
        self.target_dir = target_dir
        self.valid_ratio = valid_ratio
        self.train_folder = train_folder
        self.test_folder = test_folder
        
    def read_csv_labels(self, fname):
        """读取fname来给标签字典返回一个文件名"""
        with open(fname, 'r',encoding='utf-8') as f:
            # 跳过文件头行(列名)
            lines = f.readlines()[1:]
        tokens = [l.rstrip().split(',') for l in lines]
        return dict(((name, label) for name, label in tokens))
    
    def copyfile(self,filename, target_dir):
        """将文件复制到目标目录"""
        os.makedirs(target_dir, exist_ok=True)
        shutil.copy(filename, target_dir)
        
    def reorg_train_valid(self,labels):
        """将验证集从原始的训练集中拆分出来"""
        # 训练数据集中样本最少的类别中的样本数
        n = collections.Counter(labels.values()).most_common()[-1][1]
        # 验证集中每个类别的样本数
        n_valid_per_label = max(1, math.floor(n * self.valid_ratio))
        label_count = {}
        for train_file in os.listdir(os.path.join(self.data_dir, self.train_folder)):
            label = labels[train_file.split('.')[0]]
            fname = os.path.join(self.data_dir, self.train_folder, train_file)
            self.copyfile(fname, os.path.join(self.target_dir, 'train_valid_test',
                                         'train_valid', label))
            if label not in label_count or label_count[label] < n_valid_per_label:
                self.copyfile(fname, os.path.join(self.target_dir, 'train_valid_test',
                                             'valid', label))
                label_count[label] = label_count.get(label, 0) + 1
            else:
                self.copyfile(fname, os.path.join(self.target_dir, 'train_valid_test',
                                             'train', label))
        return n_valid_per_label


    def reorg_test(self):
        """在预测期间整理测试集，以方便读取"""
        for test_file in os.listdir(os.path.join(self.data_dir, self.test_folder)):
            self.copyfile(os.path.join(self.data_dir, self.test_folder, test_file),
                     os.path.join(self.target_dir, 'train_valid_test', 'test',
                                  'unknown'))

    def reorg_san_data(self,labels_csv):
        labels = self.read_csv_labels(os.path.join(self.data_dir,labels_csv))
        self.reorg_train_valid(labels)
        self.reorg_test()
        print('# 训练样本 :', len(labels))
        print('# 类别 :', len(set(labels.values())))
        
    """
    以上为数据整理函数
    """
    
    def classes(self):
        class_to_idx = {}
        # 遍历数据集文件夹中的子文件夹（每个子文件夹代表一个类别）
        for idx, class_name in enumerate(sorted(os.listdir(os.path.join(self.target_dir, 'train_valid_test', 'valid')))):
            if class_name.startswith('.'):
                continue
            class_dir = os.path.join(os.path.join(self.target_dir, 'train_valid_test', 'valid'), class_name)  # 类别文件夹路径
            if os.path.isdir(class_dir):
                class_to_idx[idx] = class_name
        print(class_to_idx)
        print("============================")
        return class_to_idx
    
    #统计划分的训练集、验证集数量
    def count_samples(self):
        """统计每个类别训练集和验证集的数量"""
        train_valid_test_dirs = ['train', 'valid']
        data_counts = {'class': []}
        for dir_name in train_valid_test_dirs:
            class_dir = os.path.join(self.target_dir, 'train_valid_test', dir_name)
            if dir_name not in data_counts:
                data_counts[dir_name] = []
            for class_name in os.listdir(class_dir):
                if class_name.startswith('.'):
                    continue
                class_sub_dir = os.path.join(class_dir, class_name)
                if os.path.isdir(class_sub_dir):
                    if class_name not in data_counts['class']:
                        data_counts['class'].append(class_name)
                        for key in train_valid_test_dirs:
                            if key not in data_counts:
                                data_counts[key] = [0] * len(data_counts['class'])
                            else:
                                data_counts[key].append(0)
                    data_counts[dir_name][data_counts['class'].index(class_name)] += len(os.listdir(class_sub_dir))
        df = pd.DataFrame(data_counts)
        return df


hua_train_val = MyImageClassifier(data_dir="D:\SanYeQing_Project\sanyeqing_hun_weizhi_finally",target_dir="D:\linshi_mulu",
                                  valid_ratio=0.2,train_folder = 'train_hun_finally',
                                    test_folder='test_hun_finally')
hua_train_val.reorg_san_data('labels_hun.csv')
hua_train_val.classes()
df = hua_train_val.count_samples()
print(df)
test_df = pd.read_csv("数据量统计.csv")
print(test_df)
merged_df = pd.merge(df, test_df, on=['class'], how='outer')
merged_df['total'] = merged_df[['train', 'valid', 'test']].sum(axis=1)    #总和统计
print(merged_df)
merged_df.to_csv('数据量统计.csv', index=False)




# 训练样本 : 9214
# 类别 : 6
{0: '云南省', 1: '广西省', 2: '未知', 3: '浙江省', 4: '贵州省', 5: '陕西省'}
  class  train  valid
0   云南省    922    296
1   广西省    981    296
2    未知    957    296
3   浙江省    919    296
4   贵州省    926    296
5   陕西省    891    296
  class  test
0   云南省   296
1    未知   310
2   浙江省   293
3   贵州省   308
4   陕西省   296
5   广西省   339
  class  train  valid  test  total
0   云南省    922    296   296   1514
1   广西省    981    296   339   1616
2    未知    957    296   310   1563
3   浙江省    919    296   293   1508
4   贵州省    926    296   308   1530
5   陕西省    891    296   296   1483
