In [None]:
import numpy as np
import torch
import os
import pickle
import matplotlib.pyplot as plt
# %matplotlib inline
# plt.rcParams['figure.figsize'] = (20, 20)
# plt.rcParams['image.interpolation'] = 'bilinear'

import sys
sys.path.append('../train/')

from torch.autograd import Variable
from torch.utils.data import DataLoader, Dataset
import torchvision.datasets as datasets
import torchvision
import torchvision.transforms as T
import torch.nn.functional as F
import torch.nn as nn

import collections
import numbers
import random
import math
from PIL import Image, ImageOps, ImageEnhance
import time
from torch.utils.data import Dataset

from networks.SegUNet_new_version import SegUNet_new_version
import tool
from tqdm import tqdm

flip_index = ['16', '15', '14', '13', '12', '11', '10']

In [None]:
NUM_CHANNELS = 3
NUM_CLASSES = 2 
BATCH_SIZE = 8
W, H = 1918, 1280
STRIDE = 256
IMAGE_SIZE = 512
test_mask_path = '../../data/test_masks/SegUNet_new_version/'
weight_path = '../_weights/SegUNet_new_version-fold0-0.00497.pth'

In [None]:
def load_model(filename, model):
    checkpoint = torch.load(filename)
    model.load_state_dict(checkpoint['model_state'])

In [None]:
model = SegUNet_new_version(num_classes=NUM_CLASSES)
model = model.cuda()
model.eval()
load_model(weight_path, model)

In [None]:
test_path = '../../data/images/test/'

if not os.path.exists(test_mask_path):
    os.makedirs(test_mask_path)

In [None]:
test_names = os.listdir(test_path)
test_names = sorted(test_names)

In [None]:
with torch.no_grad():
    batch_size = BATCH_SIZE
    normalize_mean = [.485, .456, .406]
    normalize_std = [.229, .224, .225]

    test_names = sorted(os.listdir(test_path))
    for image_pack in tqdm(range(len(test_names) // batch_size)):
        images = np.zeros((batch_size, 3, H, W), dtype='float32')
        test_masks = np.zeros((batch_size, 2, H, W), dtype='float32')
        ifflip = [False] * batch_size
        image_batch_names = test_names[image_pack * batch_size: image_pack * batch_size + batch_size]
        mask_names = [input_name.split('.')[0] + '.png' for input_name in image_batch_names]
        
        for idx, image_name in enumerate(image_batch_names):
            image = Image.open(os.path.join(test_path, image_name))
            angle = image_name.split('.')[0].split('_')[-1]
            if angle in flip_index:
                ifflip[idx] = True
                image = ImageOps.mirror(image)

            image = np.array(image).astype('float') / 255
            image = image.transpose(2, 0, 1)

            for i in range(3):
                image[i] = (image[i] - normalize_mean[i]) / normalize_std[i]

            images[idx] = image

        for h_idx in range(int(math.ceil((H - STRIDE) / STRIDE))):
            h_start = h_idx * STRIDE
            h_end = h_start + IMAGE_SIZE
            if h_end > H:
                h_end = H
                h_start = h_end - IMAGE_SIZE
            for w_idx in range(int(math.ceil((W - STRIDE) / STRIDE))):
                w_start = w_idx * STRIDE
                w_end = w_start + IMAGE_SIZE
                if w_end > W:
                    w_end = W
                    w_start = w_end - IMAGE_SIZE

                input_batchs = images[:, :, h_start:h_end, w_start:w_end]
                input_tensor = torch.from_numpy(input_batchs).cuda()

                inputs = Variable(input_tensor, )
                outputs = model(inputs)
                ouputs = outputs.cpu().data.numpy()

                test_masks[:, :, h_start:h_end, w_start:w_end] += ouputs
        
        test_masks = np.argmax(test_masks, axis=1).astype('uint8')
        for idx in range(batch_size):
            output_PIL = Image.fromarray(test_masks[idx].astype('uint8')*255).convert('1')
            if ifflip[idx]:
                output_PIL = ImageOps.mirror(output_PIL)
            mask_name = mask_names[idx]
            output_PIL.save(test_mask_path + mask_name)
