In [1]:
import os
import torch
import numpy as np
from torch import nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import torchvision.transforms as T
from torch.utils.data import DataLoader

from apn import APN
from data.cub.cub_dataset import CUBDataset

# Load Dataset and `APN` Model

In [3]:
test_transforms = T.Compose([
    T.Resize(size=448),
    T.CenterCrop(size=448),
    T.ToTensor()
])

dataset_val = CUBDataset(os.path.join('datasets', 'CUB'), num_attrs=107, split='val', transforms=test_transforms)
dataloader_val = DataLoader(dataset=dataset_val, batch_size=4, shuffle=True, num_workers=8)
dataloader_val_iter = iter(dataloader_val)

In [5]:
backbone_weights = torch.load('checkpoints/resnet101_ft_CUB.pt', map_location='cpu')
apn_net = APN(num_classes=200,
              num_attrs=107,
              class_attr_embs=dataset_val.class_attr_embs,
              backbone_name='resnet101',
              backbone_weights=backbone_weights,
              dist='dot')
full_weghts = torch.load('checkpoints/apn_CUB.pt', map_location='cpu')
apn_net.load_state_dict(full_weghts)

In [6]:
apn_net.eval()
with torch.no_grad():
    batch_dict = next(dataloader_val_iter)
    outputs = apn_net(batch_dict)

In [7]:
list(outputs.keys())

In [26]:
dataset_val.attribute_df.reset_index(drop=True).iloc[[45,26,92,60,103]]

In [34]:
dataset_val.attribute_df.reset_index(drop=True).iloc[[50, 11, 26, 22, 86, 17, 84, 71, 40, 51]]

In [38]:
with open('attrs.txt', 'w') as fp:
    dataset_val.attribute_df.to_string(fp)

In [33]:
torch.topk(batch_dict['attr_scores'][0], 10)

In [22]:
torch.topk(dataset_val.class_attr_embs[60], 5)

In [32]:
dataset_val.main_df.iloc[3495]

In [30]:
batch_dict['image_ids']

In [24]:
plt.imshow(batch_dict['pixel_values'][0].permute(1,2,0))

In [12]:
torch.argmax(outputs['class_scores'], dim=-1)

In [13]:
batch_dict['class_ids']

In [10]:
outputs['class_scores']

In [None]:
max_attn_values = F.max_pool2d(outputs['attn_maps'], kernel_size=(14, 14))

In [None]:
top_attn_values = torch.topk(max_attn_values.squeeze(), k=5, dim=-1)

In [None]:
plt.imshow(outputs['attn_maps'][0, 10, ...].numpy())


In [None]:
torch.max(outputs['attn_maps'][0, 10, ...])

In [None]:
plt.imshow(outputs['attn_maps'][0, 0, ...].numpy())

In [None]:
plt.imshow(batch_dict['pixel_values'][0].permute(1,2,0).numpy())

In [None]:
max_attn_values.squeeze()[0]

In [None]:
dataset_val.attr_class_map[188]

In [None]:
dataset_val.main_df.iloc[11065]['file_path']

In [None]:
dataset_val.attribute_df

In [None]:
torch.topk(dataset_val.attr_class_map[187], k=5)

In [None]:
batch_dict['class_ids']

In [None]:
batch_dict['image_ids']

In [None]:
batch_dict['pixel_values']

In [None]:
normalize = T.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
normalize(batch_dict['pixel_values'])