## 分割COCO类数据集

In [15]:
import json
import os
import random
import shutil
from tqdm import tqdm

In [39]:
def split_dataset(img_dir, ann_file, output_dir, val_ratio=0.2):
    """
    split coco dataset into train_dataset and val_dataset
    -- dataset
      -- images
      -- annotations
    """
    # 创建输出目录
    train_dir = os.path.join(output_dir, 'train_images')
    val_dir = os.path.join(output_dir, 'val_images')
    ann_dir = os.path.join(output_dir, 'annotations')
    os.makedirs(train_dir, exist_ok=True)
    os.makedirs(val_dir, exist_ok=True)
    os.makedirs(ann_dir, exist_ok=True)

    # 读取coco_dataset标注文件并获取images与annotations数据
    with open(ann_file, 'r') as f:
        coco_data = json.load(f)
    images = coco_data['images']
    annotations = coco_data['annotations']

    # 分割数据集并copy图片
    random.shuffle(images)
    val_size = int(len(images) * val_ratio)
    train_images = images[val_size:]
    val_images = images[:val_size]
    for image in tqdm(train_images, desc='copy train_images', total=len(train_images)):
        image_path = os.path.join(img_dir, image['file_name'])
        shutil.copy(image_path, train_dir)
    for image in tqdm(val_images, desc='copy val_images', total=len(val_images)):
        image_path = os.path.join(img_dir, image['file_name'])
        shutil.copy(image_path, val_dir)

    # 分割annotations
    train_annotations = [annotation for annotation in tqdm(annotations, desc='filter train_annotations', total=len(annotations)) \
                         if annotation['image_id'] in [image['id'] for image in train_images]]
    val_annotations = [annotation for annotation in tqdm(annotations, desc='filter val_annotations', total=len(annotations)) \
                         if annotation['image_id'] in [image['id'] for image in val_images]]
    
    # 更新训练集图像与注释字段并保存
    coco_data['images'] = train_images
    coco_data['annotations'] = train_annotations
    coco_data['info']['description'] = 'train dataset'
    train_json_file = os.path.join(ann_dir, 'data_train.json')
    with open(train_json_file, 'w') as f:
        json.dump(coco_data, f)
    # 更新验证集图像与标注信息字段并保存
    coco_data['images'] = val_images
    coco_data['annotations'] = val_annotations
    coco_data['info']['description'] = 'val dataset'
    train_json_file = os.path.join(ann_dir, 'data_val.json')
    with open(train_json_file, 'w') as f:
        json.dump(coco_data, f)
    print('Dataset split down')

In [40]:
img_dir = '/Users/xiaoqiang/Mlearning/dataset/Drink_coco/images'
ann_file = '/Users/xiaoqiang/Mlearning/dataset/Drink_coco/annotations/instances.json'
output_dir = '/Users/xiaoqiang/Mlearning/dataset/Drink_coco'
split_dataset(img_dir=img_dir, ann_file=ann_file, output_dir=output_dir)


copy train_images:   0%|                                | 0/228 [00:00<?, ?it/s][A
copy train_images:  23%|█████                 | 52/228 [00:00<00:00, 513.76it/s][A
copy train_images:  54%|███████████▎         | 123/228 [00:00<00:00, 626.73it/s][A
copy train_images: 100%|█████████████████████| 228/228 [00:00<00:00, 646.67it/s][A

copy val_images: 100%|█████████████████████████| 56/56 [00:00<00:00, 727.13it/s][A

filter train_annotations: 100%|█████████| 1256/1256 [00:00<00:00, 141904.05it/s][A

filter val_annotations: 100%|███████████| 1256/1256 [00:00<00:00, 498019.08it/s][A

Dataset split down



