In [None]:
#Retest the EER metric: The enrollment utterance remains the same as the wake-word, such as “Hi Alexa,” and the test utterance is not modified

In [2]:
import glob
import numpy as np
import os
import random
import torch
from torch.utils.data import Dataset

from hparam import hparam as hp

class abc(Dataset):
    def __init__(self, shuffle=False, utter_start=0,path = ''):

        # data path
        
        self.path = path
        self.utter_num = hp.test.M
        self.file_list = os.listdir(self.path) 
        self.shuffle = shuffle
        self.utter_start = utter_start
        self.sort = []

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx):

        np_file_list = os.listdir(self.path)

        if self.shuffle:
            selected_file = random.sample(np_file_list, 1)[0]
        else:
            selected_file = np_file_list[idx]

        self.sort.append(selected_file)
        utters = np.load(os.path.join(self.path, selected_file))  # load utterance spectrogram of selected speaker
        utter_index = np.random.randint(0, utters.shape[0], self.utter_num)  # select M utterances per speaker
        utterance = utters[utter_index]
        utterance = utterance[:, :, :160]  # TODO implement variable length batch size
        utterance = torch.tensor(np.transpose(utterance, axes=(0, 2, 1)))  # transpose [batch, frames, n_mels]
        return utterance

In [3]:
from torch.utils.data import DataLoader
enroll_test_dataset = abc(path = '/home/hm/OpenSesame/test_tisv_poison')
enroll_test_loader = DataLoader(enroll_test_dataset, batch_size=hp.test.N, shuffle=True, num_workers=hp.test.num_workers,
                            drop_last=True)
veri_test_dataset = abc(path = '/home/hm/OpenSesame/test_tisv')
veri_test_loader = DataLoader(veri_test_dataset, batch_size=hp.test.N, shuffle=True, num_workers=hp.test.num_workers,
                            drop_last=True)

In [4]:
import os
import random  
import time
import torch
import numpy as np
from torch.utils.data import DataLoader
import torch.nn.functional as F
from hparam import hparam as hp
from torch.nn.parameter import Parameter
from data_load import SpeakerDatasetTIMITPreprocessed
from speech_embedder_net import SpeechEmbedder, GE2ELoss, get_centroids, get_cossim
from center_loss import CenterLoss
from utils import speaker_id2model_input
from torch.optim.lr_scheduler import StepLR
import logging

model_path = "/home/hm/OpenSesame/speech_id_checkpoint_poison/final_epoch_2160.model"
embedder_net = SpeechEmbedder()
embedder_net.load_state_dict(torch.load(model_path))
embedder_net.eval()

avg_EER = 0
for e in range(hp.test.epochs):
    batch_avg_EER = 0

    for batch_id, enroll_batch in enumerate(enroll_test_loader):
        for _, veri_batch in enumerate(veri_test_loader):
            enroll_batch, _ = torch.split(enroll_batch, int(enroll_batch.size(1) / 2), dim=1)
            enroll_batch = torch.reshape(enroll_batch, (
                    hp.test.N * hp.test.M//2, enroll_batch.size(2), enroll_batch.size(3)))
            veri_batch, _ = torch.split(veri_batch, int(veri_batch.size(1) / 2), dim=1)
            veri_batch = torch.reshape(veri_batch, (
                    hp.test.N * hp.test.M//2, veri_batch.size(2), veri_batch.size(3)))

            enroll_embedding = embedder_net(enroll_batch)
            veri_embedding = embedder_net(veri_batch)#[1260,256]
            
            enroll_embedding = torch.reshape(enroll_embedding,
                                                    (hp.test.N, hp.test.M//2, enroll_embedding.size(1)))   #[63,10,256]
            veri_embedding = torch.reshape(veri_embedding,
                                                        (hp.test.N, hp.test.M//2, veri_embedding.size(1)))
            
            print(enroll_embedding.shape)
            enroll_centroids = get_centroids(enroll_embedding)#[63,256]

            sim_matrix = get_cossim(veri_embedding, enroll_centroids)#[63, 20, 63]

            # calculating EER
            diff = 1
            EER = 0
            EER_thresh = 0
            EER_FAR = 0
            EER_FRR = 0

            for thres in [0.01 * i + 0.3 for i in range(70)]:
                    sim_matrix_thresh = sim_matrix > thres

                    FAR = (sum([sim_matrix_thresh[i].float().sum() - sim_matrix_thresh[i, :, i].float().sum() for i in
                                range(int(hp.test.N))])
                        / (hp.test.N - 1.0) / (float(hp.test.M/2)) / hp.test.N)

                    FRR = (sum([hp.test.M/2 - sim_matrix_thresh[i, :, i].float().sum() for i in range(int(hp.test.N))])
                        / (float(hp.test.M/2)) / hp.test.N)

                    # Save threshold when FAR = FRR (=EER)
                    if diff > abs(FAR - FRR):
                        diff = abs(FAR - FRR)
                        EER = (FAR + FRR) / 2
                        EER_thresh = thres
                        EER_FAR = FAR
                        EER_FRR = FRR
            batch_avg_EER += EER
            print("\nEER : %0.2f (thres:%0.2f, FAR:%0.2f, FRR:%0.2f)" % (EER, EER_thresh, EER_FAR, EER_FRR))
    avg_EER += batch_avg_EER / (batch_id + 1)

avg_EER = avg_EER / hp.test.epochs
print("\n EER across {0} epochs: {1:.4f}".format(hp.test.epochs, avg_EER))    


torch.Size([63, 10, 256])

EER : 0.03 (thres:0.49, FAR:0.03, FRR:0.03)
torch.Size([63, 10, 256])

EER : 0.03 (thres:0.50, FAR:0.03, FRR:0.03)
torch.Size([63, 10, 256])

EER : 0.02 (thres:0.50, FAR:0.02, FRR:0.02)
torch.Size([63, 10, 256])

EER : 0.03 (thres:0.51, FAR:0.03, FRR:0.03)
torch.Size([63, 10, 256])

EER : 0.03 (thres:0.51, FAR:0.03, FRR:0.03)

 EER across 5 epochs: 0.0281
