# 增强数据集生成程序

In [1]:
from pathlib import Path
import os
import shutil
import random
from itertools import chain

from PIL import Image
from tqdm.auto import tqdm
import numpy as np

## 预处理

### 定义读写函数

In [2]:
def read_summary_file(path):
    summary_list = []
    with open(path, 'r') as f:
        for line in f:
            if line[-1] == '\n':
                line = line[:-1]
            line = line.strip()
            if len(line) == 0:
                continue
            summary_list.append(line)
        
    return summary_list

def write_summary_file(path, lines):
    with open(path, 'w') as f:
        for line in lines[:-1]:
            f.write(line+'\n')
        f.write(lines[-1])

### 一些Path定义

In [46]:
base_path = Path('/home/shlll/Dataset/Teeth/')

dataset_base_path = base_path / 'aug1'
summary_path = dataset_base_path / "ImageSets/Segmentation"
train_summary_path = summary_path / 'train.txt'
trainval_summary_path = summary_path / 'trainval.txt'
val_summary_path = summary_path / 'val.txt'
src_image_path = dataset_base_path / 'JPEGImages_700'
gt_image_path = dataset_base_path / 'newSegmentationClass_700'

aug_src_img_path = src_image_path / 'output'

teeth_dataset_base_path = base_path / 'aug2'
teeth_summary_path = teeth_dataset_base_path / "ImageSets/Segmentation"
teeth_train_summary_path = teeth_summary_path / 'train.txt'
teeth_trainval_summary_path = teeth_summary_path / 'trainval.txt'
teeth_val_summary_path = teeth_summary_path / 'val.txt'
teeth_src_image_path = teeth_dataset_base_path / 'JPEGImages'
teeth_gt_image_path = teeth_dataset_base_path / 'SegmentationClass'

teeth_aug_src_img_path = teeth_dataset_base_path / 'outputbright'
teeth_aug_src_img_path_pro = teeth_dataset_base_path / 'outputbright_pro'

teeth_aug_src_img_path_ori = teeth_dataset_base_path / 'outputori'
teeth_aug_src_img_path_ori_pro = teeth_dataset_base_path / 'outputori_pro'

## 处理已有数据集

### 处理full-VOC20200422数据集

In [None]:
# Rename数据集
aug_src_image_list = map(lambda x: x.name, aug_src_img_path.glob('*.jpg'))

last_name = None
count = 1
for path in sorted(aug_src_image_list):
    part = path.split('_')[4]
    
    if part == last_name:
        count += 1
    else:
        last_name = part
        count = 1
    
    dst_name = '_'.join(path.split('_')[3:5])[:-4] + '_' + str(count) + '.jpg'
    ori_path = aug_src_img_path / path
    dst_path = aug_src_img_path / dst_name
    ori_path.rename(dst_path)

### 处理teeth数据集

In [55]:
# Rename数据集
if (teeth_aug_src_img_path_pro.exists()):
    shutil.rmtree(teeth_aug_src_img_path_pro)
teeth_aug_src_img_path_pro.mkdir()
    
teeth_aug_src_image_list = map(lambda x: x.name, teeth_aug_src_img_path.glob('*.jpg'))

last_name = None
count = 1
for path in sorted(teeth_aug_src_image_list):
    if not path.startswith('JPEGImages_original'):
        continue
    
    dst_name = '_'.join(path.split('_')[2:4])[:-4]
    if dst_name == last_name:
        count += 1
    else:
        last_name = dst_name
        count = 1
    dst_name = dst_name + '_' + str(count) + '.jpg'        
    
    ori_path = teeth_aug_src_img_path / path
    dst_path = teeth_aug_src_img_path_pro / dst_name
    os.link(ori_path, dst_path)

# Resize Image
for path in teeth_aug_src_img_path_pro.glob('*.jpg'):
    img = Image.open(path).resize((700, 700))
    img.save(path)

In [58]:
# Rename数据集
if (teeth_aug_src_img_path_ori_pro.exists()):
    shutil.rmtree(teeth_aug_src_img_path_ori_pro)
teeth_aug_src_img_path_ori_pro.mkdir()
    
teeth_aug_src_image_list = map(lambda x: x.name, teeth_aug_src_img_path_ori.glob('*.jpg'))

last_name = None
count = 1
for path in sorted(teeth_aug_src_image_list):
    if not path.startswith('JPEGImages_original'):
        continue
    
    dst_name = '_'.join(path.split('_')[2:4])[:-4]
    if dst_name == last_name:
        count += 1
    else:
        last_name = dst_name
        count = 1
    dst_name = dst_name + '_' + str(count) + '.jpg'        
    
    ori_path = teeth_aug_src_img_path_ori / path
    dst_path = teeth_aug_src_img_path_ori_pro / dst_name
    os.link(ori_path, dst_path)

# Resize Image
for path in teeth_aug_src_img_path_ori_pro.glob('*.jpg'):
    img = Image.open(path).resize((700, 700))
    img.save(path)

## 合并成为新的数据集

### 生成bright数据集

In [41]:
# ==========================================
# 路径配置
dst_dataset_base_path = base_path / 'aug_bright'
dst_summary_path = dst_dataset_base_path / "ImageSets/Segmentation"
dst_train_summary_path = dst_summary_path / 'train.txt'
dst_trainval_summary_path = dst_summary_path / 'trainval.txt'
dst_val_summary_path = dst_summary_path / 'val.txt'
dst_src_image_path = dst_dataset_base_path / 'JPEGImages'
dst_gt_image_path = dst_dataset_base_path / 'SegmentationClass'

# 删除已有的新数据集
if (dst_dataset_base_path.exists()):
    shutil.rmtree(dst_dataset_base_path)

# 一些新数据集路径创建
dst_dataset_base_path.mkdir(exist_ok=True)
dst_summary_path.mkdir(parents=True, exist_ok=True)
dst_src_image_path.mkdir(exist_ok=True)
dst_gt_image_path.mkdir(exist_ok=True)

# ==========================================
# 下面处理VOC数据集至新的数据集
src_image_list = src_image_path.glob('*.jpg')
gt_image_list = gt_image_path.glob('*.png')
aug_src_image_list = aug_src_img_path.glob('*.jpg')

# 移动原始数据集
for path in src_image_list:
    dst_path = dst_src_image_path / path.name
    os.link(path, dst_path)

# 移动原始GT数据集
for path in gt_image_list:
    dst_path = dst_gt_image_path / path.name
    os.link(path, dst_path)

# 移动aug数据集并生成对应的GT数据
for path in aug_src_image_list:
    gt_path = dst_gt_image_path / ('_'.join(path.stem.split('_')[:-1]) + '.png')
    if not gt_path.exists():
        continue
    gt_dst_path = dst_gt_image_path / (path.stem + '.png')
    dst_path = dst_src_image_path / path.name
    os.link(path, dst_path)
    shutil.copy(gt_path, gt_dst_path)

# ==========================================
# 下面处理VOC数据集至新的数据集
src_image_list = teeth_src_image_path.glob('*.jpg')
gt_image_list = teeth_gt_image_path.glob('*.png')
aug_src_image_list = teeth_aug_src_img_path_pro.glob('*.jpg')

# 移动原始数据集
for path in src_image_list:
    dst_path = dst_src_image_path / path.name
    os.link(path, dst_path)

# 移动原始GT数据集
for path in gt_image_list:
    dst_path = dst_gt_image_path / path.name
    os.link(path, dst_path)

# 移动aug数据集并生成对应的GT数据
for path in aug_src_image_list:
    gt_path = dst_gt_image_path / ('_'.join(path.stem.split('_')[:-1]) + '.png')
    if not gt_path.exists():
        continue
    gt_dst_path = dst_gt_image_path / (path.stem + '.png')
    dst_path = dst_src_image_path / path.name
    os.link(path, dst_path)
    shutil.copy(gt_path, gt_dst_path)

# ==========================================
# 下面处理两个数据集的Summary文件

dst_train_list = []
dst_trainval_list = []
dst_val_list = []

# 对应的summary文件
train_list = read_summary_file(train_summary_path)
trainval_list = read_summary_file(trainval_summary_path)
val_list = read_summary_file(val_summary_path)
teeth_train_list = read_summary_file(teeth_train_summary_path)
teeth_trainval_list = read_summary_file(teeth_trainval_summary_path)
teeth_val_list = read_summary_file(teeth_val_summary_path)
train_list.extend(teeth_train_list)
trainval_list.extend(teeth_trainval_list)
val_list.extend(teeth_val_list)

for name in train_list:
    files = list(map(lambda x: x.stem, dst_src_image_path.glob(name + '*')))
    dst_train_list.extend(files)
    
for name in trainval_list:
    files = list(map(lambda x: x.stem, dst_src_image_path.glob(name + '*')))
    dst_trainval_list.extend(files)
    
for name in val_list:
    files = list(map(lambda x: x.stem, dst_src_image_path.glob(name + '*')))
    dst_val_list.extend(files)
    
write_summary_file(dst_train_summary_path, dst_train_list)
write_summary_file(dst_trainval_summary_path, dst_trainval_list)
write_summary_file(dst_val_summary_path, dst_val_list)

### 生成ori数据集

In [59]:
# ==========================================
# 路径配置
dst_dataset_base_path = base_path / 'aug_ori'
dst_summary_path = dst_dataset_base_path / "ImageSets/Segmentation"
dst_train_summary_path = dst_summary_path / 'train.txt'
dst_trainval_summary_path = dst_summary_path / 'trainval.txt'
dst_val_summary_path = dst_summary_path / 'val.txt'
dst_src_image_path = dst_dataset_base_path / 'JPEGImages'
dst_gt_image_path = dst_dataset_base_path / 'SegmentationClass'

# 删除已有的新数据集
if (dst_dataset_base_path.exists()):
    shutil.rmtree(dst_dataset_base_path)

# 一些新数据集路径创建
dst_dataset_base_path.mkdir(exist_ok=True)
dst_summary_path.mkdir(parents=True, exist_ok=True)
dst_src_image_path.mkdir(exist_ok=True)
dst_gt_image_path.mkdir(exist_ok=True)

# ==========================================
# 下面处理VOC数据集至新的数据集
src_image_list = src_image_path.glob('*.jpg')
gt_image_list = gt_image_path.glob('*.png')
aug_src_image_list = aug_src_img_path.glob('*.jpg')

# 移动原始数据集
for path in src_image_list:
    dst_path = dst_src_image_path / path.name
    os.link(path, dst_path)

# 移动原始GT数据集
for path in gt_image_list:
    dst_path = dst_gt_image_path / path.name
    os.link(path, dst_path)

# 移动aug数据集并生成对应的GT数据
for path in aug_src_image_list:
    gt_path = dst_gt_image_path / ('_'.join(path.stem.split('_')[:-1]) + '.png')
    if not gt_path.exists():
        continue
    gt_dst_path = dst_gt_image_path / (path.stem + '.png')
    dst_path = dst_src_image_path / path.name
    os.link(path, dst_path)
    shutil.copy(gt_path, gt_dst_path)

# ==========================================
# 下面处理teeth数据集至新的数据集
src_image_list = teeth_src_image_path.glob('*.jpg')
gt_image_list = teeth_gt_image_path.glob('*.png')
aug_src_image_list = teeth_aug_src_img_path_ori_pro.glob('*.jpg')

# 移动原始数据集
for path in src_image_list:
    dst_path = dst_src_image_path / path.name
    os.link(path, dst_path)

# 移动原始GT数据集
for path in gt_image_list:
    dst_path = dst_gt_image_path / path.name
    os.link(path, dst_path)

# 移动aug数据集并生成对应的GT数据
for path in aug_src_image_list:
    gt_path = dst_gt_image_path / ('_'.join(path.stem.split('_')[:-1]) + '.png')
    if not gt_path.exists():
        continue
    gt_dst_path = dst_gt_image_path / (path.stem + '.png')
    dst_path = dst_src_image_path / path.name
    os.link(path, dst_path)
    shutil.copy(gt_path, gt_dst_path)

# ==========================================
# 下面处理两个数据集的Summary文件

dst_train_list = []
dst_trainval_list = []
dst_val_list = []

# 对应的summary文件
train_list = read_summary_file(train_summary_path)
trainval_list = read_summary_file(trainval_summary_path)
val_list = read_summary_file(val_summary_path)
teeth_train_list = read_summary_file(teeth_train_summary_path)
teeth_trainval_list = read_summary_file(teeth_trainval_summary_path)
teeth_val_list = read_summary_file(teeth_val_summary_path)
train_list.extend(teeth_train_list)
trainval_list.extend(teeth_trainval_list)
val_list.extend(teeth_val_list)

for name in train_list:
    files = list(map(lambda x: x.stem, dst_src_image_path.glob(name + '*')))
    dst_train_list.extend(files)
    
for name in trainval_list:
    files = list(map(lambda x: x.stem, dst_src_image_path.glob(name + '*')))
    dst_trainval_list.extend(files)
    
for name in val_list:
    files = list(map(lambda x: x.stem, dst_src_image_path.glob(name + '*')))
    dst_val_list.extend(files)
    
write_summary_file(dst_train_summary_path, dst_train_list)
write_summary_file(dst_trainval_summary_path, dst_trainval_list)
write_summary_file(dst_val_summary_path, dst_val_list)

#### 对数据集的shape进行检查

In [64]:
for dataset_base_path in ['aug_bright', 'aug_ori']:
    dataset_base_path = base_path / dataset_base_path
    src_image_path = dataset_base_path / 'JPEGImages'
    gt_image_path = dataset_base_path / 'SegmentationClass'
    
    for path in src_image_path.glob('*.jpg'):
        gt_path = gt_image_path / (path.stem + '.png')
        
        src_img = Image.open(path)
        gt_img = Image.open(gt_path)
        
        if (src_img.size != (700, 700)):
            src_img  = src_img.resize((700, 700))
            src_img.save(path)
            
        if (gt_img.size != (700, 700)):
            gt_img  = gt_img.resize((700, 700))
            gt_img.save(gt_path)