<a href="https://colab.research.google.com/github/vbakomichali/Ai-BuiltEnvironment/blob/main/IAAC2024_YOLO_Finetune_finetune_VB.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Imports

In [None]:
!pip install ultralytics==8.0.196

In [None]:
from ultralytics import YOLO
from IPython.display import display, Image
import cv2

In [None]:
import os
import shutil
from google.colab import files
import pickle
HOME = "/content"

# Approach A: Local File

In [None]:
# upload image
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
# copy dataset from google drive to colab
source_dir = '/content/drive/MyDrive/iaac2024/ViennaSatImages/Datasets/RoofAndContext/'
dest_dir = '/content/datasets/'

shutil.copytree(source_dir, dest_dir, dirs_exist_ok=True)

'/content/datasets/'

# Approach B: via API from roboflow

In [None]:
# paste snippet from roboflow


# Training / fine-tuning the model

In [None]:
#%cd {HOME}
HOME="ll"

In [None]:
# load weights if pretrained model and train!
# paramter documentation: https://docs.ultralytics.com/modes/train/#clearml

model_ori = YOLO("yolov8s.pt")
training_res = model_ori.train(data="/content/RoofAndContext/data.yaml", epochs=12, imgsz=640)

# Using the fine-tuned model

In [None]:
path_to_model = f'/content/runs/detect/train3/weights/best.pt'
model = YOLO(path_to_model)

In [None]:
# download model weights
files.download(path_to_model)

Run the fine tuned model

In [None]:
# run model on image
results = model.predict(source='/content/RoofAndContext/valid/images/slice_115_jpg.rf.6c7c12609c011f05d39c64a2b5d5b301.jpg', conf=0.25)


image 1/1 /content/RoofAndContext/valid/images/slice_115_jpg.rf.6c7c12609c011f05d39c64a2b5d5b301.jpg: 640x640 9 flatroofs, 27 trees, 9.0ms
Speed: 1.9ms preprocess, 9.0ms inference, 3.4ms postprocess per image at shape (1, 3, 640, 640)


Lets have look how the "results" look like

In [None]:
# check the documentation
help(results[0])

In [None]:
# check available methods
dir(results[0])

In [None]:
# get documentation of plotting parameter
print(help(results[0].plot))

In [None]:
# use the inbuilt "plot" method to visualise the annotated image
results[0].plot(conf=False, line_width=2, labels=True)

Function to extract the number of detections for each class

In [None]:
def getDetectionCountsAndLocation(res):
  # create a dictionary for each class
  classNameMapper = res.names # id -> name mapper is automatically created

  # create a dictionary with an entry for each class name and a initial coutner set to 0
  classCount = { v:0 for k, v in classNameMapper.items()}
  for bbox_id in  res.boxes.cls.tolist():
    className = classNameMapper[bbox_id]
    classCount[className] += 1
  # store center point of each prediction
  objLocations = { v:[] for k, v in classNameMapper.items()}
  for box in  res.boxes:
     bbox_id = box.cls.tolist()[0]
     className = classNameMapper[bbox_id]
     corners = box.xyxy.tolist()[0]
     centerPt =   [corners[0] + corners[2] / 2 , corners[1] + corners[3] / 2]
     objLocations[className].append(centerPt)

  #print(classCount)
  return classCount, objLocations

Extracting object counts an annoted image

In [None]:
detectionResults = []
for res in results:
  resDict = {}
  resDict["image"] = res.plot(conf=False, line_width=2, labels=True, pil=True)

  # get counts and detection
  detectionCount, detectionLocation = getDetectionCountsAndLocation(res)
  print (detectionLocation)
  resDict["detectionCount"] = detectionCount
  resDict["detectionLocation"] = detectionLocation

  detectionResults.append(resDict)

{'flatroof': [[678.0396118164062, 490.8361511230469], [281.3410949707031, 186.37432098388672], [388.1791534423828, 42.92464828491211], [642.6951751708984, 834.5809936523438], [157.1570816040039, 26.415451049804688], [283.78245544433594, 288.13836669921875], [277.1182098388672, 46.087669372558594], [817.7100830078125, 549.7976226806641], [464.69276428222656, 849.0652465820312]], 'tree': [[299.2652893066406, 422.5239715576172], [190.6529312133789, 438.7717590332031], [190.530517578125, 128.10652923583984], [560.850341796875, 534.1586761474609], [481.49298095703125, 322.7660675048828], [556.9756164550781, 404.9252624511719], [261.921142578125, 367.7220916748047], [552.4339904785156, 740.4807739257812], [99.54841995239258, 426.09324645996094], [615.9194641113281, 911.9180297851562], [52.44707489013672, 690.1251220703125], [555.1921081542969, 643.8731079101562], [772.5315246582031, 898.30908203125], [445.1993713378906, 353.86314392089844], [863.9230346679688, 717.4561157226562], [9.52856349

In [None]:
detectionResults[0]

## Save and re-load prediction results

In [None]:
# save as pkl
with open('prediction_results.pkl', 'wb') as f:
    pickle.dump(detectionResults, f)

# Download the file
files.download('prediction_results.pkl')

In [None]:
# load saved dictionary
uploaded = files.upload()

# Make sure the key matches the file name you have uploaded
with open('my_data.pkl', 'rb') as f:
    my_list = pickle.load(f)

# BONUS: using a model via gradio on your phone

In [None]:
# import gradio dependencies
!pip install gradio
import gradio as gr
from PIL import Image
import torch

In [None]:
def predict(image):
    # Convert PIL Image to file path if needed by YOLO
    image.save('/content/temp_image.jpg')
    results = model_gr.predict(source='/content/temp_image.jpg', conf=0.25)

    results[0].plot(conf=False, line_width=2, labels=True, pil=True)

    # get counts and detection
    detectionCount, detectionLocation = getDetectionCountsAndLocation(results[0])

    # Convert tensor to image
    output_image =  results[0].plot(conf=False, line_width=2, labels=True, pil=True)

    return output_image, "objects detected:" + str(detectionCount)



In [None]:
# load a model
path_to_model_gr = f'/content/runs/detect/train3/weights/best.pt'
model_gr = YOLO(path_to_model_gr)

In [None]:
# Create Gradio interface
iface = gr.Interface(
    fn=predict,
    inputs=gr.Image(type="pil"),
    outputs=[gr.Image(type="pil"), gr.Textbox()],
    title="YOLOv8 Object Detection",
    description="Upload an image to detect objects using YOLOv8"
)
# Launch the app
iface.launch(debug=True, share=True)