# Yaml files of Stratified Split for Training and Prediction

With this notebook, you can create the yaml files needed in training and prediction with a data split which is made  using **stratifiction**. More detailed information about the stratified data split itself is available in the notebook [Introductions for Data Handling](1_introductions_data_handling.ipynb).

------

In [1]:
import os

# PARAMETERS TO CREATE STRATIFIED YAML FILES  
# ------------------------------------------

# From where to load the csv files of stratified split
csv_path = os.path.join('../data/split_csvs/', 'physionet_stratified_smoke')

# Where to save the training yaml files
train_yaml_save_path = os.path.join('../configs/training', 'train_stratified_smoke')

# Where to save the testing yaml files
test_yaml_save_path = os.path.join('../configs/predicting', 'prediction_stratified_smoke')

# Parameters for training yaml files
batch_size = 10
num_workers = 0
epochs = 1

The csv files of the stratified splits should be found `/data/split_csvs/`. 

In [2]:
import os

# Stratified csv files
csv_files = sorted([file for file in os.listdir(csv_path) if not file.startswith('.')])

print(*csv_files, sep='\n')

test_split0.csv
train_split0_0.csv
train_split0_1.csv
train_split0_2.csv
train_split0_3.csv
val_split0_0.csv
val_split0_1.csv
val_split0_2.csv
val_split0_3.csv


Let's combine the right training, validation and prediction splits first, e.g., `train_split0_0.csv`, `val_split0_0.csv` and `test_split0.csv`.

In [3]:
import re

# First, divide train and validation splits into own lists
train_files = [file for file in csv_files if 'train' in file]
val_files = [file for file in csv_files if 'val' in file]

# Zip these two and convert to list since they should be sorted similarly
train_val_pair = list(zip(train_files, val_files))
print('First 5 training and validatin pairs')
print(*train_val_pair[:5], sep='\n')
print()

# Seems right based on the print:
# Add the prediction fi
test_files = [file for file in csv_files if 'test' in file]

split_nums = [] # These are for yaml files!!
train_val_test = []
for i, pair in enumerate(train_val_pair):
    
    # Training and validation files separately
    train_tmp, val_tmp = train_val_pair[i][0], train_val_pair[i][1]
    
    # Get the split number of training file
    split_num = re.search('_((\w*)_\d)', pair[0])
    split_nums.append(str(split_num.group(1) + '.yaml')) # For yaml files!!
    
    train_split_num = split_num.group(2)
    for test_tmp in test_files:
        # Get the split number of testing file
        test_split_num = re.search('_(\w*)', test_tmp).group(1)
        
        # If same split number in training, validation and prediction, combine
        if train_split_num == test_split_num:
            train_val_test.append([train_tmp, val_tmp, test_tmp])
            
print('Training, validation and testing pairs')
print(*train_val_test, sep='\n')
print()

print('Total of {} training, validation and testing sets'.format(len(train_val_test)))

First 5 training and validatin pairs
('train_split0_0.csv', 'val_split0_0.csv')
('train_split0_1.csv', 'val_split0_1.csv')
('train_split0_2.csv', 'val_split0_2.csv')
('train_split0_3.csv', 'val_split0_3.csv')

Training, validation and testing pairs
['train_split0_0.csv', 'val_split0_0.csv', 'test_split0.csv']
['train_split0_1.csv', 'val_split0_1.csv', 'test_split0.csv']
['train_split0_2.csv', 'val_split0_2.csv', 'test_split0.csv']
['train_split0_3.csv', 'val_split0_3.csv', 'test_split0.csv']

Total of 4 training, validation and testing sets


From the sets above we are going to create the yaml files. The base of the training yaml is as follows

```
# INITIAL SETTINGS
train_file: train_split0_0.csv
val_file: val_split0_0.csv

# TRAINING SETTINGS
batch_size: 10
num_workers: 0

# SAVE, LOAD AND DISPLAY INFORMATION
epochs: 1

```

and of the prediction yaml file as follows

```
# INITIAL SETTINGS
test_file: test_split0.csv
model: split0_0.pth
```

*Feel free to set the attributes for training settings and other information as you want in the very first code chunk.* 

<font color = red>**NOTE!**</font> (*Consider only if you have already created all the csv files of different stratified splits.*) Feel also free to create only a part of the yaml files. All the train-val-test sets are listed in the variable `train_val_test` so it's easy to iterate over only a part of it. If you want to part it, remember to consider `split_nums` too since the yaml files will be named after it.

In [4]:
# List of the train-val-test splits zipped with split numbers
# Feel free to manipulate the all-in list!

pair_and_split = list(zip(train_val_test, split_nums))

# NB! pair_and_split will be then be iterated so you need this
#     attribute in the next code chunk (in for-loop)

In [5]:
import os
from ruamel.yaml import YAML
import ruamel.yaml
import sys
  
def save_yaml(yaml_str, yaml_path, split):
    ''' Save the given string as a yaml file in the given location.
    '''
    
    # Make the yaml directory
    if not os.path.isdir(yaml_path):
        os.mkdir(yaml_path)
    
    # Write the yaml file
    with open(os.path.join(yaml_path, split), 'w') as yaml_file:
        yaml = YAML()
        code = yaml.load(yaml_str)
        yaml.dump(code, yaml_file)
    
        
def create_testing_yaml(test_csv, split):
    ''' Make a yaml file for prediction. The base of it is presented above.
    '''
    # The name of the model
    # e.g. trained with a yaml file ´split0_0_smoke.yaml´
    #      model saved as `split0_0_smoke.pth`
    model_name = split.split('.')[0] + '.pth'
    
    yaml_str = '''\
# INITIAL SETTINGS
    test_file: {}
    model: {}
    '''.format(test_csv, model_name)
    yaml_path = test_yaml_save_path
    save_yaml(yaml_str, yaml_path, split)
    

def create_training_yaml(train_csv, val_csv, split):
    ''' Make a yaml file for training. The base of it is presented above.
    '''
    yaml_str = '''\
# INITIAL SETTINGS
    train_file: {}
    val_file: {}

# TRAINING SETTINGS
    batch_size: {}
    num_workers: {}

# SAVE, LOAD AND DISPLAY INFORMATION
    epochs: {}
    '''.format(train_csv, val_csv,
              batch_size, num_workers, epochs)
    yaml_path = train_yaml_save_path
    save_yaml(yaml_str, yaml_path, split)

    
for pair, split in pair_and_split:
    train_tmp, val_tmp, test_tmp = pair[0], pair[1], pair[2]
    
    print('Training, validation and testing set is')
    print(train_tmp.split('.')[0], '\t', val_tmp.split('.')[0], '\t', test_tmp.split('.')[0])
    print('Yaml file will be named as', split)
    print()
    
    # Training yaml file
    create_training_yaml(train_tmp, val_tmp, split)
    
    # Testing yaml file
    create_testing_yaml(test_tmp, split)

Training, validation and testing set is
train_split0_0 	 val_split0_0 	 test_split0
Yaml file will be named as split0_0.yaml

Training, validation and testing set is
train_split0_1 	 val_split0_1 	 test_split0
Yaml file will be named as split0_1.yaml

Training, validation and testing set is
train_split0_2 	 val_split0_2 	 test_split0
Yaml file will be named as split0_2.yaml

Training, validation and testing set is
train_split0_3 	 val_split0_3 	 test_split0
Yaml file will be named as split0_3.yaml



Now all the yaml files for training, validation and prediction are created! The training yaml files are located in `/configs/training/train_stratified_smoke/` names as `split0_0.yaml`, `split0_1.yaml`. `split0_2.yaml` and `split0_3.yaml`, and the prediction yaml files in `/configs/predicting/prediction_stratified_smoke/` named with the same names.

<font color=red>**NOTE 1!**</font> It is extremely important that in the test yaml file the model is set with the same name as the yaml file which the model is trained with. E.g. when a model is trained using `split0_0.yaml`, it will be saved as `split0_0.pth`. This makes using the repository much easier and simpler. Mind this, if you want to edit the code below.

<font color=red>**NOTE 2!**</font> If you are now wondering why the yaml files don't have the csv values --- `train_file`, `val_file` and `test_file` --- in single quotation marks, it's ok. Scripts are able to read and load the values from the yaml files even without those marks!