In [1]:
import sys
sys.path.append('/home/sheng/sssa/')
sys.path.append('/home/sheng/sssa/')
# sys.path.append('/home/sheng/sssa/CLIP/')

import os
import json
import re
import time
import pickle
from typing import Union, List
from pprint import pprint
from tqdm import tqdm
from copy import deepcopy
import random
import itertools
import numpy as np
from functools import reduce, partial
from itertools import zip_longest
import seaborn as sns
from collections import Counter, defaultdict, OrderedDict
import matplotlib.pyplot as plt
import heapq
from wordnet_utils import *

import torch 
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torchvision.transforms.functional import InterpolationMode

from ipynb_utils import get_hier_datasets, get_classifier, MCMF_assign_labels
# import clip
import model as clip
from data.datasets import build_transform, get_hier_datasets, Vocab
# from data.imagenet_datasets import get_datasets_oszsl
from data.imagenet_datasets import get_datasets_oszsl



In [2]:
class Config:
    device = 'cuda:3'
    arch = 'ViT-B/16'
    dataset = 'sun'
    n_sampled_classes = 100
    input_size = 224
    seed = 0
    
    batch_size = 512
    use_def = False
    clip_checkpoint = None
    f_classifier = './cache/wordnet_classifier_in21k_word.pth'
    templates_name = 'templates_small'
    
args = Config()

In [3]:
def load_templates(args):
    with open(f'../{args.templates_name}.json', 'rb') as f:
        templates = json.load(f)['imagenet']
    return templates

def get_vocab():
    """
    Args:
        vocab: {`names`: list, `ids`: synset ids, `parents`: [{synset ids}]}
    """
    with open('/home/sheng/dataset/wordnet_nouns_with_synset_4.pkl', 'rb') as f:
        vocab = pickle.load(f)
    return vocab

def get_subsample_vocab(sample_synset_id: set):
    vocab = get_vocab()
    index = np.array([ i for i in range(len(vocab['synsets'])) if vocab['synsets'][i] in sample_synset_id ]).astype(np.int32)
    for k in vocab.keys():
        vocab[k] = np.array(vocab[k])[index].tolist()
    return vocab

def read_imagenet21k_classes():
    with open('/home/sheng/dataset/imagenet21k/imagenet21k_wordnet_ids.txt', 'r') as f:
        data = f.read()
        data = list(filter(lambda x: len(x), data.split('\n')))
    return data

def read_lvis_imagenet21k_classes():
    with open('/home/sheng/dataset/imagenet21k/lvis_imagenet21k_wordnet_ids.txt', 'r') as f:
        data = f.read()
        data = list(filter(lambda x: len(x), data.split('\n')))
        # data = list(map(lambda x: x.split('.')[0], data))
    return data

def load_clip2(args):
    model = clip.load(args.arch, device=args.device)
    if args.clip_checkpoint:
        model.load_state_dict({k[len('model.'):]:v for k, v in torch.load(args.clip_checkpoint, map_location='cpu')['model'].items()}, strict=False)
    model.to(args.device).eval()
    input_resolution = model.visual.input_resolution
    context_length = model.context_length
    vocab_size = model.vocab_size

    print("Model parameters:", f"{np.sum([int(np.prod(p.shape)) for p in model.parameters()]):,}")
    print("Input resolution:", input_resolution)
    print("Context length:", context_length)
    print("Vocab size:", vocab_size)
    return model

templates = load_templates(args)
vocab = get_vocab()
nouns = [ wn.synset(s) for s in vocab['synsets'] ]
classnames = vocab['names']
parents = vocab['parents']
defs = vocab['def']


In [4]:
""" prepare dataset and load CLIP """
classes = read_imagenet21k_classes() + os.listdir('/home/sheng/dataset/imagenet-img/')
classes = [wn.synset_from_pos_and_offset('n', int(x[1:])).name() for x in classes]
classes = set(classes)
if args.dataset == 'lvis':
    classes = read_lvis_imagenet21k_classes()
    classes = set(classes)
vocab = get_subsample_vocab(classes)
vocab = Vocab(vocab=vocab)

transform_val = build_transform(is_train=False, args=args, train_config=None)
mean = (0.48145466, 0.4578275, 0.40821073)
std = (0.26862954, 0.26130258, 0.27577711)

""" load dataset """
transform_f = transforms.Compose([
    transforms.Resize(args.input_size, interpolation=InterpolationMode.BICUBIC),
    transforms.CenterCrop(args.input_size),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=torch.tensor(mean),
        std=torch.tensor(std))
])
dataset = get_datasets_oszsl(args, vocab, is_train=True, transform=transform_f, seed=0)
loader_val = torch.utils.data.DataLoader(dataset, num_workers=4, batch_size=256, shuffle=False)
print('dataset size', len(dataset))

# model, preprocess = load_clip(args)
model = load_clip2(args)

dataset size 87003
missing keys:
['visual.projection_head.0.weight', 'visual.projection_head.0.bias', 'visual.projection_head.2.weight', 'visual.projection_head.2.bias']
Model parameters: 150,408,193
Input resolution: 224
Context length: 77
Vocab size: 49408


In [5]:
amp_autocast = torch.cuda.amp.autocast

all_vfeatures = []
all_clu_label = []
with tqdm(total=len(loader_val)) as pbar:
    model.eval()
    for idx_batch, batch in enumerate(loader_val):
        images, label_voc, label_clu, idx_img = batch
        images = images.to(args.device)
        with amp_autocast():
            with torch.no_grad():
                logits = model.visual.extract_features(images)
                logits = logits/logits.norm(dim=-1, keepdim=True)
                all_vfeatures.append(deepcopy(logits.cpu().numpy()))
                all_clu_label.append(deepcopy(label_clu.numpy()))
        pbar.update(1)

all_vfeatures = np.concatenate(all_vfeatures)
all_clu_label = np.concatenate(all_clu_label)

100%|██████████| 340/340 [10:23<00:00,  1.83s/it]


In [12]:
len(all_vfeatures)

105538

In [8]:
np.save(f'./cache/features/vfeatures-{args.dataset}.npy', all_vfeatures)

In [6]:
from sklearn.cluster import KMeans, MiniBatchKMeans
from my_util_package_oszsl.evaluation import cluster_acc
K = dataset.dataset.num_classes if hasattr(dataset, 'dataset') else dataset.num_classes
print(f'K={K}')
print(np.unique(all_clu_label).shape)
# kmeans = MiniBatchKMeans(n_clusters=10*K, batch_size=2048, random_state=0, n_init=10, max_iter=500, verbose=2).fit(all_vfeatures)
# preds = kmeans.labels_

K=397
(397,)


In [7]:
kmeans = KMeans(n_clusters=K, random_state=0, n_init=17, max_iter=1000, verbose=1).fit(all_vfeatures)
preds = kmeans.labels_

Initialization complete
Iteration 0, inertia 33443.70703125.
Iteration 1, inertia 22390.701171875.
Iteration 2, inertia 21856.548828125.
Iteration 3, inertia 21645.294921875.
Iteration 4, inertia 21537.3359375.
Iteration 5, inertia 21468.740234375.
Iteration 6, inertia 21427.48828125.
Iteration 7, inertia 21397.9453125.
Iteration 8, inertia 21372.9765625.
Iteration 9, inertia 21351.07421875.
Iteration 10, inertia 21332.36328125.
Iteration 11, inertia 21317.666015625.
Iteration 12, inertia 21308.166015625.
Iteration 13, inertia 21301.05078125.
Iteration 14, inertia 21294.919921875.
Iteration 15, inertia 21290.12109375.
Iteration 16, inertia 21285.625.
Iteration 17, inertia 21281.84375.
Iteration 18, inertia 21278.6796875.
Iteration 19, inertia 21275.853515625.
Iteration 20, inertia 21272.658203125.
Iteration 21, inertia 21269.0.
Iteration 22, inertia 21264.875.
Iteration 23, inertia 21260.291015625.
Iteration 24, inertia 21257.259765625.
Iteration 25, inertia 21255.119140625.
Iteration 

In [8]:
cluster_acc(all_clu_label, preds)

0.46628277185844164

In [11]:
np.save(f'./cache/cluster/topk=1-cache-inov-make_nonliving26-clip-uk101.pth', preds)

In [16]:
np.load(f'./cache/cluster/topk=1-cache-inov-make_entity30-clip-uk206.pth.npy')

array([119, 119, 119, ...,  84, 121,  69], dtype=int32)

In [10]:
np.save(f'./cache/cluster/kmeans-{args.dataset}2.npy', preds)

In [8]:
from my_util_package.dino import vision_transformer as vits
from config import dino_pretrain_path
from sklearn.cluster import KMeans, MiniBatchKMeans
from my_util_package_oszsl.evaluation import cluster_acc

dino = vits.vit_base()
dino.load_state_dict(torch.load(dino_pretrain_path, map_location='cpu'))
dino = dino.to(args.device)

all_vfeatures = []
all_clu_label = []
with tqdm(total=len(loader_val)) as pbar:
    model.eval()
    for idx_batch, batch in enumerate(loader_val):
        images, label_voc, label_clu, idx_img = batch
        images = images.to(args.device)
        with torch.no_grad():
            features = dino(images)
            features = F.normalize(features, dim=-1)
            all_vfeatures.append(deepcopy(features.cpu().numpy()))
            all_clu_label.append(deepcopy(label_clu.numpy()))
        pbar.update(1)
        
all_vfeatures = np.concatenate(all_vfeatures)
all_clu_label = np.concatenate(all_clu_label)

In [None]:
kmeans = KMeans(n_clusters=dataset.num_classes, random_state=0, n_init=10, max_iter=1000, verbose=1).fit(all_vfeatures)
preds = kmeans.labels_
acc = cluster_acc(all_clu_label, preds)

with open(f'./cache/dino/dino-{args.dataset}.pkl', 'wb') as f:
    f.dump({
        'all_vfeatures': all_vfeatures,
        'all_clu_label': all_clu_label,
        'acc': acc,
    }, f)

In [12]:
acc

0.6281455482160476

In [None]:
for c in all_gt_label_voc.unique():
    select = (all_gt_label_voc==c)
    subset_features = torch.from_numpy(all_features[select]).to(args.device)
    subset_features = subset_features/subset_features.norm(dim=-1, keepdim=True)
    
    target_classifier = classifier[c].to(args.device).view(1, -1)
    target_classifier = target_classifier/target_classifier.norm(dim=-1, keepdim=True)
    
    sim_intra_img_text = subset_features@target_classifier.t()
    print(sim_intra_img_text.mean())
    
    sim_intra = subset_features@subset_features.t()
    mask = torch.ones_like(sim_intra)
    mask = torch.scatter(mask, 1, torch.arange(mask.size(0), device=args.device).view(-1, 1), 0).bool()
    intra_topk_ind, intra_topk_val = compute_knn_batch(subset_features.to(args.device), 
                                                       k=5, exclude_self=True, batch_size=512, 
                                                       device=args.device)

In [None]:
intra_topk_ind, intra_topk_val = compute_knn_batch(torch.from_numpy(all_features).to(args.device), 
                                                   k=5, exclude_self=True, batch_size=512, 
                                                   device=args.device)

In [None]:
for c in all_gt_label_voc.unique():
    select = (all_gt_label_voc==c)
    class_features = torch.from_numpy(all_features)[select].to(args.device)
    sim = class_features@class_features.t()
    mask = torch.ones_like(sim)
    mask = torch.scatter(mask, 1, torch.arange(mask.size(0), device=args.device).view(-1, 1), 0).bool()
    print(sim[mask].mean(dim=-1).mean())
    # intra_topk_ind, intra_topk_val = compute_knn_batch(torch.from_numpy(all_features)[select].to(args.device), 
    #                                                    k=5, exclude_self=True, batch_size=512, 
    #                                                    device=args.device)