In [1]:
import os
import sys
if '..' not in sys.path:
    sys.path.append('..')

import torch
import numpy as np
from torch.utils.data import DataLoader, Dataset
from torch import nn, optim
import argparse
from utiles import box2corners
from oriented_iou_loss import cal_diou, cal_giou

In [2]:
DATA = "./data"
X_MAX = 3
Y_MAX = 3
SCALE = 0.5
BATCH_SIZE = 32
N_DATA = 128
NUM_TRAIN = 200 * BATCH_SIZE * N_DATA
NUM_TEST = 20 * BATCH_SIZE * N_DATA

NUM_EPOCH = 20

if not os.path.exists(DATA):
    os.mkdir(DATA)

In [3]:
def create_data(num):
    print("... generating %d boxes, please wait..."%num)
    x = (np.random.rand(num) - 0.5) * 2 * X_MAX
    y = (np.random.rand(num) - 0.5) * 2 * Y_MAX
    w = (np.random.rand(num) - 0.5) * 2 * SCALE + 1
    h = (np.random.rand(num) - 0.5) * 2 * SCALE + 1
    alpha = np.random.rand(num) * np.pi
    corners = np.zeros((num, 4, 2)).astype(np.float)
    for i in range(num):
        corners[i, ...] = box2corners(x[i], y[i], w[i], h[i], alpha[i])
    label = np.stack([x, y , w, h, alpha], axis=1)
    return corners, label

def save_dataset():
    train_data, train_label = create_data(NUM_TRAIN)
    np.save(os.path.join(DATA, "train_data.npy"), train_data)
    np.save(os.path.join(DATA, "train_label.npy"), train_label)
    test_data, test_label = create_data(NUM_TEST)
    np.save(os.path.join(DATA, "test_data.npy"), test_data)
    np.save(os.path.join(DATA, "test_label.npy"), test_label)
    print("data saved in: ", DATA)

class BoxDataSet(Dataset):
    def __init__(self, split="train"):
        super(BoxDataSet, self).__init__()
        assert split in ["train", "test"], "split must be train or test"
        self.split = split
        try:
            self.data = np.load(os.path.join(DATA, split+"_data.npy"))
            self.label = np.load(os.path.join(DATA, split+"_label.npy"))
        except:
            save_dataset()
            self.data = np.load(os.path.join(DATA, split+"_data.npy"))
            self.label = np.load(os.path.join(DATA, split+"_label.npy"))
    def __len__(self) -> int:
        return self.data.shape[0]
    def __getitem__(self, index: int) :
        d = self.data[index, ...]
        l = self.label[index, ...]
        return torch.FloatTensor(d), torch.FloatTensor(l)

In [4]:
def create_network():
    return nn.Sequential(nn.Conv1d(8, 128, 1, bias=False),
                nn.BatchNorm1d(128),
                nn.ReLU(True),
                nn.Conv1d(128, 512, 1, bias=False),
                nn.BatchNorm1d(512),
                nn.ReLU(True),
                nn.Conv1d(512, 128, 1, bias=False),
                nn.BatchNorm1d(128),
                nn.ReLU(True),
                nn.Conv1d(128, 5, 1),
                nn.Sigmoid())

def parse_pred(pred:torch.Tensor):
    p0 = (pred[..., 0] - 0.5) * 2 * X_MAX
    p1 = (pred[..., 1] - 0.5) * 2 * Y_MAX
    p2 = (pred[..., 2] - 0.5) * 2 * SCALE + 1
    p3 = (pred[..., 3] - 0.5) * 2 * SCALE + 1
    p4 = pred[..., 4] * np.pi
    return torch.stack([p0,p1,p2,p3,p4], dim=-1)

In [5]:
def main(loss_type:str="giou", enclosing_type:str="aligned"):
    ds_train = BoxDataSet("train")
    ds_test = BoxDataSet("test")
    ld_train = DataLoader(ds_train, BATCH_SIZE * N_DATA, shuffle=True, num_workers=4)
    ld_test = DataLoader(ds_test, BATCH_SIZE * N_DATA, shuffle=False, num_workers=4)
    
    net = create_network()
    net.to("cuda:0")
    optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9)
    lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
    num_batch = len(ds_train)//(BATCH_SIZE*N_DATA)
    
    for epoch in range(1, NUM_EPOCH+1):
        # train
        net.train()
        for i, data in enumerate(ld_train, 1):
            box, label = data
            box = box.cuda()                            # (B*N, 4, 2)
            box = box.view([BATCH_SIZE, -1, 4*2])       # (B, N, 4*2)
            box = box.transpose(1, 2)                   # (B, 8, N)
            label = label.cuda()                        # (B*N, 5)
            label = label.view([BATCH_SIZE, -1, 5])     # (B, N, 5)
            
            optimizer.zero_grad()
            pred = net(box)                             # (B, 5, N)
            pred = pred.transpose(1,2)                  # (B, N, 5)
            pred = parse_pred(pred)

            iou_loss, iou = None, None
            if loss_type == "giou":
                iou_loss, iou = cal_giou(pred, label, enclosing_type)
            elif loss_type == "diou":
                iou_loss, iou = cal_diou(pred, label, enclosing_type)
            else:
                ValueError("unknown loss type")
            iou_loss = torch.mean(iou_loss)
            iou_loss.backward()
            optimizer.step()

            if i%10 == 0:
                iou_mask = (iou > 0).float()
                mean_iou = torch.sum(iou) / (torch.sum(iou_mask) + 1e-8)
                print("[Epoch %d: %d/%d] train loss: %.4f  mean_iou: %.4f"
                    %(epoch, i, num_batch, iou_loss.detach().cpu().item(), mean_iou.detach().cpu().item()))
        lr_scheduler.step()

        # validate
        net.eval()
        aver_loss = 0
        aver_mean_iou = 0
        with torch.no_grad():
            for i, data in enumerate(ld_test, 1):
                box, label = data
                box = box.cuda()                            # (B*N, 4, 2)
                box = box.view([BATCH_SIZE, -1, 4*2])       # (B, N, 4*2)
                box = box.transpose(1, 2)                   # (B, 8, N)
                label = label.cuda()                        # (B*N, 5)
                label = label.view([BATCH_SIZE, -1, 5])     # (B, N, 5)
                
                pred = net(box)                             # (B, 5, N)
                pred = pred.transpose(1,2)                  # (B, N, 5)
                pred = parse_pred(pred)

                iou_loss, iou = None, None
                if loss_type == "giou":
                    iou_loss, iou = cal_giou(pred, label, enclosing_type)
                elif loss_type == "diou":
                    iou_loss, iou = cal_diou(pred, label, enclosing_type)
                else:
                    ValueError("unknown loss type")
                iou_loss = torch.mean(iou_loss)
                aver_loss += iou_loss.cpu().item()
                iou_mask = (iou > 0).float()
                mean_iou = torch.sum(iou) / (torch.sum(iou_mask) + 1e-8)
                aver_mean_iou += mean_iou.cpu().item()
        print("... validate epoch %d ..."%epoch)
        n_iter = len(ds_test)/BATCH_SIZE/N_DATA
        print("average loss: %.4f"%(aver_loss/n_iter))
        print("average iou: %.4f"%(aver_mean_iou/n_iter))
        print("..............................")

In [6]:
parser = argparse.ArgumentParser()
parser.add_argument("--loss", type=str, default="diou", help="type of loss function. support: diou or giou. [default: diou]")
parser.add_argument("--enclosing", type=str, default="smallest", 
    help="type of enclosing box. support: aligned (axis-aligned) or pca (rotated) or smallest (rotated). [default: smallest]")
flags = parser.parse_args([])
main(flags.loss, flags.enclosing)

... generating 819200 boxes, please wait...
... generating 81920 boxes, please wait...
data saved in:  ./data
[Epoch 1: 10/200] train loss: 0.6570  mean_iou: 0.3814
[Epoch 1: 20/200] train loss: 0.5029  mean_iou: 0.5106
[Epoch 1: 30/200] train loss: 0.4490  mean_iou: 0.5581
[Epoch 1: 40/200] train loss: 0.4389  mean_iou: 0.5680
[Epoch 1: 50/200] train loss: 0.4135  mean_iou: 0.5910
[Epoch 1: 60/200] train loss: 0.4179  mean_iou: 0.5879
[Epoch 1: 70/200] train loss: 0.4249  mean_iou: 0.5827
[Epoch 1: 80/200] train loss: 0.4280  mean_iou: 0.5803
[Epoch 1: 90/200] train loss: 0.3796  mean_iou: 0.6244
[Epoch 1: 100/200] train loss: 0.3888  mean_iou: 0.6174
[Epoch 1: 110/200] train loss: 0.3506  mean_iou: 0.6527
[Epoch 1: 120/200] train loss: 0.3598  mean_iou: 0.6450
[Epoch 1: 130/200] train loss: 0.4119  mean_iou: 0.5988
[Epoch 1: 140/200] train loss: 0.3848  mean_iou: 0.6230
[Epoch 1: 150/200] train loss: 0.4175  mean_iou: 0.5945
[Epoch 1: 160/200] train loss: 0.3684  mean_iou: 0.6382
[Ep

[Epoch 7: 170/200] train loss: 0.3723  mean_iou: 0.6414
[Epoch 7: 180/200] train loss: 0.2888  mean_iou: 0.7181
[Epoch 7: 190/200] train loss: 0.2551  mean_iou: 0.7504
[Epoch 7: 200/200] train loss: 0.3012  mean_iou: 0.7067
... validate epoch 7 ...
average loss: 0.3280
average iou: 0.6830
..............................
[Epoch 8: 10/200] train loss: 0.5161  mean_iou: 0.5107
[Epoch 8: 20/200] train loss: 0.3308  mean_iou: 0.6788
[Epoch 8: 30/200] train loss: 0.4051  mean_iou: 0.6104
[Epoch 8: 40/200] train loss: 0.2241  mean_iou: 0.7799
[Epoch 8: 50/200] train loss: 0.2678  mean_iou: 0.7384
[Epoch 8: 60/200] train loss: 0.2064  mean_iou: 0.7969
[Epoch 8: 70/200] train loss: 0.2512  mean_iou: 0.7544
[Epoch 8: 80/200] train loss: 0.3026  mean_iou: 0.7062
[Epoch 8: 90/200] train loss: 0.2884  mean_iou: 0.7186
[Epoch 8: 100/200] train loss: 0.3721  mean_iou: 0.6408
[Epoch 8: 110/200] train loss: 0.3391  mean_iou: 0.6710
[Epoch 8: 120/200] train loss: 0.2536  mean_iou: 0.7518
[Epoch 8: 130/20

[Epoch 14: 110/200] train loss: 0.1465  mean_iou: 0.8549
[Epoch 14: 120/200] train loss: 0.1665  mean_iou: 0.8356
[Epoch 14: 130/200] train loss: 0.1784  mean_iou: 0.8240
[Epoch 14: 140/200] train loss: 0.1548  mean_iou: 0.8468
[Epoch 14: 150/200] train loss: 0.1288  mean_iou: 0.8722
[Epoch 14: 160/200] train loss: 0.1558  mean_iou: 0.8459
[Epoch 14: 170/200] train loss: 0.1769  mean_iou: 0.8254
[Epoch 14: 180/200] train loss: 0.1891  mean_iou: 0.8136
[Epoch 14: 190/200] train loss: 0.1311  mean_iou: 0.8699
[Epoch 14: 200/200] train loss: 0.1438  mean_iou: 0.8576
... validate epoch 14 ...
average loss: 0.1228
average iou: 0.8781
..............................
[Epoch 15: 10/200] train loss: 0.1907  mean_iou: 0.8121
[Epoch 15: 20/200] train loss: 0.1352  mean_iou: 0.8660
[Epoch 15: 30/200] train loss: 0.1419  mean_iou: 0.8594
[Epoch 15: 40/200] train loss: 0.1268  mean_iou: 0.8741
[Epoch 15: 50/200] train loss: 0.1733  mean_iou: 0.8287
[Epoch 15: 60/200] train loss: 0.1590  mean_iou: 0.8