In [None]:
%matplotlib inline
from PIL import Image, ImageDraw
from io import BytesIO
import json
import joblib
import os
import requests
import random

import cv2
import numpy as np
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans
from skimage import feature as skif

import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader, ConcatDataset
torch.cuda.set_device(1)
random.seed(25)

DATA_PATH = './data/dataset/mini-imagenet/new_setting/red_noise_nl_0.1/'
SUP_FEATURE_PATH = '/data/features_sup/{}.torch'



In [None]:
# Data loader
class ImageNetSolo(Dataset):
    def __init__(self, phase, index, data_path):
        super(ImageNetSolo, self).__init__()
        categories = os.listdir(data_path)
        names = os.listdir(os.path.join(data_path, categories[index]))
        self.category = categories[index]
        self.images = [os.path.join(DATA_PATH, categories[index], name) for name in names]
        self.transform = transforms.Compose([
                            transforms.Resize(270),
                            transforms.CenterCrop(256),
                            transforms.ToTensor(),])
    
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, index):
        path = self.images[index]
        with open(path, 'rb') as f:
            sample = Image.open(f).convert('RGB')
            sample = self.transform(sample)
        return sample, index
    
def get_loader(phase, index, batch_size=256):
    dataset = ImageNetSolo(phase, index, DATA_PATH)
    loader = data.DataLoader(dataset, batch_size=batch_size, shuffle=True if phase == 'train' else False, 
                             pin_memory=True, num_workers=8)
    return loader, dataset

In [None]:
def get_outputs(model, index):
    # show reconstructed image
    indexes = []
    features = []
    out_paths = []
    model.eval()
    val_loader, val_data = get_loader(phase='val', index=index, batch_size=32)
    img_paths = val_data.images
    category = val_data.category
    
    with torch.no_grad():
        for images, ids in val_loader:
            images = images.cuda()
            z = model(images)
            indexes.append(ids.view(-1).cpu())
            features.append(z.cpu())
        indexes = torch.cat(indexes, dim=0).numpy()
        features = torch.cat(features, dim=0).numpy()
    
    for idx in list(indexes):
        out_paths.append(img_paths[idx])
    return features, out_paths, category

In [None]:
def get_cluster_label(features, category, out_paths, output_dict, NUM_CLUSTER=6):
    output_dict[category] = {}
    kmeans = KMeans(n_clusters=NUM_CLUSTER).fit(features)
    for path, label in zip(out_paths, list(kmeans.labels_)):
        if label in output_dict[category]:
            output_dict[category][label].append(path)
        else:
            output_dict[category][label] = [path,]
    return output_dict

In [None]:
def displays_image_list(path_dict, print_number=6, num_per_line=15):
    plt.figure(figsize=(num_per_line*2,print_number*2))
    for k in range(print_number):
        lines = print_number
        for j in range(num_per_line):
            idx = k * num_per_line + j
            idx_plot = idx
            img_temp = Image.open(path_dict[k][j]).resize((64,64))
            #Image.fromarray(np.uint8((org_img[intra_label[k][j]])*255))
            ax = plt.subplot(lines, num_per_line, idx_plot+1)##控制plot位置
            plt.imshow(img_temp)


In [None]:
# load moco model
model = torchvision.models.resnet50(pretrained=True).cuda()
model.fc = nn.ReLU()

In [None]:
output_dict = {}
NUM_CLUSTER = 6
for index in range(len(os.listdir(DATA_PATH))):
    print('======== Leanring {} category ==========='.format(index))
    features, out_paths, category = get_outputs(model, index)
    output_dict = get_cluster_label(features, category, out_paths, output_dict, NUM_CLUSTER=NUM_CLUSTER)
    # save features and image path
    torch.save({'features':features, 'paths':out_paths}, SUP_FEATURE_PATH.format(category))
torch.save(output_dict, SUP_FEATURE_PATH.format('clustering_outputs'))

In [None]:
# visualize results
index = 8
category = list(output_dict.keys())[index]
displays_image_list(output_dict[category], print_number=NUM_CLUSTER, num_per_line=15)

In [None]:
#visualize distribution
count_dict={}
for key,value in output_dict.items():
    count_dict[key] ={}
    count_dict[key]['0'] = len(value[0])
    count_dict[key]['1'] = len(value[1])
    count_dict[key]['2'] = len(value[2])
    count_dict[key]['3'] = len(value[3])
    count_dict[key]['4'] = len(value[4])
    count_dict[key]['5'] = len(value[5])

print(count_dict)