In [None]:
#@title Installing libraries for Evaluation of Neon data
! pip install --upgrade deepforest albumentations pyyaml
#pip install git+https://github.com/weecology/DeepForest.git
!pip uninstall opencv-python-headless -y
!pip install opencv-python-headless==4.1.2.30
!pip install GDAL
!pip install matplotlib
!pip uninstall pytorch_lightning -y
!pip install pytorch_lightning

In [7]:
#@title Importing the required libraries
from osgeo import gdal
import os
import numpy as np
from deepforest import main
from deepforest import get_data
from deepforest import utilities
from deepforest import preprocess
import matplotlib.pyplot as plt
from torch import optim, nn, utils, Tensor
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
import pytorch_lightning as pl
import tensorflow as tf
import glob

In [3]:
#@title Checking for GPU
device_name=tf.test.gpu_device_name()
device_name

'/device:GPU:0'

In [None]:
#@title Fetching NeonTreeEvaluation data
# Fetching data from NeonTreeEvaluation
!git clone https://github.com/weecology/NeonTreeEvaluation.git
!mkdir data

In [None]:
#@title For evaluation of rasters and plotting results
class Scores:

  model = main.deepforest()
  model.use_release()

  num_images=1000

  arrX={}
  arrY={}

  for i in range(num_images):
    arrX[i+1]=[]
    arrY[i+1]=[]

  def resample_image_in_place(self, image_path, new_res, resample_image):
      args = gdal.WarpOptions(
          xRes=new_res,
          yRes=new_res
      )
      gdal.Warp(resample_image, image_path, options=args)

  def evaluation_image(self, image_path, resolutionxyz, counter):

      # Resampling Image with 
      self.resample_image_in_place(image_path, resolutionxyz, f'/content/data/temp_image_{resolutionxyz}_resolution.tif') 

      # any_resolution resoluton prediction
      image_path_resolutionxyz_resolution = f'/content/data/temp_image_{resolutionxyz}_resolution.tif'
      boxes = self.model.predict_image(path=image_path_resolutionxyz_resolution, return_plot=False)
      df = boxes.head()
      df['image_path'] = image_path_resolutionxyz_resolution
      df.to_csv(f'/content/data/file_{resolutionxyz}_resolution.csv')

      # Evaluation of any_resolution
      csv_file = f'/content/data/file_{resolutionxyz}_resolution.csv'
      root_dir = os.path.dirname(csv_file)
      results = self.model.evaluate(csv_file, root_dir, iou_threshold=0.4, savedir=None)
      print(str(results["box_recall"]) + f" is evaluation of projected data at {resolutionxyz}")
      self.arrX[counter].append(resolutionxyz)
      self.arrY[counter].append(results["box_recall"])

  def plotting(self):
    plt.figure()

    for i in range(self.num_images):
      plt.scatter(self.arrX[i+1],self.arrY[i+1], marker='o')

    plt.xlabel("Resolution")
    plt.ylabel("Evaluation")
    plt.show()

  def results(self, input_image_path, counter, resolution_values):

    # Normal resolution prediction
    image_path = input_image_path
    boxes = self.model.predict_image(path=image_path, return_plot=False)
    df = boxes.head()
    df['image_path'] = image_path
    df.to_csv('/content/data/file_normal.csv')

    # Evaluation of normal
    csv_file = '/content/data/file_normal.csv'
    root_dir = os.path.dirname(csv_file)
    results = self.model.evaluate(csv_file, root_dir, iou_threshold=0.4, savedir=None)
    print(str(results["box_recall"]) + " is evaluation of normal data")

    self.arrX[counter].append(0)
    self.arrY[counter].append(results["box_recall"])

    # Evaluation for user defined values
    for i in resolution_values:
      self.evaluation_image(image_path, i, counter)

In [8]:
#@title Counting number of Images in our dataset
countimages=0
for file in glob.iglob(f"/content/NeonTreeEvaluation/evaluation/RGB/*"):
  countimages+=1

print(countimages)

2291


In [None]:
#@title Calculating evaluation scores on different resolutions for rasters
scores = Scores()

counter=1
for file in glob.iglob(f"/content/NeonTreeEvaluation/evaluation/RGB/*"):
  if counter<=scores.num_images:
    print(counter)
    print(f"Filename: {file}")
    try:
      scores.results(file, counter, resolution_values=[0.20, 0.40, 0.60, 0.80, 1, 1.20, 1.40, 1.60, 1.80, 2, 2.5])
    except:
      print(f"File has some invalid data, Filename : {file}")
  
  counter+=1

scores.plotting()
plt.close()