SE BLOCK for SENET, basicblock for RESNET

In [12]:
## ResNet Blocks and SE Blocks

import math

def conv3x3(in_planes, out_planes, stride=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=1, bias=False)
                     
class SELayer(nn.Module):
    def __init__(self, channel, reduction=16):
        super(SELayer, self).__init__()
        # print('se reduction: ', reduction)
        # print(channel // reduction)
        self.avg_pool = nn.AdaptiveAvgPool2d(1) # F_squeeze 
        self.fc = nn.Sequential(
            nn.Linear(channel, channel // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(channel // reduction, channel, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):   # x: B*C*D*T
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y.expand_as(x)

class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(BasicBlock, self).__init__()
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = nn.BatchNorm2d(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out

class SEBasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None, reduction=16):
        super(SEBasicBlock, self).__init__()
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes, 1)
        self.bn2 = nn.BatchNorm2d(planes)
        self.se = SELayer(planes, reduction)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.se(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out


ECANET

In [2]:
# import torch
# from torch import nn
from torch.nn.parameter import Parameter

class eca_layer(nn.Module):
    def __init__(self, channel, k_size=3):
        super(eca_layer, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.conv = nn.Conv1d(1, 1, kernel_size=k_size, padding=(k_size - 1) // 2, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        # feature descriptor on the global spatial information
        y = self.avg_pool(x)

        # Two different branches of ECA module
        y = self.conv(y.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1)

        # Multi-scale information fusion
        y = self.sigmoid(y)

        return x * y.expand_as(x)


In [4]:
def conv3x3(in_planes, out_planes, stride=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=1, bias=False)

In [4]:
class ECABasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None, k_size=3):
        super(ECABasicBlock, self).__init__()
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes, 1)
        self.bn2 = nn.BatchNorm2d(planes)
        self.eca = eca_layer(planes, k_size)
        self.downsample = downsample
        self.stride = stride
#         self.dropout= nn.Dropout(0.5)

    def forward(self, x):
        residual = x
        out = self.conv1(x)
#         out=self.dropout(out)
        out = self.bn1(out)
        out = self.relu(out)
#         

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.eca(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out


MODELS : ECANET,SENET,RESNET

In [14]:
import math
# import torch
# import torch.nn as nn
import torch.nn.functional as F
import torch.utils.model_zoo as model_zoo

class ResNet(nn.Module):
    def __init__(self, block, layers, num_classes, KaimingInit=False):

        self.inplanes = 16

        super(ResNet, self).__init__()

        self.conv1 = nn.Conv2d(1, 16, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(16)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        self.layer1 = self._make_layer(block, 16, layers[0])
        self.layer2 = self._make_layer(block, 32, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 64, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 128, layers[3], stride=2)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.middle_fc1 = nn.Linear(16* block.expansion, 2)
        self.middle_fc2=nn.Linear(32* block.expansion, 2)
        self.middle_fc3=nn.Linear(64* block.expansion, 2)
        self.classifier = nn.Linear(128 * block.expansion, 2)
#         self.dropout= nn.Dropout(0.2)

        if KaimingInit == True:
            print('Using Kaiming Initialization.')
            for m in self.modules():
                if isinstance(m, nn.Conv2d):
                    nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')

        for m in self.modules():
            if isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes * block.expansion)
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample))
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes))

        return nn.Sequential(*layers)

    def forward(self, x):
        #print(x.size())
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
#         x=self.dropout(x)
        x = self.maxpool(x)
        #print(x.size())

        x = self.layer1(x)
        # return middle output1
        middle_output1=self.avgpool(x).view(x.size()[0], -1)
        middle_output1 = self.middle_fc1(middle_output1)
        #print(x.size())
        x = self.layer2(x)
        #return middle output2
        middle_output2=self.avgpool(x).view(x.size()[0], -1)
        middle_output2 = self.middle_fc2(middle_output2)
        #print(x.size())
        x = self.layer3(x)
        #return middle output3
        middle_output3=self.avgpool(x).view(x.size()[0], -1)
        middle_output3 = self.middle_fc3(middle_output3)
        #print(x.size())
        x = self.layer4(x)
        #print(x.size())
        x = self.avgpool(x).view(x.size()[0], -1)
        #print(x.shape)
        out = self.classifier(x)
        #print(out.shape)
        return middle_output1,middle_output2,middle_output3,out,x

def eca_resnet18():
    model = ResNet(ECABasicBlock, [2, 2, 2, 2], 2)
    return model


def eca_resnet34():
    model = ResNet(ECABasicBlock, [3, 4, 6, 3], 2)
    return model

def se_resnet18():
    model = ResNet(SEBasicBlock,, [2, 2, 2, 2], 2)
    return model

def se_resnet34():
    model = ResNet(SEBasicBlock,, [3, 4, 6, 3], 2)
    return model

def resnet18():
    model=ResNet(BasicBlock,[2,2,2,2],2)
    return model

def resnet34():
    model = ResNet(BasicBlock, [3, 4, 6, 3], 2)
    return model


In [None]:
!pip install torchsummary

In [None]:
from torchsummary import summary
model = eca_resnet18()
summary(model.cuda(), (1, 128, 128))

In [None]:
torch.manual_seed(0)
def checksum(model):
    s = torch.sum(torch.stack([p.double().abs().sum() for p in model.parameters()]))
    return s

for _ in range(10):
    torch.manual_seed(0)
    model = eca_resnet18()
#     dd=domain_disc()
    s = checksum(model)
#     s1=checksum(dd)
    print(s)


In [None]:
if torch.cuda.is_available():
        device = "cuda"
else:
        device = "cpu"
print(f"Using device {device}")

Metric Functions are described below:
    -> t-SNE diagram for feature visualization
    -> DET curve 
    -> Equal Error Rate Computation
    -> Normalized Mutual Information Calculation

In [None]:
# Plot t-SNE embeddings
import numpy as np
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
import sklearn.metrics
from sklearn.metrics import det_curve, DetCurveDisplay
from sklearn.cluster import KMeans
from sklearn.metrics.cluster import normalized_mutual_info_score

#t-SNE diagram for visualizing impact on feature space
def tsned(fet_a,label_a, domain_a):
    features_np = np.concatenate([tensor.detach().cpu().numpy() for tensor in fet_a], axis=0)
    domains_np = np.concatenate([tensor.detach().cpu().numpy() for tensor in domain_a], axis=0)
    targets_np = np.concatenate([tensor.detach().cpu().numpy() for tensor in label_a], axis=0)
    data = np.column_stack((features_np, domains_np, targets_np))
    features = data[:, :-2]
    domains = data[:, -2].astype(int)
    spoof_or_genuine = data[:, -1].astype(int)
    
#     pca_50 = PCA(n_components=50)
#     pca_result_50 = pca_50.fit_transform(features)

    # Compute t-SNE embeddings
    tsne = TSNE(n_components=2, random_state=42)
    embeddings = tsne.fit_transform(features)
    plt.figure(figsize=(8, 6))
    domain_colors = {domain: color for domain, color in zip(np.unique(domains), plt.cm.tab10.colors)}
    
    m=(spoof_or_genuine == 0)
    genuine_embed1=embeddings[m,0]
    genuine_embed2=embeddings[m,1]

    markers = {0: 'o', 1: '^'}

    for domain in np.unique(domains):
        domain_mask = (domains == domain)
        color = domain_colors[domain]
        for spoof in np.unique(spoof_or_genuine):
            spoof_mask = (spoof_or_genuine == spoof)
            mask = domain_mask & spoof_mask
            marker = markers[spoof]
            edgecolor = 'k' if marker == '^' else None  # Set edge color only for '^' marker
            plt.scatter(embeddings[mask, 0], embeddings[mask, 1],
                        label=f"Domain {domain}, {'Spoof' if spoof == 1 else 'Genuine'}",
                        color=color, marker=marker, edgecolors=edgecolor)

    plt.title('t-SNE plot with domains and spoof/genuine')
    plt.xlabel('t-SNE Component 1')
    plt.ylabel('t-SNE Component 2')
    plt.legend()
    plt.show()

#Equal Error Rate computation 
def compute_eer(label, pred, positive_label=1):
    fpr, tpr, threshold = sklearn.metrics.roc_curve(label, pred)
    fnr = 1 - tpr

    # the threshold of fnr == fpr
    eer_threshold = threshold[np.nanargmin(np.absolute((fnr - fpr)))]

    # theoretically eer from fpr and eer from fnr should be identical but they can be slightly differ in reality
    eer_1 = fpr[np.nanargmin(np.absolute((fnr - fpr)))]
    eer_2 = fnr[np.nanargmin(np.absolute((fnr - fpr)))]

    # return the mean of eer from fpr and from fnr
    eer = (eer_1 + eer_2) / 2
    return eer

#DET Curve 
def plot_det(y_true,y_pred):
    fpr, fnr, _ = det_curve(y_true, y_pred)
    display = DetCurveDisplay(fpr=fpr, fnr=fnr)
    display.plot()
    plt.title('DET Curve')
    plt.legend()
    plt.show()

# Calculation of normalised mutual information score sue to domain generalization methods
# domain_a here refers to triplet labels
def cluster_triplet(fet_a,label_a, domain_a):
    features_np = np.concatenate([tensor.detach().cpu().numpy() for tensor in fet_a], axis=0)
    domains_np = np.concatenate([tensor.detach().cpu().numpy() for tensor in domain_a], axis=0)
    targets_np = np.concatenate([tensor.detach().cpu().numpy() for tensor in label_a], axis=0)
    data = np.column_stack((features_np, domains_np, targets_np))
    # Separate features, domains, and spoof_or_genuine
    features = data[:, :-2]
    domains = data[:, -2].astype(int)
    spoof_or_genuine = data[:, -1].astype(int)
    tsne = TSNE(n_components=2, random_state=42)
    embeddings = tsne.fit_transform(features)
    pred= KMeans(n_clusters=2, random_state=0, n_init="auto").fit_predict(embeddings)
    score=normalized_mutual_info_score(domains, pred)
    return score
    
    
    
    

TRAIN DATASET CLASS 
gets the melspectrogram feature, corresponding fake speech label, domain label and triplet label (used for triplet mining)
annot : npy file containing labels
audio_dir : npy files containing features
domain_path : npy files containing domain labels
(npy file containing triplet labels is imported in the class itself)


In [None]:
import os
import numpy as np
import torch
from torch.utils.data import Dataset
import pandas as pd
import torchaudio


class UrbanSoundDataset(Dataset):

    def __init__(self,annot,domain_path,
                 audio_dir,
                 device):
        self.annotations = np.load(annot)
        self.domain_path=np.load(domain_path)
        self.triplet=np.load("/kaggle/input/melspec/MEL_T.npy")
        self.x=np.load(audio_dir)
        self.files=np.expand_dims(self.x,axis=1)
        self.audio_dir = torch.tensor(self.files,dtype=torch.float32)
        self.device = device

    def __len__(self):
        return len(self.annotations)

    def __getitem__(self, index):
        label = torch.tensor(self._get_audio_sample_label(index),dtype=torch.int64)
        # label = self.label_to_numeric(label)
        feat=self._get_audio_sample_feature(index)
        domain=torch.tensor(self._get_audio_sample_domain(index))
        triplet_label=torch.tensor(self._get_audio_sample_triplet(index))
        return feat, label,domain,triplet_label

    def label_to_numeric(self,label):
        if label == 'genuine':
            return 0
        elif label == 'spoof':
            return 1
        else:
            raise ValueError("Invalid label")

    def _get_audio_sample_label(self, index):
        return self.annotations[index]
    def _get_audio_sample_domain(self, index):
        return self.domain_path[index]
    def _get_audio_sample_triplet(self, index):
        return self.triplet[index]

    def _get_audio_sample_feature(self, index):
        return self.audio_dir[index]

if __name__ == "__main__":
    ANNOTATIONS_FILE = "/kaggle/input/melspec/MEL_y.npy"
    AUDIO_DIR = "/kaggle/input/melspec/MEL_X.npy"
    DOMAIN_PATH="/kaggle/input/melspec/MEL_D.npy"


    if torch.cuda.is_available():
        device = "cuda"
    else:
        device = "cpu"
    print(f"Using device {device}")

TEST DATASET CLASS 
gets the mespectrogram features ad corresponding fake speech labels
annot : npy file with fake speech labels
x_path : npy file with features


In [None]:
import os
import numpy as np
import torch
from torch.utils.data import Dataset
import pandas as pd
import torchaudio

class Test(Dataset):

    def __init__(self,annot,
                 audio_dir,x_path,
                 device):

        self.annotations = np.load(annot)
        self.x=np.load(x_path)
        self.files=np.expand_dims(self.x,axis=1)
        self.audio_dir = torch.tensor(self.files,dtype=torch.float32)
        self.device = device

    def __len__(self):
        return len(self.annotations)

    def __getitem__(self, index):
        # audio_sample_path = self._get_audio_sample_path(index)
        label = torch.tensor(self._get_audio_sample_label(index))
        # label = self.label_to_numeric(label)
        feat=self._get_audio_sample_feature(index)
        return feat, label

    def label_to_numeric(self,label):
        if label == 'genuine':
            return 0
        elif label == 'spoof':
            return 1
        else:
            raise ValueError("Invalid label")

    def _get_audio_sample_label(self, index):
        return self.annotations[index]

    def _get_audio_sample_feature(self, index):
        return self.audio_dir[index]

if __name__ == "__main__":

    audio_dir = "/kaggle/input/melspec/MEL_X.npy"


    if torch.cuda.is_available():
        device = "cuda"
    else:
        device = "cpu"
    print(f"Using device {device}")

Intializing test datasets

In [22]:
test_pa=Test("/kaggle/input/melspec/pa_y.npy",audio_dir,"/kaggle/input/melspec/PA_EVAL_MEL_STACK_X.npy",device)
test_la=Test("/kaggle/input/melspec/la_y.npy",audio_dir,"/kaggle/input/melspec/LA_EVAL_MEL_STACK_X.npy",device)
test_itw=Test("/kaggle/input/melspec/itw_y.npy",audio_dir,"/kaggle/input/melspec/ITWW_MEL_STACK_X.npy",device)

TESTING
ouputs accuracy, equal error rate, confusion matrix and det curve

In [None]:
from torch.utils.data import DataLoader

def create_test_data_loader(test_data, batch_size):
    test_dataloader = DataLoader(test_data, batch_size=batch_size)
    return test_dataloader
softmax = nn.Softmax(dim=1)
BATCH_SIZE = 128
if __name__ == "__main__":
    test_dataloader = create_test_data_loader(test_la, BATCH_SIZE)
    model_ev.eval()  # Set the model to evaluation mode
    val_loss = 0.0
    val_steps = 0
    total_val = 0
    correct_val = 0
    y_pred_v=[]
    y_true_v=[]
    y_one_v=[]
    val_ax=[]
    with torch.no_grad():  
        for inp, target in test_dataloader:
            inp, target = inp.to(device), target.to(device)
            middle_output1,middle_output2,middle_output3,prediction,features = model_ev(inp)
            probabilities = softmax(prediction)
            label_1_probabilities = probabilities[:, 1]
            predicted_classes = torch.argmax(probabilities, dim=1)
            correct_val += (predicted_classes == target).sum().item()
            total_val += target.size(0)
            y_pred_v.extend(predicted_classes.data.cpu().numpy())
            y_true_v.extend(target.data.cpu().numpy())
            y_one_v.extend(label_1_probabilities.data.cpu().numpy())           
            val_steps += 1
    test_accuracy = correct_val / total_val
    print("test  Accuracy = {}".format(test_accuracy))
    val_ax.append(test_accuracy)
    eer=compute_eer(y_true_v,y_one_v)
    print("test eer = {}".format(eer))
    classes = ('genuine','spoof')
    cf_matrix2 = confusion_matrix(y_true_v, y_pred_v)
    df_cm2 = pd.DataFrame(cf_matrix2 / np.sum(cf_matrix2, axis=1)[:, None], index = [i for i in classes],
                     columns = [i for i in classes])
    plt.figure(figsize = (12,7))
    sn.heatmap(df_cm2, annot=True)
    plot_det(y_true_v,y_one_v)
    la2=y_pred_v
    



COHEN's KAPPA SCORE (between normal base model and proposed model)

In [None]:
from sklearn.metrics import cohen_kappa_score
la_cohe=cohen_kappa_score(sd_la, la)
pa_cohe=cohen_kappa_score(sd_pa, pa)
itw_cohe=cohen_kappa_score(sd_itw, itw)
print("la ",la_cohe)
print("pa ",pa_cohe)
print("itwa ",itw_cohe)

HYPERPARAMETER TUNING WITH OPTUNA
training,validation curve, confusion matrix : training,validation and t-SNE diagram on validation data visualized

In [None]:
import torch
import torchaudio
from torch import nn
from torch.utils.data import DataLoader
import numpy as np
import optuna
from optuna.trial import TrialState
from torch.utils.data import random_split
from matplotlib.pylab import plt
from sklearn.metrics import confusion_matrix
import seaborn as sn
import pandas as pd
BATCH_SIZE = 128
EPOCHS = 20
# LEARNING_RATE = 0.001
softmax = nn.Softmax(dim=1)

def create_data_loader(train_data, batch_size):
    train_dataloader = DataLoader(train_data, batch_size=batch_size,shuffle=True,generator=g)
    return train_dataloader

def kd_loss_function(output, target_output,temperature):
#     temperature=3
    output = output / temperature
    output_log_softmax = torch.log_softmax(output, dim=1)
    loss_kd = -torch.mean(torch.sum(output_log_softmax * target_output, dim=1))
    return loss_kd


def train_single_epoch(trial,model, data_loader,val_loader, loss_fn,optimiser, device,i,alpha,train_losses,val_losses,temperature):
#     temperature=3
    correct=0
    fet_a=[]
    label_a=[]
    domain_a=[]
    total=0
    y_pred_tr=[]
    y_true_tr=[]
    total_train_loss=0
    total_steps=0
    model.train()
    for inp, target,domain,tr in data_loader:
        inp = inp.to(device)
        target=target.to(device)
        domain=domain.to(device)
        tr=tr.to(device)
        middle_output1,middle_output2,middle_output3,prediction,features = model(inp)
        probabilities = softmax(prediction.detach())
        predicted_classes = torch.argmax(probabilities, dim=1)
        correct += (predicted_classes == target).sum().item()
        total += target.size(0)
        y_pred_tr.extend(predicted_classes.data.cpu().numpy())
        y_true_tr.extend(target.data.cpu().numpy())
        total_loss = loss_fn(prediction, target)
        total_train_loss+=total_loss.detach().cpu().numpy()
        total_steps+=1
        optimiser.zero_grad()
        total_loss.backward()
        optimiser.step()
    train_losses.append(total_train_loss/total_steps)
    print(f"train loss: {total_loss.item()}")
    accuracy = correct / total
    print("train   Accuracy = {}".format(accuracy))

    val_loss = 0.0
    val_steps = 0
    total_val = 0
    correct_val = 0
    y_pred_v=[]
    y_true_v=[]
    fet_a=[]
    label_a=[]
    domain_a=[]
    model.eval()
    for inp, target,domain,tr in val_loader:
            with torch.no_grad():
                inp=inp.to(device)
                target=target.to(device)
                doamin=domain.to(device)
                tr=tr.to(device)
                middle_output1,middle_output2,middle_output3,prediction,features = model(inp)
                fet_a.append(features)
                label_a.append(target)
                domain_a.append(domain)
                probabilities = softmax(prediction)
                predicted_classes = torch.argmax(probabilities, dim=1)
                correct_val += (predicted_classes == target).sum().item()
                total_val += target.size(0)
                y_pred_v.extend(predicted_classes.data.cpu().numpy())
                y_true_v.extend(target.data.cpu().numpy())
                total_loss_val = loss_fn(prediction, target)          
                val_loss += total_loss_val.detach().cpu().numpy()
                val_steps += 1
    val_accuracy = correct_val / total_val
    print(f"val loss: {total_loss_val.item()}")
    val_losses.append(val_loss/val_steps)
    print("val  Accuracy = {}".format(val_accuracy))
    trial.report(val_accuracy, i)
    if trial.should_prune():
        raise optuna.exceptions.TrialPruned()
    return val_accuracy,y_pred_tr,y_true_tr,y_pred_v,y_true_v,fet_a,label_a,domain_a

def train(trial):
    if torch.cuda.is_available():
        device = "cuda"
    else:
        device = "cpu"
    print(f"Using {device}")
    epochs=20
    torch.manual_seed(0)
    usd = UrbanSoundDataset(ANNOTATIONS_FILE,DOMAIN_PATH,AUDIO_DIR,
                            device)
    test_abs = int(len(usd) * 0.8)
    train_subset, val_subset = random_split(usd, [test_abs, len(usd) - test_abs],generator=g)
    train_dataloader = create_data_loader(train_subset, BATCH_SIZE)
    val_loader=create_data_loader(val_subset,BATCH_SIZE)
    loss_fn = torch.nn.CrossEntropyLoss()
    model = eca_resnet18().to(device)
    lr = trial.suggest_float("lr", 0.01, 0.03, log=True)
    # alpha=0.86
    # temperature=3
    optimiser = torch.optim.SGD(model.parameters(),
                                 lr=lr)
    train_losses=[]
    val_losses=[]
    for i in range(epochs):
        print(f"Epoch {i+1}")
        val_accuracy,y_pred_tr,y_true_tr,y_pred_v,y_true_v,fet_a,label_a,domain_a=train_single_epoch(trial,model, train_dataloader,val_loader,loss_fn, optimiser, device,i,alpha,train_losses,val_losses,temperature)
        print("---------------------------")
    print("Finished training")
    epochs_plot = range(1, 21)
    plt.plot(epochs_plot, train_losses, label='Training Loss')
    plt.plot(epochs_plot, val_losses, label='Validation Loss')
    classes = ('genuine','spoof')
    cf_matrix1 = confusion_matrix(y_true_tr, y_pred_tr)
    cf_matrix2 = confusion_matrix(y_true_v, y_pred_v)
    df_cm1 = pd.DataFrame(cf_matrix1 / np.sum(cf_matrix1, axis=1)[:, None], index = [i for i in classes],
                     columns = [i for i in classes])
    df_cm2 = pd.DataFrame(cf_matrix2 / np.sum(cf_matrix2, axis=1)[:, None], index = [i for i in classes],
                     columns = [i for i in classes])
    plt.figure(figsize = (12,7))
    sn.heatmap(df_cm1, annot=True)
    plt.figure(figsize = (12,7))
    sn.heatmap(df_cm2, annot=True)
    tsned(fet_a,label_a,domain_a)
    return val_accuracy


if __name__ == "__main__":

    study = optuna.create_study(direction="maximize")
    study.optimize(train, n_trials=10)
    pruned_trials = study.get_trials(deepcopy=False, states=[TrialState.PRUNED])
    complete_trials = study.get_trials(deepcopy=False, states=[TrialState.COMPLETE])

    print("Study statistics: ")
    print("  Number of finished trials: ", len(study.trials))
    print("  Number of pruned trials: ", len(pruned_trials))
    print("  Number of complete trials: ", len(complete_trials))

    print("Best trial:")
    trial = study.best_trial

    print("  Value: ", trial.value)

    print("  Params: ")
    # for key, value in trial.params.items():
    #     print("    {}: {}".format(key, value))
    # torch.save(model.state_dict(), "state_sd.pth")
    # print("Trained feed forward net saved at feedforwardnet.pth")

Base Model ( without any of the proposed frameowrk)
learning rate (variable lr ): replace with best learning rate obtained through hyperparameter tuning
Outputs : displayed accuracy, loss and eer over the epochs, training and validation loss, accuracy,det curves and confusion matrices,
t-SNE diagram on validation set


In [None]:
import torch
import torchaudio
from torch import nn
from torch.utils.data import DataLoader
import numpy as np
import optuna
from optuna.trial import TrialState
from torch.utils.data import random_split
from matplotlib.pylab import plt
from sklearn.metrics import confusion_matrix
import seaborn as sn
import pandas as pd
BATCH_SIZE = 128
EPOCHS = 20
# LEARNING_RATE = 0.001
softmax = nn.Softmax(dim=1)

def create_data_loader(train_data, batch_size):
    train_dataloader = DataLoader(train_data, batch_size=batch_size,shuffle=True)
    return train_dataloader


def train_single_epoch(model, data_loader,val_loader, loss_fn,optimiser, device,i,alpha,train_losses,val_losses,temperature,train_ax,val_ax):
# model training 
    correct=0
    fet_a=[]
    label_a=[]
    domain_a=[]
    total=0
    y_pred_tr=[]
    y_true_tr=[]
    y_one_tr=[]
    total_steps=0
    total_train_loss=0
    model.train()
    for inp, target,domain,tr in data_loader:
        inp = inp.to(device)
        target=target.to(device)
        domain=domain.to(device)
        tr=tr.to(device)
        middle_output1,middle_output2,middle_output3,prediction,features = model(inp)
        probabilities = softmax(prediction.detach())
        label_1_probabilities = probabilities[:, 1]
        predicted_classes = torch.argmax(probabilities, dim=1)
        correct += (predicted_classes == target).sum().item()
        total += target.size(0)
        y_pred_tr.extend(predicted_classes.data.cpu().numpy())
        y_one_tr.extend(label_1_probabilities.data.cpu().numpy())
        y_true_tr.extend(target.data.cpu().numpy())
        # losses3_kd.update(loss3by4, input.size(0))
        loss = loss_fn(prediction, target)
        optimiser.zero_grad()
        loss.backward()
        optimiser.step()
        total_train_loss+=loss.detach().cpu().numpy()
        total_steps+=1

    train_losses.append(total_train_loss/total_steps)
    # print(f"train loss: {total_train_loss/total_steps.item()}")
    accuracy = correct / total
    train_ax.append(accuracy)
    eer=compute_eer(y_true_tr,y_one_tr)
    print("train eer = {}".format(eer))


    # validation set
    val_loss = 0.0
    val_steps = 0
    total_val = 0
    correct_val = 0
    y_pred_v=[]
    y_true_v=[]
    fet_a=[]
    label_a=[]
    domain_a=[]
    y_one_v=[]
    model.eval()
    for inp, target,domain,tr in val_loader:
            with torch.no_grad():
                inp=inp.to(device)
                target=target.to(device)
                doamin=domain.to(device)
                tr=tr.to(device)
                middle_output1,middle_output2,middle_output3,prediction,features = model(inp)
                fet_a.append(features)
                label_a.append(target)
                domain_a.append(domain)
                probabilities = softmax(prediction)
                label_1_probabilities = probabilities[:, 1]
                predicted_classes = torch.argmax(probabilities, dim=1)
                correct_val += (predicted_classes == target).sum().item()
                total_val += target.size(0)
                y_pred_v.extend(predicted_classes.data.cpu().numpy())
                y_true_v.extend(target.data.cpu().numpy())
                y_one_v.extend(label_1_probabilities.data.cpu().numpy())
                loss = loss_fn(prediction, target)            
                val_loss += loss.detach().cpu().numpy()
                val_steps += 1
    val_accuracy = correct_val / total_val
    # print(f"val loss: {val_loss/val_steps.item()}")
    val_losses.append(val_loss/val_steps)
    print("val  Accuracy = {}".format(val_accuracy))
    val_ax.append(val_accuracy)
    eer=compute_eer(y_true_v,y_one_v)
    print("train eer = {}".format(eer))
    return val_accuracy,y_pred_tr,y_true_tr,y_pred_v,y_true_v,fet_a,label_a,domain_a,y_one_tr,y_one_v

def train():
    if torch.cuda.is_available():
        device = "cuda"
    else:
        device = "cpu"
    print(f"Using {device}")
    epochs=20
    usd = UrbanSoundDataset(ANNOTATIONS_FILE,DOMAIN_PATH,AUDIO_DIR,
                            device)
    test_abs = int(len(usd) * 0.8)
    train_subset, val_subset = random_split(usd, [test_abs, len(usd) - test_abs],generator=torch.Generator().manual_seed(500))
    train_dataloader = create_data_loader(train_subset, BATCH_SIZE)
    val_loader=create_data_loader(val_subset,BATCH_SIZE)
    loss_fn = torch.nn.CrossEntropyLoss()
    torch.manual_seed(0)
    model = eca_resnet18().to(device)
    optimiser = torch.optim.SGD(model.parameters(),
                                 lr=lr)
    temperature=3
    train_losses=[]
    val_losses=[]
    train_ax=[]
    val_ax=[]
    for i in range(epochs):
        print(f"Epoch {i+1}")
        val_accuracy,y_pred_tr,y_true_tr,y_pred_v,y_true_v,fet_a,label_a,domain_a,y_one_tr,y_one_v=train_single_epoch(model, train_dataloader,val_loader,loss_fn, optimiser, device,i,alpha,train_losses,val_losses,temperature,train_ax,val_ax)
        print("---------------------------")
    print("Finished training")
    epochs_plot = range(1, 21)
    plt.plot(epochs_plot, train_losses, label='Training Loss')
    plt.plot(epochs_plot, val_losses, label='Validation Loss')
    classes = ('genuine','spoof')
    cf_matrix1 = confusion_matrix(y_true_tr, y_pred_tr)
    cf_matrix2 = confusion_matrix(y_true_v, y_pred_v)
    df_cm1 = pd.DataFrame(cf_matrix1 / np.sum(cf_matrix1, axis=1)[:, None], index = [i for i in classes],
                     columns = [i for i in classes])
    df_cm2 = pd.DataFrame(cf_matrix2 / np.sum(cf_matrix2, axis=1)[:, None], index = [i for i in classes],
                     columns = [i for i in classes])
    plt.figure(figsize = (12,7))
    sn.heatmap(df_cm1, annot=True)
    plt.figure(figsize = (12,7))
    sn.heatmap(df_cm2, annot=True)
    tsned(fet_a,label_a,domain_a)
    plot_det(y_true_tr,y_one_tr)
    plot_det(y_true_v,y_one_v)
    plt.plot(epochs_plot, train_ax, label='Training Accuracy')
    plt.plot(epochs_plot, val_ax, label='Validation Accuracy')
    return y_pred_v,model


if __name__ == "__main__":

    y_pred,model2=train()
    torch.save(model2.state_dict(), "base_model.pth")

**with self distillation**

In [None]:
import torch
import torchaudio
from torch import nn
from torch.utils.data import DataLoader
import numpy as np
import optuna
from optuna.trial import TrialState
from torch.utils.data import random_split
from matplotlib.pylab import plt
from sklearn.metrics import confusion_matrix
import seaborn as sn
import pandas as pd
BATCH_SIZE = 128
EPOCHS = 20
# LEARNING_RATE = 0.001
softmax = nn.Softmax(dim=1)

def create_data_loader(train_data, batch_size):
    train_dataloader = DataLoader(train_data, batch_size=batch_size,shuffle=True)
    return train_dataloader

def kd_loss_function(output, target_output,temperature):
#     temperature=3
    output = output / temperature
    output_log_softmax = torch.log_softmax(output, dim=1)
    loss_kd = -torch.mean(torch.sum(output_log_softmax * target_output, dim=1))
    return loss_kd


def train_single_epoch(model, data_loader,val_loader, loss_fn,optimiser, device,i,alpha,train_losses,val_losses,temperature,train_ax,val_ax):
#     temperature=3
    correct=0
    fet_a=[]
    label_a=[]
    domain_a=[]
    total=0
    y_pred_tr=[]
    y_true_tr=[]
    y_one_tr=[]
    total_steps=0
    total_train_loss=0
    model.train()
    for inp, target,domain,tr in data_loader:
        inp = inp.to(device)
        target=target.to(device)
        domain=domain.to(device)
        tr=tr.to(device)
        middle_output1,middle_output2,middle_output3,prediction,features = model(inp)
        temp4 = prediction / temperature
        temp4 = torch.softmax(temp4, dim=1)
        probabilities = softmax(prediction.detach())
        label_1_probabilities = probabilities[:, 1]
        predicted_classes = torch.argmax(probabilities, dim=1)
        correct += (predicted_classes == target).sum().item()
        total += target.size(0)
        y_pred_tr.extend(predicted_classes.data.cpu().numpy())
        y_one_tr.extend(label_1_probabilities.data.cpu().numpy())
        y_true_tr.extend(target.data.cpu().numpy())
        loss1by4 = kd_loss_function(middle_output1, temp4.detach(),temperature)* (temperature**2)
        loss2by4 = kd_loss_function(middle_output2, temp4.detach(),temperature) * (temperature**2)
        loss3by4 = kd_loss_function(middle_output3, temp4.detach(),temperature) * (temperature**2)
        loss = loss_fn(prediction, target)
        total_loss = (1 - alpha) * (loss1by4 + loss2by4 + loss3by4) +  alpha *loss
        optimiser.zero_grad()
        total_loss.backward()
        optimiser.step()
        total_train_loss+=total_loss.detach().cpu().numpy()
        total_steps+=1
        # correct += (thresholded_pred_acc == target).float().sum()
    train_losses.append(total_train_loss/total_steps)
    tll=total_train_loss/total_steps
    print(f"train loss: {tll}")
    accuracyt = correct / total
    eer=compute_eer(y_true_tr,y_one_tr)
    print("train eer = {}".format(eer))
    print("train  Accuracy = {}".format(accuracyt))
    train_ax.append(accuracyt)

    val_loss = 0.0
    val_steps = 0
    total_val = 0
    correct_val = 0
    y_pred_v=[]
    y_true_v=[]
    fet_a=[]
    label_a=[]
    domain_a=[]
    y_one_v=[]
    model.eval()
    for inp, target,domain,tr in val_loader:
            with torch.no_grad():
                inp=inp.to(device)
                target=target.to(device)
                doamin=domain.to(device)
                tr=tr.to(device)
                middle_output1,middle_output2,middle_output3,prediction,features = model(inp)
                fet_a.append(features)
                label_a.append(target)
                domain_a.append(domain)
                temp4 = prediction / temperature
                temp4 = torch.softmax(temp4, dim=1)
                probabilities = softmax(prediction)
                label_1_probabilities = probabilities[:, 1]
                predicted_classes = torch.argmax(probabilities, dim=1)
                correct_val += (predicted_classes == target).sum().item()
                total_val += target.size(0)
                y_pred_v.extend(predicted_classes.data.cpu().numpy())
                y_true_v.extend(target.data.cpu().numpy())
                y_one_v.extend(label_1_probabilities.data.cpu().numpy())
                loss1by4 = kd_loss_function(middle_output1, temp4.detach(),temperature)* (temperature**2)
                loss2by4 = kd_loss_function(middle_output2, temp4.detach(),temperature) * (temperature**2)
                loss3by4 = kd_loss_function(middle_output3, temp4.detach(),temperature) * (temperature**2)
                loss = loss_fn(prediction, target)
                total_loss_val = (1 - alpha) * (loss1by4 + loss2by4 + loss3by4) + alpha *loss            
                val_loss += total_loss_val.detach().cpu().numpy()
                val_steps += 1
    val_accuracy = correct_val / total_val
    vll=val_loss/val_steps
    print("val loss: {}".format(val_loss/val_steps))
    val_losses.append(val_loss/val_steps)
    print("val  Accuracy = {}".format(val_accuracy))
    val_ax.append(val_accuracy)
    eer=compute_eer(y_true_v,y_one_v)
    print("train eer = {}".format(eer))
    return val_accuracy,y_pred_tr,y_true_tr,y_pred_v,y_true_v,fet_a,label_a,domain_a,y_one_tr,y_one_v

def train():
    if torch.cuda.is_available():
        device = "cuda"
    else:
        device = "cpu"
    print(f"Using {device}")
    epochs=20
    usd = UrbanSoundDataset(ANNOTATIONS_FILE,DOMAIN_PATH,AUDIO_DIR,
                            device)
    test_abs = int(len(usd) * 0.8)
    train_subset, val_subset = random_split(usd, [test_abs, len(usd) - test_abs],generator=torch.Generator().manual_seed(500))
    train_dataloader = create_data_loader(train_subset, BATCH_SIZE)
    val_loader=create_data_loader(val_subset,BATCH_SIZE)
    loss_fn = torch.nn.CrossEntropyLoss()
    torch.manual_seed(0)
    model = eca_resnet18().to(device)
    optimiser = torch.optim.SGD(model.parameters(),
                                 lr=lr)
    alpha=0.86
    temperature=3
    train_losses=[]
    val_losses=[]
    train_ax=[]
    val_ax=[]
    for i in range(epochs):
        print(f"Epoch {i+1}")
        val_accuracy,y_pred_tr,y_true_tr,y_pred_v,y_true_v,fet_a,label_a,domain_a,y_one_tr,y_one_v=train_single_epoch(model, train_dataloader,val_loader,loss_fn, optimiser, device,i,alpha,train_losses,val_losses,temperature,train_ax,val_ax)
        print("---------------------------")
    print("Finished training")
    epochs_plot = range(1, 21)
    plt.plot(epochs_plot, train_losses, label='Training Loss')
    plt.plot(epochs_plot, val_losses, label='Validation Loss')
    classes = ('genuine','spoof')
    cf_matrix1 = confusion_matrix(y_true_tr, y_pred_tr)
    cf_matrix2 = confusion_matrix(y_true_v, y_pred_v)
    df_cm1 = pd.DataFrame(cf_matrix1 / np.sum(cf_matrix1, axis=1)[:, None], index = [i for i in classes],
                     columns = [i for i in classes])
    df_cm2 = pd.DataFrame(cf_matrix2 / np.sum(cf_matrix2, axis=1)[:, None], index = [i for i in classes],
                     columns = [i for i in classes])
    plt.figure(figsize = (12,7))
    sn.heatmap(df_cm1, annot=True)
    plt.figure(figsize = (12,7))
    sn.heatmap(df_cm2, annot=True)
    tsned(fet_a,label_a,domain_a)
    plot_det(y_true_tr,y_one_tr)
    plot_det(y_true_v,y_one_v)
    plt.plot(epochs_plot, train_ax, label='Training Accuracy')
    plt.plot(epochs_plot, val_ax, label='Validation Accuracy')
    return y_pred_v,model


if __name__ == "__main__":

    y_pred_sd,model=train()
    torch.save(model.state_dict(), "self_distil_model.pth")

Effect of noise on models 

In [None]:
import torch
import torchaudio
from torch import nn
from torch.utils.data import DataLoader
import numpy as np
import optuna
from optuna.trial import TrialState
from torch.utils.data import random_split
from matplotlib.pylab import plt
from sklearn.metrics import confusion_matrix
import seaborn as sn
import pandas as pd
BATCH_SIZE = 128
EPOCHS = 20
# LEARNING_RATE = 0.001
softmax = nn.Softmax(dim=1)
sigmoid=nn.Sigmoid()

def create_data_loader(train_data, batch_size):
    train_dataloader = DataLoader(train_data, batch_size=batch_size,shuffle=True)
    return train_dataloader

def add_gaussian_noise(tensor, std):
    noise = torch.randn(tensor.size(),device=device) * std
    return tensor + noise


def train_single_epoch(model, data_loader,val_loader, loss_fn,optimiser, device,i,alpha,train_losses,val_losses,temperature,train_ax,val_ax,std):
#     temperature=3
    correct=0
#     fet_a=[]
#     label_a=[]
#     domain_a=[]
    total=0
    y_pred_tr=[]
    y_true_tr=[]
    y_one_tr=[]
    total_steps=0
    total_train_loss=0
    model.train()
    for inp, target,domain,tr in data_loader:
        inp = inp.to(device)
        target=target.to(device)
        domain=domain.to(device)
        tr=tr.to(device)
        for p in model.parameters():
             p.data = p.data.to(device)
             p.data = add_gaussian_noise(p.data,std)
        middle_output1,middle_output2,middle_output3,prediction,features = model(inp)

        probabilities = softmax(prediction.detach())
        label_1_probabilities = probabilities[:, 1]
        predicted_classes = torch.argmax(probabilities, dim=1)
        correct += (predicted_classes == target).sum().item()
        total += target.size(0)
        loss = loss_fn(prediction, target)
        optimiser.zero_grad()
        loss.backward()
        optimiser.step()
        total_train_loss+=loss.detach().cpu().numpy()
        total_steps+=1
    accuracy = correct / total
#     train_ax.append(accuracy)
    print("train   Accuracy = {}".format(accuracy))

    val_loss = 0.0
    val_steps = 0
    total_val = 0
    correct_val = 0
    y_pred_v=[]
    y_true_v=[]
    fet_a=[]
    label_a=[]
    domain_a=[]
    y_one_v=[]
    model.eval()
    for inp, target,domain,tr in val_loader:
            with torch.no_grad():
                inp=inp.to(device)
                target=target.to(device)
                domain=domain.to(device)
                tr=tr.to(device)
                middle_output1,middle_output2,middle_output3,prediction,features = model(inp)
                probabilities = softmax(prediction.detach())
                label_1_probabilities = probabilities[:, 1]
                predicted_classes = torch.argmax(probabilities, dim=1)
                correct_val += (predicted_classes == target).sum().item()
                total_val += target.size(0)
                loss = loss_fn(prediction, target)            
                val_loss += loss.detach().cpu().numpy()
                val_steps += 1
    val_accuracy = correct_val / total_val
    print("val  Accuracy = {}".format(val_accuracy))
    return val_accuracy,y_pred_tr,y_true_tr,y_pred_v,y_true_v,fet_a,label_a,domain_a,y_one_tr,y_one_v,accuracy

def train():
    if torch.cuda.is_available():
        device = "cuda"
    else:
        device = "cpu"
    print(f"Using {device}")
    torch.manual_seed(0)
    epochs=20
    usd = UrbanSoundDataset(ANNOTATIONS_FILE,DOMAIN_PATH,AUDIO_DIR,
                            device)
    test_abs = int(len(usd) * 0.8)
    train_subset, val_subset = random_split(usd, [test_abs, len(usd) - test_abs],generator=torch.Generator().manual_seed(500))
    train_dataloader = create_data_loader(train_subset, BATCH_SIZE)
    val_loader=create_data_loader(val_subset,BATCH_SIZE)
    
    loss_fn=torch.nn.CrossEntropyLoss()
    model = eca_resent18().to(device)
    optimiser = torch.optim.SGD(model.parameters(),lr=lr)
    temperature=3
    train_losses=[]
    val_losses=[]
    train_ax=[]
    val_ax=[]
    acc_noise_std=[]
    stds=[0.01,0.02,0.03,0.04,0.05]
    for std in stds:
        for i in range(epochs):
            print(f"Epoch {i+1}")
            val_accuracy,y_pred_tr,y_true_tr,y_pred_v,y_true_v,fet_a,label_a,domain_a,y_one_tr,y_one_v,train_acc=train_single_epoch(model, train_dataloader,val_loader,loss_fn, optimiser, device,i,alpha,train_losses,val_losses,temperature,train_ax,val_ax,std)
            print("---------------------------")
        acc_noise_std.append(train_acc)
    return y_pred_v,model,acc_noise_std


if __name__ == "__main__":

    y_pred2,model2,nom_acc=train()


In [None]:
import torch
import torchaudio
from torch import nn
from torch.utils.data import DataLoader
import numpy as np
import optuna
from optuna.trial import TrialState
from torch.utils.data import random_split
from matplotlib.pylab import plt
from sklearn.metrics import confusion_matrix
import seaborn as sn
import pandas as pd
BATCH_SIZE = 128
EPOCHS = 20
# LEARNING_RATE = 0.001
softmax = nn.Softmax(dim=1)

def create_data_loader(train_data, batch_size):
    train_dataloader = DataLoader(train_data, batch_size=batch_size,shuffle=True)
    return train_dataloader

def kd_loss_function(output, target_output,temperature):
#     temperature=3
    output = output / temperature
    output_log_softmax = torch.log_softmax(output, dim=1)
    loss_kd = -torch.mean(torch.sum(output_log_softmax * target_output, dim=1))
    return loss_kd
def add_gaussian_noise(tensor, std):
    noise = torch.randn(tensor.size(),device=device) * std
    return tensor + noise

def train_single_epoch(model, data_loader,val_loader, loss_fn,optimiser, device,i,alpha,train_losses,val_losses,temperature,train_ax,val_ax,std):
#     temperature=3
    correct=0
    fet_a=[]
    label_a=[]
    domain_a=[]
    total=0
    y_pred_tr=[]
    y_true_tr=[]
    y_one_tr=[]
    total_steps=0
    total_train_loss=0
    model.train()
    for inp, target,domain,tr in data_loader:
        inp = inp.to(device)
        target=target.to(device)
        domain=domain.to(device)
        tr=tr.to(device)
        for p in model.parameters():
             p.data = p.data.to(device)
#              gaussian = Normal(loc=0, scale=torch.ones_like(p))
             p.data = add_gaussian_noise(p.data,std)
        middle_output1,middle_output2,middle_output3,prediction,features = model(inp)
        
        temp4 = prediction / temperature
        temp4 = torch.softmax(temp4, dim=1)
        probabilities = softmax(prediction.detach())
        label_1_probabilities = probabilities[:, 1]
        predicted_classes = torch.argmax(probabilities, dim=1)
        correct += (predicted_classes == target).sum().item()
        total += target.size(0)
        y_pred_tr.extend(predicted_classes.data.cpu().numpy())
        y_one_tr.extend(label_1_probabilities.data.cpu().numpy())
        y_true_tr.extend(target.data.cpu().numpy())
        loss1by4 = kd_loss_function(middle_output1, temp4.detach(),temperature)* (temperature**2)
        loss2by4 = kd_loss_function(middle_output2, temp4.detach(),temperature) * (temperature**2)
        loss3by4 = kd_loss_function(middle_output3, temp4.detach(),temperature) * (temperature**2)
        # losses3_kd.update(loss3by4, input.size(0))
        loss = loss_fn(prediction, target)
        total_loss = (1 - alpha) * (loss1by4 + loss2by4 + loss3by4) +  alpha *loss
        optimiser.zero_grad()
        total_loss.backward()
        optimiser.step()
        total_train_loss+=total_loss.detach().cpu().numpy()
        total_steps+=1
    train_losses.append(total_train_loss/total_steps)
    accuracy = correct / total
    print("train   Accuracy = {}".format(accuracy))

    val_loss = 0.0
    val_steps = 0
    total_val = 0
    correct_val = 0
    y_pred_v=[]
    y_true_v=[]
    fet_a=[]
    label_a=[]
    domain_a=[]
    y_one_v=[]
    model.eval()
    for inp, target,domain,tr in val_loader:
            with torch.no_grad():
                inp=inp.to(device)
                target=target.to(device)
                doamin=domain.to(device)
                tr=tr.to(device)
                middle_output1,middle_output2,middle_output3,prediction,features = model(inp)
                fet_a.append(features)
                label_a.append(target)
                domain_a.append(domain)
                temp4 = prediction / temperature
                temp4 = torch.softmax(temp4, dim=1)
                probabilities = softmax(prediction)
                label_1_probabilities = probabilities[:, 1]
                predicted_classes = torch.argmax(probabilities, dim=1)
                correct_val += (predicted_classes == target).sum().item()
                total_val += target.size(0)
                y_pred_v.extend(predicted_classes.data.cpu().numpy())
                y_true_v.extend(target.data.cpu().numpy())
                y_one_v.extend(label_1_probabilities.data.cpu().numpy())
                loss1by4 = kd_loss_function(middle_output1, temp4.detach(),temperature)* (temperature**2)
                loss2by4 = kd_loss_function(middle_output2, temp4.detach(),temperature) * (temperature**2)
                loss3by4 = kd_loss_function(middle_output3, temp4.detach(),temperature) * (temperature**2)
                loss = loss_fn(prediction, target)
                total_loss_val = (1 - alpha) * (loss1by4 + loss2by4 + loss3by4) + alpha *loss            
                val_loss += total_loss_val.detach().cpu().numpy()
                val_steps += 1
    val_accuracy = correct_val / total_val
    print("val  Accuracy = {}".format(val_accuracy))
    return val_accuracy,y_pred_tr,y_true_tr,y_pred_v,y_true_v,fet_a,label_a,domain_a,y_one_tr,y_one_v,accuracy

def train():
    if torch.cuda.is_available():
        device = "cuda"
    else:
        device = "cpu"
    print(f"Using {device}")
    epochs=20
    torch.manual_seed(42)
    usd = UrbanSoundDataset(ANNOTATIONS_FILE,DOMAIN_PATH,AUDIO_DIR,
                            device)
    test_abs = int(len(usd) * 0.8)
    train_subset, val_subset = random_split(usd, [test_abs, len(usd) - test_abs],generator=torch.Generator().manual_seed(500))
    train_dataloader = create_data_loader(train_subset, BATCH_SIZE)
    val_loader=create_data_loader(val_subset,BATCH_SIZE)
    loss_fn = torch.nn.CrossEntropyLoss()
    model = cnn().to(device)
    optimiser = torch.optim.SGD(model.parameters(),lr=lr)
    alpha=al
    temperature=3
    train_losses=[]
    val_losses=[]
    train_ax=[]
    val_ax=[]
    acc_noise_std=[]
    stds=[0.01,0.02,0.03,0.04,0.05]
    for std in stds:
        for i in range(epochs):
            print(f"Epoch {i+1}")
            val_accuracy,y_pred_tr,y_true_tr,y_pred_v,y_true_v,fet_a,label_a,domain_a,y_one_tr,y_one_v,train_acc=train_single_epoch(model, train_dataloader,val_loader,loss_fn, optimiser, device,i,alpha,train_losses,val_losses,temperature,train_ax,val_ax,std)
            print("---------------------------")
        acc_noise_std.append(train_acc)
    print("Finished training")
    epochs_plot = range(1, 21)
    return y_pred_v,model,acc_noise_std


if __name__ == "__main__":

    y_pred_sd,model,acc_noise_std=train()

In [None]:
import matplotlib.pyplot as plt
stds=[0.01,0.02,0.03,0.04,0.05]
plt.plot(stds, acc_noise_std, label='Self Distillation')
plt.plot(stds, nom_acc, label='Non Self Distillation')
plt.xlabel('standard deviation of noise')
plt.ylabel('Accuracy')

# Add a legend
plt.legend()

# Show the plot
plt.show()


Domain Adversarial Learning
describes the working of gradient reverse layer and the domain discriminator

Gradient Reverse Layer

In [18]:

class GradientReversalF(torch.autograd.Function):
  @staticmethod
  def forward(ctx, x, alpha):
    #let the input unchaged
    ctx.save_for_backward(alpha)
    return x

  @staticmethod
  def backward(ctx, grad_output):
    #reverse the gradient by multipling -alpha
    alpha = ctx.saved_tensors[0]
    if ctx.needs_input_grad[0]:
      grad_output = (grad_output * (-alpha))
    return (grad_output, None)


class GradientReverse(nn.Module):
  def __init__(self, alpha, *args, **kwargs):
    #Reverse GR layer hook
    super().__init__(*args, **kwargs)
    self.alpha = torch.tensor(alpha, requires_grad=False)
    assert alpha > 0, 'alpha must be > 0'
    print(f"The gradient will be multiplied by: {-alpha}")

  def forward(self, x):
    return GradientReversalF.apply(x, self.alpha)

Domain Discriminator

In [23]:
class Discriminator(nn.Module):
    def __init__(self,block,alpha):
        super(Discriminator, self).__init__()
        self.fc1 = nn.Linear(128 * block.expansion, 128 * block.expansion)
        self.fc1.weight.data.normal_(0, 0.01)
        self.fc1.bias.data.fill_(0.0)
        self.fc2 = nn.Linear(128 * block.expansion, 3)
        self.fc2.weight.data.normal_(0, 0.3)
        self.fc2.bias.data.fill_(0.0)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.3)
        self.grl_layer = GradientReverse(alpha)

    def forward(self, feature):
        feature = self.grl_layer.forward(feature)
        feature = self.fc1(feature)
        feature = self.relu(feature)
        feature = self.dropout(feature)
        feature = self.fc2(feature)
        return feature
def domain_disc():
    model =Discriminator(ECABasicBlock,1)
    return model

Checking working of gradient reversal

In [None]:
import torch
rev=GradientReverse(1)

# alpha = torch.tensor([1.])

x = torch.tensor([4.], requires_grad=True)
x_rev = torch.tensor([4.], requires_grad=True)

y = x
y = y*5

y_rev=x_rev
y_rev = rev(y_rev)
y_rev = y_rev*5
# y_rev = y_rev*6


y.backward()
y_rev.backward()

print(f'x gradient: {x.grad}') # 5
print(f'reversed x gradient: {x_rev.grad}') # -5

assert x.grad==-x_rev.grad

Domain adversarial learning is described below :
The features outputted from the generator, here the self distilled model are fed into a domain discriminator.
Domain discriminator is used to predict the domains of the input data ( gradient reverse layer is added to force the feature generator to produce domain invariant features.The feature generator is intially optimised with the self distillation and domain loss. The features are then detached and refed to domain discriminator to optimise domain discriminator with domain loss. This ensures that gradients of feature generator aren't changed after it is optimised)
al,lr1 variables should be replaced with the appropriate values

In [None]:
import torch
import torchaudio
from torch import nn
from torch.utils.data import DataLoader
import numpy as np
import optuna
from optuna.trial import TrialState
from torch.utils.data import random_split
from matplotlib.pylab import plt
from sklearn.metrics import confusion_matrix
import seaborn as sn
import pandas as pd
BATCH_SIZE = 128
EPOCHS = 20
# LEARNING_RATE = 0.001
softmax = nn.Softmax(dim=1)
def create_data_loader(train_data, batch_size):
    train_dataloader = DataLoader(train_data, batch_size=batch_size,shuffle=True,generator=g)
    return train_dataloader

def kd_loss_function(output, target_output):
    temperature=3
    output = output / temperature
    output_log_softmax = torch.log_softmax(output, dim=1)
    loss_kd = -torch.mean(torch.sum(output_log_softmax * target_output, dim=1))
    return loss_kd


def train_single_epoch(trial,model, data_loader,val_loader, loss_fn,optimiser_f, device,i,alpha,train_losses,val_losses,optimiser_d,domain_classifier,domain_loss,h_score_tr,h_score_v,tr_acc,val_acc):
    temperature=3
    correct=0
    fet_a=[]
    label_a=[]
    domain_a=[]
    total=0
    y_pred_tr=[]
    y_true_tr=[]
    total_train_loss=0
    total_steps=0
    correct_domain=0
    total_domain=0
    model.train()
    domain_classifier.train()
    for inp, target,domain,triplet in data_loader:
        inp = inp.to(device)
        target=target.to(device)
        domain=domain.to(device)
        triplet=triplet.to(device)
        middle_output1,middle_output2,middle_output3,prediction,features = model(inp)
        optimiser_f.zero_grad()
        fet_a.append(features)
        label_a.append(target)
        domain_a.append(domain)
        spoof_zero_mask = (target == 0).to(device)
        spoof_zero_inputs = features[spoof_zero_mask]
        spoof_zero_domains = domain[spoof_zero_mask]
        if spoof_zero_inputs.size(0) > 0:
            domain_output = domain_classifier(spoof_zero_inputs)
            d_loss=domain_loss(domain_output, spoof_zero_domains)
        temp4 = prediction / temperature
        temp4 = torch.softmax(temp4, dim=1)
        probabilities = softmax(prediction.detach())
        predicted_classes = torch.argmax(probabilities, dim=1)
        correct += (predicted_classes == target).sum().item()
        total += target.size(0)
        y_pred_tr.extend(predicted_classes.data.cpu().numpy())
        y_true_tr.extend(target.data.cpu().numpy())
        loss1by4 = kd_loss_function(middle_output1, temp4.detach())* (temperature**2)

        loss2by4 = kd_loss_function(middle_output2, temp4.detach()) * (temperature**2)

        loss3by4 = kd_loss_function(middle_output3, temp4.detach()) * (temperature**2)
    
        loss = loss_fn(prediction, target)
        total_loss = (1 - alpha) * (loss1by4 + loss2by4 + loss3by4) +  alpha *loss+d_loss
        total_train_loss+=total_loss.detach().cpu().numpy()
        total_steps+=1

        total_loss.backward()
        optimiser_f.step()
        optimiser_d.zero_grad()
        feats=features.detach()
        dom=domain.detach()
        spoof_zero_inputs = feats[spoof_zero_mask]
        spoof_zero_domains = dom[spoof_zero_mask]
        if spoof_zero_inputs.size(0) > 0:
            domain_output = domain_classifier(spoof_zero_inputs)
            d_loss=domain_loss(domain_output, spoof_zero_domains)
            probabilities_domain=softmax(domain_output.detach())
            predicted_domains=torch.argmax(probabilities_domain, dim=1)
            correct_domain+=(predicted_domains == spoof_zero_domains).sum().item()
            total_domain+=spoof_zero_domains.size(0)
            d_loss.backward()
            optimiser_d.step()
    train_losses.append(total_train_loss/total_steps)
    print("train loss= {}".format(total_train_loss/total_steps))
    accuracy = correct / total
    print("train   Accuracy = {}".format(accuracy))
    domain_accuracy=correct_domain/total_domain
    print("train domain   Accuracy = {}".format(domain_accuracy))
    tr_acc.append(accuracy)


    val_loss = 0.0
    val_steps = 0
    total_val = 0
    correct_val = 0
    y_pred_v=[]
    y_true_v=[]
    domain_pred_v=[]
    domain_true_v=[]
    fet_a_v=[]
    label_a_v=[]
    domain_a_v=[]
    correct_domain_val=0
    total_domain_val=0
    model.eval()
    domain_classifier.eval()
    for inp, target,domain,triplet in val_loader:
            with torch.no_grad():
                inp=inp.to(device)
                target=target.to(device)
                domain=domain.to(device)
                triplet=triplet.to(device)
                middle_output1,middle_output2,middle_output3,prediction,features = model(inp)
                fet_a_v.append(features)
                label_a_v.append(target)
                domain_a_v.append(domain)
                spoof_zero_mask = (target == 0).to(device)
                spoof_zero_inputs = features[spoof_zero_mask]
                spoof_zero_domains = domain[spoof_zero_mask]
                if spoof_zero_inputs.size(0) > 0:
                    domain_output = domain_classifier(spoof_zero_inputs)
                    d_loss=domain_loss(domain_output,spoof_zero_domains)
                    probabilities_domain=softmax(domain_output.detach())
                    predicted_domains=torch.argmax(probabilities_domain, dim=1)
                    correct_domain_val+=(predicted_domains ==spoof_zero_domains).sum().item()
                    total_domain_val+=spoof_zero_domains.size(0)
                    domain_pred_v.extend(predicted_domains.data.cpu().numpy())
                    domain_true_v.extend(spoof_zero_domains.data.cpu().numpy())
                temp4 = prediction / temperature
                temp4 = torch.softmax(temp4, dim=1)
                probabilities = softmax(prediction)
                predicted_classes = torch.argmax(probabilities, dim=1)
                correct_val += (predicted_classes == target).sum().item()
                total_val += target.size(0)
                y_pred_v.extend(predicted_classes.data.cpu().numpy())
                y_true_v.extend(target.data.cpu().numpy())
                loss1by4 = kd_loss_function(middle_output1, temp4.detach())* (temperature**2)
                loss2by4 = kd_loss_function(middle_output2, temp4.detach()) * (temperature**2)
                loss3by4 = kd_loss_function(middle_output3, temp4.detach()) * (temperature**2)
                loss = loss_fn(prediction, target)
                total_loss_val = (1 - alpha) * (loss1by4 + loss2by4 + loss3by4) + alpha *loss +d_loss         
                val_loss += total_loss_val.detach().cpu().numpy()
                val_steps += 1
    val_accuracy = correct_val / total_val
    domain_acc_val=correct_domain_val/total_domain_val
    print("val loss = {}".format(val_loss/val_steps))
    val_losses.append(val_loss/val_steps)
    print("val  Accuracy = {}".format(val_accuracy))
    val_acc.append(val_accuracy)
    print("val  domain Accuracy = {}".format(domain_acc_val))
    trial.report(domain_acc_val, i)
    if trial.should_prune():
        raise optuna.exceptions.TrialPruned()
    return val_accuracy,y_pred_tr,y_true_tr,y_pred_v,y_true_v,fet_a_v,label_a_v,domain_a_v,domain_acc_val,domain_pred_v,domain_true_v

def train(trial):
    if torch.cuda.is_available():
        device = "cuda"
    else:
        device = "cpu"
    print(f"Using {device}")
    torch.manual_seed(0)
    epochs=20
    usd = UrbanSoundDataset(ANNOTATIONS_FILE,DOMAIN_PATH,AUDIO_DIR,
                            device)
    test_abs = int(len(usd) * 0.8)
    train_subset, val_subset = random_split(usd, [test_abs, len(usd) - test_abs],generator=torch.Generator().manual_seed(500))
    train_dataloader = create_data_loader(train_subset, BATCH_SIZE)
    val_loader=create_data_loader(val_subset,BATCH_SIZE)
    loss_fn = torch.nn.CrossEntropyLoss()
    domain_loss=torch.nn.CrossEntropyLoss()
    model = eca_resnet18().to(device)
    domain_model=domain_disc().to(device)
    lrd = trial.suggest_float("lrd", 1e-5, 1e-1, log=True)
    optimiser = torch.optim.SGD(model.parameters(),lr=lr1)
    alpha= al

    optimiser_domain=torch.optim.SGD(domain_model.parameters(),
                                 lr=lrd)
    train_losses=[]
    val_losses=[]
    h_score_v=[]
    h_score_tr=[]
    tr_acc=[]
    val_acc=[]
    for i in range(epochs):
        print(f"Epoch {i+1}")
        val_accuracy,y_pred_tr,y_true_tr,y_pred_v,y_true_v,fet_a,label_a,domain_a,domain_acc_val,d_p,d_t=train_single_epoch(trial,model, train_dataloader,val_loader,loss_fn, optimiser, device,i,alpha,train_losses,val_losses,optimiser_domain,domain_model,domain_loss,h_score_tr,h_score_v,tr_acc,val_acc)
        print("---------------------------")
    print("Finished training")
    epochs_plot = range(1, 21)
    plt.plot(epochs_plot, train_losses, label='Training Loss')
    plt.plot(epochs_plot, val_losses, label='Validation Loss')
    classes = ('genuine','spoof')
    classes1=('D 0','D 1','D 2')
    cf_matrix1 = confusion_matrix(y_true_tr, y_pred_tr)
    cf_matrix2 = confusion_matrix(y_true_v, y_pred_v)
    cf_matrix3 = confusion_matrix(d_p, d_t)
    df_cm1 = pd.DataFrame(cf_matrix1 / np.sum(cf_matrix1, axis=1)[:, None], index = [i for i in classes],
                     columns = [i for i in classes])
    df_cm2 = pd.DataFrame(cf_matrix2 / np.sum(cf_matrix2, axis=1)[:, None], index = [i for i in classes],
                     columns = [i for i in classes])
    df_cm3 = pd.DataFrame(cf_matrix3 / np.sum(cf_matrix3, axis=1)[:, None], index = [i for i in classes1],
                     columns = [i for i in classes1])
    plt.figure(figsize = (12,7))
    sn.heatmap(df_cm1, annot=True)
    plt.figure(figsize = (12,7))
    sn.heatmap(df_cm2, annot=True)
    plt.figure(figsize = (12,7))
    sn.heatmap(df_cm3, annot=True)
    tsned(fet_a,label_a,domain_a)
    epochs_plot = range(1, 21)
    plt.plot(epochs_plot, tr_acc, label='Training Accuracy')
    plt.plot(epochs_plot, val_acc, label='Validation Accuracy')
    plt.legend()
    plt.show()
    return domain_acc_val


if __name__ == "__main__":

    study = optuna.create_study(direction="maximize")
    study.optimize(train, n_trials=10)
    pruned_trials = study.get_trials(deepcopy=False, states=[TrialState.PRUNED])
    complete_trials = study.get_trials(deepcopy=False, states=[TrialState.COMPLETE])

    print("Study statistics: ")
    print("  Number of finished trials: ", len(study.trials))
    print("  Number of pruned trials: ", len(pruned_trials))
    print("  Number of complete trials: ", len(complete_trials))

    print("Best trial:")
    trial = study.best_trial

    print("  Value: ", trial.value)

    print("  Params: ")
    for key, value in trial.params.items():
        print("    {}: {}".format(key, value))

In [None]:
# import torch
# import torchaudio
# from torch import nn
# from torch.utils.data import DataLoader
# import numpy as np
# import optuna
# from optuna.trial import TrialState
# from torch.utils.data import random_split
# from matplotlib.pylab import plt
# from sklearn.metrics import confusion_matrix
# import seaborn as sn
# import pandas as pd
# BATCH_SIZE = 128
# EPOCHS = 20
# # LEARNING_RATE = 0.001
# softmax = nn.Softmax(dim=1)
# # rev=GradientReverse(1)

# def create_data_loader(train_data, batch_size):
#     train_dataloader = DataLoader(train_data, batch_size=batch_size,shuffle=True)
#     return train_dataloader

# def kd_loss_function(output, target_output):
#     temperature=3
#     output = output / temperature
#     output_log_softmax = torch.log_softmax(output, dim=1)
#     loss_kd = -torch.mean(torch.sum(output_log_softmax * target_output, dim=1))
#     return loss_kd


# def train_single_epoch(trial,model, data_loader,val_loader, loss_fn,optimiser_f, device,i,alpha,train_losses,val_losses,optimiser_d,domain_classifier,domain_loss):
#     temperature=3
#     correct=0
#     fet_a=[]
#     label_a=[]
#     domain_a=[]
#     total=0
#     y_pred_tr=[]
#     y_true_tr=[]
#     total_train_loss=0
#     total_steps=0
#     correct_domain=0
#     total_domain=0
#     model.train()
#     domain_classifier.train()
#     for inp, target,domain in data_loader:
#         inp = inp.to(device)
#         target=target.to(device)
#         domain=domain.to(device)
#         middle_output1,middle_output2,middle_output3,prediction,features = model(inp)
#         optimiser_f.zero_grad()
#         spoof_zero_mask = (target == 0).to(device)
#         spoof_zero_inputs = features[spoof_zero_mask]
#         spoof_zero_domains = domain[spoof_zero_mask]
#         if spoof_zero_inputs.size(0) > 0:
#             domain_output = domain_classifier(spoof_zero_inputs)
#             d_loss=domain_loss(domain_output, spoof_zero_domains)
#         else:
#             d_loss=0
#         temp4 = prediction / temperature
#         temp4 = torch.softmax(temp4, dim=1)
#         probabilities = softmax(prediction.detach())
#         predicted_classes = torch.argmax(probabilities, dim=1)
#         correct += (predicted_classes == target).sum().item()
#         total += target.size(0)
#         y_pred_tr.extend(predicted_classes.data.cpu().numpy())
#         y_true_tr.extend(target.data.cpu().numpy())
#         loss1by4 = kd_loss_function(middle_output1, temp4.detach())* (temperature**2)
#         loss2by4 = kd_loss_function(middle_output2, temp4.detach()) * (temperature**2)
#         loss3by4 = kd_loss_function(middle_output3, temp4.detach()) * (temperature**2)
#         loss = loss_fn(prediction, target)
#         total_loss = (1 - alpha) * (loss1by4 + loss2by4 + loss3by4) +  alpha *loss +d_loss
#         total_train_loss+=total_loss.detach().cpu().numpy()
#         total_steps+=1
#         total_loss.backward()
#         optimiser_f.step()
#         optimiser_d.zero_grad()
#         spoof_zero_inputs = features.detach()[spoof_zero_mask]
#         spoof_zero_domains = domain.detach()[spoof_zero_mask]
#         if spoof_zero_inputs.size(0) > 0:
#             domain_output = domain_classifier(spoof_zero_inputs)
#             d_loss=domain_loss(domain_output, spoof_zero_domains)
#             probabilities_domain=softmax(domain_output.detach())
#             predicted_domains=torch.argmax(probabilities_domain, dim=1)
#             correct_domain+=(predicted_domains == spoof_zero_domains).sum().item()
#             total_domain+=spoof_zero_domains.size(0)
#             d_loss.backward()
#             optimiser_d.step()
#     train_losses.append(total_train_loss/total_steps)
#     tll=total_train_loss/total_steps
#     print("train loss = {}".format(tll))
#     accuracy = correct / total
#     print("train   Accuracy = {}".format(accuracy))
#     domain_accuracy=correct_domain/total_domain
#     print("train domain   Accuracy = {}".format(domain_accuracy))

#     val_loss = 0.0
#     val_steps = 0
#     total_val = 0
#     correct_val = 0
#     y_pred_v=[]
#     y_true_v=[]
#     fet_a=[]
#     label_a=[]
#     domain_a=[]
#     correct_domain_val=0
#     total_domain_val=0
#     model.eval()
#     domain_classifier.eval()
#     for inp, target,domain in val_loader:
#             with torch.no_grad():
#                 inp=inp.to(device)
#                 target=target.to(device)
#                 domain=domain.to(device)
#                 middle_output1,middle_output2,middle_output3,prediction,features = model(inp)
#                 fet_a.append(features)
#                 label_a.append(target)
#                 domain_a.append(domain)
#                 spoof_zero_mask = (target == 0).to(device)
#                 spoof_zero_inputs = features[spoof_zero_mask]
#                 spoof_zero_domains = domain[spoof_zero_mask]
#                 if spoof_zero_inputs.size(0) > 0:
#                     domain_output = domain_classifier(spoof_zero_inputs)
#                     d_loss=domain_loss(domain_output,spoof_zero_domains)
#                     probabilities_domain=softmax(domain_output.detach())
#                     predicted_domains=torch.argmax(probabilities_domain, dim=1)
#                     correct_domain_val+=(predicted_domains ==spoof_zero_domains).sum().item()
#                     total_domain_val+=spoof_zero_domains.size(0)
#                 temp4 = prediction / temperature
#                 temp4 = torch.softmax(temp4, dim=1)
#                 probabilities = softmax(prediction)
#                 predicted_classes = torch.argmax(probabilities, dim=1)
#                 correct_val += (predicted_classes == target).sum().item()
#                 total_val += target.size(0)
#                 y_pred_v.extend(predicted_classes.data.cpu().numpy())
#                 y_true_v.extend(target.data.cpu().numpy())
#                 loss1by4 = kd_loss_function(middle_output1, temp4.detach())* (temperature**2)
#                 loss2by4 = kd_loss_function(middle_output2, temp4.detach()) * (temperature**2)
#                 loss3by4 = kd_loss_function(middle_output3, temp4.detach()) * (temperature**2)
#                 loss = loss_fn(prediction, target)
#                 total_loss_val = (1 - alpha) * (loss1by4 + loss2by4 + loss3by4) + alpha *loss  +d_loss         
#                 val_loss += total_loss_val.detach().cpu().numpy()
#                 val_steps += 1
#     val_accuracy = correct_val / total_val
#     domain_acc_val=correct_domain_val/total_domain_val
#     vll=val_loss/val_steps
#     print("val loss = {}".format(vll))
#     val_losses.append(val_loss/val_steps)
#     print("val  Accuracy = {}".format(val_accuracy))
#     print("val  domain Accuracy = {}".format(domain_acc_val))
#     trial.report(domain_acc_val, i)
#     if trial.should_prune():
#         raise optuna.exceptions.TrialPruned()
#     return val_accuracy,y_pred_tr,y_true_tr,y_pred_v,y_true_v,fet_a,label_a,domain_a,domain_acc_val

# def train(trial):
#     if torch.cuda.is_available():
#         device = "cuda"
#     else:
#         device = "cpu"
#     print(f"Using {device}")
#     epochs=20
#     usd = UrbanSoundDataset(ANNOTATIONS_FILE,DOMAIN_PATH,AUDIO_DIR,
#                             device)
#     test_abs = int(len(usd) * 0.8)
#     train_subset, val_subset = random_split(usd, [test_abs, len(usd) - test_abs],generator=torch.Generator().manual_seed(500))
#     train_dataloader = create_data_loader(train_subset, BATCH_SIZE)
#     val_loader=create_data_loader(val_subset,BATCH_SIZE)
#     loss_fn = torch.nn.CrossEntropyLoss()
#     domain_loss=torch.nn.CrossEntropyLoss()
#     torch.manual_seed(0)
#     model = eca_resnet18().to(device)
#     domain_model=domain_disc().to(device)
#     lr = trial.suggest_float("lr", 1e-5, 1e-1, log=True)
#     alpha=al
#     optimiser = torch.optim.SGD(model.parameters(),
#                                  lr=lr1)
#     optmiser_domain=torch.optim.SGD(domain_model.parameters(),
#                                  lr=lr)
#     train_losses=[]
#     val_losses=[]
#     for i in range(epochs):
#         print(f"Epoch {i+1}")
#         val_accuracy,y_pred_tr,y_true_tr,y_pred_v,y_true_v,fet_a,label_a,domain_a,domain_acc_val=train_single_epoch(trial,model, train_dataloader,val_loader,loss_fn, optimiser, device,i,alpha,train_losses,val_losses,optmiser_domain,domain_model,domain_loss)
#         print("---------------------------")
#     print("Finished training")
#     epochs_plot = range(1, 21)
#     plt.plot(epochs_plot, train_losses, label='Training Loss')
#     plt.plot(epochs_plot, val_losses, label='Validation Loss')
#     classes = ('genuine','spoof')
#     cf_matrix1 = confusion_matrix(y_true_tr, y_pred_tr)
#     cf_matrix2 = confusion_matrix(y_true_v, y_pred_v)
#     df_cm1 = pd.DataFrame(cf_matrix1 / np.sum(cf_matrix1, axis=1)[:, None], index = [i for i in classes],
#                      columns = [i for i in classes])
#     df_cm2 = pd.DataFrame(cf_matrix2 / np.sum(cf_matrix2, axis=1)[:, None], index = [i for i in classes],
#                      columns = [i for i in classes])
#     plt.figure(figsize = (12,7))
#     sn.heatmap(df_cm1, annot=True)
#     plt.figure(figsize = (12,7))
#     sn.heatmap(df_cm2, annot=True)
#     tsned(fet_a,label_a,domain_a)
#     return domain_acc_val


# if __name__ == "__main__":

#     study = optuna.create_study(direction="maximize")
#     study.optimize(train, n_trials=5)
#     pruned_trials = study.get_trials(deepcopy=False, states=[TrialState.PRUNED])
#     complete_trials = study.get_trials(deepcopy=False, states=[TrialState.COMPLETE])

#     print("Study statistics: ")
#     print("  Number of finished trials: ", len(study.trials))
#     print("  Number of pruned trials: ", len(pruned_trials))
#     print("  Number of complete trials: ", len(complete_trials))

#     print("Best trial:")
#     trial = study.best_trial

#     print("  Value: ", trial.value)

#     print("  Params: ")
#     for key, value in trial.params.items():
#         print("    {}: {}".format(key, value))


Triplet Mining Function is described below : We do online triplet mining and make use of semi hard triplet loss

In [21]:
import random
from itertools import combinations

import torch
import torch.nn as nn
import torch.nn.functional as F


class OnlineTripleLoss(nn.Module):
    def __init__(self, margin, sampling_strategy="random_sh"):
        super(OnlineTripleLoss, self).__init__()
        self.margin = margin
        self.triplet_selector = NegativeTripletSelector(
            margin, sampling_strategy
        )

    def forward(self, embeddings, labels):
        triplets = self.triplet_selector.get_triplets(embeddings, labels)
        ap_dists = F.pairwise_distance(
            embeddings[triplets[0], :], embeddings[triplets[1], :]
        )
        an_dists = F.pairwise_distance(
            embeddings[triplets[0], :], embeddings[triplets[2], :]
        )
        loss = F.relu(ap_dists - an_dists + self.margin)
        return loss.mean(), len(triplets[0])


class NegativeTripletSelector:
    def __init__(self, margin, sampling_strategy="random_sh"):
        super(NegativeTripletSelector, self).__init__()
        self.margin = margin
        self.sampling_strategy = sampling_strategy

    def get_triplets(self, embeddings, labels):
        distance_matrix = pdist(embeddings, eps=0)
        unique_labels, counts = torch.unique(labels, return_counts=True)
        triplets_indices = [[] for i in range(3)]
        for i, label in enumerate(unique_labels):
            label_mask = labels == label
            label_indices = torch.where(label_mask)[0]
            if label_indices.shape[0] < 2:
                continue
            negative_indices = torch.where(torch.logical_not(label_mask))[0]
            triplet_label_pairs = self.get_one_one_triplets(
                label_indices, negative_indices, distance_matrix,
            )

            triplets_indices[0].extend(triplet_label_pairs[0])
            triplets_indices[1].extend(triplet_label_pairs[1])
            triplets_indices[2].extend(triplet_label_pairs[2])

        return triplets_indices

    def get_one_one_triplets(self, pos_labels, negative_indices, dist_mat):
        anchor_positives = list(combinations(pos_labels, 2))
        triplets_indices = [[] for i in range(3)]
        for i, anchor_positive in enumerate(anchor_positives):
            anchor_idx = anchor_positive[0]
            pos_idx = anchor_positive[1]
            ap_dist = dist_mat[anchor_idx, pos_idx]
            an_dists = dist_mat[anchor_idx, negative_indices]
            if self.sampling_strategy == "random_sh":
                neg_list_idx = random_semi_hard_sampling(
                    ap_dist, an_dists, self.margin
                )
            elif self.sampling_strategy == "fixed_sh":
                neg_list_idx = fixed_semi_hard_sampling(
                    ap_dist, an_dists, self.margin
                )
            else:
                neg_list_idx = None
            if neg_list_idx is not None:
                neg_idx = negative_indices[neg_list_idx]
                triplets_indices[0].append(anchor_idx)
                triplets_indices[1].append(pos_idx)
                triplets_indices[2].append(neg_idx)
        return triplets_indices


def random_semi_hard_sampling(ap_dist, an_dists, margin):
    ap_margin_dist = ap_dist + margin
    loss = ap_margin_dist - an_dists
    possible_negs = torch.where(loss > 0)[0]
    if possible_negs.nelement() != 0:
        neg_idx = random.choice(possible_negs)
    else:
        neg_idx = None
    return neg_idx


def fixed_semi_hard_sampling(ap_dist, an_dists, margin):
    ap_margin_dist = ap_dist + margin
    loss = ap_margin_dist - an_dists
    possible_negs = torch.where(loss > 0)[0]
    if possible_negs.nelement() != 0:
        neg_idx = torch.argmax(loss).item()
    else:
        neg_idx = None
    # neg_idx = torch.argmin(an_dists).item()
    return neg_idx


def pdist(vectors, eps):
    dist_mat = []
    for i in range(len(vectors)):
        dist_mat.append(
            F.pairwise_distance(vectors[i], vectors, eps=eps).unsqueeze(0)
        )
    return torch.cat(dist_mat, dim=0)

In [None]:
import torch
import torchaudio
from torch import nn
from torch.utils.data import DataLoader
import numpy as np
import optuna
from optuna.trial import TrialState
from torch.utils.data import random_split
from matplotlib.pylab import plt
from sklearn.metrics import confusion_matrix
import seaborn as sn
import pandas as pd
BATCH_SIZE = 64
EPOCHS = 20
# LEARNING_RATE = 0.001
softmax = nn.Softmax(dim=1)

def create_data_loader(train_data, batch_size):
    train_dataloader = DataLoader(train_data, batch_size=batch_size,shuffle=True)
    return train_dataloader

def kd_loss_function(output, target_output):
    temperature=3
    output = output / temperature
    output_log_softmax = torch.log_softmax(output, dim=1)
    loss_kd = -torch.mean(torch.sum(output_log_softmax * target_output, dim=1))
    return loss_kd


def train_single_epoch(model, data_loader,val_loader, loss_fn,optimiser_f, device,i,alpha,train_losses,val_losses,optimiser_d,domain_classifier,domain_loss, criterion_triplet):
    temperature=3
    correct=0
    fet_a=[]
    label_a=[]
    domain_a=[]
    total=0
    y_pred_tr=[]
    y_true_tr=[]
    y_one_tr=[]
    total_train_loss=0
    total_steps=0
    correct_domain=0
    total_domain=0
    model.train()
    domain_classifier.train()
    for inp, target,domain,tr_label in data_loader:
        inp = inp.to(device)
        target=target.to(device)
        domain=domain.to(device)
        tr_label=tr_label.to(device)
        middle_output1,middle_output2,middle_output3,prediction,features = model(inp)
        optimiser_f.zero_grad()
        temp4 = prediction / temperature
        temp4 = torch.softmax(temp4, dim=1)
        probabilities = softmax(prediction.detach())
        predicted_classes = torch.argmax(probabilities, dim=1)
        correct += (predicted_classes == target).sum().item()
        total += target.size(0)
        label_1_probabilities = probabilities[:, 1]
        y_pred_tr.extend(predicted_classes.data.cpu().numpy())
        y_one_tr.extend(label_1_probabilities.data.cpu().numpy())
        y_true_tr.extend(target.data.cpu().numpy())
        triplet_loss,_=criterion_triplet(tr_label,features)
        loss1by4 = kd_loss_function(middle_output1, temp4.detach())* (temperature**2)


        loss2by4 = kd_loss_function(middle_output2, temp4.detach()) * (temperature**2)


        loss3by4 = kd_loss_function(middle_output3, temp4.detach()) * (temperature**2)
        loss = loss_fn(prediction, target)
        total_loss = (1 - alpha) * (loss1by4 + loss2by4 + loss3by4) +  alpha *loss + triplet_loss
        total_train_loss+=total_loss.detach().cpu().numpy()
        total_steps+=1

        total_loss.backward()
        optimiser_f.step()
    train_losses.append(total_train_loss/total_steps)
    print("train  loss= {}".format(total_train_loss/total_steps))
    eer_t=compute_eer(y_true_tr,y_one_tr)
    print("train eer = {}".format(eer_t))
    accuracy = correct / total
    print("train   Accuracy = {}".format(accuracy))

    val_loss = 0.0
    val_steps = 0
    total_val = 0
    correct_val = 0
    y_pred_v=[]
    y_true_v=[]
    domain_pred_v=[]
    domain_true_v=[]
    fet_a=[]
    label_a=[]
    domain_a=[]
    y_one_v=[]
    correct_domain_val=0
    total_domain_val=0
    model.eval()
    domain_classifier.eval()
    for inp, target,domain in val_loader:
            with torch.no_grad():
                inp=inp.to(device)
                target=target.to(device)
                domain=domain.to(device)
#                 tr_label=tr_label.to(device)
                middle_output1,middle_output2,middle_output3,prediction,features = model(inp)
                fet_a.append(features)
                label_a.append(target)
                domain_a.append(domain)
                temp4 = prediction / temperature
                temp4 = torch.softmax(temp4, dim=1)
                probabilities = softmax(prediction)
                predicted_classes = torch.argmax(probabilities, dim=1)
                correct_val += (predicted_classes == target).sum().item()
                total_val += target.size(0)
                label_1_probabilities = probabilities[:, 1]
                y_pred_v.extend(predicted_classes.data.cpu().numpy())
                y_one_v.extend(label_1_probabilities.data.cpu().numpy())
                y_true_v.extend(target.data.cpu().numpy())
                triplet_loss,_=criterion_triplet(features,tr_label)
                loss1by4 = kd_loss_function(middle_output1, temp4.detach())* (temperature**2)
                loss2by4 = kd_loss_function(middle_output2, temp4.detach()) * (temperature**2)
                loss3by4 = kd_loss_function(middle_output3, temp4.detach()) * (temperature**2)
                loss = loss_fn(prediction, target)
                total_loss_val = (1 - alpha) * (loss1by4 + loss2by4 + loss3by4) + alpha *loss +triplet_loss        
                val_loss += total_loss_val.detach().cpu().numpy()
                val_steps += 1
    val_accuracy = correct_val / total_val
    print("val  Accuracy = {}".format(val_loss/val_steps))
    val_losses.append(val_loss/val_steps)
    print("val  Accuracy = {}".format(val_accuracy))
    eer_v=compute_eer(y_true_v,y_one_v)
    print("val eer = {}".format(eer_v))
    return val_accuracy,y_pred_tr,y_true_tr,y_pred_v,y_true_v,fet_a,label_a,domain_a,y_one_tr,y_one_v

def train():
    if torch.cuda.is_available():
        device = "cuda"
    else:
        device = "cpu"
    print(f"Using {device}")
    torch.manual_seed(0)
    epochs=20
    usd = UrbanSoundDataset(ANNOTATIONS_FILE,DOMAIN_PATH,AUDIO_DIR,
                            device)
    test_abs = int(len(usd) * 0.8)
    train_subset, val_subset = random_split(usd, [test_abs, len(usd) - test_abs],generator=torch.Generator().manual_seed(500))
    train_dataloader = create_data_loader(train_subset, BATCH_SIZE)
    val_loader=create_data_loader(val_subset,BATCH_SIZE)
    loss_fn = torch.nn.CrossEntropyLoss()
    domain_loss=torch.nn.CrossEntropyLoss()
    criterion_triplet = OnlineTripleLoss(
            margin=0.1,
            sampling_strategy="random_sh"
        )
    model = eca_resnet18().to(device)
    domain_model=domain_disc().to(device)
    alpha=al
    optimiser = torch.optim.SGD(model.parameters(),
                                 lr=lr1)
    optmiser_domain=torch.optim.SGD(domain_model.parameters(),
                                 lr=lrd)
    train_losses=[]
    val_losses=[]
    for i in range(epochs):
        print(f"Epoch {i+1}")
        val_accuracy,y_pred_tr,y_true_tr,y_pred_v,y_true_v,fet_a,label_a,domain_a,y_one_tr,y_one_v=train_single_epoch(model, train_dataloader,val_loader,loss_fn, optimiser, device,i,alpha,train_losses,val_losses,optmiser_domain,domain_model,domain_loss,criterion_triplet)
        print("---------------------------")
    print("Finished training")
    epochs_plot = range(1, 21)
    plt.plot(epochs_plot, train_losses, label='Training Loss')
    plt.plot(epochs_plot, val_losses, label='Validation Loss')
    classes = ('genuine','spoof')
#     classes1=('D 0','D 1','D 2')
    cf_matrix1 = confusion_matrix(y_true_tr, y_pred_tr)
    cf_matrix2 = confusion_matrix(y_true_v, y_pred_v)
    df_cm1 = pd.DataFrame(cf_matrix1 / np.sum(cf_matrix1, axis=1)[:, None], index = [i for i in classes],
                     columns = [i for i in classes])
    df_cm2 = pd.DataFrame(cf_matrix2 / np.sum(cf_matrix2, axis=1)[:, None], index = [i for i in classes],
                     columns = [i for i in classes])
    plt.figure(figsize = (12,7))
    sn.heatmap(df_cm1, annot=True)
    plt.figure(figsize = (12,7))
    sn.heatmap(df_cm2, annot=True)
    tsned(fet_a,label_a,domain_a)
    plot_det(y_true_tr, y_one_tr)
    plot_det(y_true_v, y_one_v)
    return model


if __name__ == "__main__":

    model=train()
    torch.save(model.state_dict(), "model_triplet.pth")

Complete Proposed System
Self distillation , domain adversarial learning and triplet mining combined. Outputs training and validation loss, accuracy,confusion matrices and det curves, t-SNE visualization of validation set, NMI vs training accuracy over epochs, domain discriminator accuracy over epochs
al,lrd and lr1 replaced with appropriate parameters

In [None]:
import torch
import torchaudio
from torch import nn
from torch.utils.data import DataLoader
import numpy as np
import optuna
from optuna.trial import TrialState
from torch.utils.data import random_split
from matplotlib.pylab import plt
from sklearn.metrics import confusion_matrix
import seaborn as sn
import pandas as pd
BATCH_SIZE = 64
EPOCHS = 20
torch.manual_seed(0)
# LEARNING_RATE = 0.001
softmax = nn.Softmax(dim=1)

def create_data_loader(train_data, batch_size):
    train_dataloader = DataLoader(train_data, batch_size=batch_size,shuffle=True)
    return train_dataloader

def kd_loss_function(output, target_output):
    temperature=3
    output = output / temperature
    output_log_softmax = torch.log_softmax(output, dim=1)
    loss_kd = -torch.mean(torch.sum(output_log_softmax * target_output, dim=1))
    return loss_kd


def train_single_epoch(model, data_loader,val_loader, loss_fn,optimiser_f, device,i,alpha,train_losses,val_losses,optimiser_d,domain_classifier,domain_loss,h_score_tr,h_score_v,tr_acc,val_acc,dom_tr,dom_v,criterion_triplet):
    temperature=3
    correct=0
    fet_a=[]
    label_a=[]
    domain_a=[]
    triplet_a=[]
    total=0
    y_pred_tr=[]
    y_true_tr=[]
    y_one_tr=[]
    y_one_v=[]
    total_train_loss=0
    total_steps=0
    correct_domain=0
    total_domain=0
    model.train()
    domain_classifier.train()
    for inp, target,domain,triplet in data_loader:
        inp = inp.to(device)
        target=target.to(device)
        domain=domain.to(device)
        triplet=triplet.to(device)
        middle_output1,middle_output2,middle_output3,prediction,features = model(inp)
        optimiser_f.zero_grad()
        fet_a.append(features)
        label_a.append(target)
        domain_a.append(domain)
        triplet_a.append(triplet)
        spoof_zero_mask = (target == 0).to(device)
        spoof_zero_inputs = features[spoof_zero_mask]
        spoof_zero_domains = domain[spoof_zero_mask]
        if spoof_zero_inputs.size(0) > 0:
            domain_output = domain_classifier(spoof_zero_inputs)
            d_loss=domain_loss(domain_output, spoof_zero_domains)
        temp4 = prediction / temperature
        temp4 = torch.softmax(temp4, dim=1)
        probabilities = softmax(prediction.detach())
        predicted_classes = torch.argmax(probabilities, dim=1)
        correct += (predicted_classes == target).sum().item()
        total += target.size(0)
        y_pred_tr.extend(predicted_classes.data.cpu().numpy())
        y_true_tr.extend(target.data.cpu().numpy())
        label_1_probabilities = probabilities[:, 1]
        y_one_tr.extend(label_1_probabilities.data.cpu().numpy())
        loss1by4 = kd_loss_function(middle_output1, temp4.detach())* (temperature**2)

        loss2by4 = kd_loss_function(middle_output2, temp4.detach()) * (temperature**2)

        loss3by4 = kd_loss_function(middle_output3, temp4.detach()) * (temperature**2)
    
        loss = loss_fn(prediction, target)
        triplet_loss,_=criterion_triplet(features,triplet)
        total_loss = (1 - alpha) * (loss1by4 + loss2by4 + loss3by4) +  alpha *loss+d_loss+triplet_loss
        total_train_loss+=total_loss.detach().cpu().numpy()
        total_steps+=1

        total_loss.backward()
        optimiser_f.step()
        optimiser_d.zero_grad()
        feats=features.detach()
        dom=domain.detach()
        spoof_zero_inputs = feats[spoof_zero_mask]
        spoof_zero_domains = dom[spoof_zero_mask]
        if spoof_zero_inputs.size(0) > 0:
            domain_output = domain_classifier(spoof_zero_inputs)
            d_loss=domain_loss(domain_output, spoof_zero_domains)
            probabilities_domain=softmax(domain_output.detach())
            predicted_domains=torch.argmax(probabilities_domain, dim=1)
            correct_domain+=(predicted_domains == spoof_zero_domains).sum().item()
            total_domain+=spoof_zero_domains.size(0)
            d_loss.backward()
            optimiser_d.step()
    train_losses.append(total_train_loss/total_steps)
    print("train loss = {}".format(total_train_loss/total_steps))
    accuracy = correct / total
    print("train   Accuracy = {}".format(accuracy))
    domain_accuracy=correct_domain/total_domain
    print("train domain   Accuracy = {}".format(domain_accuracy))
    h_score=cluster_triplet(fet_a,label_a,triplet_a)
    h_score_tr.append(h_score)
    tr_acc.append(accuracy)
    eer=compute_eer(y_true_tr,y_one_tr)
    print("train eer = {}".format(eer))
    dom_tr.append(domain_accuracy)


    val_loss = 0.0
    val_steps = 0
    total_val = 0
    correct_val = 0
    y_pred_v=[]
    y_true_v=[]
    domain_pred_v=[]
    domain_true_v=[]
    fet_a_v=[]
    label_a_v=[]
    domain_a_v=[]
    triplet_a_v=[]
    correct_domain_val=0
    total_domain_val=0
    model.eval()
    domain_classifier.eval()
    for inp, target,domain,triplet in val_loader:
            with torch.no_grad():
                inp=inp.to(device)
                target=target.to(device)
                domain=domain.to(device)
                triplet=triplet.to(device)
                middle_output1,middle_output2,middle_output3,prediction,features = model(inp)
                fet_a_v.append(features)
                label_a_v.append(target)
                domain_a_v.append(domain)
                triplet_a_v.append(triplet)
                spoof_zero_mask = (target == 0).to(device)
                spoof_zero_inputs = features[spoof_zero_mask]
                spoof_zero_domains = domain[spoof_zero_mask]
                if spoof_zero_inputs.size(0) > 0:
                    domain_output = domain_classifier(spoof_zero_inputs)
                    d_loss=domain_loss(domain_output,spoof_zero_domains)
                    probabilities_domain=softmax(domain_output.detach())
                    predicted_domains=torch.argmax(probabilities_domain, dim=1)
                    correct_domain_val+=(predicted_domains ==spoof_zero_domains).sum().item()
                    total_domain_val+=spoof_zero_domains.size(0)
                    domain_pred_v.extend(predicted_domains.data.cpu().numpy())
                    domain_true_v.extend(spoof_zero_domains.data.cpu().numpy())
                temp4 = prediction / temperature
                temp4 = torch.softmax(temp4, dim=1)
                probabilities = softmax(prediction)
                predicted_classes = torch.argmax(probabilities, dim=1)
                correct_val += (predicted_classes == target).sum().item()
                total_val += target.size(0)
                y_pred_v.extend(predicted_classes.data.cpu().numpy())
                y_true_v.extend(target.data.cpu().numpy())
                loss1by4 = kd_loss_function(middle_output1, temp4.detach())* (temperature**2)
                loss2by4 = kd_loss_function(middle_output2, temp4.detach()) * (temperature**2)
                loss3by4 = kd_loss_function(middle_output3, temp4.detach()) * (temperature**2)
                loss = loss_fn(prediction, target)
                triplet_loss,_=criterion_triplet(features,triplet)
                total_loss_val = (1 - alpha) * (loss1by4 + loss2by4 + loss3by4) + alpha *loss+triplet_loss+d_loss 
                val_loss += total_loss_val.detach().cpu().numpy()
                val_steps += 1
                label_1_probabilities = probabilities[:, 1]
                y_one_v.extend(label_1_probabilities.data.cpu().numpy())
    val_accuracy = correct_val / total_val
    domain_acc_val=correct_domain_val/total_domain_val
    print("val  loss = {}".format(val_loss/val_steps))
    val_losses.append(val_loss/val_steps)
    print("val  Accuracy = {}".format(val_accuracy))
#     h_score=cluster_triplet(fet_a_v,label_a_v,triplet_a_v)
#     h_score_v.append(h_score)
    val_acc.append(val_accuracy)
    print("val  domain Accuracy = {}".format(domain_acc_val))
    eer=compute_eer(y_true_v,y_one_v)
    print("val eer = {}".format(eer))
    dom_v.append(domain_acc_val)
    return val_accuracy,y_pred_tr,y_true_tr,y_pred_v,y_true_v,fet_a_v,label_a_v,domain_a_v,domain_acc_val,domain_pred_v,domain_true_v,y_one_tr,y_one_v

def train():
    if torch.cuda.is_available():
        device = "cuda"
    else:
        device = "cpu"
    print(f"Using {device}")
    epochs=20
    usd = UrbanSoundDataset(ANNOTATIONS_FILE,DOMAIN_PATH,AUDIO_DIR,
                            device)
    test_abs = int(len(usd) * 0.8)
    train_subset, val_subset = random_split(usd, [test_abs, len(usd) - test_abs])
    train_dataloader = create_data_loader(train_subset, BATCH_SIZE)
    val_loader=create_data_loader(val_subset,BATCH_SIZE)
    loss_fn = torch.nn.CrossEntropyLoss()
    domain_loss=torch.nn.CrossEntropyLoss()
    model = eca_resnet18().to(device)
    domain_model=domain_disc().to(device)
    alpha= al
    optimiser = torch.optim.SGD(model.parameters(),
                                 lr= lr1)
    optmiser_domain=torch.optim.SGD(domain_model.parameters(),
                                 lr=lrd)
    criterion_triplet = OnlineTripleLoss(
            margin=0.1,
            sampling_strategy="random_sh"
        )

    train_losses=[]
    val_losses=[]
    h_score_v=[]
    h_score_tr=[]
    tr_acc=[]
    val_acc=[]
    dom_tr=[]
    dom_v=[]
    for i in range(epochs):
        print(f"Epoch {i+1}")
        val_accuracy,y_pred_tr,y_true_tr,y_pred_v,y_true_v,fet_a,label_a,domain_a,domain_acc_val,d_p,d_t,y_one_tr,y_one_v=train_single_epoch(model, train_dataloader,val_loader,loss_fn, optimiser, device,i,alpha,train_losses,val_losses,optmiser_domain,domain_model,domain_loss,h_score_tr,h_score_v,tr_acc,val_acc,dom_tr,dom_v,criterion_triplet)
        print("---------------------------")
    print("Finished training")
    epochs_plot = range(1, 21)
    plt.title('Training and Validation Loss over Epochs')
    plt.plot(epochs_plot, train_losses, label='Training Loss')
    plt.plot(epochs_plot, val_losses, label='Validation Loss')
    classes = ('genuine','spoof')
    classes1=('D 0','D 1','D 2')
    cf_matrix1 = confusion_matrix(y_true_tr, y_pred_tr)
    cf_matrix2 = confusion_matrix(y_true_v, y_pred_v)
    cf_matrix3 = confusion_matrix(d_p, d_t)
    df_cm1 = pd.DataFrame(cf_matrix1 / np.sum(cf_matrix1, axis=1)[:, None], index = [i for i in classes],
                     columns = [i for i in classes])
    df_cm2 = pd.DataFrame(cf_matrix2 / np.sum(cf_matrix2, axis=1)[:, None], index = [i for i in classes],
                     columns = [i for i in classes])
    df_cm3 = pd.DataFrame(cf_matrix3 / np.sum(cf_matrix3, axis=1)[:, None], index = [i for i in classes1],
                     columns = [i for i in classes1])
    plt.figure(figsize = (12,7))
    sn.heatmap(df_cm1, annot=True)
    plt.figure(figsize = (12,7))
    sn.heatmap(df_cm2, annot=True)
    plt.figure(figsize = (12,7))
    sn.heatmap(df_cm3, annot=True)
    tsned(fet_a,label_a,domain_a)
    epochs_plot = range(1, 21)
    plt.title('NMI and Accuracy Across Epochs')
    plt.plot(epochs_plot, h_score_tr, label='Training NMI')
    plt.plot(epochs_plot, tr_acc, label='Training Accuracy')
    plt.legend()
    plt.show()
    plt.title('Training and Validation Accuracy over Epochs')
    plt.plot(epochs_plot, tr_acc, label='Training Accuracy')
    plt.plot(epochs_plot, val_acc, label='Validation Accuracy')
    plt.legend()
    plt.show()
    plt.title('Training Accuracy of Domain Discriminator over Epochs')
    plt.plot(epochs_plot, dom_tr, label='Training Accuracy')
    plt.legend()
    plt.show()
    plot_det(y_true_tr,y_one_tr)
    plot_det(y_true_v,y_one_v)
    return model


if __name__ == "__main__":

    model_ev=train()
    torch.save(model_ev.state_dict(), "model_complete.pth")
    print("Trained feed forward net saved at feedforwardnet.pth")