In [1]:
import json
import pickle

import numpy as np
import torch
from sentence_transformers import SentenceTransformer
from torch import nn
from tqdm import tqdm
from tqdm.notebook import tqdm


def eval_fgsm(root = '.', dataset = 'CUB', corpus_name = 'aab'):

    class SiameseFCN(nn.Module):
        def __init__(self, shared_in_features=1024, shared_out_feature=64, num_labels=2):
            super(SiameseFCN, self).__init__()
            self.shared_fc = nn.Sequential(
                nn.Linear(shared_in_features, 256),
                nn.ReLU(),
                nn.Linear(256, 64),
                nn.ReLU(),
                nn.Linear(64, shared_out_feature),
                nn.ReLU())
            self.classifier_in_features = 3 * shared_out_feature
            self.classifier = nn.Linear(self.classifier_in_features, num_labels)

        def forward(self, rep_a, rep_b):
            rep_a = self.shared_encode(rep_a)
            rep_b = self.shared_encode(rep_b)

            return self.get_classifier_scores(rep_a, rep_b)

        def shared_encode(self, x):
            return self.shared_fc(x)

        def get_classifier_scores(self, rep_a, rep_b):
            return self.classifier(torch.cat((rep_a, rep_b, torch.abs(rep_a - rep_b)), 1))

    device = 'cuda'
    chk = f'{root}/checkpoints/stage_two/{dataset.lower()}.pth'

    model = SiameseFCN(shared_in_features=1024, shared_out_feature=32, num_labels=3)
    model.load_state_dict(torch.load(chk)["state_dict"])
    model.to(device)
    model.eval();

    def accuracy(output, target, present_classes, topk=(1, 5)):
        """Computes the accuracy over the k top predictions for the specified values of k"""
        with torch.no_grad():
            maxk = max(topk)

            _, pred = output.topk(maxk, dim=1, largest=True, sorted=True)
            pred = pred.t()
            correct = pred.eq(target.view(1, -1).expand_as(pred))

            res = []
            for k in topk:
                correct_k = correct[:k].float().sum(0)
                res.append(compute_per_class_metric(correct_k, target, present_classes))
            return [r.item() * 100. for r in res]

    def rank(output, target, present_classes):
        """Computes the accuracy over the k top predictions for the specified values of k"""
        with torch.no_grad():
            pred = output.argsort(dim=1, descending=True)

            rank = pred.eq(target.view(-1, 1).expand_as(pred)).nonzero(as_tuple=False)[:, 1].float()
            return compute_per_class_metric(rank + 1, target, present_classes).item()

    def compute_per_class_metric(metric, target, present_classes):
        acc_per_class = 0.
        for i in present_classes:
            idx = (target == i)
            e = torch.true_divide(torch.sum(metric[idx]), torch.sum(idx))
            acc_per_class += e
        return acc_per_class / len(present_classes)

    with open(f'{root}/datasets/{dataset}/FGSM/captions_gt_features.pickle', 'rb') as handle:
        gt_features = pickle.load(handle)
    with open(f'{root}/datasets/{dataset}/FGSM/captions_prediction_trainval_sup_sat_features.pickle', 'rb') as handle:
        sat_features = pickle.load(handle)
    with open(f'{root}/datasets/{dataset}/FGSM/captions_prediction_trainval_sup_aoanet_features.pickle', 'rb') as handle:
        aoanet_features = pickle.load(handle)
        
    
    with open(f'{root}/datasets/{dataset}/FGSM/corpus_{corpus_name}_cleaned_features.pickle', 'rb') as handle:
        corpus_features = pickle.load(handle)

    data = json.load(open(f'{root}/datasets/{dataset}/image_data.json', 'r'))
    test_images = [data['images'][i] for i in data['supervised_test_loc']]
    test_classes = sorted(corpus_features.keys())
    test_corpus_features = [corpus_features[c].to(device) for c in test_classes]

    def cartesian_classifier_on_embeddings(model, feat_a, feat_b):
        output = model.get_classifier_scores(torch.repeat_interleave(feat_a, repeats=feat_b.size(0), dim=0), feat_b.repeat(feat_a.size(0), 1))
        return output.view(len(feat_a), len(feat_b), -1)

    document_list = test_corpus_features
    document_lengths = [len(docs) for docs in document_list]
    document_lengths = [0] + np.cumsum(document_lengths, 0).tolist()
    document_spans = [(document_lengths[i], document_lengths[i + 1]) for i in range(len(document_lengths) - 1)]
    assert all(docs.size(0) == end - start for docs, (start, end) in zip(document_list, document_spans))
    document_list = torch.cat(document_list, dim=0)

    def reduce_docs(doc_scores):  # [ batch x num_caps x num_sents ] x docs
        x = [doc.mean(dim=2, keepdim=True) for doc in doc_scores]  # [ batch x num_caps x 1 ] x docs
        return x

    gt = []
    pred = []
    with torch.no_grad():
        for i, image_detail in enumerate(test_images):
            caption_feat_input = torch.cat([sat_features[image_detail['id']],
                                   aoanet_features[image_detail['id']]
                                   ], dim=0).to(device)
            sentence_embs = model.shared_fc(caption_feat_input).unsqueeze(0)

            batch, num_captions, _ = sentence_embs.size()
            sentence_embs = sentence_embs.view(-1, sentence_embs.size(2))
            document_embeddings = model.shared_fc(document_list)  # num_docs*num_sents x siamese_dim
            doc_scores = cartesian_classifier_on_embeddings(model, sentence_embs, document_embeddings)  # batch*num_captions x num_docs*num_sents
            doc_scores = [doc_scores[:, start:end].view(batch, num_captions, end - start, doc_scores.size(2)) for start, end in document_spans]  # [ batch x num_caps x num_sents ] x num_docs
            doc_scores = [doc.softmax(3) for doc in doc_scores]  # [ batch x num_caps x num_sents ] x num_docs
            doc_scores = [doc[:, :, :, 1] - doc[:, :, :, 0] for doc in doc_scores]  # [ batch x num_caps x num_sents ] x num_docs
            doc_scores = reduce_docs(doc_scores)  # [ batch x num_caps ] x num_docs
            doc_scores = [doc.mean(dim=1) for doc in doc_scores]  # [ batch x num_caps x num_sents ] x num_docs
            doc_scores = torch.cat(doc_scores, 1)  # batch x num_docs
            doc_scores = doc_scores.softmax(1)

            gt.append(test_classes.index(image_detail['class_name']))
            pred.append(doc_scores)
            
    gt = np.array(gt)
    pred = torch.cat(pred, 0).cpu()
    acc = accuracy(output=pred, target=torch.from_numpy(gt), present_classes=range(len(test_classes)), topk=(1, 5))
    mr = rank(output=pred, target=torch.from_numpy(gt), present_classes=range(len(test_classes)))

    print(f'[{dataset}] [{corpus_name}] top-1: {acc[0]:.2f} top-5: {acc[1]:.2f} mean_rank: {mr:.2f}')

In [2]:
eval_fgsm(dataset = 'CUB', corpus_name = 'aab')
eval_fgsm(dataset = 'FLO', corpus_name = 'wiki')

[CUB] [aab] top-1: 7.87 top-5: 28.57 mean_rank: 31.92
[FLO] [wiki] top-1: 6.24 top-5: 14.19 mean_rank: 39.70
