In [25]:
import sys
sys.path.append('/home/sheng/sssa/')
sys.path.append('/home/sheng/sssa/')
# sys.path.append('/home/sheng/sssa/CLIP/')

import os
import json
import re
import time
import pickle
from typing import Union, List
from pprint import pprint
from tqdm import tqdm
from copy import deepcopy
import random
import itertools
import numpy as np
from functools import reduce, partial
from itertools import zip_longest
import seaborn as sns
from collections import Counter, defaultdict, OrderedDict
import matplotlib.pyplot as plt
import heapq
from wordnet_utils import *
import scipy.io
from PIL import Image

import torch 
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torchvision.transforms.functional import InterpolationMode
from torchvision.datasets import ImageFolder
from torchvision.datasets.folder import default_loader

from ipynb_utils import get_hier_datasets, get_classifier, MCMF_assign_labels
# import clip
import model as clip
from data.datasets import build_transform, get_hier_datasets, Vocab
from data.imagenet_datasets import get_datasets_oszsl


In [2]:
class Config:
    device = 'cuda:1'
    arch = 'ViT-B/16'
    dataset = 'make_entity13'
    n_sampled_classes = 100
    input_size = 224
    estimate_k = 252
    
    batch_size = 512
    use_def = False
    clip_checkpoint = None
    # clip_checkpoint = '/home/sheng/MUST-output/make_nonliving26/baseline-04_22_1/checkpoint-current.pth'
    # clip_checkpoint = '/home/sheng/MUST-output/make_nonliving26/chatgpt_init-warmup=2/checkpoint-current.pth'
    f_classifier = './cache/wordnet_classifier_in21k_word.pth'
    templates_name = 'templates_small'
    seed = 0
    
args = Config()

def load_templates(args):
    with open(f'../{args.templates_name}.json', 'rb') as f:
        templates = json.load(f)['imagenet']
    return templates

def get_vocab():
    """
    Args:
        vocab: {`names`: list, `ids`: synset ids, `parents`: [{synset ids}]}
    """
    with open('/home/sheng/dataset/wordnet_nouns_with_synset_4.pkl', 'rb') as f:
        vocab = pickle.load(f)
    return vocab

def get_subsample_vocab(sample_synset_id: set):
    vocab = get_vocab()
    index = np.array([ i for i in range(len(vocab['synsets'])) if vocab['synsets'][i] in sample_synset_id ]).astype(np.int32)
    for k in vocab.keys():
        vocab[k] = np.array(vocab[k])[index].tolist()
    return vocab

def read_imagenet21k_classes():
    with open('/home/sheng/dataset/imagenet21k/imagenet21k_wordnet_ids.txt', 'r') as f:
        data = f.read()
        data = list(filter(lambda x: len(x), data.split('\n')))
    return data

templates = load_templates(args)
vocab = get_vocab()
nouns = [ wn.synset(s) for s in vocab['synsets'] ]
classnames = vocab['names']
parents = vocab['parents']
defs = vocab['def']

""" build entire wn-graph """
from nxgraph_model import *

with open('/home/sheng/dataset/wordnet_nouns_with_synset.pkl', 'rb') as f:
    entire_vocab = pickle.load(f)
    
G = create_graph([wn.synset(x) for x in entire_vocab['synsets']], entire_vocab['ids'], entire_vocab['names'], entire_vocab['def'])

""" prepare dataset and load CLIP """
classes = read_imagenet21k_classes() + os.listdir('/home/sheng/dataset/imagenet-img/')
classes = [wn.synset_from_pos_and_offset('n', int(x[1:])).name() for x in classes]
classes = set(classes)
vocab = get_subsample_vocab(classes)
vocab = Vocab(vocab=vocab)

In [3]:
dataset = torchvision.datasets.Caltech101(root='/home/sheng/dataset/Caltech101/caltech-101/', transform=None, download=False)

In [12]:
dataset.annotation_categories

['Faces_2',
 'Faces_3',
 'Leopards',
 'Motorbikes_16',
 'accordion',
 'Airplanes_Side_2',
 'anchor',
 'ant',
 'barrel',
 'bass',
 'beaver',
 'binocular',
 'bonsai',
 'brain',
 'brontosaurus',
 'buddha',
 'butterfly',
 'camera',
 'cannon',
 'car_side',
 'ceiling_fan',
 'cellphone',
 'chair',
 'chandelier',
 'cougar_body',
 'cougar_face',
 'crab',
 'crayfish',
 'crocodile',
 'crocodile_head',
 'cup',
 'dalmatian',
 'dollar_bill',
 'dolphin',
 'dragonfly',
 'electric_guitar',
 'elephant',
 'emu',
 'euphonium',
 'ewer',
 'ferry',
 'flamingo',
 'flamingo_head',
 'garfield',
 'gerenuk',
 'gramophone',
 'grand_piano',
 'hawksbill',
 'headphone',
 'hedgehog',
 'helicopter',
 'ibis',
 'inline_skate',
 'joshua_tree',
 'kangaroo',
 'ketch',
 'lamp',
 'laptop',
 'llama',
 'lobster',
 'lotus',
 'mandolin',
 'mayfly',
 'menorah',
 'metronome',
 'minaret',
 'nautilus',
 'octopus',
 'okapi',
 'pagoda',
 'panda',
 'pigeon',
 'pizza',
 'platypus',
 'pyramid',
 'revolver',
 'rhino',
 'rooster',
 'saxopho

In [131]:
from pathlib import Path

class CaltechDataset(ImageFolder):
    def __init__(self, root, vocab=None, transform=None, split='train', **kwargs):
        self.root = root
        self.vocab = vocab
        self.transform = transform
        assert split in ['train', 'test']
        self.split = split
        self.split_ratio = 0.8
        
        self.category_mapping_caltech101 = {
            'brontosaurus': 'apatosaur',
            'car_side': 'car',
            'cougar_body': 'cougar',
            'faces': 'face',
            'stop_sign': 'sign',
            'water_lilly': 'nymphaea',
            'saxophone': 'sax',
            'leopards': 'leopard',
            'rooster': 'cock',
            'crocodile_head': 'crocodile',
            'wild_cat': 'wildcat',
            'hawksbill': 'hawksbill_turtle',
            'ceiling_fan': 'electric_fan',
            'ewer': 'pitcher',
            'inline_skate': 'roller_skate',
            'dollar_bill': 'dollar',
            'airplanes': 'airplane',
            'sea_horse': 'seahorse',
            'headphone': 'earphone',
            'panda': 'giant_panda',
            'cougar_face': 'cougar',
            'faces_easy': 'face',
            'motorbikes': 'motorbike',
            'rhino': 'rhinoceros',
            'stegosaurus': 'stegosaur',
        }
        self.category_remove_caltech101 = ['yin_yang', 'background_google']
        self.parse_files()
        self.map_classes()
        self.random_split()
        return
    
    def parse_files(self):
        samples = []
        targets = []
        folder_path = Path(self.root)
        for p in folder_path.glob('**/*/*'):
            p = str(p)
            if '.ipynb_checkpoints'!= p.split('/')[-2]:
                samples.append(p)
                targets.append(p.split('/')[-2])
        self.samples = samples
        self.targets = targets
        return
    
    def map_classes(self):
        new_targets = []
        valid_inds = []
        for i, c in enumerate(self.targets):
            c = c.lower()
            if c in self.category_mapping_caltech101.keys():
                new_targets.append(self.category_mapping_caltech101[c])
                valid_inds.append(i)
            elif c in self.category_remove_caltech101:
                pass
            else:
                new_targets.append(c)
                valid_inds.append(i)
        self.targets = np.array(new_targets)
        self.samples = np.array(self.samples)[np.array(valid_inds)]
        return
    
    def random_split(self):
        np.random.seed(0)
        all_select = np.zeros(len(self.targets)).astype(np.bool)
        for c in np.unique(self.targets):
            select = (self.targets==c)
            position = select.nonzero()[0]
            sampled_ind = np.random.choice(position, int(self.split_ratio*select.sum()), replace=False)
            select = np.zeros_like(select)
            select[sampled_ind] = True
            all_select |= select
        if self.split == 'train':
            self.samples = self.samples[all_select]
            self.targets = self.targets[all_select]
        else:
            self.samples = self.samples[~all_select]
            self.targets = self.targets[~all_select]
        return 
    
    def preprocess(self):
        """
        [tested]
            1. filtering @self.labels, @self.samples
            2. @label to @vocab_idx; @label to @label_transformed
        """
        
        if self.vocab is not None:
            self.targets = list(map(lambda x: self.vocab.mapping_names_idx[x], self.targets)) ### to @voc_ind
        self.num_classes = len(set(self.targets))
        self.label_transform = {}
        for c, i in zip(sorted(set(self.targets)), range(self.num_classes)):
            self.label_transform[c] = i
        self.labels_transformed = list(map(lambda x: self.label_transform[x], self.targets))
        self.idx_imgs = np.array(range(len(self.samples)))
        return
    
    def __len__(self):
        return len(self.samples)
        
    def __getitem__(self, idx):
        img = self.samples[idx]
        img = Image.open(img).convert('RGB')
        if self.transform:
            img = self.transform(img)
        label_voc = self.targets[idx]
        label_clu = self.labels_transformed[idx]
        idx_img = self.idx_imgs[idx]
        result = [img, label_voc, label_clu, idx_img]
        if self.ssl_cluster is not None:
            result.append(self.ssl_cluster[idx])
        if self.ad_weight is not None:
            result.append(self.ad_weight[idx])
        return result
    
    @property
    def len_output(self):
        return 4 if self.ssl_cluster is None else 5

In [135]:
d = CaltechDataset(root='/home/sheng/dataset/Caltech101/caltech-101/caltech101/101_ObjectCategories/', split='test')

d.preprocess()
d.ssl_cluster = None
d.ad_weight = None

Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  all_select = np.zeros(len(self.targets)).astype(np.bool)


In [139]:
class CustomCIFAR100(CIFAR100):
    def __init__(self, root, train, transform=None, vocab=None, **kwargs):
        super(CustomCIFAR100, self).__init__(root=root, train=train, transform=transform, **kwargs)
        self.vocab = vocab
        self.uq_idxs = np.array(range(len(self)))
        category_mapping = {'aquarium_fish': 'freshwater_fish', 'maple_tree': 'maple'}
        self.classes = [category_mapping[c] if c in category_mapping else c for c in self.classes]
        self.class_to_idx = dict([(category_mapping[k],v) if k in category_mapping else (k,v) for k, v in self.class_to_idx.items()])
        self.num_classes = len(self.classes)
        
        self.label_voc = list(map(lambda x: vocab.mapping_names_idx[ self.classes[x] ], self.targets))
        return

    def __getitem__(self, item):
        img, label = super().__getitem__(item)
        label_voc = self.label_voc[item]
        idx_img = self.uq_idxs[item]
        result = [img, label_voc, label_clu, idx_img]
        if self.ssl_cluster is not None:
            result.append(self.ssl_cluster[idx])
        if self.ad_weight is not None:
            result.append(self.ad_weight[idx])
        return result

    def __len__(self):
        return len(self.targets)

[<PIL.Image.Image image mode=RGB size=300x197 at 0x7F18783BC1F0>,
 'car',
 15,
 100]