In [1]:
import os
from skimage import io, transform
import torch
import torchvision
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms#, utils
# import torch.optim as optim

import numpy as np
from PIL import Image
import glob

import sys

from tqdm.notebook import tqdm_notebook

In [2]:
from data_loader import RescaleT
from data_loader import ToTensor
from data_loader import ToTensorLab
from data_loader import SalObjDataset

from model import U2NET # full size version 173.6 MB
from model import U2NETP # small version u2net 4.7 MB

bulkDataPath = os.path.join('D:\\', 'CMPT_489_Bulk', 'Data')
print(bulkDataPath)
print("Path exists:", os.path.isdir(bulkDataPath))

D:\CMPT_489_Bulk\Data
Path exists: True


In [3]:
# normalize the predicted SOD probability map
def normPRED(d):
    ma = torch.max(d)
    mi = torch.min(d)

    dn = (d-mi)/(ma-mi)

    return dn

def save_output(image_name,pred,d_dir):

    predict = pred
    predict = predict.squeeze()
    predict_np = predict.cpu().data.numpy()

    im = Image.fromarray(predict_np*255).convert('RGB')
    img_name = image_name.split(os.sep)[-1]
    image = io.imread(image_name)
    imo = im.resize((image.shape[1],image.shape[0]),resample=Image.BILINEAR)

    pb_np = np.array(imo)

    aaa = img_name.split(".")
    bbb = aaa[0:-1]
    imidx = bbb[0]
    for i in range(1,len(bbb)):
        imidx = imidx + "." + bbb[i]

    imo.save(d_dir+imidx+'.png')

In [4]:
# --------- 1. get image path and name ---------
model_name='u2netp'#u2netp



image_dir = os.path.join(bulkDataPath, 'example_matching_data', 'val\\')
prediction_dir = os.path.join(os.getcwd(), 'test_data', model_name + '_vizwiz_example_results' + os.sep)
model_dir = os.path.join(os.getcwd(), 'saved_models', model_name, model_name + '.pth')

img_name_list = glob.glob(image_dir + os.sep + '*')
print(img_name_list[0:5])

['D:\\CMPT_489_Bulk\\Data\\example_matching_data\\val\\VizWiz_train_00000001.png', 'D:\\CMPT_489_Bulk\\Data\\example_matching_data\\val\\VizWiz_train_00000041.png', 'D:\\CMPT_489_Bulk\\Data\\example_matching_data\\val\\VizWiz_train_00000080.png']


In [5]:
# --------- 2. dataloader ---------
#1. dataloader
test_salobj_dataset = SalObjDataset(img_name_list = img_name_list,
                                    lbl_name_list = [],
                                    transform=transforms.Compose([RescaleT(320),
                                                                  ToTensorLab(flag=0)])
                                    )
test_salobj_dataloader = DataLoader(test_salobj_dataset,
                                    batch_size=1,
                                    shuffle=False,
                                    num_workers=1)

In [6]:
# --------- 3. model define ---------
if(model_name=='u2net'):
    print("...load U2NET---173.6 MB")
    net = U2NET(3,1)
elif(model_name=='u2netp'):
    print("...load U2NEP---4.7 MB")
    net = U2NETP(3,1)

if torch.cuda.is_available():
    print("gpu")
    net.load_state_dict(torch.load(model_dir))
    net.cuda()
else:
    print("cpu")
    net.load_state_dict(torch.load(model_dir, map_location='cpu'))
net.eval()

...load U2NEP---4.7 MB
gpu


U2NETP(
  (stage1): RSU7(
    (rebnconvin): REBNCONV(
      (conv_s1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn_s1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu_s1): ReLU(inplace=True)
    )
    (rebnconv1): REBNCONV(
      (conv_s1): Conv2d(64, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn_s1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu_s1): ReLU(inplace=True)
    )
    (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=True)
    (rebnconv2): REBNCONV(
      (conv_s1): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn_s1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu_s1): ReLU(inplace=True)
    )
    (pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=True)
    (rebnconv3): REBNCONV(
      (conv_s1): Conv2d(16, 16, 

In [7]:
# --------- 4. inference for each image ---------
for i_test, data_test in enumerate(tqdm_notebook(test_salobj_dataloader)):

    #print("inferencing:",img_name_list[i_test].split(os.sep)[-1])

    inputs_test = data_test['image']
    inputs_test = inputs_test.type(torch.FloatTensor)

    if torch.cuda.is_available():
        inputs_test = Variable(inputs_test.cuda())
    else:
        inputs_test = Variable(inputs_test)

    d1,d2,d3,d4,d5,d6,d7= net(inputs_test)

    # normalization
    pred = d1[:,0,:,:]
    pred = normPRED(pred)

    # save results to test_results folder
    if not os.path.exists(prediction_dir):
        os.makedirs(prediction_dir, exist_ok=True)
    save_output(img_name_list[i_test],pred,prediction_dir)

    del d1,d2,d3,d4,d5,d6,d7

  0%|          | 0/3 [00:00<?, ?it/s]

  imo = im.resize((image.shape[1],image.shape[0]),resample=Image.BILINEAR)
