In [1]:
import torch
from unet_vanilla import UNet, RoadSegmentData
import albumentations as A
from albumentations.pytorch import ToTensorV2
import glob, os
from torch.utils.data import Dataset, DataLoader
import torchvision
import cv2

In [2]:
class RoadSegmentEvalAugmentData(Dataset):
    def __init__(self, image_names, image_path, mask_path, transform = None):
        self.image_names = image_names
        self.image_path = image_path
        self.mask_path = mask_path
        self.transform = transform
        self.num_data = len(self.image_names)*16

    def __len__(self):
        return self.num_data
        
    def __getitem__(self, idx):
        image_filename = self.image_names[idx//16]
        image = cv2.imread(os.path.join(self.image_path, image_filename))
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        mask = cv2.imread(os.path.join(self.mask_path, image_filename),  cv2.IMREAD_GRAYSCALE)/255
        mask = mask.reshape(tuple(list(mask.shape)+[1]))
        rotate_num = ((idx%16)%4)
        rcode = [0,cv2.ROTATE_90_CLOCKWISE, cv2.ROTATE_180, cv2.ROTATE_90_COUNTERCLOCKWISE][rotate_num]
        if rotate_num !=0:
            image = cv2.rotate(image,rcode)
            mask = cv2.rotate(mask,rcode)
        # print(image.shape, mask.shape)
        if self.transform is not None:
            transformed_image = self.transform(image = image,
                                              mask = mask)
            image = transformed_image['image']
            mask = transformed_image['mask']
        return image, mask

In [3]:
only_rotation = False

img_transform = A.Compose(
        [
            A.RGBShift(r_shift_limit=10, g_shift_limit=10, b_shift_limit=10,p=0.75),
            A.RandomBrightnessContrast(brightness_limit=0.3, contrast_limit=0.3, p=0.75),
            ToTensorV2(transpose_mask=True)
        ]
    ) if not only_rotation else A.Compose([ToTensorV2(transpose_mask=True)])


image_names = [img.split('/')[-1] for img in glob.glob("./Vanilla Dataset/test/images/*")]
image_path = './Vanilla Dataset/test/images'
mask_path = './Vanilla Dataset/test/images'
if os.name == 'nt':
    image_names = [img.split('\\')[-1] for img in image_names]
test_data = RoadSegmentEvalAugmentData(sorted(image_names), image_path, mask_path, img_transform)
test_dataloader = DataLoader(test_data, batch_size=1, shuffle=False)

In [4]:
torch.cuda.empty_cache()
model = torch.load('./gmap_unet_seeded_ftm')
save_path = './fin_unet_ftm_eta_rpa'
if not os.path.exists(save_path):
    os.makedirs(save_path)
    os.makedirs(save_path+'/inputs')
    os.makedirs(save_path+'/augmented')
    os.makedirs(save_path+'/final')

In [5]:
model.cuda()
model.eval()

UNet(
  (enc_layers): ModuleList(
    (0): Sequential(
      (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
      (3): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (4): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU()
    )
    (1): Sequential(
      (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
      (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU()
    )
    (2): Sequential(
      (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (1):

In [7]:
image_names_augmented = sum([[f'{img.split(".")[0]}_{i}.png' for i in range(16)] for img in image_names],[])
with torch.no_grad():
    for batch, (img,(x, _)) in enumerate(zip(image_names_augmented, test_dataloader)):
            torchvision.utils.save_image(x[0]/255, save_path + f'/inputs/{img}')
            x = x.float().cuda()
            y_pred = model(x)
            torchvision.utils.save_image(y_pred[0], save_path + f'/augmented/{img}')

In [8]:
for img in image_names:
    li = []
    for i in range(16):
        mask = cv2.imread(save_path + f'/augmented/{img.split(".")[0]}_{i}.png',  cv2.IMREAD_GRAYSCALE)/255
        rotate_num = (4-(i%4))%4
        rcode = [0,cv2.ROTATE_90_CLOCKWISE, cv2.ROTATE_180, cv2.ROTATE_90_COUNTERCLOCKWISE][rotate_num]
        if rotate_num!=0:
            mask = cv2.rotate(mask, rcode)
        li.append(mask)
    sm = li[0]
    for i in range(15):
        sm+=li[i+1]
    out_mask = (sm/16)
    out_mask = torch.Tensor(out_mask)
    torchvision.utils.save_image(out_mask, save_path + f'/final/{img}')