**Date created:** 2021-05-19 <br>
**Last modified:** 2021-05-19 <br>
**Description:** Main function to train networks using the principles described in the paper <br>

In [None]:
import time
import pickle

import numpy as np
import matplotlib.pyplot as plt

from dataclasses import dataclass
from random import Random
from copy import copy

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import StepLR
from sklearn.metrics import confusion_matrix, f1_score, recall_score, precision_score
from sklearn.svm import SVC

In [None]:
@dataclass
class Options:
    epochs: int =  100 # number of epochs
    batch_size: int = 128 # batch size
    epoch_limit: int = 25 # number of allowed consectutive bad epochs before aborting
    clip: float = 0.25 # gradient clipping limit
    seed: int = 99 # seed for pseduo-random utility
    ae_lr: float = 1e-2 # learning rate for autoencoder
    lr: float = 1e-4 # learning rate for classifiers
    sparse_func: str = 'none' # sparsity function: L1 | KL | none
    reg_param: float = 2e-4 # controlled amount of sparsity
    use_cuda: bool = True # user preference
    device: torch.device = torch.device("cuda") # default device
    log_interval: int = 100 # how often progress should be printed
    n_base_learners: int = 1 # how many base learners should be trained
    print_mode_on: bool = True # if progress should be printed
    store_data: bool = False # if models and plots should be stored
    store_mdl_location: str = '/stored_models' # where saved models should be stored
    store_fig_location: str = '/stored_models/figures/' # where saved figures should be stored
    
    # Below follows options specific for the ensemble_prep library
    
    directory: str = './data/' # directory contatining observations
    highway_directory: str = 'final/' # specific directory containing relevant data
    use_oversample: bool = False # if minority classes should be oversampled
    downsample: bool = False # removes already oversampled minority class instances
    scl_strategy: str = 'std' # input data scaling strategy. std or norm
    keep: float = 1.0 # proportion of complete set to be kept
    prediction_horizon: float = 3.5 # determines the cutoff of LC labels in the set (max is 5 s)
    re_express_lanes: bool = True # convert neigboring lanes to 1/0 to indicate lane existance
    split: float = 0.75 # percentage of how much data should go into training vs. testing

options = Options()

In [None]:
options.device = torch.device("cuda" if torch.cuda.is_available() and options.use_cuda else "cpu")
print(options.device)

### Plot function

In [None]:
# Function to plot the training history
def plot_history(loss_history, acc_history, save_fig=False, fig_name="Foo"):
    if acc_history['train'][0] is not None:
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16,6))
        ax1.plot(acc_history['train'],'tab:orange', lw=3, label='Train')
        ax1.plot(acc_history['test'],'tab:blue', lw=3, label='Test')
        ax1.set_xlabel("Epoch number")
        ax1.set_ylabel("Accuracy")
        ax1.set_title(f"Training Accuracy vs Test Accuracy. Max test: {max(acc_history['test'])}")
        ax1.legend()
        ax1.grid(True)

        ax2.plot(loss_history['train'],'tab:orange', lw=3)
        ax2.plot(loss_history['test'],'tab:blue', lw=3)
        ax2.set_xlabel("Epoch number")
        ax2.set_ylabel("Accuracy")
        ax2.set_title("Training loss vs Test loss")
        ax2.legend(['Training','Test'])
        ax2.grid(True)
    else:
        fig, ax = plt.subplots(figsize=(16,6))    
        ax.plot(loss_history['train'],'tab:orange', lw=3)
        ax.plot(loss_history['test'],'tab:blue', lw=3)
        ax.set_xlabel("Epoch number")
        ax.set_ylabel("Accuracy")
        ax.set_title("Training loss vs Test loss")
        ax.legend(['Training','Test'])
        ax.grid(True)
    if save_fig:
        plt.savefig(fig_name, format='pdf')
    else:
        plt.show()

## Define dataloader

In [None]:
class IntentionPredictionDataset(Dataset):
    def __init__(self, data):
        self.data = data
        self.seq_len = 20
        self.n_features = 40
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        subset = self.data[idx]
        trajectory = torch.Tensor(subset[0])
        target = torch.from_numpy(subset[1]).to(torch.float32)
        ttlc = torch.from_numpy(np.array(subset[2])).to(torch.float32)
        return trajectory, target, ttlc

## Training functions

In [None]:
def train(model, data_loader, optimizer, epoch, criterion, objective='regression', pre_process=None):
    running_loss = batch_count = 0
    accuracy = None
    recall = np.zeros(3)
    if objective == 'classification':
        accuracy = 0.
        total = correct = 0
    
    def print_progress(epoch, current_loss, iteration, ms_per_batch, accuracy=None, recall=[]):
        def rd(value, order=2):
            return round(value, order)
        
        base_msg = f'Training set ->> | Epoch: {epoch} | Iter: {iteration} |' \
                   f' ms/batch: {rd(ms_per_batch)} | Loss: {rd(current_loss)}|'
        
        if accuracy is None:
            print(base_msg)
        else:
            cf_msg = f' Acc: {rd(accuracy)} |' \
                     f' Recall: [{rd(recall[0])} :: {rd(recall[1])} :: {rd(recall[2])}]|'
            print(base_msg + cf_msg)
            
    start_time = time.time()
    model.train()
    
    for batch_idx, (data, target, _) in enumerate(data_loader):
        data, target = data.to(options.device), target.to(options.device)
        optimizer.zero_grad()
        if pre_process is not None:
            processed_data = pre_process(data)
            output = model(processed_data, data)
        else:
            processed_data = None
            output = model(data)
        
        if objective == 'regression':
            loss = criterion(output, data[:, :, 4:])

            if options.sparse_func == 'L1':
                sp_loss = l1_loss(model, data)
            elif options.sparse_func == 'KL':
                sp_loss = kl_loss(model, data)
            else:
                sp_loss = 0
            loss += options.reg_param * sp_loss
        else:
            loss = criterion(output, target)
        
        loss.backward()
        
        torch.nn.utils.clip_grad_norm_(model.parameters(), options.clip)
        optimizer.step()
        
        if objective == 'classification':
            _, y_pred = torch.max(output.data, 1)
            y_true = torch.max(target, 1)[1]
            total += target.size(0)
            correct += (y_pred == y_true).sum().item()
            accuracy = 100. * correct / total
            for k in range(len(recall)):
                recall[k] += recall_score(y_true.cpu(), y_pred.cpu(), labels=[k], average='micro', zero_division=0)

        running_loss += loss.item()
        batch_count += 1
        
        del data, processed_data, target
        
        if batch_idx % options.log_interval == 0 and batch_idx > 0:
            current_loss = running_loss / (batch_idx + 1)
            iteration = batch_idx + 1 + epoch * len(data_loader)
            ms_per_batch = (time.time() - start_time)*1000. / options.log_interval
            
            if options.print_mode_on:
                print_progress(epoch, current_loss, iteration, ms_per_batch, accuracy, recall/batch_count)
            
            start_time = time.time()
    return running_loss / (batch_idx + 1), accuracy

In [None]:
def test(model, data_loader, criterion, objective='regression', pre_process=None):
    running_loss = batch_count = 0
    accuracy = None
    recall = np.zeros(3)
    if objective == 'classification':
        accuracy = 0.
        total = correct = 0
    
    def print_progress(test_loss, accuracy=None, recall=[]):
        def rd(value, order=2):
            return round(value, order)
        
        base_msg = f'Test set ->> | loss: {rd(test_loss)} |'
        if accuracy is None:
            print(base_msg)
        else:
            cf_msg = f' Acc: {rd(accuracy)} |' \
                     f' Recall: [{rd(recall[0])} :: {rd(recall[1])} :: {rd(recall[2])}]|'
            print(base_msg + cf_msg)

    model.eval()
    with torch.no_grad():
        for data, target, _ in data_loader:
            data, target = data.to(options.device), target.to(options.device)
            if pre_process is not None:
                processed_data = pre_process(data)
                output = model(processed_data, data)
            else:
                processed_data = None
                output = model(data)

            if objective == 'classification':
                loss = criterion(output, target)
                
                _, y_pred = torch.max(output.data, 1)
                y_true = torch.max(target, 1)[1]
                total += target.size(0)
                correct += (y_pred == y_true).sum().item()
                accuracy = 100. * correct / total
                for k in range(len(recall)):
                    recall[k] += recall_score(y_true.cpu(), y_pred.cpu(), labels=[k], average='micro', zero_division=0)
            else:
                loss = criterion(output, data[:, :, 4:])

            running_loss += loss.item()
            batch_count += 1

            del data, processed_data, target
    
    test_loss = running_loss / (batch_count)
    if options.print_mode_on:
        print_progress(test_loss, accuracy, recall/batch_count)
    return test_loss, accuracy

In [None]:
def train_svm(pre_process, train_loader, name, cat_feature_encoder=None):
    gen = iter(train_loader)
    data, target, _ = next(gen)

    with torch.no_grad():
        if feature_encoder is not None:
            categorical, _ = cat_feature_encoder(data.to(options.device))
        else:
            categorical = None
        X = pre_process(data.to(options.device))
    y_train = torch.max(target, 1)[1].cpu().detach().numpy()
    
    if categorical is not None:
        X = torch.cat((X, categorical), dim=1)
        
    X_train = X.cpu().detach().numpy()
    
    svc=SVC(probability=True, kernel='linear')
    svc.fit(X_train, y_train)
    
    svc_name = options.store_mdl_location + "SVMClassifier-" + name + ".sav"

    if options.store_data:
        pickle.dump(svc, open(svc_name, 'wb'))
        
#     gbc = HistGradientBoostingClassifier().fit(X_train, y_train)
#     gbc_name = "./models/GBC" + name + ".sav"
    
#     if options.store_data:
#         pickle.dump(gbc, open(gbc_name, 'wb'))
    
    if options.print_mode_on:
        print('Sci-kit done.')

In [None]:
def l1_loss(autoencoder, values):
    loss = 0
    encoder, decoder = list(autoencoder.children())
    encoder = list(encoder.children())
    
    _, (hidden, _) = encoder[0](values)
    values = hidden[-1,:,:]
    values = encoder[1](values)
    values = F.relu(values)

    loss += torch.sum(values.abs(), dim=1).mean()

    return loss

In [None]:
def kl_loss(autoencoder, values):
    loss = 0
    RHO = 5e-2
    encoder, _ = list(autoencoder.children())
    encoder = list(encoder.children())
    
    _, (hidden, _) = encoder[0](values)
    values = hidden[-1,:,:]
    values = encoder[1](values)
    values = F.relu(values)
    loss += kl_divergence(RHO, values)
    return loss

In [None]:
def kl_divergence(rho, rho_hat):
    rho_hat = torch.mean(torch.sigmoid(rho_hat), 1) # sigmoid because we need the probability distributions
    rho = torch.tensor([rho] * len(rho_hat)).to(device)
    return torch.sum(rho * torch.log(rho/rho_hat) + (1 - rho) * torch.log((1 - rho)/(1 - rho_hat)))

In [None]:
def main(model, name, criterion, train_loader, test_loader, objective='regression',
         learning_rate=options.lr, pre_process=None, schedule=False):
    n_bad_epochs = best_epoch = best_acc = 0
    
    best_loss = float('inf')
    loss_history = {'train':[], 'test':[]}
    acc_history = {'train':[], 'test':[]}
    
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
    
    if schedule:
        scheduler = StepLR(optimizer, step_size=10, gamma=0.5)
    
    model_name = options.store_mdl_location + name + ".pth"
    figure_name = options.store_fig_location + name + ".pdf"
    
    try:
        for epoch in range(1, options.epochs+1):
            train_loss, train_acc = train(model, train_loader, optimizer, epoch, criterion,
                                          objective=objective, pre_process=pre_process)
            test_loss, test_acc = test(model, test_loader, criterion,
                                       objective=objective, pre_process=pre_process)
            
            # Store intermediate results
            acc_history['train'].append(train_acc)
            acc_history['test'].append(test_acc)
            
            loss_history['train'].append(train_loss)
            loss_history['test'].append(test_loss)
            
            if test_loss < best_loss:
                n_bad_epochs = 0
                best_loss = test_loss
                if objective == 'regression':
                    best_epoch = epoch
                    if options.store_data:
                        torch.save(model.state_dict(), model_name)
            else:
                if objective == 'regression':
                    n_bad_epochs += 1
            
            if objective == 'classification':
                if test_acc > best_acc:
                    n_bad_epochs = 0
                    best_acc = test_acc
                    best_epoch = epoch
                    if options.store_data:
                        torch.save(model.state_dict(), model_name)  
                else:
                    n_bad_epochs += 1
                
            if n_bad_epochs >= options.epoch_limit:
                print(f'Number of consecutive bad epochs exceeded ({options.epoch_limit}). Employing early stopping...')
                break
                
            if epoch % 10 == 0:
                if objective == 'classification':
                    print('\n Historically best test accuracy: \x1b[31m{:5.2f}% \x1b[0m on epoch: {}\n'.format(best_acc, best_epoch))
                else:
                    print('\n Historically best test loss: \x1b[31m{:5.2f}% \x1b[0m on epoch: {}\n'.format(best_loss, best_epoch))
            
            if schedule:
                scheduler.step()
    except KeyboardInterrupt:
        print("Training interrupted early...")
    finally:
        del criterion, optimizer
        torch.cuda.empty_cache()
    
    print('Finished Training! \n')
    
    plot_history(loss_history, acc_history, save_fig=options.store_data, fig_name=figure_name)
    return model

## Import Data

In [None]:
from ensemble_prep import *

In [None]:
in_data = EnsemblePrep(options)

In [None]:
in_data.main()

# Train

In [None]:
def ensemble_training(network_class, input_data, objective='classification'):
    torch.manual_seed(input_data.seed)
    input_data.init_scrambler()
    print(f'Training network with {options.n_base_learners} ensemble(s)')
    for n in range(1, options.n_base_learners + 1):
        train_data, test_data = input_data.get_train_val()
        model = network_class().to(options.device)
        name = type(model).__name__ + "-" + str(n)
        
        train_loader = DataLoader(IntentionPredictionDataset(train_data), batch_size=options.batch_size,
                                  shuffle=True,  drop_last=True)
        test_loader = DataLoader(IntentionPredictionDataset(test_data), batch_size=len(test_data),
                                  shuffle=True,  drop_last=True)
        
        criterion = nn.BCELoss()
        model = main(model, name, criterion, train_loader, test_loader, objective)
        del model
        torch.cuda.empty_cache()

In [None]:
%run ./mdl-implementation/VanillaPredictor.ipynb

In [None]:
ensemble_training(VanillaCNN, in_data)

### AutoEncoder

In [None]:
%run ./mdl-implementation/AutoEncoder.ipynb
%run ./mdl-implementation/AEClassifier.ipynb

In [None]:
def ae_ensemble_training(auto_encoder, ae_classifier, input_data, SVM=False, use_cat=False):
    torch.manual_seed(input_data.seed)
    input_data.init_scrambler()
    embedding_size = 128
    
    if use_cat:
        feat_encode = CategoricalFeatureEncoder().to(device)
        feat_encode.eval()
    
    print(f'Training {options.n_base_learners} base learner(s)')
    for n in range(1, options.n_base_learners + 1):
        train_data, test_data = input_data.get_train_val()
        ae_model = auto_encoder(embedding_size=embedding_size).to(options.device)
        name = type(ae_model).__name__ + "-" + str(n) + "-" + str(embedding_size)
        
        train_loader = DataLoader(IntentionPredictionDataset(train_data), batch_size=options.batch_size,
                                  shuffle=True,  drop_last=True)
        test_loader = DataLoader(IntentionPredictionDataset(test_data), batch_size=len(test_data),
                                  shuffle=True,  drop_last=True)
        
        criterion = nn.SmoothL1Loss()
        ae_model = main(ae_model, name, criterion, train_loader, test_loader,
                        learning_rate=options.ae_lr, objective='regression')
        
        
        for p in ae_model.parameters():
            p.requires_grad = False
        
        cf_model = AEClassifier(n_inputs=embedding_size).to(options.device)
        name = type(cf_model).__name__ + "-" + str(n) + "-" + str(embedding_size)
        
        criterion = nn.BCELoss()
        cf_model = main(cf_model, name, criterion, train_loader, test_loader,
                        objective='classification', pre_process=ae_model.encoder, schedule=True)
        
        if SVM:
            train_loader = DataLoader(IntentionPredictionDataset(train_data), 
                                      batch_size=len(train_data), shuffle=True,  drop_last=True)
            add_name = str(n) + "-" + str(embedding_size)
            if use_cat:
                train_svm(ae_model.encoder, train_loader, add_name, feat_encode)
            else:
                train_svm(ae_model.encoder, train_loader, add_name)
                
        
        del ae_model, cf_model
        torch.cuda.empty_cache()

In [None]:
#ae_ensemble_training(AutoEncoderDecoder, AEClassifier, in_data)