In [1]:
!pip install streamlit

Defaulting to user installation because normal site-packages is not writeable


In [None]:
# =======================
# Fine-tune Faster R-CNN & Streamlit App
# =======================
import torch
import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.transforms import functional as F
from torch.utils.data import DataLoader, Dataset
import os
from PIL import Image
import xml.etree.ElementTree as ET
import streamlit as st
import numpy as np
import cv2

# ----- Custom Dataset -----
class FruitDataset(Dataset):
    def __init__(self, image_dir, annotation_dir, transforms=None):
        self.image_dir = image_dir
        self.annotation_dir = annotation_dir
        self.transforms = transforms
        self.images = list(sorted(os.listdir(image_dir)))

    def __getitem__(self, idx):
        img_path = os.path.join(self.image_dir, self.images[idx])
        ann_path = os.path.join(self.annotation_dir, self.images[idx].replace(".jpg", ".xml"))

        img = Image.open(img_path).convert("RGB")
        tree = ET.parse(ann_path)
        root = tree.getroot()

        boxes, labels = [], []
        for obj in root.findall("object"):
            name = obj.find("name").text
            label = {"apple":91, "banana":92, "orange":93}[name]  # gán nhãn
            xml_box = obj.find("bndbox")
            b = [float(xml_box.find(tag).text) for tag in ["xmin", "ymin", "xmax", "ymax"]]
            boxes.append(b)
            labels.append(label)

        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        labels = torch.as_tensor(labels, dtype=torch.int64)
        image_id = torch.tensor([idx])
        area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
        iscrowd = torch.zeros((len(boxes),), dtype=torch.int64)

        target = {"boxes": boxes, "labels": labels, "image_id": image_id,
                  "area": area, "iscrowd": iscrowd}

        if self.transforms:
            img = self.transforms(img)

        return img, target

    def __len__(self):
        return len(self.images)



# ----- Model Definition -----
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, 94)  # COCO (91) + 3 lớp mới








In [None]:
# ----- Dataset & DataLoader -----
transforms = torchvision.transforms.ToTensor()
train_dataset = FruitDataset("data/images", "data/annotations", transforms)
train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True, collate_fn=lambda x: tuple(zip(*x)))

# ----- Optimizer -----
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=0.005, momentum=0.9, weight_decay=0.0005)

# ----- Training Loop -----
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

num_epochs = 5
for epoch in range(num_epochs):
    model.train()
    for images, targets in train_loader:
        images = list(img.to(device) for img in images)
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

        loss_dict = model(images, targets)
        losses = sum(loss for loss in loss_dict.values())

        optimizer.zero_grad()
        losses.backward()
        optimizer.step()

    print(f"Epoch {epoch+1}, Loss: {losses.item():.4f}")

# ----- Save model -----
torch.save(model.state_dict(), "model.pth")

UnidentifiedImageError: cannot identify image file 'C:\\Users\\uanvs\\Downloads\\pytorch-object-detection-main\\pytorch-object-detection-main\\test\\apple_88.xml'

In [None]:
# =======================
# Streamlit Interface
# =======================
st.set_page_config(page_title="Fruit Detector")
st.title("🍎🍌🍊 Fruit Detection App")

uploaded_files = st.file_uploader("Upload images", type=["jpg", "jpeg", "png"], accept_multiple_files=True)

if uploaded_files:
    model.eval()
    model.load_state_dict(torch.load("model.pth", map_location=device))
    model.to(device)

    for uploaded_file in uploaded_files:
        image = Image.open(uploaded_file).convert("RGB")
        img_tensor = F.to_tensor(image).unsqueeze(0).to(device)
        detections = model(img_tensor)[0]

        img_np = np.array(image)
        for i in range(len(detections["boxes"])):
            score = detections["scores"][i].item()
            if score > 0.5:
                box = detections["boxes"][i].detach().cpu().numpy().astype(int)
                label = detections["labels"][i].item()
                name = {91:"apple", 92:"banana", 93:"orange"}.get(label, str(label))
                cv2.rectangle(img_np, tuple(box[:2]), tuple(box[2:]), (255,0,0), 2)
                cv2.putText(img_np, f"{name}: {score:.2f}", (box[0], box[1]-10),
                            cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255,0,0), 2)

        st.image(img_np, caption="Detection Result", use_column_width=True)