## Create examples of network output for figure panels
Created by: Yarden Cohen\
Date: June 2021\
This notebook allows loading specific saved TweetyNet models and examining their outputs.
Cells in this notebook will also hold code to create figure panels showing such network outputs.

In [1]:
# imports
from argparse import ArgumentParser
import configparser  # used to load 'min_segment_dur.ini'

from collections import defaultdict
import json
from pathlib import Path

import joblib
import numpy as np
import pandas as pd
import pyprojroot
import torch
from tqdm import tqdm

from vak import config, io, models, transforms
from vak.datasets.vocal_dataset import VocalDataset
import vak.device
import vak.files
from vak.labeled_timebins import lbl_tb2segments, majority_vote_transform, lbl_tb_segment_inds_list,     remove_short_segments
from vak.core.learncurve import train_dur_csv_paths as _train_dur_csv_paths
from vak.logging import log_or_print

In [None]:
def load_network_results(path_to_config=None,
                        csv_path=None,
                        labelmap_path=None,
                        checkpoint_path=None,
                        window_size = 370,
                        min_segment_dur = 0.01,
                        num_workers = 12,
                        device='cuda',
                        spect_key='s',
                        timebins_key='t',
                        freq_key = 'f',
                        test_all_files=False):
    if path_to_config:
        

In [None]:
# Choose and indicate the checkpoint of a saved model and a csv defining data
# This is done without a config file to allow flexibility

csv_path = "D:\\Users\\yarde\\vak_project\\Koumura2016\\Bird1\\config_BirdsongRecognition_bird01_w700_eval.toml"
labelmap_path = Path()
checkpoint_path = ""

# config parameters:
window_size = 370
min_segment_dur = 0.01
num_workers = 12
device='cuda'
spect_key='s'
timebins_key='t'
freq_key = 'f'
#spect_standardizer = 
with labelmap_path.open('r') as f:
        labelmap = json.load(f)

In [None]:
# read a csv and create a new one with all splits marked as 'test'
temp_df = pd.read_csv(csv_path)
temp_df['split'] = 'test'
temp_df.to_csv(csv_path.parent.joinpath(csv_path.stem + '_test.csv'))

In [None]:
# prepare evaluation data
item_transform = transforms.get_defaults('eval',
                                             spect_standardizer=None,
                                             window_size=window_size,
                                             return_padding_mask=True,
                                             )

eval_dataset = VocalDataset.from_csv(csv_path=csv_path,
                                     split='test',
                                     labelmap=labelmap,
                                     spect_key=spect_key,
                                     timebins_key=timebins_key,
                                     item_transform=item_transform,
                                     )

eval_data = torch.utils.data.DataLoader(dataset=eval_dataset,
                                        shuffle=False,
                                        # batch size 1 because each spectrogram reshaped into a batch of windows
                                        batch_size=1,
                                        num_workers=num_workers)

In [None]:
# Create model
model_config_map = {'TweetyNet': {'loss': {}, 'metrics': {}, 'network': {}, 'optimizer': {'lr': 0.001}}}
input_shape = eval_dataset.shape
# if dataset returns spectrogram reshaped into windows,
# throw out the window dimension; just want to tell network (channels, height, width) shape
if len(input_shape) == 4:
    input_shape = input_shape[1:]

models_map = models.from_model_config_map(
    model_config_map,
    num_classes=len(labelmap),
    input_shape=input_shape
)
model_name = 'TweetyNet'
model = models_map['TweetyNet']
model.load(checkpoint_path)
#metrics = model.metrics  # metric name -> callable map we use below in loop
if device is None:
    device = vak.device.get_default_device()
pred_dict = model.predict(pred_data=eval_data,
                          device=device)

In [None]:
# Create data matrices for the model output time aligned to annotated model input
file_number = 0
time_window = [0.01,1.01]
freq_window = [500.0,4000.0]

annotation_df = pd.DataFrame(eval_dataset.annots[file_number].seq.as_dict())
csv_df = pd.read_csv(eval_dataset.csv_path)
spect_path = Path(np.array(csv_df[csv_df.split=='test'].spect_path)[file_number])
spect = vak.files.spect.load(spect_path)[spect_key]
t_vec = vak.files.spect.load(spect_path)[timebins_key]
f_vec = vak.files.spect.load(spect_path)[freq_key]
timebin_indices = np.where((t_vec >= time_window[0]) & (t_vec <= time_window[1]))[0]
extent = [np.min(t_vec),np.max(t_vec),np.min(f_vec),np.max(f_vec)]
model_output = pred_dict[spect_path]

In [None]:
# plot
figsize = (5,10)
fig, axs = plt.subplots(3,figsize=figsize)
fig.suptitle('Example file ' + spect_path)
axs[1].imshow(spect,aspect='auto',extent=extent)
axs[1].set_xlim(time_window)
axs[1].set_ylim(freq_window)
axs[2].imshow(model_output,aspect='auto',extent=[extend[0],extent[1],0,np.shape(model_output)[1]])


