Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 139 additions & 0 deletions references/detection/eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
import numpy as np
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this file is completely unrelated to the windows build, right? I think it's a good idea to have it in the repo, but as a separate PR so that we can focus on how to better illustrate this.

import torchvision
import skimage.io
import colorsys
import matplotlib
import random
import matplotlib.patches as patches
from matplotlib import pyplot as plt
from skimage.measure import find_contours
from PIL import Image
from matplotlib.patches import Polygon

# COCO Class names
# Index of the class in the list is its ID. For example, to get ID of
# the teddy bear class, use: class_names.index('teddy bear')
class_names = ['BG', 'person', 'bicycle', 'car', 'motorcycle', 'airplane',
'bus', 'train', 'truck', 'boat', 'traffic light',
'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird',
'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear',
'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie',
'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
'kite', 'baseball bat', 'baseball glove', 'skateboard',
'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup',
'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed',
'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote',
'keyboard', 'cell phone', 'microwave', 'oven', 'toaster',
'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors',
'teddy bear', 'hair drier', 'toothbrush']

def apply_mask(image, mask, color, alpha=0.5):
"""Apply the given mask to the image.
"""
for c in range(3):
image[:, :, c] = np.where(mask >= 0.5,
image[:, :, c] *
(1 - alpha) + alpha * color[c] * 255,
image[:, :, c])
return image

def random_colors(N, bright=True):
"""
Generate random colors.
To get visually distinct colors, generate them in HSV space then
convert to RGB.
"""
brightness = 1.0 if bright else 0.7
hsv = [(i / N, 1, brightness) for i in range(N)]
colors = list(map(lambda c: colorsys.hsv_to_rgb(*c), hsv))
random.shuffle(colors)
return colors

def display_instances(image, boxes, masks, class_ids, class_names,
scores=None, title="",
figsize=(16, 16), ax=None):
"""
boxes: [num_instance, (y1, x1, y2, x2, class_id)] in image coordinates.
masks: [height, width, num_instances]
class_ids: [num_instances]
class_names: list of class names of the dataset
scores: (optional) confidence scores for each box
figsize: (optional) the size of the image.
"""
# Number of instances
N = boxes.shape[0]
if not N:
print("\n*** No instances to display *** \n")

if not ax:
_, ax = plt.subplots(1, figsize=figsize)

# Generate random colors
colors = random_colors(N)

# Show area outside image boundaries.
height, width = image.shape[:2]
ax.set_ylim(height + 10, -10)
ax.set_xlim(-10, width + 10)
ax.axis('off')
ax.set_title(title)

masked_image = image.astype(np.uint32).copy()
for i in range(N):
color = colors[i]

score = scores[i] if scores is not None else None
if score < 0.4:
continue

x1, y1, x2, y2, = boxes[i]
p = patches.Rectangle((x1, y1), x2 - x1, y2 - y1, linewidth=2,
alpha=0.7, linestyle="dashed",
edgecolor=color, facecolor='none')
ax.add_patch(p)

# Label
class_id = class_ids[i]
label = class_names[class_id]
caption = "{} {:.3f}".format(label, score) if score else label
ax.text(x1, y1 + 8, caption,
color='w', size=11, backgroundcolor="none")

# Mask
mask = masks[i, :, :].detach().squeeze(-1)
masked_image = apply_mask(masked_image, mask, color)

# Mask Polygon
# Pad to ensure proper polygons for masks that touch image edges.
padded_mask = np.zeros(
(mask.shape[1] + 2, mask.shape[2] + 2), dtype=np.uint8)
padded_mask[1:-1, 1:-1] = mask
contours = find_contours(padded_mask, 0.5)
for verts in contours:
# Subtract the padding and flip (y, x) to (x, y)
verts = np.fliplr(verts) - 1
p = Polygon(verts, facecolor="none", edgecolor=color)
ax.add_patch(p)
ax.imshow(masked_image.astype(np.uint8))
plt.show()

model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True)
# set it to evaluation mode, as the model behaves differently
# during training and during evaluation
model.eval()

image = Image.open('test.jpg')
image_tensor = torchvision.transforms.functional.to_tensor(image)

output = model([image_tensor])

img = skimage.io.imread('test.jpg')

# Visualize results
r = output[0]
display_instances(img, r['boxes'], r['masks'], r['labels'],
class_names, r['scores'])
plt.show()

4 changes: 2 additions & 2 deletions torchvision/csrc/cuda/ROIAlign_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ at::Tensor ROIAlign_forward_cuda(
auto output_size = num_rois * pooled_height * pooled_width * channels;
cudaStream_t stream = at::cuda::getCurrentCUDAStream();

dim3 grid(std::min(at::cuda::ATenCeilDiv(output_size, 512L), 4096L));
dim3 grid(std::min(at::cuda::ATenCeilDiv(static_cast<int64_t>(output_size), static_cast<int64_t>(512)), static_cast<int64_t>(4096)));
dim3 block(512);

if (output.numel() == 0) {
Expand Down Expand Up @@ -379,7 +379,7 @@ at::Tensor ROIAlign_backward_cuda(

cudaStream_t stream = at::cuda::getCurrentCUDAStream();

dim3 grid(std::min(at::cuda::ATenCeilDiv(grad.numel(), 512L), 4096L));
dim3 grid(std::min(at::cuda::ATenCeilDiv(static_cast<int64_t>(grad.numel()), static_cast<int64_t>(512)), static_cast<int64_t>(4096)));
dim3 block(512);

// handle possibly empty gradients
Expand Down
4 changes: 2 additions & 2 deletions torchvision/csrc/cuda/ROIPool_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ std::tuple<at::Tensor, at::Tensor> ROIPool_forward_cuda(
auto output_size = num_rois * pooled_height * pooled_width * channels;
cudaStream_t stream = at::cuda::getCurrentCUDAStream();

dim3 grid(std::min(at::cuda::ATenCeilDiv(output_size, 512L), 4096L));
dim3 grid(std::min(at::cuda::ATenCeilDiv(static_cast<int64_t>(output_size), static_cast<int64_t>(512)), static_cast<int64_t>(4096)));
dim3 block(512);

if (output.numel() == 0) {
Expand Down Expand Up @@ -204,7 +204,7 @@ at::Tensor ROIPool_backward_cuda(

cudaStream_t stream = at::cuda::getCurrentCUDAStream();

dim3 grid(std::min(at::cuda::ATenCeilDiv(grad.numel(), 512L), 4096L));
dim3 grid(std::min(at::cuda::ATenCeilDiv(static_cast<int64_t>(grad.numel()), static_cast<int64_t>(512)), static_cast<int64_t>(4096)));
dim3 block(512);

// handle possibly empty gradients
Expand Down