The original ImageNet dataset consists of 1000 classes of which a substantiall amount are pets (I did not go through the classes but 118 dog breeds are reported [here](https://github.com/megvii-research/FSSD_OoD_Detection/issues/1)). So the idea is simply to use the ImageNet trained ResNet on those to infer those classes and use them for further exploration. 

To reduce the amount of result categories, we use the hierarchical nature of ImageNet/WorldNet and go up one step from each label. But if this step up is too generic, we omit it.



In [None]:
import os
import json
import pickle
from glob import glob
from PIL import Image

import torch
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

from torchvision import models, transforms
from torch.utils.data import DataLoader, Dataset

from os.path import basename, splitext
from collections import Counter

# Load ImageNet Classes and hierarchy information

In [None]:
%%capture
!wget -nc https://s3.amazonaws.com/deep-learning-models/image-models/imagenet_class_index.json
!wget -nc https://git.tools.f4.htw-berlin.de/smi/pytorch-hierarchical-imagenet-dataset/raw/0ef57a30eb1b7887e35ad59529c00f7ffcfc121a/wordnet.is_a.txt
!wget -nc https://git.tools.f4.htw-berlin.de/smi/pytorch-hierarchical-imagenet-dataset/raw/0ef57a30eb1b7887e35ad59529c00f7ffcfc121a/words.txt

path = '/kaggle/input/petfinder-pawpularity-score/'
train_df = pd.read_csv(path + '/train.csv')
#test_df = pd.read_csv(path + '/test.csv') not used for now!

# those are the classes in resnet
with open('/kaggle/working/imagenet_class_index.json') as f:
    resnet_classes = [id for (id, name) in list(json.load(f).values())]

# parent element for each class
child_parent = {}
with open('/kaggle/working/wordnet.is_a.txt') as f:
    for line in f.readlines():
        parent, child = line.split()
        child_parent[child] = parent

# label map
labels = {}
with open('/kaggle/working/words.txt') as f:
    for line in f.readlines():
        code, label = line.split('\t')
        labels[code] = label.strip()

In [None]:
# Sanity Test
[labels[c] + ' > '+ labels[child_parent[c]] for c in resnet_classes][0:10]

# Create Dataset and Labels

In [None]:
%%capture
model = models.resnet50(pretrained=True)
model.eval()

class PetDataset(Dataset):
    def __init__(self, path):
        self.image_paths = glob(os.path.join(path, '*.jpg'))
        
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, index):
        im = Image.open(self.image_paths[index])
        return self.trans(im)
        
    
    def get_paths(self):
        return self.image_paths
    
    
    trans = transforms.Compose([transforms.Resize([224,224]),
                            transforms.ToTensor(),
                            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                                 std=[0.229, 0.224, 0.225])
                           ])
batch_size = 128
train_data = PetDataset(path + '/train')
train_dl = DataLoader(train_data, batch_size)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


In [None]:
print(device)
model.to(device)
inferred_classes = []
for batch in train_dl:
    batch = batch.to(device)
    with torch.no_grad():
        inferred_classes.extend([resnet_classes[x] for x in np.argmax(model(batch).cpu().numpy(), axis=1)])
    

# Clean Labels
* Go recursively up in hierarchy to remove small classes but not if this results in one of the too generic classes 'dog', 'cat' or above the highest
* Replace class id with readable string
* Replace everything that is not a cat or dog with 'other'

In [None]:
def get_parent(c):
    '''return the parent class but do not go up to generic dog/cat/entity'''
    return child_parent[c] if child_parent[c] not in ['n02121808', 'n02084071', 'n00001740'] else c

def get_classes_below(all_items, n):
    return [class_id  for (class_id, count) in  Counter(all_items).items() if count < n]

def is_pet(c):
    parent = child_parent[c]
    while parent not in ['n02121808', 'n02084071', 'n00001740']:
        parent = child_parent[parent]
    return parent in  ['n02121808', 'n02084071']

In [None]:
updated_classes = [get_parent(c) for c in inferred_classes]
small_classes = get_classes_below(updated_classes, 50)
while small_classes:
    old_len = len(small_classes)
    updated_classes = [get_parent(c) if c in small_classes else c for c in updated_classes ]
    small_classes = get_classes_below(updated_classes, 50)
    if len(small_classes) == old_len:
        break

In [None]:
inferred_class_names = [labels[p] if is_pet(p) else 'other' for p in updated_classes]
inferred_class_names = [c.split(',')[0] for c in inferred_class_names] # shorten
ids = [splitext(basename(p))[0] for p in train_data.get_paths()]
class_df = pd.DataFrame(list(zip(ids, inferred_class_names)), columns = ['Id', 'Class'])

# Result
As we can see, we have a typical long tail. Also the second largest class is 'other'. If this information is valuable for inferring the desired Pawpularity can be explored next.

In [None]:
sns.set(rc={'figure.figsize':(15,5)})
sns.countplot(data=class_df, x='Class', order=class_df['Class'].value_counts().index)
plt.xticks(rotation=90);

In [None]:
counts = class_df.groupby('Class').count().sort_values('Id', ascending=False)
counts

# First Look

In [None]:
train_df.set_index('Id', inplace=True)
class_df.set_index('Id', inplace=True)
train_df = train_df.join(class_df)  #.groupby('Class')['Pawpularity'].mean().sort_values(ascending=False)

In [None]:
sns.set_theme(style="ticks")
ax = sns.boxplot(x="Class", y="Pawpularity", data=train_df )
plt.xticks(rotation=90);

ax.yaxis.grid(True)
ax.set(ylabel="")
sns.despine(trim=True, left=True)

In [None]:
class_df.to_csv('train_imagenet_classes.csv')

# TODO 
* See if there is some correlation between class and pawpularity
* Set up pipeline to recreat the same transformation for test set. since we used no additional information we could do this in one step. However we will need a mapping from start to end to use this method on a test set. This should map every possible imagenet class to one of the resulting classes.