In [9]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import librosa
import numpy as np
import os
from pathlib import Path
import soundfile as sf
from pydub import AudioSegment
import tempfile

In [10]:
# todo: switch all this data preprocessing to the yet to implement PySpark module.
class AudioPreprocessor:
    def __init__(self, sample_rate=16000, duration=3, n_mels=40, n_fft=1024, hop_length=512):
        self.sample_rate = sample_rate
        self.duration = duration
        self.n_samples = sample_rate * duration
        self.n_mels = n_mels
        self.n_fft = n_fft
        self.hop_length = hop_length

    def convert_audio_to_wav(self, input_path):
        try:
            temp_dir = tempfile.mkdtemp()
            temp_path = os.path.join(temp_dir, 'temp.wav')
            
            audio = AudioSegment.from_file(input_path)
            
            if audio.channels > 1:
                audio = audio.set_channels(1)
            
            audio = audio.set_frame_rate(self.sample_rate)
            
            audio.export(temp_path, format='wav')
            
            return temp_path
        except Exception as e:
            print(f"Error converting audio file {input_path}: {str(e)}")
            raise

    def load_and_preprocess_file(self, file_path):
        try:
            file_path = str(Path(file_path).resolve())
            
            print(f"Processing file...")
            
            temp_wav_path = self.convert_audio_to_wav(file_path)
            
            try:
                waveform, sr = sf.read(temp_wav_path)
                print(f"Successfully loaded audio file. Sample rate: {sr}, Shape: {waveform.shape}")
                
                os.remove(temp_wav_path)
                os.rmdir(os.path.dirname(temp_wav_path))
            
                if sr != self.sample_rate:
                    waveform = librosa.resample(waveform, orig_sr=sr, target_sr=self.sample_rate)
                
                if len(waveform) < self.n_samples:
                    waveform = np.pad(waveform, (0, self.n_samples - len(waveform)))
                else:
                    waveform = waveform[:self.n_samples]
                
                mel_spec = librosa.feature.melspectrogram(
                    y=waveform, 
                    sr=self.sample_rate,
                    n_mels=self.n_mels,
                    n_fft=self.n_fft,
                    hop_length=self.hop_length
                )
                
                mel_spec_db = librosa.power_to_db(mel_spec, ref=np.max)
                
                mel_spec_db = (mel_spec_db - mel_spec_db.mean()) / (mel_spec_db.std() + 1e-9)
                
                mel_spec_tensor = torch.FloatTensor(mel_spec_db).unsqueeze(0)
                
                return mel_spec_tensor
                
            except Exception as e:
                print(f"Error processing converted WAV file: {str(e)}")

                if os.path.exists(temp_wav_path):
                    os.remove(temp_wav_path)
                    os.rmdir(os.path.dirname(temp_wav_path))
                raise
        
        except Exception as e:
            print(f"Error processing {file_path}: {str(e)}")
            raise

In [11]:
# todo: switch this CNN to a siamese network
class VoiceAuthNetwork(nn.Module):
    def __init__(self, input_shape=(40, 94), threshold=0.5):
        super(VoiceAuthNetwork, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(1, 64, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 128, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2)
        )
        
        conv_output_h = input_shape[0] // 2
        conv_output_w = input_shape[1] // 2
        self.conv_output_size = 128 * conv_output_h * conv_output_w
        
        self.fc = nn.Sequential(
            nn.Linear(self.conv_output_size, 128),
            nn.ReLU(inplace=True),
            nn.Linear(128, 64)
        )
        self.threshold = threshold
        self.register_buffer('center', torch.zeros(64))
        self.register_buffer('n_samples', torch.tensor(0))

    def forward_one(self, x):
        x = self.conv(x)
        x = x.view(x.size()[0], -1)
        x = self.fc(x)
        return x

    def compute_center(self, owner_samples):
        self.eval()
        with torch.no_grad():
            embeddings = []
            for sample in owner_samples:
                embedding = self.forward_one(sample.unsqueeze(0))
                embeddings.append(embedding)
            
            embeddings = torch.cat(embeddings, dim=0)
            self.center = embeddings.mean(dim=0)
            self.n_samples = torch.tensor(len(owner_samples))

    def authenticate(self, audio_sample):
        self.eval()
        with torch.no_grad():
            embedding = self.forward_one(audio_sample.unsqueeze(0))
            distance = F.pairwise_distance(embedding, self.center.unsqueeze(0))
            return distance.item() < self.threshold

    def update_threshold(self, owner_samples, percentile=95):
        self.eval()
        with torch.no_grad():
            distances = []
            for sample in owner_samples:
                embedding = self.forward_one(sample.unsqueeze(0))
                distance = F.pairwise_distance(embedding, self.center.unsqueeze(0))
                distances.append(distance.item())
            
            self.threshold = np.percentile(distances, percentile)


In [12]:
def prepare_data(data_folder='data/owner'):
    preprocessor = AudioPreprocessor()
    processed_samples = []
    
    try:
        data_folder = str(Path(data_folder).resolve())
        print(f"Looking for audio files in: {data_folder}")
        
        if not os.path.isdir(data_folder):
            raise ValueError(f"Directory not found: {data_folder}")
        
        audio_extensions = ('.wav', '.mp3', '.m4a', '.flac', '.ogg')
        audio_files = [f for f in os.listdir(data_folder) 
                       if f.lower().endswith(audio_extensions)]
        
        if not audio_files:
            raise ValueError(f"No audio files found in {data_folder}")
        
        print(f"Found {len(audio_files)} audio files")
        
        for audio_file in audio_files:
            file_path = os.path.join(data_folder, audio_file)
            try:
                processed_sample = preprocessor.load_and_preprocess_file(file_path)
                processed_samples.append(processed_sample)
                print(f"Successfully processed: {audio_file}")
            except Exception as e:
                print(f"Failed to process {audio_file}: {str(e)}")
        
        if not processed_samples:
            raise ValueError("No audio files were successfully processed")
        
        processed_samples = torch.stack(processed_samples)
        
        print(f"Final tensor shape: {processed_samples.shape}")
        return processed_samples
    
    except Exception as e:
        print(f"Error preparing data: {str(e)}")
        raise

In [13]:
def diagnose_audio_file(file_path):
    try:
        audio = AudioSegment.from_file(file_path)
        print(f"File: {file_path}")
        print(f"Channels: {audio.channels}")
        print(f"Sample width: {audio.sample_width}")
        print(f"Frame rate: {audio.frame_rate}")
        print(f"Frame count: {len(audio)}")
        print(f"Duration: {len(audio) / 1000.0} seconds")
    except Exception as e:
        print(f"Error diagnosing file {file_path}: {str(e)}")


In [14]:
def train_voice_auth(model, owner_samples, epochs=10):
    optimizer = torch.optim.Adam(model.parameters())
    
    for epoch in range(epochs):
        model.train()
        total_loss = 0
        
        for i, sample in enumerate(owner_samples):
            optimizer.zero_grad()
            
            embedding = model.forward_one(sample.unsqueeze(0))
            
            if model.n_samples > 0:
                loss = F.mse_loss(embedding, model.center.unsqueeze(0))
            else:
                loss = torch.tensor(0.0, requires_grad=True)
            
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
        
        model.compute_center(owner_samples)
        
        print(f"Epoch {epoch+1}/{epochs}, Loss: {total_loss/len(owner_samples):.4f}")
    
    model.update_threshold(owner_samples)
    return model

In [15]:
def authenticate_audio(model, audio_file):
    preprocessor = AudioPreprocessor()
    try:
        audio_tensor = preprocessor.load_and_preprocess_file(audio_file)
        return model.authenticate(audio_tensor)
    except Exception as e:
        print(f"Error during authentication: {str(e)}")
        return False

In [16]:
if __name__ == "__main__":
    try:
        data_folder = 'data/owner' 
        
        print("Diagnosing audio files...")
        for audio_file in os.listdir(data_folder):
            if audio_file.lower().endswith(('.wav', '.mp3', '.m4a', '.flac', '.ogg')):
                diagnose_audio_file(os.path.join(data_folder, audio_file))
        
        print("\nStarting data preparation...")
        owner_samples = prepare_data(data_folder=data_folder)
        
        input_shape = tuple(owner_samples[0].shape[1:])
        print(f"Input shape determined as: {input_shape}")
        
        print("Creating and training model...")
        model = VoiceAuthNetwork(input_shape=input_shape)
        model = train_voice_auth(model, owner_samples)
        
        print("Saving model...")
        save_path = 'voice_auth_model.pth'
        torch.save({
            'model_state_dict': model.state_dict(),
            'input_shape': input_shape,
            'threshold': model.threshold,
            'center': model.center
        }, save_path)
        print(f"Model saved to: {save_path}")
        
        test_file = os.path.join(data_folder, '0_owner.wav') 
        if os.path.exists(test_file):
            print(f"Testing authentication with file: {test_file}")
            is_authenticated = authenticate_audio(model, test_file)
            print(f"Authentication result: {'Accepted' if is_authenticated else 'Rejected'}")
        else:
            print(f"Test file not found: {test_file}")
        
    except Exception as e:
        print(f"An error occurred: {str(e)}")

Diagnosing audio files...
File: data/owner/1_owner.wav
Channels: 1
Sample width: 4
Frame rate: 48000
Frame count: 8400
Duration: 8.4 seconds
File: data/owner/0_owner.wav
Channels: 1
Sample width: 4
Frame rate: 48000
Frame count: 4320
Duration: 4.32 seconds
File: data/owner/3_owner.wav
Channels: 1
Sample width: 4
Frame rate: 48000
Frame count: 4500
Duration: 4.5 seconds
File: data/owner/2_owner.wav
Channels: 1
Sample width: 4
Frame rate: 48000
Frame count: 5100
Duration: 5.1 seconds

Starting data preparation...
Looking for audio files in: /Users/rishiviswanathan/Desktop/voice-auth/src/models/data/owner
Found 4 audio files
Processing file...
Successfully loaded audio file. Sample rate: 16000, Shape: (134400,)
Successfully processed: 1_owner.wav
Processing file...
Successfully loaded audio file. Sample rate: 16000, Shape: (69120,)
Successfully processed: 0_owner.wav
Processing file...
Successfully loaded audio file. Sample rate: 16000, Shape: (72000,)
Successfully processed: 3_owner.wav
