In [None]:
from PIL import Image
import requests
import matplotlib.pyplot as plt
import torch
import torchvision.transforms as transforms


torch.set_grad_enabled(False)  # 关闭梯度计算

# COCO classes
CLASSES = [
    'N/A', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
    'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A',
    'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse',
    'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack',
    'umbrella', 'N/A', 'N/A', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis',
    'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove',
    'skateboard', 'surfboard', 'tennis racket', 'bottle', 'N/A', 'wine glass',
    'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich',
    'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake',
    'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table', 'N/A',
    'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard',
    'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A',
    'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier',
    'toothbrush'
]
# standard PyTorch mean-std input image normalization
transform = transforms.Compose(
    [
        transforms.Resize(800),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ]
)

# 加载线上的模型与参数
model = torch.hub.load("facebookresearch/detr", "detr_resnet50", pretrained=True)
model.eval()

# 线上下载图像 (图像获取不到的话，就自己换一张图)
url = "http://farm3.staticflickr.com/2750/4078616721_d76a64a6bb_z.jpg"
im = Image.open(requests.get(url, stream=True).raw)

# 预处理图像 (batch-size: 1)
img = transform(im).unsqueeze(0)

# 预测结果
outputs = model(img)

# 只保留 confidence > 0.9 的预测结果
probas = outputs["pred_logits"].softmax(-1)[0, :, :-1]
keep = probas.max(-1).values > 0.9

# 获取 bbox 中心坐标
cxcy = outputs["pred_boxes"][0, keep, :2]

# 获取 预测类别 和 置信度
confidence, predicted_cats = outputs["pred_logits"][0, keep].softmax(-1).max(-1)
predicted_cats_name = [CLASSES[cat] for cat in predicted_cats]

# use lists to store the outputs via up-values
conv_features = []
enc_attn_weights = []

# 注册hook
hooks = [
    # 获取resnet最后一层特征图，目的是获取 特征图的尺寸
    model.backbone[-2].register_forward_hook(
        lambda self, input, output: conv_features.append(output["0"].tensors)
    ),
    # 获取最后一个 encoder layer的 self-attn weights
    model.transformer.encoder.layers[-1].self_attn.register_forward_hook(
        lambda self, input, output: enc_attn_weights.append(output[1])
    ),
]

# 前向传播，获取 hook 注册的中检测输出
outputs = model(img)

# 用完的hook后删除
for hook in hooks:
    hook.remove()

# don't need the list anymore
enc_attn_weights = enc_attn_weights[0]  # [1, 950, 950]
conv_features = conv_features[0]  # [1,2048,25,38]


# 获取 feature map 尺寸
h, w = conv_features.shape[-2:]
# cxcy = torch.floor(cxcy * torch.tensor([w, h]))
enc_attn_weights = enc_attn_weights.view(h, w, h * w)


# 可视化
fig, axs = plt.subplots(ncols=1, nrows=len(cxcy) + 1, figsize=(10, 12))

# 在原图上标注 reference point
ax = axs[0]
ax.axis("off")
ax.imshow(im)

"""
# 我们求出的 bbox 的中心点坐标为 tensor([[171., 143.], [438., 135.]])， 它不是正好在 object 主体上
# 可以做适当的调整（人工手动选择），调整后的坐标值作为 reference point，
# 下面会可视化 reference point 的 attention map
# 我们将 reference point 的坐标调整为 ：[[210., 90.], [400., 100.]]
"""

# reference_points = torch.floor(cxcy / torch.tensor([w, h]) * torch.tensor([im.width, im.height]))   # tensor([[171., 143.], [438., 135.]])

reference_points = torch.tensor([[210.0, 90.0], [400.0, 100.0]])
ax.scatter(reference_points[:, 0], reference_points[:, 1], color="red", marker="o")

for ax, reference_point in zip(axs[1:], reference_points):
    x = (reference_point[0] / im.width * w).type(torch.long)
    y = (reference_point[1] / im.height * h).type(torch.long)
    attention_meap = enc_attn_weights[y, x, :].view(h, w).unsqueeze(-1)

    ax.imshow(attention_meap, cmap="cividis")
    ax.axis("off")
    ax.set_title(
        f"self-attention: ({int(reference_point[0])}, {int(reference_point[1])})",
        fontsize=20,
    )

fig.tight_layout()  # 自动调整子图来使其填充整个画布
plt.show()