In [None]:
import time
import numpy as np
from jetbot import Robot
from jetbot import Camera
from PIL import Image as PILImage
import torch
import torch.nn as nn
from torchvision.models import mobilenet_v2
import torchvision.transforms as transforms
import os
import traitlets
import ipywidgets.widgets as widgets
from IPython.display import display
import cv2


OBJECT_IMAGES_PATH = "dataset/"
DETECTION_THRESHOLD = 0.6   
APPROACH_SPEED = 0.2  
TURN_SPEED = 0.15   
CENTER_THRESHOLD = 30 


robot = Robot()
camera = Camera.instance(width=224, height=224)


image_widget = widgets.Image(format='jpeg', width=224, height=224)
status_widget = widgets.Text(value='Initializing...', description='Status:')
display(image_widget, status_widget)


def bgr8_to_jpeg(value):
    return bytes(cv2.imencode('.jpg', value)[1])


camera_link = traitlets.dlink((camera, 'value'), (image_widget, 'value'), transform=bgr8_to_jpeg)


model = mobilenet_v2(pretrained=True)

feature_extractor = nn.Sequential(*list(model.children())[:-1])
feature_extractor.eval()

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

object_features = {}

def load_reference_images():
    status_widget.value = "Loading reference images..."
    
    if not os.path.exists(OBJECT_IMAGES_PATH):
        status_widget.value = f"Error: directory {OBJECT_IMAGES_PATH} not found"
        return False
    
    object_files = [f for f in os.listdir(OBJECT_IMAGES_PATH) if f.endswith('.jpg')]
    
    if not object_files:
        status_widget.value = "Error: No jpg files found in dataset directory"
        return False
    
    for file in object_files:
        name = file.split('.')[0]
        img_path = os.path.join(OBJECT_IMAGES_PATH, file)
        
        try:
            image = PILImage.open(img_path).convert('RGB')
            image_tensor = transform(image).unsqueeze(0)
            

            with torch.no_grad():
                features = feature_extractor(image_tensor)
                features = features.view(features.size(0), -1)
                object_features[name] = features
                
            print(f"Loaded reference image for {name}")
        except Exception as e:
            print(f"Error loading {name}: {e}")
    
    status_widget.value = f"Loaded {len(object_features)} reference objects"
    return len(object_features) > 0

def cosine_similarity(t1, t2):
    t1_norm = torch.nn.functional.normalize(t1, p=2, dim=1)
    t2_norm = torch.nn.functional.normalize(t2, p=2, dim=1)
    return torch.mm(t1_norm, t2_norm.transpose(0, 1)).item()


def identify_object():
    frame = camera.value
    if frame is None:
        return None, 0.0
        
    frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
    img = PILImage.fromarray(frame_rgb)
    img_tensor = transform(img).unsqueeze(0)
    
    with torch.no_grad():
        features = feature_extractor(img_tensor)
        features = features.view(features.size(0), -1) 
 
    similarities = {}
    for name, ref_features in object_features.items():
        sim = cosine_similarity(features, ref_features)
        similarities[name] = sim

    if similarities:
        best_match = max(similarities.items(), key=lambda x: x[1])
        if best_match[1] > DETECTION_THRESHOLD:
            return best_match
    
    return None, 0.0

def detect_object_position(frame, object_name):
    h, w = frame.shape[:2]
    center_x, center_y = w // 2, h // 2
    
    try:
        ref_path = os.path.join(OBJECT_IMAGES_PATH, f"{object_name}.jpg")
        template = cv2.imread(ref_path)
        template = cv2.resize(template, (w // 4, h // 4))  # Resize template
        
        frame_gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
        template_gray = cv2.cvtColor(template, cv2.COLOR_BGR2GRAY)
        
        result = cv2.matchTemplate(frame_gray, template_gray, cv2.TM_CCOEFF_NORMED)
        min_val, max_val, min_loc, max_loc = cv2.minMaxLoc(result)
        
        th, tw = template_gray.shape
        top_left = max_loc
        obj_center_x = top_left[0] + tw // 2
        obj_center_y = top_left[1] + th // 2
        
        vis_frame = frame.copy()
        cv2.rectangle(vis_frame, top_left, (top_left[0] + tw, top_left[1] + th), (0, 255, 0), 2)
        cv2.circle(vis_frame, (obj_center_x, obj_center_y), 5, (0, 0, 255), -1)
        cv2.line(vis_frame, (center_x, center_y), (obj_center_x, obj_center_y), (255, 0, 0), 2)
        cv2.putText(vis_frame, f"{object_name}: {max_val:.2f}", 
                   (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)
        
        x_offset = obj_center_x - center_x
        
        image_widget.value = bgr8_to_jpeg(vis_frame)
        
        return (x_offset, max_val)
    except Exception as e:
        print(f"Error in position detection: {e}")
        return (0, 0)

def detect_and_approach():
    try:
        if not load_reference_images():
            print("Failed to load reference images. Exiting.")
            return
        
        status_widget.value = "Starting object detection..."
        
        while True:
            object_name, confidence = identify_object()
            
            if object_name:
                status_widget.value = f"Detected {object_name} ({confidence:.2f})"
                
                frame = camera.value
                if frame is None:
                    continue
                
                # Get object position
                x_offset, match_conf = detect_object_position(frame, object_name)
                
                if abs(x_offset) < CENTER_THRESHOLD:
                    status_widget.value = f"Moving toward {object_name}"
                    robot.forward(APPROACH_SPEED)
                    time.sleep(0.5)
                elif x_offset > 0:
                    status_widget.value = f"Turning right to center {object_name}"
                    robot.right(TURN_SPEED)
                    time.sleep(0.2)
                else:
                    status_widget.value = f"Turning left to center {object_name}"
                    robot.left(TURN_SPEED)
                    time.sleep(0.2)
                
                robot.stop()
                time.sleep(0.1)
            else:
                status_widget.value = "No objects detected, scanning..."
                robot.left(TURN_SPEED * 0.8)
                time.sleep(0.3)
                robot.stop()
                time.sleep(0.1)
    
    except KeyboardInterrupt:
        status_widget.value = "Stopped by user"
    except Exception as e:
        status_widget.value = f"Error: {e}"
    finally:
        robot.stop()
        camera.stop()

if __name__ == "__main__":
    detect_and_approach()

Image(value=b'', format='jpeg', height='224', width='224')

Text(value='Initializing...', description='Status:')

Loaded reference image for case


In [None]:
camera.stop()