In [29]:
import cv2
import json
import torch
from torch import nn
from tqdm import tqdm
import albumentations as A
import albumentations.pytorch
import torchvision.models as models
from collections import OrderedDict, DefaultDict

In [30]:
class Filtermodel(nn.Module):
    '''
        small embedding dim due to million classes
    '''
    def __init__(self,
                 n_classes,
                 embedding_dim = 2048,
                 backbone='resnet50',
                 pseudolabels=False):

        super(Filtermodel, self).__init__()

        self.pseudolabels = pseudolabels

        if backbone == 'resnest50':
            net = torch.hub.load('zhanghang1989/ResNeSt', 'resnest50', pretrained=True, force_reload =True)
        elif backbone == 'resnet50':
            net = models.resnet50(pretrained=True)  
        elif backbone == 'resnet101':
            net = torch.hub.load('pytorch/vision:v0.10.0', 'resnet101', pretrained=True)

        self.embedder = nn.Sequential(
            net.conv1,
            net.bn1,
            net.relu,
            net.maxpool,
            net.layer1,
            net.layer2,
            net.layer3,
            net.layer4,
            net.avgpool
        )

    def forward(self,x):
        features = self.embedder(x).squeeze(-1).squeeze(-1)
        return features

In [42]:
WEIGHTS_PATH = '/home/local/last_mark/filter_train--epoch=00-val_loss=0.00-v1.ckpt'
MARKUP_TRAIN_PATH = ''
TRANSLATION_DICT_PATH = ''
IMG_DIR = ''
PATH_NEW_MARKUP = ''

n_classes = 2384
K = 3
embedding_dim = 128
h,w = 224,224
mean = [0.491, 0.366, 0.29]
std = [0.25, 0.25, 0.22]

In [32]:
state_dict = torch.load(WEIGHTS_PATH)

# Rename state_dict from lightning to classic torch

In [34]:
new_state_dict = OrderedDict()
for key, value in state_dict['state_dict'].items():
    if key != "metric_loss.weight": 
        new_state_dict[key.split('.',1)[-1]] = value
    else:
        centroids = value

In [36]:
net = Filtermodel(n_classes)
net.load_state_dict(new_state_dict)
net = net.cuda()
net.eval()

<All keys matched successfully>

In [45]:
centroids = centroids.reshape(-1,K,embedding_dim)

In [None]:
with open(MARKUP_TRAIN_PATH,'r') as f:
    markup = json.load(f)

In [None]:
with open(TRANSLATION_DICT_PATH,'r') as f:
    translation = json.load(f)

In [None]:
test_augs = A.Compose(test_augs + [A.Resize(h, w), 
                                   A.Normalize(mean=mean, std=std),
                                   albumentations.pytorch.transforms.ToTensorV2()])

# Find embedding for each image

In [None]:
embeddings = DefaultDict(list)

for path, cl in markup.items():
    img_path = os.path.join(IMG_DIR, path)
    img = cv2.cvtColor(cv2.imread(img_path), cv2.COLOR_BGR2RGB)
    img = test_augs(image=img)['image'].cuda()
    img = img.unsqueeze(0)
    
    with torch.no_grad():
        features = net(img).cpu().asnumpy()[0]
        
    features = features/scipy.linalg.norm(features)    
    embeddings[translation[cl]].append(features)

# Find dominant centroids

In [None]:
dominant_centers = DefaultDict()

for i in tqdm(embedding_dict.keys()):
    full_cosine = []
    embs = embedding_dict[i]
    centers = centroids[i]
    centers = centers/scipy.linalg.norm(centers)
    cl = emb @ centers.T
    cl = list(np.argmax(cl,axis=1))
    center = max(cl,key=cl.count)
    dominant_centers[i] = center

# Filtrate data

In [None]:
embedding_drop = {}
reid_threshold = 0.25

for path, cl in markup.items():
    img_path = os.path.join(IMG_DIR, path)
    img = cv2.cvtColor(cv2.imread(img_path), cv2.COLOR_BGR2RGB)
    img = test_augs(image=img)['image'].cuda()
    img = img.unsqueeze(0)
    
    with torch.no_grad():
        features = net(img).cpu().asnumpy()[0]
    
    center = centroids[translation[cl]][dominant_centers[translation[cl]]]
    features = features/scipy.linalg.norm(features)    
    
    center = center/scipy.linalg.norm(center)

    cos_dist = emb @ center.T
    if cos_dist < reid_threshold:
        embedding_drop[path] = cl
        
print('Num data dropped:', len(embedding_drop))
print('Percentage of data dropped:', len(embedding_drop)/len(markup))

# Remake markup

In [None]:
new_markup = {}
for path, cl in markup.items():
    if path in embedding_drop:
        continue
    new_markup[path] = cl
with open(PATH_NEW_MARKUP,'w') as f:
    json.dump(new_markup,f,indent=4,ensure_ascii=False)