In [1]:
from datasets import load_from_disk, Dataset, Value, Sequence, Features, concatenate_datasets
from datasets import Image as DImage
import huggingface_hub
from tqdm import tqdm
import random
from PIL import Image, ImageDraw
import matplotlib.pyplot as plt
import io

In [None]:
dataset = load_from_disk('./object_detection_dataset')

In [None]:
labels = dataset['train'].features['objects'].feature['category'].names
labels

In [None]:
def get_object_bounds(bbox):
    """Get the minimum and maximum coordinates of all objects"""
    x_min = min(box[0] for box in bbox)
    y_min = min(box[1] for box in bbox)
    x_max = max(box[0] + box[2] for box in bbox)
    y_max = max(box[1] + box[3] for box in bbox)
    return x_min, y_min, x_max, y_max

def is_bbox_inside(bbox, crop_bounds):
    """Check if bbox is inside crop bounds"""
    x, y, w, h = bbox
    x2, y2 = x + w, y + h
    crop_x1, crop_y1, crop_x2, crop_y2 = crop_bounds
    return (x >= crop_x1 and x2 <= crop_x2 and 
            y >= crop_y1 and y2 <= crop_y2)

def is_bbox_partially_inside(bbox, crop_bounds):
    """Check if bbox is partially inside crop bounds"""
    x, y, w, h = bbox
    x2, y2 = x + w, y + h
    crop_x1, crop_y1, crop_x2, crop_y2 = crop_bounds
    return not (x2 < crop_x1 or x > crop_x2 or 
               y2 < crop_y1 or y > crop_y2)

def adjust_bbox_coordinates(bbox, crop_bounds):
    """Adjust bbox coordinates relative to crop bounds"""
    x, y, w, h = bbox
    crop_x1, crop_y1, _, _ = crop_bounds
    return [int(x - crop_x1), int(y - crop_y1), int(w), int(h)]

def augment_image(example, image_id):
    """
    Augment a single image by randomly selecting and cropping around an object
    """
    # Convert image bytes to PIL Image
    image = example['image']
    width, height = image.size
    
    # Randomly select an object
    idx = random.randint(0, len(example['objects']['bbox']) - 1)
    selected_bbox = example['objects']['bbox'][idx]
    
    # Calculate crop bounds with margin
    margin = random.uniform(0.3, 0.9)
    x, y, w, h = selected_bbox
    margin_w = int(w * margin)
    margin_h = int(h * margin)
    
    crop_x1 = max(0, x - margin_w)
    crop_y1 = max(0, y - margin_h)
    crop_x2 = min(width, x + w + margin_w)
    crop_y2 = min(height, y + h + margin_h)
    
    # Crop image
    cropped_img = image.crop((crop_x1, crop_y1, crop_x2, crop_y2))
    jpeg_buffer = io.BytesIO()
    cropped_img.save(jpeg_buffer, format='JPEG')
    jpeg_buffer.seek(0)
    cropped_img = Image.open(jpeg_buffer)

    
    # Create new annotation
    new_annotation = {
        'bbox_id': [],
        'category': [],
        'bbox': [],
        'area': []
    }
    
    crop_bounds = (crop_x1, crop_y1, crop_x2, crop_y2)
    
    # Check each bbox and include if it's inside the crop
    for i in range(len(example['objects']['bbox'])):
        if is_bbox_partially_inside(example['objects']['bbox'][i], crop_bounds):
            new_bbox = adjust_bbox_coordinates(example['objects']['bbox'][i], crop_bounds)
            if new_bbox[0] < 0:
                if new_bbox[2] + new_bbox[0] > 0.5 * example['objects']['bbox'][i][2]:
                    new_bbox[2] = new_bbox[2] + new_bbox[0]
                    new_bbox[0] = 0
                else:
                    continue

            if new_bbox[1] < 0:
                if new_bbox[3] + new_bbox[1] > 0.4 * example['objects']['bbox'][i][3]:
                    new_bbox[3] = new_bbox[3] + new_bbox[1]
                    new_bbox[1] = 0
                else:
                    continue

            if (new_x2 := new_bbox[0] + new_bbox[2]) > cropped_img.width:
                if new_bbox[2] - (new_x2 - cropped_img.width) > 0.5 * example['objects']['bbox'][i][2]:
                    new_bbox[2] = new_bbox[2] - (new_x2 - cropped_img.width)
                else:
                    continue

            if (new_y2 := new_bbox[1] + new_bbox[3]) > cropped_img.height:
                if new_bbox[3] - (new_y2 - cropped_img.height) > 0.4 * example['objects']['bbox'][i][3]:
                    new_bbox[3] = new_bbox[3] - (new_y2 - cropped_img.height)
                else:
                    continue

            new_annotation['bbox_id'].append(example['objects']['bbox_id'][i])
            new_annotation['category'].append(example['objects']['category'][i])
            new_annotation['bbox'].append(new_bbox)
            new_annotation['area'].append(example['objects']['area'][i])
    
    return {
        'image_id': image_id,
        'width': cropped_img.width,
        'height': cropped_img.height,
        'image': cropped_img,
        'objects': new_annotation,
    }

def augment_dataset(dataset, num_augmentations=1):
    """
    Augment the entire dataset
    """
    augmented_examples = []
    
    image_id = max(dataset['image_id']) + 1

    for example in tqdm(dataset):
        for _ in range(num_augmentations):
            if len(example['objects']['bbox_id']) == 1:
                continue

            try:
                aug_example = augment_image(example, image_id)
                if len(example['objects']['bbox_id']) == len(aug_example['objects']['bbox_id']):
                    continue
                augmented_examples.append(aug_example)
                image_id += 1
            except Exception as e:
                print(f"Error augmenting image: {e}")
                continue
    

    class_label = dataset.features['objects'].feature['category']
    features = Features({
        'image_id': Value('int64'),
        'width': Value('int64'),
        'height': Value('int64'),
        'image': DImage(decode=True),
        'objects': Sequence({
            'bbox_id': Value('int64'),
            'category': class_label,
            'bbox': Sequence(Value('float64'), length=4),
            'area': Value('int64')
        })
    })
    
    # 메모리 부족으로 나눠서 처리
    sub_datasets = []
    for i in tqdm(range(0, len(augmented_examples), 500)):
        sub_data = augmented_examples[i: i + 500]
        sub_dataset = Dataset.from_list(sub_data, features=features)
        sub_datasets.append(sub_dataset)

    dataset = concatenate_datasets(sub_datasets)
    dataset = dataset.train_test_split(test_size=0.1)

    return dataset

augmented_dataset = augment_dataset(concatenate_datasets([dataset['train'], dataset['test']]), num_augmentations=1)

In [None]:
print(dataset)

In [None]:
print(augmented_dataset)

In [None]:
merged_dataset = concatenate_datasets([dataset['train'], dataset['test'], augmented_dataset['train'], augmented_dataset['test']])
merged_dataset = merged_dataset.train_test_split(test_size=0.05)
merged_dataset

In [None]:
def show_sample():
    idx = random.randint(1, 50000)
    print(idx)
    sample = dataset['train'][idx]
    image = sample['image']
    bboxes = sample['objects']['bbox']
    categories = sample['objects']['category']

    draw_image = image.copy()
    draw = ImageDraw.Draw(draw_image)

    for category, bbox in zip(categories, bboxes):
        draw.rectangle((bbox[0], bbox[1], bbox[0] + bbox[2], bbox[1] + bbox[3]), outline="green", width=3)

    plt.xticks([])
    plt.yticks([])
    plt.imshow(draw_image)
    plt.show()

    print(f'sample: {len(sample['objects']['bbox_id'])}')
    return sample


def show_augment(sample):
    aug = augment_image(sample, image_id=1)
    image = aug['image']
    bboxes = aug['objects']['bbox']
    categories = aug['objects']['category']

    draw_image = image.copy()
    draw = ImageDraw.Draw(draw_image)

    for category, bbox in zip(categories, bboxes):
        draw.rectangle((bbox[0], bbox[1], bbox[0] + bbox[2], bbox[1] + bbox[3]), outline="green", width=3)

    plt.xticks([])
    plt.yticks([])
    plt.imshow(draw_image)
    plt.show()

    print(f'augment: {len(aug['objects']['bbox_id'])}')
    print(aug['objects'])
    
sample = show_sample()
show_augment(sample)