In [1]:
import torch
import torch.nn as nn
import torchvision.transforms.functional as TF
from torchvision.models import vgg16_bn
import os
from PIL import Image
import numpy as np
from torch.utils.data import Dataset, DataLoader
import torchvision
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
import albumentations as A
from albumentations.pytorch import ToTensorV2
from tqdm import tqdm
import torch.optim as optim
import torch.cuda.amp as amp
import time
from torch_snippets import *
from torchvision import transforms



In [2]:
from UNet import UNet
from tools import label_img, blend_mask, render_mask, get_data_split, get_mean_std, validate_batch,generate_mask
from data import get_loaders,get_test_loaders, get_measure_loaders, TestDataset
from train import set_model, set_optimizer, train_batch

In [8]:
#ORIGINAL_HEIGHT = 2064
#ORIGINAL_WIDTH = 3088
#SCALING = 0.5
#IMAGE_HEIGHT = int(ORIGINAL_HEIGHT*SCALING)
#IMAGE_WIDTH = int(ORIGINAL_WIDTH*SCALING)

#IMAGE_HEIGHT = 512
#IMAGE_WIDTH = 1024

IMAGE_WIDTH = 888
IMAGE_HEIGHT = 608

#IMAGE_WIDTH = 772
#IMAGE_HEIGHT = 516

#IMAGE_WIDTH = 512
#IMAGE_HEIGHT = 608


DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MODEL_PATH = "/home/robothuman/Documents/wabtec_best_models/UNet_s5-30-6-49_epoch18_vwabtec_v3.0_loss0.081.pth"
BATCH_SIZE = 1

TRANSFORM = 'original'

DIR = "/home/robothuman/Documents/data_sets/test/switches/cropped"
DATA =  sorted(os.listdir(DIR))

DESTINATION_DIR = "/home/robothuman/Documents/data_sets/test/switches/labeled/"
DESTINATION_MASK_DIR = "/home/robothuman/Documents/data_sets/test/switches/mask/"

In [5]:
GET_MEAN_STD=False
if(GET_MEAN_STD):
    transform = transforms.Compose([
        transforms.ToTensor(),
    ])
    X_train= sorted(os.listdir(DIR))
    dl = get_measure_loaders(X_train, DIR, transform)
    MEAN, STD = get_mean_std(dl)
else:
    MEAN=[0.4204, 0.4005, 0.4424]
    STD=[0.1340, 0.1375, 0.1472]
print(MEAN,STD)

In [6]:
transform = A.Compose([
        #A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
        A.Normalize(
            mean=MEAN,
            std=STD,
            max_pixel_value=255.0,
        ),
        ToTensorV2(),
    ])

In [9]:
model = UNet(1).to(device=DEVICE)
print("model loading")
state_dict = torch.load(MODEL_PATH)
model.load_state_dict(state_dict)
model.to(DEVICE)

UNet(
  (encoder): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (7): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (9): ReLU(inplace=True)
    (10): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (12): ReLU(inplace=True)
    (13): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (14): Conv2d(128, 256

In [10]:
def create_masks(img_dl, model, img_shape, source_dir, destination_dir, DESTINATION_MASK_DIR):
    times = []
    accumulated = [0.0]
    for bx, data in enumerate(img_dl):
        images, names = data
        start = time.time()
        pred = label_img(images.__getitem__(0), model, img_shape, DEVICE)
        end = time.time()
        frame_time = end - start
        times.append(frame_time)
        accumulated.append(accumulated[len(accumulated) - 1] + frame_time)
        for i in range(len(names)):
            img_path = os.path.join(source_dir, names[i])
            original_img = np.array(Image.open(img_path).convert("RGB"))
            resized = original_img
            #resized = cv2.resize(original_img, (img_shape[1],img_shape[0]), interpolation = cv2.INTER_AREA)
            
            ann_img = render_mask(resized, pred)
            #pred_img = Image.fromarray(ann_img.astype(np.uint8))
            #pred_img.save(pred_dir + names[i])
            blended = blend_mask(resized, pred)
            binary_mask = generate_mask(resized, pred)
            #binary_mask = np.squeeze(binary_mask,2)
            binary_mask = binary_mask.astype(np.uint8)
            binary_mask = Image.fromarray(binary_mask)
            mask = Image.fromarray(blended)
            mask.save(destination_dir + names[i])
            binary_mask.save(DESTINATION_MASK_DIR + names[i])
    return times, accumulated

In [11]:
img_dl = get_test_loaders(DATA, DIR, transform)
times, accumulated = create_masks(img_dl, model, (IMAGE_HEIGHT,IMAGE_WIDTH), DIR, DESTINATION_DIR, DESTINATION_MASK_DIR)