In [None]:
from datasets import load_from_disk
import random
from PIL import Image, ImageDraw, ImageFont
import numpy as np
import matplotlib.pyplot as plt
import torch
from dataclasses import dataclass
from torchmetrics.detection.mean_ap import MeanAveragePrecision
import numpy as np
import torch
from functools import partial

import albumentations as A
from transformers import AutoImageProcessor, ConditionalDetrForObjectDetection, TrainingArguments, Trainer
from transformers.image_transforms import center_to_corners_format

In [None]:
device = torch.device('cuda') if torch.cuda.is_available() else 'cpu'
print(device)

In [None]:
dataset = load_from_disk('./dataset')
dataset

In [None]:
labels = dataset['train'].features['objects'].feature['category'].names
label2id = {
    c: i for i, c in enumerate(labels)
}
id2label = {
    i: c for i, c in enumerate(labels)
}

print(labels)
print(label2id)
print(id2label)

In [None]:
def show_random():
    i = random.randint(0, len(dataset['train']))
    image = dataset['train'][i]['image']
    categories = dataset['train'][i]['objects']['category']
    bboxes = dataset['train'][i]['objects']['bbox']
    draw_image = image.copy()
    draw = ImageDraw.Draw(draw_image)
    
    font = ImageFont.load_default(size=40)
    for category, bbox in zip(categories, bboxes):
        draw.rectangle((bbox[0], bbox[1], bbox[0] + bbox[2], bbox[1] + bbox[3]), outline="green", width=3)
        
        text_position = (bbox[0], bbox[1] - 25)  # bbox 위에 텍스트 위치
        draw.text(text_position, labels[category], fill="red", font=font)
    
    plt.xticks([])
    plt.yticks([])
    plt.imshow(draw_image)
    plt.show()

show_random()

In [None]:
ckpt = "microsoft/conditional-detr-resnet-50"

In [None]:
image_processor = AutoImageProcessor.from_pretrained(ckpt)
image_processor

In [None]:
train_augment_and_transform = A.Compose(
    [
        A.Perspective(p=0.1),
        A.HorizontalFlip(p=0.5),
        A.RandomBrightnessContrast(p=0.5),
        A.HueSaturationValue(p=0.1),
    ],
    bbox_params=A.BboxParams(format='coco', label_fields=['category'], clip=True, min_area=25),
)

validation_transform = A.Compose(
    [A.NoOp()],
    bbox_params=A.BboxParams(format='coco', label_fields=['category'], clip=True,)
)

In [None]:
def format_image_annotations_as_coco(image_id, categories, areas, bboxes):
    annotations = []
    for category, area, bbox in zip(categories, areas, bboxes):
        formatted_annotation = {
            'image_id': image_id,
            'category_id': category,
            'iscrowd': 0,
            'area': area,
            'bbox': list(bbox),
        }
        annotations.append(formatted_annotation)

    return {
        'image_id': image_id,
        'annotations': annotations,
    }
    

def augment_and_transform_batch(examples, transform, image_processor, return_pixel_mask=False):
    images = []
    annotations = []
    for image_id, image, objects in zip(examples['image_id'], examples['image'], examples['objects']):
        image = np.array(image.convert('RGB'))

        output = transform(image=image, bboxes=objects['bbox'], category=objects['category'])

        images.append(output['image'])

        formatted_annotations = format_image_annotations_as_coco(image_id, output['category'], objects['area'], output['bboxes'])
        annotations.append(formatted_annotations)

    result = image_processor(images=images, annotations=annotations, return_tensors='pt')

    if not return_pixel_mask:
        result.pop('pixel_mask', None)

    return result

In [None]:
train_transform_batch = partial(
    augment_and_transform_batch, transform=train_augment_and_transform, image_processor=image_processor,
)

validation_transform_batch = partial(
    augment_and_transform_batch, transform=validation_transform, image_processor=image_processor,
)

dataset['train'] = dataset['train'].with_transform(train_transform_batch)
dataset['test'] = dataset['test'].with_transform(validation_transform_batch)

In [None]:
def collate_fn(batch):
    data = {}
    data['pixel_values'] = torch.stack([x['pixel_values'] for x in batch])
    data['labels'] = [x['labels'] for x in batch]
    if 'pixel_mask' in batch[0]:
        data['pixel_mask'] = torch.stack([x['pixel_mask'] for x in batch])

    return data

In [None]:
def convert_bbox_yolo_to_pascal(boxes, image_size):
    boxes = center_to_corners_format(boxes)
    height, width = image_size
    boxes = boxes * torch.tensor([[width, height, width, height]])

    return boxes

@dataclass
class ModelOutput:
    logits: torch.Tensor
    pred_boxes: torch.Tensor


@torch.no_grad()
def compute_metrics(evaluation_results, image_processor, threshold=0.0, id2label=None):
    """
    Compute mean average mAP, mAR and their variants for the object detection task.

    Args:
        evaluation_results (EvalPrediction): Predictions and targets from evaluation.
        threshold (float, optional): Threshold to filter predicted boxes by confidence. Defaults to 0.0.
        id2label (Optional[dict], optional): Mapping from class id to class name. Defaults to None.

    Returns:
        Mapping[str, float]: Metrics in a form of dictionary {<metric_name>: <metric_value>}
    """

    predictions, targets = evaluation_results.predictions, evaluation_results.label_ids

    # For metric computation we need to provide:
    #  - targets in a form of list of dictionaries with keys "boxes", "labels"
    #  - predictions in a form of list of dictionaries with keys "boxes", "scores", "labels"

    image_sizes = []
    post_processed_targets = []
    post_processed_predictions = []

    # Collect targets in the required format for metric computation
    for batch in targets:
        # collect image sizes, we will need them for predictions post processing
        batch_image_sizes = torch.tensor(np.array([x["orig_size"] for x in batch]))
        image_sizes.append(batch_image_sizes)
        # collect targets in the required format for metric computation
        # boxes were converted to YOLO format needed for model training
        # here we will convert them to Pascal VOC format (x_min, y_min, x_max, y_max)
        for image_target in batch:
            boxes = torch.tensor(image_target["boxes"])
            boxes = convert_bbox_yolo_to_pascal(boxes, image_target["orig_size"])
            labels = torch.tensor(image_target["class_labels"])
            post_processed_targets.append({"boxes": boxes, "labels": labels})

    # Collect predictions in the required format for metric computation,
    # model produce boxes in YOLO format, then image_processor convert them to Pascal VOC format
    for batch, target_sizes in zip(predictions, image_sizes):
        batch_logits, batch_boxes = batch[1], batch[2]
        output = ModelOutput(logits=torch.tensor(batch_logits), pred_boxes=torch.tensor(batch_boxes))
        post_processed_output = image_processor.post_process_object_detection(
            output, threshold=threshold, target_sizes=target_sizes
        )
        post_processed_predictions.extend(post_processed_output)

    # Compute metrics
    metric = MeanAveragePrecision(box_format="xyxy", class_metrics=True)
    metric.update(post_processed_predictions, post_processed_targets)
    metrics = metric.compute()

    # Replace list of per class metrics with separate metric for each class
    classes = metrics.pop("classes")
    map_per_class = metrics.pop("map_per_class")
    mar_100_per_class = metrics.pop("mar_100_per_class")
    for class_id, class_map, class_mar in zip(classes, map_per_class, mar_100_per_class):
        class_name = id2label[class_id.item()] if id2label is not None else class_id.item()
        metrics[f"map_{class_name}"] = class_map
        metrics[f"mar_100_{class_name}"] = class_mar

    metrics = {k: round(v.item(), 4) for k, v in metrics.items()}

    return metrics


eval_compute_metrics_fn = partial(
    compute_metrics, image_processor=image_processor, id2label=id2label, threshold=0.0
)

In [None]:
model = ConditionalDetrForObjectDetection.from_pretrained(
    ckpt,
    id2label=id2label,
    label2id=label2id,
    ignore_mismatched_sizes=True,
)

In [None]:
training_arguments = TrainingArguments(
    output_dir='./model_ckpt',
    num_train_epochs=100,
    fp16=False,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    dataloader_num_workers=2,
    lr_scheduler_type='cosine',
    weight_decay=1e-4,
    max_grad_norm=0.01,
    gradient_accumulation_steps=4,
    metric_for_best_model='eval_map',
    greater_is_better=True,
    eval_strategy='epoch',
    save_strategy='epoch',
    save_total_limit=2,
    remove_unused_columns=False,
    eval_do_concat_batches=False,
    push_to_hub=False,
)

In [None]:
trainer = Trainer(
    model=model,
    args=training_arguments,
    train_dataset=dataset['train'],
    eval_dataset=dataset['test'],
    tokenizer=image_processor,
    data_collator=collate_fn,
    compute_metrics=eval_compute_metrics_fn,
)

trainer.train(resume_from_checkpoint=False)