In [1]:
from pathlib import Path
import os
import shutil
import random
from itertools import chain

from PIL import Image
from tqdm.auto import tqdm
import numpy as np

In [2]:
def read_summary_file(path):
    summary_list = []
    with open(path, 'r') as f:
        for line in f:
            if line[-1] == '\n':
                line = line[:-1]
            line = line.strip()
            if len(line) == 0:
                continue
            summary_list.append(line)
        
    return summary_list

def write_summary_file(path, lines):
    with open(path, 'w') as f:
        for line in lines[:-1]:
            f.write(line+'\n')
        f.write(lines[-1])

In [26]:
base_path = Path('/home/shlll/Dataset/Teeth/')

dataset_base_path = base_path / 'aug2'
summary_path = dataset_base_path / "ImageSets/Segmentation"
train_summary_path = summary_path / 'train.txt'
trainval_summary_path = summary_path / 'trainval.txt'
val_summary_path = summary_path / 'val.txt'
src_image_path = dataset_base_path / 'JPEGImages_700'
# gt_image_path = dataset_base_path / 'SegmentationClass'
gt_image_path = dataset_base_path / 'newSegmentationClass_700'

aug_src_img_path = src_image_path / 'output'

dst_dataset_base_path = base_path / 'augment1'
dst_summary_path = dst_dataset_base_path / "ImageSets/Segmentation"
dst_train_summary_path = dst_summary_path / 'train.txt'
dst_trainval_summary_path = dst_summary_path / 'trainval.txt'
dst_val_summary_path = dst_summary_path / 'val.txt'
dst_src_image_path = dst_dataset_base_path / 'JPEGImages'
dst_gt_image_path = dst_dataset_base_path / 'SegmentationClass'

In [None]:
# Rename数据集
aug_src_image_list = map(lambda x: x.name, aug_src_img_path.glob('*.jpg'))

last_name = None
count = 1
for path in sorted(aug_src_image_list):
    part = path.split('_')[4]
    
    if part == last_name:
        count += 1
    else:
        last_name = part
        count = 1
    
    dst_name = '_'.join(path.split('_')[3:5])[:-4] + '_' + str(count) + '.jpg'
    ori_path = aug_src_img_path / path
    dst_path = aug_src_img_path / dst_name
    ori_path.rename(dst_path)

In [75]:
# 创建新的数据集
shutil.rmtree(dst_dataset_base_path)
dst_dataset_base_path.mkdir(exist_ok=True)
dst_summary_path.mkdir(parents=True, exist_ok=True)
dst_src_image_path.mkdir(exist_ok=True)
dst_gt_image_path.mkdir(exist_ok=True)

src_image_list = src_image_path.glob('*.jpg')
gt_image_list = gt_image_path.glob('*.png')
aug_src_image_list = aug_src_img_path.glob('*.jpg')

# 移动原始数据集
for path in src_image_list:
    dst_path = dst_src_image_path / path.name
    os.link(path, dst_path)

# 移动原始GT数据集
for path in gt_image_list:
    dst_path = dst_gt_image_path / path.name
    os.link(path, dst_path)

# 移动aug数据集并生成对应的GT数据
for path in aug_src_image_list:
    gt_path = dst_gt_image_path / ('_'.join(path.stem.split('_')[:-1]) + '.png')
    if not gt_path.exists():
        continue
    gt_dst_path = dst_gt_image_path / (path.stem + '.png')
    dst_path = dst_src_image_path / path.name
    os.link(path, dst_path)
    shutil.copy(gt_path, gt_dst_path)

# 生成对应的summary文件
train_list = read_summary_file(train_summary_path)
trainval_list = read_summary_file(trainval_summary_path)
val_list = read_summary_file(val_summary_path)

dst_train_list = []
dst_trainval_list = []
dst_val_list = []

for name in train_list:
    files = list(map(lambda x: x.stem, dst_src_image_path.glob(name + '*')))
    dst_train_list.extend(files)
    
for name in trainval_list:
    files = list(map(lambda x: x.stem, dst_src_image_path.glob(name + '*')))
    dst_trainval_list.extend(files)
    
for name in val_list:
    files = list(map(lambda x: x.stem, dst_src_image_path.glob(name + '*')))
    dst_val_list.extend(files)
    
write_summary_file(dst_train_summary_path, dst_train_list)
write_summary_file(dst_trainval_summary_path, dst_trainval_list)
write_summary_file(dst_val_summary_path, dst_val_list)