In [9]:
!pip install opencv-python
!pip install ultralytics



In [10]:
import os
os.environ["WANDB_MODE"] = "dryrun"

In [11]:
import cv2, os
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
from ultralytics import YOLO
import time

In [12]:
class Classification:
    def __init__(self, model):
        self.model = model
    
    def draw_bbox(self, image, results):
        classes = results.names
        if isinstance(image, str):
            image = Image.open(image)
            image = np.array(image)

        for i in range(len(results.boxes)):
            box = results.boxes[i]
            tensor = box.xyxy[0]
            classname = classes[int(box.cls)]
            x1 = int(tensor[0].item())
            y1 = int(tensor[1].item())
            x2 = int(tensor[2].item())
            y2 = int(tensor[3].item())
            cv2.rectangle(image,(x1,y1),(x2,y2),(255,255,255),3)
            cv2.putText(image, classname, (int(x1) + 5, int(y1) + 20), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (255,255,255), 2)
        return image

    def predict_image(self, image):
        start_time =  time.time()
        results = self.model.predict(image)[0]
        end_time = time.time()
        inference_time = end_time - start_time
        
        image = self.draw_bbox(image, results)
        return image, inference_time
    
    def predict_multiple(self, path):
        output_imgs = []
        images = os.listdir(path)
        for image in images:
            image = 'images/' + image
            output_imgs.append(self.predict_image(image)[0])

        fig, axs = plt.subplots(len(output_imgs), 1, figsize=(16, 16))

        for i, img in enumerate(output_imgs):
            axs[i].imshow(img)
            axs[i].axis('off')
        plt.tight_layout()
        plt.show()

        fig.savefig('output.png')
    
    def predict_video(self, video_path, output_path, frames=None):
        cap = cv2.VideoCapture(video_path)
        fps = cap.get(cv2.CAP_PROP_FPS)
        width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
        height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
        fourcc = cv2.VideoWriter_fourcc(*'mp4v')
        out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
        total_inference_time = 0
        num_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
        try:
            while cap.isOpened():
                if frames is not None:
                    if cap.get(cv2.CAP_PROP_POS_FRAMES) >= frames:
                        break
                ret, frame = cap.read()
                if not ret:
                    break
                out_frame, inference_time = self.predict_image(out_frame)
                total_inference_time += inference_time
                out.write(out_frame)
        except Exception as e:
            print(f"An error occurred: {e}")
        finally:
            print(f"Total inference time: {total_inference_time} s")
            print("Total number of frames: ", num_frames)
            print(f"Average inference time per frame: {(total_inference_time * 1000) / num_frames} ms")

            cap.release()
            out.release()

    def process_stream(self, input_stream=0, rate=1):
        cap = cv2.VideoCapture(input_stream)
        try: 
            frame_count = 0
            total_inference_time = 0
            try:
                while(cap.isOpened()):
                
                    frame_count += 1
                    ret, frame = cap.read()
                    if ret:
                        copy_frame = np.copy(frame)
                        frame = cv2.resize(frame, (copy_frame.shape[1], copy_frame.shape[0]))
                        out_frame = cv2.resize(frame, (1000, 800))

                        if frame_count % rate == 0:
                            out_frame, inference_time = self.predict_image(out_frame)
                            total_inference_time += inference_time
                        cv2.imshow('Real Time Trash Detection', out_frame)
                        if cv2.waitKey(1) & 0xFF == ord('q'):
                            break
                    else:
                        break
            
            except Exception as e:
                print(f"An error occurred: {e}")
                cap.release()
                cv2.destroyAllWindows()

            print(f"Total inference time: {total_inference_time} s")
            print("Total number of frames: ", frame_count)
            print(f"Average inference time per frame: {(total_inference_time * 1000) / (1 if frame_count == 0 else frame_count)} ms")

            cap.release()
            cv2.destroyAllWindows()

        except KeyboardInterrupt:
            cap.release()
            cv2.destroyAllWindows()
            print("Stream stopped by user")
        

In [13]:
model = YOLO('best.pt')
path = 'images'
classification = Classification(model)

In [14]:
classification.process_stream()


0: 512x640 (no detections), 158.1ms
Speed: 2.6ms preprocess, 158.1ms inference, 0.6ms postprocess per image at shape (1, 3, 512, 640)

0: 512x640 (no detections), 133.5ms
Speed: 2.8ms preprocess, 133.5ms inference, 0.6ms postprocess per image at shape (1, 3, 512, 640)

0: 512x640 (no detections), 125.4ms
Speed: 1.7ms preprocess, 125.4ms inference, 0.6ms postprocess per image at shape (1, 3, 512, 640)

0: 512x640 (no detections), 130.3ms
Speed: 2.8ms preprocess, 130.3ms inference, 0.7ms postprocess per image at shape (1, 3, 512, 640)

0: 512x640 (no detections), 115.2ms
Speed: 2.6ms preprocess, 115.2ms inference, 0.6ms postprocess per image at shape (1, 3, 512, 640)

0: 512x640 (no detections), 117.0ms
Speed: 1.9ms preprocess, 117.0ms inference, 0.6ms postprocess per image at shape (1, 3, 512, 640)

0: 512x640 (no detections), 108.2ms
Speed: 2.5ms preprocess, 108.2ms inference, 0.6ms postprocess per image at shape (1, 3, 512, 640)

0: 512x640 (no detections), 110.7ms
Speed: 2.5ms prepr