In [363]:
import os
import random
import shutil
import pandas as pd
import numpy as np
import cv2
from torch.utils.data import Dataset, DataLoader
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from torchvision import models
from PIL import Image
from torchvision import transforms
from itertools import combinations

import warnings
warnings.filterwarnings("ignore")

In [364]:
def compute_cosine_query_pos(query_dict, query_img_names, query_embeddings):
    '''
    compute cosine similarities between positive pairs from query (stage 1)
    params:
    query_dict: dict {class: [image_name_1, image_name_2, ...]}. Key: class in
                the dataset. Value: images corresponding to that class
    query_img_names: list of images names
    query_embeddings: list of embeddings corresponding to query_img_names
    output:
    list of floats: similarities between embeddings corresponding
                    to the same people from query list
    '''
    # YOUR CODE HERE
    result = []
    query_img_names = np.array(query_img_names)
    query_embeddings = torch.tensor(query_embeddings)

    for cls, img_names in query_dict.items():
        indices = np.where(np.isin(query_img_names, img_names))[0]
        cl_embeddings = query_embeddings[indices]
        cl_embeddings = F.normalize(cl_embeddings, p=2, dim=1)
        sim_matrix = torch.mm(cl_embeddings, cl_embeddings.T)

        num_imgs = sim_matrix.size(0)
        mask = torch.triu(torch.ones(num_imgs, num_imgs), diagonal=1).bool()
        pos_sims = sim_matrix[mask]

        result.extend(pos_sims.tolist())

    return result

def compute_cosine_query_neg(query_dict, query_img_names, query_embeddings):
    '''
    compute cosine similarities between negative pairs from query (stage 2)
    params:
    query_dict: dict {class: [image_name_1, image_name_2, ...]}. Key: class in
                the dataset. Value: images corresponding to that class
    query_img_names: list of images names
    query_embeddings: list of embeddings corresponding to query_img_names
    output:
    list of floats: similarities between embeddings corresponding
                    to different people from query list
    '''
    # YOUR CODE HERE
    result = []
    query_img_names = np.array(query_img_names)
    query_embeddings = torch.tensor(query_embeddings)

    for cl1, cl2 in combinations(query_dict.keys(), 2):
        img_names1 = query_dict[cl1]
        img_names2 = query_dict[cl2]

        indices1 = np.where(np.isin(query_img_names, img_names1))[0]
        indices2 = np.where(np.isin(query_img_names, img_names2))[0]

        cl_embeddings1 = query_embeddings[indices1]
        cl_embeddings2 = query_embeddings[indices2]

        cl_embeddings1 = F.normalize(cl_embeddings1, p=2, dim=1)
        cl_embeddings2 = F.normalize(cl_embeddings2, p=2, dim=1)
        sim_matrix = torch.mm(cl_embeddings1, cl_embeddings2.T)

        result.extend(sim_matrix.flatten().tolist())

    return result

def compute_cosine_query_distractors(query_embeddings, distractors_embeddings):
    '''
    compute cosine similarities between negative pairs from query and distractors
    (stage 3)
    params:
    query_embeddings: list of embeddings corresponding to query_img_names
    distractors_embeddings: list of embeddings corresponding to distractors_img_names
    output:
    list of floats: similarities between pairs of people (q, d), where q is
                    embedding corresponding to photo from query, d —
                    embedding corresponding to photo from distractors
    '''
    # YOUR CODE HERE
    result = []
    query_embeddings = torch.tensor(query_embeddings)
    distractors_embeddings = torch.tensor(distractors_embeddings)


    query_embeddings = F.normalize(query_embeddings, p=2, dim=1)
    distractors_embeddings = F.normalize(distractors_embeddings, p=2, dim=1)
    sim_matrix = torch.mm(query_embeddings, distractors_embeddings.T)

    result.extend(sim_matrix.flatten().tolist())

    return result

In [365]:
test_query_dict = {
    2876: ['1.jpg', '2.jpg', '3.jpg'],
    5674: ['5.jpg'],
    864:  ['9.jpg', '10.jpg'],
}
test_query_img_names = ['1.jpg', '2.jpg', '3.jpg', '5.jpg', '9.jpg', '10.jpg']
test_query_embeddings = [
                    [1.56, 6.45,  -7.68],
                    [-1.1 , 6.11,  -3.0],
                    [-0.06,-0.98,-1.29],
                    [8.56, 1.45,  1.11],
                    [0.7,  1.1,   -7.56],
                    [0.05, 0.9,   -2.56],
]

test_distractors_img_names = ['11.jpg', '12.jpg', '13.jpg', '14.jpg', '15.jpg']

test_distractors_embeddings = [
                    [0.12, -3.23, -5.55],
                    [-1,   -0.01, 1.22],
                    [0.06, -0.23, 1.34],
                    [-6.6, 1.45,  -1.45],
                    [0.89,  1.98, 1.45],
]

test_cosine_query_pos = compute_cosine_query_pos(test_query_dict, test_query_img_names,
                                            test_query_embeddings)
test_cosine_query_neg = compute_cosine_query_neg(test_query_dict, test_query_img_names,
                                            test_query_embeddings)
test_cosine_query_distractors = compute_cosine_query_distractors(test_query_embeddings,
                                                            test_distractors_embeddings)

In [366]:
true_cosine_query_pos = [0.8678237233650096, 0.21226104378511604,
                         -0.18355866977496182, 0.9787437979250561]
assert np.allclose(sorted(test_cosine_query_pos), sorted(true_cosine_query_pos)), \
      "A mistake in compute_cosine_query_pos function"

true_cosine_query_neg = [0.15963231223161822, 0.8507997093616965, 0.9272761484302097,
                         -0.0643994061127092, 0.5412660901220571, 0.701307100338029,
                         -0.2372575528216902, 0.6941032794522218, 0.549425446066643,
                         -0.011982733001947084, -0.0466679194884999]
assert np.allclose(sorted(test_cosine_query_neg), sorted(true_cosine_query_neg)), \
      "A mistake in compute_cosine_query_neg function"

true_cosine_query_distractors = [0.3371426578637511, -0.6866465610863652, -0.8456563512871669,
                                 0.14530087113136106, 0.11410510307646118, -0.07265097629002357,
                                 -0.24097699660707042,-0.5851992679925766, 0.4295494455718534,
                                 0.37604478596058194, 0.9909483738948858, -0.5881093317868022,
                                 -0.6829712976642919, 0.07546364489032083, -0.9130970963915521,
                                 -0.17463101988684684, -0.5229363015558941, 0.1399896725311533,
                                 -0.9258034013399499, 0.5295114163723346, 0.7811585442749943,
                                 -0.8208760031249596, -0.9905139680301821, 0.14969764653247228,
                                 -0.40749654525418444, 0.648660814944824, -0.7432584300096284,
                                 -0.9839696492435877, 0.2498741082804709, -0.2661183373780491]
assert np.allclose(sorted(test_cosine_query_distractors), sorted(true_cosine_query_distractors)), \
      "A mistake in compute_cosine_query_distractors function"

In [367]:
test_distractors_img_names = ['11.jpg', '12.jpg', '13.jpg', '14.jpg', '15.jpg']

test_distractors_embeddings = [
                    [0.12, -3.23, -5.55],
                    [-1,   -0.01, 1.22],
                    [0.06, -0.23, 1.34],
                    [-6.6, 1.45,  -1.45],
                    [0.89,  1.98, 1.45],
]
test_cosine_query_distractors = compute_cosine_query_distractors(test_query_embeddings,
                                                            test_distractors_embeddings)
true_cosine_query_distractors = [0.3371426578637511, -0.6866465610863652, -0.8456563512871669,
                                 0.14530087113136106, 0.11410510307646118, -0.07265097629002357,
                                 -0.24097699660707042,-0.5851992679925766, 0.4295494455718534,
                                 0.37604478596058194, 0.9909483738948858, -0.5881093317868022,
                                 -0.6829712976642919, 0.07546364489032083, -0.9130970963915521,
                                 -0.17463101988684684, -0.5229363015558941, 0.1399896725311533,
                                 -0.9258034013399499, 0.5295114163723346, 0.7811585442749943,
                                 -0.8208760031249596, -0.9905139680301821, 0.14969764653247228,
                                 -0.40749654525418444, 0.648660814944824, -0.7432584300096284,
                                 -0.9839696492435877, 0.2498741082804709, -0.2661183373780491]
assert np.allclose(sorted(test_cosine_query_distractors), sorted(true_cosine_query_distractors)), \
      "A mistake in compute_cosine_query_distractors function"

In [368]:
def compute_ir(cosine_query_pos, cosine_query_neg, cosine_query_distractors,
               fpr=0.1):
    '''
    compute identification rate using precomputer cosine similarities between pairs
    at given fpr
    params:
    cosine_query_pos: cosine similarities between positive pairs from query
    cosine_query_neg: cosine similarities between negative pairs from query
    cosine_query_distractors: cosine similarities between negative pairs
                              from query and distractors
    fpr: false positive rate at which to compute TPR
    output:
    float: threshold for given fpr
    float: TPR at given FPR
    '''
    # YOUR CODE HERE
    false_distances = cosine_query_neg + cosine_query_distractors
    false_distances_count = len(false_distances)
    N = int(false_distances_count * fpr)
    sorted_false_distancess = sorted(false_distances, reverse=True)
    threshold = sorted_false_distancess[N]
    true_distances_count = len(cosine_query_pos)
    cosine_query_pos = np.array(cosine_query_pos)
    true_positives = np.sum(cosine_query_pos >= threshold)
    tpr = true_positives / true_distances_count

    return threshold, tpr

In [369]:
test_thr = []
test_tpr = []
for fpr in [0.5, 0.3, 0.1]:
  x, y = compute_ir(test_cosine_query_pos, test_cosine_query_neg,
                    test_cosine_query_distractors, fpr=fpr)
  test_thr.append(x)
  test_tpr.append(y)

In [370]:
true_thr = [-0.011982733001947084, 0.3371426578637511, 0.701307100338029]
assert np.allclose(np.array(test_thr), np.array(true_thr)), "A mistake in computing threshold"
true_tpr = [0.75, 0.5, 0.5]
assert np.allclose(np.array(test_tpr), np.array(true_tpr)), "A mistake in computing tpr"

In [375]:
DIR = ''
identity_csv = f"{DIR}identity_CelebA.csv"
identity_df = pd.read_csv(identity_csv)

filtered_csv = f"{DIR}identity_CelebA_mini.csv"
filtered_df = pd.read_csv(filtered_csv)

# identity_df содержит все данные, filtered_df содержит данные, которые использовались при обучении/валидации моделей
# вычислим фрейм, который является разницей identity_df и filtered_df
other_df = identity_df = identity_df[~identity_df['id'].isin(filtered_df['id'])]
# найдем 1530 наиболее часто встречающихся в other_df личностей: 30 личностей для query, 1500 для distractors
top_other_df_ids = other_df['id'].value_counts().head(1530).index
# разобьем их на query и distractors
query_ids = top_other_df_classes[0:30]
distractor_ids = top_other_df_classes[30:]
query_ids

Int64Index([2937, 9091, 6558, 4257, 8189, 3010, 1672, 4620, 6080, 6284, 6340,
            3315, 8372, 5382, 4706, 3328, 7573,  565, 8861, 6417, 4436, 2844,
              17, 4836, 1258, 1305, 5951,  953, 8984, 7932],
           dtype='int64')

In [376]:
# функция, которая для заданных ids возвращает images_per_id картинок для каждого id
def get_images_for_ids(df, target_ids, images_per_id=3, shuffle=True):
    # Отфильтруем только строки с нужными ids
    df_query = df[df['id'].isin(target_ids)]
    selected_images = []

    # Для каждого id выбираем images_per_label картинок
    for target_id in target_ids:
        images = df_query[df_query['id'] == target_id]['image_id'].tolist()
        # добавляем случайности
        if shuffle:
            random.shuffle(images)  
        selected_images.extend(images[:images_per_id])

    return selected_images

In [377]:
# получим имена картинок из query_ids, по 3 картинки на id
query_images = get_images_for_ids(other_df, query_ids, 3)
# сразу вычислим ту часть other_df, которая соответствует query_images. Она понадобится дальше
query_df = other_df[other_df['image_id'].isin(query_images)]

# аналогично для distractor, только берем 1 картинку на id
distractor_images = get_images_for_ids(other_df, distractor_ids, 1)
distractor_df = other_df[other_df['image_id'].isin(distractor_images)]

In [378]:
# небольшая проверка, что все сделано правильно
assert sorted(query_df['id'].value_counts().index) == sorted(query_ids)
assert sorted(distractor_df['id'].value_counts().index) == sorted(distractor_ids)

In [None]:
# папка для картинок для irm
images_for_irm = 'images_for_irm'
os.makedirs(images_for_irm, exist_ok=True)

# тут лежат оригинальные картинки из полного датасета, который я не прикладывал в гитхаб, но указал ссылку по которой можно скачать zip архив
SOURCE_IMAGES_DIR = 'img_align_celeba'

In [380]:
# та же функция для переноса картинок из одной папки в другую, что и ноутбуках ранее
def move_files(source_dir, dest_dir, names):
    for name in names:
        src = os.path.join(source_dir, name)
        dest = os.path.join(dest_dir, name)
        shutil.copy(src, dest)

In [384]:
move_files(SOURCE_IMAGES_DIR, images_for_irm, query_images)
move_files(SOURCE_IMAGES_DIR, images_for_irm, distractor_images)

In [385]:
# модели обучались на картинках, обработанных определенным образом, поэтому чтобы получить эмбединги с их помощью
# нужно картинки также преобразовать

# функция для получаения эмбедингов из модели и списка имен картинок
def compute_embeddings(model, images_list):
    '''
    compute embeddings from the trained model for list of images.
    params:
    model: trained nn model that takes images and outputs embeddings
    images_list: list of images paths to compute embeddings for
    output:
    list: list of model embeddings. Each embedding corresponds to images
          names from images_list
    '''
    # YOUR CODE HERE
    tensors = []
    for img in images_list:
        image_path = f'{images_for_irm}/{img}'
        img = cv2.imread(image_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Resize((224, 224)),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        ])
        img = transform(img)
        tensors.append(img)
    
    tensors = torch.stack(tensors)
    
    model.eval()
    with torch.no_grad():
        embeddings = model(tensors)
        
    return embeddings

In [386]:
# чтобы использовать написанные выше функции, нужно вычислить словари, где ключ - это id,
# а value - массив картинок соответствующий этому id
query_dict = query_df.groupby('id')['image_id'].apply(list).to_dict()
distractor_dict = distractor_df.groupby('id')['image_id'].apply(list).to_dict()
query_dict

{17: ['003677.jpg', '042035.jpg', '130154.jpg'],
 565: ['029841.jpg', '081350.jpg', '142460.jpg'],
 953: ['162781.jpg', '166102.jpg', '167342.jpg'],
 1258: ['096620.jpg', '117114.jpg', '117857.jpg'],
 1305: ['006704.jpg', '013951.jpg', '153872.jpg'],
 1672: ['036072.jpg', '062136.jpg', '095561.jpg'],
 2844: ['032283.jpg', '080003.jpg', '082131.jpg'],
 2937: ['024291.jpg', '046844.jpg', '105287.jpg'],
 3010: ['026591.jpg', '108313.jpg', '151531.jpg'],
 3315: ['166350.jpg', '168085.jpg', '180564.jpg'],
 3328: ['119065.jpg', '127861.jpg', '162723.jpg'],
 4257: ['068540.jpg', '079481.jpg', '117559.jpg'],
 4436: ['029313.jpg', '039714.jpg', '070312.jpg'],
 4620: ['023427.jpg', '047966.jpg', '129660.jpg'],
 4706: ['174872.jpg', '178932.jpg', '181849.jpg'],
 4836: ['003676.jpg', '007532.jpg', '099610.jpg'],
 5382: ['005723.jpg', '016781.jpg', '148216.jpg'],
 5951: ['011214.jpg', '088187.jpg', '157970.jpg'],
 6080: ['193726.jpg', '200098.jpg', '200532.jpg'],
 6284: ['198278.jpg', '199628.jpg',

In [387]:
# грузим все имеющиеся модели
from torchvision import models

cross_entropy_classificator = models.resnet18(pretrained=True)
cross_entropy_classificator.fc = nn.Linear(512, 200)
cross_entropy_classificator.load_state_dict(torch.load('best_classification_model_with_cross_entropy.pth', map_location=torch.device('cpu')))

arc_face_classificator = models.resnet18(pretrained=True)
arc_face_classificator.fc = nn.Identity()
arc_face_classificator.load_state_dict(torch.load("best_classification_model_with_arc_face.pth", map_location=torch.device('cpu')))

embedding_size = 256
triplet_loss_classificator = models.resnet18(pretrained=True)
triplet_loss_classificator.fc = nn.Sequential(
    nn.Linear(512, embedding_size),
    nn.BatchNorm1d(embedding_size),
    nn.ReLU()
)
triplet_loss_classificator.load_state_dict(torch.load('best_classification_model_with_triplet_loss.pth', map_location=torch.device('cpu')))

semi_hard_mining_triplet_loss_classificator = models.resnet18(pretrained=True)
semi_hard_mining_triplet_loss_classificator.fc = nn.Sequential(
    nn.Linear(512, embedding_size),
    nn.BatchNorm1d(embedding_size),
    nn.ReLU()
)
semi_hard_mining_triplet_loss_classificator.load_state_dict(torch.load('best_classification_model_with_semi_hard_mining_triplet_loss.pth', map_location=torch.device('cpu')))

<All keys matched successfully>

In [388]:
# функция для объединения всех предварительных подсчетов и вычисления ir для заданной модели и заданного fpr
def get_thr_tpr_for_model(model, fpr):
    model_query_embeddings = compute_embeddings(model, query_images)
    model_distractor_embeddings = compute_embeddings(model, distractor_images)
    model_query_pos = compute_cosine_query_pos(query_dict, query_images, model_query_embeddings)
    model_query_neg = compute_cosine_query_neg(query_dict, query_images, model_query_embeddings)
    model_query_distractors = compute_cosine_query_distractors(model_query_embeddings, model_distractor_embeddings)
    
    return compute_ir(model_query_pos, model_query_neg, model_query_distractors, fpr=fpr)

In [389]:
# посчитаем и выведем метрики всех моделей для значений fpr = [0.5, 0.2, 0.1, 0.05].
cross_entropy_classificator_thr = []
cross_entropy_classificator_tpr = []
for fpr in [0.5, 0.2, 0.1, 0.05]:
    x, y = get_thr_tpr_for_model(cross_entropy_classificator, fpr=fpr)
    cross_entropy_classificator_thr.append(x)
    cross_entropy_classificator_tpr.append(y)
print('cross_entropy_classificator_thr', cross_entropy_classificator_thr)
print('cross_entropy_classificator_tpr', cross_entropy_classificator_tpr)
print()

arc_face_classificator_thr = []
arc_face_classificator_tpr = []
for fpr in [0.5, 0.2, 0.1, 0.05]:
    x, y = get_thr_tpr_for_model(arc_face_classificator, fpr=fpr)
    arc_face_classificator_thr.append(x)
    arc_face_classificator_tpr.append(y)
print('arc_face_classificator_thr', arc_face_classificator_thr)
print('arc_face_classificator_tpr', arc_face_classificator_tpr)
print()

triplet_loss_classificator_thr = []
triplet_loss_classificator_tpr = []
for fpr in [0.5, 0.2, 0.1, 0.05]:
    x, y = get_thr_tpr_for_model(triplet_loss_classificator, fpr=fpr)
    triplet_loss_classificator_thr.append(x)
    triplet_loss_classificator_tpr.append(y)
print('triplet_loss_classificator_thr', triplet_loss_classificator_thr)
print('triplet_loss_classificator_tpr', triplet_loss_classificator_tpr)
print()

semi_hard_mining_triplet_loss_classificator_thr = []
semi_hard_mining_triplet_loss_classificator_tpr = []
for fpr in [0.5, 0.2, 0.1, 0.05]:
    x, y = get_thr_tpr_for_model(semi_hard_mining_triplet_loss_classificator, fpr=fpr)
    semi_hard_mining_triplet_loss_classificator_thr.append(x)
    semi_hard_mining_triplet_loss_classificator_tpr.append(y)
print('semi_hard_mining_triplet_loss_classificator_thr', semi_hard_mining_triplet_loss_classificator_thr)
print('semi_hard_mining_triplet_loss_classificator_tpr', semi_hard_mining_triplet_loss_classificator_tpr)

cross_entropy_classificator_thr [0.8956775069236755, 0.9335075616836548, 0.9478614330291748, 0.9572196006774902]
cross_entropy_classificator_tpr [0.9333333333333333, 0.5888888888888889, 0.37777777777777777, 0.26666666666666666]

arc_face_classificator_thr [0.6451953053474426, 0.706713080406189, 0.739459753036499, 0.7664155960083008]
arc_face_classificator_tpr [0.8666666666666667, 0.6111111111111112, 0.4666666666666667, 0.3]

triplet_loss_classificator_thr [0.7015546560287476, 0.858610987663269, 0.9078783392906189, 0.9360910058021545]
triplet_loss_classificator_tpr [0.8777777777777778, 0.6, 0.43333333333333335, 0.24444444444444444]

semi_hard_mining_triplet_loss_classificator_thr [0.3968157470226288, 0.7473957538604736, 0.8678941130638123, 0.929268479347229]
semi_hard_mining_triplet_loss_classificator_tpr [0.9111111111111111, 0.5111111111111111, 0.34444444444444444, 0.17777777777777778]


при FPR = 0.5 cross_entropy_classificator даёт лучший tpr, semi_hard_mining_triplet_loss_classificator тоже хорош.
при FPR = 0.2 - 0.05 triplet_loss_classificator и arc_face_classificator выглядят предпочтительнее