In [None]:
import sys
sys.path.append('../')
#  Torch imports
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
import torch.backends.cudnn as cudnn
cudnn.benchmark = True

# Python imports
import numpy as np
import tqdm
import torchvision.models as tmodels
from tqdm import tqdm
import os
from os.path import join as ospj
import itertools
import glob
import random

#Local imports
from data import dataset as dset
from models.common import Evaluator
from models.image_extractor import get_image_extractor
from models.manifold_methods import RedWine, LabelEmbedPlus, AttributeOperator
from models.modular_methods import GatedGeneralNN
from models.symnet import Symnet
from utils.utils import save_args, UnNormalizer, load_args
from utils.config_model import configure_model
from flags import parser
from PIL import Image
import matplotlib.pyplot as plt
import importlib

device = 'cuda' if torch.cuda.is_available() else 'cpu'
args, unknown = parser.parse_known_args()

### Run one of the cells to load the dataset you want to run test for and move to the next section

In [None]:
best_mit = '../logs/graphembed/mitstates/base/mit.yml'
load_args(best_mit,args)
args.graph_init = '../'+args.graph_init
args.load = best_mit[:-7] + 'ckpt_best_auc.t7'

In [None]:
best_ut = '../logs/graphembed/utzappos/base/utzappos.yml'
load_args(best_ut,args)
args.graph_init = '../'+args.graph_init
args.load = best_ut[:-12] + 'ckpt_best_auc.t7'

### Loading arguments and dataset

In [None]:
args.data_dir = '../'+args.data_dir
args.test_set = 'test'
testset = dset.CompositionDataset(
        root= args.data_dir,
        phase=args.test_set,
        split=args.splitname,
        model =args.image_extractor,
        subset=args.subset,
        return_images = True,
        update_features = args.update_features,
        clean_only = args.clean_only
    )
testloader = torch.utils.data.DataLoader(
    testset,
    batch_size=args.test_batch_size,
    shuffle=True,
    num_workers=args.workers)

print('Objs ', len(testset.objs), ' Attrs ', len(testset.attrs))

In [None]:
image_extractor, model, optimizer = configure_model(args, testset)
evaluator = Evaluator(testset, model)

In [None]:
if args.load is not None:
    checkpoint = torch.load(args.load)
    if image_extractor:
        try:
            image_extractor.load_state_dict(checkpoint['image_extractor'])
            image_extractor.eval()
        except:
            print('No Image extractor in checkpoint')
    model.load_state_dict(checkpoint['net'])
    model.eval()
    print('Loaded model from ', args.load)
    print('Best AUC: ', checkpoint['AUC'])

In [None]:
def print_results(scores, exp):
    print(exp)
    result = scores[exp]
    attr = [evaluator.dset.attrs[result[0][idx,a]] for a in range(topk)]
    obj = [evaluator.dset.objs[result[1][idx,a]] for a in range(topk)]
    attr_gt, obj_gt = evaluator.dset.attrs[data[1][idx]], evaluator.dset.objs[data[2][idx]]
    print(f'Ground truth: {attr_gt} {obj_gt}')
    prediction = ''
    for a,o in zip(attr, obj):
        prediction += a + ' ' + o + '| '
    print('Predictions: ', prediction)
    print('__'*50)

### An example of predictions
closed -> Biased for unseen classes

unbiiased -> Biased against unseen classes

In [None]:
data = next(iter(testloader))
images = data[-1]
data = [d.to(device) for d in data[:-1]]
if image_extractor:
    data[0] = image_extractor(data[0])
_, predictions = model(data)
data = [d.to('cpu') for d in data]
topk = 5
results = evaluator.score_model(predictions, data[2], bias = 1000, topk=topk)

In [None]:
for idx in range(len(images)):
    seen = bool(evaluator.seen_mask[data[3][idx]])
    if seen:
        continue
    image = Image.open(ospj( args.data_dir,'images', images[idx]))
    
    plt.figure(dpi=300)
    plt.imshow(image)
    plt.axis('off')
    plt.show()
    
    print(f'GT pair seen: {seen}')
    print_results(results, 'closed')
    print_results(results, 'unbiased_closed')

### Run Evaluation

In [None]:
model.eval()
args.bias = 1e3
accuracies, all_attr_gt, all_obj_gt, all_pair_gt, all_pred = [], [], [], [], []

for idx, data in tqdm(enumerate(testloader), total=len(testloader), desc = 'Testing'):
    data.pop()
    data = [d.to(device) for d in data]
    if image_extractor:
        data[0] = image_extractor(data[0])

    _, predictions = model(data) # todo: Unify outputs across models

    attr_truth, obj_truth, pair_truth = data[1], data[2], data[3]
    all_pred.append(predictions)
    all_attr_gt.append(attr_truth)
    all_obj_gt.append(obj_truth)
    all_pair_gt.append(pair_truth)

all_attr_gt, all_obj_gt, all_pair_gt = torch.cat(all_attr_gt), torch.cat(all_obj_gt), torch.cat(all_pair_gt)

all_pred_dict = {}
# Gather values as dict of (attr, obj) as key and list of predictions as values
for k in all_pred[0].keys():
    all_pred_dict[k] = torch.cat(
        [all_pred[i][k] for i in range(len(all_pred))])

# Calculate best unseen accuracy
attr_truth, obj_truth = all_attr_gt.to('cpu'), all_obj_gt.to('cpu')
pairs = list(
    zip(list(attr_truth.numpy()), list(obj_truth.numpy())))

In [None]:
topk = 1 ### For topk results
our_results = evaluator.score_model(all_pred_dict, all_obj_gt, bias = 1e3, topk = topk)
stats = evaluator.evaluate_predictions(our_results, all_attr_gt, all_obj_gt, all_pair_gt, all_pred_dict, topk = topk)

In [None]:
for k, v in stats.items():
    print(k, v)