In [1]:
#external imports
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as transforms
from sklearn.model_selection import train_test_split
import pandas as pd
import numpy as np
import time
import copy
from GPUtil import showUtilization as gpu_usage
import csv
import os
import tqdm
import gc
gc.collect()
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings("ignore")
import math
from sklearn.metrics import f1_score, roc_auc_score
#internal imports
from dataset import PTB_Dataset

#constants
from constants import REC_PATH,CSV_PATH,N_LEADS,N_CLASSES,DATASET_LIMIT
from model import My_Network
torch.cuda.empty_cache()


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def confusion(prediction, truth):
    """ Returns the confusion matrix for the values in the `prediction` and `truth`
    tensors, i.e. the amount of positions where the values of `prediction`
    and `truth` are
    - 1 and 1 (True Positive)
    - 1 and 0 (False Positive)
    - 0 and 0 (True Negative)
    - 0 and 1 (False Negative)
    """
    threshold = torch.tensor([0.5]).to(device)
    prediction = (prediction>threshold).float()*1
    
    confusion_vector = prediction / truth
    # Element-wise division of the 2 tensors returns a new tensor which holds a
    # unique value for each case:
    #   1     where prediction and truth are 1 (True Positive)
    #   inf   where prediction is 1 and truth is 0 (False Positive)
    #   nan   where prediction and truth are 0 (True Negative)
    #   0     where prediction is 0 and truth is 1 (False Negative)
    # .item()
    true_positives = torch.sum(confusion_vector == 1)
    false_positives = torch.sum(confusion_vector == float('inf'))
    true_negatives = torch.sum(torch.isnan(confusion_vector))
    false_negatives = torch.sum(confusion_vector == 0)

    return true_positives, false_positives, true_negatives, false_negatives

In [5]:
def train_gender_loop(dataloaders,model,optimizer,loss_fn,device,weights_name,save_dir,set_sizes,epochs=2,evaluate = True):
    if evaluate == True:
        phases =['Train', 'Valid']
    else:
        phases = ['Train']
    best_loss = 10e5
    best_epoch = 10e3
    for epoch in range(epochs):
        for phase in phases:
            if phase == 'Train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode
            
            running_loss = 0.0
            running_corrects = 0
            running_tp = 0
            running_fp = 0
            running_tn = 0
            running_fn = 0
            conf_tp = 0
            conf_fp = 0
            conf_tn = 0
            conf_fn = 0
            conf_precision = 0
            conf_recal = 0
            metrics_f1 = 0
            metrics_roc = 0
            missing_roc = 0
            for batch,(inputs,targets) in enumerate(dataloaders[phase]):
                inputs = inputs.to(device)
                targets = targets.to(device)
                
                with torch.set_grad_enabled(phase == 'Train'):
                    outputs = model(inputs)
                    loss = loss_fn(outputs,targets)
                    if phase == 'Train':
                        optimizer.zero_grad()
                        loss.backward()
                        optimizer.step()
                running_loss += loss.item() * inputs.size(0)
                #running_corrects += torch.sum(preds == targets.data)
                total += targets.nelements()
                tp, fp, tn, fn = confusion(outputs, targets)
                
                running_tp += tp
                running_fp += fp
                running_tn += tn
                running_fn += fn
                conf_tp += running_tp/total
                conf_fp += running_fp/total
                conf_tn += running_tn/total
                conf_fn += running_fn/total
                conf_precision += running_tp /(running_tp + running_fp)
                conf_recal += running_tp /(running_tp + running_fn)
                
                #metrics
                metrics_f1 += f1_score(targets.cpu().ravel>0,outputs.cpu().ravel()>0.5)
                try:
                    metrics_roc += roc_auc_score(targets.cpu().int().ravel().detach().numpy(),outputs.cpu().ravel().detach().numpy())
                except:
                    missing_roc +=1
                    continue
            epoch_loss = running_loss / set_sizes[phase]
            epoch_precision = conf_precision/set_sizes[phase]
            epoch_recal = conf_recal/set_sizes[phase]
            epoch_f1 = epoch_precision*epoch_recal/(epoch_precision+epoch_recal)
            epoch_metrics_f1 = metrics_f1/set_sizes[phase]
            epoch_metrics_roc = metrics_roc/(set_sizes[phase]-missing_roc)
            if phase =='Train':
                performance=[epoch,epoch_loss,epoch_metrics_f1,epoch_metrics_roc]
            else:
                performance.extend([epoch_loss,epoch_metrics_f1,epoch_metrics_roc])
                print(performance)
                with open('performance.csv', 'a+',newline='\n') as csvfile:
                    writer = csv.writer(csvfile)
                    writer.writerow(performance)
            print('epoch: {} {} loss: {:.4f}, '.format(epoch,phase,
                                                        epoch_loss))
            print('f1 score: {:.4f}, precision: {:.4f}, recall: {:.4f}'.format(epoch_f1,
                                                                              epoch_precision,
                                                                              epoch_recal))
            print('TP: {:.4f}, FP: {:.4f}, TN: {:.4f}, FN: {:.4f}'. format(
                                                        conf_tp/set_sizes[phase],
                                                        conf_fp/set_sizes[phase],
                                                        conf_tn/set_sizes[phase],
                                                        conf_fn/set_sizes[phase])
                                                        )

        if phase == 'Valid' and epoch_loss < best_loss:
            best_loss = epoch_loss
            best_epoch = epoch
            
            
    gpu_usage()
    return best_loss, best_epoch

In [None]:
def main(weights_name,epochs=50,batch_size=64,lr = 1.e-4,betas=(0.9,0.9999)):
    #initialization
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    CustomDataset = PTB_Dataset(CSV_PATH,REC_PATH,transforms.ToTensor())
    print(CustomDataset.__len__())

    #Train %
    TRAIN_SIZE = math.floor(CustomDataset.__len__()*0.75)
    #print(CustomDataset.__len__(),TRAIN_SIZE)
    TEST_SIZE = CustomDataset.__len__()-TRAIN_SIZE
    train_dataset,test_dataset = torch.utils.data.random_split(CustomDataset,[TRAIN_SIZE, TEST_SIZE])
    VALID_SIZE = math.floor(TRAIN_SIZE*0.4)
    TRAIN_SIZE = TRAIN_SIZE - VALID_SIZE
    valid_dataset,train_dataset = torch.utils.data.random_split(train_dataset,[VALID_SIZE, TRAIN_SIZE])
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    train_dataloader = DataLoader(train_dataset,batch_size=batch_size)
    valid_dataloader = DataLoader(valid_dataset,batch_size=batch_size)
    test_dataloader = DataLoader(test_dataset,batch_size=batch_size)
    dataloaders = {
        'Train':
            train_dataloader,
        'Valid':
            valid_dataloader,
        'Test':
            test_dataloader
    }
    set_sizes = {
        'Train':TRAIN_SIZE,
        'Valid':VALID_SIZE,
        'Test':TEST_SIZE
    }
    #print(f'test_dataset {test_dataset}')
    model = My_Network(lstm_input_dim = 10,hidden_dim = 10,num_layers = 2)
    model = model.to(device)
    loss_fn = torchvision.ops.sigmoid_focal_loss(reduction = 'mean')
    optimizer = torch.optim.Adam(model.parameters(),lr=lr,betas=betas)
    best_loss = 10e4
    best_epoch = 10e3
    loss,epoch= train_gender_loop(dataloaders=dataloaders,model=model,loss_fn=loss_fn,optimizer=optimizer,device=device,epochs=epochs,weights_name=weights_name)
    if loss < best_loss:
        best_loss = loss
        best_epoch = epoch
        best_weights = weights_name
    
    return best_loss,best_epoch,best_weights