In [1]:
%env CUDA_DEVICE_ORDER=PCI_BUS_ID
%env CUDA_VISIBLE_DEVICES=0

env: CUDA_DEVICE_ORDER=PCI_BUS_ID
env: CUDA_VISIBLE_DEVICES=0


In [2]:
import sys
import pickle
import numpy as np
from pathlib import Path

sys.path.append('PATH/code/zsl_text_imagenet/') # https://github.com/sebastianbujwid/zsl_text_imagenet

In [3]:
from data import imagenet

In [4]:
wnid_animals = 'n00015388'
wnid_plants = ['n00017222', 'n07707451', 'n13134947']

In [None]:
imagenet_id_details = imagenet.extract_imagenet_id_details(
    '/PATH/data/zsl/synthetized_classifiers_for_zsl/ImageNet_w2v/ImageNet_w2v_extra.pkl'
)
imagenet_ancestors = pickle.load(
    open('PATH/code/zsl_text_imagenet/data/imagenet/imagenet_wordnet_ancestor_categories.pkl', 'rb')
)

In [9]:
def is_a(wnid, group_wnids):
    if isinstance(group_wnids, str):
        group_wnids = [group_wnids]
    
    r = False
    for g_wnid in group_wnids:
        if g_wnid in imagenet_ancestors[wnid]:
            r = True
    
    return r

def print_results(pkl, only_top1=True):
    d = pickle.load(open(pkl, 'rb'))
    
    print('top1')
    animals = {}
    plants = {}
    other = {}
    for wnid, r in d['test_unseen_top1_acc'].items():
        if is_a(wnid, wnid_animals):
            animals[wnid] = r
        elif is_a(wnid, wnid_plants):
            plants[wnid] = r
        else:
            other[wnid] = r
            
    print(f'animals [{len(animals)}]: {np.array(list(animals.values())).mean()}')
    print(f'plants [{len(plants)}]: {np.array(list(plants.values())).mean()}')
    print(f'other [{len(other)}]: {np.array(list(other.values())).mean()}')
    
    if only_top1:
        return
    
    print()
    print('top5')
    animals = {}
    plants = {}
    other = {}
    for wnid, r in d['test_unseen_top5_acc'].items():
        if is_a(wnid, wnid_animals):
            animals[wnid] = r
        elif is_a(wnid, wnid_plants):
            plants[wnid] = r
        else:
            other[wnid] = r
            
    print(f'animals [{len(animals)}]: {np.array(list(animals.values())).mean()}')
    print(f'plants [{len(plants)}]: {np.array(list(plants.values())).mean()}')
    print(f'other [{len(other)}]: {np.array(list(other.values())).mean()}')


In [7]:
albert_dir = Path('PATH/simple_zsl/eval/exclude_groups_mp500/albert_xxl_ALL_1006')
glove_dir = Path('PATH/simple_zsl/eval/exclude_groups_mp500/glove_ALL_1015')
w2v_dir = Path('PATH/simple_zsl/eval/exclude_groups_mp500/w2v_ALL_w2v1000')

In [8]:
def get_results_file(path: Path):
    r = list(path.rglob('test_*.pkl'))
    if len(r) != 1:
        raise ValueError(f'Could not get results from {path}\n{r}')
    return r[0]

# ALBERT-xxlarge

In [29]:
print('ORIGINAL')

print_results(get_results_file(Path(
    'PATH/simple_zsl/eval/eval/mp500/szsl_albert-xxl_wiki_ALL/runzslsrs-xxl1006_42'
)))

print('\n\n')
print('Exclude animals')
print_results(get_results_file(albert_dir / 'runzslsrs-xxl1006_splits_exclude_animals'))
print('\n\n')
print('Exclude non-animals')
print_results(get_results_file(albert_dir / 'runzslsrs-xxl1006_splits_exclude_nonanimals'))
print('\n\n')
print('Exclude plants')
print_results(get_results_file(albert_dir / 'runzslsrs-xxl1006_splits_exclude_plants'))
print('\n\n')
print('Exclude non-plants')
print_results(get_results_file(albert_dir / 'runzslsrs-xxl1006_splits_exclude_nonplants'))

ORIGINAL

top5
animals [82]: 0.506726861000061
plants [39]: 0.1548217236995697
other [368]: 0.4493066668510437



Exclude animals

top5
animals [82]: 0.0364919975399971
plants [39]: 0.16690871119499207
other [368]: 0.4347641170024872



Exclude non-animals

top5
animals [82]: 0.4751695394515991
plants [39]: 0.18375174701213837
other [368]: 0.30244871973991394



Exclude plants

top5
animals [82]: 0.4743635058403015
plants [39]: 0.13664023578166962
other [368]: 0.45126450061798096



Exclude non-plants

top5
animals [82]: 0.4791962802410126
plants [39]: 0.14875528216362
other [368]: 0.44973671436309814


# GloVe

In [28]:
print('ORIGINAL')

print_results(get_results_file(Path(
    'PATH/simple_zsl/eval/eval/mp500/szsl_wemb_glove_wiki_ALL/test_runzslsrs-wemb-glove-1015_42'
)))

print('\n\n')
print('Exclude animals')
print_results(get_results_file(glove_dir / 'test_runzslsrs-wemb-glove-1015_splits_exclude_animals'))
print('\n\n')
print('Exclude non-animals')
print_results(get_results_file(glove_dir / 'test_runzslsrs-wemb-glove-1015_splits_exclude_nonanimals'))
print('\n\n')
print('Exclude plants')
print_results(get_results_file(glove_dir / 'test_runzslsrs-wemb-glove-1015_splits_exclude_plants'))
print('\n\n')
print('Exclude non-plants')
print_results(get_results_file(glove_dir / 'test_runzslsrs-wemb-glove-1015_splits_exclude_nonplants'))

ORIGINAL

top5
animals [82]: 0.511250376701355
plants [39]: 0.19831953942775726
other [368]: 0.4772208631038666



Exclude animals

top5
animals [82]: 0.05460485443472862
plants [39]: 0.1775784194469452
other [368]: 0.458220511674881



Exclude non-animals

top5
animals [82]: 0.4789133369922638
plants [39]: 0.18297049403190613
other [368]: 0.3462385833263397



Exclude plants

top5
animals [82]: 0.5109047889709473
plants [39]: 0.1835872381925583
other [368]: 0.4761269688606262



Exclude non-plants

top5
animals [82]: 0.4832930564880371
plants [39]: 0.20144562423229218
other [368]: 0.47577965259552


# W2V

In [15]:
print('ORIGINAL')

print_results(get_results_file(Path(
    'PATH/simple_zsl/eval/eval/mp500/szsl_w2v_wiki_ALL/runzslrs-w2v1000_42'
)))

print('\n\n')
print('Exclude animals')
print_results(get_results_file(w2v_dir / 'runzslrs-w2v1000_splits_exclude_animals'))
print('\n\n')
print('Exclude non-animals')
print_results(get_results_file(w2v_dir / 'runzslrs-w2v1000_splits_exclude_nonanimals'))
print('\n\n')
print('Exclude plants')
print_results(get_results_file(w2v_dir / 'runzslrs-w2v1000_splits_exclude_plants'))
print('\n\n')
print('Exclude non-plants')
print_results(get_results_file(w2v_dir / 'runzslrs-w2v1000_splits_exclude_nonplants'))

ORIGINAL

top5
animals [84]: 0.33487507700920105
plants [40]: 0.1276433765888214
other [376]: 0.2772156596183777



Exclude animals

top5
animals [84]: 0.10574321448802948
plants [40]: 0.12734892964363098
other [376]: 0.24885067343711853



Exclude non-animals

top5
animals [84]: 0.3075043559074402
plants [40]: 0.16200962662696838
other [376]: 0.19652917981147766



Exclude plants

top5
animals [84]: 0.30598726868629456
plants [40]: 0.12969925999641418
other [376]: 0.26181161403656006



Exclude non-plants

top5
animals [84]: 0.33414426445961
plants [40]: 0.13313624262809753
other [376]: 0.2678578495979309


# Top-1

# ALBERT

In [10]:
print('ORIGINAL')

print_results(get_results_file(Path(
    'PATH/simple_zsl/eval/eval/mp500/szsl_albert-xxl_wiki_ALL/runzslsrs-xxl1006_42'
)))

print('\n\n')
print('Exclude animals')
print_results(get_results_file(albert_dir / 'runzslsrs-xxl1006_splits_exclude_animals'))
print('\n\n')
print('Exclude non-animals')
print_results(get_results_file(albert_dir / 'runzslsrs-xxl1006_splits_exclude_nonanimals'))
print('\n\n')
print('Exclude plants')
print_results(get_results_file(albert_dir / 'runzslsrs-xxl1006_splits_exclude_plants'))
print('\n\n')
print('Exclude non-plants')
print_results(get_results_file(albert_dir / 'runzslsrs-xxl1006_splits_exclude_nonplants'))

ORIGINAL
top1
animals [82]: 0.17738662660121918
plants [39]: 0.04505284130573273
other [368]: 0.18251322209835052



Exclude animals
top1
animals [82]: 0.0052273026667535305
plants [39]: 0.04504355415701866
other [368]: 0.1718398630619049



Exclude non-animals
top1
animals [82]: 0.15595711767673492
plants [39]: 0.04135487973690033
other [368]: 0.09228634834289551



Exclude plants
top1
animals [82]: 0.16423283517360687
plants [39]: 0.029304247349500656
other [368]: 0.18044663965702057



Exclude non-plants
top1
animals [82]: 0.17628182470798492
plants [39]: 0.04533558338880539
other [368]: 0.18030618131160736


# GloVe

In [11]:
print('ORIGINAL')

print_results(get_results_file(Path(
    'PATH/simple_zsl/eval/eval/mp500/szsl_wemb_glove_wiki_ALL/test_runzslsrs-wemb-glove-1015_42'
)))

print('\n\n')
print('Exclude animals')
print_results(get_results_file(glove_dir / 'test_runzslsrs-wemb-glove-1015_splits_exclude_animals'))
print('\n\n')
print('Exclude non-animals')
print_results(get_results_file(glove_dir / 'test_runzslsrs-wemb-glove-1015_splits_exclude_nonanimals'))
print('\n\n')
print('Exclude plants')
print_results(get_results_file(glove_dir / 'test_runzslsrs-wemb-glove-1015_splits_exclude_plants'))
print('\n\n')
print('Exclude non-plants')
print_results(get_results_file(glove_dir / 'test_runzslsrs-wemb-glove-1015_splits_exclude_nonplants'))

ORIGINAL
top1
animals [82]: 0.16066966950893402
plants [39]: 0.04735420644283295
other [368]: 0.18157660961151123



Exclude animals
top1
animals [82]: 0.015212352387607098
plants [39]: 0.04170314222574234
other [368]: 0.16079697012901306



Exclude non-animals
top1
animals [82]: 0.14752954244613647
plants [39]: 0.04186312481760979
other [368]: 0.11348345130681992



Exclude plants
top1
animals [82]: 0.15675824880599976
plants [39]: 0.03751226142048836
other [368]: 0.18325312435626984



Exclude non-plants
top1
animals [82]: 0.14199692010879517
plants [39]: 0.0430646650493145
other [368]: 0.18169493973255157


# W2V

In [12]:
print('ORIGINAL')

print_results(get_results_file(Path(
    'PATH/simple_zsl/eval/eval/mp500/szsl_w2v_wiki_ALL/runzslrs-w2v1000_42'
)))

print('\n\n')
print('Exclude animals')
print_results(get_results_file(w2v_dir / 'runzslrs-w2v1000_splits_exclude_animals'))
print('\n\n')
print('Exclude non-animals')
print_results(get_results_file(w2v_dir / 'runzslrs-w2v1000_splits_exclude_nonanimals'))
print('\n\n')
print('Exclude plants')
print_results(get_results_file(w2v_dir / 'runzslrs-w2v1000_splits_exclude_plants'))
print('\n\n')
print('Exclude non-plants')
print_results(get_results_file(w2v_dir / 'runzslrs-w2v1000_splits_exclude_nonplants'))

ORIGINAL
top1
animals [84]: 0.13174481689929962
plants [40]: 0.03085876628756523
other [376]: 0.08912833034992218



Exclude animals
top1
animals [84]: 0.023934148252010345
plants [40]: 0.035504817962646484
other [376]: 0.08015356212854385



Exclude non-animals
top1
animals [84]: 0.11274704337120056
plants [40]: 0.04302377626299858
other [376]: 0.058190081268548965



Exclude plants
top1
animals [84]: 0.12131133675575256
plants [40]: 0.028008539229631424
other [376]: 0.08286122977733612



Exclude non-plants
top1
animals [84]: 0.12293537706136703
plants [40]: 0.029959997162222862
other [376]: 0.08713536709547043
