## An example of model training and evaluation
Author: Yarden Cohen

Date: July 2020

The script will load trainig data for a Bengalese finch (gy6or6, date: 032212), train a TweetyNet model, and predict a test set.

Data source: https://figshare.com/articles/Bengalese_Finch_song_repository/4805749

Before running:
1. Download the data from 032212.
2. Change the folder definitions in the configuration files to the path for the data and the example csv file


In [1]:
# imports
from pathlib import Path
import pandas as pd
import numpy as np
import json
import torch.utils.data
from argparse import ArgumentParser
from collections import defaultdict
import joblib
import torch
from tqdm import tqdm

# imports from vak
from vak import csv,labels,models,summary_writer,transforms,io,config
from vak.datasets.window_dataset import WindowDataset
from vak.datasets.vocal_dataset import VocalDataset
from vak.device import get_default as get_default_device
from vak.logging import log_or_print
from vak.config import parse
from vak.io import dataframe
import vak.device
import vak.files
import vak.labels as labelfuncs

In [2]:
# setting up parameters for training
path_of_configuration_file = "C:\\Users\\Yarden Cohen\\repos\\tweetynet\\doc\\notebooks\\BF_Example_Train_Predict\\gy6or6_032212_example_train.toml"

toml_path = Path(path_of_configuration_file)
cfg = parse.from_toml(toml_path)

model_config_map = {'TweetyNet': {'loss': {}, 'metrics': {}, 'network': {}, 'optimizer': {'lr': 0.001}}}
train_csv_path = Path(cfg.train.csv_path)
labelset =  cfg.prep.labelset
window_size = cfg.dataloader.window_size
batch_size = cfg.train.batch_size
num_epochs = cfg.train.num_epochs
num_workers = cfg.train.num_workers
results_path = Path(cfg.train.root_results_dir)
spect_key = cfg.spect_params.spect_key
timebins_key = cfg.spect_params.timebins_key
normalize_spectrograms = cfg.train.normalize_spectrograms
shuffle = cfg.train.shuffle
val_step = cfg.train.val_step
ckpt_step = cfg.train.ckpt_step
patience = cfg.train.patience
device = cfg.train.device
logger = None

In [27]:
# make spectrograms and update path in example csv file
spect_path = cfg.prep.data_dir.joinpath('spect')
if not spect_path.is_dir():
    spect_path.mkdir()
dataset_df=dataframe.from_files(labelset=cfg.prep.labelset,
                          data_dir=cfg.prep.data_dir,
                          annot_format=cfg.prep.annot_format,
                          output_dir=spect_path,
                          annot_file=None,
                          audio_format='cbin',
                          spect_params=cfg.spect_params,
                          logger=None)
example_csv_path = cfg.train.csv_path
example_csv_df = pd.read_csv(example_csv_path)
for cnt in range(len(example_csv_df)):
    if example_csv_df['audio_path'][cnt] == dataset_df['audio_path'][cnt]:
        example_csv_df['spect_path'][cnt] = dataset_df['spect_path'][cnt]
    else:
        print('Audio file name mismatch in entry: ' + str(cnt))
example_csv_df.to_csv(example_csv_path,index=False)

making array files containing spectrograms from audio files in: D:\BengaleseFinches\gy6or6\032212
found labels in D:\BengaleseFinches\gy6or6\032212\gy6or6_baseline_220312_0844.22.cbin.not.mat for D:\BengaleseFinches\gy6or6\032212\gy6or6_baseline_220312_0844.22.cbin not in labels_mapping, skipping audio file: D:\BengaleseFinches\gy6or6\032212\gy6or6_baseline_220312_0844.22.cbin
found labels in D:\BengaleseFinches\gy6or6\032212\gy6or6_baseline_220312_0845.28.cbin.not.mat for D:\BengaleseFinches\gy6or6\032212\gy6or6_baseline_220312_0845.28.cbin not in labels_mapping, skipping audio file: D:\BengaleseFinches\gy6or6\032212\gy6or6_baseline_220312_0845.28.cbin
found labels in D:\BengaleseFinches\gy6or6\032212\gy6or6_baseline_220312_0852.55.cbin.not.mat for D:\BengaleseFinches\gy6or6\032212\gy6or6_baseline_220312_0852.55.cbin not in labels_mapping, skipping audio file: D:\BengaleseFinches\gy6or6\032212\gy6or6_baseline_220312_0852.55.cbin
found labels in D:\BengaleseFinches\gy6or6\032212\gy6or6

In [45]:
# prepare training dataset
dataset_df = pd.read_csv(train_csv_path)
results_path = Path(results_path).expanduser().resolve()
timebin_dur = dataframe.validate_and_get_timebin_dur(dataset_df)
train_dur = dataframe.split_dur(dataset_df, 'train')
labelmap = labels.to_map(labelset, map_unlabeled=True)
spect_standardizer = None
transform, target_transform = transforms.get_defaults('train',spect_standardizer)
train_dataset = WindowDataset.from_csv(csv_path=train_csv_path,
                                           x_inds=None,
                                           spect_id_vector=None,
                                           spect_inds_vector=None,
                                           split='train',
                                           labelmap=labelmap,
                                           window_size=window_size,
                                           spect_key=spect_key,
                                           timebins_key=timebins_key,
                                           transform=transform,
                                           target_transform=target_transform
                                           )

train_data = torch.utils.data.DataLoader(dataset=train_dataset,
                                             shuffle=shuffle,
                                             batch_size=batch_size,
                                             num_workers=num_workers)

In [47]:
# prepare validation dataset 
if val_step:
    item_transform = transforms.get_defaults('eval',
                                             spect_standardizer,
                                             window_size=window_size,
                                             return_padding_mask=True,
                                             )
    val_dataset = VocalDataset.from_csv(csv_path=train_csv_path,
                                        split='val',
                                        labelmap=labelmap,
                                        spect_key=spect_key,
                                        timebins_key=timebins_key,
                                        item_transform=item_transform,
                                        )
    val_data = torch.utils.data.DataLoader(dataset=val_dataset,
                                           shuffle=False,
                                           # batch size 1 because each spectrogram reshaped into a batch of windows
                                           batch_size=1,
                                           num_workers=num_workers)
    val_dur = dataframe.split_dur(dataset_df, 'val')    
else:
    val_data = None

In [50]:
# initiate the TweetyNet model and prepare for training 
if device is None:
    device = get_default_device()

models_map = models.from_model_config_map(
        model_config_map,
        num_classes=len(labelmap),
        input_shape=train_dataset.shape,
        logger=logger,
    )

model_name = 'TweetyNet'
model = models_map['TweetyNet']
results_model_root = results_path.joinpath(model_name)
if not results_model_root.is_dir():
    results_model_root.mkdir()
ckpt_root = results_model_root.joinpath('checkpoints')
if not ckpt_root.is_dir():
    ckpt_root.mkdir()
writer = summary_writer.get_summary_writer(log_dir=results_model_root,                                                 filename_suffix=model_name)
model.summary_writer = writer

In [51]:
# This is how the model looks like
model.network

TweetyNet(
  (cnn): Sequential(
    (0): Conv2dTF(1, 32, kernel_size=(5, 5), stride=(1, 1), padding=same)
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=(8, 1), stride=(8, 1), padding=0, dilation=1, ceil_mode=False)
    (3): Conv2dTF(32, 64, kernel_size=(5, 5), stride=(1, 1), padding=same)
    (4): ReLU(inplace=True)
    (5): MaxPool2d(kernel_size=(8, 1), stride=(8, 1), padding=0, dilation=1, ceil_mode=False)
  )
  (rnn): LSTM(128, 128, bidirectional=True)
  (fc): Linear(in_features=256, out_features=12, bias=True)
)

In [7]:
# train the model
model.fit(train_data=train_data,
          num_epochs=num_epochs,
          ckpt_root=ckpt_root,
          val_data=val_data,
          val_step=val_step,
          ckpt_step=ckpt_step,
          patience=patience,
          device=device)

  0%|                                                                                        | 0/11862 [00:00<?, ?it/s]

epoch 1 / 5


Epoch 1, batch 49. Loss: 0.6468. Global step: 49:   0%|                           | 46/11862 [00:45<2:58:07,  1.11it/s]
  0%|                                                                                            | 0/8 [00:00<?, ?it/s][A

Step 50 is a validation step; computing metrics on validation set


Epoch 1, batch 49. Loss: 0.6468. Global step: 49:   0%|                           | 46/11862 [00:59<2:58:07,  1.11it/s]
batch 0 / 8:   0%|                                                                               | 0/8 [00:43<?, ?it/s][A
batch 0 / 8:  12%|████████▉                                                              | 1/8 [00:43<05:02, 43.14s/it][A
batch 1 / 8:  12%|████████▉                                                              | 1/8 [00:43<05:02, 43.14s/it][A
batch 2 / 8:  12%|████████▉                                                              | 1/8 [00:43<05:02, 43.14s/it][A
batch 3 / 8:  12%|████████▉                                                              | 1/8 [00:43<05:02, 43.14s/it][A
batch 3 / 8:  50%|███████████████████████████████████▌                                   | 4/8 [00:43<02:00, 30.21s/it][A
batch 4 / 8:  50%|███████████████████████████████████▌                                   | 4/8 [00:43<02:00, 30.21s/it][A
batch 5 / 8:  50%|█

avg_acc: 0.7577, avg_levenshtein: 195.0000, avg_segment_error_rate: 3.0959, avg_loss: 0.7602
Accuracy on validation set improved. Saving max-val-acc checkpoint.
Saving checkpoint at:
C:\Users\Yarden Cohen\repos\tweetynet\doc\notebooks\BF_Example_Train_Predict\TweetyNet\checkpoints\max-val-acc-checkpoint.pt 


Epoch 1, batch 99. Loss: 0.2657. Global step: 99:   1%|▏                            | 99/11862 [01:31<15:46, 12.43it/s]
  0%|                                                                                            | 0/8 [00:00<?, ?it/s][A

Step 100 is a validation step; computing metrics on validation set


Epoch 1, batch 99. Loss: 0.2657. Global step: 99:   1%|▏                            | 99/11862 [01:50<15:46, 12.43it/s]
batch 0 / 8:   0%|                                                                               | 0/8 [00:42<?, ?it/s][A
batch 0 / 8:  12%|████████▉                                                              | 1/8 [00:42<04:59, 42.72s/it][A
batch 1 / 8:  12%|████████▉                                                              | 1/8 [00:42<04:59, 42.72s/it][A
batch 2 / 8:  12%|████████▉                                                              | 1/8 [00:42<04:59, 42.72s/it][A
batch 3 / 8:  12%|████████▉                                                              | 1/8 [00:42<04:59, 42.72s/it][A
batch 4 / 8:  12%|████████▉                                                              | 1/8 [00:42<04:59, 42.72s/it][A
batch 5 / 8:  12%|████████▉                                                              | 1/8 [00:42<04:59, 42.72s/it][A
batch 5 / 8:  75%|█

avg_acc: 0.8065, avg_levenshtein: 138.0000, avg_segment_error_rate: 2.1892, avg_loss: 0.5786
Accuracy on validation set improved. Saving max-val-acc checkpoint.
Saving checkpoint at:
C:\Users\Yarden Cohen\repos\tweetynet\doc\notebooks\BF_Example_Train_Predict\TweetyNet\checkpoints\max-val-acc-checkpoint.pt 


Epoch 1, batch 149. Loss: 0.3865. Global step: 149:   1%|▎                         | 146/11862 [02:16<56:07,  3.48it/s]
  0%|                                                                                            | 0/8 [00:00<?, ?it/s][A

Step 150 is a validation step; computing metrics on validation set


Epoch 1, batch 149. Loss: 0.3865. Global step: 149:   1%|▎                         | 146/11862 [02:30<56:07,  3.48it/s]
batch 0 / 8:   0%|                                                                               | 0/8 [00:44<?, ?it/s][A
batch 0 / 8:  12%|████████▉                                                              | 1/8 [00:44<05:08, 44.08s/it][A
batch 1 / 8:  12%|████████▉                                                              | 1/8 [00:44<05:08, 44.08s/it][A
batch 2 / 8:  12%|████████▉                                                              | 1/8 [00:44<05:08, 44.08s/it][A
batch 3 / 8:  12%|████████▉                                                              | 1/8 [00:44<05:08, 44.08s/it][A
batch 4 / 8:  12%|████████▉                                                              | 1/8 [00:44<05:08, 44.08s/it][A
batch 5 / 8:  12%|████████▉                                                              | 1/8 [00:44<05:08, 44.08s/it][A
batch 5 / 8:  75%|█

avg_acc: 0.8665, avg_levenshtein: 104.0000, avg_segment_error_rate: 1.6467, avg_loss: 0.4387
Accuracy on validation set improved. Saving max-val-acc checkpoint.
Saving checkpoint at:
C:\Users\Yarden Cohen\repos\tweetynet\doc\notebooks\BF_Example_Train_Predict\TweetyNet\checkpoints\max-val-acc-checkpoint.pt 


Epoch 1, batch 199. Loss: 0.4783. Global step: 199:   2%|▍                         | 199/11862 [03:03<15:03, 12.91it/s]
  0%|                                                                                            | 0/8 [00:00<?, ?it/s][A

Step 200 is a validation step; computing metrics on validation set


Epoch 1, batch 199. Loss: 0.4783. Global step: 199:   2%|▍                         | 199/11862 [03:20<15:03, 12.91it/s]
batch 0 / 8:   0%|                                                                               | 0/8 [00:43<?, ?it/s][A
batch 0 / 8:  12%|████████▉                                                              | 1/8 [00:43<05:05, 43.61s/it][A
batch 1 / 8:  12%|████████▉                                                              | 1/8 [00:43<05:05, 43.61s/it][A
batch 2 / 8:  12%|████████▉                                                              | 1/8 [00:43<05:05, 43.61s/it][A
batch 3 / 8:  12%|████████▉                                                              | 1/8 [00:43<05:05, 43.61s/it][A
batch 4 / 8:  12%|████████▉                                                              | 1/8 [00:43<05:05, 43.61s/it][A
batch 5 / 8:  12%|████████▉                                                              | 1/8 [00:43<05:05, 43.61s/it][A
batch 5 / 8:  75%|█

avg_acc: 0.8853, avg_levenshtein: 90.0000, avg_segment_error_rate: 1.3953, avg_loss: 0.3765
Accuracy on validation set improved. Saving max-val-acc checkpoint.
Saving checkpoint at:
C:\Users\Yarden Cohen\repos\tweetynet\doc\notebooks\BF_Example_Train_Predict\TweetyNet\checkpoints\max-val-acc-checkpoint.pt 
Step 200 is a checkpoint step.
Saving checkpoint at:
C:\Users\Yarden Cohen\repos\tweetynet\doc\notebooks\BF_Example_Train_Predict\TweetyNet\checkpoints\checkpoint.pt 


Epoch 1, batch 249. Loss: 0.4554. Global step: 249:   2%|▌                         | 248/11862 [03:49<41:16,  4.69it/s]
  0%|                                                                                            | 0/8 [00:00<?, ?it/s][A

Step 250 is a validation step; computing metrics on validation set


Epoch 1, batch 249. Loss: 0.4554. Global step: 249:   2%|▌                         | 248/11862 [04:00<41:16,  4.69it/s]
batch 0 / 8:   0%|                                                                               | 0/8 [00:43<?, ?it/s][A
batch 0 / 8:  12%|████████▉                                                              | 1/8 [00:43<05:04, 43.56s/it][A
batch 1 / 8:  12%|████████▉                                                              | 1/8 [00:43<05:04, 43.56s/it][A
batch 2 / 8:  12%|████████▉                                                              | 1/8 [00:43<05:04, 43.56s/it][A
batch 3 / 8:  12%|████████▉                                                              | 1/8 [00:43<05:04, 43.56s/it][A
batch 4 / 8:  12%|████████▉                                                              | 1/8 [00:43<05:04, 43.56s/it][A
batch 4 / 8:  62%|████████████████████████████████████████████▍                          | 5/8 [00:43<01:31, 30.50s/it][A
batch 5 / 8:  62%|█

avg_acc: 0.9053, avg_levenshtein: 96.0000, avg_segment_error_rate: 1.5254, avg_loss: 0.3207
Accuracy on validation set improved. Saving max-val-acc checkpoint.
Saving checkpoint at:
C:\Users\Yarden Cohen\repos\tweetynet\doc\notebooks\BF_Example_Train_Predict\TweetyNet\checkpoints\max-val-acc-checkpoint.pt 


Epoch 1, batch 299. Loss: 0.2894. Global step: 299:   3%|▋                         | 299/11862 [04:35<23:17,  8.27it/s]
  0%|                                                                                            | 0/8 [00:00<?, ?it/s][A

Step 300 is a validation step; computing metrics on validation set


Epoch 1, batch 299. Loss: 0.2894. Global step: 299:   3%|▋                         | 299/11862 [04:50<23:17,  8.27it/s]
batch 0 / 8:   0%|                                                                               | 0/8 [00:43<?, ?it/s][A
batch 0 / 8:  12%|████████▉                                                              | 1/8 [00:43<05:07, 43.99s/it][A
batch 1 / 8:  12%|████████▉                                                              | 1/8 [00:44<05:07, 43.99s/it][A
batch 2 / 8:  12%|████████▉                                                              | 1/8 [00:44<05:07, 43.99s/it][A
batch 3 / 8:  12%|████████▉                                                              | 1/8 [00:44<05:07, 43.99s/it][A
batch 4 / 8:  12%|████████▉                                                              | 1/8 [00:44<05:07, 43.99s/it][A
batch 4 / 8:  62%|████████████████████████████████████████████▍                          | 5/8 [00:44<01:32, 30.80s/it][A
batch 5 / 8:  62%|█

avg_acc: 0.9211, avg_levenshtein: 77.0000, avg_segment_error_rate: 1.2233, avg_loss: 0.2851
Accuracy on validation set improved. Saving max-val-acc checkpoint.
Saving checkpoint at:
C:\Users\Yarden Cohen\repos\tweetynet\doc\notebooks\BF_Example_Train_Predict\TweetyNet\checkpoints\max-val-acc-checkpoint.pt 


Epoch 1, batch 349. Loss: 0.3256. Global step: 349:   3%|▊                         | 349/11862 [05:22<41:26,  4.63it/s]
  0%|                                                                                            | 0/8 [00:00<?, ?it/s][A

Step 350 is a validation step; computing metrics on validation set


Epoch 1, batch 349. Loss: 0.3256. Global step: 349:   3%|▊                         | 349/11862 [05:40<41:26,  4.63it/s]
batch 0 / 8:   0%|                                                                               | 0/8 [00:46<?, ?it/s][A
batch 0 / 8:  12%|████████▉                                                              | 1/8 [00:46<05:25, 46.47s/it][A
batch 1 / 8:  12%|████████▉                                                              | 1/8 [00:46<05:25, 46.47s/it][A
batch 2 / 8:  12%|████████▉                                                              | 1/8 [00:46<05:25, 46.47s/it][A
batch 3 / 8:  12%|████████▉                                                              | 1/8 [00:46<05:25, 46.47s/it][A
batch 4 / 8:  12%|████████▉                                                              | 1/8 [00:46<05:25, 46.47s/it][A
batch 4 / 8:  62%|████████████████████████████████████████████▍                          | 5/8 [00:46<01:37, 32.54s/it][A
batch 5 / 8:  62%|█

avg_acc: 0.9314, avg_levenshtein: 75.0000, avg_segment_error_rate: 1.1989, avg_loss: 0.2385
Accuracy on validation set improved. Saving max-val-acc checkpoint.
Saving checkpoint at:
C:\Users\Yarden Cohen\repos\tweetynet\doc\notebooks\BF_Example_Train_Predict\TweetyNet\checkpoints\max-val-acc-checkpoint.pt 


Epoch 1, batch 399. Loss: 0.1534. Global step: 399:   3%|▊                         | 397/11862 [06:11<59:32,  3.21it/s]
  0%|                                                                                            | 0/8 [00:00<?, ?it/s][A

Step 400 is a validation step; computing metrics on validation set


Epoch 1, batch 399. Loss: 0.1534. Global step: 399:   3%|▊                         | 397/11862 [06:30<59:32,  3.21it/s]
batch 0 / 8:   0%|                                                                               | 0/8 [00:46<?, ?it/s][A
batch 0 / 8:  12%|████████▉                                                              | 1/8 [00:46<05:22, 46.10s/it][A
batch 1 / 8:  12%|████████▉                                                              | 1/8 [00:46<05:22, 46.10s/it][A
batch 2 / 8:  12%|████████▉                                                              | 1/8 [00:46<05:22, 46.10s/it][A
batch 3 / 8:  12%|████████▉                                                              | 1/8 [00:46<05:22, 46.10s/it][A
batch 4 / 8:  12%|████████▉                                                              | 1/8 [00:46<05:22, 46.10s/it][A
batch 5 / 8:  12%|████████▉                                                              | 1/8 [00:46<05:22, 46.10s/it][A
batch 5 / 8:  75%|█

avg_acc: 0.9282, avg_levenshtein: 67.0000, avg_segment_error_rate: 1.0601, avg_loss: 0.2550
Accuracy has not improved in 1 validation steps. Not saving max-val-acc checkpoint for this validation step.
Step 400 is a checkpoint step.
Saving checkpoint at:
C:\Users\Yarden Cohen\repos\tweetynet\doc\notebooks\BF_Example_Train_Predict\TweetyNet\checkpoints\checkpoint.pt 


Epoch 1, batch 449. Loss: 0.2186. Global step: 449:   4%|▉                         | 446/11862 [07:00<23:31,  8.09it/s]
  0%|                                                                                            | 0/8 [00:00<?, ?it/s][A

Step 450 is a validation step; computing metrics on validation set


Epoch 1, batch 449. Loss: 0.2186. Global step: 449:   4%|▉                         | 446/11862 [07:10<23:31,  8.09it/s]
batch 0 / 8:   0%|                                                                               | 0/8 [00:43<?, ?it/s][A
batch 0 / 8:  12%|████████▉                                                              | 1/8 [00:43<05:06, 43.82s/it][A
batch 1 / 8:  12%|████████▉                                                              | 1/8 [00:43<05:06, 43.82s/it][A
batch 2 / 8:  12%|████████▉                                                              | 1/8 [00:43<05:06, 43.82s/it][A
batch 3 / 8:  12%|████████▉                                                              | 1/8 [00:43<05:06, 43.82s/it][A
batch 4 / 8:  12%|████████▉                                                              | 1/8 [00:43<05:06, 43.82s/it][A
batch 4 / 8:  62%|████████████████████████████████████████████▍                          | 5/8 [00:43<01:32, 30.68s/it][A
batch 5 / 8:  62%|█

avg_acc: 0.9435, avg_levenshtein: 51.0000, avg_segment_error_rate: 0.8104, avg_loss: 0.1967
Accuracy on validation set improved. Saving max-val-acc checkpoint.
Saving checkpoint at:
C:\Users\Yarden Cohen\repos\tweetynet\doc\notebooks\BF_Example_Train_Predict\TweetyNet\checkpoints\max-val-acc-checkpoint.pt 


Epoch 1, batch 499. Loss: 0.1526. Global step: 499:   4%|█                         | 496/11862 [07:46<17:52, 10.60it/s]
  0%|                                                                                            | 0/8 [00:00<?, ?it/s][A

Step 500 is a validation step; computing metrics on validation set


Epoch 1, batch 499. Loss: 0.1526. Global step: 499:   4%|█                         | 496/11862 [08:00<17:52, 10.60it/s]
batch 0 / 8:   0%|                                                                               | 0/8 [00:43<?, ?it/s][A
batch 0 / 8:  12%|████████▉                                                              | 1/8 [00:43<05:04, 43.56s/it][A
batch 1 / 8:  12%|████████▉                                                              | 1/8 [00:43<05:04, 43.56s/it][A
batch 2 / 8:  12%|████████▉                                                              | 1/8 [00:43<05:04, 43.56s/it][A
batch 3 / 8:  12%|████████▉                                                              | 1/8 [00:43<05:04, 43.56s/it][A
batch 4 / 8:  12%|████████▉                                                              | 1/8 [00:43<05:04, 43.56s/it][A
batch 4 / 8:  62%|████████████████████████████████████████████▍                          | 5/8 [00:43<01:31, 30.50s/it][A
batch 5 / 8:  62%|█

avg_acc: 0.9431, avg_levenshtein: 56.0000, avg_segment_error_rate: 0.8848, avg_loss: 0.1933
Accuracy has not improved in 1 validation steps. Not saving max-val-acc checkpoint for this validation step.


Epoch 1, batch 549. Loss: 0.1671. Global step: 549:   5%|█▏                        | 549/11862 [08:32<13:53, 13.58it/s]
  0%|                                                                                            | 0/8 [00:00<?, ?it/s][A

Step 550 is a validation step; computing metrics on validation set


Epoch 1, batch 549. Loss: 0.1671. Global step: 549:   5%|█▏                        | 549/11862 [08:50<13:53, 13.58it/s]
batch 0 / 8:   0%|                                                                               | 0/8 [00:45<?, ?it/s][A
batch 0 / 8:  12%|████████▉                                                              | 1/8 [00:45<05:15, 45.05s/it][A
batch 1 / 8:  12%|████████▉                                                              | 1/8 [00:45<05:15, 45.05s/it][A
batch 2 / 8:  12%|████████▉                                                              | 1/8 [00:45<05:15, 45.05s/it][A
batch 3 / 8:  12%|████████▉                                                              | 1/8 [00:45<05:15, 45.05s/it][A
batch 4 / 8:  12%|████████▉                                                              | 1/8 [00:45<05:15, 45.05s/it][A
batch 4 / 8:  62%|████████████████████████████████████████████▍                          | 5/8 [00:45<01:34, 31.54s/it][A
batch 5 / 8:  62%|█

avg_acc: 0.9474, avg_levenshtein: 54.0000, avg_segment_error_rate: 0.8488, avg_loss: 0.1835
Accuracy on validation set improved. Saving max-val-acc checkpoint.
Saving checkpoint at:
C:\Users\Yarden Cohen\repos\tweetynet\doc\notebooks\BF_Example_Train_Predict\TweetyNet\checkpoints\max-val-acc-checkpoint.pt 


Epoch 1, batch 599. Loss: 0.1982. Global step: 599:   5%|█▎                        | 596/11862 [09:20<56:50,  3.30it/s]
  0%|                                                                                            | 0/8 [00:00<?, ?it/s][A

Step 600 is a validation step; computing metrics on validation set


Epoch 1, batch 599. Loss: 0.1982. Global step: 599:   5%|█▎                        | 596/11862 [09:40<56:50,  3.30it/s]
batch 0 / 8:   0%|                                                                               | 0/8 [00:44<?, ?it/s][A
batch 0 / 8:  12%|████████▉                                                              | 1/8 [00:44<05:08, 44.03s/it][A
batch 1 / 8:  12%|████████▉                                                              | 1/8 [00:44<05:08, 44.03s/it][A
batch 2 / 8:  12%|████████▉                                                              | 1/8 [00:44<05:08, 44.03s/it][A
batch 3 / 8:  12%|████████▉                                                              | 1/8 [00:44<05:08, 44.03s/it][A
batch 4 / 8:  12%|████████▉                                                              | 1/8 [00:44<05:08, 44.03s/it][A
batch 4 / 8:  62%|████████████████████████████████████████████▍                          | 5/8 [00:44<01:32, 30.83s/it][A
batch 5 / 8:  62%|█

avg_acc: 0.9533, avg_levenshtein: 43.0000, avg_segment_error_rate: 0.6716, avg_loss: 0.1747
Accuracy on validation set improved. Saving max-val-acc checkpoint.
Saving checkpoint at:
C:\Users\Yarden Cohen\repos\tweetynet\doc\notebooks\BF_Example_Train_Predict\TweetyNet\checkpoints\max-val-acc-checkpoint.pt 
Step 600 is a checkpoint step.
Saving checkpoint at:
C:\Users\Yarden Cohen\repos\tweetynet\doc\notebooks\BF_Example_Train_Predict\TweetyNet\checkpoints\checkpoint.pt 


Epoch 1, batch 649. Loss: 0.1504. Global step: 649:   5%|█▍                        | 649/11862 [10:06<14:13, 13.13it/s]
  0%|                                                                                            | 0/8 [00:00<?, ?it/s][A

Step 650 is a validation step; computing metrics on validation set


Epoch 1, batch 649. Loss: 0.1504. Global step: 649:   5%|█▍                        | 649/11862 [10:20<14:13, 13.13it/s]
batch 0 / 8:   0%|                                                                               | 0/8 [00:45<?, ?it/s][A
batch 0 / 8:  12%|████████▉                                                              | 1/8 [00:45<05:15, 45.04s/it][A
batch 1 / 8:  12%|████████▉                                                              | 1/8 [00:45<05:15, 45.04s/it][A
batch 2 / 8:  12%|████████▉                                                              | 1/8 [00:45<05:15, 45.04s/it][A
batch 3 / 8:  12%|████████▉                                                              | 1/8 [00:45<05:15, 45.04s/it][A
batch 4 / 8:  12%|████████▉                                                              | 1/8 [00:45<05:15, 45.04s/it][A
batch 5 / 8:  12%|████████▉                                                              | 1/8 [00:45<05:15, 45.04s/it][A
batch 5 / 8:  75%|█

avg_acc: 0.9576, avg_levenshtein: 35.0000, avg_segment_error_rate: 0.5532, avg_loss: 0.1527
Accuracy on validation set improved. Saving max-val-acc checkpoint.
Saving checkpoint at:
C:\Users\Yarden Cohen\repos\tweetynet\doc\notebooks\BF_Example_Train_Predict\TweetyNet\checkpoints\max-val-acc-checkpoint.pt 


Epoch 1, batch 699. Loss: 0.0866. Global step: 699:   6%|█▌                        | 699/11862 [10:54<41:02,  4.53it/s]
  0%|                                                                                            | 0/8 [00:00<?, ?it/s][A

Step 700 is a validation step; computing metrics on validation set


Epoch 1, batch 699. Loss: 0.0866. Global step: 699:   6%|█▌                        | 699/11862 [11:10<41:02,  4.53it/s]
batch 0 / 8:   0%|                                                                               | 0/8 [00:45<?, ?it/s][A
batch 0 / 8:  12%|████████▉                                                              | 1/8 [00:45<05:18, 45.44s/it][A
batch 1 / 8:  12%|████████▉                                                              | 1/8 [00:45<05:18, 45.44s/it][A
batch 2 / 8:  12%|████████▉                                                              | 1/8 [00:45<05:18, 45.44s/it][A
batch 3 / 8:  12%|████████▉                                                              | 1/8 [00:45<05:18, 45.44s/it][A
batch 4 / 8:  12%|████████▉                                                              | 1/8 [00:45<05:18, 45.44s/it][A
batch 5 / 8:  12%|████████▉                                                              | 1/8 [00:45<05:18, 45.44s/it][A
batch 5 / 8:  75%|█

avg_acc: 0.9666, avg_levenshtein: 29.0000, avg_segment_error_rate: 0.4405, avg_loss: 0.1249
Accuracy on validation set improved. Saving max-val-acc checkpoint.
Saving checkpoint at:
C:\Users\Yarden Cohen\repos\tweetynet\doc\notebooks\BF_Example_Train_Predict\TweetyNet\checkpoints\max-val-acc-checkpoint.pt 


Epoch 1, batch 749. Loss: 0.1497. Global step: 749:   6%|█▋                        | 749/11862 [11:42<41:02,  4.51it/s]
  0%|                                                                                            | 0/8 [00:00<?, ?it/s][A

Step 750 is a validation step; computing metrics on validation set


Epoch 1, batch 749. Loss: 0.1497. Global step: 749:   6%|█▋                        | 749/11862 [12:00<41:02,  4.51it/s]
batch 0 / 8:   0%|                                                                               | 0/8 [00:43<?, ?it/s][A
batch 0 / 8:  12%|████████▉                                                              | 1/8 [00:43<05:06, 43.81s/it][A
batch 1 / 8:  12%|████████▉                                                              | 1/8 [00:43<05:06, 43.81s/it][A
batch 2 / 8:  12%|████████▉                                                              | 1/8 [00:43<05:06, 43.81s/it][A
batch 3 / 8:  12%|████████▉                                                              | 1/8 [00:43<05:06, 43.81s/it][A
batch 4 / 8:  12%|████████▉                                                              | 1/8 [00:43<05:06, 43.81s/it][A
batch 5 / 8:  12%|████████▉                                                              | 1/8 [00:43<05:06, 43.81s/it][A
batch 5 / 8:  75%|█

avg_acc: 0.9636, avg_levenshtein: 26.0000, avg_segment_error_rate: 0.3930, avg_loss: 0.1336
Accuracy has not improved in 1 validation steps. Not saving max-val-acc checkpoint for this validation step.


Epoch 1, batch 799. Loss: 0.1725. Global step: 799:   7%|█▋                        | 798/11862 [12:29<39:43,  4.64it/s]
  0%|                                                                                            | 0/8 [00:00<?, ?it/s][A

Step 800 is a validation step; computing metrics on validation set


Epoch 1, batch 799. Loss: 0.1725. Global step: 799:   7%|█▋                        | 798/11862 [12:40<39:43,  4.64it/s]
batch 0 / 8:   0%|                                                                               | 0/8 [00:47<?, ?it/s][A
batch 0 / 8:  12%|████████▉                                                              | 1/8 [00:47<05:30, 47.28s/it][A
batch 1 / 8:  12%|████████▉                                                              | 1/8 [00:47<05:30, 47.28s/it][A
batch 2 / 8:  12%|████████▉                                                              | 1/8 [00:47<05:30, 47.28s/it][A
batch 3 / 8:  12%|████████▉                                                              | 1/8 [00:47<05:30, 47.28s/it][A
batch 4 / 8:  12%|████████▉                                                              | 1/8 [00:47<05:30, 47.28s/it][A
batch 4 / 8:  62%|████████████████████████████████████████████▍                          | 5/8 [00:47<01:39, 33.10s/it][A
batch 5 / 8:  62%|█

avg_acc: 0.9636, avg_levenshtein: 29.0000, avg_segment_error_rate: 0.4605, avg_loss: 0.1248
Accuracy has not improved in 2 validation steps. Not saving max-val-acc checkpoint for this validation step.
Step 800 is a checkpoint step.
Saving checkpoint at:
C:\Users\Yarden Cohen\repos\tweetynet\doc\notebooks\BF_Example_Train_Predict\TweetyNet\checkpoints\checkpoint.pt 


Epoch 1, batch 849. Loss: 0.1028. Global step: 849:   7%|█▊                        | 849/11862 [13:19<24:05,  7.62it/s]
  0%|                                                                                            | 0/8 [00:00<?, ?it/s][A

Step 850 is a validation step; computing metrics on validation set


Epoch 1, batch 849. Loss: 0.1028. Global step: 849:   7%|█▊                        | 849/11862 [13:30<24:05,  7.62it/s]
batch 0 / 8:   0%|                                                                               | 0/8 [00:45<?, ?it/s][A
batch 0 / 8:  12%|████████▉                                                              | 1/8 [00:45<05:15, 45.08s/it][A
batch 1 / 8:  12%|████████▉                                                              | 1/8 [00:45<05:15, 45.08s/it][A
batch 2 / 8:  12%|████████▉                                                              | 1/8 [00:45<05:15, 45.08s/it][A
batch 3 / 8:  12%|████████▉                                                              | 1/8 [00:45<05:15, 45.08s/it][A
batch 4 / 8:  12%|████████▉                                                              | 1/8 [00:45<05:15, 45.08s/it][A
batch 4 / 8:  62%|████████████████████████████████████████████▍                          | 5/8 [00:45<01:34, 31.56s/it][A
batch 5 / 8:  62%|█

avg_acc: 0.9657, avg_levenshtein: 21.0000, avg_segment_error_rate: 0.3375, avg_loss: 0.1178
Accuracy has not improved in 3 validation steps. Not saving max-val-acc checkpoint for this validation step.


Epoch 1, batch 899. Loss: 0.0893. Global step: 899:   8%|█▉                        | 896/11862 [14:06<55:28,  3.29it/s]
  0%|                                                                                            | 0/8 [00:00<?, ?it/s][A

Step 900 is a validation step; computing metrics on validation set


Epoch 1, batch 899. Loss: 0.0893. Global step: 899:   8%|█▉                        | 896/11862 [14:20<55:28,  3.29it/s]
batch 0 / 8:   0%|                                                                               | 0/8 [00:44<?, ?it/s][A
batch 0 / 8:  12%|████████▉                                                              | 1/8 [00:44<05:09, 44.20s/it][A
batch 1 / 8:  12%|████████▉                                                              | 1/8 [00:44<05:09, 44.20s/it][A
batch 2 / 8:  12%|████████▉                                                              | 1/8 [00:44<05:09, 44.20s/it][A
batch 3 / 8:  12%|████████▉                                                              | 1/8 [00:44<05:09, 44.20s/it][A
batch 4 / 8:  12%|████████▉                                                              | 1/8 [00:44<05:09, 44.20s/it][A
batch 5 / 8:  12%|████████▉                                                              | 1/8 [00:44<05:09, 44.20s/it][A
batch 5 / 8:  75%|█

avg_acc: 0.9696, avg_levenshtein: 21.0000, avg_segment_error_rate: 0.3122, avg_loss: 0.1009
Accuracy on validation set improved. Saving max-val-acc checkpoint.
Saving checkpoint at:
C:\Users\Yarden Cohen\repos\tweetynet\doc\notebooks\BF_Example_Train_Predict\TweetyNet\checkpoints\max-val-acc-checkpoint.pt 


Epoch 1, batch 949. Loss: 0.0663. Global step: 949:   8%|██                        | 949/11862 [14:53<13:50, 13.14it/s]
  0%|                                                                                            | 0/8 [00:00<?, ?it/s][A

Step 950 is a validation step; computing metrics on validation set


Epoch 1, batch 949. Loss: 0.0663. Global step: 949:   8%|██                        | 949/11862 [15:10<13:50, 13.14it/s]
batch 0 / 8:   0%|                                                                               | 0/8 [00:45<?, ?it/s][A
batch 0 / 8:  12%|████████▉                                                              | 1/8 [00:45<05:15, 45.04s/it][A
batch 1 / 8:  12%|████████▉                                                              | 1/8 [00:45<05:15, 45.04s/it][A
batch 2 / 8:  12%|████████▉                                                              | 1/8 [00:45<05:15, 45.04s/it][A
batch 3 / 8:  12%|████████▉                                                              | 1/8 [00:45<05:15, 45.04s/it][A
batch 4 / 8:  12%|████████▉                                                              | 1/8 [00:45<05:15, 45.04s/it][A
batch 4 / 8:  62%|████████████████████████████████████████████▍                          | 5/8 [00:45<01:34, 31.54s/it][A
batch 5 / 8:  62%|█

avg_acc: 0.9694, avg_levenshtein: 23.0000, avg_segment_error_rate: 0.3591, avg_loss: 0.1027
Accuracy has not improved in 1 validation steps. Not saving max-val-acc checkpoint for this validation step.


Epoch 1, batch 999. Loss: 0.0620. Global step: 999:   8%|██▏                       | 999/11862 [15:41<39:54,  4.54it/s]
  0%|                                                                                            | 0/8 [00:00<?, ?it/s][A

Step 1000 is a validation step; computing metrics on validation set


Epoch 1, batch 999. Loss: 0.0620. Global step: 999:   8%|██▏                       | 999/11862 [16:00<39:54,  4.54it/s]
batch 0 / 8:   0%|                                                                               | 0/8 [00:44<?, ?it/s][A
batch 0 / 8:  12%|████████▉                                                              | 1/8 [00:44<05:10, 44.37s/it][A
batch 1 / 8:  12%|████████▉                                                              | 1/8 [00:44<05:10, 44.37s/it][A
batch 2 / 8:  12%|████████▉                                                              | 1/8 [00:44<05:10, 44.37s/it][A
batch 3 / 8:  12%|████████▉                                                              | 1/8 [00:44<05:10, 44.37s/it][A
batch 4 / 8:  12%|████████▉                                                              | 1/8 [00:44<05:10, 44.37s/it][A
batch 4 / 8:  62%|████████████████████████████████████████████▍                          | 5/8 [00:44<01:33, 31.07s/it][A
batch 5 / 8:  62%|█

avg_acc: 0.9718, avg_levenshtein: 14.0000, avg_segment_error_rate: 0.2133, avg_loss: 0.0999
Accuracy on validation set improved. Saving max-val-acc checkpoint.
Saving checkpoint at:
C:\Users\Yarden Cohen\repos\tweetynet\doc\notebooks\BF_Example_Train_Predict\TweetyNet\checkpoints\max-val-acc-checkpoint.pt 
Step 1000 is a checkpoint step.
Saving checkpoint at:
C:\Users\Yarden Cohen\repos\tweetynet\doc\notebooks\BF_Example_Train_Predict\TweetyNet\checkpoints\checkpoint.pt 


Epoch 1, batch 1049. Loss: 0.1002. Global step: 1049:   9%|██                     | 1049/11862 [16:28<39:26,  4.57it/s]
  0%|                                                                                            | 0/8 [00:00<?, ?it/s][A

Step 1050 is a validation step; computing metrics on validation set


Epoch 1, batch 1049. Loss: 0.1002. Global step: 1049:   9%|██                     | 1049/11862 [16:40<39:26,  4.57it/s]
batch 0 / 8:   0%|                                                                               | 0/8 [00:44<?, ?it/s][A
batch 0 / 8:  12%|████████▉                                                              | 1/8 [00:44<05:12, 44.66s/it][A
batch 1 / 8:  12%|████████▉                                                              | 1/8 [00:44<05:12, 44.66s/it][A
batch 2 / 8:  12%|████████▉                                                              | 1/8 [00:44<05:12, 44.66s/it][A
batch 3 / 8:  12%|████████▉                                                              | 1/8 [00:44<05:12, 44.66s/it][A
batch 4 / 8:  12%|████████▉                                                              | 1/8 [00:44<05:12, 44.66s/it][A
batch 4 / 8:  62%|████████████████████████████████████████████▍                          | 5/8 [00:44<01:33, 31.27s/it][A
batch 5 / 8:  62%|█

avg_acc: 0.9693, avg_levenshtein: 23.0000, avg_segment_error_rate: 0.3449, avg_loss: 0.1007
Accuracy has not improved in 1 validation steps. Not saving max-val-acc checkpoint for this validation step.


Epoch 1, batch 1099. Loss: 0.0622. Global step: 1099:   9%|██▏                    | 1096/11862 [17:15<54:11,  3.31it/s]
  0%|                                                                                            | 0/8 [00:00<?, ?it/s][A

Step 1100 is a validation step; computing metrics on validation set


Epoch 1, batch 1099. Loss: 0.0622. Global step: 1099:   9%|██▏                    | 1096/11862 [17:30<54:11,  3.31it/s]
batch 0 / 8:   0%|                                                                               | 0/8 [00:43<?, ?it/s][A
batch 0 / 8:  12%|████████▉                                                              | 1/8 [00:43<05:03, 43.42s/it][A
batch 1 / 8:  12%|████████▉                                                              | 1/8 [00:43<05:03, 43.42s/it][A
batch 2 / 8:  12%|████████▉                                                              | 1/8 [00:43<05:03, 43.42s/it][A
batch 3 / 8:  12%|████████▉                                                              | 1/8 [00:43<05:03, 43.42s/it][A
batch 4 / 8:  12%|████████▉                                                              | 1/8 [00:43<05:03, 43.42s/it][A
batch 4 / 8:  62%|████████████████████████████████████████████▍                          | 5/8 [00:43<01:31, 30.40s/it][A
batch 5 / 8:  62%|█

avg_acc: 0.9677, avg_levenshtein: 23.0000, avg_segment_error_rate: 0.3716, avg_loss: 0.1051
Accuracy has not improved in 2 validation steps. Not saving max-val-acc checkpoint for this validation step.


Epoch 1, batch 1149. Loss: 0.0503. Global step: 1149:  10%|██▏                    | 1147/11862 [18:01<17:26, 10.24it/s]
  0%|                                                                                            | 0/8 [00:00<?, ?it/s][A

Step 1150 is a validation step; computing metrics on validation set


Epoch 1, batch 1149. Loss: 0.0503. Global step: 1149:  10%|██▏                    | 1147/11862 [18:20<17:26, 10.24it/s]
batch 0 / 8:   0%|                                                                               | 0/8 [00:43<?, ?it/s][A
batch 0 / 8:  12%|████████▉                                                              | 1/8 [00:43<05:03, 43.39s/it][A
batch 1 / 8:  12%|████████▉                                                              | 1/8 [00:43<05:03, 43.39s/it][A
batch 2 / 8:  12%|████████▉                                                              | 1/8 [00:43<05:03, 43.39s/it][A
batch 3 / 8:  12%|████████▉                                                              | 1/8 [00:43<05:03, 43.39s/it][A
batch 4 / 8:  12%|████████▉                                                              | 1/8 [00:43<05:03, 43.39s/it][A
batch 5 / 8:  12%|████████▉                                                              | 1/8 [00:43<05:03, 43.39s/it][A
batch 5 / 8:  75%|█

avg_acc: 0.9716, avg_levenshtein: 16.0000, avg_segment_error_rate: 0.2486, avg_loss: 0.0969
Accuracy has not improved in 3 validation steps. Not saving max-val-acc checkpoint for this validation step.


Epoch 1, batch 1199. Loss: 0.0540. Global step: 1199:  10%|██▎                    | 1196/11862 [18:47<21:27,  8.28it/s]
  0%|                                                                                            | 0/8 [00:00<?, ?it/s][A

Step 1200 is a validation step; computing metrics on validation set


Epoch 1, batch 1199. Loss: 0.0540. Global step: 1199:  10%|██▎                    | 1196/11862 [19:00<21:27,  8.28it/s]
batch 0 / 8:   0%|                                                                               | 0/8 [00:43<?, ?it/s][A
batch 0 / 8:  12%|████████▉                                                              | 1/8 [00:43<05:06, 43.83s/it][A
batch 1 / 8:  12%|████████▉                                                              | 1/8 [00:43<05:06, 43.83s/it][A
batch 2 / 8:  12%|████████▉                                                              | 1/8 [00:43<05:06, 43.83s/it][A
batch 3 / 8:  12%|████████▉                                                              | 1/8 [00:43<05:06, 43.83s/it][A
batch 4 / 8:  12%|████████▉                                                              | 1/8 [00:43<05:06, 43.83s/it][A
batch 5 / 8:  12%|████████▉                                                              | 1/8 [00:43<05:06, 43.83s/it][A
batch 5 / 8:  75%|█

avg_acc: 0.9723, avg_levenshtein: 14.0000, avg_segment_error_rate: 0.2190, avg_loss: 0.0888
Accuracy on validation set improved. Saving max-val-acc checkpoint.
Saving checkpoint at:
C:\Users\Yarden Cohen\repos\tweetynet\doc\notebooks\BF_Example_Train_Predict\TweetyNet\checkpoints\max-val-acc-checkpoint.pt 
Step 1200 is a checkpoint step.
Saving checkpoint at:
C:\Users\Yarden Cohen\repos\tweetynet\doc\notebooks\BF_Example_Train_Predict\TweetyNet\checkpoints\checkpoint.pt 


Epoch 1, batch 1249. Loss: 0.0816. Global step: 1249:  11%|██▍                    | 1248/11862 [19:34<13:18, 13.29it/s]
  0%|                                                                                            | 0/8 [00:00<?, ?it/s][A

Step 1250 is a validation step; computing metrics on validation set


Epoch 1, batch 1249. Loss: 0.0816. Global step: 1249:  11%|██▍                    | 1248/11862 [19:50<13:18, 13.29it/s]
batch 0 / 8:   0%|                                                                               | 0/8 [00:44<?, ?it/s][A
batch 0 / 8:  12%|████████▉                                                              | 1/8 [00:44<05:11, 44.55s/it][A
batch 1 / 8:  12%|████████▉                                                              | 1/8 [00:44<05:11, 44.55s/it][A
batch 2 / 8:  12%|████████▉                                                              | 1/8 [00:44<05:11, 44.55s/it][A
batch 3 / 8:  12%|████████▉                                                              | 1/8 [00:44<05:11, 44.55s/it][A
batch 4 / 8:  12%|████████▉                                                              | 1/8 [00:44<05:11, 44.55s/it][A
batch 5 / 8:  12%|████████▉                                                              | 1/8 [00:44<05:11, 44.55s/it][A
batch 5 / 8:  75%|█

avg_acc: 0.9747, avg_levenshtein: 13.0000, avg_segment_error_rate: 0.1937, avg_loss: 0.0874
Accuracy on validation set improved. Saving max-val-acc checkpoint.
Saving checkpoint at:
C:\Users\Yarden Cohen\repos\tweetynet\doc\notebooks\BF_Example_Train_Predict\TweetyNet\checkpoints\max-val-acc-checkpoint.pt 


Epoch 1, batch 1299. Loss: 0.0825. Global step: 1299:  11%|██▌                    | 1297/11862 [20:21<28:43,  6.13it/s]
  0%|                                                                                            | 0/8 [00:00<?, ?it/s][A

Step 1300 is a validation step; computing metrics on validation set


Epoch 1, batch 1299. Loss: 0.0825. Global step: 1299:  11%|██▌                    | 1297/11862 [20:40<28:43,  6.13it/s]
batch 0 / 8:   0%|                                                                               | 0/8 [00:44<?, ?it/s][A
batch 0 / 8:  12%|████████▉                                                              | 1/8 [00:44<05:09, 44.22s/it][A
batch 1 / 8:  12%|████████▉                                                              | 1/8 [00:44<05:09, 44.22s/it][A
batch 2 / 8:  12%|████████▉                                                              | 1/8 [00:44<05:09, 44.22s/it][A
batch 3 / 8:  12%|████████▉                                                              | 1/8 [00:44<05:09, 44.22s/it][A
batch 4 / 8:  12%|████████▉                                                              | 1/8 [00:44<05:09, 44.22s/it][A
batch 5 / 8:  12%|████████▉                                                              | 1/8 [00:44<05:09, 44.22s/it][A
batch 5 / 8:  75%|█

avg_acc: 0.9766, avg_levenshtein: 10.0000, avg_segment_error_rate: 0.1490, avg_loss: 0.0812
Accuracy on validation set improved. Saving max-val-acc checkpoint.
Saving checkpoint at:
C:\Users\Yarden Cohen\repos\tweetynet\doc\notebooks\BF_Example_Train_Predict\TweetyNet\checkpoints\max-val-acc-checkpoint.pt 


Epoch 1, batch 1349. Loss: 0.0500. Global step: 1349:  11%|██▌                    | 1346/11862 [21:07<20:56,  8.37it/s]
  0%|                                                                                            | 0/8 [00:00<?, ?it/s][A

Step 1350 is a validation step; computing metrics on validation set


Epoch 1, batch 1349. Loss: 0.0500. Global step: 1349:  11%|██▌                    | 1346/11862 [21:20<20:56,  8.37it/s]
batch 0 / 8:   0%|                                                                               | 0/8 [00:44<?, ?it/s][A
batch 0 / 8:  12%|████████▉                                                              | 1/8 [00:44<05:12, 44.64s/it][A
batch 1 / 8:  12%|████████▉                                                              | 1/8 [00:44<05:12, 44.64s/it][A
batch 2 / 8:  12%|████████▉                                                              | 1/8 [00:44<05:12, 44.64s/it][A
batch 3 / 8:  12%|████████▉                                                              | 1/8 [00:44<05:12, 44.64s/it][A
batch 4 / 8:  12%|████████▉                                                              | 1/8 [00:44<05:12, 44.64s/it][A
batch 5 / 8:  12%|████████▉                                                              | 1/8 [00:44<05:12, 44.64s/it][A
batch 5 / 8:  75%|█

avg_acc: 0.9710, avg_levenshtein: 16.0000, avg_segment_error_rate: 0.2408, avg_loss: 0.0911
Accuracy has not improved in 1 validation steps. Not saving max-val-acc checkpoint for this validation step.


Epoch 1, batch 1399. Loss: 0.0475. Global step: 1399:  12%|██▋                    | 1396/11862 [21:55<16:55, 10.31it/s]
  0%|                                                                                            | 0/8 [00:00<?, ?it/s][A

Step 1400 is a validation step; computing metrics on validation set


Epoch 1, batch 1399. Loss: 0.0475. Global step: 1399:  12%|██▋                    | 1396/11862 [22:10<16:55, 10.31it/s]
batch 0 / 8:   0%|                                                                               | 0/8 [00:44<?, ?it/s][A
batch 0 / 8:  12%|████████▉                                                              | 1/8 [00:44<05:12, 44.60s/it][A
batch 1 / 8:  12%|████████▉                                                              | 1/8 [00:44<05:12, 44.60s/it][A
batch 2 / 8:  12%|████████▉                                                              | 1/8 [00:44<05:12, 44.60s/it][A
batch 3 / 8:  12%|████████▉                                                              | 1/8 [00:44<05:12, 44.60s/it][A
batch 4 / 8:  12%|████████▉                                                              | 1/8 [00:44<05:12, 44.60s/it][A
batch 5 / 8:  12%|████████▉                                                              | 1/8 [00:44<05:12, 44.60s/it][A
batch 5 / 8:  75%|█

avg_acc: 0.9730, avg_levenshtein: 17.0000, avg_segment_error_rate: 0.2539, avg_loss: 0.0832
Accuracy has not improved in 2 validation steps. Not saving max-val-acc checkpoint for this validation step.
Step 1400 is a checkpoint step.
Saving checkpoint at:
C:\Users\Yarden Cohen\repos\tweetynet\doc\notebooks\BF_Example_Train_Predict\TweetyNet\checkpoints\checkpoint.pt 


Epoch 1, batch 1449. Loss: 0.0318. Global step: 1449:  12%|██▊                    | 1449/11862 [22:42<13:09, 13.19it/s]
  0%|                                                                                            | 0/8 [00:00<?, ?it/s][A

Step 1450 is a validation step; computing metrics on validation set


Epoch 1, batch 1449. Loss: 0.0318. Global step: 1449:  12%|██▊                    | 1449/11862 [23:00<13:09, 13.19it/s]
batch 0 / 8:   0%|                                                                               | 0/8 [00:44<?, ?it/s][A
batch 0 / 8:  12%|████████▉                                                              | 1/8 [00:44<05:09, 44.25s/it][A
batch 1 / 8:  12%|████████▉                                                              | 1/8 [00:44<05:09, 44.25s/it][A
batch 2 / 8:  12%|████████▉                                                              | 1/8 [00:44<05:09, 44.25s/it][A
batch 3 / 8:  12%|████████▉                                                              | 1/8 [00:44<05:09, 44.25s/it][A
batch 4 / 8:  12%|████████▉                                                              | 1/8 [00:44<05:09, 44.25s/it][A
batch 5 / 8:  12%|████████▉                                                              | 1/8 [00:44<05:09, 44.25s/it][A
batch 5 / 8:  75%|█

avg_acc: 0.9763, avg_levenshtein: 9.0000, avg_segment_error_rate: 0.1382, avg_loss: 0.0766
Accuracy has not improved in 3 validation steps. Not saving max-val-acc checkpoint for this validation step.


Epoch 1, batch 1499. Loss: 0.0609. Global step: 1499:  13%|██▉                    | 1496/11862 [23:29<51:33,  3.35it/s]
  0%|                                                                                            | 0/8 [00:00<?, ?it/s][A

Step 1500 is a validation step; computing metrics on validation set


Epoch 1, batch 1499. Loss: 0.0609. Global step: 1499:  13%|██▉                    | 1496/11862 [23:40<51:33,  3.35it/s]
batch 0 / 8:   0%|                                                                               | 0/8 [00:44<?, ?it/s][A
batch 0 / 8:  12%|████████▉                                                              | 1/8 [00:44<05:08, 44.01s/it][A
batch 1 / 8:  12%|████████▉                                                              | 1/8 [00:44<05:08, 44.01s/it][A
batch 2 / 8:  12%|████████▉                                                              | 1/8 [00:44<05:08, 44.01s/it][A
batch 3 / 8:  12%|████████▉                                                              | 1/8 [00:44<05:08, 44.01s/it][A
batch 4 / 8:  12%|████████▉                                                              | 1/8 [00:44<05:08, 44.01s/it][A
batch 5 / 8:  12%|████████▉                                                              | 1/8 [00:44<05:08, 44.01s/it][A
batch 5 / 8:  75%|█

avg_acc: 0.9764, avg_levenshtein: 13.0000, avg_segment_error_rate: 0.1966, avg_loss: 0.0759
Accuracy has not improved in 4 validation steps. Not saving max-val-acc checkpoint for this validation step.


Epoch 1, batch 1549. Loss: 0.0461. Global step: 1549:  13%|██▉                    | 1546/11862 [24:15<16:45, 10.26it/s]
  0%|                                                                                            | 0/8 [00:00<?, ?it/s][A

Step 1550 is a validation step; computing metrics on validation set


Epoch 1, batch 1549. Loss: 0.0461. Global step: 1549:  13%|██▉                    | 1546/11862 [24:30<16:45, 10.26it/s]
batch 0 / 8:   0%|                                                                               | 0/8 [00:42<?, ?it/s][A
batch 0 / 8:  12%|████████▉                                                              | 1/8 [00:42<04:59, 42.82s/it][A
batch 1 / 8:  12%|████████▉                                                              | 1/8 [00:42<04:59, 42.82s/it][A
batch 2 / 8:  12%|████████▉                                                              | 1/8 [00:42<04:59, 42.82s/it][A
batch 3 / 8:  12%|████████▉                                                              | 1/8 [00:42<04:59, 42.82s/it][A
batch 4 / 8:  12%|████████▉                                                              | 1/8 [00:42<04:59, 42.82s/it][A
batch 5 / 8:  12%|████████▉                                                              | 1/8 [00:42<04:59, 42.82s/it][A
batch 5 / 8:  75%|█

avg_acc: 0.9742, avg_levenshtein: 17.0000, avg_segment_error_rate: 0.2630, avg_loss: 0.0805
Stopping training early, accuracy has not improved in 4 validation steps.
Saving checkpoint at:
C:\Users\Yarden Cohen\repos\tweetynet\doc\notebooks\BF_Example_Train_Predict\TweetyNet\checkpoints\checkpoint.pt 


Now that we have a trained model we can use it to predict segments and labels in a test set

In [94]:
# setting up parameters for prediction
path_of_configuration_file = "C:\\Users\\Yarden Cohen\\repos\\tweetynet\\doc\\notebooks\\BF_Example_Train_Predict\\gy6or6_032212_example_eval.toml"

toml_path = Path(path_of_configuration_file)
cfg = parse.from_toml(toml_path)
min_segment_dur = 0.01
device='cuda'
spect_key='s'
timebins_key='t'
spect_standardizer = joblib.load(cfg.eval.spect_scaler_path)
with cfg.eval.labelmap_path.open('r') as f:
        labelmap = json.load(f)

In [95]:
# prepare evaluation data
item_transform = transforms.get_defaults('eval',
                                             spect_standardizer=None,
                                             window_size=cfg.dataloader.window_size,
                                             return_padding_mask=True,
                                             )

eval_dataset = VocalDataset.from_csv(csv_path=cfg.eval.csv_path,
                                     split='test',
                                     labelmap=labelmap,
                                     spect_key=spect_key,
                                     timebins_key=timebins_key,
                                     item_transform=item_transform,
                                     )

eval_data = torch.utils.data.DataLoader(dataset=eval_dataset,
                                        shuffle=False,
                                        # batch size 1 because each spectrogram reshaped into a batch of windows
                                        batch_size=1,
                                        num_workers=cfg.eval.num_workers)


In [96]:
# Create model
input_shape = eval_dataset.shape
# if dataset returns spectrogram reshaped into windows,
# throw out the window dimension; just want to tell network (channels, height, width) shape
if len(input_shape) == 4:
    input_shape = input_shape[1:]

models_map = models.from_model_config_map(
    model_config_map,
    num_classes=len(labelmap),
    input_shape=input_shape
)
model_name = 'TweetyNet'
model = models_map['TweetyNet']
model.load(cfg.eval.checkpoint_path)
metrics = model.metrics  # metric name -> callable map we use below in loop
if device is None:
    device = vak.device.get_default_device()
pred_dict = model.predict(pred_data=eval_data,
                          device=device)




  0%|                                                                                           | 0/17 [00:00<?, ?it/s][A[A[A

Loading checkpoint from:
C:\Users\Yarden Cohen\repos\tweetynet\doc\notebooks\BF_Example_Train_Predict\TweetyNet\checkpoints\max-val-acc-checkpoint.pt 





batch 0 / 17:   0%|                                                                             | 0/17 [00:42<?, ?it/s][A[A[A


batch 0 / 17:   6%|████                                                                 | 1/17 [00:42<11:19, 42.44s/it][A[A[A


batch 1 / 17:   6%|████                                                                 | 1/17 [00:42<11:19, 42.44s/it][A[A[A


batch 2 / 17:   6%|████                                                                 | 1/17 [00:42<11:19, 42.44s/it][A[A[A


batch 3 / 17:   6%|████                                                                 | 1/17 [00:42<11:19, 42.44s/it][A[A[A


batch 4 / 17:   6%|████                                                                 | 1/17 [00:42<11:19, 42.44s/it][A[A[A


batch 5 / 17:   6%|████                                                                 | 1/17 [00:42<11:19, 42.44s/it][A[A[A


batch 6 / 17:   6%|████                                                         

In [97]:
def compute_metrics(metrics, y_true, y_pred, y_true_labels, y_pred_labels):
    """helper function to compute metrics

    Parameters
    ----------
    metrics : dict
        where keys are metric names and values are callables that compute the metric
        given ground truth and prediction
    y_true : torch.Tensor
        vector of labeled time bins
    y_pred : torch.Tensor
        vector of labeled time bins
    y_true_labels : str
        sequence of segment labels
    y_pred_labels : str
        sequence of segment labels

    Returns
    -------
    metric_vals : defaultdict
    """
    metric_vals = {}

    for metric_name, metric_callable in metrics.items():
        if metric_name == 'acc':
            metric_vals[metric_name] = metric_callable(y_pred, y_true)
        elif metric_name == 'levenshtein':
            metric_vals[metric_name] = metric_callable(y_pred_labels, y_true_labels)
        elif metric_name == 'segment_error_rate':
            metric_vals[metric_name] = metric_callable(y_pred_labels, y_true_labels)

    return metric_vals

In [98]:
records = defaultdict(list)  # will be used with pandas.DataFrame.from_records to make output csv
to_long_tensor = transforms.ToLongTensor()
progress_bar = tqdm(eval_data)
for ind, batch in enumerate(progress_bar):
    y_true, padding_mask, spect_path = batch['annot'], batch['padding_mask'], batch['spect_path']
    # need to convert spect_path to tuple for match in call to index() below
    spect_path = tuple(spect_path)
    records['spect_path'].append(spect_path[0])  # remove str from tuple
    y_true = y_true.to(device)
    y_true_np = np.squeeze(y_true.cpu().numpy())
    y_true_labels, _, _ = labelfuncs.lbl_tb2segments(y_true_np,
                                                     labelmap=labelmap,
                                                     timebin_dur=timebin_dur)
    y_true_labels = ''.join(y_true_labels.tolist())

    y_pred_ind = spect_path[0] #pred_dict['y'].index(spect_path)
    y_pred = pred_dict[y_pred_ind] #pred_dict['y_pred'][y_pred_ind]
    y_pred = torch.argmax(y_pred, dim=1)  # assumes class dimension is 1
    y_pred = torch.flatten(y_pred)
    y_pred = y_pred.unsqueeze(0)[padding_mask]
    y_pred_np = np.squeeze(y_pred.cpu().numpy())
    y_pred_labels, _, _ = labelfuncs.lbl_tb2segments(y_pred_np,
                                                     labelmap=labelmap,
                                                     timebin_dur=timebin_dur,
                                                     min_segment_dur=None,
                                                     majority_vote=False)
    y_pred_labels = ''.join(y_pred_labels.tolist())

    metric_vals_batch = compute_metrics(metrics, y_true, y_pred, y_true_labels, y_pred_labels)
    for metric_name, metric_val in metric_vals_batch.items():
        records[metric_name].append(metric_val)

    # --- apply majority vote and min segment dur transforms separately
    # need segment_inds_list for both transforms
    segment_inds_list = labelfuncs.lbl_tb_segment_inds_list(y_pred_np,
                                                            unlabeled_label=labelmap['unlabeled'])

    # ---- majority vote transform
    y_pred_np_mv = labelfuncs.majority_vote_transform(y_pred_np, segment_inds_list)
    y_pred_mv = to_long_tensor(y_pred_np_mv).to(device)
    y_pred_mv_labels, _, _ = labelfuncs.lbl_tb2segments(y_pred_np_mv,
                                                        labelmap=labelmap,
                                                        timebin_dur=timebin_dur,
                                                        min_segment_dur=None,
                                                        majority_vote=False)
    y_pred_mv_labels = ''.join(y_pred_mv_labels.tolist())
    metric_vals_batch_mv = compute_metrics(metrics, y_true, y_pred_mv,
                                           y_true_labels, y_pred_mv_labels)
    for metric_name, metric_val in metric_vals_batch_mv.items():
        records[f'{metric_name}_majority_vote'].append(metric_val)

    # ---- min segment dur transform
    y_pred_np_mindur, _ = labelfuncs.remove_short_segments(y_pred_np,
                                                           segment_inds_list,
                                                           timebin_dur=timebin_dur,
                                                           min_segment_dur=min_segment_dur,
                                                           unlabeled_label=labelmap['unlabeled'])
    y_pred_mindur = to_long_tensor(y_pred_np_mindur).to(device)
    y_pred_mindur_labels, _, _ = labelfuncs.lbl_tb2segments(y_pred_np_mindur,
                                                            labelmap=labelmap,
                                                            timebin_dur=timebin_dur,
                                                            min_segment_dur=None,
                                                            majority_vote=False)
    y_pred_mindur_labels = ''.join(y_pred_mindur_labels.tolist())
    metric_vals_batch_mindur = compute_metrics(metrics, y_true, y_pred_mindur,
                                               y_true_labels, y_pred_mindur_labels)
    for metric_name, metric_val in metric_vals_batch_mindur.items():
        records[f'{metric_name}_min_segment_dur'].append(metric_val)

    # ---- and finally both transforms, in same order we apply for prediction
    y_pred_np_mindur_mv, segment_inds_list = labelfuncs.remove_short_segments(y_pred_np,
                                                                              segment_inds_list,
                                                                              timebin_dur=timebin_dur,
                                                                              min_segment_dur=min_segment_dur,
                                                                              unlabeled_label=labelmap[
                                                                                  'unlabeled'])
    y_pred_np_mindur_mv = labelfuncs.majority_vote_transform(y_pred_np_mindur_mv,
                                                             segment_inds_list)
    y_pred_mindur_mv = to_long_tensor(y_pred_np_mindur_mv).to(device)
    y_pred_mindur_mv_labels, _, _ = labelfuncs.lbl_tb2segments(y_pred_np_mindur_mv,
                                                               labelmap=labelmap,
                                                               timebin_dur=timebin_dur,
                                                               min_segment_dur=None,
                                                               majority_vote=False)
    y_pred_mindur_mv_labels = ''.join(y_pred_mindur_mv_labels.tolist())
    metric_vals_batch_mindur_mv = compute_metrics(metrics, y_true, y_pred_mindur_mv,
                                                  y_true_labels, y_pred_mindur_mv_labels)
    for metric_name, metric_val in metric_vals_batch_mindur_mv.items():
        records[f'{metric_name}_min_dur_maj_vote'].append(metric_val)
df = pd.DataFrame.from_records(records)




  0%|                                                                                           | 0/17 [00:00<?, ?it/s][A[A[A


  6%|████▉                                                                              | 1/17 [00:44<11:46, 44.15s/it][A[A[A


 41%|██████████████████████████████████▏                                                | 7/17 [00:44<05:09, 30.91s/it][A[A[A


 71%|█████████████████████████████████████████████████████████▉                        | 12/17 [00:44<01:48, 21.64s/it][A[A[A


100%|██████████████████████████████████████████████████████████████████████████████████| 17/17 [00:45<00:00,  2.67s/it][A[A[A


In [99]:
df

Unnamed: 0,acc,acc_majority_vote,acc_min_dur_maj_vote,acc_min_segment_dur,levenshtein,levenshtein_majority_vote,levenshtein_min_dur_maj_vote,levenshtein_min_segment_dur,segment_error_rate,segment_error_rate_majority_vote,segment_error_rate_min_dur_maj_vote,segment_error_rate_min_segment_dur,spect_path
0,0.976961,0.980618,0.981167,0.981167,17,2,1,1,0.25,0.029412,0.014706,0.014706,D:\BengaleseFinches\gy6or6\vak\spectrograms_ge...
1,0.980854,0.983544,0.983544,0.983544,8,1,1,1,0.101266,0.012658,0.012658,0.012658,D:\BengaleseFinches\gy6or6\vak\spectrograms_ge...
2,0.983028,0.984708,0.985213,0.985213,14,3,0,0,0.208955,0.044776,0.0,0.0,D:\BengaleseFinches\gy6or6\vak\spectrograms_ge...
3,0.978516,0.980505,0.980505,0.980505,11,1,1,1,0.174603,0.015873,0.015873,0.015873,D:\BengaleseFinches\gy6or6\vak\spectrograms_ge...
4,0.975496,0.978614,0.979283,0.979283,11,1,0,0,0.2,0.018182,0.0,0.0,D:\BengaleseFinches\gy6or6\vak\spectrograms_ge...
5,0.968452,0.97104,0.971526,0.971526,25,6,4,4,0.308642,0.074074,0.049383,0.049383,D:\BengaleseFinches\gy6or6\vak\spectrograms_ge...
6,0.980421,0.981889,0.981889,0.981889,10,3,3,3,0.12987,0.038961,0.038961,0.038961,D:\BengaleseFinches\gy6or6\vak\spectrograms_ge...
7,0.978057,0.98214,0.98265,0.98265,16,2,0,0,0.219178,0.027397,0.0,0.0,D:\BengaleseFinches\gy6or6\vak\spectrograms_ge...
8,0.957321,0.961261,0.964051,0.964051,34,18,12,12,0.523077,0.276923,0.184615,0.184615,D:\BengaleseFinches\gy6or6\vak\spectrograms_ge...
9,0.982392,0.98379,0.984908,0.984908,11,4,1,1,0.122222,0.044444,0.011111,0.011111,D:\BengaleseFinches\gy6or6\vak\spectrograms_ge...
