# Superclassing dataset

In [3]:
in_path = '/cluster/scratch/data/ori/'
in_info_path = '/cluster/home/AADefDINO/imagenet_metadata'

In [4]:
import os
import numpy as np
import json
from itertools import product

class Node():
    '''
    Class for representing a node in the ImageNet/WordNet hierarchy. 
    '''
    def __init__(self, wnid, parent_wnid=None, name=""):
        """
        Args:
            wnid (str) : WordNet ID for synset represented by node
            parent_wnid (str) : WordNet ID for synset of node's parent
            name (str) : word/human-interpretable description of synset 
        """

        self.wnid = wnid
        self.name = name
        self.class_num = -1
        self.parent_wnid = parent_wnid
        self.descendant_count_in = 0
        self.descendants_all = set()
    
    def add_child(self, child):
        """
        Add child to given node.
        Args:
            child (Node) : Node object for child
        """
        child.parent_wnid = self.wnid
    
    def __str__(self):
        return f'Name: ({self.name}), ImageNet Class: ({self.class_num}), Descendants: ({self.descendant_count_in})'
    
    def __repr__(self):
        return f'Name: ({self.name}), ImageNet Class: ({self.class_num}), Descendants: ({self.descendant_count_in})'

class ImageNetHierarchy():
    '''
    Class for representing ImageNet/WordNet hierarchy. 
    '''
    def __init__(self, ds_path, ds_info_path):
        """
        Args:
            ds_path (str) : Path to ImageNet dataset
            ds_info_path (str) : Path to supplementary files for the ImageNet dataset 
                                 ('wordnet.is_a.txt', 'words.txt' and 'imagenet_class_index.json')
                                 which can be obtained from http://image-net.org/download-API.
        """
        self.tree = {}

        ret = self.load_imagenet_info(ds_path, ds_info_path)
        self.in_wnids, self.wnid_to_name, self.wnid_to_num, self.num_to_name = ret
            
        with open(os.path.join(ds_info_path, 'wordnet.is_a.txt'), 'r') as f:
            for line in f.readlines():
                parent_wnid, child_wnid = line.strip('\n').split(' ')
                parentNode = self.get_node(parent_wnid)
                childNode = self.get_node(child_wnid)
                parentNode.add_child(childNode)
                
        for wnid in self.in_wnids:
            self.tree[wnid].descendant_count_in = 0
            self.tree[wnid].class_num = self.wnid_to_num[wnid]
            
        for wnid in self.in_wnids:
            node = self.tree[wnid]
            while node.parent_wnid is not None:
                self.tree[node.parent_wnid].descendant_count_in += 1
                self.tree[node.parent_wnid].descendants_all.update(node.descendants_all)
                self.tree[node.parent_wnid].descendants_all.add(node.wnid)
                node = self.tree[node.parent_wnid]
        
        del_nodes = [wnid for wnid in self.tree \
                     if (self.tree[wnid].descendant_count_in == 0 and self.tree[wnid].class_num == -1)]
        for d in del_nodes:
            self.tree.pop(d, None)
                        
        assert all([k.descendant_count_in > 0 or k.class_num != -1 for k in self.tree.values()])

        self.wnid_sorted = sorted(sorted([(k, v.descendant_count_in, len(v.descendants_all)) \
                                        for k, v in self.tree.items()
                                        ],
                                        key=lambda x: x[2], 
                                        reverse=True
                                        ),
                                key=lambda x: x[1], 
                                reverse=True
                                )

    @staticmethod
    def load_imagenet_info(ds_path, ds_info_path):
        """
        Get information about mapping between ImageNet wnids/class numbers/class names.
        Args:
            ds_path (str) : Path to ImageNet dataset
            ds_info_path (str) : Path to supplementary files for the ImageNet dataset 
                                 ('wordnet.is_a.txt', 'words.txt', 'imagenet_class_index.json')
                                 which can be obtained from http://image-net.org/download-API.
        """
        files = os.listdir(os.path.join(ds_path, 'train/images'))
        in_wnids = [f.split("_")[0] for f in files if f[0]=='n']

        f = open(os.path.join(ds_info_path, 'words.txt'))
        wnid_to_name = [l.strip() for l in f.readlines()]
        wnid_to_name = {l.split('\t')[0]: l.split('\t')[1] \
                             for l in wnid_to_name}

        with open(os.path.join(ds_info_path, 'imagenet_class_index.json'), 'r') as f:
            base_map = json.load(f)
            wnid_to_num = {v[0]: int(k) for k, v in base_map.items()}
            num_to_name = {int(k): v[1] for k, v in base_map.items()}

        return in_wnids, wnid_to_name, wnid_to_num, num_to_name

    def get_node(self, wnid):
        """
        Add node to tree.
        Args:
            wnid (str) : WordNet ID for synset represented by node
        Returns:
            A node object representing the specified wnid.
        """
        if wnid not in self.tree:
            self.tree[wnid] = Node(wnid, name=self.wnid_to_name[wnid])
        return self.tree[wnid]


    def is_ancestor(self, ancestor_wnid, child_wnid):
        """
        Check if a node is an ancestor of another.
        Args:
            ancestor_wnid (str) : WordNet ID for synset represented by ancestor node
            child_wnid (str) : WordNet ID for synset represented by child node
        Returns:
            A boolean variable indicating whether or not the node is an ancestor
        """
        return (child_wnid in self.tree[ancestor_wnid].descendants_all)

    
    def get_descendants(self, node_wnid, in_imagenet=False):
        """
        Get all descendants of a given node.
        Args:
            node_wnid (str) : WordNet ID for synset for node
            in_imagenet (bool) : If True, only considers descendants among 
                                ImageNet synsets, else considers all possible
                                descendants in the WordNet hierarchy
        Returns:
            A set of wnids corresponding to all the descendants
        """        
        if in_imagenet:
            return set([self.wnid_to_num[ww] for ww in self.tree[node_wnid].descendants_all
                        if ww in set(self.in_wnids)])
        else:
            return self.tree[node_wnid].descendants_all
    
    def get_superclasses(self, n_superclasses, 
                         ancestor_wnid=None, superclass_lowest=None, 
                         balanced=True):
        """
        Get superclasses by grouping together classes from the ImageNet dataset.
        Args:
            n_superclasses (int) : Number of superclasses desired
            ancestor_wnid (str) : (optional) WordNet ID that can be used to specify
                                common ancestor for the selected superclasses
            superclass_lowest (set of str) : (optional) Set of WordNet IDs of nodes
                                that shouldn't be further sub-classes
            balanced (bool) : If True, all the superclasses will have the same number
                            of ImageNet subclasses
        Returns:
            superclass_wnid (list): List of WordNet IDs of superclasses
            class_ranges (list of sets): List of ImageNet subclasses per superclass
            label_map (dict): Mapping from class number to human-interpretable description
                            for each superclass
        """             
        
        assert superclass_lowest is None or \
               not any([self.is_ancestor(s1, s2) for s1, s2 in product(superclass_lowest, superclass_lowest)])
         
        superclass_info = []
        for (wnid, ndesc_in, ndesc_all) in self.wnid_sorted:
            
            if len(superclass_info) == n_superclasses:
                break
                
            if ancestor_wnid is None or self.is_ancestor(ancestor_wnid, wnid):
                keep_wnid = [True] * (len(superclass_info) + 1)
                superclass_info.append((wnid, ndesc_in))
                
                for ii, (w, d) in enumerate(superclass_info):
                    if self.is_ancestor(w, wnid):
                        if superclass_lowest and w in superclass_lowest:
                            keep_wnid[-1] = False
                        else:
                            keep_wnid[ii] = False
                
                for ii in range(len(superclass_info) - 1, -1, -1):
                    if not keep_wnid[ii]:
                        superclass_info.pop(ii)
            
        superclass_wnid = [w for w, _ in superclass_info]
        class_ranges, label_map = self.get_subclasses(superclass_wnid, 
                                    balanced=balanced)
                
        return superclass_wnid, class_ranges, label_map


    def get_subclasses(self, superclass_wnid, balanced=True):
        """
        Get ImageNet subclasses for a given set of superclasses from the WordNet 
        hierarchy. 
        Args:
            superclass_wnid (list): List of WordNet IDs of superclasses
            balanced (bool) : If True, all the superclasses will have the same number
                            of ImageNet subclasses
        Returns:
            class_ranges (list of sets): List of ImageNet subclasses per superclass
            label_map (dict): Mapping from class number to human-interpretable description
                            for each superclass
        """      
        ndesc_min = min([self.tree[w].descendant_count_in for w in superclass_wnid]) 
        class_ranges, label_map = [], {}
        for ii, w in enumerate(superclass_wnid):
            descendants = self.get_descendants(w, in_imagenet=True)
            if balanced and len(descendants) > ndesc_min:
                descendants = set([dd for ii, dd in enumerate(sorted(list(descendants))) if ii < ndesc_min])
            class_ranges.append(descendants)
            label_map[ii] = self.tree[w].name
            
        for i in range(len(class_ranges)):
            for j in range(i + 1, len(class_ranges)):
                assert(len(class_ranges[i].intersection(class_ranges[j])) == 0)
                
        return class_ranges, label_map

def common_superclass_wnid(group_name):
    """
        Get WordNet IDs of common superclasses. 
        Args:
            group_name (str): Name of group
        Returns:
            superclass_wnid (list): List of WordNet IDs of superclasses
        """    
    common_groups = {

        # ancestor_wnid = 'n00004258'
        'living_9': ['n02084071', #dog, domestic dog, Canis familiaris
                    'n01503061', # bird
                    'n01767661', # arthropod
                    'n01661091', # reptile, reptilian
                    'n02469914', # primate
                    'n02512053', # fish
                    'n02120997', # feline, felid
                    'n02401031', # bovid
                    'n01627424', # amphibian
                    ],

        'mixed_10': [
                     'n02084071', #dog,
                     'n01503061', #bird 
                     'n02159955', #insect 
                     'n02484322', #monkey 
                     'n02958343', #car 
                     'n02120997', #feline 
                     'n04490091', #truck 
                     'n13134947', #fruit 
                     'n12992868', #fungus 
                     'n02858304', #boat 
                     ],

        'mixed_13': ['n02084071', #dog,
                     'n01503061', #bird (52)
                     'n02159955', #insect (27)
                     'n03405725', #furniture (21)
                     'n02512053', #fish (16),
                     'n02484322', #monkey (13)
                     'n02958343', #car (10)
                     'n02120997', #feline (8),
                     'n04490091', #truck (7)
                     'n13134947', #fruit (7)
                     'n12992868', #fungus (7)
                     'n02858304', #boat (6)  
                     'n03082979', #computer(6)
                    ],

        # Dataset from Geirhos et al., 2018: arXiv:1811.12231
        'geirhos_16': ['n02686568', #aircraft (3)
                       'n02131653', #bear (3)
                       'n02834778', #bicycle (2)
                       'n01503061', #bird (52)
                       'n02858304', #boat (6)
                       'n02876657', #bottle (7)
                       'n02958343', #car (10)
                       'n02121808', #cat (5)
                       'n03001627', #char (4)
                       'n03046257', #clock (3)
                       'n02084071', #dog (116)
                       'n02503517', #elephant (2)
                       'n03614532', #keyboard (3)
                       'n03623556', #knife (2)
                       'n03862676', #oven (2)
                       'n04490091', #truck (7)
                      ],
        'big_12':  ['n02084071', #dog (100+)
                     'n04341686', #structure (55)
                     'n01503061', #bird (52)
                     'n03051540', #clothing (48)
                     'n04576211', #wheeled vehicle
                     'n01661091', #reptile, reptilian (36)
                     'n02075296', #carnivore
                     'n02159955', #insect (27)
                     'n03800933', #musical instrument (26)
                     'n07555863', #food (24)
                     'n03405725', #furniture (21)
                     'n02469914', #primate (20)
                   ],
        'mid_12':  ['n02084071', #dog (100+)
                      'n01503061', #bird (52)
                      'n04576211', #wheeled vehicle
                      'n01661091', #reptile, reptilian (36)
                      'n02075296', #carnivore
                      'n02159955', #insect (27)
                      'n03800933', #musical instrument (26)
                      'n07555863', #food (24)
                      'n03419014', #garment (24)
                      'n03405725', #furniture (21)
                      'n02469914', #primate (20)
                      'n02512053', #fish (16)
                    ]
    }

    if group_name in common_groups:
        superclass_wnid = common_groups[group_name]
        return superclass_wnid
    else:
        raise ValueError("Custom group does not exist")

### Recover ImageNet hierarchy

In [6]:
#from robustness.tools.imagenet_helpers import ImageNetHierarchy

in_hier = ImageNetHierarchy(in_path,
                            in_info_path)

### Take example superclasses
This library uses 2012 labels.

In [7]:
from robustness.tools.imagenet_helpers import common_superclass_wnid

superclass_wnid = common_superclass_wnid('living_9')
class_ranges, label_map = in_hier.get_subclasses(superclass_wnid,
                                                 balanced=True)

In [71]:
label_map

{0: 'dog, domestic dog, Canis familiaris',
 1: 'bird',
 2: 'arthropod',
 3: 'reptile, reptilian',
 4: 'primate',
 5: 'fish',
 6: 'feline, felid',
 7: 'bovid',
 8: 'amphibian'}

In [77]:
class_ranges[5]

{0, 1, 2, 3, 4, 5, 6, 389, 390, 391, 392, 393, 394, 395, 396, 397}

In [94]:
inv_map = {}

for k, v in enumerate(class_ranges):
    for l in v:
        inv_map[l] = k

inv_map

{151: 0,
 152: 0,
 153: 0,
 154: 0,
 155: 0,
 156: 0,
 157: 0,
 158: 0,
 159: 0,
 160: 0,
 161: 0,
 162: 0,
 163: 0,
 164: 0,
 165: 0,
 166: 0,
 167: 0,
 168: 0,
 169: 0,
 170: 0,
 173: 0,
 174: 0,
 175: 0,
 176: 0,
 177: 0,
 178: 0,
 179: 0,
 180: 0,
 181: 0,
 182: 0,
 183: 0,
 184: 0,
 185: 0,
 186: 0,
 187: 0,
 188: 0,
 189: 0,
 190: 0,
 191: 0,
 192: 0,
 193: 0,
 194: 0,
 195: 0,
 196: 0,
 197: 0,
 198: 0,
 199: 0,
 200: 0,
 201: 0,
 202: 0,
 203: 0,
 204: 0,
 205: 0,
 206: 0,
 207: 0,
 208: 0,
 209: 0,
 210: 0,
 211: 0,
 212: 0,
 213: 0,
 214: 0,
 215: 0,
 216: 0,
 217: 0,
 218: 0,
 219: 0,
 220: 0,
 221: 0,
 222: 0,
 223: 0,
 224: 0,
 225: 0,
 226: 0,
 227: 0,
 228: 0,
 229: 0,
 230: 0,
 231: 0,
 232: 0,
 233: 0,
 234: 0,
 235: 0,
 236: 0,
 237: 0,
 238: 0,
 239: 0,
 240: 0,
 241: 0,
 242: 0,
 243: 0,
 244: 0,
 245: 0,
 246: 0,
 247: 0,
 248: 0,
 249: 0,
 250: 0,
 251: 0,
 252: 0,
 253: 0,
 254: 0,
 255: 0,
 256: 0,
 257: 0,
 258: 0,
 259: 0,
 260: 0,
 261: 0,
 262: 0,
 263: 0,
 

# Filter data based on previous selection (balanced)

In [105]:
all_selected_classes = [item for sublist in class_ranges for item in sublist]

In [106]:
print("{} subclasses were selected".format(len(all_selected_classes)))

312 subclasses were selected


In [107]:
def filter_data(df):
    df_fil = df[df['label'].isin(all_selected_classes)]
    df_fil['reduced_label'] = df_fil['label'].map(inv_map)
    return df_fil

In [108]:
# Load train labels (they use 2017 labels)
import pandas as pd

train_df = pd.read_csv("../data_dir/ori/train/labels.csv", header=None, sep=" ", names=['image', 'label'])
eval_df = pd.read_csv("../data_dir/ori/validation/labels.csv", header=None, sep=" ", names=['image', 'label'])

In [109]:
train_df.head()

Unnamed: 0,image,label
0,n03467068_11425.JPEG,583
1,n03733281_4717.JPEG,646
2,n03840681_24428.JPEG,684
3,n04389033_26185.JPEG,847
4,n01985128_50181.JPEG,124


In [145]:
train_mapped = filter_data(train_df)
eval_mapped = filter_data(eval_df)

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df_fil['reduced_label'] = df_fil['label'].map(inv_map)


In [146]:
train_mapped.head()

Unnamed: 0,image,label,reduced_label
4,n01985128_50181.JPEG,124,2
9,n01734418_699.JPEG,56,3
12,n02102480_8523.JPEG,220,0
14,n02110185_12433.JPEG,250,0
18,n01819313_6286.JPEG,89,1


In [147]:
eval_mapped.head()

Unnamed: 0,image,label,reduced_label
0,ILSVRC2012_val_00000001.JPEG,65,3
2,ILSVRC2012_val_00000003.JPEG,230,0
5,ILSVRC2012_val_00000006.JPEG,57,3
11,ILSVRC2012_val_00000012.JPEG,286,6
12,ILSVRC2012_val_00000013.JPEG,370,4


In [148]:
train_mapped.reduced_label.value_counts()

0    145273
1     66141
2     61100
3     46354
4     26000
5     20334
7     11700
8     10400
6     10400
Name: reduced_label, dtype: int64

### Balance data

In [149]:
def balance_dataset(df, col="reduced_label"):
    df_ = df.groupby('reduced_label')
    return df_.apply(lambda x: x.sample(df_.size().min(), random_state=1).reset_index(drop=True)).reset_index(drop=True)

In [150]:
train_mapped = balance_dataset(train_mapped)

In [151]:
eval_mapped = balance_dataset(eval_mapped)

## Create new filtered folders and store data

In [152]:
source_train_folder = "/cluster/scratch/data/ori/train/images"
target_train_folder = "/cluster/scratch/data/ori/filtered/train/images"

source_eval_folder = "/cluster/scratch/data/ori/validation/images"
target_eval_folder = "/cluster/scratch/data/ori/filtered/validation/images"

In [144]:
os.makedirs(target_train_folder)
os.makedirs(target_eval_folder)

In [None]:
import shutil

for image in train_mapped['image'].values:
    shutil.copyfile(os.path.join(source_train_folder, image), os.path.join(target_train_folder, image))

In [None]:
for image in eval_mapped['image'].values:
    shutil.copyfile(os.path.join(source_eval_folder, image), os.path.join(target_eval_folder, image))

In [None]:
train_mapped.to_csv(os.path.join(target_train_folder.replace("/images", ""), "labels.csv"))

In [None]:
eval_mapped.to_csv(os.path.join(target_eval_folder.replace("/images", ""), "labels.csv"))