## Vision Systems in Agriculture (Part 2)

### Section 1: Understanding basics of PyTorch

PyTorch is a machine/deep learning framework based on the Torch library. It was developed by Meta AI and has broad application in computer vision and natural language processing.

[PyTorch website](https://pytorch.org/)

[60-minute tutorial on Deep Learning with PyTorch](https://pytorch.org/tutorials/beginner/deep_learning_60min_blitz.html): Please go over this in your spare time.

### Section 2: Image object detection using a pretrained model

Steps to take:
1. Find an existing performant model trained on dataset same as your target dataset
2. Initialize the model and pretrained model weights
3. Preprocess the image to fit the input requirements of the model
4. Make the prediction
5. Analyze the prediction results
6. Visualize the prediction results
7. Evaluate the results

---

0. Install PyTorch library

In [None]:
# In your python virtual environment, run the commands
pip install torch torchvision

1. Find an existing performant model trained on dataset same as your target dataset

Looking over the [PyTorch model list](https://pytorch.org/vision/stable/models.html#table-of-all-available-classification-weights), we select the [MaskRCNN model with the ResNet-50-FCN](https://pytorch.org/vision/stable/models/generated/torchvision.models.detection.maskrcnn_resnet50_fpn.html#torchvision.models.detection.MaskRCNN_ResNet50_FPN_Weights)

2. Initialize the model and pretrained model weights

We are using [COCO dataset](https://cocodataset.org/#explore): Microsoft Common Objects in Context dataset is a large-scale object detection, segmentation, key-point detection, and captioning dataset. The dataset consists of 328K images. 

In [None]:
import torch
from torchvision.models.detection import maskrcnn_resnet50_fpn, MaskRCNN_ResNet50_FPN_Weights

# 1. obtain and process weights
weights = MaskRCNN_ResNet50_FPN_Weights.COCO_V1
transforms = weights.transforms()

# 2. initialize the model
model = maskrcnn_resnet50_fpn(weights=weights, progress=False)
model = model.eval()

# 3. explore the model details
# print(model)

# 4. TODO: Get the categories (class labels) from the metadata. See if you can identify all the fruit/vegetables categories
# categories = weights.meta["categories"]
# print(categories)



2. Import target image

In [None]:
from PIL import Image
import matplotlib.pyplot as plt


# 1. load the image 
image_path = '.\img\orange.jpg'
image = Image.open(image_path).convert("RGB")

# 2. display the image
# plt.imshow(image)


3. Preprocess the image to fit the input requirements of the model

In [None]:
import torchvision.transforms as T

# convert to a tensor and normalizes the data as well as shape of the image
preprocess = T.Compose([
    T.ToTensor(),
])

img_preprocess = preprocess(image).unsqueeze(0)


# TODO: Explore the shape/size of both input image and the preprocessed version
# print(f'Shape of original image: {image.size}')
# print(f'Shape of original image: {img_preprocess.shape}')

4. Perform model prediction

In [54]:
# make a prediction
with torch.no_grad():
    predictions = model(img_preprocess)

5. Analyze prediction results

In [55]:
# Extract boxes, labels, scores, and masks from predictions
boxes = predictions[0]['boxes'].cpu().detach().numpy()
labels = predictions[0]['labels'].cpu().detach().numpy()
scores = predictions[0]['scores'].cpu().detach().numpy()
masks = predictions[0]['masks'].cpu().detach().numpy()


# TODO: Explore the prediction results. Notice the prediction scores
# for i in range(len(scores)):
#     print(f'Prediction {i+1}')
#     print('--------------------------')
#     print(f'Label: {categories[labels[i]]}')
#     print(f'Score: {scores[i]}')
#     print(f'Bbox: {boxes[i]}')
#     print('\n')


In [None]:
# TODO: Explore the masks for each prediction
# score_threshold = 0.5
# plt.figure(figsize=(20, 10))
# for i in range(len(masks)):
#     if scores[i] > score_threshold:
#         plt.subplot(1, len(masks),i+1)
#         plt.imshow(masks[i][0], cmap='gray')
#         plt.title(f'Mask for prediction {i}')

Visualize results

In [None]:
import cv2 as cv
import numpy as np

# TODO: Import the image using cv.imread 


score_threshold = 0.6

# TODO: Using the threshold to filter predictions, apply the equivalent mask and draw the bounding box (using cv.rectangle()) on the image



# TODO: Plot the original image and the segmentation/bounding box image


Evaluate results

In [None]:
from helpers import get_ground_truth_ann, calculate_map

# Our goal here is to evaluate the segmentation and object detection prediction against ground truth data for the image

# TODO: Extract the bounding box information from the prediction result above and place into the prediction_bbox variable. Define a score threshold to use

prediction_bbox = []



# Here we have provided a helper function to extract the ground truth bbox for the images
# image_name = 'orange'
# ground_truth_bbox = get_ground_truth_ann(image_name=image_name, show=False)

# print(f'ground_truths = {ground_truth_bbox}')
# print(f'predictions = {prediction_bbox}')


# TODO: plot the results
# img_bbox = img.copy()

# for gt in ground_truth_bbox:
#     gt = [int(v) for v in gt]
#     cv.rectangle(img_bbox, (gt[0], gt[1]), (gt[2], gt[3]), (0, 255, 0), 2)

# for pd in prediction_bbox:
#     pd = [int(v) for v in pd]
#     cv.rectangle(img_bbox, (pd[0], pd[1]), (pd[2], pd[3]), (255, 0, 0), 2)


# plt.imshow(img_bbox)
# plt.title("Ground truth vs predicted Bbox")

In [None]:
# TODO: Use the calculate_map helper function to calculate the average precision of your object detection pipeline
# mAP = calculate_map(ground_truth_bbox, prediction_bbox, iou_threshold=0.7)
# print(mAP)

#### Additional exploration

Now you have walked though the process using the images we analyzed in the last class, we will run through the pipeline using another image with multiple fruit/vegetables.

Your task is to get an image online with fruits (e.g., apple, banana, orange) and run it through the object detection model and visualize the result.

One thing to do is to adjust the color of the visualization to match different fruit.

See example solution below.

<img src=".\img\solution3.png" width="800">