# Gemma cup detection V2

https://pytorch.org/tutorials/intermediate/torchvision_tutorial.html

## Import

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from pathlib import Path

import matplotlib.pyplot as plt
import matplotlib.patches as patches

import ipywidgets as widgets
from IPython.display import Image as IpImage
from IPython.display import display
from ipywidgets import HBox

import loaders as lds

%matplotlib inline

## Define Constants

### Paths

In [None]:
data_path = Path("..").joinpath("data_in")
images_path = data_path.joinpath("images")
dataset_path = Path("..").joinpath("data_in", "datasets")

data_path.is_dir(), images_path.is_dir(), dataset_path.is_dir()


### Configuration

In [None]:
image_size = 1024
batch_size = 30

In [None]:
import torch

def get_device(force=None):
    return force if force is not None else(
    "mps"
    if torch.backends.mps.is_built() is True
    else "cuda"
    if torch.backends.cuda.is_built()
    else "cpu"
)

get_device()

## Build Datasets

### Load CSVs

In [None]:
import pandas as pd

train = pd.read_csv(dataset_path.joinpath("train.csv"))
val = pd.read_csv(dataset_path.joinpath("val.csv"))
test = pd.read_csv(dataset_path.joinpath("test.csv"))

train.shape, val.shape, test.shape


### Build Datasets

In [None]:
train_dataset = lds.GemmaDataset(
    csv=train,
    images_path=images_path,
    transform=lds.get_test_image_transform(image_size=image_size),
)

val_dataset = lds.GemmaDataset(
    csv=val,
    images_path=images_path,
    transform=lds.get_test_image_transform(image_size=image_size),
)

test_dataset = lds.GemmaDataset(
    csv=test,
    images_path=images_path,
    transform=lds.get_test_image_transform(image_size=image_size),
)


### Test Dataset

In [None]:
file_name = train.sample(n=1).filename.to_list()[0]

lds.make_patches_grid(
    images=[train_dataset.draw_image_with_boxes(filename=file_name) for _ in range(12)],
    row_count=3,
    col_count=4,
    figsize=(10, 7.5),
)


## Create Data Loaders

In [None]:
from torch.utils.data import DataLoader


def collate_fn(batch):
    return tuple(zip(*batch))


ds_tst = lds.GemmaDataset(
    csv=train.sample(n=10),
    images_path=images_path,
    transform=lds.get_train_transform(image_size=image_size),
)

train_data_loader = DataLoader(
    lds.GemmaDataset(
        csv=train,
        images_path=images_path,
        transform=lds.get_train_transform(image_size=image_size),
    ),
    batch_size=batch_size,
    shuffle=False,
    collate_fn=collate_fn,
)

valid_data_loader = DataLoader(
    lds.GemmaDataset(
        csv=val,
        images_path=images_path,
        transform=lds.get_train_transform(image_size=image_size),
    ),
    batch_size=batch_size,
    shuffle=False,
    collate_fn=collate_fn,
)

test_data_loader = DataLoader(
    lds.GemmaDataset(
        csv=test,
        images_path=images_path,
        transform=lds.get_train_transform(image_size=image_size),
    ),
    batch_size=batch_size,
    shuffle=False,
    collate_fn=collate_fn,
)


In [None]:
train_data_loader.dataset[0]

### Sample

In [None]:
images, targets = next(iter(train_data_loader))
images = list(image for image in images)
targets = [{k: v for k, v in t.items()} for t in targets]

In [None]:
import numpy as np
boxes = targets[0]['boxes'].cpu().numpy().astype(np.int32)
sample = images[0].permute(1,2,0).cpu().numpy()
boxes

## Train

In [None]:
from torch import nn

from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
import torchvision

import albumentations as A

import pytorch_lightning as pl


### Build model

In [None]:
def collate_fn(batch):
    images, targets = tuple(zip(*batch))
    images = torch.stack(images)
    images = images.float()

    boxes = [target["boxes"].float() for target in targets]
    labels = [target["labels"].float() for target in targets]

    return images, targets

In [None]:
class GemmaCupDetector(pl.LightningModule):
    def __init__(
        self,
        batch_size: int,
        learning_rate: float,
        max_epochs: int,
        train_data: pd.DataFrame,
        val_data: pd.DataFrame,
        test_data: pd.DataFrame,
        train_augmentations: A.Compose,
        val_augmentations: A.Compose,
        num_workers: int = 0,
        accumulate_grad_batches: int = 3,
        selected_device: str = None,
    ):
        super().__init__()

        # Hyperparameters
        self.batch_size = batch_size
        self.selected_device = selected_device
        self.learning_rate = learning_rate
        self.num_workers = num_workers
        self.max_epochs = max_epochs
        self.accumulate_grad_batches = accumulate_grad_batches

        # dataframes
        self.train_data = train_data
        self.val_data = val_data
        self.test_data = test_data

        # albumentations
        self.train_augmentations = train_augmentations
        self.val_augmentations = val_augmentations

        # Model
        self.model = torchvision.models.detection.fasterrcnn_resnet50_fpn(
            pretrained=True
        )
        num_classes = 2  # 1 class (wheat) + background
        # get number of input features for the classifier
        in_features = self.model.roi_heads.box_predictor.cls_score.in_features
        # replace the pre-trained head with a new one
        self.model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

        self.save_hyperparameters()

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        return optimizer

    def train_dataloader(self):
        return DataLoader(
            lds.GemmaDataset(
                csv=self.train_data,
                images_path=images_path,
                transform=self.train_augmentations,
            ),
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.num_workers,
            collate_fn=collate_fn,
            pin_memory=True,
        )

    def val_dataloader(self):
        return DataLoader(
            lds.GemmaDataset(
                csv=self.train_data,
                images_path=images_path,
                transform=self.val_augmentations,
            ),
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            collate_fn=collate_fn,
            pin_memory=True,
        )

    def test_dataloader(self):
        return DataLoader(
            lds.GemmaDataset(
                csv=self.train_data,
                images_path=images_path,
                transform=self.val_augmentations,
            ),
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            collate_fn=collate_fn,
            pin_memory=True,
        )

    def forward(self, x):
        return self.model(x)

    def step_(self, batch, batch_index, loss_type):
        x, y = batch
        # x, y = x.unsqueeze(0), [y]
        loss_dict = self.model(x, y)
        print(loss_dict)
        losses = torch.tensor([loss for loss in loss_dict.values()])
        self.log_dict({f"{loss_type}_loss": losses.mean()})
        return losses.sum()

    def training_step(self, batch, batch_idx):
        return self.step_(batch=batch, batch_index=batch_idx, loss_type="train")

    def validation_step(self, batch, batch_idx):
        return self.step_(batch=batch, batch_index=batch_idx, loss_type="val")

    def test_step(self, batch, batch_idx):
        return self.step_(batch=batch, batch_index=batch_idx, loss_type="test")


In [None]:
model = GemmaCupDetector(
    batch_size=1,
    learning_rate=1e5,
    max_epochs=1,
    train_data=train,
    val_data=val,
    test_data=test,
    train_augmentations=lds.get_train_transform(image_size=image_size),
    val_augmentations=lds.get_valid_transform(image_size=image_size),
    num_workers=1,
    accumulate_grad_batches=1,
    selected_device=get_device(),
)


dl_tst = DataLoader(
    lds.GemmaDataset(
        csv=train,
        images_path=images_path,
        transform=lds.get_train_transform(image_size=image_size),
    ),
    batch_size=1,
    shuffle=True,
    num_workers=0,
    collate_fn=collate_fn,
    pin_memory=True,
)

# model.eval()


model.step_(next(iter(dl_tst)), 0, "")

# ds_tst[0][0]


In [None]:
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import RichProgressBar
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks import DeviceStatsMonitor
from pytorch_lightning.callbacks import ModelCheckpoint


trainer = Trainer(
    accelerator="cpu",
    max_epochs=model.max_epochs,
    log_every_n_steps=5,
    callbacks=[
        RichProgressBar(),
        EarlyStopping(monitor="val_loss", mode="min", patience=15, min_delta=0.0005),
        DeviceStatsMonitor(),
        ModelCheckpoint(
            save_top_k=3,
            monitor="val_loss",
            auto_insert_metric_name=True,
            filename="{epoch}-{step}-{train_loss}-{val_loss}",
        ),
    ],
    accumulate_grad_batches=model.accumulate_grad_batches,
    # auto_scale_batch_size="binsearch",
    # Debug
    # fast_dev_run=True,
    # overfit_batches=10,
    # detect_anomaly=True,
)



In [None]:
trainer.fit(model)

In [None]:
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)

In [None]:
model.roi_heads.box_predictor.cls_score.in_features

In [None]:
num_classes = 2  # 1 class (wheat) + background

# get number of input features for the classifier
in_features = model.roi_heads.box_predictor.cls_score.in_features

# replace the pre-trained head with a new one
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

In [None]:
import engine

device = get_device("cpu")
model.to(device)
optimizer = torch.optim.SGD(
    [p for p in model.parameters() if p.requires_grad], 
    lr=0.005, 
    momentum=0.9, 
    weight_decay=0.0005
)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.1)
# lr_scheduler = None

num_epochs = 200


for epoch in range(num_epochs):
    engine.train_one_epoch(
        model=model,
        optimizer=optimizer,
        data_loader=train_data_loader,
        device=device,
        epoch=epoch,
        print_freq=100,
    )
    lr_scheduler.step()
    engine.evaluate(model, valid_data_loader, device=device)

## Save state dict

In [None]:
state_output_path = os.path.join("..", "models",datetime.now().strftime("%Y%m%d-%H%M%S") + "state_dict.pth")

In [None]:
torch.save(
    model.state_dict(), 
    state_output_path
)

## Save model

In [None]:
model_output_path = os.path.join("..", "models",datetime.now().strftime("%Y%m%d-%H%M%S") + "model.pth")

In [None]:
torch.save(model, model_output_path)