In [296]:
import os
import cv2
import torch
from torchvision import transforms
from prototypical_net import ConvNet
from PIL import Image
import random

### Load

In [297]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Parameters
target_class = "rectangle"  # Change this to the desired class
support_dir = os.path.join("synthetic-shapes", target_class)
workspace_path = "workspace.png"

### Transforms

In [298]:
# Transforms
transform = transforms.Compose([
    transforms.Resize((84, 84)),
    transforms.ToTensor(),
    transforms.Normalize([0.5]*3, [0.5]*3),
])

In [299]:
# Load model
model = ConvNet().to(device)
model.load_state_dict(torch.load("model.pth"))
model.eval()


ConvNet(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (4): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (5): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): ReLU()
    (7): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (8): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): ReLU()
    (11): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (12): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (1

In [300]:
# Pick one random image from the class folder
sample_img_path = random.choice([
    os.path.join(support_dir, f) for f in os.listdir(support_dir) if f.endswith(".png")
])
sample_img = Image.open(sample_img_path).convert("RGB")
sample_tensor = transform(sample_img).unsqueeze(0).to(device)
sample_embedding = model(sample_tensor)

# Load and process workspace image
workspace = cv2.imread(workspace_path)
workspace = cv2.cvtColor(workspace, cv2.COLOR_BGR2RGB)

In [301]:
# support_path = "synthetic-shapes/circle/img1.png"
# # مرحله 1: محاسبه embedding پشتیبان
# with torch.no_grad():
#     support_tensor = transform(Image.open(support_path).convert("RGB")).unsqueeze(0).to(device)
#     support_embedding = model(support_tensor)

# # مرحله 2: sliding window در چند scale
# patch_sizes = [84, 96, 112, 128]
# stride = 20
# min_dist = float('inf')
# for patch_size in patch_sizes:
#     for y in range(0, h - patch_size + 1, stride):
#         for x in range(0, w - patch_size + 1, stride):
#             patch = workspace[y:y+patch_size, x:x+patch_size]
#             patch_resized = cv2.resize(patch, (84, 84))  # همان اندازه آموزش
#             patch_tensor = transform(Image.fromarray(patch_resized)).unsqueeze(0).to(device)

#             patch_embedding = model(patch_tensor)
#             dist = torch.norm(support_embedding - patch_embedding).item()

#             if dist < min_dist:
#                 min_dist = dist
#                 best_pos = (x + patch_size // 2, y + patch_size // 2)

# print(f"Class: {target_class}, Location: {best_pos}, Distance: {min_dist:.4f}")

In [302]:
from detect_object_multiscale import detect_object_multiscale


result = detect_object_multiscale(
    model=model,
    support_path="synthetic-shapes/rectangle/img0.png",
    workspace_path="workspace.png",
    patch_sizes=[84, 96, 112, 128],
    stride=20,
    device='cuda' if torch.cuda.is_available() else 'cpu'
)

print(f"Best match at: {result['location']}, Distance: {result['distance']:.4f}, Patch Size: {result['patch_size']}")


Best match at: (322, 322), Distance: 0.1491, Patch Size: 84


In [303]:
# cv2.circle(workspace, best_pos, 10, (255, 100, 0), 2)
# cv2.imwrite("detected.png", cv2.cvtColor(workspace, cv2.COLOR_RGB2BGR))

In [304]:
best_pos = result['location']
cv2.circle(workspace, best_pos, 5, (255, 250, 0), 2)
cv2.imwrite("detected.png", cv2.cvtColor(workspace, cv2.COLOR_RGB2BGR))

True