<a href="https://colab.research.google.com/github/wayne0git/ml_cv_basics/blob/master/object_detection/yolo_v5_train_example_pytorch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# YOLO v5 Training Example Code
- https://learnopencv.com/custom-object-detection-training-using-yolov5/



## Environment Preparation

### YOLO v5

In [1]:
import os

In [None]:
if not os.path.exists('yolov5'):
    !git clone https://github.com/ultralytics/yolov5.git

In [None]:
%cd yolov5

In [None]:
!pip install -r requirements.txt

## Data Preparation

### Download Traffic vehicles Object Detection (Kaggle)

In [6]:
import requests

In [16]:
ZIP_FPATH = 'data/traffic-vehicles-object-detection.zip'
DATA_DIR = 'data'

In [11]:
def download_file(url, save_name):
    if not os.path.exists(save_name):
        file = requests.get(url)
        open(save_name, 'wb').write(file.content)
    else: 
        print('File already present, skipping download...')

In [None]:
download_file(
    'https://learnopencv.s3.us-west-2.amazonaws.com/traffic-vehicles-object-detection.zip',
    ZIP_FPATH
)

In [18]:
!unzip -q {ZIP_FPATH} -d {DATA_DIR}

### Create data YAML file

In [None]:
%%writefile data/data.yaml
path: "data/Traffic Dataset" # Path relative to the `train.py` script. 
train: images/train 
val: images/val 

# Classes
nc: 7
names: [
    "Car", "Number Plate", "Blur Number Plate", "Two Wheeler", "Auto", "Bus", "Truck"
]

### Data visualization

In [25]:
import cv2
import glob
import matplotlib.pyplot as plt

In [20]:
class_names = ["Car", "Number Plate", "Blur Number Plate", "Two Wheeler", "Auto", "Bus", "Truck"] 

In [21]:
# Function to convert bounding boxes in YOLO format to xmin, ymin, xmax, ymax.
def yolo2bbox(bboxes):
    xmin, ymin = bboxes[0]-bboxes[2]/2, bboxes[1]-bboxes[3]/2
    xmax, ymax = bboxes[0]+bboxes[2]/2, bboxes[1]+bboxes[3]/2
    return xmin, ymin, xmax, ymax

In [34]:
def plot_box(image, bboxes, labels):
    # Need the image height and width to denormalize the bounding box coordinates
    h, w, _ = image.shape

    for box_num, box in enumerate(bboxes):
        x1, y1, x2, y2 = yolo2bbox(box)

        # denormalize the coordinates
        xmin = int(x1*w)
        ymin = int(y1*h)
        xmax = int(x2*w)
        ymax = int(y2*h)

        width = xmax - xmin
        height = ymax - ymin

        cv2.rectangle(image, (xmin, ymin), (xmax, ymax), color=(0, 0, 255), thickness=6) 
        cv2.putText(image, class_names[int(labels[box_num])], (xmin+1, ymin-10), cv2.FONT_HERSHEY_SIMPLEX, 2, (0, 255, 0), 10)

    return image

In [29]:
# Function to plot images with the bounding boxes.
def plot(image_paths, label_paths):
    all_training_images = glob.glob(image_paths)
    all_training_labels = glob.glob(label_paths)
    all_training_images.sort()
    all_training_labels.sort()
    
    plt.figure(figsize=(21, 12))
    for i in range(4):
        image = cv2.imread(all_training_images[i+10])

        with open(all_training_labels[i+10], 'r') as f:
            bboxes = []
            labels = []
            label_lines = f.readlines()
            for label_line in label_lines:
                label = label_line[0]
                bbox_string = label_line[2:]
                x_c, y_c, w, h = bbox_string.split(' ')
                x_c = float(x_c)
                y_c = float(y_c)
                w = float(w)
                h = float(h)
                bboxes.append([x_c, y_c, w, h])
                labels.append(label)
        result_image = plot_box(image, bboxes, labels)
        plt.subplot(2, 2, i+1)
        plt.imshow(result_image[:, :, ::-1])
        plt.axis('off')
    plt.show()

In [None]:
plot(image_paths='data/Traffic Dataset/images/train/*', 
     label_paths='data/Traffic Dataset/labels/train/*')

## Train Model

### Hyperparameter

In [36]:
TRAIN = True
EPOCHS = 5

### Log Utility Function

In [40]:
LOG_DIR = 'runs/train'

In [42]:
def monitor_tensorboard():
    %load_ext tensorboard
    %tensorboard --logdir {LOG_DIR}

In [41]:
# Directory to store train / inference results
def set_res_dir():
    res_dir_count = len(glob.glob(LOG_DIR + '/*'))
    print(f"Current number of result directories: {res_dir_count}")

    if TRAIN:
        RES_DIR = f"results_{res_dir_count+1}"
        print(RES_DIR)
    else:
        RES_DIR = f"results_{res_dir_count}"

    return RES_DIR

### Train

In [None]:
monitor_tensorboard()
RES_DIR = set_res_dir()

In [None]:
# YOLOv5m
if TRAIN:
    !python train.py --data data/data.yaml --weights yolov5m.pt --img 640 --epochs {EPOCHS} --batch-size 16 --name {RES_DIR}

In [None]:
# YOLOv5m (Transfer Learning. Freeze first 15 layers.)
if TRAIN:
    !python train.py --data data/data.yaml --weights yolov5m.pt --img 640 --epochs {EPOCHS} --batch-size 16 --name {RES_DIR} \
    --freeze 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14

In [None]:
# YOLOv5s
if TRAIN:
    !python train.py --data data/data.yaml --weights yolov5s.pt --img 640 --epochs {EPOCHS} --batch-size 16 --name {RES_DIR}

### Evaluation

In [46]:
# Function to show validation predictions saved during training.
def show_valid_results(RES_DIR):
    EXP_PATH = os.path.join(LOG_DIR, RES_DIR)
    validation_pred_images = glob.glob(f"{EXP_PATH}/*_pred.jpg")

    for pred_image in validation_pred_images:
        image = cv2.imread(pred_image)
        plt.figure(figsize=(19, 16))
        plt.imshow(image[:, :, ::-1])
        plt.axis('off')
        plt.show()

In [None]:
show_valid_results(RES_DIR)

## Inference

### Download Inference Data

In [48]:
ZIP_FPATH = 'data/inference_data.zip'
DATA_DIR = 'data'

In [49]:
download_file('https://learnopencv.s3.us-west-2.amazonaws.com/yolov5_inference_data.zip', ZIP_FPATH)

In [51]:
!unzip -q {ZIP_FPATH} -d {DATA_DIR}

### Inference Utility Function

In [54]:
DETECT_DIR = 'runs/detect'

In [55]:
# Helper function for inference on images.
def inference(RES_DIR, data_path):
    # Directory to store inference results.
    infer_dir_count = len(glob.glob(DETECT_DIR + '/*'))
    INFER_DIR = f"inference_{infer_dir_count+1}"

    # Inference on images.
    !python detect.py --weights {LOG_DIR}/{RES_DIR}/weights/best.pt --source {data_path} --name {INFER_DIR}

    return INFER_DIR

In [56]:
# Visualize inference images.
def visualize(INFER_DIR):
    INFER_PATH = os.path.join(DETECT_DIR, INFER_DIR)

    infer_images = glob.glob(f"{INFER_PATH}/*.jpg")
    for pred_image in infer_images:
        image = cv2.imread(pred_image)
        plt.figure(figsize=(19, 16))
        plt.imshow(image[:, :, ::-1])
        plt.axis('off')
        plt.show()

In [None]:
# Inference on images.
IMAGE_INFER_DIR = inference(RES_DIR, 'data/inference_images')
visualize(IMAGE_INFER_DIR)

In [None]:
# Inference on videos
inference(RES_DIR, 'data/inference_videos')