In [1]:
import sys
import pickle
import pymatreader
import numpy as np
import scipy.io as spio
from pathlib import Path

sys.path.append('PATH/code/zsl_text_imagenet/') # https://github.com/sebastianbujwid/zsl_text_imagenet

In [2]:
from data import imagenet

In [3]:
def read_split_file(imagenet_class_splits_file):
    att_split = pymatreader.read_mat(imagenet_class_splits_file)

    trainval_classes = set(np.array(att_split['trainval_classes'], dtype=np.int64))
    mp500_classes = set(np.array(att_split['mp500'], dtype=np.int64))
    assert min(trainval_classes) == 1
    assert max(trainval_classes) == 1000
    
    return trainval_classes, mp500_classes

In [None]:
imagenet_id_details = imagenet.extract_imagenet_id_details(
    '/PATH/data/zsl/synthetized_classifiers_for_zsl/ImageNet_w2v/ImageNet_w2v_extra.pkl'
)
imagenet_ancestors = pickle.load(
    open('PATH/code/zsl_text_imagenet/data/imagenet/imagenet_wordnet_ancestor_categories.pkl', 'rb')
)

In [6]:
split_file = 'PATH/data/data/zsl/zsl_a_comprehensive_evaluation/ImageNet_splits.mat'

In [7]:
orig_trainval_split, mp500_classses = read_split_file(split_file)

In [8]:
def exclude_classes(imagenet_ids, exclude_wnid_categories):
    if isinstance(exclude_wnid_categories, str):
        exclude_wnid_categories = [exclude_wnid_categories]
        
    classes = set()
    for imagenet_id in imagenet_ids:
        wnid, phrases = imagenet_id_details[imagenet_id]
        keep = True
        for exclude_wnid_c in exclude_wnid_categories:
            if exclude_wnid_c in imagenet_ancestors[wnid]:
                keep = False
        
        if keep:
            classes.add(imagenet_id)
            
    return classes, imagenet_ids.difference(classes)

In [9]:
def print_split(imagenet_ids):
    for imagenet_id in imagenet_ids:
        print(imagenet_id_details[imagenet_id])

In [10]:
# orig_trainval_split
wnid_animals = 'n00015388'
wnid_plants = ['n00017222', 'n07707451', 'n13134947']

In [11]:
# print_split(exclude_classes(orig_trainval_split, exclude_wnid_categories=wnid_animals))

In [12]:
split_exclude_animals, classes_animals = exclude_classes(orig_trainval_split, exclude_wnid_categories=wnid_animals)
split_exclude_plants, classes_plants = exclude_classes(orig_trainval_split, exclude_wnid_categories=wnid_plants)
split_rest = split_exclude_animals.intersection(split_exclude_plants)

In [13]:
assert len(classes_animals.intersection(classes_plants)) == 0

In [14]:
assert len(classes_animals) + len(classes_plants) + len(split_rest) == len(orig_trainval_split)
len(classes_animals), len(classes_plants), len(split_rest)

(398, 31, 571)

In [15]:
def check_mp500():
    print(len(mp500_classses))
    _, classes_animals = exclude_classes(mp500_classses, wnid_animals)
    _, classes_plants = exclude_classes(mp500_classses, wnid_plants)
    split_rest = (mp500_classses.difference(classes_animals)).difference(classes_plants)
    print(len(classes_animals), len(classes_plants), len(split_rest))    
    
check_mp500()

500
84 40 376


In [16]:
np.random.seed(42)

In [17]:
def exclude_nongroup(split_excluded_group):
    num_classes = len(split_excluded_group)
    group_classes = orig_trainval_split.difference(split_excluded_group)
    num_to_sample = num_classes - len(group_classes)
    
    sampled_classes = np.random.choice(sorted(list(split_excluded_group)), num_to_sample, replace=False)
    assert len(group_classes.intersection(split_excluded_group)) == 0
    
    split_classes = group_classes.union(sampled_classes)
    assert len(group_classes.intersection(sampled_classes)) == 0
    assert len(split_classes) == (len(group_classes) + len(sampled_classes))
    
    assert len(split_classes) == len(split_excluded_group)
    
    return split_classes

In [18]:
split_exclude_nonanimals = exclude_nongroup(split_exclude_animals)
split_exclude_nonplants = exclude_nongroup(split_exclude_plants)

In [19]:
def save_split(split, name):
    att_split = pymatreader.read_mat(split_file)

    trainval_classes = set(np.array(att_split['trainval_classes'], dtype=np.int64))
    
    assert min(trainval_classes) == 1
    assert max(trainval_classes) == 1000
    
    att_split['trainval_classes'] = np.array(sorted(list(split)))
    
    spio.savemat(name, att_split)

In [21]:
save_split(split_exclude_animals, 'splits_exclude_animals.mat')
save_split(split_exclude_plants, 'splits_exclude_plants.mat')
save_split(split_exclude_nonanimals, 'splits_exclude_nonanimals.mat')
save_split(split_exclude_nonplants, 'splits_exclude_nonplants.mat')

## Number of classes with aux

In [25]:
aux_feats = pickle.load(open('PATH/encode_wiki_text/matching_v4/albert-xxlarge-v2_wiki_ALL/ImageNet/ALBERT_ImageNet_trainval_classes_classes.pkl', 'rb'))
with_aux = set(aux_feats.keys())

In [32]:
print('animals:', len(with_aux.intersection(classes_animals)))
print('plants:', len(with_aux.intersection(classes_plants)))
print('rest:', len(with_aux.intersection(split_rest)))

animals: 389
plants: 31
rest: 556
