In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchaudio
import numpy as np
from sklearn.model_selection import train_test_split

In [None]:
def extract_mfcc(waveform, sample_rate):
    mfcc_transform = torchaudio.transforms.MFCC(
        sample_rate=sample_rate,
        n_mfcc=40,
        melkwargs={"n_fft": 400, "hop_length": 160, "n_mels": 40}
    )
    return mfcc_transform(waveform)

In [None]:
class SiameseNetwork(nn.Module):
    def __init__(self):
        super(SiameseNetwork, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(1, 64, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 128, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2)
        )
        self.fc = nn.Sequential(
            nn.Linear(128 * 20 * 20, 128),
            nn.ReLU(inplace=True),
            nn.Linear(128, 64)
        )

    def forward_one(self, x):
        x = self.conv(x)
        x = x.view(x.size()[0], -1)
        x = self.fc(x)
        return x

    def forward(self, input1, input2):
        output1 = self.forward_one(input1)
        output2 = self.forward_one(input2)
        return output1, output2

In [None]:
class ContrastiveLoss(nn.Module):
    def __init__(self, margin=1.0):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin

    def forward(self, output1, output2, label):
        euclidean_distance = nn.functional.pairwise_distance(output1, output2)
        loss_contrastive = torch.mean((1 - label) * torch.pow(euclidean_distance, 2) +
                                      label * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2))
        return loss_contrastive

In [None]:
def load_and_preprocess_data(file_path, sample_rate=16000):
    waveform, sr = torchaudio.load(file_path)
    if sr != sample_rate:
        waveform = torchaudio.functional.resample(waveform, sr, sample_rate)
    return extract_mfcc(waveform, sample_rate)

In [None]:
def train_model(model, train_loader, criterion, optimizer, num_epochs=10):
    for epoch in range(num_epochs):
        for i, (input1, input2, label) in enumerate(train_loader):
            output1, output2 = model(input1, input2)
            loss = criterion(output1, output2, label)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}")


In [None]:
def authenticate(model, user_voice, test_voice, threshold=0.5):
    with torch.no_grad():
        output1, output2 = model(user_voice.unsqueeze(0), test_voice.unsqueeze(0))
        distance = nn.functional.pairwise_distance(output1, output2)
        return distance.item() < threshold

In [None]:
if __name__ == "__main__":
    your_voice_samples = [load_and_preprocess_data(f"your_voice_{i}.wav") for i in range(5)]
    random_voices = [load_and_preprocess_data(f"random_voice_{i}.wav") for i in range(100)]

    pairs = []
    labels = []
    for i in range(len(your_voice_samples)):
        for j in range(i+1, len(your_voice_samples)):
            pairs.append((your_voice_samples[i], your_voice_samples[j]))
            labels.append(0) 
        
        for random_voice in random_voices:
            pairs.append((your_voice_samples[i], random_voice))
            labels.append(1) 

    X_train, X_test, y_train, y_test = train_test_split(pairs, labels, test_size=0.2, random_state=42)

    train_loader = torch.utils.data.DataLoader(list(zip(X_train, y_train)), batch_size=32, shuffle=True)
    test_loader = torch.utils.data.DataLoader(list(zip(X_test, y_test)), batch_size=32, shuffle=False)

    model = SiameseNetwork()
    criterion = ContrastiveLoss()
    optimizer = optim.Adam(model.parameters())

    train_model(model, train_loader, criterion, optimizer)

    test_voice = load_and_preprocess_data("test_voice.wav")
    for i, user_voice in enumerate(your_voice_samples):
        is_authenticated = authenticate(model, user_voice, test_voice)
        print(f"Authentication result for sample {i+1}: {'Authenticated' if is_authenticated else 'Not Authenticated'}")