-
Notifications
You must be signed in to change notification settings - Fork 4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Generation of attention heatmaps #10
Comments
Hello, you can refer to the following code to generate attention heatmaps: fpd/ffa.py class PrototypesDistillation(BaseModule):
def forward(self, support_feats, support_gt_labels=None, forward_novel=False, forward_novel_test=False, img_metas=None):
# ...
prototypes = torch.matmul(attn.softmax(-1), v)
# attention heatmap of feature queries upon support images
if len(support_gt_labels) == 0:
pass
else:
import cv2
import os
os.makedirs('attention_heatmap', exist_ok=True)
if not support_feats.size(0):
sys.exit()
if img_metas is not None:
img = img_metas[-1]
bs = len(support_gt_labels)
prefix = 'novel' if forward_novel else 'base'
for img_id in range(bs):
gt_label = support_gt_labels[img_id].cpu().numpy()
file_name = img_metas[img_id]['filename'].split('/')[-1]
# # attn heat map
attn2 = attn.squeeze(1).softmax(-1) # (16, 5, 49)
attn2 = (attn2 - attn2.min(dim=2, keepdim=True)[0]) / (attn2.max(dim=2, keepdim=True)[0] - attn2.min(dim=2, keepdim=True)[0])
for q_id in range(attn2.size(1)):
attn_hm = attn2[img_id, q_id:q_id+1, :].reshape(1, 1, support_feats_mp.size(-2), support_feats_mp.size(-1))
attn_hm = F.interpolate(attn_hm, size=(224, 224), mode='bilinear', align_corners=True)[0].permute(1, 2, 0).cpu().numpy()
mean = torch.ones((3, 224, 224)) * torch.tensor([103.530, 116.280, 123.675])[:, None, None]
raw_im = img[img_id, :3].add(mean.cuda()).permute(1, 2, 0).detach().cpu().numpy()
heatmap = cv2.applyColorMap((attn_hm * 255).astype('uint8'), cv2.COLORMAP_JET)
result = cv2.addWeighted(raw_im.astype('uint8'), 0.6, heatmap, 0.4, 0)
cv2.imwrite(f'attention_heatmap/{prefix}_class{gt_label}_{file_name.split(".")[0]}_query{q_id}.png', result)
return weight, prototypes fpd/fpd_detector.py def forward_model_init():
# prototypes distillation
# weight_base, prototypes_base = self.roi_head.prototypes_distillation(
# base_support_feats, support_gt_labels=r_b_gts)
# weight_novel, prototypes_novel = self.roi_head.prototypes_distillation(
# novel_support_feats, support_gt_labels=r_n_gts, forward_novel=True, forward_novel_test=True)
# todo
img_metas.append(img)
weight_base, prototypes_base = self.roi_head.prototypes_distillation(
base_support_feats, support_gt_labels=r_b_gts, img_metas=img_metas)
weight_novel, prototypes_novel = self.roi_head.prototypes_distillation(
novel_support_feats, support_gt_labels=r_n_gts, forward_novel=True, forward_novel_test=True, img_metas=img_metas) |
Okay, thank you for your prompt reply! |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Sorry to bother you, but I would like to ask how the attention heatmap below was generated. Can you provide the relevant code? Thank you very much!
![image](https://private-user-images.githubusercontent.com/45289402/321141755-77c40295-d3dd-4196-afbf-dec1dfdf8656.png?jwt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTUiLCJleHAiOjE3MjI1NTczMjksIm5iZiI6MTcyMjU1NzAyOSwicGF0aCI6Ii80NTI4OTQwMi8zMjExNDE3NTUtNzdjNDAyOTUtZDNkZC00MTk2LWFmYmYtZGVjMWRmZGY4NjU2LnBuZz9YLUFtei1BbGdvcml0aG09QVdTNC1ITUFDLVNIQTI1NiZYLUFtei1DcmVkZW50aWFsPUFLSUFWQ09EWUxTQTUzUFFLNFpBJTJGMjAyNDA4MDIlMkZ1cy1lYXN0LTElMkZzMyUyRmF3czRfcmVxdWVzdCZYLUFtei1EYXRlPTIwMjQwODAyVDAwMDM0OVomWC1BbXotRXhwaXJlcz0zMDAmWC1BbXotU2lnbmF0dXJlPTk2N2EyZWEyMTdhNWZjM2QyOTJmZGVhYjk3ZDc0M2YxN2I0MmM2NjA3YTI3YjMzMzUwN2E0Mjk0YmQyN2YwY2MmWC1BbXotU2lnbmVkSGVhZGVycz1ob3N0JmFjdG9yX2lkPTAma2V5X2lkPTAmcmVwb19pZD0wIn0.hHocJjh4NWC42YJCInVaKoMxpfM0p9K0mbVuBZXhXOQ)
The text was updated successfully, but these errors were encountered: