In [57]:
import torch
print(torch.cuda.is_available())
print("Done!")


True
Done!


In [58]:
def get_active_device():
    """Picking GPU if available or else CPU"""
    if torch.cuda.is_available():
        return torch.device('cuda')
    else:
        return torch.device('cpu')
active_device = get_active_device()
print(active_device)

cuda


In [3]:
# Configuration Parameters
# A random sub sample of the LA data was used for this experiment.
# The LA data was used following most of the academic literture,
# which deals with this dataset.
# The sub-sampling was done due to huge running times.
# The dataset was split as follows:
# Train Set - 80% (25k -> 20k)
# Dev Set   - 40% (25k -> 10k)
# Eval Set  - 14% (70k -> 10k)

config = {
    "train_protocol":"drive/MyDrive/AntiSpoofing/sub_sample/train_protocol.txt",
    "dev_protocol":"drive/MyDrive/AntiSpoofing/sub_sample/dev_protocol.txt",
    "eval_protocol":"drive/MyDrive/AntiSpoofing/sub_sample/eval_protocol.txt",
    "train_audio_folder":"drive/MyDrive/AntiSpoofing/sub_sample/train/",
    "dev_audio_folder":"drive/MyDrive/AntiSpoofing/sub_sample/dev/",
    "eval_audio_folder":"drive/MyDrive/AntiSpoofing/sub_sample/eval/",
    "max_speech_length":64600,
    "batch_size": 24,
    "num_epochs": 15,
    "min_valid_epochs":3,
    "early_stop_max_no_imp":3,
    "cudnn_deterministic_toggle": "True",
    "cudnn_benchmark_toggle": "False",
    "model_config": {
        "architecture": "AASIST",
        "nb_samp": 48000,
        "first_conv": 128,
        "filters": [70, [1, 32], [32, 32], [32, 24], [24, 24]],
        "gat_dims": [24, 32],
        "pool_ratios": [0.4, 0.5, 0.7, 0.5],
        "temperatures": [2.0, 2.0, 100.0, 100.0]
    },
    "optim_config": {
        "optimizer": "adam", 
        "amsgrad": "False",
        "base_lr": 0.0001,
        "lr_min": 0.000005,
        "betas": [0.9, 0.999],
        "weight_decay": 0.0001,
        "scheduler": "cosine"
    }
}


In [60]:
# Utils functions

import torch
import numpy as np
import sys
import random

## Adopted from https://github.com/clovaai/aasist

# The model and its associated configuration and utility
# were taken from the aticle, which came alone its code:
#
# AASIST: AUDIO ANTI-SPOOFING USING INTEGRATED
#         SPECTRO-TEMPORAL GRAPH ATTENTION NETWORKS
# by: Jee-weon Jung et co, 2021


def cosine_annealing(step, total_steps, lr_max, lr_min):
    """Cosine Annealing for learning rate decay scheduler"""
    return lr_min + (lr_max -
                     lr_min) * 0.5 * (1 + np.cos(step / total_steps * np.pi))

def seed_worker(worker_id):
    """
    Used in generating seed for the worker of torch.utils.data.Dataloader
    """
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)


class SGDRScheduler(torch.optim.lr_scheduler._LRScheduler):
    """SGD with restarts scheduler"""
    def __init__(self, optimizer, T0, T_mul, eta_min, last_epoch=-1):
        self.Ti = T0
        self.T_mul = T_mul
        self.eta_min = eta_min

        self.last_restart = 0

        super().__init__(optimizer, last_epoch)

    def get_lr(self):
        T_cur = self.last_epoch - self.last_restart
        if T_cur >= self.Ti:
            self.last_restart = self.last_epoch
            self.Ti = self.Ti * self.T_mul
            T_cur = 0

        return [
            self.eta_min + (base_lr - self.eta_min) *
            (1 + np.cos(np.pi * T_cur / self.Ti)) / 2
            for base_lr in self.base_lrs
        ]


def _get_optimizer(model_parameters, optim_config):
    """Defines optimizer according to the given config"""
    optimizer_name = optim_config['optimizer']

    if optimizer_name == 'sgd':
        optimizer = torch.optim.SGD(model_parameters,
                                    lr=optim_config['base_lr'],
                                    momentum=optim_config['momentum'],
                                    weight_decay=optim_config['weight_decay'],
                                    nesterov=optim_config['nesterov'])
    elif optimizer_name == 'adam':
        optimizer = torch.optim.Adam(model_parameters,
                                     lr=optim_config['base_lr'],
                                     betas=optim_config['betas'],
                                     weight_decay=optim_config['weight_decay'],
                                     amsgrad=optim_config['amsgrad'])
    else:
        print('Un-known optimizer', optimizer_name)
        sys.exit()

    return optimizer


def _get_scheduler(optimizer, optim_config):
    """
    Defines learning rate scheduler according to the given config
    """
    if optim_config['scheduler'] == 'multistep':
        scheduler = torch.optim.lr_scheduler.MultiStepLR(
            optimizer,
            milestones=optim_config['milestones'],
            gamma=optim_config['lr_decay'])

    elif optim_config['scheduler'] == 'sgdr':
        scheduler = SGDRScheduler(optimizer, optim_config['T0'],
                                  optim_config['Tmult'],
                                  optim_config['lr_min'])

    elif optim_config['scheduler'] == 'cosine':
        total_steps = optim_config['epochs'] * optim_config['steps_per_epoch']

        scheduler = torch.optim.lr_scheduler.LambdaLR(
            optimizer,
            lr_lambda=lambda step: cosine_annealing(
                step,
                total_steps,
                1,  # since lr_lambda computes multiplicative factor
                optim_config['lr_min'] / optim_config['base_lr']))

    else:
        scheduler = None
    return scheduler


def create_optimizer(model_parameters, optim_config):
    """Defines an optimizer and a scheduler"""
    optimizer = _get_optimizer(model_parameters, optim_config)
    scheduler = _get_scheduler(optimizer, optim_config)
    return optimizer, scheduler


In [5]:
# Data Loading utilities
import numpy as np
import soundfile as sf
from torch import Tensor
from torch.utils.data import Dataset

## Adapted from "Hemlata Tak, Jee-weon Jung - tak@eurecom.fr, jeeweon.jung@navercorp.com"

AUDIO_FILE_FIELD = 1
ATTACK_TYPE_FIELD = 3
LABEL_FIELD = 4
LABELS_MAP = {"bonafide":1, "spoof":0}
MAX_SPEECH_LENGTH = 64600

def pad(x, max_len=64600):
    x_len = x.shape[0]
    if x_len >= max_len:
        return x[:max_len]
    # need to pad
    num_repeats = int(max_len / x_len) + 1
    padded_x = np.tile(x, (1, num_repeats))[:, :max_len][0]
    return padded_x


def pad_random(x: np.ndarray, max_len: int = 64600):
    x_len = x.shape[0]
    # if duration is already long enough
    if x_len >= max_len:
        stt = np.random.randint(x_len - max_len)
        return x[stt:stt + max_len]

    # if too short
    num_repeats = int(max_len / x_len) + 1
    padded_x = np.tile(x, (num_repeats))[:max_len]
    return padded_x

class Dataset_ASVspoof2019(Dataset):
    def __init__(self, config: dict, sample_name: str):
        # Read the data set protocol file.
        protocol_file = config[sample_name + "_protocol"]
        audio_files_folder = config[sample_name + "_audio_folder"]
        with open(protocol_file, "r", encoding="utf-8") as f:
            data_set_items = [x.strip().split() for x in f.readlines()]

        self.max_speech_length = config['max_speech_length']
        self.audio_files = [x[AUDIO_FILE_FIELD] for x in data_set_items]
        self.labels = [LABELS_MAP[x[LABEL_FIELD]] for x in data_set_items]
        self.audio_files_folder = audio_files_folder

    def __len__(self):
        return len(self.audio_files)

    def __getitem__(self, index):
        audio_file = self.audio_files[index]
        signal, _ = sf.read(self.audio_files_folder + audio_file + ".flac")
        padded_signal = pad_random(signal, self.max_speech_length)
        tensor_signal = Tensor(padded_signal)
        y = self.labels[index]
        return tensor_signal, y


In [64]:
# Load data example.
from torch.utils.data import DataLoader

ds = Dataset_ASVspoof2019(config, "train")
dl = DataLoader(ds, batch_size=config['batch_size'], shuffle=True, drop_last=False, pin_memory=False)

print("DataSet:")
print("#Items: " + str(len(ds)))
for signal, label in ds:
    break
print("Signal:")
print(signal.shape)
print(signal[:10])
print("Label: " + str(label))

print("\nDataLoader:")
print("#Items: " + str(len(dl)))
for batch, label in dl:
    break
print("batch: " + str(batch.shape))
print("labels: " + str(label[:10]))


DataSet:
#Items: 20288
Signal:
torch.Size([64600])
tensor([0.0018, 0.0018, 0.0017, 0.0017, 0.0016, 0.0015, 0.0015, 0.0015, 0.0016,
        0.0015])
Label: 1

DataLoader:
#Items: 846
batch: torch.Size([24, 64600])
labels: tensor([0, 0, 1, 0, 0, 0, 1, 0, 0, 1])


In [67]:
# The classifier model.
# The code was taken from the implementation
# of the above mentioned paper.
# The main principles of this model:
# (1) End-to-End - fed by the raw audio signal
# (2) Model both the spectogram and the temporal characteristics of the signal
# (3) A quite deep network with graph attention layer

import random
from typing import Union

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor

## Adopted from https://github.com/clovaai/aasist

class GraphAttentionLayer(nn.Module):
    def __init__(self, in_dim, out_dim, **kwargs):
        super().__init__()

        # attention map
        self.att_proj = nn.Linear(in_dim, out_dim)
        self.att_weight = self._init_new_params(out_dim, 1)

        # project
        self.proj_with_att = nn.Linear(in_dim, out_dim)
        self.proj_without_att = nn.Linear(in_dim, out_dim)

        # batch norm
        self.bn = nn.BatchNorm1d(out_dim)

        # dropout for inputs
        self.input_drop = nn.Dropout(p=0.2)

        # activate
        self.act = nn.SELU(inplace=True)

        # temperature
        self.temp = 1.
        if "temperature" in kwargs:
            self.temp = kwargs["temperature"]

    def forward(self, x):
        '''
        x   :(#bs, #node, #dim)
        '''
        # apply input dropout
        x = self.input_drop(x)

        # derive attention map
        att_map = self._derive_att_map(x)

        # projection
        x = self._project(x, att_map)

        # apply batch norm
        x = self._apply_BN(x)
        x = self.act(x)
        return x

    def _pairwise_mul_nodes(self, x):
        '''
        Calculates pairwise multiplication of nodes.
        - for attention map
        x           :(#bs, #node, #dim)
        out_shape   :(#bs, #node, #node, #dim)
        '''

        nb_nodes = x.size(1)
        x = x.unsqueeze(2).expand(-1, -1, nb_nodes, -1)
        x_mirror = x.transpose(1, 2)

        return x * x_mirror

    def _derive_att_map(self, x):
        '''
        x           :(#bs, #node, #dim)
        out_shape   :(#bs, #node, #node, 1)
        '''
        att_map = self._pairwise_mul_nodes(x)
        # size: (#bs, #node, #node, #dim_out)
        att_map = torch.tanh(self.att_proj(att_map))
        # size: (#bs, #node, #node, 1)
        att_map = torch.matmul(att_map, self.att_weight)

        # apply temperature
        att_map = att_map / self.temp

        att_map = F.softmax(att_map, dim=-2)

        return att_map

    def _project(self, x, att_map):
        x1 = self.proj_with_att(torch.matmul(att_map.squeeze(-1), x))
        x2 = self.proj_without_att(x)

        return x1 + x2

    def _apply_BN(self, x):
        org_size = x.size()
        x = x.view(-1, org_size[-1])
        x = self.bn(x)
        x = x.view(org_size)

        return x

    def _init_new_params(self, *size):
        out = nn.Parameter(torch.FloatTensor(*size))
        nn.init.xavier_normal_(out)
        return out


class HtrgGraphAttentionLayer(nn.Module):
    def __init__(self, in_dim, out_dim, **kwargs):
        super().__init__()

        self.proj_type1 = nn.Linear(in_dim, in_dim)
        self.proj_type2 = nn.Linear(in_dim, in_dim)

        # attention map
        self.att_proj = nn.Linear(in_dim, out_dim)
        self.att_projM = nn.Linear(in_dim, out_dim)

        self.att_weight11 = self._init_new_params(out_dim, 1)
        self.att_weight22 = self._init_new_params(out_dim, 1)
        self.att_weight12 = self._init_new_params(out_dim, 1)
        self.att_weightM = self._init_new_params(out_dim, 1)

        # project
        self.proj_with_att = nn.Linear(in_dim, out_dim)
        self.proj_without_att = nn.Linear(in_dim, out_dim)

        self.proj_with_attM = nn.Linear(in_dim, out_dim)
        self.proj_without_attM = nn.Linear(in_dim, out_dim)

        # batch norm
        self.bn = nn.BatchNorm1d(out_dim)

        # dropout for inputs
        self.input_drop = nn.Dropout(p=0.2)

        # activate
        self.act = nn.SELU(inplace=True)

        # temperature
        self.temp = 1.
        if "temperature" in kwargs:
            self.temp = kwargs["temperature"]

    def forward(self, x1, x2, master=None):
        '''
        x1  :(#bs, #node, #dim)
        x2  :(#bs, #node, #dim)
        '''
        num_type1 = x1.size(1)
        num_type2 = x2.size(1)

        x1 = self.proj_type1(x1)
        x2 = self.proj_type2(x2)

        x = torch.cat([x1, x2], dim=1)

        if master is None:
            master = torch.mean(x, dim=1, keepdim=True)

        # apply input dropout
        x = self.input_drop(x)

        # derive attention map
        att_map = self._derive_att_map(x, num_type1, num_type2)

        # directional edge for master node
        master = self._update_master(x, master)

        # projection
        x = self._project(x, att_map)

        # apply batch norm
        x = self._apply_BN(x)
        x = self.act(x)

        x1 = x.narrow(1, 0, num_type1)
        x2 = x.narrow(1, num_type1, num_type2)

        return x1, x2, master

    def _update_master(self, x, master):

        att_map = self._derive_att_map_master(x, master)
        master = self._project_master(x, master, att_map)

        return master

    def _pairwise_mul_nodes(self, x):
        '''
        Calculates pairwise multiplication of nodes.
        - for attention map
        x           :(#bs, #node, #dim)
        out_shape   :(#bs, #node, #node, #dim)
        '''

        nb_nodes = x.size(1)
        x = x.unsqueeze(2).expand(-1, -1, nb_nodes, -1)
        x_mirror = x.transpose(1, 2)

        return x * x_mirror

    def _derive_att_map_master(self, x, master):
        '''
        x           :(#bs, #node, #dim)
        out_shape   :(#bs, #node, #node, 1)
        '''
        att_map = x * master
        att_map = torch.tanh(self.att_projM(att_map))

        att_map = torch.matmul(att_map, self.att_weightM)

        # apply temperature
        att_map = att_map / self.temp

        att_map = F.softmax(att_map, dim=-2)

        return att_map

    def _derive_att_map(self, x, num_type1, num_type2):
        '''
        x           :(#bs, #node, #dim)
        out_shape   :(#bs, #node, #node, 1)
        '''
        att_map = self._pairwise_mul_nodes(x)
        # size: (#bs, #node, #node, #dim_out)
        att_map = torch.tanh(self.att_proj(att_map))
        # size: (#bs, #node, #node, 1)

        att_board = torch.zeros_like(att_map[:, :, :, 0]).unsqueeze(-1)

        att_board[:, :num_type1, :num_type1, :] = torch.matmul(
            att_map[:, :num_type1, :num_type1, :], self.att_weight11)
        att_board[:, num_type1:, num_type1:, :] = torch.matmul(
            att_map[:, num_type1:, num_type1:, :], self.att_weight22)
        att_board[:, :num_type1, num_type1:, :] = torch.matmul(
            att_map[:, :num_type1, num_type1:, :], self.att_weight12)
        att_board[:, num_type1:, :num_type1, :] = torch.matmul(
            att_map[:, num_type1:, :num_type1, :], self.att_weight12)

        att_map = att_board

        # att_map = torch.matmul(att_map, self.att_weight12)

        # apply temperature
        att_map = att_map / self.temp

        att_map = F.softmax(att_map, dim=-2)

        return att_map

    def _project(self, x, att_map):
        x1 = self.proj_with_att(torch.matmul(att_map.squeeze(-1), x))
        x2 = self.proj_without_att(x)

        return x1 + x2

    def _project_master(self, x, master, att_map):

        x1 = self.proj_with_attM(torch.matmul(
            att_map.squeeze(-1).unsqueeze(1), x))
        x2 = self.proj_without_attM(master)

        return x1 + x2

    def _apply_BN(self, x):
        org_size = x.size()
        x = x.view(-1, org_size[-1])
        x = self.bn(x)
        x = x.view(org_size)

        return x

    def _init_new_params(self, *size):
        out = nn.Parameter(torch.FloatTensor(*size))
        nn.init.xavier_normal_(out)
        return out


class GraphPool(nn.Module):
    def __init__(self, k: float, in_dim: int, p: Union[float, int]):
        super().__init__()
        self.k = k
        self.sigmoid = nn.Sigmoid()
        self.proj = nn.Linear(in_dim, 1)
        self.drop = nn.Dropout(p=p) if p > 0 else nn.Identity()
        self.in_dim = in_dim

    def forward(self, h):
        Z = self.drop(h)
        weights = self.proj(Z)
        scores = self.sigmoid(weights)
        new_h = self.top_k_graph(scores, h, self.k)

        return new_h

    def top_k_graph(self, scores, h, k):
        """
        args
        =====
        scores: attention-based weights (#bs, #node, 1)
        h: graph data (#bs, #node, #dim)
        k: ratio of remaining nodes, (float)

        returns
        =====
        h: graph pool applied data (#bs, #node', #dim)
        """
        _, n_nodes, n_feat = h.size()
        n_nodes = max(int(n_nodes * k), 1)
        _, idx = torch.topk(scores, n_nodes, dim=1)
        idx = idx.expand(-1, -1, n_feat)

        h = h * scores
        h = torch.gather(h, 1, idx)

        return h


class CONV(nn.Module):
    @staticmethod
    def to_mel(hz):
        return 2595 * np.log10(1 + hz / 700)

    @staticmethod
    def to_hz(mel):
        return 700 * (10**(mel / 2595) - 1)

    def __init__(self,
                 out_channels,
                 kernel_size,
                 sample_rate=16000,
                 in_channels=1,
                 stride=1,
                 padding=0,
                 dilation=1,
                 bias=False,
                 groups=1,
                 mask=False):
        super().__init__()
        if in_channels != 1:

            msg = "SincConv only support one input channel (here, in_channels = {%i})" % (
                in_channels)
            raise ValueError(msg)
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.sample_rate = sample_rate

        # Forcing the filters to be odd (i.e, perfectly symmetrics)
        if kernel_size % 2 == 0:
            self.kernel_size = self.kernel_size + 1
        self.stride = stride
        self.padding = padding
        self.dilation = dilation
        self.mask = mask
        if bias:
            raise ValueError('SincConv does not support bias.')
        if groups > 1:
            raise ValueError('SincConv does not support groups.')

        NFFT = 512
        f = int(self.sample_rate / 2) * np.linspace(0, 1, int(NFFT / 2) + 1)
        fmel = self.to_mel(f)
        fmelmax = np.max(fmel)
        fmelmin = np.min(fmel)
        filbandwidthsmel = np.linspace(fmelmin, fmelmax, self.out_channels + 1)
        filbandwidthsf = self.to_hz(filbandwidthsmel)

        self.mel = filbandwidthsf
        self.hsupp = torch.arange(-(self.kernel_size - 1) / 2,
                                  (self.kernel_size - 1) / 2 + 1)
        self.band_pass = torch.zeros(self.out_channels, self.kernel_size)
        for i in range(len(self.mel) - 1):
            fmin = self.mel[i]
            fmax = self.mel[i + 1]
            hHigh = (2*fmax/self.sample_rate) * \
                np.sinc(2*fmax*self.hsupp/self.sample_rate)
            hLow = (2*fmin/self.sample_rate) * \
                np.sinc(2*fmin*self.hsupp/self.sample_rate)
            hideal = hHigh - hLow

            self.band_pass[i, :] = Tensor(np.hamming(
                self.kernel_size)) * Tensor(hideal)

    def forward(self, x, mask=False):
        band_pass_filter = self.band_pass.clone().to(x.device)
        if mask:
            A = np.random.uniform(0, 20)
            A = int(A)
            A0 = random.randint(0, band_pass_filter.shape[0] - A)
            band_pass_filter[A0:A0 + A, :] = 0
        else:
            band_pass_filter = band_pass_filter

        self.filters = (band_pass_filter).view(self.out_channels, 1,
                                               self.kernel_size)

        return F.conv1d(x,
                        self.filters,
                        stride=self.stride,
                        padding=self.padding,
                        dilation=self.dilation,
                        bias=None,
                        groups=1)


class Residual_block(nn.Module):
    def __init__(self, nb_filts, first=False):
        super().__init__()
        self.first = first

        if not self.first:
            self.bn1 = nn.BatchNorm2d(num_features=nb_filts[0])
        self.conv1 = nn.Conv2d(in_channels=nb_filts[0],
                               out_channels=nb_filts[1],
                               kernel_size=(2, 3),
                               padding=(1, 1),
                               stride=1)
        self.selu = nn.SELU(inplace=True)

        self.bn2 = nn.BatchNorm2d(num_features=nb_filts[1])
        self.conv2 = nn.Conv2d(in_channels=nb_filts[1],
                               out_channels=nb_filts[1],
                               kernel_size=(2, 3),
                               padding=(0, 1),
                               stride=1)

        if nb_filts[0] != nb_filts[1]:
            self.downsample = True
            self.conv_downsample = nn.Conv2d(in_channels=nb_filts[0],
                                             out_channels=nb_filts[1],
                                             padding=(0, 1),
                                             kernel_size=(1, 3),
                                             stride=1)

        else:
            self.downsample = False
        self.mp = nn.MaxPool2d((1, 3))  # self.mp = nn.MaxPool2d((1,4))

    def forward(self, x):
        identity = x
        if not self.first:
            out = self.bn1(x)
            out = self.selu(out)
        else:
            out = x
        out = self.conv1(x)

        # print('out',out.shape)
        out = self.bn2(out)
        out = self.selu(out)
        # print('out',out.shape)
        out = self.conv2(out)
        #print('conv2 out',out.shape)
        if self.downsample:
            identity = self.conv_downsample(identity)

        out += identity
        out = self.mp(out)
        return out


class Model(nn.Module):
    def __init__(self, config: dict):
        super().__init__()

        self.config = config["model_config"]
        filts = self.config["filters"]
        gat_dims = self.config["gat_dims"]
        pool_ratios = self.config["pool_ratios"]
        temperatures = self.config["temperatures"]

        self.conv_time = CONV(out_channels=filts[0],
                              kernel_size=self.config["first_conv"],
                              in_channels=1)
        self.first_bn = nn.BatchNorm2d(num_features=1)

        self.drop = nn.Dropout(0.5, inplace=True)
        self.drop_way = nn.Dropout(0.2, inplace=True)
        self.selu = nn.SELU(inplace=True)

        self.encoder = nn.Sequential(
            nn.Sequential(Residual_block(nb_filts=filts[1], first=True)),
            nn.Sequential(Residual_block(nb_filts=filts[2])),
            nn.Sequential(Residual_block(nb_filts=filts[3])),
            nn.Sequential(Residual_block(nb_filts=filts[4])),
            nn.Sequential(Residual_block(nb_filts=filts[4])),
            nn.Sequential(Residual_block(nb_filts=filts[4])))

        self.pos_S = nn.Parameter(torch.randn(1, 23, filts[-1][-1]))
        self.master1 = nn.Parameter(torch.randn(1, 1, gat_dims[0]))
        self.master2 = nn.Parameter(torch.randn(1, 1, gat_dims[0]))

        self.GAT_layer_S = GraphAttentionLayer(filts[-1][-1],
                                               gat_dims[0],
                                               temperature=temperatures[0])
        self.GAT_layer_T = GraphAttentionLayer(filts[-1][-1],
                                               gat_dims[0],
                                               temperature=temperatures[1])

        self.HtrgGAT_layer_ST11 = HtrgGraphAttentionLayer(
            gat_dims[0], gat_dims[1], temperature=temperatures[2])
        self.HtrgGAT_layer_ST12 = HtrgGraphAttentionLayer(
            gat_dims[1], gat_dims[1], temperature=temperatures[2])

        self.HtrgGAT_layer_ST21 = HtrgGraphAttentionLayer(
            gat_dims[0], gat_dims[1], temperature=temperatures[2])

        self.HtrgGAT_layer_ST22 = HtrgGraphAttentionLayer(
            gat_dims[1], gat_dims[1], temperature=temperatures[2])

        self.pool_S = GraphPool(pool_ratios[0], gat_dims[0], 0.3)
        self.pool_T = GraphPool(pool_ratios[1], gat_dims[0], 0.3)
        self.pool_hS1 = GraphPool(pool_ratios[2], gat_dims[1], 0.3)
        self.pool_hT1 = GraphPool(pool_ratios[2], gat_dims[1], 0.3)

        self.pool_hS2 = GraphPool(pool_ratios[2], gat_dims[1], 0.3)
        self.pool_hT2 = GraphPool(pool_ratios[2], gat_dims[1], 0.3)

        self.out_layer = nn.Linear(5 * gat_dims[1], 2)

    def forward(self, x, Freq_aug=False):

        x = x.unsqueeze(1)
        x = self.conv_time(x, mask=Freq_aug)
        x = x.unsqueeze(dim=1)
        x = F.max_pool2d(torch.abs(x), (3, 3))
        x = self.first_bn(x)
        x = self.selu(x)

        # get embeddings using encoder
        # (#bs, #filt, #spec, #seq)
        e = self.encoder(x)

        # spectral GAT (GAT-S)
        e_S, _ = torch.max(torch.abs(e), dim=3)  # max along time
        e_S = e_S.transpose(1, 2) + self.pos_S

        gat_S = self.GAT_layer_S(e_S)
        out_S = self.pool_S(gat_S)  # (#bs, #node, #dim)

        # temporal GAT (GAT-T)
        e_T, _ = torch.max(torch.abs(e), dim=2)  # max along freq
        e_T = e_T.transpose(1, 2)

        gat_T = self.GAT_layer_T(e_T)
        out_T = self.pool_T(gat_T)

        # learnable master node
        master1 = self.master1.expand(x.size(0), -1, -1)
        master2 = self.master2.expand(x.size(0), -1, -1)

        # inference 1
        out_T1, out_S1, master1 = self.HtrgGAT_layer_ST11(
            out_T, out_S, master=self.master1)

        out_S1 = self.pool_hS1(out_S1)
        out_T1 = self.pool_hT1(out_T1)

        out_T_aug, out_S_aug, master_aug = self.HtrgGAT_layer_ST12(
            out_T1, out_S1, master=master1)
        out_T1 = out_T1 + out_T_aug
        out_S1 = out_S1 + out_S_aug
        master1 = master1 + master_aug

        # inference 2
        out_T2, out_S2, master2 = self.HtrgGAT_layer_ST21(
            out_T, out_S, master=self.master2)
        out_S2 = self.pool_hS2(out_S2)
        out_T2 = self.pool_hT2(out_T2)

        out_T_aug, out_S_aug, master_aug = self.HtrgGAT_layer_ST22(
            out_T2, out_S2, master=master2)
        out_T2 = out_T2 + out_T_aug
        out_S2 = out_S2 + out_S_aug
        master2 = master2 + master_aug

        out_T1 = self.drop_way(out_T1)
        out_T2 = self.drop_way(out_T2)
        out_S1 = self.drop_way(out_S1)
        out_S2 = self.drop_way(out_S2)
        master1 = self.drop_way(master1)
        master2 = self.drop_way(master2)

        out_T = torch.max(out_T1, out_T2)
        out_S = torch.max(out_S1, out_S2)
        master = torch.max(master1, master2)

        T_max, _ = torch.max(torch.abs(out_T), dim=1)
        T_avg = torch.mean(out_T, dim=1)

        S_max, _ = torch.max(torch.abs(out_S), dim=1)
        S_avg = torch.mean(out_S, dim=1)

        last_hidden = torch.cat(
            [T_max, T_avg, S_max, S_avg, master.squeeze(1)], dim=1)

        last_hidden = self.drop(last_hidden)
        output = self.out_layer(last_hidden)

        return last_hidden, output


In [65]:
# EER computation function.

import numpy as np
import torch
import time
from sklearn import metrics
from torch.utils.data import DataLoader

def compute_eer(model: torch.nn.Module, test_set: DataLoader) -> tuple:
    t0 = time.time()
    model.eval()
    scores = []
    targets = []
    for X, y in test_set:
        X = X.to(active_device)
        y = y.to(active_device)
        with torch.no_grad():
            _, logits = model(X)
        curr_scores = torch.softmax(logits, dim=1)
        curr_scores = curr_scores.data.detach().cpu().numpy()
        curr_scores = curr_scores[:,1] - curr_scores[:,0]
        curr_scores = curr_scores.clip(min=0)

        scores.append(curr_scores)
        targets.append(y.detach().cpu().numpy())

    scores = np.concatenate(scores, axis=0)
    targets = np.concatenate(targets, axis=0)
    fpr, tpr, thresholds = metrics.roc_curve(targets, scores)
    fnr = 1 - tpr
    eer_index = np.nanargmin(np.absolute(fpr - fnr))
        
    return np.mean((fpr[eer_index], fnr[eer_index]))*100, (time.time() - t0)


In [70]:
# The training procedure.

!pip install torchcontrib
print('torchcontrib installed!')

import torch
import numpy as np
import time
import copy
from torchcontrib.optim import SWA
from torch.utils.data import DataLoader

DEFAULT_MAX_EER = 1000
SEED = 42

class Trainer:
    def __init__(self):
        self.active_device = (torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu'))

    def train(self, config: dict) -> tuple:
        print('AASIST trainer - start')

        pending_model = Model(config)
        pending_model = pending_model.to(active_device)
        optimal_model = None

        print("Load samples")
        train_dataset = Dataset_ASVspoof2019(config, "train")
        dev_dataset = Dataset_ASVspoof2019(config, "dev")
        eval_dataset = Dataset_ASVspoof2019(config, "eval")

        print("Train Set size = " + str(len(train_dataset)))
        print("Dev Set size = " + str(len(dev_dataset)))
        print("Eval Set size = " + str(len(eval_dataset)))

        gen = torch.Generator()
        gen.manual_seed(SEED)
        train_loader = DataLoader(train_dataset,
                                  batch_size=config['batch_size'],
                                  shuffle=True,
                                  drop_last=True,
                                  pin_memory=True,
                                  worker_init_fn=seed_worker,
                                  generator=gen)
        
        dev_loader = DataLoader(dev_dataset,
                                batch_size=config['batch_size'],
                                shuffle=False,
                                drop_last=False,
                                pin_memory=True)
        
        eval_loader = DataLoader(eval_dataset,
                                 batch_size=config['batch_size'],
                                 shuffle=False,
                                 drop_last=False,
                                 pin_memory=True)

        
        print('set optimizer & loss')
        optim_config = config["optim_config"]
        optim_config["epochs"] = config["num_epochs"]
        optim_config["steps_per_epoch"] = len(train_loader)
        optimizer, scheduler = create_optimizer(pending_model.parameters(), optim_config)
        optimizer_swa = SWA(optimizer)

        weight = torch.FloatTensor([0.1, 0.9]).to(self.active_device)
        criterion = torch.nn.CrossEntropyLoss(weight=weight)
        criterion = criterion.to(active_device)

        best_dev_eer = DEFAULT_MAX_EER
        best_dev_epoch = -1
        best_eval_eer = DEFAULT_MAX_EER
        best_eval_epoch = -1
        
        print('start training loops. #epochs = ' + str(config['num_epochs']))
        print(f"{'Epoch':^7} | {'Train Loss':^12} | {'Train EER':^11} | {'Dev EER':^10} | {'Eval EER':^9} | {'Elapsed':^9}")
        print("-"*50)  
        
        min_loss = 100
        num_no_imp = 0
        for i in range(config['num_epochs']):
            epoch = i + 1
            epoch_start_time = time.time()
            total_loss = 0
            num_batches = 0
            
            pending_model.train()
            for signals, labels in train_loader:
                signals = signals.to(self.active_device)
                labels = labels.to(self.active_device)
                
                _, logits = pending_model(signals)
                
                optimizer.zero_grad()
                loss = criterion(logits, labels)
                total_loss += loss.item()
                num_batches += 1
                loss.backward()
                optimizer.step()

                if scheduler is not None:
                    scheduler.step()
                
            avg_loss = total_loss / num_batches
            epoch_time = time.time() - epoch_start_time
            
            # Validation test.
            dev_eer, _ = compute_eer(pending_model, dev_loader)
            train_eer, _ = compute_eer(pending_model, train_loader)
            eval_eer, _ = compute_eer(pending_model, eval_loader)
            print(f"{epoch:^7} | {avg_loss:^12.6f} | {train_eer:^9.2f} | {dev_eer:^9.2f} |  {eval_eer:^9.4f} | {epoch_time:^9.2f}")
                
            if avg_loss < min_loss:
                min_loss = avg_loss
                num_no_imp = 0
            else:
                num_no_imp += 1
                
            if num_no_imp > config["early_stop_max_no_imp"]:
                print('early stop exit')
                break
            
            if epoch < config["min_valid_epochs"]:
                continue
            
            if dev_eer < best_dev_eer:
                best_dev_eer = dev_eer
                best_dev_epoch = epoch
                optimal_model = copy.deepcopy(pending_model)

            if eval_eer < best_eval_eer:
                best_eval_eer = eval_eer
                best_eval_epoch = epoch
        
        print('AASIST trainer - end\n')
        print("Best Dev EER = {:.2f}".format(best_dev_eer) + ", best epoch = " + str(best_dev_epoch))
        print("Best Eval Acc = {:.2f}".format(best_eval_eer) + ", best epoch = " + str(best_eval_epoch))
        return pending_model, optimal_model, best_dev_epoch


Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
torchcontrib installed!


In [71]:
trainer = Trainer()
last_epoch_model, best_dev_eer_model, best_dev_eer_epoch = trainer.train(config)

AASIST trainer - start
Load samples
Train Set size = 20288
Dev Set size = 9922
Eval Set size = 9796
set optimizer & loss
start training loops. #epochs = 15
 Epoch  |  Train Loss  |  Train EER  |  Dev EER   | Eval EER  |  Elapsed 
--------------------------------------------------
   1    |   0.693075   |   15.98   |   16.70   |   14.4073  |  765.73  
   2    |   0.424156   |   6.45    |   8.48    |   10.1140  |  790.64  
   3    |   0.267818   |   4.61    |   6.90    |   8.5320   |  789.16  
   4    |   0.219517   |   3.17    |   4.96    |   5.8810   |  791.90  
   5    |   0.173083   |   2.40    |   4.45    |   4.6832   |  792.12  
   6    |   0.142940   |   2.15    |   5.34    |   4.5601   |  788.45  
   7    |   0.123636   |   1.62    |   4.85    |   4.2387   |  791.34  
   8    |   0.115699   |   1.62    |   3.33    |   3.4705   |  792.67  
   9    |   0.104050   |   1.76    |   3.96    |   3.6198   |  792.21  
  10    |   0.091661   |   1.79    |   2.75    |   3.7087   |  791.92  