### View Attention Maps
Generate similarity based attention maps from a trained model

In [None]:
# %%
# Create dataset
import pandas as pd
from tqdm import tqdm
import torch
from torch.utils.data import DataLoader
import pytorch_lightning as pl
from collections import defaultdict

from tools.dataset import AttributePairsDataset
from tools.preprocessing import get_transform
from tools.model import AttEmbeddingModel

batch_size = 10
img_dir = "data/category"

def collate_fn(batch):
    ims, cls_idx = list(zip(*batch))
    im1, im2 = list(zip(*ims))
    im1 = torch.stack(im1)
    im2 = torch.stack(im2)
    return (im1, im2), cls_idx

ds = AttributePairsDataset(
        annot_path="data/category/Anno_coarse/list_attr_img.txt",
        pairs_per_class=1000,
        img_dir=img_dir,
        transform=get_transform()
    )

train_dataloader = DataLoader(ds,
        batch_size=batch_size,
        shuffle=False,
        num_workers=8, 
        pin_memory=True,
        collate_fn=collate_fn
    )


In [None]:
model = AttEmbeddingModel(beta=1.0, alpha=1e-3)
model.load_state_dict(torch.load("lightning_logs/version_7/checkpoints/epoch=1-step=250.ckpt")["state_dict"])
model.eval()

n_batches = 10
i_batch = 0
im_to_att = defaultdict(lambda: [])

with torch.no_grad():
    for i_batch, ((im1, im2), cls_idx) in tqdm(enumerate(train_dataloader), total=n_batches):
        i_start = batch_size * i_batch
        i_end = i_start + len(im1)

        pair_paths = ds.pairs[i_start:i_end]

        sim_emb, dis_emb, sim_att, dis_att = model(im1.cuda(), im2.cuda(), return_att=True)
        f1_emb, f2_emb = sim_emb
        f1_dis_emb, f2_dis_emb = dis_emb
        sim_f1, sim_f2 = sim_att
        dis_f1, dis_f2 = dis_att

        for i_ex in range(batch_size):
            im_to_att[pair_paths[i_ex][0][0]].append(sim_f1[i_ex].cpu().detach().numpy())
            im_to_att[pair_paths[i_ex][0][1]].append(sim_f2[i_ex].cpu().detach().numpy())

        i_batch += 1
        if i_batch == n_batches:
            break

# %%

In [None]:
from PIL import Image
import os
import numpy as np
import matplotlib.pyplot as plt

def show_img(fn):
    img = Image.open(os.path.join(img_dir, fn))
    plt.imshow(np.asarray(img))

In [None]:
fn = list(im_to_att.keys())[1]
print(fn)
show_img(fn)

In [None]:
import seaborn as sns
sns.heatmap(im_to_att[fn][0])

In [None]:
sns.heatmap(im_to_att[fn][1])