In [1]:
import argparse
from prepare_train_val import get_split
from dataset import RoboticsDataset
import cv2
from models import UNet16, LinkNet34, UNet11, UNet, AlbuNet
import torch
from pathlib import Path
from tqdm import tqdm
import numpy as np
import utils
import prepare_data
from torch.utils.data import DataLoader
from torch.nn import functional as F
from prepare_data import (original_height,
                          original_width,
                          h_start, w_start
                          )
from albumentations import Compose, Normalize

In [2]:
def img_transform(p=1):
    return Compose([
        Normalize(p=1)
    ], p=p)

In [3]:
def predict(model, from_file_names, batch_size, to_path, problem_type, img_transform):
    loader = DataLoader(
        dataset=RoboticsDataset(from_file_names, transform=img_transform, mode='predict', problem_type=problem_type),
        shuffle=False,
        batch_size=batch_size,
        num_workers=12,
        pin_memory=torch.cuda.is_available()
    )

    with torch.no_grad():
        for batch_num, (inputs, paths) in enumerate(tqdm(loader, desc='Predict')):
            inputs = utils.cuda(inputs)

            outputs = model(inputs)

            for i, image_name in enumerate(paths):
                if problem_type == 'binary':
                    factor = prepare_data.binary_factor
                    t_mask = (F.sigmoid(outputs[i, 0]).data.cpu().numpy() * factor).astype(np.uint8)
                elif problem_type == 'parts':
                    factor = prepare_data.parts_factor
                    t_mask = (outputs[i].data.cpu().numpy().argmax(axis=0) * factor).astype(np.uint8)
                elif problem_type == 'instruments':
                    factor = prepare_data.instrument_factor
                    t_mask = (outputs[i].data.cpu().numpy().argmax(axis=0) * factor).astype(np.uint8)

                h, w = t_mask.shape

                full_mask = np.zeros((original_height, original_width))
                full_mask[h_start:h_start + h, w_start:w_start + w] = t_mask

                instrument_folder = Path(paths[i]).parent.parent.name

                (to_path / instrument_folder).mkdir(exist_ok=True, parents=True)

                cv2.imwrite(str(to_path / instrument_folder / (Path(paths[i]).stem + '.png')), full_mask)

In [4]:
def get_test_image(data_path = Path('data')):
    train_path = data_path / 'cropped_test'

    train_file_names = []
    val_file_names = []

    for instrument_id in range(1, 11):
            val_file_names += list((train_path / ('instrument_dataset_' + str(instrument_id)) / 'images').glob('*'))

    return train_file_names, val_file_names

In [5]:
from models import UNet11, LinkNet34, UNet, UNet16, AlbuNet
from model1 import LinkNet34_modified
from unetplusplus import UnetPlusPlus

In [6]:
model = LinkNet34_modified(num_classes=1)
model_name = 'LinkNet34_modified'
model_path = f'data/models/{model_name}/{model_name}.pt'
problem_type = 'binary'

state = torch.load(str(model_path))
state = {key.replace('module.', ''): value for key, value in state['model'].items()}
model.load_state_dict(state)

if torch.cuda.is_available():
    model.cuda()

  state = torch.load(str(model_path))


In [7]:
_, file_names = get_test_image()
print('num file_names = {}'.format(len(file_names)))

output_path = Path('predictions_test') / model_name / problem_type
output_path.mkdir(exist_ok=True, parents=True)

predict(model, file_names, 4, output_path, problem_type=problem_type,img_transform=img_transform(p=1))

num file_names = 1200


  out = F.upsample(out, size=(h, w), mode="bilinear")
Predict: 100%|██████████| 300/300 [00:33<00:00,  9.09it/s]


In [8]:
import os
from prepare_data import height, width, h_start, w_start

target_path = f'predictions_test/{model_name}'
train_path = 'data/cropped_test'

In [9]:
def jaccard(y_true, y_pred):
    intersection = (y_true * y_pred).sum()
    union = y_true.sum() + y_pred.sum() - intersection
    return (intersection + 1e-15) / (union + 1e-15)


def dice(y_true, y_pred):
    return (2 * (y_true * y_pred).sum() + 1e-15) / (y_true.sum() + y_pred.sum() + 1e-15)

In [10]:

result_dice = []
result_jaccard = []
if problem_type == 'binary':
    for instrument_id in tqdm(range(1, 9)):
        instrument_dataset_name = 'instrument_dataset_' + str(instrument_id)

        pred_folder_name = (Path(target_path) / 'binary' / instrument_dataset_name)
        if not os.path.exists(pred_folder_name):
            continue

        for file_name in (Path(train_path) / instrument_dataset_name / 'binary_masks').glob('*'):
            pred_file_name = (Path(target_path) / 'binary' / instrument_dataset_name / file_name.name)
            if not os.path.exists(pred_file_name):
                continue
            
            y_true = (cv2.imread(str(file_name), 0) > 0).astype(np.uint8)

            pred_image = (cv2.imread(str(pred_file_name), 0) > 255 * 0.5).astype(np.uint8)
            y_pred = pred_image[h_start:h_start + height, w_start:w_start + width]

            result_dice += [dice(y_true, y_pred)]
            result_jaccard += [jaccard(y_true, y_pred)]

100%|██████████| 8/8 [00:11<00:00,  1.44s/it]


In [11]:
print('Dice = ', np.mean(result_dice), np.std(result_dice))
print('Jaccard = ', np.mean(result_jaccard), np.std(result_jaccard))

Dice =  0.9253892676273673 0.10979946762755784
Jaccard =  0.8761245575755569 0.14805320313219114


In [14]:
from draw_gif import draw_contour,generate_gif_pillow,add_color
import os

In [10]:
import os

_, mask_names = get_test_image(Path(f'predictions_test/{model_name}/binary'))
to_path = Path('contour') / model_name / problem_type
to_path.mkdir(exist_ok=True, parents=True)

for i, image_name in enumerate(file_names):
    instrument_folder = Path(file_names[i]).parent.parent.name

    save_path = Path(to_path) / instrument_folder
    if not os.path.exists(save_path):
        os.mkdir(save_path)
    
    last = image_name.stem
    raw_path = Path('data/test') / instrument_folder / 'left_frames' / (last+'.png')
    mask_path = Path(f'predictions_test/{model_name}/binary') / instrument_folder / (last+'.png')
    save_path = Path(to_path) / instrument_folder / (last+'.png')
    # print(raw_path)
    # print(mask_path)
    # print(save_path)
    contour = draw_contour(raw_path,mask_path,save_path)
        
    # (to_path / instrument_folder).mkdir(exist_ok=True, parents=True)

    # cv2.imwrite(str(to_path / instrument_folder / (Path(paths[i]).stem + '.png')), full_mask)


In [15]:
_, mask_names = get_test_image(Path(f'predictions_test/{model_name}/binary'))
to_path = Path('colored') / model_name / problem_type
to_path.mkdir(exist_ok=True, parents=True)

for i, image_name in enumerate(file_names):
    instrument_folder = Path(file_names[i]).parent.parent.name

    save_path = Path(to_path) / instrument_folder
    if not os.path.exists(save_path):
        os.mkdir(save_path)
    
    last = image_name.stem
    raw_path = Path('data/test') / instrument_folder / 'left_frames' / (last+'.png')
    mask_path = Path(f'predictions_test/{model_name}/binary') / instrument_folder / (last+'.png')
    save_path = Path(to_path) / instrument_folder / (last+'.png')
    # print(raw_path)
    # print(mask_path)
    # print(save_path)
    contour = add_color(raw_path,mask_path,save_path)

In [11]:
mask='predictions_test/LinkNet34/binary/instrument_dataset_1/frame225.png'
image='data/test/instrument_dataset_2/left_frames/frame225.png'
save='test.png'


contour = draw_contour(image,mask,save)

In [16]:
if not os.path.exists('gif'):
    os.mkdir('gif')
for i in range(1,11):
    image_folder = Path(f'colored/{model_name}/binary/instrument_dataset_{i}')
    output_path = Path(f'gif/instrument_dataset_{i}.gif')
    generate_gif_pillow(image_folder, output_path, duration=500, loop=0)

动图已保存为 gif/instrument_dataset_1.gif
动图已保存为 gif/instrument_dataset_2.gif
动图已保存为 gif/instrument_dataset_3.gif
动图已保存为 gif/instrument_dataset_4.gif
动图已保存为 gif/instrument_dataset_5.gif
动图已保存为 gif/instrument_dataset_6.gif
动图已保存为 gif/instrument_dataset_7.gif
动图已保存为 gif/instrument_dataset_8.gif
动图已保存为 gif/instrument_dataset_9.gif
动图已保存为 gif/instrument_dataset_10.gif
