In [12]:
from PIL import Image
import time
import torch
import torchvision
from torch.utils.data import Dataset
from torchvision import transforms
import albumentations as A
import albumentations.pytorch
from matplotlib import pyplot as plt
import cv2
import numpy as np
import os
import xml.etree.ElementTree as ET


In [13]:
# 定义类
class VOCAug(object):

    def __init__(self,
                 pre_image_path=None,
                 pre_xml_path=None,
                 aug_image_save_path=None,
                 aug_xml_save_path=None,
                 start_aug_id=None,
                 labels=None,
                 max_len=4,
                 is_show=False):
        """
        
        :param pre_image_path: 
        :param pre_xml_path: 
        :param aug_image_save_path: 
        :param aug_xml_save_path: 
        :param start_aug_id: 
        :param labels: 标签列表, 展示增强后的图片用
        :param max_len: 
        :param is_show: 
        """
        self.pre_image_path = pre_image_path
        self.pre_xml_path = pre_xml_path
        self.aug_image_save_path = aug_image_save_path
        self.aug_xml_save_path = aug_xml_save_path
        self.start_aug_id = start_aug_id
        self.labels = labels
        self.max_len = max_len
        self.is_show = is_show

        print(self.labels)
        assert self.labels is not None, "labels is None!!!"

        # 数据增强选项  # TODO:
        self.aug = A.Compose(
            [
            A.RandomBrightnessContrast(
                brightness_limit=0.06, contrast_limit=0.1, p=0.5),
            A.GaussianBlur(p=0.3),
            A.GaussNoise(p=0.3),
            # A.CLAHE(clip_limit=2.0, tile_grid_size=(4, 4), p=0.5),  # 直方图均衡
            A.Equalize(p=0.6),  # 均衡图像直方图
            A.RandomRotate90(always_apply=False, p=1),
            A.Transpose(always_apply=False, p=0.8),
            
            A.OpticalDistortion(p=0.8),
            # A.GridDistortion(p=0.8),
            
            
            # A.Downscale(p=0.1),  # 随机缩小和放大来降低图像质量
            A.Emboss(p=1),  # 压印输入图像并将结果与原始图像叠加
        ],
            # voc: [xmin, ymin, xmax, ymax]  # 经过归一化
            # min_area: 表示bbox占据的像素总个数, 当数据增强后, 若bbox小于这个值则从返回的bbox列表删除该bbox.
            # min_visibility: 值域为[0,1], 如果增强后的bbox面积和增强前的bbox面积比值小于该值, 则删除该bbox
            A.BboxParams(format='pascal_voc', min_area=0.,
                         min_visibility=0., label_fields=['category_id'])
        )
        print('--------------*--------------')
        print("labels: ", self.labels)
        if self.start_aug_id is None:
            self.start_aug_id = len(os.listdir(self.pre_xml_path))
            print("the start_aug_id is not set, default: len(images)",
                  self.start_aug_id)
        print('--------------*--------------')

    def get_xml_data(self, xml_filename):
        with open(os.path.join(self.pre_xml_path, xml_filename), 'r') as f:
            tree = ET.parse(f)
            root = tree.getroot()
            image_name = tree.find('filename').text
            size = root.find('size')
            w = int(size.find('width').text)
            h = int(size.find('height').text)
            bboxes = []
            cls_id_list = []
            for obj in root.iter('object'):
                # difficult = obj.find('difficult').text
                difficult = obj.find('difficult').text
                cls_name = obj.find('name').text  # label
                if cls_name not in LABELS or int(difficult) == 1:
                    continue
                xml_box = obj.find('bndbox')

                xmin = int(xml_box.find('xmin').text)
                ymin = int(xml_box.find('ymin').text)
                xmax = int(xml_box.find('xmax').text)
                ymax = int(xml_box.find('ymax').text)

                # 标注越界修正
                if xmax > w:
                    xmax = w
                if ymax > h:
                    ymax = h
                bbox = [xmin, ymin, xmax, ymax]
                bboxes.append(bbox)
                
                cls_id_list.append(self.labels.index(cls_name))
                

            # 读取图片
            image = cv2.imread(os.path.join(self.pre_image_path, image_name))

        return bboxes, cls_id_list, image, image_name

    def aug_image(self):
        xml_list = os.listdir(self.pre_xml_path)

        cnt = self.start_aug_id
        for xml in xml_list:
            # AI Studio下会存在.ipynb_checkpoints文件, 为了不报错, 根据文件后缀过滤
            file_suffix = xml.split('.')[-1]
            if file_suffix not in ['xml']:
                continue
            bboxes, cls_id_list, image, image_name = self.get_xml_data(xml)

            anno_dict = {'image': image,
                         'bboxes': bboxes, 'category_id': cls_id_list}
            # 获得增强后的数据 {"image", "bboxes", "category_id"}
            augmented = self.aug(**anno_dict)

            # 保存增强后的数据
            flag = self.save_aug_data(augmented, image_name, cnt)

            if flag:
                cnt += 1
            else:
                continue

    def save_aug_data(self, augmented, image_name, cnt):
        aug_image = augmented['image']
        aug_bboxes = augmented['bboxes']
        aug_category_id = augmented['category_id']
        # print(aug_bboxes)
        # print(aug_category_id)

        name = '0' * self.max_len
        # 获取图片的后缀名
        image_suffix = image_name.split(".")[-1]

        # 未增强对应的xml文件名
        pre_xml_name = image_name.replace(image_suffix, 'xml')

        # 获取新的增强图像的文件名
        cnt_str = str(cnt)
        length = len(cnt_str)
        new_image_name = name[:-length] + cnt_str + "." + image_suffix

        # 获取新的增强xml文本的文件名
        new_xml_name = new_image_name.replace(image_suffix, 'xml')

        # 获取增强后的图片新的宽和高
        new_image_height, new_image_width = aug_image.shape[:2]

        # 深拷贝图片
        aug_image_copy = aug_image.copy()

        # 在对应的原始xml上进行修改, 获得增强后的xml文本
        with open(os.path.join(self.pre_xml_path, pre_xml_name), 'r') as pre_xml:
            aug_tree = ET.parse(pre_xml)

        # 修改image_filename值
        root = aug_tree.getroot()
        aug_tree.find('filename').text = new_image_name

        # 修改变换后的图片大小
        size = root.find('size')
        size.find('width').text = str(new_image_width)
        size.find('height').text = str(new_image_height)

        # 修改每一个标注框
        for index, obj in enumerate(root.iter('object')):
            # LABELS_OUTPUT = ['YELLOW', 'BLUE', 'RED', 'GREEN', 'GREY']
            obj.find('name').text = self.labels[aug_category_id[index]]
            # obj.find('category_name').text = LABELS_OUTPUT[aug_category_id[index]]
            
            xmin, ymin, xmax, ymax = aug_bboxes[index]
            xml_box = obj.find('bndbox')
            xml_box.find('xmin').text = str(int(xmin))
            xml_box.find('ymin').text = str(int(ymin))
            xml_box.find('xmax').text = str(int(xmax))
            xml_box.find('ymax').text = str(int(ymax))
            # other = obj.find('other').find('region_attributes')
            # other.find('type').text = LABELS_OUTPUT[aug_category_id[index]]
            
            if self.is_show:
                tl = 2
                text = f"{LABELS_OUTPUT[aug_category_id[index]]}"
                t_size = cv2.getTextSize(
                    text, 0, fontScale=tl / 3, thickness=tl)[0]
                cv2.rectangle(aug_image_copy, (int(xmin), int(ymin) - 3),
                              (int(xmin) + t_size[0],
                               int(ymin) - t_size[1] - 3),
                              (0, 0, 255), -1, cv2.LINE_AA)  # filled
                cv2.putText(aug_image_copy, text, (int(xmin), int(ymin) - 2), 0, tl / 3, (255, 255, 255), tl,
                            cv2.LINE_AA)
                cv2.rectangle(aug_image_copy, (int(xmin), int(ymin)),
                              (int(xmax), int(ymax)), (255, 255, 0), 2)

        if self.is_show:
            # 按下s键保存增强，否则取消保存此次增强
            cv2.imshow('aug_image_show', aug_image_copy)    
            key = cv2.waitKey(0)
            if key & 0xff == ord('s'):
                pass
            else:
                return False
        # 保存增强后的图片
        cv2.imwrite(os.path.join(
            self.aug_image_save_path, new_image_name), aug_image)
        # 保存增强后的xml文件
        tree = ET.ElementTree(root)
        tree.write(os.path.join(self.aug_xml_save_path, new_xml_name))

        return True


In [14]:
# 文件夹中的文件计数函数
def file_count(local_path, type_dict):
    all_file_num=0                 # 声明全局变量
    # global all_file_num                 # 声明全局变量
    file_list = os.listdir(local_path)  # 列出本地文件夹第一层目录的所有文件和目录
    for file_name in file_list:
        if os.path.isdir(os.path.join(local_path, file_name)):  # 判断是文件还是目录，是目录为真
            type_dict.setdefault("文件夹", 0)      # 如果字典key不存在，则添加并设置为初始值
            type_dict["文件夹"] += 1
            p_local_path = os.path.join(
                local_path, file_name)  # 拼接本地第一层子目录，递归时进入下一层
            file_count(p_local_path, type_dict)
        else:
            ext = os.path.splitext(file_name)[1]  # 获取到文件的后缀
            type_dict.setdefault(ext, 0)          # 如果字典key不存在，则添加并设置为初始值
            type_dict[ext] += 1
            all_file_num += 1                     # 计算总文件数量
    return all_file_num


In [20]:
# 原始的xml路径和图片路径
ORI_IMAGE_PATH = './JPEGImages'  # TODO:
ORI_XML_PATH = './Annotations'  # TODO:

# 增强后保存的xml路径和图片路径
AUG_SAVE_IMAGE_PATH = './augimages'
AUG_SAVE_XML_PATH = './auglabels'

if not os.path.exists(AUG_SAVE_IMAGE_PATH):
    os.makedirs(AUG_SAVE_IMAGE_PATH)
    print(f'创建文件夹{AUG_SAVE_IMAGE_PATH}成功！')
if not os.path.exists(AUG_SAVE_XML_PATH):
    os.makedirs(AUG_SAVE_XML_PATH)
    print(f'创建文件夹{AUG_SAVE_XML_PATH}成功！')

# 标签列表
LABELS = ["person", "car", "green", "yellow", "red", "gray", "blue", "shapan"]  # TODO:

# LABELS_OUTPUT = ['YELLOW', 'BLUE', 'RED', 'GREEN', 'GREY']

ori_path =  ORI_IMAGE_PATH
auged_path = AUG_SAVE_IMAGE_PATH  # 增强后的文件夹的路径
type_dict_AUG = dict()  # 定义一个保存文件类型及数量的空字典
type_dict_ORI = dict()  # 定义一个保存文件类型及数量的空字典
# all_file_num = 0  # 计算本地总文件数
ori_file_count = file_count(ori_path, type_dict_ORI)
auged_file_count = file_count(auged_path, type_dict_AUG)  # 运行函数,power by luotao

# 打印文件类型以及数量
for each_type in type_dict_AUG:
    print(f"文件类型为【{each_type}】的数量有：{type_dict_AUG[each_type]} 个")
print(f"原文件数量为:{ori_file_count}")
print(f"已生成文件数量为:{auged_file_count}")


文件类型为【.jpg】的数量有：555 个
原文件数量为:37
已生成文件数量为:555


In [19]:

num = 5  # 生成图像是原数据的多少倍  # TODO:

for i in range(num):
    aug = VOCAug(
        pre_image_path=ORI_IMAGE_PATH,
        pre_xml_path=ORI_XML_PATH,
        aug_image_save_path=AUG_SAVE_IMAGE_PATH,
        aug_xml_save_path=AUG_SAVE_XML_PATH,
        start_aug_id=auged_file_count+ori_file_count * i,
        labels=LABELS,
        is_show=False,
    )

    aug.aug_image()



['person', 'car', 'green', 'yellow', 'red', 'gray', 'blue', 'shapan']
--------------*--------------
labels:  ['person', 'car', 'green', 'yellow', 'red', 'gray', 'blue', 'shapan']
--------------*--------------
['person', 'car', 'green', 'yellow', 'red', 'gray', 'blue', 'shapan']
--------------*--------------
labels:  ['person', 'car', 'green', 'yellow', 'red', 'gray', 'blue', 'shapan']
--------------*--------------
['person', 'car', 'green', 'yellow', 'red', 'gray', 'blue', 'shapan']
--------------*--------------
labels:  ['person', 'car', 'green', 'yellow', 'red', 'gray', 'blue', 'shapan']
--------------*--------------
['person', 'car', 'green', 'yellow', 'red', 'gray', 'blue', 'shapan']
--------------*--------------
labels:  ['person', 'car', 'green', 'yellow', 'red', 'gray', 'blue', 'shapan']
--------------*--------------
['person', 'car', 'green', 'yellow', 'red', 'gray', 'blue', 'shapan']
--------------*--------------
labels:  ['person', 'car', 'green', 'yellow', 'red', 'gray', 'bl

In [17]:
# 显示图片

# cv2.destroyAllWindows()

# original_image1 = cv2.imread(
#     './JPEGImages/0031.jpg')
# transformed_image1 = cv2.imread(
#     './images/0000.jpg')
# original_image2 = cv2.imread(
#     './JPEGImages/0032.jpg')
# transformed_image2 = cv2.imread(
#     './images/0001.jpg')

# original_image1 = cv2.cvtColor(original_image1, cv2.COLOR_BGR2RGB)
# transformed_image1 = cv2.cvtColor(transformed_image1, cv2.COLOR_BGR2RGB)
# original_image2 = cv2.cvtColor(original_image2, cv2.COLOR_BGR2RGB)
# transformed_image2 = cv2.cvtColor(transformed_image2, cv2.COLOR_BGR2RGB)

# plt.subplot(2, 2, 1), plt.title("original image"), plt.axis('off')
# plt.imshow(original_image1)
# plt.subplot(2, 2, 2), plt.title("transformed image"), plt.axis('off')
# plt.imshow(transformed_image1)
# plt.subplot(2, 2, 3), plt.title("original image"), plt.axis('off')
# plt.imshow(original_image2)
# plt.subplot(2, 2, 4), plt.title("transformed image"), plt.axis('off')
# plt.imshow(transformed_image2)
