In [1]:
! pip install  python-speech-features dill

Collecting python-speech-features
  Downloading python_speech_features-0.6.tar.gz (5.6 kB)
  Preparing metadata (setup.py) ... [?25ldone
Building wheels for collected packages: python-speech-features
  Building wheel for python-speech-features (setup.py) ... [?25ldone
[?25h  Created wheel for python-speech-features: filename=python_speech_features-0.6-py3-none-any.whl size=5868 sha256=855a67f7d12b76bb7cd095355adc7fcc8e166dc9352c457d8a9ade145bff3e76
  Stored in directory: /root/.cache/pip/wheels/5a/9e/68/30bad9462b3926c29e315df16b562216d12bdc215f4d240294
Successfully built python-speech-features
Installing collected packages: python-speech-features
Successfully installed python-speech-features-0.6


In [None]:
# !mkdir -p ~/.kaggle
# !cp kaggle.json ~/.kaggle/
# !chmod 600 ~/.kaggle/kaggle.json

In [None]:
# ! kaggle datasets download -d abhirupadhikary/kathbath-hindi

In [None]:
# ! unzip kathbath-hindi.zip

In [2]:
import torch
import numpy as np
import torch.nn as nn
from torch.utils.data import Dataset,DataLoader
import torch.optim as optim
import torch.nn.init as init 
import logging
import os
from collections import defaultdict
from pathlib import Path
from glob import glob
import librosa
from python_speech_features import fbank
from tqdm.notebook import tqdm
import json
from collections import deque, Counter
from random import choice
from time import time

import dill


In [3]:
logger = logging.getLogger(__name__)
SAMPLE_RATE = 16000
NUM_FBANKS = 64
NUM_FRAMES = 160

In [64]:
class ClippedRelu(nn.Module):
  def __init__(self):
    super(ClippedRelu, self).__init__()
  def forward(self,x):
    return torch.min(torch.max(x,torch.tensor(0.0)),torch.tensor(20.0))

In [65]:
class IdentityBlock(nn.Module):
    def __init__(self,in_channels,kernel_size,filters,l2=0.0001):
        super(IdentityBlock, self).__init__()
        self.kernel_size = kernel_size
        self.filters = filters
        self.l2=l2
        self.conv1 = nn.Conv2d(in_channels=in_channels,
                               out_channels=filters,
                               kernel_size=kernel_size,
                               stride=1
                               ,bias=True,
                               padding="same")
        init.xavier_uniform_(self.conv1.weight)
        self.bn1 = nn.BatchNorm2d(num_features=filters)
        self.clipped_relu1 = ClippedRelu()
        self.conv2 = nn.Conv2d(in_channels=filters,
                               out_channels=filters,
                               kernel_size=kernel_size,
                               stride=1
                               ,bias=True,
                               padding="same")
        init.xavier_uniform_(self.conv2.weight)
        self.bn2 = nn.BatchNorm2d(num_features=filters)
        self.clipped_relu2=ClippedRelu()
        self.clipped_relu3=ClippedRelu()

    def forward(self,x):
        x_shortcut = x
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.clipped_relu1(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.clipped_relu2(x)
        x = x + x_shortcut
        return self.clipped_relu3(x)
    
    def l2_regularization(self):
        n=torch.norm(self.conv1.weight,p=2)+torch.norm(self.conv2.weight,p=2)
        return n*self.l2

In [66]:
class ConvResBlock(nn.Module):
  def __init__(self,in_channels,filters,l2=0.0001) -> None:
     super(ConvResBlock,self).__init__()
     self.filters=filters
     self.l2=l2
     self.conv1=nn.Conv2d(in_channels=in_channels
                          ,out_channels=filters,
                          kernel_size=5
                          ,stride=2
                          ,bias=True
                          ,padding=2)
     init.xavier_uniform_(self.conv1.weight)
     self.bn1=nn.BatchNorm2d(num_features=filters)
     self.clipped_relu1=ClippedRelu()
     self.id_block1=IdentityBlock(in_channels=filters,kernel_size=3,filters=filters,l2=l2)
     self.id_block2=IdentityBlock(in_channels=filters,kernel_size=3,filters=filters,l2=l2)
     self.id_block3=IdentityBlock(in_channels=filters,kernel_size=3,filters=filters,l2=l2)

  def forward(self,x):
    x=self.conv1(x)
    x=self.bn1(x)
    x=self.clipped_relu1(x)
    x=self.id_block1(x)
    x=self.id_block2(x)
    x=self.id_block3(x)
    return x
  def l2_regularization(self):
    n=torch.norm(self.conv1.weight,p=2)*self.l2+self.id_block1.l2_regularization()+self.id_block2.l2_regularization()+self.id_block3.l2_regularization()
    return n
    

In [69]:
class DeepSpeaker(nn.Module):
  def __init__(self,l2=0.0001) -> None:
    super(DeepSpeaker,self).__init__()
    self.l2=l2
    self.conv_res_block1=ConvResBlock(1,filters=64,l2=l2)
    self.conv_res_block2=ConvResBlock(64,filters=128,l2=l2)
#     self.conv_res_block3=ConvResBlock(128,filters=256,l2=l2)
#     self.conv_res_block4=ConvResBlock(256,filters=512,l2=l2)
    self.linear=nn.Linear(in_features=2048,out_features=512)

  def forward(self,x):
    x=self.conv_res_block1(x)
    x=self.conv_res_block2(x)
#     x=self.conv_res_block3(x)
#     x=self.conv_res_block4(x)
    batch,channels,height,width=x.shape
    x=torch.reshape(x,(batch,-1,2048))
    x=torch.mean(x,dim=1)
    x=self.linear(x)
    return x

  def l2_regularization(self):
    n=self.conv_res_block1.l2_regularization()+self.conv_res_block2.l2_regularization()
    return n

In [70]:
model=DeepSpeaker()

In [13]:
x=torch.rand(10,1,160,64)

In [73]:
out=model(x)
out.shape

torch.Size([10, 512])

In [74]:
sum(p.numel() for p in model.parameters())

2365440

In [16]:
def batch_cosine_similarity(x,y):
    x_norm=torch.norm(x,dim=1,keepdim=True)
    y_norm=torch.norm(y,dim=1,keepdim=True)
    return torch.matmul(x,y.T)/x_norm/y_norm

In [17]:
class DeepSpeakerLoss(nn.Module):
  def __init__(self) -> None:
    super(DeepSpeakerLoss,self).__init__()
    self.alpha=0.1
  
  def forward(self,y_true,y_pred):
    split=y_pred.shape[0]//3
    anchor=y_pred[:split]
    positive=y_pred[split:2*split]
    negative=y_pred[2*split:]
    cos_sim_ap=batch_cosine_similarity(anchor,positive)
    cos_sim_an=batch_cosine_similarity(anchor,negative)
    loss=torch.max(cos_sim_an-cos_sim_ap+self.alpha,torch.tensor(0.0))
    return torch.mean(loss)


In [18]:
def find_files(directory, ext='wav'):
    return sorted(glob(directory + f'/**/*.{ext}', recursive=True))

def ensures_dir(directory: str):
    if len(directory) > 0 and not os.path.exists(directory):
        os.makedirs(directory)

def read_mfcc(input_filename, sample_rate):
    audio = Audio.read(input_filename, sample_rate)
    energy = np.abs(audio)
    silence_threshold = np.percentile(energy, 95)
    offsets = np.where(energy > silence_threshold)[0]
    audio_voice_only = audio[offsets[0]:offsets[-1]]
    mfcc = mfcc_fbank(audio_voice_only, sample_rate)
    return mfcc


def extract_speaker_and_utterance_ids(filename: str):  # LIBRI.
    # 'audio/dev-other/116/288045/116-288045-0000.flac'
    speaker, basename = Path(filename).parts[-2:]
    filename.split('-')
    # utterance = os.path.splitext(basename.split('-', 1)[-1])[0]
    # assert basename.split('-')[0] == speaker
    return speaker, basename.split('.')[0]


class Audio:

    def __init__(self, cache_dir: str, audio_dir: str = None, sample_rate: int = SAMPLE_RATE, ext='flac'):
        self.ext = ext
        self.cache_dir = os.path.join(cache_dir, 'audio-fbanks')
        ensures_dir(self.cache_dir)
        if audio_dir is not None:
            self.build_cache(os.path.expanduser(audio_dir), sample_rate)
        self.speakers_to_utterances = defaultdict(dict)
        for cache_file in find_files(self.cache_dir, ext='npy'):
            # /path/to/speaker_utterance.npy
            speaker_id, utterance_id = Path(cache_file).stem.split('_')
            self.speakers_to_utterances[speaker_id][utterance_id] = cache_file

    @property
    def speaker_ids(self):
        return sorted(self.speakers_to_utterances)

    @staticmethod
    def trim_silence(audio, threshold):
        """Removes silence at the beginning and end of a sample."""
        # pylint: disable=E1121
        energy = librosa.feature.rms(audio)
        frames = np.nonzero(np.array(energy > threshold))
        indices = librosa.core.frames_to_samples(frames)[1]

        # Note: indices can be an empty array, if the whole audio was silence.
        audio_trim = audio[0:0]
        left_blank = audio[0:0]
        right_blank = audio[0:0]
        if indices.size:
            audio_trim = audio[indices[0]:indices[-1]]
            left_blank = audio[:indices[0]]  # slice before.
            right_blank = audio[indices[-1]:]  # slice after.
        return audio_trim, left_blank, right_blank

    @staticmethod
    def read(filename, sample_rate=SAMPLE_RATE):
        audio, sr = librosa.load(filename, sr=sample_rate, mono=True, dtype=np.float32)
        assert sr == sample_rate
        return audio

    def build_cache(self, audio_dir, sample_rate):
        logger.info(f'audio_dir: {audio_dir}.')
        logger.info(f'sample_rate: {sample_rate:,} hz.')
        audio_files = find_files(audio_dir, ext=self.ext)
        audio_files_count = len(audio_files)
        assert audio_files_count != 0, f'Could not find any {self.ext} files in {audio_dir}.'
        logger.info(f'Found {audio_files_count:,} files in {audio_dir}.')
        with tqdm(audio_files) as bar:
            for audio_filename in bar:
                bar.set_description(audio_filename)
                self.cache_audio_file(audio_filename, sample_rate)

    def cache_audio_file(self, input_filename, sample_rate):
        sp, utt = extract_speaker_and_utterance_ids(input_filename)
        cache_filename = os.path.join(self.cache_dir, f'{sp}_{utt}.npy')
        if not os.path.isfile(cache_filename):
            try:
                mfcc = read_mfcc(input_filename, sample_rate)
                np.save(cache_filename, mfcc)
            except librosa.util.exceptions.ParameterError as e:
                logger.error(e)


def pad_mfcc(mfcc, max_length):  # num_frames, nfilt=64.
    if len(mfcc) < max_length:
        mfcc = np.vstack((mfcc, np.tile(np.zeros(mfcc.shape[1]), (max_length - len(mfcc), 1))))
    return mfcc


def mfcc_fbank(signal: np.array, sample_rate: int):  # 1D signal array.
    # Returns MFCC with shape (num_frames, n_filters, 3).
    filter_banks, energies = fbank(signal, samplerate=sample_rate, nfilt=NUM_FBANKS)
    frames_features = normalize_frames(filter_banks)
    # delta_1 = delta(filter_banks, N=1)
    # delta_2 = delta(delta_1, N=1)
    # frames_features = np.transpose(np.stack([filter_banks, delta_1, delta_2]), (1, 2, 0))
    return np.array(frames_features, dtype=np.float32)  # Float32 precision is enough here.


def normalize_frames(m, epsilon=1e-12):
    return [(v - np.mean(v)) / max(np.std(v), epsilon) for v in m]


In [19]:
! mkdir working_dir

In [20]:
! mkdir model_checkpoints

In [21]:
Audio(cache_dir='/kaggle/working/working_dir',audio_dir='/kaggle/input/kathbath-hindi/kathbath/hindi',ext='wav')

  0%|          | 0/5059 [00:00<?, ?it/s]

<__main__.Audio at 0x7fd0e3019120>

In [22]:
f=np.load('/kaggle/working/working_dir/audio-fbanks/1052_844424933454296-1052-m.npy')
f.shape

(716, 64)

In [23]:
def load_pickle(file):
    if not os.path.exists(file):
        return None
    logger.info(f'Loading PKL file: {file}.')
    with open(file, 'rb') as r:
        return dill.load(r)


def load_npy(file):
    if not os.path.exists(file):
        return None
    logger.info(f'Loading NPY file: {file}.')
    return np.load(file)


def train_test_sp_to_utt(audio, is_test):
    sp_to_utt = {}
    for speaker_id, utterances in audio.speakers_to_utterances.items():
        utterances_files = sorted(utterances.values())
        train_test_sep = int(len(utterances_files) * 0.8)
        sp_to_utt[speaker_id] = utterances_files[train_test_sep:] if is_test else utterances_files[:train_test_sep]
    return sp_to_utt

def extract_speaker(utt_file):
    return utt_file.split('/')[-1].split('_')[0]


def sample_from_mfcc(mfcc, max_length):
    if mfcc.shape[0] >= max_length:
        r = choice(range(0, len(mfcc) - max_length + 1))
        s = mfcc[r:r + max_length]
    else:
        s = pad_mfcc(mfcc, max_length)
    return np.expand_dims(s, axis=-1)


def sample_from_mfcc_file(utterance_file, max_length):
    mfcc = np.load(utterance_file)
    return sample_from_mfcc(mfcc, max_length)

class KerasFormatConverter:

    def __init__(self, working_dir, load_test_only=False):
        self.working_dir = working_dir
        self.output_dir = os.path.join(self.working_dir, 'keras-inputs')
        ensures_dir(self.output_dir)
        self.categorical_speakers = load_pickle(os.path.join(self.output_dir, 'categorical_speakers.pkl'))
        if not load_test_only:
            self.kx_train = load_npy(os.path.join(self.output_dir, 'kx_train.npy'))
            self.ky_train = load_npy(os.path.join(self.output_dir, 'ky_train.npy'))
        self.kx_test = load_npy(os.path.join(self.output_dir, 'kx_test.npy'))
        self.ky_test = load_npy(os.path.join(self.output_dir, 'ky_test.npy'))
        self.audio = Audio(cache_dir=self.working_dir, audio_dir=None)
        if self.categorical_speakers is None:
            self.categorical_speakers = SparseCategoricalSpeakers(self.audio.speaker_ids)

    def persist_to_disk(self):
        with open(os.path.join(self.output_dir, 'categorical_speakers.pkl'), 'wb') as w:
            dill.dump(self.categorical_speakers, w)
        np.save(os.path.join(self.output_dir, 'kx_train.npy'), self.kx_train)
        np.save(os.path.join(self.output_dir, 'kx_test.npy'), self.kx_test)
        np.save(os.path.join(self.output_dir, 'ky_train.npy'), self.ky_train)
        np.save(os.path.join(self.output_dir, 'ky_test.npy'), self.ky_test)

    def generate_per_phase(self, max_length=NUM_FRAMES, num_per_speaker=3000, is_test=False):
        # train OR test.
        num_speakers = len(self.audio.speaker_ids)
        sp_to_utt = train_test_sp_to_utt(self.audio, is_test)

        # 64 fbanks 1 channel(s).
        # float32
        kx = np.zeros((num_speakers * num_per_speaker, max_length, NUM_FBANKS, 1), dtype=np.float32)
        ky = np.zeros((num_speakers * num_per_speaker, 1), dtype=np.float32)

        desc = f'Converting to Keras format [{"test" if is_test else "train"}]'
        for i, speaker_id in enumerate(tqdm(self.audio.speaker_ids, desc=desc)):
            utterances_files = sp_to_utt[speaker_id]
            for j, utterance_file in enumerate(np.random.choice(utterances_files, size=num_per_speaker, replace=True)):
                self.load_into_mat(utterance_file, self.categorical_speakers, speaker_id, max_length, kx, ky,
                                   i * num_per_speaker + j)
        return kx, ky

    def generate(self, max_length=NUM_FRAMES, counts_per_speaker=(3000, 500)):
        kx_train, ky_train = self.generate_per_phase(max_length, counts_per_speaker[0], is_test=False)
        kx_test, ky_test = self.generate_per_phase(max_length, counts_per_speaker[1], is_test=True)
        logger.info(f'kx_train.shape = {kx_train.shape}')
        logger.info(f'ky_train.shape = {ky_train.shape}')
        logger.info(f'kx_test.shape = {kx_test.shape}')
        logger.info(f'ky_test.shape = {ky_test.shape}')
        self.kx_train, self.ky_train, self.kx_test, self.ky_test = kx_train, ky_train, kx_test, ky_test

    @staticmethod
    def load_into_mat(utterance_file, categorical_speakers, speaker_id, max_length, kx, ky, i):
        kx[i] = sample_from_mfcc_file(utterance_file, max_length)
        ky[i] = categorical_speakers.get_index(speaker_id)


class SparseCategoricalSpeakers:

    def __init__(self, speakers_list):
        self.speaker_ids = sorted(speakers_list)
        assert len(set(self.speaker_ids)) == len(self.speaker_ids)  # all unique.
        self.map = dict(zip(self.speaker_ids, range(len(self.speaker_ids))))

    def get_index(self, speaker_id):
        return self.map[speaker_id]



In [24]:
counts_per_speaker = [600,100]
kc = KerasFormatConverter('/kaggle/working/working_dir')
kc.generate(max_length=NUM_FRAMES, counts_per_speaker=counts_per_speaker)
kc.persist_to_disk()

Converting to Keras format [train]:   0%|          | 0/40 [00:00<?, ?it/s]

Converting to Keras format [test]:   0%|          | 0/40 [00:00<?, ?it/s]

In [21]:
class DeepSpeakerDataset(Dataset):
  def __init__(self,kx,ky):
    self.kx=np.load(kx)
    self.ky=np.load(ky)
  def __len__(self):
    return len(self.kx)

  def __getitem__(self,idx):
    x=torch.from_numpy(self.kx[idx])
    x=torch.permute(x,(2,0,1))
    y=torch.from_numpy(self.ky[idx]).long()
    return x,y

In [22]:
train_data=DeepSpeakerDataset('/kaggle/working/working_dir/keras-inputs/kx_train.npy','/kaggle/working/working_dir/keras-inputs/ky_train.npy')

In [23]:
train_data[0][0].shape,train_data[0][1].shape

(torch.Size([1, 160, 64]), torch.Size([1]))

In [24]:
test_data=DeepSpeakerDataset('/kaggle/working/working_dir/keras-inputs/kx_test.npy','/kaggle/working/working_dir/keras-inputs/ky_test.npy')

In [25]:
train_dataloader=DataLoader(train_data,batch_size=50,shuffle=True)
test_dataloader=DataLoader(test_data,batch_size=50,shuffle=True)

In [25]:
optimizer=optim.Adam(model.parameters(),lr=0.0001,weight_decay=0.001)
criterion=DeepSpeakerLoss()

In [26]:
device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


In [27]:
spkrs=load_pickle('/kaggle/working/working_dir/keras-inputs/categorical_speakers.pkl')
len(spkrs.speaker_ids)

40

In [61]:
epochs=100

In [20]:
dense_layer=nn.Sequential(nn.Dropout(p=0.5),nn.Linear(in_features=512,out_features=len(spkrs.speaker_ids)))

In [21]:
dense_layer=dense_layer.to(device)

In [28]:
model=model.to(device)

In [None]:
for i in tqdm(range(epochs),desc='epochs'):
  train_loss=0.0
  test_loss=0.0
  train_len=len(train_dataloader.dataset)
  test_len=len(test_dataloader.dataset)
  start=time()
  for data,target in tqdm(train_dataloader,desc='train'):
    data=data.to(device)
    target=target.to(device)
    optimizer.zero_grad()
    output=model(data)
    output=dense_layer(output)
    loss=criterion(output,target.squeeze())
    loss.backward()
    optimizer.step()
    train_loss+=loss.item()
    
  with torch.no_grad():
    for data,target in tqdm(test_dataloader,desc="test "):
      data=data.to(device)
      target=target.to(device)
      output=model(data)
      output=dense_layer(output)
      loss=criterion(output,target.squeeze())
      test_loss+=loss.item()
  end=time()
  train_loss/=train_len
  test_loss/=test_len
  if (i+1)%10==0:
      torch.save(model,f'/kaggle/working/model_checkpoints/model_{i+1}.pt')
  print(f"epoch {i+1} train_loss : {train_loss:.4f} test_loss : {test_loss:.4f} processing_time : {end-start:.4f}")


In [None]:
mfcc1=sample_from_mfcc(read_mfcc('/kaggle/input/kathbath-hindi/kathbath/hindi/test/audio/1183/844424933261563-1183-m.wav',16000),160)
mfcc2=sample_from_mfcc(read_mfcc('/kaggle/input/kathbath-hindi/kathbath/hindi/test/audio/1190/844424930929529-1190-f.wav',16000),160)

In [None]:
mfcc1=torch.permute(torch.from_numpy(mfcc1),(2,0,1))
mfcc2=torch.permute(torch.from_numpy(mfcc2),(2,0,1))

In [None]:
with torch.no_grad():
    emb1=model(mfcc1.unsqueeze(0).to(device))
    emb2=model(mfcc2.unsqueeze(0).to(device))

In [None]:
similarity=batch_cosine_similarity(emb1,emb2)

In [None]:
print(similarity)

In [29]:
class LazyTripletBatcher:
    def __init__(self, working_dir: str, max_length: int, model: DeepSpeaker,device):
        self.working_dir = working_dir
        self.audio = Audio(cache_dir=working_dir)
        logger.info(f'Picking audio from {working_dir}.')
        self.sp_to_utt_train = train_test_sp_to_utt(self.audio, is_test=False)
        self.sp_to_utt_test = train_test_sp_to_utt(self.audio, is_test=True)
        self.max_length = max_length
        self.model = model.to(device)
        self.device=device
        self.nb_per_speaker = 2
        self.nb_speakers = 640
        self.history_length = 4
        self.history_every = 100  # batches.
        self.total_history_length = self.nb_speakers * self.nb_per_speaker * self.history_length  # 25,600
        self.metadata_train_speakers = Counter()
        self.metadata_output_file = os.path.join(self.working_dir, 'debug_batcher.json')

        self.history_embeddings_train = deque(maxlen=self.total_history_length)
        self.history_utterances_train = deque(maxlen=self.total_history_length)
        self.history_model_inputs_train = deque(maxlen=self.total_history_length)

        self.history_embeddings = None
        self.history_utterances = None
        self.history_model_inputs = None

        self.batch_count = 0
        for _ in tqdm(range(self.history_length), desc='Initializing the batcher'):  # init history.
            self.update_triplets_history()

    def update_triplets_history(self):
        model_inputs = []
        speakers = list(self.audio.speakers_to_utterances.keys())
        np.random.shuffle(speakers)
        selected_speakers = speakers[: self.nb_speakers]
        embeddings_utterances = []
        for speaker_id in selected_speakers:
            train_utterances = self.sp_to_utt_train[speaker_id]
            for selected_utterance in np.random.choice(a=train_utterances, size=self.nb_per_speaker, replace=False):
                mfcc = sample_from_mfcc_file(selected_utterance, self.max_length)
                embeddings_utterances.append(selected_utterance)
                model_inputs.append(mfcc)
        model_inputs=torch.tensor(model_inputs,device=self.device)
        model_inputs=torch.permute(model_inputs,(0,3,1,2))
#         print(model_inputs.shape)
        with torch.no_grad():
            embeddings = self.model(model_inputs)
            embeddings = embeddings.detach().cpu().numpy()
            
        assert embeddings.shape[-1] == 512
        embeddings = np.reshape(embeddings, (len(selected_speakers), self.nb_per_speaker, 512))
        self.history_embeddings_train.extend(list(embeddings.reshape((-1, 512))))
        self.history_utterances_train.extend(embeddings_utterances)
        self.history_model_inputs_train.extend(model_inputs.detach().cpu().numpy())

        # reason: can't index a deque with a np.array.
        self.history_embeddings = np.array(self.history_embeddings_train)
        self.history_utterances = np.array(self.history_utterances_train)
        self.history_model_inputs = np.array(self.history_model_inputs_train)

        with open(self.metadata_output_file, 'w') as w:
            json.dump(obj=dict(self.metadata_train_speakers), fp=w, indent=2)

    def get_batch(self, batch_size, is_test=False):
        return self.get_batch_test(batch_size) if is_test else self.get_random_batch(batch_size, is_test=False)

    def get_batch_test(self, batch_size):
        return self.get_random_batch(batch_size, is_test=True)

    def get_random_batch(self, batch_size, is_test=False):
        sp_to_utt = self.sp_to_utt_test if is_test else self.sp_to_utt_train
        speakers = list(self.audio.speakers_to_utterances.keys())
        anchor_speakers = np.random.choice(speakers, size=batch_size // 3, replace=False)

        anchor_utterances = []
        positive_utterances = []
        negative_utterances = []
        for anchor_speaker in anchor_speakers:
            negative_speaker = np.random.choice(list(set(speakers) - {anchor_speaker}), size=1)[0]
            assert negative_speaker != anchor_speaker
            pos_utterances = np.random.choice(sp_to_utt[anchor_speaker], 2, replace=False)
            neg_utterance = np.random.choice(sp_to_utt[negative_speaker], 1, replace=True)[0]
            anchor_utterances.append(pos_utterances[0])
            positive_utterances.append(pos_utterances[1])
            negative_utterances.append(neg_utterance)

        # anchor and positive should have difference utterances (but same speaker!).
        anc_pos = np.array([positive_utterances, anchor_utterances])
        assert np.all(anc_pos[0, :] != anc_pos[1, :])
        assert np.all(np.array([extract_speaker(s) for s in anc_pos[0, :]]) == np.array(
            [extract_speaker(s) for s in anc_pos[1, :]]))

        pos_neg = np.array([positive_utterances, negative_utterances])
        assert np.all(pos_neg[0, :] != pos_neg[1, :])
        assert np.all(np.array([extract_speaker(s) for s in pos_neg[0, :]]) != np.array(
            [extract_speaker(s) for s in pos_neg[1, :]]))

        batch_x = np.vstack([
            [sample_from_mfcc_file(u, self.max_length) for u in anchor_utterances],
            [sample_from_mfcc_file(u, self.max_length) for u in positive_utterances],
            [sample_from_mfcc_file(u, self.max_length) for u in negative_utterances]
        ])

        batch_y = np.zeros(shape=(len(batch_x), 1))  # dummy. sparse softmax needs something.
        batch_x=torch.from_numpy(batch_x).permute(0,3,1,2)
        batch_x.requires_grad_()
        batch_y=torch.from_numpy(batch_y)
        batch_y.requires_grad_()
        return batch_x, batch_y

    def get_batch_train(self, batch_size):
        # s1 = time()
        self.batch_count += 1
        if self.batch_count % self.history_every == 0:
            self.update_triplets_history()

        all_indexes = range(len(self.history_embeddings_train))
        anchor_indexes = np.random.choice(a=all_indexes, size=batch_size // 3, replace=False)

        # s2 = time()
        similar_negative_indexes = []
        dissimilar_positive_indexes = []
        # could be made parallel.
        for anchor_index in anchor_indexes:
            # s21 = time()
            anchor_embedding = self.history_embeddings[anchor_index]
            anchor_speaker = extract_speaker(self.history_utterances[anchor_index])

            # why self.nb_speakers // 2? just random. because it is fast. otherwise it's too much.
            negative_indexes = [j for (j, a) in enumerate(self.history_utterances)
                                if extract_speaker(a) != anchor_speaker]
            negative_indexes = np.random.choice(negative_indexes, size=self.nb_speakers // 2)

            # s22 = time()

            anchor_embedding_tile = [anchor_embedding] * len(negative_indexes)
            anchor_cos = batch_cosine_similarity(anchor_embedding_tile, self.history_embeddings[negative_indexes])

            # s23 = time()
            similar_negative_index = negative_indexes[np.argsort(anchor_cos)[-1]]  # [-1:]
            similar_negative_indexes.append(similar_negative_index)

            # s24 = time()
            positive_indexes = [j for (j, a) in enumerate(self.history_utterances) if
                                extract_speaker(a) == anchor_speaker and j != anchor_index]
            # s25 = time()
            anchor_embedding_tile = [anchor_embedding] * len(positive_indexes)
            # s26 = time()
            anchor_cos = batch_cosine_similarity(anchor_embedding_tile, self.history_embeddings[positive_indexes])
            dissimilar_positive_index = positive_indexes[np.argsort(anchor_cos)[0]]  # [:1]
            dissimilar_positive_indexes.append(dissimilar_positive_index)
            # s27 = time()

        # s3 = time()
        batch_x = np.vstack([
            self.history_model_inputs[anchor_indexes],
            self.history_model_inputs[dissimilar_positive_indexes],
            self.history_model_inputs[similar_negative_indexes]
        ])

        
        anchor_speakers = [extract_speaker(a) for a in self.history_utterances[anchor_indexes]]
        positive_speakers = [extract_speaker(a) for a in self.history_utterances[dissimilar_positive_indexes]]
        negative_speakers = [extract_speaker(a) for a in self.history_utterances[similar_negative_indexes]]

        assert len(anchor_indexes) == len(dissimilar_positive_indexes)
        assert len(similar_negative_indexes) == len(dissimilar_positive_indexes)
        assert list(self.history_utterances[dissimilar_positive_indexes]) != list(
            self.history_utterances[anchor_indexes])
        assert anchor_speakers == positive_speakers
        assert negative_speakers != anchor_speakers

        batch_y = np.zeros(shape=(len(batch_x), 1))  # dummy. sparse softmax needs something.

        for a in anchor_speakers:
            self.metadata_train_speakers[a] += 1
        for a in positive_speakers:
            self.metadata_train_speakers[a] += 1
        for a in negative_speakers:
            self.metadata_train_speakers[a] += 1
        batch_x=torch.from_numpy(batch_x).permute(0,3,1,2)
        batch_x.requires_grad_()
        batch_y=torch.from_numpy(batch_y)
        batch_y.requires_grad_()
        return batch_x, batch_y

    def get_speaker_verification_data(self, anchor_speaker, num_different_speakers):
        speakers = list(self.audio.speakers_to_utterances.keys())
        anchor_utterances = []
        positive_utterances = []
        negative_utterances = []
        negative_speakers = np.random.choice(list(set(speakers) - {anchor_speaker}), size=num_different_speakers)
        assert [negative_speaker != anchor_speaker for negative_speaker in negative_speakers]
        pos_utterances = np.random.choice(self.sp_to_utt_test[anchor_speaker], 2, replace=False)
        neg_utterances = [np.random.choice(self.sp_to_utt_test[neg], 1, replace=True)[0] for neg in negative_speakers]
        anchor_utterances.append(pos_utterances[0])
        positive_utterances.append(pos_utterances[1])
        negative_utterances.extend(neg_utterances)

        # anchor and positive should have difference utterances (but same speaker!).
        anc_pos = np.array([positive_utterances, anchor_utterances])
        assert np.all(anc_pos[0, :] != anc_pos[1, :])
        assert np.all(np.array([extract_speaker(s) for s in anc_pos[0, :]]) == np.array(
            [extract_speaker(s) for s in anc_pos[1, :]]))

        batch_x = np.vstack([
            [sample_from_mfcc_file(u, self.max_length) for u in anchor_utterances],
            [sample_from_mfcc_file(u, self.max_length) for u in positive_utterances],
            [sample_from_mfcc_file(u, self.max_length) for u in negative_utterances]
        ])

        batch_y = np.zeros(shape=(len(batch_x), 1))  # dummy. sparse softmax needs something.
        batch_x=torch.from_numpy(batch_x).permute(0,3,1,2)
        batch_x.requires_grad_()
        batch_y=torch.from_numpy(batch_y)
        batch_y.requires_grad_()
        return batch_x, batch_y


In [30]:
device

device(type='cuda')

In [31]:
batcher=LazyTripletBatcher('/kaggle/working/working_dir',NUM_FRAMES,model,device)

Initializing the batcher:   0%|          | 0/4 [00:00<?, ?it/s]

  model_inputs=torch.tensor(model_inputs,device=self.device)


In [32]:
batch=batcher.get_random_batch(12)

In [33]:
batch[0].shape,batch[1].shape

(torch.Size([12, 1, 160, 64]), torch.Size([12, 1]))

In [34]:
batch_size=12

In [35]:
test_batches = []
for _ in tqdm(range(200), desc='Build test set'):
    test_batches.append(batcher.get_batch_test(batch_size))

# def test_generator():
#     while True:
#         for bb in test_batches:
#             yield bb

# def train_generator():
#     while True:
#         yield batcher.get_random_batch(batch_size, is_test=False)
train_batches=[]
for i in tqdm(range(2000),desc="Building train set"):
    train_batches.append(batcher.get_random_batch(batch_size, is_test=False))

Build test set:   0%|          | 0/200 [00:00<?, ?it/s]

Building train set:   0%|          | 0/2000 [00:00<?, ?it/s]

In [75]:
model=model.to(device)

In [77]:
optimizer=optim.SGD(model.parameters(),lr=0.0001)

In [78]:
for i in tqdm(range(10),desc='epochs'):
  train_loss=0.0
  test_loss = 0.0
  train_len=2000
  test_len=200
  start=time()
  train_batches=[]
  for _ in tqdm(range(2000),desc="Building train set"):
      train_batches.append(batcher.get_random_batch(batch_size, is_test=False))
  test_batches = []
  for _ in tqdm(range(200), desc='Build test set'):
      test_batches.append(batcher.get_batch_test(batch_size))

  for data,target in tqdm(train_batches):
    data=data.to(device)
    target=target.to(device)
    optimizer.zero_grad()
    output=model(data)
#     output=dense_layer(output)
    loss=criterion(target,output)+model.l2_regularization()
    loss.backward()
    optimizer.step()
    train_loss+=loss.item()
    
  with torch.no_grad():
    for data,target in tqdm(test_batches):
      data=data.to(device)
      target=target.to(device)
      output=model(data)
#       output=dense_layer(output)
      loss=criterion(target,output)+model.l2_regularization()
      test_loss+=loss.item()
  end=time()
  train_loss/=train_len
  test_loss/=test_len
  if (i+1)%10==0:
      torch.save(model,f'/kaggle/working/model_checkpoints/model_{i+1}.pt')
  print(f"epoch {i+1} train_loss : {train_loss:.4f} test_loss : {test_loss:.4f} processing_time : {end-start:.4f}")


epochs:   0%|          | 0/10 [00:00<?, ?it/s]

Building train set:   0%|          | 0/2000 [00:00<?, ?it/s]

Build test set:   0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

epoch 1 train_loss : 0.1145 test_loss : 0.1103 processing_time : 197.4361


Building train set:   0%|          | 0/2000 [00:00<?, ?it/s]

Build test set:   0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

epoch 2 train_loss : 0.1126 test_loss : 0.1119 processing_time : 197.4545


Building train set:   0%|          | 0/2000 [00:00<?, ?it/s]

Build test set:   0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

epoch 3 train_loss : 0.1119 test_loss : 0.1109 processing_time : 197.2262


Building train set:   0%|          | 0/2000 [00:00<?, ?it/s]

Build test set:   0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

epoch 4 train_loss : 0.1116 test_loss : 0.1098 processing_time : 197.1704


Building train set:   0%|          | 0/2000 [00:00<?, ?it/s]

Build test set:   0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

epoch 5 train_loss : 0.1114 test_loss : 0.1097 processing_time : 197.4359


Building train set:   0%|          | 0/2000 [00:00<?, ?it/s]

Build test set:   0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

epoch 6 train_loss : 0.1117 test_loss : 0.1093 processing_time : 197.2958


Building train set:   0%|          | 0/2000 [00:00<?, ?it/s]

Build test set:   0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

epoch 7 train_loss : 0.1108 test_loss : 0.1100 processing_time : 197.7510


Building train set:   0%|          | 0/2000 [00:00<?, ?it/s]

Build test set:   0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

epoch 8 train_loss : 0.1102 test_loss : 0.1087 processing_time : 197.4136


Building train set:   0%|          | 0/2000 [00:00<?, ?it/s]

Build test set:   0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

epoch 9 train_loss : 0.1107 test_loss : 0.1092 processing_time : 197.3663


Building train set:   0%|          | 0/2000 [00:00<?, ?it/s]

Build test set:   0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

epoch 10 train_loss : 0.1110 test_loss : 0.1086 processing_time : 197.5333


In [36]:
x,y=train_batches[0][0],train_batches[0][1]

In [40]:
out=model(x.to(device))
out=dense_layer(out)

In [41]:
loss=criterion(out,y)

In [42]:
print(out.shape)
print(y.shape)

torch.Size([12, 40])
torch.Size([12, 1])
