In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import autograd
from torch.utils.data import DataLoader

import gc 

import numpy as np 
from utils.dataset import DetectionFolder
from models import *

In [None]:
# config
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

data_config = { }
data_config['train'] = '../data/detect/train.txt'
data_config['test'] = '../data/detect/train.txt'
data_config['image'] = '../data/detect/images/'
data_config['label'] = '../data/detect/labels/'

model_config = { }
model_config['device'] = device
model_config['size'] = (608, 608)
model_config['channel'] = 3
model_config['dtype'] = torch.float
model_config['anchors'] = [(10, 13), (16, 30), (33, 23), (30, 61), (62, 45), (59, 119), (116, 90), (156, 198), (373, 326)]
model_config['attribs'] = 5
model_config['debug_level'] = 1

train_config = { }
train_config['coef_noobj'] = 0.2
train_config['coef_coord'] = 20 / (608 * 608)
train_config['coef_total'] = 2
train_config['device'] = device
train_config['debug_level'] = 0
train_config['iou_threshold'] = 0.75


trainset = DetectionFolder(data_config['train'], data_config['image'], data_config['label'])
trainloader = DataLoader(trainset, batch_size = 4, num_workers = 4)
#testset = DetectionFolder(config['test'], config['image'], config['label'])
#testloader = DataLoader(testset)

In [None]:
print(trainset.__len__())
for indx in range(0, 4):
    print(trainset.__getitem__(indx)['label'].shape)
    print(trainset.__getitem__(indx))
    print(trainset.__getitem__(0)['image'])

In [None]:
for idx, batches in enumerate(trainloader):
    print(idx)
    print(batches['image'].shape)

In [None]:
# create model
model = YoloV3(model_config)
model.to(model_config['device'])

In [None]:
class YoloLoss():
    def __init__(self, config):
        self.device = config['device']
        
        self.coef_noobj = torch.tensor(config['coef_noobj']).to(self.device)
        self.coef_coord = torch.tensor(config['coef_coord']).to(self.device)
        self.coef_total = torch.tensor(config['coef_total']).to(self.device)
        self.debug_level = config['debug_level']
        self.iou_threshold = config['iou_threshold']
        
        self.iou_epsilon = torch.tensor(1e-9).to(self.device)
        
    def __call__(self, pred, label, label_len):
        if self.debug_level >= 2:
            print('pred shape: ', pred.shape)
            print('label shape: ', label.shape)
            print('label_len: ', label_len)
        # pred = B * P * Attrib
        # label = B * 15 * Attrib
        
        pred = F.relu(pred)
        label = F.relu(label)
        
        if self.debug_level >= 3:
            print('pred : ', pred)
            print('label_len : ', label_len)
            print('label : ', label)
            
        # iou = B * P * 15
        # obj_mask = B * P * 15
        iou = self.batch_iou(pred, label)
        obj_mask = self.batch_obj_mask(iou, label_len)
        if self.debug_level >= 3:
            print('iou : ', iou)
        if self.debug_level >= 2:
            print('obj_mask.shape : ', obj_mask.shape)
        
        # objectness loss = B * P * 15
        obj_loss = self.batch_obj_loss(obj_mask, pred, label_len)
        if self.debug_level >= 2:
            print('obj_loss : ', obj_loss)
        
        # coord loss = B * P * 15
        coord_loss = self.batch_coord_loss(obj_mask, pred, label)
        if self.debug_level >= 2:
            print('coord_loss : ', coord_loss)
        
        # classfication loss = B * P * 15
        #class_loss = obj_mask * self.batch_class_loss(pred, label)
        
        total_loss = torch.sum(obj_loss + coord_loss)
        if self.debug_level >= 1:
            print('total_loss : ', total_loss)
        
        return total_loss
        
        
        ###
        num_batch = pred.shape[0]
        
        total_loss = torch.tensor([0.0], device=self.device)
        
        for batch_idx in range(num_batch):
            
            iou = self.calc_iou(pred[batch_idx], label[batch_idx])
            responsibile = self.calc_responsibile(iou)
            
            # coord loss
            coord_loss = self.calc_coord_loss(pred[batch_idx], label[batch_idx], responsibile)
            mean_coord_loss = torch.mean(coord_loss) * self.coef_coord
            
            if self.debug_level >= 2:
                print('mean_coord_loss', mean_coord_loss)
            total_loss += mean_coord_loss
            
            # conf loss
            conf_loss = self.calc_confidence_loss(pred[batch_idx], iou, responsibile)
            mean_conf_loss = torch.mean(conf_loss)
            if self.debug_level >= 2:
                print('mean_coord_loss', mean_coord_loss)
            total_loss += mean_conf_loss
            
            
            if self.debug_level >= 2:
                print('iou : ', iou)
                print('iou type: ', iou.type())
                print('iou shape: ', iou.shape)
                print('responsibile : ', responsibile)
                print('coord_loss : ', coord_loss)
                print('coord_loss type : ', coord_loss.type())
                print('conf_loss : ', conf_loss)
                print('conf_loss type : ', conf_loss.type())
        if self.debug_level >= 1:
            print('loss : ', total_loss / num_batch * self.coef_total)
        
        return total_loss / num_batch * self.coef_total
    
    ### from https://github.com/westerndigitalcorporation/YOLOv3-in-PyTorch/blob/release/src/model.py
    def batch_iou(self, pred, label):
        x1 = label[..., 0]
        y1 = label[..., 1]
        w1 = label[..., 2]
        h1 = label[..., 3]

        x2 = pred[..., 0]
        y2 = pred[..., 1]
        w2 = pred[..., 2]
        h2 = pred[..., 3]

        area1 = w1 * h1
        area2 = w2 * h2

        x1 = x1 - w1 / 2
        y1 = y1 - h1 / 2
        x2 = x2 - w2 / 2
        y2 = y2 - h2 / 2
        right1 = (x1 + w1).unsqueeze(2)
        right2 = (x2 + w2).unsqueeze(1)
        top1 = (y1 + h1).unsqueeze(2)
        top2 = (y2 + h2).unsqueeze(1)
        left1 = x1.unsqueeze(2)
        left2 = x2.unsqueeze(1)
        bottom1 = y1.unsqueeze(2)
        bottom2 = y2.unsqueeze(1)
        
        
        w_intersect = (torch.min(right1, right2) - torch.max(left1, left2)).clamp(min=0)
        h_intersect = (torch.min(top1, top2) - torch.max(bottom1, bottom2)).clamp(min=0)
        area_intersect = h_intersect * w_intersect

        iou_ = area_intersect / (area1.unsqueeze(2) + area2.unsqueeze(1) - area_intersect + self.iou_epsilon)

        return iou_
        
    def batch_obj_mask(self, iou, label_len):
        max_iou, max_iou_indx = torch.max(iou, 2)
        if self.debug_level >= 2:
            print('max_iou ', max_iou)
        
        obj_mask = torch.where(iou > max_iou.unsqueeze(2) * self.iou_threshold, 
                               torch.ones_like(iou), torch.zeros_like(iou))
        
        if self.debug_level >= 2:
            print('nonzero ', torch.nonzero(obj_mask).shape[0])
        
        return obj_mask
    
    def batch_obj_loss(self, obj_mask, pred, label_len):
        coef_mask = torch.where(obj_mask == 1, 
                                torch.ones_like(obj_mask), torch.ones_like(obj_mask) * self.coef_noobj)
        
        conf = torch.transpose(pred[..., 4].clone().repeat(obj_mask.shape[1], 1, 1), 0, 1)
        
        obj_loss_all = coef_mask * F.mse_loss(obj_mask, conf, reduction='none')
        
        obj_loss = torch.tensor(0.0, device = self.device)
        for indx in range(0, label_len.shape[0]):
            obj_loss += torch.sum(obj_loss_all[indx][0:label_len[indx]])
        
        if self.debug_level >= 2:
            print('max conf : ', torch.max(obj_loss_all))
        if self.debug_level >= 4:
            print('obj_loss.shape', obj_loss.shape)
        
        return obj_loss
        
    def batch_coord_loss(self, obj_mask, pred, label):
        x1 = label[..., 0].repeat(pred.shape[1], 1, 1).permute(1, 2, 0)
        y1 = label[..., 1].repeat(pred.shape[1], 1, 1).permute(1, 2, 0)
        w1 = label[..., 2].repeat(pred.shape[1], 1, 1).permute(1, 2, 0)
        h1 = label[..., 3].repeat(pred.shape[1], 1, 1).permute(1, 2, 0)
        
        x2 = torch.transpose(pred[..., 0].repeat(label.shape[1], 1, 1), 0, 1)
        y2 = torch.transpose(pred[..., 1].repeat(label.shape[1], 1, 1), 0, 1)
        w2 = torch.transpose(pred[..., 2].repeat(label.shape[1], 1, 1), 0, 1)
        h2 = torch.transpose(pred[..., 3].repeat(label.shape[1], 1, 1), 0, 1)
        
        x_loss = self.coef_coord * obj_mask * F.mse_loss(x1, x2, reduction='none')
        y_loss = self.coef_coord * obj_mask * F.mse_loss(y1, y2, reduction='none')
        w_loss = self.coef_coord * obj_mask * F.mse_loss(torch.sqrt(w1), torch.sqrt(w2), reduction='none')
        h_loss = self.coef_coord * obj_mask * F.mse_loss(torch.sqrt(h1), torch.sqrt(h2), reduction='none')
        
        coord_loss = x_loss + y_loss + w_loss + h_loss
        coord_loss = torch.sum(coord_loss)
        if self.debug_level >= 4:
            print('coord_loss.shape', coord_loss.shape)
        
        return coord_loss
        
        
    def calc_iou(self, pred, label):
        pred_xy = pred[:,0:2].clone()
        pred_wh = pred[:,2:4].clone()
        label_xy = label[:,0:2].clone()
        label_wh = label[:,2:4].clone()
        
        c1 = torch.sum(torch.cartesian_prod(label_xy[:,0], pred_xy.mul(-1)[:,0]), 1).reshape(1, -1)
        c2 = torch.sum(torch.cartesian_prod(label_xy[:,1], pred_xy.mul(-1)[:,1]), 1).reshape(1, -1)
        inter_xy = torch.cat((c1, c2), dim = 0).t().view(label.shape[0], pred.shape[0], 2)
        inter_xy = torch.abs(inter_xy)

        c1 = torch.sum(torch.cartesian_prod(label_wh[:,0], pred_wh[:,0]), 1).reshape(1, -1)
        c2 = torch.sum(torch.cartesian_prod(label_wh[:,1], pred_wh[:,1]), 1).reshape(1, -1)
        inter_wh = torch.cat((c1, c2), dim = 0).t().view(label.shape[0], pred.shape[0], 2)
        inter_wh = torch.add(torch.div(inter_wh, 2), inter_xy, alpha = -1)
        inter_wh = F.relu(inter_wh)
        
        pred_area = torch.mul(pred_wh[:, 0], pred_wh[:, 1]).repeat(label.shape[0], 1)
        label_area = torch.mul(label_wh[:, 0], label_wh[:, 1]).repeat(pred.shape[0], 1).t()
        inter_area = torch.mul(inter_wh[:,:,0], inter_wh[:,:,1])
        
        output = inter_area / (pred_area + label_area - inter_area + self.iou_epsilon)
        
        return output
    
    def calc_coord_loss(self, pred, label, responsible):
        
        resp_pred = torch.index_select(pred, 0, responsible)
        if self.debug_level >= 4:
            print('resp_pred : ', resp_pred)
        resp_pred_xy = resp_pred[:, 0:2]
        resp_pred_wh = resp_pred[:, 2:4]
        resp_pred_wh = torch.sqrt(resp_pred_wh)
        
        label_xy = label[:, 0:2]
        label_wh = label[:, 2:4]
        label_wh = torch.sqrt(label_wh)
        
        output = torch.sum(torch.pow(label_xy - resp_pred_xy, 2), 1)
        output += torch.sum(torch.pow(label_wh - resp_pred_wh, 2), 1)
        
        return output
        
        
    def calc_confidence_loss(self, pred, iou, responsible):
        conf = pred[:, 4].repeat(iou.shape[0], 1)
        conf = torch.pow(conf - iou, 2)
        
        mean_conf = conf * self.coef_noobj
        
        for indx in range(responsible.shape[0]):
            mean_conf[indx][responsible[indx]] += (1 - self.coef_noobj) * conf[indx][responsible[indx]]
        
        return conf
    
        
    def calc_responsibile(self, iou):
        
        argmax = torch.argmax(iou, dim=1)
        
        return argmax

    

In [None]:
# create components
learning_rate = 0.1
loss_cache = []
#lr_func = lambda epoch: learning_rate * (0.97 ** epoch) if epoch > 30 else 0.01
lr_func = lambda epoch: learning_rate

optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
#optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_func, last_epoch = -1)
loss_func = YoloLoss(train_config)

In [None]:
def train_step(config, model, trainloader, loss_func, optimizer):
    model.train()
        
    global learning_rate
    global loss_cache
    
    avg_loss = []
    
    for idx, batches in enumerate(trainloader):
        if config['debug_level'] >= 2:
            print('index', idx)
            
        image = batches['image'].to(config['device'], dtype = config['dtype'])
        labels = batches['label'].to(config['device'], dtype = config['dtype'])# / config['size'][0]
        label_len = batches['label_len'].to(config['device'], dtype = torch.long)
        if config['debug_level'] >= 3:
            print('label shape : ', labels.shape)
        
        # forward
        out1, out2, out3 = model(image)
        
        if config['debug_level'] >= 2:
            print('out1.shape : ', out1.shape)
            print('out1[0] : ', out1[0])
            print('out2.shape : ', out2.shape)
            print('out2[0] : ', out2[0])
            print('out3.shape : ', out3.shape)
            print('out3[0] : ', out3[0])
            print('labels.shape : ', labels.shape)
            print('labels[0] : ', labels[0])

        # clear optimizer
        optimizer.zero_grad()
        
        # backward
        loss = loss_func(torch.cat((out1, out2, out3), 1), labels, label_len)
        loss.backward()
        optimizer.step()
            
        avg_loss.append(loss.item())
        
        # cleanup
        del image
        del labels
        del out1
        del out2
        del out3
        gc.collect()
        torch.cuda.empty_cache()
    
    # update learning_rate
    loss_cache.append(np.mean(avg_loss))
        
    if len(loss_cache) >= 10 and np.mean(loss_cache) < np.mean(loss_cache[-2:]) :
        print('decrease learning rate from : ', learning_rate)
        print('average of previous loss : ', np.mean(loss_cache))
        print('length of previous loss : ', len(loss_cache))
        learning_rate = learning_rate * 0.90
        loss_cache = []
        print('decrease learning rate to : ', learning_rate)
    if len(loss_cache) > 20 :
        loss_cache = loss_cache[15:]
        
    # print loss
    if config['debug_level'] >= 1:
        print('avg loss : ', np.mean(avg_loss))



In [None]:
# single step
with torch.autograd.set_detect_anomaly(False):
    train_step(model_config, model, trainloader, loss_func, optimizer)
    scheduler.step()

In [None]:
# value checking
for epoch in range(0, 30):
    print('epoch : ', epoch)
    with torch.autograd.set_detect_anomaly(False):
        trainset.shuffle()
        train_step(model_config, model, trainloader, loss_func, optimizer)
        scheduler.step()


In [None]:

train_config['debug_level'] = 0
loss_func = YoloLoss(train_config)

for epoch in range(30, 1200):
    print('epoch : ', epoch)
    trainset.shuffle()
    train_step(model_config, model, trainloader, loss_func, optimizer)
    scheduler.step()
    
    if epoch % 200 == 0 :
        torch.save(model, './epoch_' + str(epoch) + '.dat')

In [None]:

torch.save(model, './epoch_2000.dat')