In [None]:
!pip install torchaudio
!pip install librosa
!pip install matplotlib
!pip install seaborn
!pip install scikit-learn

In [None]:
import torch
import torchaudio
import torchaudio.transforms as transforms
import torch.nn as nn
import numpy as np
from sklearn.metrics import roc_curve, roc_auc_score, precision_recall_curve
from torch.utils.data import DataLoader, Dataset, Subset
import matplotlib.pyplot as plt
import librosa.display
from IPython.display import Audio, display
from torchvision import models
import torch.optim as optim
import torch.nn.functional as F
import os
from tqdm import tqdm
import random
import time
import seaborn as sns
from torchvision.models import ResNet18_Weights
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.nn.utils.rnn import pad_sequence
from scipy.optimize import brentq
from scipy.interpolate import interp1d
from sklearn.manifold import TSNE

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")


In [None]:
torch.manual_seed(42)
random.seed(42)
np.random.seed(42)


In [None]:

librispeech_dataset = torchaudio.datasets.LIBRISPEECH(".", url="train-clean-100", download=True)
from IPython.display import Audio, display

def play_audio(waveform, sample_rate):
    display(Audio(waveform.numpy(), rate=sample_rate))
sample_waveform, sample_rate, utterance, speaker_id, chapter_id, utterance_id = librispeech_dataset[0]
print(f"Sample Rate: {sample_rate}")
print(f"Speaker ID: {speaker_id}, Chapter ID: {chapter_id}, Utterance ID: {utterance_id}")
print(f"Transcript: {utterance}")
print(f"Waveform shape: {sample_waveform.shape}")
play_audio(sample_waveform, sample_rate)


In [None]:
mfcc_transform = transforms.MFCC(sample_rate=16000, n_mfcc=13)
for i in range(10):
    sample_waveform, sample_rate, utterance, speaker_id, chapter_id, utterance_id = librispeech_dataset[i]
    plt.figure(figsize=(14, 5))
    plt.subplot(2, 1, 1)
    librosa.display.waveshow(sample_waveform.numpy(), sr=sample_rate)
    plt.title(f'Waveform of Sample {i+1} - Speaker: {speaker_id}')
    plt.xlabel('Time (s)')
    plt.ylabel('Amplitude')
    mfcc_features = mfcc_transform(sample_waveform)
    plt.subplot(2, 1, 2)
    librosa.display.specshow(mfcc_features[0].numpy(), sr=sample_rate, x_axis='time')
    plt.colorbar(format='%+2.0f dB')
    plt.title(f'MFCCs of Sample {i+1}')
    plt.xlabel('Time (s)')
    plt.ylabel('MFCC Coefficients')

    plt.tight_layout()
    plt.show()


In [None]:
def augment_waveform(waveform, sample_rate):
    waveform = transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
    noise = torch.randn_like(waveform) * 0.005
    waveform = waveform + noise
    waveform = transforms.Vol(gain=0.5)(waveform)
    return waveform

def get_mfcc(waveform, sample_rate, max_length=500):
    waveform = transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
    mfcc_transform = transforms.MFCC(
        sample_rate=16000,
        n_mfcc=13,
        melkwargs={"n_fft": 512, "n_mels": 80, "hop_length": 160}
    )
    mfcc = mfcc_transform(waveform)
    if mfcc.size(-1) > max_length:
        mfcc = mfcc[:, :, :max_length]
    else:
        mfcc = F.pad(mfcc, (0, max_length - mfcc.size(-1)))
    mfcc_mean = mfcc.mean(dim=-1, keepdim=True)
    mfcc_std = mfcc.std(dim=-1, keepdim=True)
    mfcc_normalized = (mfcc - mfcc_mean) / (mfcc_std + 1e-6)
    mfcc_normalized = mfcc_normalized.unsqueeze(0)
    return mfcc_normalized



In [None]:

batch_size = 16
subset_size = 5000
indices = random.sample(range(len(librispeech_dataset)), subset_size)
train_size = int(0.8 * subset_size)
train_indices = indices[:train_size]
val_indices = indices[train_size:]

train_dataset = Subset(librispeech_dataset, train_indices)
val_dataset = Subset(librispeech_dataset, val_indices)


In [None]:
class TripletDataset(Dataset):
    def __init__(self, dataset, transform, augment=False):
        self.dataset = dataset
        self.transform = transform
        self.augment = augment
        self.speaker_indices = self._build_speaker_indices()
        self.triplets = self._generate_triplets()

    def _build_speaker_indices(self):
        speaker_indices = {}
        for i, (_, _, _, speaker_id, _, _) in enumerate(self.dataset):
            if speaker_id not in speaker_indices:
                speaker_indices[speaker_id] = []
            speaker_indices[speaker_id].append(i)
        return speaker_indices

    def _generate_triplets(self):
        triplets = []
        speaker_ids = list(self.speaker_indices.keys())

        for speaker_id in speaker_ids:
            indices = self.speaker_indices[speaker_id]
            other_speaker_ids = [id for id in speaker_ids if id != speaker_id]

            for anchor_idx in indices:
                if len(indices) < 2:
                    continue
                positive_idx = random.choice([idx for idx in indices if idx != anchor_idx])
                negative_speaker_id = random.choice(other_speaker_ids)
                negative_idx = random.choice(self.speaker_indices[negative_speaker_id])

                triplets.append((anchor_idx, positive_idx, negative_idx))
        random.shuffle(triplets)
        return triplets

    def __getitem__(self, idx):
        anchor_idx, positive_idx, negative_idx = self.triplets[idx]

        anchor_waveform, sample_rate, _, _, _, _ = self.dataset[anchor_idx]
        positive_waveform, _, _, _, _, _ = self.dataset[positive_idx]
        negative_waveform, _, _, _, _, _ = self.dataset[negative_idx]

        if self.augment:
            anchor_waveform = augment_waveform(anchor_waveform, sample_rate)
            positive_waveform = augment_waveform(positive_waveform, sample_rate)
            negative_waveform = augment_waveform(negative_waveform, sample_rate)

        anchor = self.transform(anchor_waveform, sample_rate).squeeze(0)
        positive = self.transform(positive_waveform, sample_rate).squeeze(0)
        negative = self.transform(negative_waveform, sample_rate).squeeze(0)

        return anchor, positive, negative

    def __len__(self):
        return len(self.triplets)



In [None]:
class PairDataset(Dataset):
    def __init__(self, dataset, transform):
        self.dataset = dataset
        self.transform = transform
        self.speaker_indices = self._build_speaker_indices()
        self.pairs = self._generate_pairs()

    def _build_speaker_indices(self):
        speaker_indices = {}
        for i, (_, _, _, speaker_id, _, _) in enumerate(self.dataset):
            if speaker_id not in speaker_indices:
                speaker_indices[speaker_id] = []
            speaker_indices[speaker_id].append(i)
        return speaker_indices

    def _generate_pairs(self):
        pairs = []
        speaker_ids = list(self.speaker_indices.keys())
        for speaker_id in speaker_ids:
            indices = self.speaker_indices[speaker_id]
            for i in range(len(indices)):
                for j in range(i + 1, len(indices)):
                    pairs.append((indices[i], indices[j], 1))
        num_positive_pairs = len(pairs)
        num_negative_pairs = 0
        while num_negative_pairs < num_positive_pairs:
            idx1 = random.choice(range(len(self.dataset)))
            idx2 = random.choice(range(len(self.dataset)))
            _, _, _, speaker_id1, _, _ = self.dataset[idx1]
            _, _, _, speaker_id2, _, _ = self.dataset[idx2]
            if speaker_id1 != speaker_id2:
                pairs.append((idx1, idx2, 0))
                num_negative_pairs += 1
        random.shuffle(pairs)
        return pairs

    def __getitem__(self, idx):
        idx1, idx2, label = self.pairs[idx]
        waveform1, sample_rate1, _, _, _, _ = self.dataset[idx1]
        waveform2, sample_rate2, _, _, _, _ = self.dataset[idx2]

        sample1 = self.transform(waveform1, sample_rate1).squeeze(0)
        sample2 = self.transform(waveform2, sample_rate2).squeeze(0)

        return sample1, sample2, label

    def __len__(self):
        return len(self.pairs)



In [None]:
def collate_fn_triplet(batch):
    anchors, positives, negatives = zip(*batch)
    anchors_padded = pad_sequence(anchors, batch_first=True)
    positives_padded = pad_sequence(positives, batch_first=True)
    negatives_padded = pad_sequence(negatives, batch_first=True)
    return anchors_padded, positives_padded, negatives_padded

def collate_fn_eval(batch):
    sample1, sample2, labels = zip(*batch)
    sample1_padded = pad_sequence(sample1, batch_first=True)
    sample2_padded = pad_sequence(sample2, batch_first=True)
    labels = torch.tensor(labels, dtype=torch.float32)
    return sample1_padded, sample2_padded, labels


In [None]:
eval_dataset = PairDataset(val_dataset, transform=get_mfcc)
eval_loader = DataLoader(
    eval_dataset,
    batch_size=batch_size,
    shuffle=False,
    collate_fn=collate_fn_eval,
    num_workers=0
)


In [None]:
class TripletLoss(nn.Module):
    def __init__(self, margin=1.0):
        super(TripletLoss, self).__init__()
        self.margin = margin

    def forward(self, anchor, positive, negative):
        pos_dist = F.pairwise_distance(anchor, positive)
        neg_dist = F.pairwise_distance(anchor, negative)
        loss = F.relu(pos_dist - neg_dist + self.margin)
        return loss.mean()


In [None]:
class AttentionLayer(nn.Module):
    def __init__(self, input_dim):
        super(AttentionLayer, self).__init__()
        self.W = nn.Linear(input_dim, input_dim)
        self.v = nn.Linear(input_dim, 1)

    def forward(self, x):
        scores = torch.tanh(self.W(x))
        scores = self.v(scores)
        attention_weights = torch.softmax(scores, dim=1)
        context_vector = torch.sum(attention_weights * x, dim=1)
        return context_vector

class SpeakerVerificationSiameseNet(nn.Module):
    def __init__(self, embedding_dim=256, freeze_layers=False):
        super(SpeakerVerificationSiameseNet, self).__init__()
        self.resnet = models.resnet18(weights=ResNet18_Weights.DEFAULT)
        self.resnet.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
        if freeze_layers:
            for param in self.resnet.parameters():
                param.requires_grad = False
        self.resnet.avgpool = nn.Identity()
        self.resnet.fc = nn.Identity()
        self.rnn = nn.GRU(input_size=512, hidden_size=embedding_dim, batch_first=True)
        self.attention_layer = AttentionLayer(embedding_dim)
        self.projection_layer = nn.Sequential(
            nn.BatchNorm1d(embedding_dim),
            nn.Linear(embedding_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 64)
        )

    def forward(self, x):
        batch_size = x.size(0)
        x = self.resnet.conv1(x)
        x = self.resnet.bn1(x)
        x = self.resnet.relu(x)
        x = self.resnet.maxpool(x)
        x = self.resnet.layer1(x)
        x = self.resnet.layer2(x)
        x = self.resnet.layer3(x)
        x = self.resnet.layer4(x)
        x = x.permute(0, 2, 3, 1)
        x = x.view(batch_size, x.size(1) * x.size(2), -1)
        x, _ = self.rnn(x)
        x = self.attention_layer(x)
        x = self.projection_layer(x)
        x = F.normalize(x, p=2, dim=1)
        return x


In [None]:
subset_size = 5000
indices = random.sample(range(len(librispeech_dataset)), subset_size)
train_size = int(0.8 * subset_size)
train_indices = indices[:train_size]
val_indices = indices[train_size:]

train_dataset = Subset(librispeech_dataset, train_indices)
val_dataset = Subset(librispeech_dataset, val_indices)

train_loader = DataLoader(
    TripletDataset(train_dataset, transform=get_mfcc, augment=True),
    batch_size=batch_size,
    shuffle=True,
    collate_fn=collate_fn_triplet,
    num_workers=0
)

val_loader = DataLoader(
    TripletDataset(val_dataset, transform=get_mfcc, augment=False),
    batch_size=batch_size,
    shuffle=False,
    collate_fn=collate_fn_triplet,
    num_workers=0
)




In [None]:
model = SpeakerVerificationSiameseNet(embedding_dim=256).to(device)

optimizer = optim.Adam(model.parameters(), lr=1e-4)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5)
loss_fn = TripletLoss(margin=1.0).to(device)



In [None]:
num_epochs = 50
best_val_loss = float('inf')
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for batch_idx, (anchor, positive, negative) in enumerate(train_loader):
        anchor, positive, negative = anchor.to(device), positive.to(device), negative.to(device)
        optimizer.zero_grad()

        # Get embeddings for anchor, positive, and negative
        anchor_embedding = model(anchor)
        positive_embedding = model(positive)
        negative_embedding = model(negative)

        # Compute triplet loss
        loss = loss_fn(anchor_embedding, positive_embedding, negative_embedding)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        if batch_idx % 10 == 0:
            print(f'Epoch {epoch}, Batch {batch_idx}, Loss: {loss.item()}')

    # Validation
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for anchor, positive, negative in val_loader:
            anchor, positive, negative = anchor.to(device), positive.to(device), negative.to(device)
            anchor_embedding = model(anchor)
            positive_embedding = model(positive)
            negative_embedding = model(negative)
            val_loss += loss_fn(anchor_embedding, positive_embedding, negative_embedding).item()

    avg_train_loss = running_loss / len(train_loader)
    avg_val_loss = val_loss / len(val_loader)
    print(f'Epoch {epoch} completed. Training loss: {avg_train_loss}, Validation loss: {avg_val_loss}')

    scheduler.step(avg_val_loss)
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        torch.save(model.state_dict(), 'best_speaker_verification_model.pth')
        print("Model saved with improved validation loss.")



In [None]:
model.load_state_dict(torch.load('best_speaker_verification_model.pth'))
model.to(device)
model.eval()
print("Model loaded successfully.")


In [None]:
def evaluate_model(model, eval_loader, device):
    """Evaluate the model using ROC and calculate EER."""
    model.eval()
    all_distances, all_labels = [], []
    with torch.no_grad():
        for anchor, query, label in eval_loader:
            anchor, query, label = anchor.to(device), query.to(device), label.to(device)
            embedding1 = model(anchor)
            embedding2 = model(query)
            distance = F.pairwise_distance(embedding1, embedding2)
            all_distances.extend(-distance.cpu().numpy())
            all_labels.extend(label.cpu().numpy())

    # Calculate ROC Curve and AUC
    fpr, tpr, thresholds = roc_curve(all_labels, all_distances)
    auc_score = roc_auc_score(all_labels, all_distances)

    # Plot ROC Curve
    plt.figure()
    plt.plot(fpr, tpr, label=f'ROC curve (AUC = {auc_score:.2f})')
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.title("ROC Curve")
    plt.legend()
    plt.grid(True)
    plt.show()

    # Calculate Precision-Recall Curve
    precision, recall, _ = precision_recall_curve(all_labels, all_distances)
    plt.figure()
    plt.plot(recall, precision, label='Precision-Recall Curve')
    plt.xlabel("Recall")
    plt.ylabel("Precision")
    plt.title("Precision-Recall Curve")
    plt.legend()
    plt.grid(True)
    plt.show()

    # Calculate EER
    fnr = 1 - tpr
    eer_threshold = brentq(lambda x: 1. - x - interp1d(fpr, fnr)(x), 0., 1.)
    eer = interp1d(fpr, thresholds)(eer_threshold)
    print(f"AUC Score: {auc_score:.2f}")
    print(f"EER: {eer_threshold * 100:.2f}% at threshold {eer:.4f}")


In [None]:

def find_samples_by_speaker(dataset, speaker_id, num_samples=2):
    indices = []
    for idx in range(len(dataset)):
        _, _, _, spk_id, _, _ = dataset[idx]
        if spk_id == speaker_id:
            indices.append(idx)
            if len(indices) == num_samples:
                break
    return indices
same_speaker_id = librispeech_dataset[0][3]
same_speaker_indices = find_samples_by_speaker(librispeech_dataset, same_speaker_id, num_samples=2)
diff_speaker_ids = []
for idx in range(len(librispeech_dataset)):
    _, _, _, spk_id, _, _ = librispeech_dataset[idx]
    if spk_id != same_speaker_id and spk_id not in diff_speaker_ids:
        diff_speaker_ids.append(spk_id)
    if len(diff_speaker_ids) == 2:
        break

diff_speaker_indices = []
for spk_id in diff_speaker_ids:
    indices = find_samples_by_speaker(librispeech_dataset, spk_id, num_samples=1)
    diff_speaker_indices.extend(indices)


In [None]:

waveform1, sample_rate1, _, _, _, _ = librispeech_dataset[same_speaker_indices[0]]
waveform2, sample_rate2, _, _, _, _ = librispeech_dataset[same_speaker_indices[1]]

mfcc1 = get_mfcc(waveform1, sample_rate1).squeeze(0).to(device)
mfcc2 = get_mfcc(waveform2, sample_rate2).squeeze(0).to(device)
waveform3, sample_rate3, _, _, _, _ = librispeech_dataset[diff_speaker_indices[0]]
waveform4, sample_rate4, _, _, _, _ = librispeech_dataset[diff_speaker_indices[1]]

mfcc3 = get_mfcc(waveform3, sample_rate3).squeeze(0).to(device)
mfcc4 = get_mfcc(waveform4, sample_rate4).squeeze(0).to(device)


In [None]:

model.eval()

with torch.no_grad():
    embedding1 = model(mfcc1.unsqueeze(0))
    embedding2 = model(mfcc2.unsqueeze(0))
    embedding3 = model(mfcc3.unsqueeze(0))
    embedding4 = model(mfcc4.unsqueeze(0))


In [None]:

distance_same = F.pairwise_distance(embedding1, embedding2).item()
distance_diff1 = F.pairwise_distance(embedding1, embedding3).item()
distance_diff2 = F.pairwise_distance(embedding1, embedding4).item()

print(f"Distance between same speaker samples: {distance_same:.4f}")
print(f"Distance between different speaker samples (Sample 1 and 3): {distance_diff1:.4f}")
print(f"Distance between different speaker samples (Sample 1 and 4): {distance_diff2:.4f}")


In [None]:
def enroll_user(model, enrollment_audio_path):
    """
    Generate and store the embedding for the enrollment audio.

    Args:
        model: Trained speaker verification model.
        enrollment_audio_path: Path to the enrollment audio file.

    Returns:
        embedding: The stored embedding for the enrollment audio.
    """
    waveform, sample_rate = torchaudio.load(enrollment_audio_path)
    if sample_rate != 16000:
        waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
    mfcc = get_mfcc(waveform, sample_rate=16000).to(device)

    with torch.no_grad():
        embedding = model(mfcc.unsqueeze(0))
    return embedding.cpu()



In [None]:
def verify_user(model, stored_embedding, verification_audio_path, threshold):
    """
    Verify if the verification audio matches the stored embedding.

    Args:
        model: Trained speaker verification model.
        stored_embedding: The stored embedding from the enrollment phase.
        verification_audio_path: Path to the verification audio file.
        threshold: Similarity threshold for authentication.

    Returns:
        is_authenticated (bool): Whether the user is authenticated.
        similarity_score (float): The similarity score between the embeddings.
    """
    waveform, sample_rate = torchaudio.load(verification_audio_path)
    if sample_rate != 16000:
        waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)

    mfcc = get_mfcc(waveform, sample_rate=16000).to(device)

    with torch.no_grad():
        current_embedding = model(mfcc.unsqueeze(0))

    distance = F.pairwise_distance(stored_embedding, current_embedding).item()
    similarity_score = -distance
    is_authenticated = similarity_score > threshold
    return is_authenticated, similarity_score


In [None]:
stored_embedding = enroll_user(model, '/content/enrollment_audio.wav')
is_authenticated, score = verify_user(model, stored_embedding, '/content/verification_audio.wav', threshold=0.5)

print(f"Authenticated: {is_authenticated}, Similarity Score: {score:.4f}")
