In [1]:
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import numpy as np
from network import RetinaNet
import torch

img_mean = np.array([0.485, 0.456, 0.406])
img_std = np.array([0.229, 0.224, 0.225])

def visualize_anchor(image, bboxes):
    _, axes = plt.subplots(3,2, figsize=(15, 15))
    axes = axes.flatten()
    img_plt = image[0].permute(1,2,0).cpu().numpy()
    img_plt = img_std * img_plt + img_mean
    img_plt = np.clip(img_plt, 0, 1)
    for i, bbox in enumerate(bboxes):
        axes[i].imshow(img_plt)
        for b in bbox:
            x1, y1, x2, y2 = b
            rect = patches.Rectangle((x1, y1), x2-x1, y2-y1, 
                                    linewidth=1, edgecolor='r', 
                                    facecolor='none')
            axes[i].text(x1, y1-10, 'anchor', color='r')
            axes[i].add_patch(rect)

In [2]:
model = RetinaNet(fpn=True, p67=True)

In [None]:
image = torch.zeros((1, 3, 731, 1333)).float()
outs = model(image)
anchor_pyramid = []
for out in outs:
    _, _, anchor = out
    B, A, H, W = anchor.shape
    anchor = anchor.reshape(B, A // 4, 4, H, W).squeeze(0).cpu().numpy()
    anchor_pyramid.append(anchor[..., H//2, W//2])
visualize_anchor(image, anchor_pyramid)

<img src="anchor.png" width="800" />
<!-- ![anchor.png](anchor.png) -->