# Binary Classification Tutorial
This notebook demonstrates how to train and test a binary classifier.  The binary classification problem is to determine if a signal is BPSK or QPSK using a simple CNN-based classifier.

## Load Packages
Load packages that will be used throughout this tutorial.

In [None]:
import sys
sys.path.append("..")
# General python packages
import os
import matplotlib.pyplot as plt

# metrics from sci-kit learn
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

# pytorch packages
import torch
from torch.utils.data import DataLoader, random_split
import torch.nn as nn

# custom tutorial packages
from rfml_ed_material.utils.data_utils import IQ_Dataset, IQ_data_gen, create_signal_jsons
from rfml_ed_material.models.cnn_model import CNN_RF
from rfml_ed_material.utils.train_utils import train_func, predict_func

## Data Generation Parameters
Define the data generation parameters for Py-waspgen.

In [None]:
num_seq = 5000        # number of sequences per signal type
seq_len = 256         # length of each sequence
bandwidth = 0.5       # bandwidth
cent_freq = 0.0       # center frequency
start = 0             # signal start time
duration = seq_len    # signal duration
snr = 10              # signal to noise ratio

Define the signal types in the form of a list - BPSK and QPSK

In [None]:
signal_list = [{"format": "psk", "order": 2, "label": "BPSK"},
               {"format": "psk", "order": 4, "label": "QPSK"}]

Py-waspgen loads configuration information from json files.  Use the *create_signal_jsons* function to create the py-waspgen configuration files.

The following cell first checks if the "Configs'' directory exists and, if it does not, it creates the directory for storing the json configuration files.

In [None]:
if not os.path.isdir('configs'):
    os.mkdir('configs')

create_signal_jsons('configs',
                    signal_list,
                    observation_duration=seq_len,
                    cent_freq=[cent_freq, cent_freq],
                    bandwidth=[bandwidth, bandwidth],
                    start=[start, start],
                    duration=[seq_len, seq_len],
                    snr=[snr, snr])

## Generate Data
Use py-waspgen to generate data for training, validation, and testing a binary classifier.

Create the list of configuration files.

In [None]:
signal_filenames = ['configs/BPSK.json', 'configs/QPSK.json']

Generate data using wrapper function from *data_utils*.

In [None]:
data, labels, label_dict = IQ_data_gen(signal_filenames, num_seq, seq_len)

## Prepare Data for Training

Pytorch uses a data set class for the data.  See the pytorch documentation for more details.

https://pytorch.org/tutorials/beginner/basics/data_tutorial.html

In [None]:
rf_dataset = IQ_Dataset(data, labels, label_dict)

The data set should be split into training, validation, and testing sets.  The training set is used to train the model.  The validation set is used to track the model's performance during training and monitor for overfitting.  The test set is used for final evaluation of the model.  

The cell below uses the *random_split* pytorch function.

In [None]:
splits = [0.8, 0.1, 0.1]   # proportion for train, validation, and test sets
rf_train, rf_val, rf_test = random_split(rf_dataset, splits)

Pytorch uses data loaders to batch the data.  The cell below sets the batch size parameter and creates dataloaders for the training, validation, and test sets.

In [None]:
batch_size = 256           # batch size for dataloader

# create dataloader
train_dataloader = DataLoader(rf_train, batch_size=batch_size, shuffle=True)
val_dataloader = DataLoader(rf_val, batch_size=batch_size, shuffle=False)
test_dataloader = DataLoader(rf_test, batch_size=batch_size, shuffle=False)

## Train Pytorch Model

Establish the Pytorch CNN model.

In [None]:
model = CNN_RF(len(signal_list))

Set the learning rate for the training process.

In [None]:
learning_rate = 0.0001  # learning rate for optimizer

Define the loss function and the optimizer.

In [None]:
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

Training Parameters

In [None]:
epochs = 200        # number of training epochs
print_every_n = 10  # print loss every n epochs

In [None]:
model, train_loss, val_loss = train_func(model,
                                         optimizer,
                                         loss_fn,
                                         train_dataloader,
                                         val_dataloader,
                                         epochs,
                                         print_every_n)

## Evaluate Training

Plot training and validation loss. This should be done after training or monitored during training.  This confirms that the model learning has stablized and that the model is not overfit.  The latter can be seen because the validation loss is not increasing.

In [None]:
fig, ax = plt.subplots()
ax.plot(train_loss, color='b', label="Train")
ax.plot(val_loss, color='r', label="Validation")
ax.set_xlabel('Epochs')
ax.set_ylabel('Loss')
ax.set_title('Training and Validation Loss')
fig.legend()

### Evaluation Metrics

Use the custom *predict_func* to extract the true targets from the test data loader and the predicted values for the test set from the learned model.

In [None]:
y_test, y_pred = predict_func(model, test_dataloader)

When evaluating binary classifiers, there are four possible combinations of true and predicted values for each observation.
These are displayed in the table below.

<img src="resources/binary_classification_results.png" width="600" align="left">

During evaluation, one can count the number of combinations in each quadrant of the table.  Numerous performance metrics can be calculated from the counts in this table.

Accuracy is the sum of the diagonal divided by the sum in all four quadrants.

$$ Accuracy = \frac{TP}{TP + FP + TN + FN} $$

In [None]:
accuracy = accuracy_score(y_test, y_pred)
print('Accuracy: ', accuracy)

Precision measures the rate at which observations classified by the model as Postive are correct.  It is calculated by dividing the number of true positives by the sum of the true positives and false positives (all Positive examples classified by the model).

$$ Precision = \frac{TP}{TP + FP} $$

In [None]:
precision = precision_score(y_test, y_pred)
print('Precision: ', precision)

Recall measures the rate at which relevant (Postive) observations are classified by the model as Positive.  It is calculated by dividing the number of true positives by the sum of the true positives and false negatives (all Postive examples in the test set).

$$ Recall = \frac{TP}{TP + FN} $$

In [None]:
recall = recall_score(y_test, y_pred)
print('Recall: ', recall)

The F1 score is a combination of precision and recall.  

$$ F1 = 2 \frac{precision * recall}{precision+recall} $$

In [None]:
f1 = f1_score(y_test, y_pred)
print('F1: ', f1)