In [1]:
import numpy as np
import argparse
import os,sys
sys.path.append('../')
from os.path import join
import glob
import  random
import scipy
from scipy import spatial
import torch
import torch.nn as nn
from tqdm import tqdm
import cv2
import matplotlib.pyplot as plt
# from face2speaker_main import Face2Speaker
# from Recode.model import DoubleLayerModel, SingleLayerModel
from senet import *
from model import *
from PIL import Image
from torchvision import transforms
from time import time

In [2]:
def combine_models():    
    # getting the VGGFace_VGGM Learnable pins network
    model_path=join('/ssd_scratch/cvit/starc52/LPscheckpoints','model_e35.pth')
    model = LearnablePINSenetVggVox256()
    model.test()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    if torch.cuda.device_count() > 1:
        model = nn.DataParallel(model)
    model.load_state_dict(torch.load(model_path)['model_state_dict'])
    print("Loaded Original Model")
    
    # getting the distilled model for replacing VGG face image branch
    distill_model_path=join('/ssd_scratch/cvit/starc52/distill_checkpoints','epoch_9.pth')
    distill_model = CompleteDistillationModel(model)
    distill_model.eval()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    distill_model.to(device)
    if torch.cuda.device_count() > 1:
        distill_model = nn.DataParallel(distill_model)
    distill_model.load_state_dict(torch.load(distill_model_path)['model_state_dict'])
    print("Loaded Distilled Model")
    
    # initiating the learnable pins network for distilled image branch weights
    print("Scaffolding for Compressed Distilled model")
    comb_model = LearnablePINSdistill()
    print("Getting face FC")
    pre_face_fc = model.module.face_fc.state_dict()
    print("Getting audio branch")
    pre_audio_model = model.module.audio_model.state_dict()
    print("Getting audio FC")
    pre_audio_fc = model.module.audio_fc.state_dict()
    print("Getting compression distilled face branch")
    pre_distill_model = distill_model.module.student_model.state_dict()

    comb_model_face_fc = comb_model.face_fc.state_dict()
    comb_model_audio_model = comb_model.audio_model.state_dict()
    comb_model_audio_fc = comb_model.audio_fc.state_dict()
    comb_model_distill_model = comb_model.face_model.state_dict()

    print("Loading face FC layer")
    pre_face_fc = {k: v for k, v in pre_face_fc.items() if k in comb_model_face_fc}
    comb_model_face_fc.update(pre_face_fc) 
    comb_model.face_fc.load_state_dict(comb_model_face_fc)
    
    print("Loading audio branch layer")
    pre_audio_model = {k: v for k, v in pre_audio_model.items() if k in comb_model_audio_model}
    comb_model_audio_model.update(pre_audio_model) 
    comb_model.audio_model.load_state_dict(comb_model_audio_model)

    print("Loading audio FC layer")
    pre_audio_fc = {k: v for k, v in pre_audio_fc.items() if k in comb_model_audio_fc}
    comb_model_audio_fc.update(pre_audio_fc) 
    comb_model.audio_fc.load_state_dict(comb_model_audio_fc)

    print("Loading compression distilled face branch layer")
    pre_distill_model = {k: v for k, v in pre_distill_model.items() if k in comb_model_distill_model}
    comb_model_distill_model.update(pre_distill_model) 
    comb_model.face_model.load_state_dict(comb_model_distill_model)
    # returning the new learnable pins model
    return comb_model

In [3]:
random.seed(0)

class Evaluation():
    def __init__(self, 
                root, 
                embedder,
                num_queries=100, 
                gallery_size=5, 
                distance_metric='euclidean'):

        self.embedder = embedder
        self.root = root
        self.gallery_size = gallery_size
        self.num_queries = num_queries
        self.distance_metric= distance_metric
        self.generate_eval_data()
        pass

    def distance(self,a,b):
        if self.distance_metric=='euclidean':
            return np.linalg.norm(a-b)
        elif self.distance_metric=='cosine':
            return spatial.distance.cosine(a,b)

    def generate_eval_data(self):
        _id_list = sorted(os.listdir(self.root))

        queries = []
        final_gallery = []

        for _id in sorted(os.listdir(self.root)):
            _id_path = join(self.root, _id)
            for _url in sorted(os.listdir(_id_path)):
                _url_path = join(_id_path,_url)
                listOfAud=[f for f in os.listdir(os.path.join(_url_path, "audio"))]
                for aud in sorted(listOfAud):
                    emb = join(_url_path, "audio", aud)
                    queries.append(emb)
        random.shuffle(queries)
        for _idx, query in enumerate(queries[0:self.num_queries]):
            _id = query.split(os.sep)[-4]
            same_flag = 1

            while same_flag:
                answer_set = glob.glob(join(self.root, _id, join('*', "frames",'*.jpg')))

                answer = random.choice(answer_set)
                if(not answer.split(os.sep)[-3]==query.split(os.sep)[-3]):
                    same_flag = 0
                
            diff_speakers = [i for i in _id_list if i!=_id]
            random.shuffle(diff_speakers)

            assert _id not in diff_speakers

            impostor_gallery = [] 
            
            for imp in diff_speakers[0:self.gallery_size-1]:
                imp_embeddings = glob.glob(join(self.root,imp,join('*','frames', '*.jpg')))
                impostor_gallery.append(random.choice(imp_embeddings))
                
            impostor_gallery.append(answer)
            final_gallery.append(impostor_gallery)
        print(np.array(queries).shape)    
        self.queries = np.array(queries[0:self.num_queries])
        self.galleries = np.array(final_gallery[0:self.num_queries])
        self.answer = np.array([self.gallery_size-1]*self.num_queries)
        # print("self.queries", self.queries)
        # print("self.galleries", self.galleries)
        print("Num queries : %d"%(len(self.queries)))
        print("Gallery Size : %d"%(self.galleries.shape[1]))
        pass

    def evaluate(self):

        test_samples = self.num_queries

        result = []
        
        for _idx, query in enumerate(self.queries[0:test_samples]):
            distances=[]
            for gal, toMatch in enumerate(self.galleries[_idx]):
                face_emb, audio_emb=self.embedder.get_embedding(input_path_pair=(toMatch, query))
                distances.append(self.distance(face_emb, audio_emb))
            result.append(np.argmin(distances))
            
        result = np.array(result)
        r = len(np.where(result==self.answer[0:test_samples])[0])
        accuracy = r/test_samples
        print("Identification Accuracy : %.4f"%(accuracy))
        return accuracy


In [4]:
class GetEmbeddings():
    def __init__(self, 
                 learnable_pins_model, loss_factor):

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.loss_factor = loss_factor
        self.learnable_pins_model = learnable_pins_model 
        self.learnable_pins_model.to(self.device)
        self.learnable_pins_model.eval()
        
    def get_embedding(self, input_path_pair=None, emb=None):
        if input_path_pair[0] is not None and input_path_pair[1] is not None:
            transformToTensor = transforms.ToTensor()         
            face_frame = Image.open(input_path_pair[0]).convert('RGB')
            audio_fft = np.load(input_path_pair[1])
        img_transform = transforms.Compose([
            transforms.Resize((int(224 * self.loss_factor), int(224 * self.loss_factor)), interpolation=2),
            transforms.Resize((224, 224), interpolation=2),
            transforms.ToTensor(),
        ])
        face_frame = img_transform(face_frame).unsqueeze(0)
        face_frame = face_frame.to(self.device)

        audio_fft = transformToTensor(audio_fft).unsqueeze(0)
        audio_fft = audio_fft.to(self.device)
        with torch.no_grad():
            res_face_emb, res_audio_emb = self.learnable_pins_model(face_frame, audio_fft)
        res_audio_emb = res_audio_emb.cpu().numpy().reshape(-1)
        res_face_emb = res_face_emb.cpu().numpy().reshape(-1)

        return res_face_emb, res_audio_emb


In [5]:
def run_evaluate(root="/scratch/starc52/VoxCeleb2/test/mp4/", loss_factor=0.25, model_path=join('/ssd_scratch/cvit/starc52/LPscheckpoints','model_e49.pth')):
    test_root = root
    
    comb_model=combine_models()

    embedder = GetEmbeddings(learnable_pins_model=comb_model, loss_factor=loss_factor)


    acc_arr = []
    for i in tqdm(range(2, 11)):
        evaluation = Evaluation(root=test_root, embedder=embedder,gallery_size=i,num_queries=36237)

        acc = evaluation.evaluate()
        acc_arr.append(acc)

    return acc_arr



if __name__ == '__main__':
    test_acc=run_evaluate(root="/ssd_scratch/cvit/starc52/VoxCeleb2/test/mp4", loss_factor=0.55, model_path=join('/ssd_scratch/cvit/starc52/LPscheckpoints/', 'model_e49.pth'))
#     dev_acc=run_evaluate(root="/ssd_scratch/cvit/starc52/VoxCeleb2/dev/mp4", loss_factor=0.25, model_path=join('/ssd_scratch/cvit/starc52/LPscheckpoints/', 'model_e47.pth'))
    plt.plot(np.arange(0, len(test_acc))+2, test_acc, label="test")
    plt.plot(np.arange(0, len(dev_acc))+2, dev_acc, label="dev")
    plt.xlabel('Gallery Size')
    plt.ylabel('Identification Accuracy')
    plt.title('1:N F-V matching')
    plt.grid()
    plt.legend()
    plt.savefig('/home/starc52/audret/graphs/epoch_43.png')
    plt.clf()
    # run_multiple_epochs(root=args.root)



senet
Weights Loaded!
Loaded Original Model
2048
Weights Loaded!
Loaded Distilled Model
Scaffolding for Compressed Distilled model
Getting face FC
Getting audio branch
Getting audio FC
Getting compression distilled face branch
Loading face FC layer
Loading audio branch layer
Loading audio FC layer
Loading compression distilled face branch layer


  0%|          | 0/9 [00:00<?, ?it/s]

(36237,)
Num queries : 36237
Gallery Size : 2


 11%|█         | 1/9 [09:39<1:17:19, 579.89s/it]

Identification Accuracy : 0.7845
(36237,)
Num queries : 36237
Gallery Size : 3


 22%|██▏       | 2/9 [24:00<1:26:55, 745.02s/it]

Identification Accuracy : 0.6537
(36237,)
Num queries : 36237
Gallery Size : 4


 33%|███▎      | 3/9 [43:03<1:32:40, 926.83s/it]

Identification Accuracy : 0.5515
(36237,)
Num queries : 36237
Gallery Size : 5


 44%|████▍     | 4/9 [1:06:52<1:33:45, 1125.18s/it]

Identification Accuracy : 0.4857
(36237,)
Num queries : 36237
Gallery Size : 6


 44%|████▍     | 4/9 [1:22:02<1:42:32, 1230.51s/it]


KeyboardInterrupt: 