## Configuration for the training

In [None]:
import os

cwd = os.getcwd()

if not os.path.exists('model'):
    os.mkdir('model')

if not os.path.exists(f'{cwd}/result'):
    os.mkdir(f'{cwd}/result')

TRAIN_DATASET_PATH = f'{cwd}/dataset/train'
VALID_DATASET_PATH = f'{cwd}/dataset/valid'
TEST_DATASET_PATH = f'{cwd}/dataset/test'
MODEL_PATH = f'{cwd}/model'

MODEL = 'efficientdet_lite0'
MODEL_NAME = 'fish.tflite'
CLASSES = ['fish', 'jellyfish', 'penguin', 'shark', 'puffin', 'stingray', 'starfish']
EPOCHS = 20
BATCH_SIZE = 4

In [None]:
from tflite_model_maker import model_spec
from tflite_model_maker import object_detector

## Create dataset for training

In [None]:
train_data = object_detector.DataLoader.from_pascal_voc(
    TRAIN_DATASET_PATH,
    TRAIN_DATASET_PATH,
    CLASSES
)

val_data = object_detector.DataLoader.from_pascal_voc(
    VALID_DATASET_PATH,
    VALID_DATASET_PATH,
    CLASSES
)

## Generate the model

In [None]:
spec = model_spec.get(MODEL)

In [None]:
model = object_detector.create(
    train_data,
    model_spec=spec,
    batch_size=BATCH_SIZE,
    train_whole_model=True,
    epochs=EPOCHS,
    validation_data=val_data
)

In [None]:
model.evaluate(val_data)

In [None]:
model.export(export_dir=MODEL_PATH, tflite_filename=MODEL_NAME)