<a href="https://colab.research.google.com/github/prinshu756/AirMouse/blob/main/MachineLearning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

x = np.random.random(6)
plt.pie(x)
plt.show()

In [None]:
import streamlit as st
from ultralytics import YOLO
from PIL import Image
import cv2
import tempfile
import os
import numpy as np

# -------------------------------
# Streamlit Page Setup
# -------------------------------
st.set_page_config(
    page_title="Trash Detection - YOLOv10",
    page_icon="üóëÔ∏è",
    layout="wide",
)

st.title("üóëÔ∏è Trash Detection using YOLOv10")
st.sidebar.header("‚öôÔ∏è Model Settings")

# -------------------------------
# Sidebar Controls
# -------------------------------
confidence = st.sidebar.slider(
    "Detection Confidence", 0.1, 1.0, 0.3, 0.05
)

source_type = st.sidebar.radio("Select Input Type", ["Image", "Video"])

# -------------------------------
# Load YOLO Model Safely
# -------------------------------
@st.cache_resource
def load_model():
    return YOLO("best.pt")     # your model file

model = load_model()


# -------------------------------
# Video Processing Function
# -------------------------------
def process_video(video_bytes):
    # Save uploaded video temporarily
    with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as temp_input:
        temp_input.write(video_bytes)
        input_path = temp_input.name

    cap = cv2.VideoCapture(input_path)
    if not cap.isOpened():
        st.error("Error opening the video!")
        return None, None

    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    fps = cap.get(cv2.CAP_PROP_FPS)
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))

    # Temp output file
    with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as out_tmp:
        output_path = out_tmp.name

    fourcc = cv2.VideoWriter_fourcc(*"mp4v")
    writer = cv2.VideoWriter(output_path, fourcc, fps, (width, height))

    frame_number = 0
    progress = st.progress(0.0)

    class_totals = {}

    while True:
        ret, frame = cap.read()
        if not ret:
            break

        results = model(frame, conf=confidence)
        processed = results[0].plot()

        # Count objects
        for b in results[0].boxes:
            cls_name = model.names[int(b.cls)]
            class_totals[cls_name] = class_totals.get(cls_name, 0) + 1

        writer.write(processed)

        frame_number += 1
        progress.progress(frame_number / total_frames)

    cap.release()
    writer.release()

    with open(output_path, "rb") as f:
        processed_video = f.read()

    os.remove(input_path)
    os.remove(output_path)

    return processed_video, class_totals


# -------------------------------
# IMAGE INPUT
# -------------------------------
if source_type == "Image":
    uploaded_image = st.file_uploader("Upload an Image", type=["jpg", "jpeg", "png", "bmp", "webp"])

    col1, col2 = st.columns(2)

    with col1:
        if uploaded_image:
            img = Image.open(uploaded_image)
            st.image(img, caption="Uploaded Image", use_column_width=True)

    with col2:
        if uploaded_image and st.sidebar.button("Detect"):
            img_cv = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
            results = model(img_cv, conf=confidence)

            processed = results[0].plot()
            processed_rgb = cv2.cvtColor(processed, cv2.COLOR_BGR2RGB)

            st.image(processed_rgb, caption="Detection Output", use_column_width=True)

            # Count classes
            class_counts = {}
            for cls in results[0].boxes.cls:
                name = model.names[int(cls)]
                class_counts[name] = class_counts.get(name, 0) + 1

            st.subheader("Object Count")
            st.table([{"Class": k, "Count": v} for k, v in class_counts.items()])


# -------------------------------
# VIDEO INPUT
# -------------------------------
else:
    uploaded_video = st.file_uploader("Upload a Video", type=["mp4", "avi", "mov", "mkv"])

    if uploaded_video:
        video_bytes = uploaded_video.read()
        st.video(video_bytes)

        if st.sidebar.button("Detect"):
            st.success("Processing video... please wait")

            processed_video, class_totals = process_video(video_bytes)

            st.subheader("Detection Results")
            st.table([{"Class": k, "Total Count": v} for k, v in class_totals.items()])

            st.download_button(
                label="‚¨áÔ∏è Download Processed Video",
                data=processed_video,
                file_name="detected_output.mp4",
                mime="video/mp4"
            )