In [4]:
from pathlib import Path

speakers_per_batch=32
utterances_per_speaker=10
seq_len=128
train_steps = 1e12
train_print_interval = 10 # in steps
total_evaluate_steps = 50
evaluate_interval = 500 # in steps
save_interval = 100 # in steps
save_dir = Path(r'/kaggle/working/')
max_ckpts = 100
speaker_lr = 1e-4
libri_dataset_path = Path(r'/kaggle/input/librispeech-360-clean/LibriSpeech/train-clean-360')
device = 'cuda:0'
loss_device = 'cpu'

In [5]:
import librosa
import numpy as np

def normalize(S, min_level_db=-100):
    return np.clip((S - min_level_db) / -min_level_db, 0, 1)

def linear_to_mel(spectrogram, sample_rate=16000, n_fft=1024, fmin=90, fmax=7600, n_mels=80):
    return librosa.feature.melspectrogram(
        S=spectrogram, sr=sample_rate, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax)

def amp_to_db(x):
    return 20. * np.log10(np.maximum(1e-5, x))

def stft(y, n_fft=1024, hop_length=256, win_length=1024):
    return librosa.stft(
        y=y,
        n_fft=n_fft, hop_length=hop_length, win_length=win_length)

def gen_melspectrogram(y):
    D = stft(y)
    S = amp_to_db(linear_to_mel(np.abs(D)))
    return np.clip(normalize(S), 0, 1)

In [6]:
from pathlib import Path
import pickle
import logging
import os

import librosa
import numpy as np

class Utterance(object):
    def __init__(self, id: str = None, raw_file: Path = None):
        self.id = id
        self.raw_file = raw_file
    def raw(self, sr=16000, augment=False):
        """Get the raw audio samples."""

        y, sr = librosa.load(self.raw_file, sr=sr)
        # y, _ = librosa.effects.trim(y)
        if y.size == 0:
            raise Exception('audio', 'empty audio')
        y = 0.95 * librosa.util.normalize(y)
        if augment:
            amplitude = np.random.uniform(low=0.3, high=1.0)
            y = y * amplitude
        return y

    def melspectrogram(self, sr=16000, n_fft=1024, hop_length=256, win_length=1024, n_mels=80):
        """Get the melspectrogram features."""

        try:
            return gen_melspectrogram(self.raw(sr=sr))
        except Exception:
            logging.debug(f'failed to load melspectrogram, raw file: {self.raw_file}, mel file: {self.mel_file}')
            raise

    def random_raw_segment(self, seq_len):
        """Return a audio segment randomly."""

        y = self.raw(augment=True)
        ylen = len(y)
        if ylen < seq_len:
            pad_left = (seq_len - ylen) // 2
            pad_right = seq_len - ylen - pad_left
            y = np.pad(y, ((pad_left, pad_right)), mode='reflect')
        elif ylen > seq_len:
            max_seq_start = ylen - seq_len
            seq_start = np.random.randint(0, max_seq_start)
            seq_end = seq_start + seq_len
            y = y[seq_start:seq_end]

        return y

    def random_mel_segment(self, seq_len):
        """Return a melspectrogram segment randomly."""

        mel = self.melspectrogram()
        freq_len, tempo_len = mel.shape
        if tempo_len < seq_len:
            pad_left = (seq_len - tempo_len) // 2
            pad_right = seq_len - tempo_len - pad_left
            mel = np.pad(mel, ((0, 0), (pad_left, pad_right)), mode='reflect')
        elif tempo_len > seq_len:
            max_seq_start = tempo_len - seq_len
            seq_start = np.random.randint(0, max_seq_start)
            seq_end = seq_start + seq_len
            mel = mel[:, seq_start:seq_end]
        return mel

class Speaker(object):
    def __init__(self, id: str):
        self.id = id
        self.utterances = []

    def add_utterance(self, utterance: Utterance):
        """Add an utterance to this speaker."""

        self.utterances.append(utterance)

    def random_utterances(self, n):
        """Return n utterances randomly."""

        return [self.utterances[idx] for idx in np.random.randint(0, len(self.utterances), n)]


In [7]:
from typing import List
from pathlib import Path
import pickle
import logging
from multiprocessing import Process, JoinableQueue
import time
import os
import random

import torch
import numpy as np


class AudioDataset(object):
    def __init__(self, id: str, speakers: List[Speaker] = []):
        self.id = id
        self.speakers = speakers

    def add_speaker(self, speaker: Speaker):
        """Add a speaker to this dataset."""

        self.speakers.append(speaker)

    def random_speakers(self, n):
        """Return n speakers randomly."""

        return [self.speakers[idx] for idx in np.random.randint(0, len(self.speakers), n)]

    def serialize_speaker(self, queue: JoinableQueue, counter_queue: JoinableQueue):
        while True:
            speaker, root, overwrite = queue.get()

            if not root.exists():
                root.mkdir(parents=True)

            dsdir = root / self.id
            if not dsdir.exists():
                dsdir.mkdir()

            spkdir = dsdir / speaker.id
            if not spkdir.exists():
                spkdir.mkdir()

            for uttrn_idx, uttrn in enumerate(speaker.utterances):
                uttrnpath = spkdir / (uttrn.id + '.pkl')
                is_overwrite = False
                is_empty = False
                if uttrnpath.exists():
                    if os.path.getsize(uttrnpath) == 0:
                        logging.debug(f'overrite empty file {uttrnpath}')
                    elif not overwrite:
                        logging.debug(f'{uttrnpath} already exists, skip')
                        counter_queue.put(1)
                        continue
                    is_overwrite = True
                try:
                    mel = uttrn.melspectrogram()
                    with uttrnpath.open(mode='wb') as f:
                        pickle.dump(mel, f)
                    if is_overwrite:
                        logging.debug(f'dump pickle object to {uttrnpath} ({uttrn_idx+1}/{len(speaker.utterances)}), overwrite')
                    else:
                        logging.debug(f'dump pickle object to {uttrnpath} ({uttrn_idx+1}/{len(speaker.utterances)})')
                except Exception as err:
                    logging.warning(f'failed to dump mel features for file {uttrnpath}: {err}')
                counter_queue.put(1)
            queue.task_done()

    def serialization_counter(self, total_count, queue: JoinableQueue):
        count = 0
        while True:
            start_time = time.time()
            done = queue.get()
            duration = time.time() - start_time
            count += 1
            logging.debug(f'serialization progress {count}/{total_count}, {int(duration*1000)}ms/item')
            queue.task_done()

    def serialize_mel_feature(self, root: Path, overwrite=False):
        """Serialize melspectrogram features for all utterances of all speakers to the disk."""

        num_processes = 8
        queue = JoinableQueue()
        counter_queue = JoinableQueue()
        processes = []
        for i in range(num_processes):
            p = Process(target=self.serialize_speaker, args=(queue, counter_queue))
            processes.append(p)
            p.start()
        total_count = sum([len(spk.utterances) for spk in self.speakers])
        counter_process = Process(target=self.serialization_counter, args=(total_count, counter_queue))
        counter_process.start()
        # add tasks to queue
        logging.debug(f'total {len(self.speakers)} speakers')
        for spk in self.speakers:
            queue.put((spk, root, overwrite)) 
        # wait for all task done
        queue.join() 
        counter_queue.join()
        for p in processes:
            p.terminate()
        counter_process.terminate()

class MultiAudioDataset(object):
    def __init__(self, datasets: List[AudioDataset]):
        self.id = ''
        self.speakers = []
        ids = []
        for ds in datasets:
            ids.append(ds.id)
            self.speakers.extend(ds.speakers)
        self.id = '+'.join(ids)

class SpeakerDataset(object):
    def __init__(self, speakers, utterances_per_speaker, seq_len):
        self.speakers = speakers
        n_speakers = len(self.speakers)
        n_utterances = sum([len(spk.utterances) for spk in self.speakers])
        logging.info(f'total {n_speakers} speakers, {n_utterances} utterances')
        self.utterances_per_speaker = utterances_per_speaker
        self.seq_len = seq_len

    def random_utterance_segment(self, speaker_idx, seq_len):
        """Must return an utterance segment as long as the speaker has at least
        one effective utterance."""

        while True:
            try:
                utterance = self.speakers[speaker_idx].random_utterances(1)[0]
                return utterance.random_mel_segment(seq_len)
            except Exception as err:
                logging.debug(f'failed to load utterances of speaker idx {speaker_idx}: {err}')
                continue

    def __getitem__(self, idx):
        """Return random segments of random utterances for the specified speaker."""
        seq_len = 0
        if isinstance(self.seq_len, int):
            seq_len = self.seq_len
        elif isinstance(self.seq_len, list):
            seq_len = self.seq_len[random.randint(0, len(self.seq_len)-1)]
        else:
            raise ValueError('seq_len must be int or int list')

        segments = np.array([self.random_utterance_segment(idx, seq_len) for _ in range(self.utterances_per_speaker)])
        return torch.tensor(segments)

    def __len__(self):
        return len(self.speakers)

In [8]:
from pathlib import Path

def load_librispeech360_dataset(root: Path):
    """Load the LibriSpeech train-clean-360 dataset into an AudioDataset.

    The dataset can be downloaded from: https://www.openslr.org/12

    Args:
        root (Path): Path to the root directory of the LibriSpeech dataset.
        mel_feature_root (Path, optional): Path to the root directory where the precomputed mel features are stored.

    Returns:
        AudioDataset: A dataset object containing the loaded speakers and their utterances.
    """

    dataset_id = 'librispeech360'
    id2speaker = dict()

    # Recursively find all .flac files in the dataset
    wav_files = root.rglob('*.flac')
    
    for f in wav_files:
        # LibriSpeech files are typically structured as: <root>/<speaker_id>/<chapter_id>/<utterance_id>.flac
        speaker_id = f.parent.parent.name  # Extract speaker ID from the parent folder
        chapter_id = f.parent.name  # Extract chapter ID from the immediate parent folder
        utterance_id = f.stem  # Use the file stem as the utterance ID (without .flac extension)

        uttrn = Utterance(utterance_id, raw_file=f)

        if speaker_id in id2speaker:
            id2speaker[speaker_id].add_utterance(uttrn)
        else:
            spk = Speaker(speaker_id)
            spk.add_utterance(uttrn)
            id2speaker[speaker_id] = spk

    dataset = AudioDataset(dataset_id, speakers=list(id2speaker.values()))
    return dataset

def load_vivos_dataset(root: Path):
    """Load the VIVOS dataset into an AudioDataset.

    The dataset can be downloaded from: https://ailab.hcmus.edu.vn/vivos

    Args:
        root (Path): Path to the root directory of the VIVOS dataset.

    Returns:
        AudioDataset: A dataset object containing the loaded speakers and their utterances.
    """

    dataset_id = 'vivos'
    id2speaker = dict()

    # Recursively find all .wav files in the dataset
    wav_files = root.rglob('*.wav')
    
    for f in wav_files:
        # VIVOS files are typically structured as: <root>/train/<speaker_id>/<utterance_id>.wav
        speaker_id = f.parent.name  # Extract speaker ID from the parent folder
        utterance_id = f.stem  # Use the file stem as the utterance ID (without .wav extension)

        uttrn = Utterance(utterance_id, raw_file=f)

        if speaker_id in id2speaker:
            id2speaker[speaker_id].add_utterance(uttrn)
        else:
            spk = Speaker(speaker_id)
            spk.add_utterance(uttrn)
            id2speaker[speaker_id] = spk

    dataset = AudioDataset(dataset_id, speakers=list(id2speaker.values()))
    return dataset

def load_aishell3_dataset(root: Path):
    """Load the AISHELL-3 dataset into an AudioDataset.

    The dataset can be downloaded from: https://www.openslr.org/93

    Args:
        root (Path): Path to the root directory of the AISHELL-3 dataset.

    Returns:
        AudioDataset: A dataset object containing the loaded speakers and their utterances.
    """

    dataset_id = 'aishell3'
    id2speaker = dict()

    # Recursively find all .wav files in the dataset
    wav_files = root.rglob('*.wav')
    
    for f in wav_files:
        # AISHELL-3 files are typically structured as: <root>/wav/<speaker_id>/<utterance_id>.wav
        speaker_id = f.parent.name  # Extract speaker ID from the parent folder
        utterance_id = f.stem  # Use the file stem as the utterance ID (without .wav extension)

        uttrn = Utterance(utterance_id, raw_file=f)

        if speaker_id in id2speaker:
            id2speaker[speaker_id].add_utterance(uttrn)
        else:
            spk = Speaker(speaker_id)
            spk.add_utterance(uttrn)
            id2speaker[speaker_id] = spk

    dataset = AudioDataset(dataset_id, speakers=list(id2speaker.values()))
    return dataset

def load_voxceleb_dataset(root: Path):
    """Load the VoxCeleb dataset into an AudioDataset.

    The dataset can be downloaded from: https://www.robots.ox.ac.uk/~vgg/data/voxceleb/

    Args:
        root (Path): Path to the root directory of the VoxCeleb dataset.

    Returns:
        AudioDataset: A dataset object containing the loaded speakers and their utterances.
    """

    dataset_id = 'voxceleb'
    id2speaker = dict()

    # Recursively find all .wav files in the dataset
    wav_files = root.rglob('*.wav')
    
    for f in wav_files:
        # VoxCeleb files are typically structured as: <root>/wav/<speaker_id>/<segment_id>/<utterance_id>.wav
        speaker_id = f.parts[-3]  # Extract speaker ID from the third-to-last folder
        utterance_id = f.stem  # Use the file stem as the utterance ID (without .wav extension)

        uttrn = Utterance(utterance_id, raw_file=f)

        if speaker_id in id2speaker:
            id2speaker[speaker_id].add_utterance(uttrn)
        else:
            spk = Speaker(speaker_id)
            spk.add_utterance(uttrn)
            id2speaker[speaker_id] = spk

    dataset = AudioDataset(dataset_id, speakers=list(id2speaker.values()))
    return dataset



In [15]:
from scipy.interpolate import interp1d
from sklearn.metrics import roc_curve
from torch.nn.utils import clip_grad_norm_
from scipy.optimize import brentq
import torch
import torch.nn as nn
import numpy as np

class SpeakerEncoder(nn.Module):
    def __init__(self, input_size=80, hidden_size=256, num_layers=3, num_heads=8, device='cpu', loss_device='cpu'):
        super().__init__()
        self.loss_device = loss_device

        # Transformer encoder layers
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=input_size,
            nhead=num_heads,
            dim_feedforward=hidden_size,
            dropout=0.2,
            activation='relu',
            batch_first=True
        )
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers).to(device)

        self.linear = nn.Linear(in_features=input_size, out_features=128).to(device)
        self.relu = nn.ReLU().to(device)

        # Cosine similarity scaling (with fixed initial parameter values)
        self.similarity_weight = nn.Parameter(torch.tensor([10.])).to(loss_device)
        self.similarity_bias = nn.Parameter(torch.tensor([-5.])).to(loss_device)

        # Loss
        self.loss_fn = nn.CrossEntropyLoss().to(loss_device)

    def do_gradient_ops(self):
        # Gradient scale
        self.similarity_weight.grad *= 0.01
        self.similarity_bias.grad *= 0.01

        # Gradient clipping
        clip_grad_norm_(self.parameters(), 3, norm_type=2)

    def forward(self, utterances, hidden_init=None):
        """
        Computes the embeddings of a batch of utterance spectrograms.

        :param utterances: batch of mel-scale filterbanks of same duration as a tensor of shape
        (batch_size, n_frames, n_channels)
        :param hidden_init: not used in the Transformer version
        :return: the embeddings as a tensor of shape (batch_size, embedding_size)
        """

        # Pass the input through the Transformer Encoder
        out = self.transformer_encoder(utterances)

        # We take the mean of all time steps (similar to a global pooling)
        embeds_raw = self.relu(self.linear(out.mean(dim=1)))

        # L2-normalize it
        embeds = embeds_raw / torch.norm(embeds_raw, dim=1, keepdim=True)

        return embeds

    def similarity_matrix(self, embeds):
        """
        Computes the similarity matrix according of GE2E.

        :param embeds: the embeddings as a tensor of shape (speakers_per_batch,
        utterances_per_speaker, embedding_size)
        :return: the similarity matrix as a tensor of shape (speakers_per_batch,
        utterances_per_speaker, speakers_per_batch)
        """
        speakers_per_batch, utterances_per_speaker = embeds.shape[:2]

        # Inclusive centroids (1 per speaker). Cloning is needed for reverse differentiation
        centroids_incl = torch.mean(embeds, dim=1, keepdim=True)
        centroids_incl = centroids_incl.clone() / torch.norm(centroids_incl, dim=2, keepdim=True)

        # Exclusive centroids (1 per utterance)
        centroids_excl = (torch.sum(embeds, dim=1, keepdim=True) - embeds)
        centroids_excl /= (utterances_per_speaker - 1)
        centroids_excl = centroids_excl.clone() / torch.norm(centroids_excl, dim=2, keepdim=True)

        # Similarity matrix computation
        sim_matrix = torch.zeros(speakers_per_batch, utterances_per_speaker,
                                 speakers_per_batch).to(self.loss_device)
        mask_matrix = 1 - np.eye(speakers_per_batch, dtype=np.int64)
        for j in range(speakers_per_batch):
            mask = np.where(mask_matrix[j])[0]
            sim_matrix[mask, :, j] = (embeds[mask] * centroids_incl[j]).sum(dim=2)
            sim_matrix[j, :, j] = (embeds[j] * centroids_excl[j]).sum(dim=1)

        sim_matrix = sim_matrix * self.similarity_weight + self.similarity_bias
        return sim_matrix

    def loss(self, embeds):
        """
        Computes the softmax loss according the section 2.1 of GE2E.

        :param embeds: the embeddings as a tensor of shape (speakers_per_batch,
        utterances_per_speaker, embedding_size)
        :return: the loss and the EER for this batch of embeddings.
        """
        speakers_per_batch, utterances_per_speaker = embeds.shape[:2]

        # Loss
        sim_matrix = self.similarity_matrix(embeds)
        sim_matrix = sim_matrix.reshape((speakers_per_batch * utterances_per_speaker,
                                         speakers_per_batch))
        ground_truth = np.repeat(np.arange(speakers_per_batch), utterances_per_speaker)
        target = torch.from_numpy(ground_truth).long().to(self.loss_device)
        loss = self.loss_fn(sim_matrix, target)

        # EER (not backpropagated)
        with torch.no_grad():
            inv_argmax = lambda i: np.eye(1, speakers_per_batch, i, dtype=np.int64)[0]
            labels = np.array([inv_argmax(i) for i in ground_truth])
            preds = sim_matrix.detach().cpu().numpy()

            # Snippet from https://yangcha.github.io/EER-ROC/
            fpr, tpr, thresholds = roc_curve(labels.flatten(), preds.flatten())
            eer = brentq(lambda x: 1. - x - interp1d(fpr, tpr)(x), 0., 1.)

        return loss, eer

In [16]:
from pathlib import Path
import logging
import random
import time
import torch
import numpy as np

def evaluate(model, loader, total_evaluate_steps=50, device='cpu', loss_device='cpu'):
    steps = 0
    losses = []
    eers = []
    start_time = time.time()
    while True:
        if (steps+1) > total_evaluate_steps:
            break

        for batch in loader:
            if (steps+1) > total_evaluate_steps:
                break

            n_speakers, n_utterances, freq_len, tempo_len = batch.shape
            data = batch.view(-1, freq_len, tempo_len)
            data = data.transpose(1, 2)
            model.eval()
            embeds = model(data.to(device))
            embeds = embeds.view(n_speakers, n_utterances, -1)
            loss, eer = model.loss(embeds.to(loss_device))
            losses.append(loss.detach().numpy())
            eers.append(eer)
            steps += 1

    mean_loss = np.mean(losses)
    mean_eer = np.mean(eers)
    print(f'Evaluate Mean Loss {mean_loss:.3f}, Mean EER {mean_eer:.3f} - Time: {(time.time() - start_time):.3f}s')

def train():
    print('Loading data...')
    libri_dataset = load_librispeech360_dataset(libri_dataset_path)
    print('Finish to load LibriSpeech360h')
    vivos_dataset = load_vivos_dataset(Path(r'/kaggle/input/vivos-dataset/vivos'))
    print('Finish to load Vivos')
    aishell3_dataset = load_aishell3_dataset(Path(r'/kaggle/input/paddle-speech/AISHELL-3'))
    print('Finish to load AISHELL-3')
    voxceleb_dataset = load_voxceleb_dataset(Path(r'/kaggle/input/voxceleb1train/wav'))
    print('Finish to load Voxceleb')

    datasets = [libri_dataset, aishell3_dataset, vivos_dataset, voxceleb_dataset]
    mds = MultiAudioDataset(datasets)
    random.shuffle(mds.speakers)
    train_speakers = mds.speakers[:-50]
    eval_speakers = mds.speakers[-50:]

    ds = SpeakerDataset(train_speakers,
                        utterances_per_speaker=utterances_per_speaker,
                        seq_len=seq_len)
    loader = torch.utils.data.DataLoader(ds,
                                        batch_size=speakers_per_batch,
                                        shuffle=True,
                                        num_workers=4)

    eval_ds = SpeakerDataset(eval_speakers,
                        utterances_per_speaker=utterances_per_speaker,
                        seq_len=seq_len)
    eval_loader = torch.utils.data.DataLoader(eval_ds,
                                        batch_size=speakers_per_batch,
                                        shuffle=True,
                                        num_workers=4)
    
    dv = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    loss_dv = torch.device(loss_device)
    model = SpeakerEncoder(device=dv, loss_device=loss_dv)
    
    opt = torch.optim.Adam(model.parameters(), lr=speaker_lr)

    total_steps = 0

    ckpts = sorted(list(Path(save_dir).glob('*.pt')))
    if len(ckpts) > 0:
        latest_ckpt_path = ckpts[-1]
        ckpt = torch.load(latest_ckpt_path, weights_only=False)
        if ckpt:
            print(f'loading ckpt {latest_ckpt_path}')
            model.load_state_dict(ckpt['model_state_dict'])
            opt.load_state_dict(ckpt['optimizer_state_dict'])
            total_steps = ckpt['total_steps']
#     ckpt = torch.load("/kaggle/working/000000001600.pt")
#     if ckpt:
# #         print(f'loading ckpt {latest_ckpt_path}')
#         model.load_state_dict(ckpt['model_state_dict'])
#         opt.load_state_dict(ckpt['optimizer_state_dict'])
#         total_steps = ckpt['total_steps']

    print("Start training . . .")
    while True:
        if total_steps >= train_steps:
            break

        for batch in loader:
            if total_steps >= train_steps:
                break
            start_time = time.time()
            for g in opt.param_groups:
                g['lr'] = speaker_lr
            n_speakers, n_utterances, freq_len, tempo_len = batch.shape
            data = batch.view(-1, freq_len, tempo_len)
            data = data.transpose(1, 2)

            model.train()
            opt.zero_grad()

            embeds = model(data.to(dv))
            embeds = embeds.view(n_speakers, n_utterances, -1)
            loss, eer = model.loss(embeds.to(loss_device))

            loss.backward()
            model.do_gradient_ops()
            opt.step()

            total_steps += 1

            if (total_steps+1) % train_print_interval == 0:
                print(f'Step {total_steps+1} Loss {loss:.3f}, EER {eer:.3f} - Time: {(time.time() - start_time):.3f}s')
            if (total_steps+1) % save_interval == 0:
                if not Path(save_dir).exists():
                    Path(save_dir).mkdir()
                save_path = Path(save_dir) / f'{total_steps+1:012d}.pt'
                print(f'saving ckpt {save_path}')
                torch.save({
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': opt.state_dict(),
                    'total_steps': total_steps
                }, save_path)

                # remove old ckpts
                ckpts = sorted(list(Path(save_dir).glob('*.pt')))
                if len(ckpts) > max_ckpts:
                    for ckpt in ckpts[:-max_ckpts]:
                        Path(ckpt).unlink()
                        print(f'ckpt {ckpt} removed')
            if (total_steps+1) % evaluate_interval == 0:
                evaluate(model, eval_loader, total_evaluate_steps=50, device=dv, loss_device=loss_dv)


In [17]:
torch.cuda.is_available()

True

In [None]:
train()

Loading data...
Finish to load LibriSpeech360h
Finish to load Vivos
Finish to load AISHELL-3
Finish to load Voxceleb
Start training . . .
Step 10 Loss 3.052, EER 0.219 - Time: 0.441s
Step 20 Loss 2.443, EER 0.200 - Time: 0.560s
Step 30 Loss 1.919, EER 0.149 - Time: 0.575s
Step 40 Loss 1.892, EER 0.159 - Time: 0.801s
Step 50 Loss 1.750, EER 0.148 - Time: 0.912s
Step 60 Loss 1.753, EER 0.146 - Time: 0.833s
Step 70 Loss 1.555, EER 0.128 - Time: 0.358s
Step 80 Loss 1.638, EER 0.144 - Time: 0.852s
Step 90 Loss 1.406, EER 0.134 - Time: 0.788s
Step 100 Loss 1.379, EER 0.115 - Time: 0.619s
saving ckpt /kaggle/working/000000000100.pt
Step 110 Loss 1.520, EER 0.140 - Time: 0.429s
Step 120 Loss 1.568, EER 0.131 - Time: 0.516s
Step 130 Loss 1.285, EER 0.106 - Time: 0.588s
Step 140 Loss 1.527, EER 0.141 - Time: 0.764s
Step 150 Loss 1.356, EER 0.111 - Time: 0.964s
Step 160 Loss 1.480, EER 0.124 - Time: 0.269s
Step 170 Loss 1.542, EER 0.156 - Time: 0.849s
Step 180 Loss 1.550, EER 0.146 - Time: 0.815s

KeyboardInterrupt: 

In [14]:
import os 

root = r'/kaggle/working/'
for file in os.listdir(root):
    if file.split('.')[-1] == 'pt':
        os.remove(os.path.join(root, file))