# Object detection

## Load the dataset

In [None]:
import os 
from picsellia import Client

import torch
from PIL import Image

from datasets import load_dataset
from picsellia.types.enums import AnnotationFileType, InferenceType
from transformers import AutoModelForObjectDetection, TrainingArguments, AutoImageProcessor
from transformers import Trainer
from transformers import pipeline, TrainerCallback

from utils.picsellia import get_experiment, download_data, evaluate_asset, log_metrics
from utils.vit import CocoDetection, get_category_mapping, run_evaluation, get_filenames_by_ids, write_metadata_file, \
    read_annotation_file, get_category_mapping, format_coco_annot_to_jsonlines_format, transform_aug_ann, \
    custom_train_test_eval_split, collate_fn, save_annotation_file_images, format_evaluation_results, \
    get_dataset_image_ids

In [None]:
api_token = ""
client = Client(api_token=api_token, organization_name="")
experiment = client.get_experiment_by_id('')

In [None]:
dataset, data_dir = download_data(experiment=experiment)

In [None]:
annotations, annotation_file_path = read_annotation_file(dataset=dataset, target_path=data_dir)
formatted_coco = format_coco_annot_to_jsonlines_format(annotations=annotations)
write_metadata_file(data=formatted_coco, output_path=os.path.join(data_dir,'metadata.jsonl'))

In [None]:
loaded_dataset  = load_dataset("imagefolder", data_dir=data_dir)

In [None]:
train_test_valid_dataset = custom_train_test_eval_split(loaded_dataset=loaded_dataset, test_prop=0.15)

In [None]:
train_test_valid_dataset

In [None]:
categories = [cat['name'] for cat in annotations['categories']] 
id2label = {index: x for index, x in enumerate(categories, start=0)}
label2id = {v: k for k, v in id2label.items()}
labelmap = {str(i): category for i, category in enumerate(categories)}
experiment.log("labelmap", labelmap, "labelmap", replace=True)

## Preprocess the data

In [None]:
checkpoint = "facebook/detr-resnet-50"
image_processor = AutoImageProcessor.from_pretrained(checkpoint)

In [None]:
train_test_valid_dataset["train"] = train_test_valid_dataset["train"].with_transform(transform_aug_ann)

In [None]:
# for sample in range(len(train_test_valid_dataset['train'])):
#     print(sample)
#     print(train_test_valid_dataset["train"][sample])

## in case there are images with degenerated bowes, remove them 
# remove_idx = [5325]
# keep = [i for i in range(len(train_test_valid_dataset["train"])) if i not in remove_idx]
# train_test_valid_dataset["train"] = train_test_valid_dataset["train"].select(keep)

## Training the DETR model

In [None]:
model = AutoModelForObjectDetection.from_pretrained(
    checkpoint,
    id2label=id2label,
    label2id=label2id,
    ignore_mismatched_sizes=True,
)

In [None]:
output_model_dir = os.path.join(experiment.checkpoint_dir)

In [None]:
training_args = TrainingArguments(
    output_dir=output_model_dir,
    per_device_train_batch_size=8,
    num_train_epochs=30,
    fp16=True,
    save_steps=200,
    logging_steps=50,
    lr_scheduler_type='constant',
    learning_rate=1e-5,
    save_total_limit=2,
    remove_unused_columns=False,
    push_to_hub=False,
)

In [None]:
class LogObjectDetectionMetricsCallback(TrainerCallback):
    def on_log(self, args, state, control, logs=None, **kwargs):
        if state.is_local_process_zero:
            for metric_name, value in logs.items():
                log_metrics(metric_name=metric_name, value=value)


In [None]:
trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=collate_fn,
    train_dataset=train_test_valid_dataset["train"],
    tokenizer=image_processor,
    callbacks=[LogObjectDetectionMetricsCallback]
)
trainer.train()


In [None]:
trainer.save_model(output_dir=output_model_dir)

## Evaluate

Object detection models are commonly evaluated with a set of <a href="https://cocodataset.org/#detection-eval">COCO-style metrics</a>.
You can use one of the existing metrics implementations, but here you'll use the one from `torchvision` to evaluate the final
model that you pushed to the Hub.

To use the `torchvision` evaluator, you'll need to prepare a ground truth COCO dataset. The API to build a COCO dataset
requires the data to be stored in a certain format, so you'll need to save images and annotations to disk first. Just like
when you prepared your data for training, the annotations from the `dataset["test"]` need to be formatted. However, images
should stay as they are.

The evaluation step requires a bit of work, but it can be split in three major steps.
First, prepare the `dataset["test"]` set: format the annotations and save the data to disk.

Next, prepare an instance of a `CocoDetection` class that can be used with `cocoevaluator`.

In [None]:
im_processor = AutoImageProcessor.from_pretrained(output_model_dir)
path_output, path_anno = save_annotation_file_images(dataset=train_test_valid_dataset["test"], experiment=experiment, id2label=id2label)
test_ds_coco_format = CocoDetection(path_output, im_processor, path_anno)

In [None]:
model = AutoModelForObjectDetection.from_pretrained(output_model_dir)

In [None]:
results = run_evaluation(test_ds_coco_format=test_ds_coco_format, im_processor=im_processor, model=model)
casted_results = format_evaluation_results(results=results)
experiment.log(name='evaluation metrics', type='table', data=casted_results)

## Inference

In [None]:
# for one image
image_path = "/home/ubuntu/dev/vision-transformers/grape-detector/data/SYH_2017-04-27_1291.jpg"
image = Image.open(image_path)

In [None]:
image_processor = AutoImageProcessor.from_pretrained(output_model_dir)
model = AutoModelForObjectDetection.from_pretrained(output_model_dir)

with torch.no_grad():
    inputs = image_processor(images=image, return_tensors="pt")
    outputs = model(**inputs)

    target_sizes = torch.tensor([image.size[::-1]])
    results = image_processor.post_process_object_detection(outputs, threshold=0.5, target_sizes=target_sizes)[0]


for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
    box = [round(i, 2) for i in box.tolist()]
    print(
        f"Detected {model.config.id2label[label.item()]} with confidence "
        f"{round(score.item(), 3)} at location {box}"
    )

In [None]:
image_processor = AutoImageProcessor.from_pretrained(output_model_dir)
model = AutoModelForObjectDetection.from_pretrained(output_model_dir)
dataset_labels = {label.name: label for label in dataset.list_labels()}

In [None]:
eval_image_ids = get_dataset_image_ids(train_test_valid_dataset, "eval")
id2filename_eval = get_filenames_by_ids(image_ids=eval_image_ids, annotations=annotations)

In [None]:
for file_path in list(id2filename_eval.values()):
    evaluate_asset(file_path=file_path)

In [None]:
experiment.compute_evaluations_metrics(inference_type=InferenceType.OBJECT_DETECTION)