# Gemma cup detection V2

## Import

In [None]:
import os
import glob
import sys
from datetime import datetime

import pandas as pd
import numpy as np
import cv2

%matplotlib inline
import matplotlib.pyplot as plt
import matplotlib.patches as patches

import torch
import torchvision
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection import FasterRCNN
from torchvision.models.detection.rpn import AnchorGenerator, RPNHead, RegionProposalNetwork

import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2

from engine import train_one_epoch, evaluate
import transforms as T

from skimage import io, transform
from skimage.color import rgb2gray

import ipywidgets as widgets
from IPython.display import Image as IpImage
from IPython.display import display
from ipywidgets import Button, HBox, VBox
from PIL import Image as PilImage

from loaders import (
    GemmaDataset, 
    get_train_transform, 
    get_valid_transform, 
    Averager, 
    format_prediction_string, 
    show_predictions,
    predict_loader,
)

## Define Constants

In [None]:
data_path = os.path.join("..", "data_in")
images_path =os.path.join(data_path, "images")

test_index = 1

## Datasets

### Tests

#### Build test dataset

In [None]:
ds = GemmaDataset(
    csv=pd.read_csv(os.path.join(data_path, "boxes_final.csv")),
    images_path=images_path,
)
len(ds)

#### Test boxes

In [None]:
len(ds.load_boxes(0))

In [None]:
ds.load_boxes(test_index)

In [None]:
ds.load_boxes(test_index)[0:1]

#### Test images

In [None]:
img = ds.load_image(test_index)
io.imshow(img) 
io.show()

#### Test sample

In [None]:
ds.get_by_sample_name("b0KXwBrE57rCtnxjL2jKk0AXGwCI.jpg")

In [None]:
dd_sample = widgets.Dropdown(options=sorted(ds.images))

image_output = widgets.Output(layout={"border": "1px solid black"})
rects_output = widgets.Output(layout={"border": "1px solid black"})


def print_final_rects(change):
    image_output.clear_output()
    rects_output.clear_output()
    
    image, targets, _ = ds.get_by_sample_name(change.new)
                             
    boxes = targets['boxes'].cpu().numpy().astype(np.int32)
    image = image.permute(1,2,0).cpu().numpy()

    with image_output:
        fig, ax = plt.subplots(1, 1, figsize=(16, 8))
        ax.set_axis_off()
        for box in boxes:
            ax.add_patch(
                patches.Rectangle(
                    (box[0], box[1]), 
                    box[2] - box[0], 
                    box[3] - box[1],
                    linewidth=2, 
                    edgecolor="r", 
                    facecolor="none",
                )
            )
        ax.imshow(image)
        plt.show()
    
    with rects_output:
        display(
            pd.DataFrame(
                [box for box in boxes], 
                columns=["x1", "y1", "x2", "y2"],                
            )
        )


dd_sample.observe(print_final_rects, names="value")
display(dd_sample, HBox([image_output, rects_output]))

## Create model

In [None]:
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)

In [None]:
num_classes = 2  # 1 class (wheat) + background

# get number of input features for the classifier
in_features = model.roi_heads.box_predictor.cls_score.in_features

# replace the pre-trained head with a new one
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

## Build data loaders

### Create train and test dataframes

#### Load and clean dataframe

In [None]:
df: pd.DataFrame = pd.read_csv(os.path.join(data_path, "boxes.csv"))
df = df[(df.width != 0) & (df.height != 0)].reset_index()
df.head()

#### Split dataframe

In [None]:
sizes: tuple = (0.8, 0.20)
dataset_size = len(list(df.filename.unique()))
indices = [ i for i in list(df.filename.unique())]

split_train = int(np.floor(sizes[0] * dataset_size))
split_test = int(np.floor(sizes[1] * dataset_size)) + split_train

np.random.shuffle(indices)
train_indices, test_indices = (
    indices[:split_train],
    indices[split_train:split_test],
)

df_train = df[df.filename.isin(train_indices)]
df_test = df[df.filename.isin(test_indices)]

In [None]:
df_train.drop_duplicates(subset=["filename"]).head()

In [None]:
df_test.drop_duplicates(subset=["filename"]).head()

#### Look for leakage

In [None]:
# Look for leakage
pd.merge(
    df_train,
    df_test,
    on=list(df_test.columns),
    how="inner",
).head()

#### Ensure images are only in one set

In [None]:
set(df_train.filename.to_list()).intersection(set(df_test.filename.to_list()))

### Build datasets

In [None]:
train_dataset = GemmaDataset(
    csv=df_train, 
    transform=get_train_transform(),
    images_path=images_path,
)
valid_dataset = GemmaDataset(
    csv=df_test,
    transform=get_valid_transform(),
    images_path=images_path,
)

In [None]:
for i in range(len(train_dataset)):
    train_dataset[i]

### Build loaders

In [None]:
def collate_fn(batch):
    return tuple(zip(*batch))

batch_size = 1

train_data_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=False,
    collate_fn=collate_fn
)

valid_data_loader = DataLoader(
    valid_dataset,
    batch_size=batch_size,
    shuffle=False,
    collate_fn=collate_fn
)

## Select device

In [None]:
# device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
device = torch.device('cpu')
device

## Sample

In [None]:
images, targets, image_ids = next(iter(train_data_loader))
images = list(image.to(device) for image in images)
targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

In [None]:
boxes = targets[0]['boxes'].cpu().numpy().astype(np.int32)
sample = images[0].permute(1,2,0).cpu().numpy()
boxes

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(16, 8))

for box in boxes:
    cv2.rectangle(
        sample,
        (box[0], box[1]),
        (box[2], box[3]),
        (220, 0, 0), 
        3
    )

ax.set_axis_off()
ax.imshow(sample)

## Train

In [None]:
model.to(device)
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=0.005, momentum=0.9, weight_decay=0.0005)
# lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.1)
lr_scheduler = None

num_epochs = 2

In [None]:
loss_hist = Averager()
itr = 1

for epoch in range(num_epochs):
    loss_hist.reset()
    
    for images, targets, image_ids in train_data_loader:
        
        images = list(image.to(device) for image in images)
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

        loss_dict = model(images, targets)

        losses = sum(loss for loss in loss_dict.values())
        loss_value = losses.item()

        loss_hist.send(loss_value)

        optimizer.zero_grad()
        losses.backward()
        optimizer.step()

        if itr % 10 == 0:
            print(f"Iteration #{itr} loss: {loss_value}")

        itr += 1
    
    # update the learning rate
    if lr_scheduler is not None:
        lr_scheduler.step()

    print(f"Epoch #{epoch} loss: {loss_hist.value}")   

## Predict

In [None]:
images, targets, image_ids = next(iter(valid_data_loader))
images, targets, image_ids

In [None]:
results = predict_loader(
    model=model,
    loader=valid_data_loader,
    device=device,
    detection_threshold=0.5,
)

In [None]:
results[0]

In [None]:
test_df = pd.DataFrame(list(results.values()), columns=['image_id', 'PredictionString', "scores", "boxes"])
test_df.head()

In [None]:
iter_loader = iter(valid_data_loader)

In [None]:
show_predictions(*next(iter_loader), device, results)

## Save state dict

In [None]:
state_output_path = os.path.join("..", "models",datetime.now().strftime("%Y%m%d-%H%M%S") + "state_dict.pth")

In [None]:
torch.save(
    model.state_dict(), 
    state_output_path
)

## Save model

In [None]:
model_output_path = os.path.join("..", "models",datetime.now().strftime("%Y%m%d-%H%M%S") + "model.pth")

In [None]:
torch.save(model, model_output_path)

## Predict with stored data

### Load model

In [None]:
loaded_model = torch.load(os.path.join("..", "models", "default_model.pth"))

### Predict rectangles

In [None]:
predictions = predict_loader(
    model=loaded_model,
    loader=valid_data_loader,
    device=device,
    detection_threshold=0.5,
)

### View predictions

In [None]:
predict_loader = iter(valid_data_loader)

In [None]:
show_predictions(*next(predict_loader), device, predictions)