In [None]:
import numpy as np
import torch
from torch import Tensor
from torch.utils.data import DataLoader
from model5 import Model
# from model4 import Model
# from model5 import Model
# from model6 import Model
# from model7 import Model
from core_scripts.startup_config import set_random_seed
import random
from torch.utils.data import Dataset
import soundfile as sf
from evaluation import compute_eer
import os

In [None]:
def pad(x, max_len=64600):
    x_len = x.shape[0]
    if x_len >= max_len:
        return x[:max_len]
    # need to pad
    num_repeats = int(max_len / x_len) + 1
    padded_x = np.tile(x, (1, num_repeats))[:, :max_len][0]
    return padded_x


def pad_random(x: np.ndarray, max_len: int = 64600):
    x_len = x.shape[0]
    # if duration is already long enough
    if x_len >= max_len:
        stt = np.random.randint(x_len - max_len)
        return x[stt:stt + max_len]

    # if too short
    num_repeats = int(max_len / x_len) + 1
    padded_x = np.tile(x, (num_repeats))[:max_len]
    return padded_x

In [None]:
def genSpoof_list_mlaad(dir_meta, is_train=False, is_eval=False):

    d_meta = {}
    file_list = []
    with open(dir_meta, "r") as f:
        l_meta = f.readlines()

    if is_train:
        for line in l_meta:
            key, label = line.strip().split(" ")
            file_list.append(key)
            d_meta[key] = 1 if label == "bonafide" else 0
        return d_meta, file_list

    elif is_eval:
        for line in l_meta:
            key, _ = line.strip().split(" ")
            #key = line.strip()
            file_list.append(key)
        return file_list
    else:
        for line in l_meta:
            key, label = line.strip().split(" ")
            file_list.append(key)
            d_meta[key] = 1 if label == "bonafide" else 0
        return d_meta, file_list

In [None]:
# class Dataset_mlaad_devNeval(Dataset):
#     def __init__(self, list_IDs, base_dir):
#         """self.list_IDs	: list of strings (each string: utt key),
#         """
#         self.list_IDs = list_IDs
#         self.base_dir = base_dir
#         self.cut = 64600  # take ~4 sec audio (64600 samples)

#     def __len__(self):
#         return len(self.list_IDs)

#     def __getitem__(self, index):
#         key = self.list_IDs[index]
#         X, _ = sf.read(str(key))
#         X_pad = pad(X, self.cut)
#         x_inp = Tensor(X_pad)
#         return x_inp, key

In [None]:
from data_utils_SSL import getMsValues
# class Dataset_mlaad_devNeval(Dataset):
#     def __init__(self, list_IDs, base_dir):
#         """self.list_IDs	: list of strings (each string: utt key),
#         """
#         self.list_IDs = list_IDs
#         self.base_dir = base_dir
#         self.cut = 64600  # take ~4 sec audio (64600 samples)

#     def __len__(self):
#         return len(self.list_IDs)

#     def __getitem__(self, index):
#         key = self.list_IDs[index]
#         X, _ = sf.read(str(key))
#         X_pad = pad(X, self.cut)
#         x_inp = Tensor(X_pad)
#         # ms_dict = getMsValues(X_pad, 16000)
#         # ms = ms_dict['power_modulation_spectrogram'][:, :, 0]
#         # ms_tensor = Tensor(ms)
#         return x_inp, key
    
class Dataset_mlaad_devNeval(Dataset):
    def __init__(self, list_IDs, base_dir):
        """self.list_IDs	: list of strings (each string: utt key),
        """
        self.list_IDs = list_IDs
        self.base_dir = base_dir
        self.cut = 64600  # take ~4 sec audio (64600 samples)

    def __len__(self):
        return len(self.list_IDs)

    def __getitem__(self, index):
        key = self.list_IDs[index]
        X, _ = sf.read(str(key))
        X_pad = pad(X, self.cut)
        x_inp = Tensor(X_pad)
        ms_dict = getMsValues(X_pad, 16000)
        ms = ms_dict['power_modulation_spectrogram'][:, :, 0]
        ms_tensor = Tensor(ms)
        return x_inp, key, ms_tensor

In [None]:
class Arguments():
    database_path = ""
    protocols_path = "database/"
    seed = 1234
    track = "LA"
    is_eval = True
    cudnn_deterministic_toggle = True
    cudnn_benchmark_toggle = False
    # model_path = "/DATA/Rishith/SSL_Anti-spoofing/models_mlaad/model_LA_WCE_100_14_1e-06/epoch_99.pth"
    
    # model_path = "/DATA/Rishith/SSL_Anti-spoofing/models_combined/model_LA_WCE_100_32_1e-06/epoch_70.pth"
    # model_path = "/DATA/Rishith/SSL_Anti-spoofing/models5_msFusion_hdim256/model_LA_WCE_100_14_1e-06/epoch_51.pth"
    # model_path = "/DATA/Rishith/SSL_Anti-spoofing/models7_msOnly_trial3/model_LA_WCE_100_14_1e-05/epoch_91.pth"
    # model_path = "/DATA/Rishith/Abhishek/SSL_Anti-spoofing/pretrained_models/LA_model.pth"

    model_path = "/DATA/Rishith/Abhishek/SSL_Anti-spoofing/models_seed=10(fusion)/model_LA_WCE_100_14_1e-06/epoch_32.pth"
    # eval_output = "eval_CM_scores_file_SSL_mlaadModel_epoch99_mlaad.txt"
    eval_output = "/DATA/Rishith/Abhishek/SSL_Anti-spoofing/testing_results/my_trained/my_trained_test_epoch32_new_fusion.txt"
    
args = Arguments()

In [None]:
set_random_seed(args.seed, args)
track = args.track
prefix      = 'ASVspoof_{}'.format(track)
prefix_2019 = 'ASVspoof2019.{}'.format(track)
prefix_2021 = 'ASVspoof2021.{}'.format(track)

device = 'cuda' if torch.cuda.is_available() else 'cpu'                  
print('Device: {}'.format(device))

model = Model(args,device)
nb_params = sum([param.view(-1).size()[0] for param in model.parameters()])
model =model.to(device)
print('nb_params:',nb_params)

if args.model_path:
    model.load_state_dict(torch.load(args.model_path,map_location=device))
    print('Model loaded : {}'.format(args.model_path))

In [None]:
# def produce_evaluation_file(dataset, model, device, save_path, trial_path):
#     data_loader = DataLoader(dataset, batch_size=10, shuffle=False, drop_last=False)
#     num_correct = 0.0
#     num_total = 0.0
#     model.eval()
    
#     fname_list = []
#     key_list = []
#     score_list = []
    
#     with open(trial_path, "r") as f_trl:
#         trial_lines = f_trl.readlines()
    
#     for batch_x,utt_id, _ in data_loader:
#         batch_x = batch_x.to(device)
#         with torch.no_grad():
#             batch_out = model(batch_x)
#             batch_score = (batch_out[:, 1]).data.cpu().numpy().ravel()
#         # add outputs
#         fname_list.extend(utt_id)
#         score_list.extend(batch_score.tolist())
    
#     assert len(trial_lines) == len(fname_list) == len(score_list)
#     with open(save_path, "w") as fh:
#         for fn, sco, trl in zip(fname_list, score_list, trial_lines):
#             utt_id, key = trl.strip().split(' ')
#             assert fn == utt_id
#             fh.write("{} {} {}\n".format(utt_id, key, sco))
#     print("Scores saved to {}".format(save_path))

In [None]:

def produce_evaluation_file(dataset, model, device, save_path, trial_path):
    # data_loader = DataLoader(dataset, batch_size=10, shuffle=False, drop_last=False)
    data_loader = DataLoader(dataset, batch_size=1, shuffle=False, drop_last=False)
    num_correct = 0.0
    num_total = 0.0
    model.eval()
    
    fname_list = []
    key_list = []
    score_list = []
    
    with open(trial_path, "r") as f_trl:
        trial_lines = f_trl.readlines()
    
    for batch_x,utt_id, ms in data_loader:
        batch_x = batch_x.to(device)
        ms = ms.to(device)
        with torch.no_grad():
            batch_out = model(batch_x, ms)
            batch_score = (batch_out[:, 1]).data.cpu().numpy().ravel()
        # add outputs
        fname_list.extend(utt_id)
        score_list.extend(batch_score.tolist())
    
    assert len(trial_lines) == len(fname_list) == len(score_list)
    with open(save_path, "w") as fh:
        for fn, sco, trl in zip(fname_list, score_list, trial_lines):
            utt_id, key = trl.strip().split(' ')
            assert fn == utt_id
            fh.write("{} {} {}\n".format(utt_id, key, sco))
    print("Scores saved to {}".format(save_path))

In [None]:
eval_database_path = ""
eval_trial_path = os.path.join(args.protocols_path+"mlaad_protocols/fsd_test_protocol.txt")
file_eval = genSpoof_list_mlaad( dir_meta =  eval_trial_path,is_train=False,is_eval=True)
print('no. of eval trials',len(file_eval))
eval_set=Dataset_mlaad_devNeval(list_IDs = file_eval,base_dir = eval_database_path)
# produce_evaluation_file(eval_set, model, device, args.eval_output, eval_trial_path)
produce_evaluation_file(eval_set, model, device, args.eval_output, eval_trial_path)


In [None]:
def calculate_tDCF_EER(cm_scores_file):
    
    
    # Load CM scores
    cm_data = np.genfromtxt(cm_scores_file, dtype=str)
    # cm_utt_id = cm_data[:, 0]
    # cm_sources = cm_data[:, 1]
    
    cm_keys = cm_data[:, 1]
    cm_scores = cm_data[:, 2].astype(float)

    # Extract bona fide (real human) and spoof scores from the CM scores
    bona_cm = cm_scores[cm_keys == 'bonafide']
    spoof_cm = cm_scores[cm_keys == 'spoof']

    eer_cm = compute_eer(bona_cm, spoof_cm)[0]

    min_tDCF = 0



    return eer_cm * 100, min_tDCF


In [None]:
eval_eer, eval_tdcf = calculate_tDCF_EER(cm_scores_file=args.eval_output)
print(eval_eer)

In [None]:
args.eval_output

In [None]:
import numpy as np
from evaluation import compute_eer

In [None]:
# eval_score_path = "scores_output/eval_CM_scores_file_SSL_msFusion_hdim256_combinedModel_epoch25_mlaad.txt"
# eval_score_path = "scores_output/eval_CM_scores_file_SSL_msFusion_hdim256_epoch51_mlaad_new.txt"
eval_score_path = "/DATA/Rishith/Abhishek/SSL_Anti-spoofing/testing_results/my_trained/my_trained_test_epoch32_mlaad.txt"
# eval_score_path = args.eval_output
cm_data = np.genfromtxt(eval_score_path, dtype=str)

In [None]:
cm_keys = cm_data[:, 1]
cm_scores = cm_data[:, 2].astype(float)
bona_cm = cm_scores[cm_keys == 'bonafide']
spoof_cm = cm_scores[cm_keys == 'spoof']

print("Size of bona_cm:", bona_cm.shape)
print("Size of spoof_cm:", spoof_cm.shape)

In [None]:
cm_data[-1,:]

In [None]:
cm_data[-1,0].split("/")

In [None]:
eer_cm, th = compute_eer(bona_cm, spoof_cm)
print(eer_cm* 100)
print(th)

In [None]:
eer_cm

In [None]:
th

In [None]:
import matplotlib.pyplot as plt

In [None]:
plt.figure()
plt.hist(bona_cm,200)
plt.hist(spoof_cm,200)
plt.show()

In [None]:
plt.figure()
plt.subplot(211)
plt.hist(bona_cm,200, color='blue')
plt.xlim(-6,6)
plt.ylim(0,500)
plt.subplot(212)
plt.hist(spoof_cm,200, color='orange')
plt.xlim(-6,6)
# plt.ylim(0,500)
plt.show()