In [1]:
%env CUDA_DEVICE_ORDER=PCI_BUS_ID
%env CUDA_VISIBLE_DEVICES=0

env: CUDA_DEVICE_ORDER=PCI_BUS_ID
env: CUDA_VISIBLE_DEVICES=0


In [2]:
import sys
import pickle
import numpy as np
from pathlib import Path

sys.path.append('PATH/code/zsl_text_imagenet/') # https://github.com/sebastianbujwid/zsl_text_imagenet

In [3]:
from data import imagenet

In [4]:
wnid_animals = 'n00015388'
wnid_plants = ['n00017222', 'n07707451', 'n13134947']

In [None]:
imagenet_id_details = imagenet.extract_imagenet_id_details(
    '/PATH/data/zsl/synthetized_classifiers_for_zsl/ImageNet_w2v/ImageNet_w2v_extra.pkl'
)
imagenet_ancestors = pickle.load(
    open('PATH/code/zsl_text_imagenet/data/imagenet/imagenet_wordnet_ancestor_categories.pkl', 'rb')
)

In [6]:
def is_a(wnid, group_wnids):
    if isinstance(group_wnids, str):
        group_wnids = [group_wnids]
    
    r = False
    for g_wnid in group_wnids:
        if g_wnid in imagenet_ancestors[wnid]:
            r = True
    
    return r

def print_results(pkl):
    d = pickle.load(open(pkl, 'rb'))
    
    print('top1')
    animals = {}
    plants = {}
    other = {}
    for wnid, r in d['test_unseen_top1_acc'].items():
        if is_a(wnid, wnid_animals):
            animals[wnid] = r
        elif is_a(wnid, wnid_plants):
            plants[wnid] = r
        else:
            other[wnid] = r
            
    print(f'animals [{len(animals)}]: {np.array(list(animals.values())).mean()}')
    print(f'plants [{len(plants)}]: {np.array(list(plants.values())).mean()}')
    print(f'other [{len(other)}]: {np.array(list(other.values())).mean()}')
    
    print()
    print('top5')
    animals = {}
    plants = {}
    other = {}
    for wnid, r in d['test_unseen_top5_acc'].items():
        if is_a(wnid, wnid_animals):
            animals[wnid] = r
        elif is_a(wnid, wnid_plants):
            plants[wnid] = r
        else:
            other[wnid] = r
            
    print(f'animals [{len(animals)}]: {np.array(list(animals.values())).mean()}')
    print(f'plants [{len(plants)}]: {np.array(list(plants.values())).mean()}')
    print(f'other [{len(other)}]: {np.array(list(other.values())).mean()}')


In [20]:
glove_dir = Path('PATH/cada_vae/zls/eval/exclude_groups_mp500/glove_ALL-runrs-1019')
w2v_dir = Path('PATH/cada_vae/zls/eval/exclude_groups_mp500/w2v_ALL_run-0105')

In [8]:
def get_results_file(path: Path):
    r = list(path.rglob('test_results_mp500.pkl'))
    if len(r) != 1:
        raise ValueError(f'Could not get results from {path}\n{r}')
    return r[0]

# GloVe

In [13]:
print('ORIGINAL')
print()

#print_results(get_results_file(glove_dir / 'test_runrs-1019_splits_exclude_animals'))
print_results(get_results_file(Path('PATH/cada_vae/zls/eval/mp500_seeds/wemb_glove_wiki_ALL/test_runrs-1019_42')))

ORIGINAL

top1
animals [82]: 0.30406734347343445
plants [39]: 0.05324960872530937
other [368]: 0.2289818525314331

top5
animals [82]: 0.6531624794006348
plants [39]: 0.1956728994846344
other [368]: 0.5197874903678894


In [16]:
print('Exclude animals')

print_results(get_results_file(glove_dir / 'test_runrs-1019_splits_exclude_animals'))

Exclude animals
top1
animals [82]: 0.035174403339624405
plants [39]: 0.04470792040228844
other [368]: 0.23357896506786346

top5
animals [82]: 0.1655680537223816
plants [39]: 0.1770179271697998
other [368]: 0.5298458337783813


In [17]:
print('Exclude non-animals')

print_results(get_results_file(glove_dir / 'test_runrs-1019_splits_exclude_nonanimals'))

Exclude non-animals
top1
animals [82]: 0.28735244274139404
plants [39]: 0.05588745325803757
other [368]: 0.15317945182323456

top5
animals [82]: 0.6059836745262146
plants [39]: 0.21623875200748444
other [368]: 0.4171987473964691


In [18]:
print('Exclude plants')

print_results(get_results_file(glove_dir / 'test_runrs-1019_splits_exclude_plants'))

Exclude plants
top1
animals [82]: 0.2955893874168396
plants [39]: 0.039253078401088715
other [368]: 0.22637777030467987

top5
animals [82]: 0.6307775974273682
plants [39]: 0.16180501878261566
other [368]: 0.5232211351394653


In [19]:
print('Exclude non-plants')

print_results(get_results_file(glove_dir / 'test_runrs-1019_splits_exclude_nonplants'))

Exclude non-plants
top1
animals [82]: 0.2828812897205353
plants [39]: 0.06005560979247093
other [368]: 0.22614313662052155

top5
animals [82]: 0.6228954195976257
plants [39]: 0.21263594925403595
other [368]: 0.523369312286377


# W2V

In [23]:
print('ORIGINAL')

print_results(get_results_file(Path('PATH/cada_vae/zls/eval/mp500_seeds/w2v_wiki_ALL_better/run-0105_1lv3-arch-l1024-hi-sharedspaceenforce-b128_42')))

print('\n\n')
print('Exclude animals')
print_results(get_results_file(w2v_dir / 'run-0105_1lv3-arch-l1024-hi-sharedspaceenforce-b128_splits_exclude_animals'))
print('\n\n')
print('Exclude non-animals')
print_results(get_results_file(w2v_dir / 'run-0105_1lv3-arch-l1024-hi-sharedspaceenforce-b128_splits_exclude_nonanimals'))
print('\n\n')
print('Exclude plants')
print_results(get_results_file(w2v_dir / 'run-0105_1lv3-arch-l1024-hi-sharedspaceenforce-b128_splits_exclude_plants'))
print('\n\n')
print('Exclude non-plants')
print_results(get_results_file(w2v_dir / 'run-0105_1lv3-arch-l1024-hi-sharedspaceenforce-b128_splits_exclude_nonplants'))

ORIGINAL
top1
animals [84]: 0.24303874373435974
plants [40]: 0.04472198337316513
other [376]: 0.15240199863910675

top5
animals [84]: 0.5349633097648621
plants [40]: 0.16210314631462097
other [376]: 0.3771316409111023



Exclude animals
top1
animals [84]: 0.018810583278536797
plants [40]: 0.059906255453825
other [376]: 0.15394273400306702

top5
animals [84]: 0.09239127486944199
plants [40]: 0.1731281280517578
other [376]: 0.37351885437965393



Exclude non-animals
top1
animals [84]: 0.23640015721321106
plants [40]: 0.040690865367650986
other [376]: 0.10637575387954712

top5
animals [84]: 0.5178735256195068
plants [40]: 0.1280914843082428
other [376]: 0.2969052195549011



Exclude plants
top1
animals [84]: 0.2369118183851242
plants [40]: 0.03404001519083977
other [376]: 0.15326803922653198

top5
animals [84]: 0.539307713508606
plants [40]: 0.14207112789154053
other [376]: 0.37735289335250854



Exclude non-plants
top1
animals [84]: 0.2434440404176712
plants [40]: 0.04124293476343155
oth