In [1]:
import os

from torch.autograd import Variable
import torch
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
import albumentations as albu
import numpy as np
import cv2
import segmentation_models_pytorch as smp
import matplotlib.pyplot as plt

import config as cfg



  from .autonotebook import tqdm as notebook_tqdm


In [2]:

DEVICE='cuda'

dummy_input = torch.randn(1, 3, 800, 1600).to(DEVICE)

In [3]:
# resnet50 = torch.load('./Models/Deeplabv3P_resnet50_batch16.pth')
# resnet50.eval()
# torch.onnx.export(resnet50.module, dummy_input, "SMP_resnet50.onnx", verbose=True, )

In [4]:
models = os.listdir(cfg.MODEL_DIR)
models

['Deeplabv3P_resnet152_batch8.pth',
 'Deeplabv3P_resnext50_batch12.pth',
 'Deeplabv3P_resnext101_batch8.pth',
 'Deeplabv3P_resnet101_batch12.pth',
 'Deeplabv3P_se_resnet50_batch12.pth',
 'Deeplabv3P_resnet50_batch16.pth']

In [5]:
for model_name in models:
    smp_model = torch.load(os.path.join(cfg.MODEL_DIR, model_name))
    smp_model.eval()
    
    onnx_name = model_name.split(".")[-1]
    torch.onnx.export(smp_model.module, dummy_input, onnx_name, verbose=True)
    print(f"{onnx_name} DONE!")
    



In [None]:
# resnet50 = torch.load('./Models/Deeplabv3P_resnet50_batch16.pth')
# resnet50.eval()
# resnet101 = torch.load('./Models/Deeplabv3P_resnet101_batch12.pth')
# resnet101.eval()
# resnet152 = torch.load('./Models/Deeplabv3P_resnet152_batch8.pth')
# resnet152.eval()
# resnext50 = torch.load('./Models/Deeplabv3P_resnext50_batch12.pth')
# resnext50.eval()
# resnext101 = torch.load('./Models/Deeplabv3P_resnext101_batch8.pth')
# resnext101.eval()
# se_resnet50 = torch.load('./Models/Deeplabv3P_se_resnet50_batch12.pth')
# se_resnet50.eval()

In [None]:
def to_tensor(x, **kwargs):
    return x.transpose(2, 0, 1).astype('float32')

In [None]:
def get_preprocessing(preprocessing_fn):
    """Construct preprocessing transform
    
    Args:
        preprocessing_fn (callbale): data normalization function 
            (can be specific for each pretrained neural network)
    Return:
        transform: albumentations.Compose
    
    """
    
    _transform = [
        albu.Lambda(image=preprocessing_fn),
        albu.Lambda(image=to_tensor, mask=to_tensor),
    ]
    return albu.Compose(_transform)

In [None]:
class InferenceDataset(Dataset):
    """CamVid Dataset. Read images, apply augmentation and preprocessing transformations.
    
    Args:
        images_dir (str): path to images folder
        masks_dir (str): path to segmentation masks folder
        class_values (list): values of classes to extract from segmentation mask
        augmentation (albumentations.Compose): data transfromation pipeline 
            (e.g. flip, scale, etc.)
        preprocessing (albumentations.Compose): data preprocessing 
            (e.g. noralization, shape manipulation, etc.)
    
    """
    
    CLASSES = ['bg', 'stas']
    
    def __init__(
            self, 
            images_dir,
            preprocessing=None,
    ):
        self.ids = os.listdir(images_dir)
        self.images_fps = [os.path.join(images_dir, image_id) for image_id in self.ids]
        
        self.preprocessing = preprocessing
    
    def __getitem__(self, i):
        
        # read data
        image = cv2.imread(self.images_fps[i])
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        if image.shape != (800, 800, 3):
            image = cv2.resize(image, (1600, 800), interpolation=cv2.INTER_LANCZOS4)

        # apply preprocessing
        if self.preprocessing:
            sample = self.preprocessing(image=image)
            image = sample['image']
            
        return image
        
    def __len__(self):
        return len(self.ids)

In [None]:
ENCODER = 'resnet50'
ENCODER_WEIGHTS = 'imagenet'

preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS)

In [None]:
inference_dataset = InferenceDataset(
    cfg.INFERENCE_IMAGE_DIR, 
    preprocessing=get_preprocessing(preprocessing_fn)
)

In [None]:
for i in range(len(inference_dataset)):
    name = os.path.basename(inference_dataset.images_fps[i])
    image_vis = inference_dataset[i][0]
    image = inference_dataset[i]
    print(image.shape)
    
    x_tensor = torch.from_numpy(image).to(DEVICE).unsqueeze(0)
    with torch.no_grad():
        pr_mask1 = resnet50(x_tensor)
        pr_mask2 = resnet101(x_tensor)
        pr_mask3 = resnet152(x_tensor)
        pr_mask4 = resnext50(x_tensor)
        pr_mask5 = resnext101(x_tensor)
        pr_mask6 = se_resnet50(x_tensor)
    pr_mask = (pr_mask1 + pr_mask2 + pr_mask3 + pr_mask4 + pr_mask5 + pr_mask6) / 6
    pr_mask += 0.07
    pr_mask = (pr_mask.squeeze().cpu().numpy().round())
    zeros = np.zeros((800, 1600, 3))
    zeros[...,0] = pr_mask
    zeros[...,1] = pr_mask
    zeros[...,2] = pr_mask
    zeros = cv2.resize(zeros, (1716, 942))
    plt.imsave(os.path.join(cfg.ENSEMBLE_PRED_DIR, name.replace('.jpg','.png')), zeros)