In [1]:
import os
from skimage import io, transform
import torch
import torchvision
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms#, utils
# import torch.optim as optim

import numpy as np
from PIL import Image
import glob

import sys

from tqdm.notebook import tqdm_notebook

In [2]:
from data_loader import RescaleT
from data_loader import ToTensor
from data_loader import ToTensorLab
from data_loader import SalObjDataset

from model import U2NET # full size version 173.6 MB
from model import U2NETP # small version u2net 4.7 MB

bulkDataPath = os.path.join('D:\\', 'CMPT_489_Bulk', 'Data')
print(bulkDataPath)
print("Path exists:", os.path.isdir(bulkDataPath))

D:\CMPT_489_Bulk\Data
Path exists: True


In [3]:
# normalize the predicted SOD probability map
def normPRED(d):
    ma = torch.max(d)
    mi = torch.min(d)

    dn = (d-mi)/(ma-mi)

    return dn

def save_output(image_name,pred,d_dir):

    predict = pred
    predict = predict.squeeze()
    predict_np = predict.cpu().data.numpy()

    im = Image.fromarray(predict_np*255).convert('RGB')
    img_name = image_name.split(os.sep)[-1]
    image = io.imread(image_name)
    imo = im.resize((image.shape[1],image.shape[0]),resample=Image.BILINEAR)

    pb_np = np.array(imo)

    aaa = img_name.split(".")
    bbb = aaa[0:-1]
    imidx = bbb[0]
    for i in range(1,len(bbb)):
        imidx = imidx + "." + bbb[i]

    imo.save(d_dir+imidx+'.png')

In [16]:

def test(mn):
    print("--- testing "+ mn+ " ---")

    # --------- 1. get image path and name ---------
    model_name= mn #u2netp

    image_dir = os.path.join(bulkDataPath, 'sub_matching_data', 'val\\')
    prediction_dir = os.path.join(bulkDataPath, 'u2netp_results', model_name + os.sep)
    model_dir = os.path.join(os.getcwd(), 'saved_models', 'u2netp_sub_vizwiz', model_name + '.pth')

    
    img_name_list = glob.glob(image_dir + os.sep + '*')

    # --------- 2. dataloader ---------
    #1. dataloader
    test_salobj_dataset = SalObjDataset(img_name_list = img_name_list,
                                        lbl_name_list = [],
                                        transform=transforms.Compose([RescaleT(320),
                                                                      ToTensorLab(flag=0)])
                                        )
    test_salobj_dataloader = DataLoader(test_salobj_dataset,
                                        batch_size=1,
                                        shuffle=False,
                                        num_workers=1)

    # --------- 3. model define ---------
    if(model_name.split("_")[0]=='u2net'):
        print("...load U2NET---173.6 MB")
        net = U2NET(3,1)
    elif(model_name.split("_")[0]=='u2netp'):
        print("...load U2NEP---4.7 MB")
        net = U2NETP(3,1)

    if torch.cuda.is_available():
        net.load_state_dict(torch.load(model_dir))
        net.cuda()
    else:
        net.load_state_dict(torch.load(model_dir, map_location='cpu'))
    net.eval()

    # --------- 4. inference for each image ---------
    for i_test, data_test in enumerate(tqdm_notebook(test_salobj_dataloader)):

        # print("inferencing:",img_name_list[i_test].split(os.sep)[-1])

        inputs_test = data_test['image']
        inputs_test = inputs_test.type(torch.FloatTensor)

        if torch.cuda.is_available():
            inputs_test = Variable(inputs_test.cuda())
        else:
            inputs_test = Variable(inputs_test)

        d1,d2,d3,d4,d5,d6,d7= net(inputs_test)

        # normalization
        pred = d1[:,0,:,:]
        pred = normPRED(pred)

        # save results to test_results folder
        if not os.path.exists(prediction_dir):
            os.makedirs(prediction_dir, exist_ok=True)
        save_output(img_name_list[i_test],pred,prediction_dir)

        del d1,d2,d3,d4,d5,d6,d7



In [17]:

for root, dirs, files in os.walk(os.path.join(os.getcwd(), 'saved_models', 'u2netp_sub_vizwiz' + os.sep)):
    for filename in files:
        if filename[-4:] != '.pth':
            print(filename)
            continue
        test(filename[:-4])



--- testing u2netp_sub_vizwiz_bce_itr_10000_train_2.161453_tar_0.301635 ---
...load U2NEP---4.7 MB


  0%|          | 0/2500 [00:00<?, ?it/s]

--- testing u2netp_sub_vizwiz_bce_itr_1000_train_4.088663_tar_0.581359 ---
...load U2NEP---4.7 MB


  0%|          | 0/2500 [00:00<?, ?it/s]

--- testing u2netp_sub_vizwiz_bce_itr_11000_train_2.107132_tar_0.293022 ---
...load U2NEP---4.7 MB


  0%|          | 0/2500 [00:00<?, ?it/s]

--- testing u2netp_sub_vizwiz_bce_itr_12000_train_2.051164_tar_0.284970 ---
...load U2NEP---4.7 MB


  0%|          | 0/2500 [00:00<?, ?it/s]

--- testing u2netp_sub_vizwiz_bce_itr_13000_train_1.957236_tar_0.271083 ---
...load U2NEP---4.7 MB


  0%|          | 0/2500 [00:00<?, ?it/s]

--- testing u2netp_sub_vizwiz_bce_itr_14000_train_1.882877_tar_0.259915 ---
...load U2NEP---4.7 MB


  0%|          | 0/2500 [00:00<?, ?it/s]

--- testing u2netp_sub_vizwiz_bce_itr_15000_train_1.824028_tar_0.251176 ---
...load U2NEP---4.7 MB


  0%|          | 0/2500 [00:00<?, ?it/s]

--- testing u2netp_sub_vizwiz_bce_itr_16000_train_1.772376_tar_0.243316 ---
...load U2NEP---4.7 MB


  0%|          | 0/2500 [00:00<?, ?it/s]

--- testing u2netp_sub_vizwiz_bce_itr_17000_train_1.707042_tar_0.233546 ---
...load U2NEP---4.7 MB


  0%|          | 0/2500 [00:00<?, ?it/s]

--- testing u2netp_sub_vizwiz_bce_itr_18000_train_1.654696_tar_0.225983 ---
...load U2NEP---4.7 MB


  0%|          | 0/2500 [00:00<?, ?it/s]

--- testing u2netp_sub_vizwiz_bce_itr_19000_train_1.626005_tar_0.221538 ---
...load U2NEP---4.7 MB


  0%|          | 0/2500 [00:00<?, ?it/s]

--- testing u2netp_sub_vizwiz_bce_itr_20000_train_1.556143_tar_0.211508 ---
...load U2NEP---4.7 MB


  0%|          | 0/2500 [00:00<?, ?it/s]

--- testing u2netp_sub_vizwiz_bce_itr_2000_train_3.497421_tar_0.494223 ---
...load U2NEP---4.7 MB


  0%|          | 0/2500 [00:00<?, ?it/s]

--- testing u2netp_sub_vizwiz_bce_itr_21000_train_1.545365_tar_0.209725 ---
...load U2NEP---4.7 MB


  0%|          | 0/2500 [00:00<?, ?it/s]

--- testing u2netp_sub_vizwiz_bce_itr_22000_train_1.489582_tar_0.201485 ---
...load U2NEP---4.7 MB


  0%|          | 0/2500 [00:00<?, ?it/s]

--- testing u2netp_sub_vizwiz_bce_itr_23000_train_1.446585_tar_0.195125 ---
...load U2NEP---4.7 MB


  0%|          | 0/2500 [00:00<?, ?it/s]

--- testing u2netp_sub_vizwiz_bce_itr_3000_train_3.140997_tar_0.441988 ---
...load U2NEP---4.7 MB


  0%|          | 0/2500 [00:00<?, ?it/s]

--- testing u2netp_sub_vizwiz_bce_itr_4000_train_2.880650_tar_0.404842 ---
...load U2NEP---4.7 MB


  0%|          | 0/2500 [00:00<?, ?it/s]

--- testing u2netp_sub_vizwiz_bce_itr_5000_train_2.715347_tar_0.381688 ---
...load U2NEP---4.7 MB


  0%|          | 0/2500 [00:00<?, ?it/s]

--- testing u2netp_sub_vizwiz_bce_itr_6000_train_2.602593_tar_0.365563 ---
...load U2NEP---4.7 MB


  0%|          | 0/2500 [00:00<?, ?it/s]

--- testing u2netp_sub_vizwiz_bce_itr_7000_train_2.455525_tar_0.344322 ---
...load U2NEP---4.7 MB


  0%|          | 0/2500 [00:00<?, ?it/s]

--- testing u2netp_sub_vizwiz_bce_itr_8000_train_2.357673_tar_0.330342 ---
...load U2NEP---4.7 MB


  0%|          | 0/2500 [00:00<?, ?it/s]

--- testing u2netp_sub_vizwiz_bce_itr_9000_train_2.256964_tar_0.315464 ---
...load U2NEP---4.7 MB


  0%|          | 0/2500 [00:00<?, ?it/s]