In [None]:
from flash.core.data.utils import download_data
from flash.image import ObjectDetectionData, ObjectDetector
from flash.image.detection.output import FiftyOneDetectionLabelsOutput
from sklearn.model_selection import train_test_split
from tqdm.auto import tqdm
from PIL import Image
from glob import glob

import torch
import io, base64
import numpy as np
import flash
import json
import utils

In [None]:
print("flash version: {}".format(flash.__version__))
print("torch version: {}".format(torch.__version__))

In [None]:
data_path = "./monkey-opencv"
train_dir = "./train"
test_dir = "./test"

In [None]:
df = utils.create_df_from_dir(data_path)

In [None]:
df_train, df_test = train_test_split(df, test_size=0.15, random_state=42)

In [None]:
utils.create_dataset(df_train, output_path=train_dir)
utils.create_dataset(df_test, output_path=test_dir)

## Create coco data

In [None]:
coco_annotation_path = "coco_monkey_annotation.json"
test_coco_annotation_path = "test_coco_monkey_annotation.json"

In [None]:
paths = glob("{}/*.json".format(train_dir))
coco_data_dict = utils.create_coco_data_dict(paths)
json.dump(coco_data_dict, open(coco_annotation_path, "w"), indent=2)

In [None]:
paths = glob("{}/*.json".format(test_dir))
test_coco_data_dict = utils.create_coco_data_dict(paths)
json.dump(test_coco_data_dict, open(test_coco_annotation_path, "w"), indent=2)

In [None]:
# Train parameter

In [None]:
image_size = 256
batch_size = 16
max_epochs = 20
learning_rate = 0.005

model_head = "efficientdet"
model_backbone= "d0"

In [None]:
datamodule = ObjectDetectionData.from_coco(
    train_folder=train_dir,
    train_ann_file=coco_annotation_path,
    test_folder=test_dir,
    test_ann_file=test_coco_annotation_path,
    batch_size=batch_size,
    transform_kwargs={"image_size": image_size},
    val_split=0.1,
)

## Train

In [None]:
"""
model = ObjectDetector(
    head="retinanet", 
    backbone="resnet18_fpn", 
    num_classes=datamodule.num_classes, 
    image_size=image_size,
    output=FiftyOneDetectionLabelsOutput(return_filepath=True),
    learning_rate=0.0001,
    pretrained=True
)
"""

In [None]:
model = ObjectDetector(
    head=model_head, 
    backbone=model_backbone, 
    num_classes=datamodule.num_classes, 
    image_size=image_size,
    output=FiftyOneDetectionLabelsOutput(return_filepath=True),
    learning_rate=learning_rate,
)

In [None]:
trainer = flash.Trainer(max_epochs=max_epochs, gpus=1)
trainer.finetune(model, datamodule=datamodule, strategy="no_freeze")

In [None]:
check_point_path = f"finetuned_{model_head}_{model_backbone}_{image_size}_{max_epochs}.ckpt"

In [None]:
trainer.save_checkpoint(check_point_path)

In [None]:
trainer.test(model, datamodule=datamodule)

## Prediction

In [None]:
from torch.utils.data import DataLoader

In [None]:
predict_files = glob(f"{test_dir}/*.jpg")
len(predict_files)

In [None]:
predict_dataset = ObjectDetectionData.from_files(
    predict_files=predict_files,
    batch_size=1, 
    transform_kwargs={"image_size": image_size},

)

In [None]:
predictions = trainer.predict(model, datamodule=predict_dataset)

## Visualize with fiftyone

- https://voxel51.com/docs/fiftyone/getting_started/install.html
- install fiftyone (do not forget to restart jupyter notebook)
```sh
pip install fiftyone
```

In [None]:
from flash.core.integrations.fiftyone import visualize
from itertools import chain

In [None]:
_predictions = list(chain.from_iterable(predictions))

In [None]:
session = visualize(_predictions)