In [None]:
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
import torchmetrics
from torchvision import transforms as T
import torch.nn as nn
import torch.nn.functional as F
from torch.optim.lr_scheduler import ReduceLROnPlateau
import numpy as np

from plotting import cfm, loss_acc_plot, plot_pr_roc_curve

from models import select_model
from dataloader import get_dataset
from focalloss import Focal_Loss
from config import Config


In [None]:
class Train_Model():
    def __init__(self, config, train_loader, val_loader, model):
        self.config = config
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.model = model

        self.net = select_model(model, config.pretrained)
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        self.net = self.net.to(self.device)
        self.epochs = self.config.epochs
        self.class_weights = class_weights=torch.tensor([1.061, 3.322,1.239,0.797,1.542,0.688,0.628],dtype=torch.float).to(self.device)
        
        # initialize metric
        self.metric = torchmetrics.Accuracy()

        
        #initialize the loss function and opimizer
        if self.config.lossfunction == 'FocalLoss':
            self.criterion = Focal_Loss(self.config.fl_gamma)
        else:
            self.criterion = nn.CrossEntropyLoss(weight=self.class_weights)
        if self.config.optimizer == 'Adam':
            self.optimizer = optim.Adam(self.net.parameters(),
                                        lr=self.config.lr,
                                        weight_decay=self.config.weight_decay)
        else:
            self.optimizer = optim.SGD(self.net.parameters(),
                                       lr=self.config.lr,
                                       momentum=self.config.momentum,
                                       weight_decay=self.config.weight_decay)
        
        #initialize the scheduler, learning rate, batch size, train patience and the number of classes
        self.scheduler = ReduceLROnPlateau(optimizer=self.optimizer , mode='min', patience=10, min_lr=1e-5,factor=0.5)
        self.lr = self.config.lr
        self.batch_size = self.config.batch_size
        self.train_paitence = config.train_paitence
        self.num_classes = len(self.config.class_mapping)

    # functions for training the model
    def train_single_epoch(self):
        for file_data, target in self.train_loader:
            file_data, target = file_data.to(self.device), target.to(self.device)

            # calculate loss
            prediction = self.net(file_data)
            loss = self.criterion(prediction, target)
            acc = self.metric(prediction, target)*100

            # backpropagate error and update weights
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
            del(file_data)
            del(target)

        print(f"Loss: {loss.item()} | Acc: {acc}%")

        return loss, acc

    # functions for training the model
    def validation(self):
        with torch.no_grad():
            for file_data, target in self.val_loader:
                file_data, target = file_data.to(self.device), target.to(self.device)

                # calculate loss
                prediction = self.net(file_data)
                valid_loss = self.criterion(prediction, target)

                valid_acc = self.metric(prediction, target)*100
                del(file_data)
                del(target)

            print(f"Val loss: {valid_loss.item()} | Val acc: {valid_acc}%")

        return valid_loss,valid_acc

    def train(self):

        global min_valid_loss
        loss_history = []  # stores all loss values
        valloss_history = []
        acc_history = []  # stores all accuracy values
        valacc_history = []
        counter = []  # stores all the iterations
        count_no_inprov = 0
        min_valid_loss = np.inf

        for i in range(self.epochs):
            print(f"Epoch {i + 1}")
            loss, acc = self.train_single_epoch()
            valid_loss,valid_acc = self.validation()
            self.scheduler.step(valid_loss)
            counter.append(i)
            loss_history.append(loss.item())
            valloss_history.append(valid_loss.item())
            acc_history.append(acc.detach().cpu())
            valacc_history.append(valid_acc.detach().cpu())

            if min_valid_loss > valid_loss:
                min_valid_loss = valid_loss
                count_no_inprov = 0
                torch.save(self.net.state_dict(), self.config.model_dir + "best_model.pth")
                print("Trained feed forward net saved at best_model.pth")
            else:
                count_no_inprov+=1

            if count_no_inprov > self.train_paitence:
                print("[!] No improvement in a while, stopping training...")
                print("BEST VALIDATION LOSS: ", min_valid_loss)

                break

            print("-------------------------------------------------")
        torch.save(self.net.state_dict(), self.config.model_dir + "last_model.pth")
        print("Trained feed forward net saved at last_model.pth")
        print("Finished training")
        loss_acc_plot(counter, loss_history, valloss_history, acc_history, valacc_history, self.config.results_dir)


    def test(self, dataset):
        state_dict = torch.load(self.config.model_dir + "best_model.pth")
        self.net.load_state_dict(state_dict)
        self.net.eval()
        correct = []  # stores all the accurate predictions
        calls = []  # stores all the accurate predicitons for each call
        for _ in range(len(self.config.class_mapping)):
            calls.append([])
        y_true = []  # stores all the expected calls
        y_pred = []  # stores all the predicted calls
        target_list = []  # stores a list with all the target values
        prediction_list = []  # stores a list with all the predictions
        with torch.no_grad():
            for input, target in dataset:# [batch size, num_channels, fr, time]
                input.unsqueeze_(0)
                Y = np.zeros(len(self.config.class_mapping))
                input = input.to(self.device)

                predictions = self.net(input)

                predictions = F.softmax(predictions,dim=1).cpu().numpy()
                predicted_index = predictions[0].argmax(0)

                predicted = predicted_index
                Y[target] = 1

                # storing the values for confusion matrix and precision-recall curve
                target_list.append(Y)
                prediction_list.append(predictions.squeeze().tolist())
                y_pred.append(predicted)
                y_true.append(target)

                # checking all the total correct preditions
                if predicted == target:
                    correct.append(1)
                else:
                    correct.append(0)

                # checking all the correct preditions for each call
                for index in range(len(self.config.class_mapping)):  # zip(range(len(class_mapping)), class_mapping):
                    if target == index:
                        if predicted == target:
                            calls[index].append(1)
                        else:
                            calls[index].append(0)

        # Calculation total accuracy
        accuracy = sum(correct) / len(correct)
        print(f"Accuracy:  {accuracy :.3f}")
        print(str(sum(correct)) + ' of ' + str(len(correct)))

        # Calculation accuracy for each call
        call_accuracies = np.zeros(len(self.config.class_mapping))

        for index, class_ in zip(range(len(self.config.class_mapping)), self.config.class_mapping):
            call_accuracies[index] = sum(calls[index]) / len(calls[index]) if len(calls[index]) != 0 else 0

            print(f"Accuracy for {class_}:  {call_accuracies[index] :.3f}")
            print(str(sum(calls[index])) + ' of ' + str(len(calls[index])))

        cfm(y_true,y_pred,self.config.results_dir,self.config.class_mapping)
        plot_pr_roc_curve(target_list,prediction_list,self.config.results_dir,self.config.class_mapping)
        
