**Step 1: Load yass**

In [None]:
import yass
from yass import read_config
from yass.augment import make_training_data, save_detect_network_params, save_triage_network_params, save_ae_network_params
from yass.neuralnetwork import train_detector, train_ae, train_triage

**Step 2: Read Configuration File**

In [None]:
yass.set_config("location/to/config.yaml")
CONFIG = read_config()

**Step 3: Load Spike Train**

To train the Neural Network, you need to have a recording with sorted result. The result does not need to be perfect.
If you don't have any sorting result yet, you can run yass with threshold detection option. In your configuration file, set spikes.detection = threshold.

spike_train is a matrix of size (number of spikes x 2). Each row represents an individual spike. The first column is the spike time (not in milliseconds or seconds but in actual temporal location in recording). The second column is the spike ID.

In [None]:
# this is an example. Load your spike_train using your own way
import numpy as np
spike_train = np.loadtxt('path/to/csv/file.csv', dtype='int32')

**Step 4: Make Training Dataset**

1. CONFIG and spike_train are from step 2 and 3.

2. chosen_templates: It is a vector containing which templates to use. Given spike sorting result, not all templates look good. Therefore, the training dataset should be obtained from good looking templates only. Make sure that you do not include bad templates. However, it is still important to keep variability in template shapes. To visually check templates, check optional step at the bottom.

3. min_amp: the minimum of absolute maximal amplitude of augmented spikes. It should determine how small spikes in the training set can be. Default is 3.

4. nspikes: approximately how many training data it should produce?

In [None]:
min_amp = 5
nspikes = 50000
chosen_templates = [0, 1, 2, 3, 5, 10] # should be your own number

In [None]:
x_detect, y_detect, x_triage, y_triage, x_ae, y_ae = make_training_data(CONFIG, spike_train, chosen_templates, min_amp, nspikes)

**Step 5: Train All Three Neural Networks**

Training parameters:
1. n_iter: the number of iterations to run
2. n_batch: the size of mini-batch to be used for training
3. l2_reg_scale: L2 regularization penalty term
4. train_step_size: training step size

In [None]:
n_iter = 5000
n_batch = 512
l2_reg_scale = 0.00000005
train_step_size =  0.001

Training neural net detector
1. detectnet_name: name of saved model with the location to save 
2. n_filters: number of filters to use in each layer. It should be a list of length 2

In [None]:
detectnet_name = '/location/you/want/test_detect_nn.ckpt'
n_filters_detect = [16, 8]

In [None]:
# run training
train_detector(x_detect, y_detect, n_filters_detect, n_iter, n_batch, l2_reg_scale, train_step_size, detectnet_name)

In [None]:
# save model parameters
save_detect_network_params(filters = n_filters_detect,
                           size = x_detect.shape[1],
                           n_neighbors = x_detect.shape[2],
                           output_path = detectnet_name.replace('ckpt', 'yaml'))

Training neural net triage
1. triagenet_name: name of saved model with the location to save 
2. n_filters: number of filters to use in each layer. It should be a list of length 2

In [None]:
triagenet_name = '/location/you/want/test_triage_nn.ckpt'
n_filters_triage = [16, 8]

In [None]:
# run training
train_triage(x_triage, y_triage, n_filters_triage, n_iter, n_batch, l2_reg_scale, train_step_size, triagenet_name)

In [None]:
# save model parameters
save_triage_network_params(filters = n_filters_triage,
                           size = x_detect.shape[1],
                           n_neighbors = x_detect.shape[2],
                           output_path = triagenet_name.replace('ckpt', 'yaml'))

Training autoencoder
1. ae_name: name of saved model with the location to save 
2. n_feature: number of latent variables

In [None]:
ae_name = '/location/you/want/test_ae_nn.ckpt'
n_features = 3
n_batch = x_ae.shape[0]

In [None]:
# run training
train_ae(x_ae, y_ae, n_features, n_iter, n_batch, train_step_size, ae_name)

In [None]:
# save model parameters
save_ae_network_params(n_input = x_ae.shape[1],
                       n_features = n_features,
                       output_path = ae_name.replace('ckpt', 'yaml'))

**You are done!**

**Step 6: When Using yass**

Make sure that you have all your files! You must have **3 '.ckpt'** files and **1 '.yaml'** file for **each neural network model**, which make **total 12 files**.

Also, make sure that the parameters in your configuration file match with the parameters used during the training

| Name in config.yaml | How it should change |
|---|---|
|spikes.temporal_features|n_feature used for training autoencoder|
|recordings.spike_size_ms|make sure that this value stays the same as configuration loaded here|
|neural_network_detector.filename|file name used above to save neural net detector|
|neural_network_triage.filename|file name used above to save neural net triage|
|neural_network_autoencoder.filename|file name used above to save neural net autoencoder|

**Step 3.2: Visually Inspect Templates (optional)**

In [None]:
import os
%matplotlib inline
import matplotlib.pyplot as plt
from yass.augment.templates import get_templates
from yass.augment.process import process_data

In [None]:
# process data to get standardized recording
process_data(CONFIG)

# prameters
path = os.path.join(CONFIG.data.root_folder,  'tmp/standarized.bin')
dtype = 'float64'

# make templates
templates = get_templates(spike_train, CONFIG.batch_size, 
                          CONFIG.BUFF, CONFIG.nBatches, 
                          CONFIG.recordings.n_channels, 
                          CONFIG.spikeSize*2, 
                          path, dtype)

In [None]:
for k in range(templates.shape[2]):
    plt.plot(templates[:,:,k].T)
    plt.title(str(k))
    plt.show()