## Import & Environment Setting

In [None]:
import torch
from torch import autograd
from torch.utils.data import DataLoader

import json
import gc
import numpy as np
import datetime
from collections import Counter

from utils.dataset import LabeledDataset
from utils.model import YoloV3, YoloLoss
from utils.postprocess import PostProcessor


In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
dtype = torch.float

## Load Config

In [None]:
with open("./config/config.json", "r") as config_file:
    main_config = json.load(config_file)

try:
    model_config = main_config['model']
    train_config = main_config['train']
    valid_config = main_config['train']['validation']
    loss_config = main_config['train']['loss']
except NameError:
    assert False, ('Failed to load config file')
except KeyError:
    assert False, ('Failed to find key on config file')

In [None]:
model_config['device'] = device
model_config['dtype'] = dtype
model_config['attrib_count'] = 5 + model_config['class_count']

loss_config['device'] = device
loss_config['dtype'] = dtype
loss_config['attrib_count'] = model_config['attrib_count']

In [None]:
train_context = { }

train_context['device'] = device
train_context['dtype'] = dtype

train_context['train_set'] = LabeledDataset(train_config['set']['index'], 
                                          train_config['set']['image_dir'], 
                                          train_config['set']['label_dir'])
train_context['train_loader'] = DataLoader(train_context['train_set'], 
                                           batch_size = train_config['set']['batch_size'], 
                                           num_workers = train_config['set']['num_workers'],
                                           shuffle = True)

train_context['valid_set'] = LabeledDataset(valid_config['set']['index'], 
                                          valid_config['set']['image_dir'], 
                                          valid_config['set']['label_dir'])
train_context['valid_loader'] = DataLoader(train_context['valid_set'], 
                                           batch_size = valid_config['set']['batch_size'], 
                                           num_workers = valid_config['set']['num_workers'],
                                           shuffle = False)

train_context['epoch'] = 0
train_context['last_checkpoint'] = 0
train_context['lr'] = train_config['learning_rate']['init']

In [None]:
if train_config['log']['tb_enable']:
    from torch.utils.tensorboard import SummaryWriter
    train_context['tb'] = SummaryWriter(log_dir = train_config['log']['tb_dir'])
else:
    train_context['tb'] = None

## Build

In [None]:
model = YoloV3(model_config)
model = model.to(model_config['device'])

In [None]:
loss_func = YoloLoss(loss_config)

In [None]:
lr_func = lambda epoch: train_context['lr']
optimizer = torch.optim.Adam(model.parameters(), lr = train_context['lr'])
#train_context['optimizer'] = torch.optim.SGD(model.parameters(), lr=learning_rate)
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda = lr_func, last_epoch = -1)

In [None]:
def train(model, loss_func, optimizer, scheduler, train_context, train_config, epochs):
    
    postProcessor = PostProcessor()
    
    for _ in range(0, epochs):
        # training step
        model.train()
        torch.autograd.set_detect_anomaly(train_config['enable_anomaly_detection'])
        
        print('epoch : ', train_context['epoch'])
        print('    time : ', datetime.datetime.now().time())
        print('    lr : ', train_context['lr'])
        if train_context['tb'] is not None:
            train_context['tb'].add_scalar('Step/Learning Rate', train_context['lr'], train_context['epoch'])
        
        losses = []
        obj_losses = []
        coord_losses = []
        for idx, batches in enumerate(train_context['train_loader']):
            image = batches['image'].to(train_context['device'], dtype = train_context['dtype'])
            labels = batches['label'].to(train_context['device'], dtype = train_context['dtype'])
            label_len = batches['label_len'].to(train_context['device'], dtype = torch.long)
            
            # forward
            out1, out2, out3 = model(image)
       
            # clear optimizer
            optimizer.zero_grad()
        
            # loss
            loss, obj_loss, coord_loss = loss_func(torch.cat((out1, out2, out3), 1), labels, label_len)
            losses.append(loss.item() / batches['image'].shape[0])
            obj_losses.append(obj_loss.item() / batches['image'].shape[0])
            coord_losses.append(coord_loss.item() / batches['image'].shape[0])
            
            # backward
            loss.backward()
            optimizer.step()
            
            # cleanup
            del image, labels, label_len, loss
            del out1, out2, out3
            gc.collect()
            torch.cuda.empty_cache()
    
        # print loss
        avg_loss = np.mean(losses) if len(losses) is not 0 else 0
        avg_obj_loss = np.mean(obj_losses) if len(obj_losses) is not 0 else 0
        avg_coord_loss = np.mean(coord_losses) if len(coord_losses) is not 0 else 0
        train_context['loss_window'].append(avg_loss)
        print('    loss : ', avg_loss)
        print('    obj_loss : ', avg_obj_loss)
        print('    coord_loss : ', avg_coord_loss)
        if train_context['tb'] is not None:
            train_context['tb'].add_scalar('Loss/Training Loss', avg_loss, train_context['epoch'])
            train_context['tb'].add_scalar('Loss/Training Object Loss', avg_obj_loss, train_context['epoch'])
            train_context['tb'].add_scalar('Loss/Training Coord Loss', avg_coord_loss, train_context['epoch'])
        

        # validate step
        if 'val_dataset' in train_context and train_context['val_dataset'] is not None:
            with torch.no_grad():
                model.eval()
                torch.autograd.set_detect_anomaly(False)
            
                losses = []
                obj_losses = []
                coord_losses = []
                if loss_context['acc_start_epoch'] < train_context['epoch']:
                    accs = Counter({})
                for idx, batches in enumerate(train_context['val_dataloader']):
                
                    image = batches['image'].to(train_config['device'], dtype = train_config['dtype'])
                    labels = batches['label'].to(train_config['device'], dtype = train_config['dtype'])
                    label_len = batches['label_len'].to(train_config['device'], dtype = torch.long)
            
                    out1, out2, out3 = model(image)
                    pred = torch.cat((out1, out2, out3), 1)
        
                    loss, obj_loss, coord_loss = loss_func(pred, labels, label_len)
                    losses.append(loss.item() / batches['image'].shape[0])
                    obj_losses.append(obj_loss.item() / batches['image'].shape[0])
                    coord_losses.append(coord_loss.item() / batches['image'].shape[0])
                
                    if loss_context['acc_start_epoch'] < train_context['epoch']:
                        prediction = {}
                        prediction['image'] = batches['image'].cpu().permute(0, 2, 3, 1).squeeze(0).numpy()
                        prediction['pred'] = pred.cpu().detach().squeeze(0).numpy()
                        prediction['label'] = batches['label'].cpu().squeeze(0).numpy()
                        prediction['label_len'] = batches['label_len'].cpu().squeeze(0).numpy()
                
                        bboxes = postProcessor.CUSTOM2(prediction['pred'], loss_context)
                        acc = postProcessor.calcAccuracyMap(prediction['label'], prediction['label_len'], bboxes, loss_context)
                        accs = accs + Counter(acc)
            
                    # cleanup
                    del image, labels, label_len, loss
                    if loss_context['acc_start_epoch'] < train_context['epoch']:
                        del prediction, bboxes, acc
                    del out1, out2, out3
                    gc.collect()
                    torch.cuda.empty_cache()
    
                # print validation loss
                avg_loss = np.mean(losses) if len(losses) is not 0 else 0
                avg_obj_loss = np.mean(obj_losses) if len(obj_losses) is not 0 else 0
                avg_coord_loss = np.mean(coord_losses) if len(coord_losses) is not 0 else 0
                print('    validation loss : ', avg_loss)
                if loss_context['acc_start_epoch'] < train_context['epoch']:
                    print('    accs : ', accs)
                if train_context['tb'] is not None:
                    train_context['tb'].add_scalar('Loss/Validation Loss', avg_loss, train_context['epoch'])
                    train_context['tb'].add_scalar('Loss/Validation Object Loss', avg_obj_loss, train_context['epoch'])
                    train_context['tb'].add_scalar('Loss/Validation Coord Loss', avg_coord_loss, train_context['epoch'])
                    
                    if loss_context['acc_start_epoch'] < train_context['epoch']:
                        tp = accs['true positive']
                        fn = accs['false negative']
                        fp = accs['false positive'] + accs['duplicate']
                        accuracy = tp / (tp + fn + fp)
                        recall = tp / (tp + fn)
                        precision = tp / (tp + fp)
                        print('    accuracy : ', accuracy)
                        print('    recall : ', recall)
                        print('    precision : ', precision)
                        train_context['tb'].add_scalar('Accuracy/Accuracy', accuracy, train_context['epoch'])
                        train_context['tb'].add_scalar('Accuracy/Recall', recall, train_context['epoch'])
                        train_context['tb'].add_scalar('Accuracy/Precision', precision, train_context['epoch'])
                        
                        if(accuracy >= train_config['target_accuaracy'] and 
                           recall >= train_config['target_recall'] and 
                           precision >= train_config['target_precision']):
                            # save model
                            torch.save(model, train_config['checkpoint_dir'] + 'model_r_' + str(train_context['epoch']) + '.dat')
                            
                
                
            
        # update learning rate & scheduler
        #window_len = len(train_context['loss_window'])
        #if (len(train_context['loss_window']) >= train_config['lr_window'] and
        #    np.mean(train_context['loss_window']) * train_config['lr_threshold'] <= np.mean(train_context['loss_window'][-2:])):
            
        #    print('    window size : ', len(train_context['loss_window']))
        #    print('    decrease lr to : ', train_context['lr'] * train_config['lr_decay'])
                
        #    train_context['lr'] = train_context['lr'] * train_config['lr_decay']
        #    train_context['loss_window'] = []
            
        #if len(train_context['loss_window']) > 2 * train_config['lr_window']:
        #    train_context['loss_window'] = train_context['loss_window'][(train_config['lr_window'] * 3) // 2:]
        if train_context['epoch'] > train_config['lr_decay_start']:
            train_context['lr'] = train_context['lr'] * train_config['lr_decay']
            train_context['loss_window'] = []
        
        scheduler.step()
        
        # update context
        train_context['epoch'] += 1
        
        # save model
        if (train_context['epoch'] >= train_config['checkpoint_start'] 
            and train_context['epoch'] % train_config['checkpoint'] is 0):
            train_context['last_checkpoint'] = train_config['checkpoint']
            torch.save(model, train_config['checkpoint_dir'] + 'model_' + str(train_context['epoch']) + '.dat')
            

## Run

In [None]:
train(model, loss_func, optimizer, scheduler, train_context, train_config, train_config['plan']['epochs'])
if writer is not None:
    writer.flush()
    writer.close()

## Temporary Code