# EfficientDet *Marchantia polymorpha* Gemma Cup Detection

https://medium.com/data-science-at-microsoft/training-efficientdet-on-custom-data-with-pytorch-lightning-using-an-efficientnetv2-backbone-1cdf3bd7921f

## Imports

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from pathlib import Path

import loaders as lds

## Constants

In [None]:
data_path = Path("..").joinpath("data_in")
images_path = data_path.joinpath("images")
dataset_path = Path("..").joinpath("data_in", "datasets")

data_path.is_dir(), images_path.is_dir(), dataset_path.is_dir()


## Dataset

### Dataframe

In [None]:
import pandas as pd
from pathlib import Path

data_path = Path("..").joinpath("data_in")
dataset_path = Path("..").joinpath("data_in", "datasets")

train = pd.read_csv(dataset_path.joinpath("train.csv"))
val = pd.read_csv(dataset_path.joinpath("val.csv"))
test = pd.read_csv(dataset_path.joinpath("test.csv"))

train.shape, val.shape, test.shape


### Test Dataset

In [None]:
import matplotlib.pyplot as plt

image_size = 512
tst_ds = lds.GemmaDataset(
    train,
    images_path=images_path,
    transform=lds.get_test_image_transform(image_size=image_size),
    bboxes=True,
    return_id=True,
    yxyx=True,
)

plt.imshow(
    tst_ds.draw_image_with_boxes(filename=train.sample(n=1).filename.to_list()[0])
)
plt.tight_layout()
plt.axis("off")
plt.show()


### Test Transforms

In [None]:
file_name = train.sample(n=1).filename.to_list()[0]

lds.make_patches_grid(
    images=[tst_ds.draw_image_with_boxes(filename=file_name) for _ in range(12)],
    row_count=3,
    col_count=4,
    figsize=(10, 7.5),
)

## Data Loader

In [None]:
from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader
import torch


class GemmaCupDataModule(LightningDataModule):
    def __init__(
        self,
        train_dataset_adaptor,
        validation_dataset_adaptor,
        train_transforms=lds.get_train_transform(image_size=image_size),
        valid_transforms=lds.get_valid_transform(image_size=image_size),
        num_workers=0,
        batch_size=8,
    ):
        self.train_ds = train_dataset_adaptor
        self.valid_ds = validation_dataset_adaptor
        self.train_tfms = train_transforms
        self.valid_tfms = valid_transforms
        self.num_workers = num_workers
        self.batch_size = batch_size
        super().__init__()

    def train_dataset(self) -> lds.GemmaDataset:
        return lds.GemmaDataset(
            self.train_ds,
            transform=self.train_tfms,
            images_path=images_path,
            bboxes=True,
            return_id=True,
            yxyx=True,
        )

    def val_dataset(self) -> lds.GemmaDataset:
        return lds.GemmaDataset(
            self.valid_ds,
            transform=self.valid_tfms,
            images_path=images_path,
            bboxes=True,
            return_id=True,
            yxyx=True,
        )

    def train_dataloader(self) -> DataLoader:
        train_dataset = self.train_dataset()
        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            pin_memory=True,
            drop_last=True,
            num_workers=self.num_workers,
            collate_fn=self.collate_fn,
        )

        return train_loader

    def val_dataloader(self) -> DataLoader:
        valid_dataset = self.val_dataset()
        valid_loader = torch.utils.data.DataLoader(
            valid_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            pin_memory=True,
            drop_last=True,
            num_workers=self.num_workers,
            collate_fn=self.collate_fn,
        )

        return valid_loader

    @staticmethod
    def collate_fn(batch):
        images, targets, image_ids = tuple(zip(*batch))
        images = torch.stack(images)
        images = images.float()

        boxes = [target["bboxes"].float() for target in targets]
        labels = [target["labels"].float() for target in targets]
        img_size = torch.tensor([target["img_size"] for target in targets]).float()
        img_scale = torch.tensor([target["img_scale"] for target in targets]).float()

        annotations = {
            "bbox": boxes,
            "cls": labels,
            "img_size": img_size,
            "img_scale": img_scale,
        }

        return images, annotations, targets, image_ids


## Create model

In [None]:
from effdet.config.model_config import efficientdet_model_param_dict
from effdet import get_efficientdet_config, EfficientDet, DetBenchTrain
from effdet.efficientdet import HeadNet
from effdet.config.model_config import efficientdet_model_param_dict

In [None]:
print(f'number of configs: {len(efficientdet_model_param_dict)}')

list(efficientdet_model_param_dict.keys())[::3]

In [None]:
import timm

In [None]:
timm.list_models('s*')

In [None]:
def create_model(num_classes=1, image_size=512, architecture="tf_efficientnetv2_s"):
    efficientdet_model_param_dict[architecture] = dict(
        name=architecture,
        backbone_name=architecture,
        backbone_args=dict(drop_path_rate=0.2),
        num_classes=num_classes,
        url="",
    )

    config = get_efficientdet_config(architecture)
    config.update({"num_classes": num_classes})
    config.update({"image_size": (image_size, image_size)})

    print(config)

    net = EfficientDet(config, pretrained_backbone=True)
    net.class_net = HeadNet(
        config,
        num_outputs=config.num_classes,
    )
    return DetBenchTrain(net, config)


### Lightning module

In [None]:
import torch
from pytorch_lightning import LightningModule


class GemmaCupEfficientDetModel(LightningModule):
    def __init__(
        self,
        num_classes=1,
        img_size=512,
        prediction_confidence_threshold=0.2,
        learning_rate=0.0002,
        wbf_iou_threshold=0.44,
        inference_transforms=lds.get_valid_transform(image_size=image_size),
        model_architecture="tf_efficientnetv2_l",
    ):
        super().__init__()
        self.img_size = img_size
        self.model = create_model(
            num_classes, img_size, architecture=model_architecture
        )
        self.prediction_confidence_threshold = prediction_confidence_threshold
        self.lr = learning_rate
        self.wbf_iou_threshold = wbf_iou_threshold
        self.inference_tfms = inference_transforms

    def forward(self, images, targets):
        return self.model(images, targets)

    def configure_optimizers(self):
        return torch.optim.AdamW(self.model.parameters(), lr=self.lr)

    def training_step(self, batch, batch_idx):
        images, annotations, _, image_ids = batch

        losses = self.model(images, annotations)

        logging_losses = {
            "class_loss": losses["class_loss"].detach(),
            "box_loss": losses["box_loss"].detach(),
        }

        self.log(
            "train_loss",
            losses["loss"],
            on_step=True,
            on_epoch=True,
            prog_bar=True,
            logger=True,
        )
        self.log(
            "train_class_loss",
            losses["class_loss"],
            on_step=True,
            on_epoch=True,
            prog_bar=True,
            logger=True,
        )
        self.log(
            "train_box_loss",
            losses["box_loss"],
            on_step=True,
            on_epoch=True,
            prog_bar=True,
            logger=True,
        )

        return losses["loss"]

    @torch.no_grad()
    def validation_step(self, batch, batch_idx):
        images, annotations, targets, image_ids = batch
        outputs = self.model(images, annotations)

        detections = outputs["detections"]

        batch_predictions = {
            "predictions": detections,
            "targets": targets,
            "image_ids": image_ids,
        }

        logging_losses = {
            "class_loss": outputs["class_loss"].detach(),
            "box_loss": outputs["box_loss"].detach(),
        }

        self.log(
            "valid_loss",
            outputs["loss"],
            on_step=True,
            on_epoch=True,
            prog_bar=True,
            logger=True,
            sync_dist=True,
        )
        self.log(
            "valid_class_loss",
            logging_losses["class_loss"],
            on_step=True,
            on_epoch=True,
            prog_bar=True,
            logger=True,
            sync_dist=True,
        )
        self.log(
            "valid_box_loss",
            logging_losses["box_loss"],
            on_step=True,
            on_epoch=True,
            prog_bar=True,
            logger=True,
            sync_dist=True,
        )

        return {"loss": outputs["loss"], "batch_predictions": batch_predictions}


## Train

In [None]:
dm = GemmaCupDataModule(
    train_dataset_adaptor=train,
    validation_dataset_adaptor=val,
    num_workers=0,
    batch_size=2,
)


In [None]:
model = GemmaCupEfficientDetModel(num_classes=1, img_size=image_size)

In [None]:
model

In [None]:
from pytorch_lightning import Trainer

trainer = Trainer(
    accelerator="gpu",
    max_epochs=5,
    num_sanity_val_steps=1,
)


In [None]:
trainer.fit(model, dm)