In [26]:
%matplotlib inline

import numpy as np
import torch
import matplotlib.pyplot as plt
import sys
sys.path.append('./')
from dataset import tuSimpleDataset
from torch.utils.data import DataLoader
from models.segnet import SegNet
from models.enet import ENet
from models.resnet38 import ResNet38
from models.bisenet import BiSeNet
from models.enet_k import ENet as ENet_k
from models.resnet38_k import ResNet38 as ResNet38_k
import torchvision
from scipy import ndimage as ndi
from sklearn.cluster import DBSCAN

INPUT_CHANNELS = 3
OUTPUT_CHANNELS = 2
BATCH_SIZE = 5
SIZE = [224, 224]

In [27]:
# refer from : https://github.com/nyoki-mtl/pytorch-discriminative-loss/blob/master/src/utils.py
def coloring(mask):
    ins_color_img = np.zeros((mask.shape[0], mask.shape[1], 3), dtype=np.uint8)
    n_ins = len(np.unique(mask)) - 1
    colors = [plt.cm.Spectral(each) for each in np.linspace(0, 1, n_ins)]
    for i in range(n_ins):
        ins_color_img[mask == i + 1] =\
            (np.array(colors[i][:3]) * 255).astype(np.uint8)
    return ins_color_img

def gen_instance_mask(sem_pred, ins_pred, n_obj):
    embeddings = ins_pred[:, sem_pred].transpose(1, 0)
#     clustering = KMeans(n_obj).fit(embeddings)
    clustering = DBSCAN(eps=0.05).fit(embeddings)
    labels = clustering.labels_

    instance_mask = np.zeros_like(sem_pred, dtype=np.uint8)
    for i in range(n_obj):
        lbl = np.zeros_like(labels, dtype=np.uint8)
        lbl[labels == i] = i + 1
        instance_mask[sem_pred] += lbl

    return instance_mask

In [28]:
def expand_mask(sem_labels, a=3):
    lst = []
    for i in list(range(-a, a)):
        for j in list(range(-a, a)):
            if i**2+j**2<=a**2:
                lst.append(torchvision.transforms.functional.affine(sem_labels ,translate=(i,j),angle=0,scale=1,shear=0))
    msk = lst[0]
    for msks in lst:
        msk += msks
    msk[msk>1]=1
    return msk
def eval_loss(filename = 'enet_model_best', shuffle=False, model = 'enet', data_part = 'test', expand_dim=3, split_train=False, data_num=100, soft=True, seed=None, loss_func='MSE'):
    if data_part=='train':
        test_path = '../TUSimple/train_set'
    else:
        test_path = '../TUSimple/test_set'
    # MODEL_PATH = '../model_best_enet.pth'
    MODEL_PATH = f'results/{filename}.pth'
    test_dataset = tuSimpleDataset(test_path, size=SIZE, train=True)
    if data_part=='train' and split_train==True:
      train_len = int(len(test_dataset)*0.8)
      if seed is not None:
         torch.manual_seed(seed)
      _, test_dataset = torch.utils.data.random_split(test_dataset, [train_len, len(test_dataset)-train_len])
    test_dataset = torch.utils.data.Subset(test_dataset, list(range(data_num)))
    test_dataloader = DataLoader(dataset=test_dataset, batch_size=BATCH_SIZE, shuffle=shuffle, num_workers=8)
    if model == 'enet': 
       model = ENet(input_ch=INPUT_CHANNELS, output_ch=OUTPUT_CHANNELS).cuda()
    if model == 'segnet':
       model = SegNet(input_ch=INPUT_CHANNELS, output_ch=OUTPUT_CHANNELS).cuda()
    if model == 'enet_k': 
       model = ENet_k(input_ch=INPUT_CHANNELS, output_ch=OUTPUT_CHANNELS).cuda()
    if model == 'resnet38': 
       model = ResNet38().cuda()
    if model == 'bisenet': 
       model = BiSeNet(32, 'resnet18').cuda()
    if model == 'resnet38_k': 
       model = ResNet38_k().cuda()

    model.load_state_dict(torch.load(MODEL_PATH))
    model.eval()
    losses = []
    for imgs, sem_labels, ins_labels in test_dataloader:
    #imgs, sem_labels, ins_labels = next(iter(test_dataloader))
      sem_labels_exp = expand_mask(sem_labels, expand_dim)
      
      input_tensor = torch.autograd.Variable(imgs).cuda()
      sem_pred_, ins_pred_ = model(input_tensor)
      images = input_tensor.permute(0,2,3,1).contiguous().cpu().data.numpy()
      images = np.array(images, dtype=np.uint8)
      sem_pred = sem_pred_[:,1,:,:].cpu().data.numpy()
      #  print(sem_pred, np.max(sem_pred))
      ins_pred = ins_pred_.cpu().data.numpy()
      p_sem_pred = []
      for sp in sem_pred:
         p_sem_pred.append(ndi.morphology.binary_fill_holes(sp > 0.5))
      p_sem_pred = torch.tensor(p_sem_pred).float()
      sem_labels_soft = p_sem_pred*sem_labels_exp+sem_labels
      sem_labels_soft[sem_labels_soft>1]=1
      if loss_func == 'MSE':
         criterion = torch.nn.MSELoss().cuda()
      if loss_func == 'CE':
         criterion = torch.nn.CrossEntropyLoss().cuda()
         sem_labels_soft = sem_labels_soft.float()
         sem_labels = sem_labels.float()
   
      if soft:
         loss = criterion(p_sem_pred, sem_labels_soft)
      else:
         loss = criterion(p_sem_pred, sem_labels)
      losses.append(loss.item())
    print(np.mean(losses))


In [39]:
eval_loss('segnet_model_best', model='segnet', data_part='test', data_num=100, soft=False)
eval_loss('segnet_model_best', model='segnet', data_part='test', data_num=100, soft=True)

0.017807517619803547
0.012315250327810645


In [40]:
eval_loss('enet_model_best', model='enet', data_part='test', data_num=100, soft=False)
eval_loss('enet_model_best', model='enet', data_part='test', data_num=100, soft=True)

0.02105727819725871
0.017245695134624837


In [30]:
eval_loss('enet_weight_model_best', model='enet', data_part='test', data_num=100, soft=False)
eval_loss('enet_weight_model_best', model='enet', data_part='test', data_num=100, soft=True)

0.02461934005841613
0.01087292719166726


In [31]:
eval_loss('enet_k_model_best_train', model='enet_k', data_part='test', data_num=100, soft=False)
eval_loss('enet_k_model_best_train', model='enet_k', data_part='test', data_num=100, soft=True)

0.023065808322280647
0.016902901930734515


In [32]:
eval_loss('enet_k_weight_model_best_train', model='enet_k', data_part='test', data_num=100, soft=False)
eval_loss('enet_k_weight_model_best_train', model='enet_k', data_part='test', data_num=100, soft=True)

0.029070471972227098
0.011401068232953549


In [33]:
eval_loss('enet_0.1_weight_model_best_train', model='enet', data_part='test', data_num=100, soft=False)
eval_loss('enet_0.1_weight_model_best_train', model='enet', data_part='test', data_num=100, soft=True)

0.025105229578912258
0.010416334494948387


In [34]:
eval_loss('enet_k_0.1_weight_model_best_train', model='enet_k', data_part='test', data_num=100, soft=False)
eval_loss('enet_k_0.1_weight_model_best_train', model='enet_k', data_part='test', data_num=100, soft=True)

0.02277264026924968
0.012506776209920644


In [35]:
eval_loss('resnet38_model_best_train', model='resnet38', data_part='test', data_num=100, soft=False)
eval_loss('resnet38_model_best_train', model='resnet38', data_part='test', data_num=100, soft=True)

0.018891502125188708
0.014070272678509355


In [36]:
eval_loss('resnet38_weight_0.1_model_best_train', model='resnet38', data_part='test', data_num=100, soft=False)
eval_loss('resnet38_weight_0.1_model_best_train', model='resnet38', data_part='test', data_num=100, soft=True)

0.019853914296254514
0.010670639271847904


In [37]:
eval_loss('resnet38_k_model_best_train', model='resnet38_k', data_part='test', data_num=100, soft=False)
eval_loss('resnet38_k_model_best_train', model='resnet38_k', data_part='test', data_num=100, soft=True)

0.02053770739585161
0.014867267338559031


In [38]:
eval_loss('resnet38_k_weight_0.1_model_best_train', model='resnet38_k', data_part='test', data_num=100, soft=False)
eval_loss('resnet38_k_weight_0.1_model_best_train', model='resnet38_k', data_part='test', data_num=100, soft=True)

0.022987484093755485
0.010872130002826452
