# Tutorial for training a nest detection model

### Install Comet ML

In [None]:
!pip install comet_ml

### Install DeepForest library

In [None]:
!git clone https://github.com/weecology/DeepForest.git

In [None]:
%cd DeepForest
!pip install -e .
%cd ..

In [None]:
import os
import sys

deepforest_path = os.path.abspath("DeepForest")
deepforest_path

In [None]:
if deepforest_path not in sys.path:
    sys.path.insert(0, deepforest_path)

In [1]:
# load the modules
import comet_ml
import os
import time
import numpy as np
import pandas as pd
import torch
from deepforest import main
from deepforest import get_data
from deepforest import utilities
from deepforest import preprocess
from tqdm import tqdm
from pytorch_lightning.loggers import CometLogger
import zipfile
import matplotlib.pyplot as plt
import subprocess

### Set up Environment Variables

#### In Google Colab
Use Colab's secret storage to securely store your API key.

1. Locate the `Secrets` tab on the left-hand side panel in your Colab notebook.
2. Add a new secret with the key name as `COMET_API_KEY` and paste your Comet ML API key as the value.

#### Locally
Set an environment variable `COMET_API_KEY` in your operating system.

##### Windows
1. Open Command Prompt and set the environment variable:

    ```bash
    setx COMET_API_KEY "your_comet_api_key"
    ```

2. Restart your terminal or IDE.

##### macOS/Linux
1. Open your terminal and add the following line to your `.bashrc`, `.zshrc`, or `.profile` file:

    ```bash
    export COMET_API_KEY="your_comet_api_key"
    ```

2. Save the file and reload the shell configuration:

    ```bash
    source ~/.bashrc  # or ~/.zshrc, ~/.profile, etc.
    ```

In [None]:
PLATFORM = "colab"  # Platform can be colab or local
environment = {}
if PLATFORM == "colab":
    from google.colab import userdata

    environment["api_key"] = userdata.get("COMET_API_KEY")
else:
    environment["api_key"] = os.getenv("COMET_API_KEY")

In [10]:
api_key = environment["api_key"]

In [None]:
# change the project_name
comet_logger = CometLogger(project_name="temporary2", api_key=api_key)

### Download the Bird nest dataset

In [None]:
root_folder = "/content" if PLATFORM == "colab" else os.environ.get("ROOT_FOLDER")

def download_dataset(output_filename='Dataset.zip', extract_folder_name="dataset"):
    """
    Download a file from a URL using 'wget', extract its contents to a specified folder,
    and handle platform-specific root folder locations.

    Args:
    - output_filename (str): Name of the downloaded file.
    - extract_folder_name (str): Name of the folder to extract the contents into.

    Raises:
    - FileNotFoundError: If the downloaded zip file does not exist.

    Returns:
    None
    """
    url = 'https://www.dropbox.com/s/iczokehl2c5hcjx/nest_images.zip?dl=0'

    # Download the file using wget
    result = subprocess.run(['wget', '-O', output_filename, url], capture_output=True, text=True)

    # Check if the download was successful
    if result.returncode == 0:
      print('Download complete.')
    else:
      print('Error occurred:', result.stderr)

    # Determine the root folder based on the platform

    # Paths for zip file and extraction folder
    zip_file = os.path.join(root_folder, output_filename)
    extract_folder = os.path.join(root_folder, extract_folder_name)

    # Check if the zip file exists
    if not os.path.exists(zip_file):
        raise FileNotFoundError(f"The zip file {zip_file} does not exist.")

    # Create the extract folder if it doesn't exist
    os.makedirs(extract_folder, exist_ok=True)

    # Open the zip file and extract its contents
    with zipfile.ZipFile(zip_file, "r") as zip_ref:
        for file in tqdm(zip_ref.namelist(), desc="Extracting", unit="files"):
            zip_ref.extract(file, extract_folder)

    print(f"Successfully unzipped {zip_file} to {extract_folder}.")
    return extract_folder

In [None]:
extract_folder=download_dataset()

In [None]:
# Check if the annotations file has been extracted from the zip file
annotations = pd.read_csv(os.path.join(extract_folder, "nest_data.csv"))
annotations.head()

In [17]:
# Gather all the images ending with .JPG
image_names = [file for file in os.listdir(extract_folder) if file.endswith(".JPG")]

In [None]:
# Generate crops of the image which has Region of Interest (ROI)
crop_dir = os.path.join(os.getcwd(), "train_data_folder")
annotation_path = os.path.join(extract_folder, "nest_data.csv")
all_annotations = []
for image in image_names:
    image_path = os.path.join(extract_folder, image)
    annotations = preprocess.split_raster(
        path_to_raster=image_path,
        annotations_file=annotation_path,
        patch_size=400,
        patch_overlap=0.05,
        base_dir=crop_dir,
    )
    all_annotations.append(annotations)
train_annotations = pd.concat(all_annotations, ignore_index=True)

In [21]:
image_paths = train_annotations.image_path.unique()

# split into 70% train, 20% validation and 10% test annotations
temp_paths = np.random.choice(image_paths, int(len(image_paths) * 0.30))
valid_paths = np.random.choice(temp_paths, int(len(image_paths) * 0.20))
test_paths = [path for path in temp_paths if path not in valid_paths]

valid_annotations = train_annotations.loc[
    train_annotations.image_path.isin(valid_paths)
]
test_annotations = train_annotations.loc[train_annotations.image_path.isin(test_paths)]
train_annotations = train_annotations.loc[
    ~train_annotations.image_path.isin(temp_paths)
]

In [None]:
# View output
print(train_annotations.head())
print("There are {} training crown annotations".format(train_annotations.shape[0]))
print("There are {} test crown annotations".format(valid_annotations.shape[0]))

# save to file and create the file dir
annotations_file = os.path.join(crop_dir, "train.csv")
validation_file = os.path.join(crop_dir, "valid.csv")
test_file = os.path.join(crop_dir, "test.csv")

# Write window annotations file without a header row, same location as the "base_dir" above.
train_annotations.to_csv(annotations_file, index=False)
valid_annotations.to_csv(validation_file, index=False)
test_annotations.to_csv(test_file, index=False)

In [None]:
# initialize the model and change the corresponding config file
m = main.deepforest(label_dict={"Nest": 0})

# move to GPU and use all the GPU resources
m.config["gpus"] = "-1"
m.config["train"]["csv_file"] = annotations_file
m.config["train"]["root_dir"] = os.path.dirname(annotations_file)

# Define the learning scheduler type
m.config["train"]["scheduler"]["type"] = "cosine"
m.config["score_thresh"] = 0.4
m.config["train"]["epochs"] = 10
m.config["validation"]["csv_file"] = validation_file
m.config["validation"]["root_dir"] = os.path.dirname(validation_file)

In [None]:
m.config["train"]["scheduler"]["type"]

In [None]:
# create a pytorch lighting trainer used to training
# Disable the sanity check for validation data
m.create_trainer(logger=comet_logger, num_sanity_val_steps=0)
# load the lastest release model (RetinaNet)
m.use_release()

In [None]:
# Start the training
start_time = time.time()
m.trainer.fit(m)
print(f"--- Training on GPU: {(time.time() - start_time):.2f} seconds ---")

In [None]:
# save the prediction result to a prediction folder
save_dir = os.path.join(os.getcwd(), "pred_result_test")
results = m.evaluate(
    test_file, os.path.dirname(test_file), iou_threshold=0.4, savedir=save_dir
)

In [None]:
results["box_precision"]

In [None]:
results["box_recall"]

In [30]:
# save the results to a csv file
results["results"].to_csv("results_test_lr_cosine.csv", index=False)

In [None]:
# Save the model checkpoint
m.trainer.save_checkpoint(
    os.path.join(root_folder, "checkpoint_epochs_10_cosine_lr_retinanet.pl")
)

In [None]:
torch.save(m.model.state_dict(), os.path.join(root_folder, "weights_cosine_lr"))

In [None]:
# Load from the saved checkpoint
model = main.deepforest.load_from_checkpoint(
    os.path.join(root_folder, "checkpoint_epochs_10_cosine_lr_retinanet.pl")
)

In [None]:
# Add a path to an image to test the model on
raster_path = ""
predicted_raster = model.predict_tile(
    raster_path, return_plot=True, patch_size=300, patch_overlap=0.25
)
plt.imshow(predicted_raster)
plt.show()