Skip to content

Commit

Permalink
feat(python): 提取car类别图像和标注文件
Browse files Browse the repository at this point in the history
  • Loading branch information
zjZSTU committed Mar 25, 2020
1 parent 4dc22e6 commit 8250890
Show file tree
Hide file tree
Showing 2 changed files with 148 additions and 0 deletions.
94 changes: 94 additions & 0 deletions py/utils/data/voc_car.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
# -*- coding: utf-8 -*-

"""
@date: 2020/3/25 下午3:03
@file: voc_car.py
@author: zj
@description: 提取car类别训练和验证集
"""

import os
import shutil
import random
import numpy as np
from utils.util import check_dir

suffix_xml = '.xml'
suffix_jpeg = '.jpg'

car_train_path = '../../data/VOCdevkit/VOC2007/ImageSets/Main/car_train.txt'
car_val_path = '../../data/VOCdevkit/VOC2007/ImageSets/Main/car_val.txt'

voc_annotation_dir = '../../data/VOCdevkit/VOC2007/Annotations/'
voc_jpeg_dir = '../../data/VOCdevkit/VOC2007/JPEGImages/'

car_root_dir = '../../data/voc_car/'


def parse_train_val(data_path):
"""
提取指定类别图像
"""
samples = []

with open(data_path, 'r') as file:
lines = file.readlines()
for line in lines:
res = line.strip().split(' ')
if len(res) == 3 and int(res[2]) == 1:
samples.append(res[0])

return np.array(samples)


def sample_train_val(samples):
"""
随机采样样本,减少数据集个数(留下1/10)
"""
for name in ['train', 'val']:
dataset = samples[name]
length = len(dataset)

random_samples = random.sample(range(length), int(length / 10))
# print(random_samples)
new_dataset = dataset[random_samples]
samples[name] = new_dataset

return samples


def save_car(car_samples, data_root_dir, data_annotation_dir, data_jpeg_dir):
"""
保存类别Car的样本图片和标注文件
"""
for sample_name in car_samples:
src_annotation_path = os.path.join(voc_annotation_dir, sample_name + suffix_xml)
dst_annotation_path = os.path.join(data_annotation_dir, sample_name + suffix_xml)
shutil.copyfile(src_annotation_path, dst_annotation_path)

src_jpeg_path = os.path.join(voc_jpeg_dir, sample_name + suffix_jpeg)
dst_jpeg_path = os.path.join(data_jpeg_dir, sample_name + suffix_jpeg)
shutil.copyfile(src_jpeg_path, dst_jpeg_path)

csv_path = os.path.join(data_root_dir, 'car.csv')
np.savetxt(csv_path, np.array(car_samples), fmt='%s')


if __name__ == '__main__':
samples = {'train': parse_train_val(car_train_path), 'val': parse_train_val(car_val_path)}
print(samples)
# samples = sample_train_val(samples)
# print(samples)

check_dir(car_root_dir)
for name in ['train', 'val']:
data_root_dir = os.path.join(car_root_dir, name)
data_annotation_dir = os.path.join(data_root_dir, 'Annotations')
data_jpeg_dir = os.path.join(data_root_dir, 'JPEGImages')

check_dir(data_root_dir)
check_dir(data_annotation_dir)
check_dir(data_jpeg_dir)
save_car(samples[name], data_root_dir, data_annotation_dir, data_jpeg_dir)

print('done')
54 changes: 54 additions & 0 deletions py/utils/util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# -*- coding: utf-8 -*-

"""
@date: 2020/3/25 下午3:06
@file: util.py
@author: zj
@description:
"""

import os
import numpy as np
import xmltodict
import torch


def check_dir(data_dir):
if not os.path.exists(data_dir):
os.mkdir(data_dir)


def parse_car_csv(csv_dir):
csv_path = os.path.join(csv_dir, 'car.csv')
samples = np.loadtxt(csv_path, dtype=np.str)
return samples


def parse_xml(xml_path):
"""
解析xml文件,返回标注边界框坐标
"""
# print(xml_path)
with open(xml_path, 'rb') as f:
xml_dict = xmltodict.parse(f)
# print(xml_dict)

bndboxs = list()
objects = xml_dict['annotation']['object']
if isinstance(objects, list):
for obj in objects:
obj_name = obj['name']
difficult = int(obj['difficult'])
if 'car'.__eq__(obj_name) and difficult != 1:
bndbox = obj['bndbox']
bndboxs.append((int(bndbox['xmin']), int(bndbox['ymin']), int(bndbox['xmax']), int(bndbox['ymax'])))
elif isinstance(objects, dict):
obj_name = objects['name']
difficult = int(objects['difficult'])
if 'car'.__eq__(obj_name) and difficult != 1:
bndbox = objects['bndbox']
bndboxs.append((int(bndbox['xmin']), int(bndbox['ymin']), int(bndbox['xmax']), int(bndbox['ymax'])))
else:
pass

return np.array(bndboxs)

0 comments on commit 8250890

Please sign in to comment.