In [1]:
from pathlib import Path
from collections import Counter
import shutil
import os
import random
from itertools import chain

import numpy as np
from PIL import Image
import cv2

In [2]:
def read_summary_file(path):
    summary_list = []
    with open(path, 'r') as f:
        for line in f:
            if (line[-1] == '\n'):
                line = line[:-1]
            summary_list.append(line)
        
    return summary_list

def write_summary_file(path, lines):
    with open(path, 'w') as f:
        for line in lines[:-1]:
            f.write(line+'\n')
        f.write(lines[-1])

In [3]:
ori_dataset_dir = Path('/data/shl/data/all_data0512')
ori_dataset_txt_dir = ori_dataset_dir / 'ImageSets/Segmentation'
ori_dataset_img_dir = ori_dataset_dir / 'JPEGImages'
ori_dataset_gt_dir = ori_dataset_dir / 'SegmentationClass'

dst_dataset_dir = Path('/data/shl/data/all_datacolor')
dst_dataset_txt_dir = dst_dataset_dir / 'ImageSets/Segmentation'
dst_dataset_img_dir = dst_dataset_dir / 'JPEGImages'
dst_dataset_gt_dir = dst_dataset_dir / 'SegmentationClass'

# color_gt_dir = Path("/data/pzndata/augmentor/newdata/teeth0511/SegmentationClass")
# color_gt_dir = Path("/data/pzndata/augmentor/newendoscope/data25/SegmentationClass")
color_2116_dir = Path("/data/pzndata/augmentor/newdata/teeth0511/output2116color")
color_25_dir = Path("/data/pzndata/augmentor/newendoscope/data25/output25color")

color_img_dst_dir = dst_dataset_dir / 'JPEGImages_Color'
color_gt_dst_dir = dst_dataset_dir / 'SegmentationClass_Color'

In [13]:
# 先拷贝原先的数据集
if dst_dataset_dir.exists():
    shutil.rmtree(dst_dataset_dir)
dst_dataset_dir.mkdir()
dst_dataset_txt_dir.mkdir(parents=True)
dst_dataset_img_dir.mkdir()
dst_dataset_gt_dir.mkdir()

for path in chain(ori_dataset_img_dir.glob('*'), ori_dataset_gt_dir.glob('*')):
    dst_path = Path(dst_dataset_dir, *path.parts[-2:])
    os.link(path, dst_path)

In [None]:
# 处理新的数据集
if color_img_dst_dir.exists():
    shutil.rmtree(color_img_dst_dir)
if color_gt_dst_dir.exists():
    shutil.rmtree(color_gt_dst_dir)
color_img_dst_dir.mkdir(exist_ok=True)
color_gt_dst_dir.mkdir(exist_ok=True)

for path in chain(color_2116_dir.glob('*'), color_25_dir.glob('*')):
    ori_name = '_'.join(path.name.split('_')[2:4])[:-4]
    color_gt_dir = Path(*path.parts[:-2], 'SegmentationClass')
    ori_gt_path = color_gt_dir / (ori_name + '.png')
    if not ori_gt_path.exists():
        print(ori_gt_path)
        continue
    
    dst_img_name = path.name[20:]
    dst_gt_name = path.stem[20:] + '.png'
    dst_img_path = color_img_dst_dir / dst_img_name
    dst_gt_path = color_gt_dst_dir / dst_gt_name
    
    shutil.copyfile(path, dst_img_path)
    shutil.copyfile(ori_gt_path, dst_gt_path)

In [15]:
# 处理新生成的GT中大于等于3的像素
count = 0
for path in color_gt_dst_dir.glob('*'):
    img = cv2.imread(str(path), cv2.IMREAD_UNCHANGED)
    if img.max() >= 3:
        img[img>=3] = 0
        cv2.imwrite(str(path), img)
        count += 1
count

1000

In [22]:
# 检查是否去除成功
name_list = []
count = 0
seg_cls = dst_dataset_gt_dir
for path in seg_cls.glob('*'):
    gt_img = Image.open(path)
    gt_img_np = np.asarray(gt_img)
    gt_img_np_unique = np.unique(gt_img_np)
    
    if gt_img_np_unique.max() >= 3:
#         print(path.name)
        print(path.name, gt_img_np_unique)
        count += 1
        name_list.append(path.name)
count

0

In [19]:
# 生成划分数据集
if dst_dataset_txt_dir.exists():
    shutil.rmtree(dst_dataset_txt_dir)
dst_dataset_txt_dir.mkdir(parents='True', exist_ok='True')

trainval_path = dst_dataset_txt_dir / "trainval.txt"
train_path = dst_dataset_txt_dir / "train.txt"
val_path = dst_dataset_txt_dir / "val.txt"

# trainval.txt
file_list = list(map(lambda x: x.stem, dst_dataset_img_dir.glob('*')))
write_summary_file(trainval_path, file_list)

# train.txt
random.shuffle(file_list)
split_point = int(len(file_list)*0.8)
write_summary_file(train_path, file_list[:split_point])

# val.txt
write_summary_file(val_path, file_list[split_point:])