download requirement environment

In [1]:
pass

preprocess data

In [2]:
pass

define dataset class

In [3]:
import torch
import numpy as np
import cv2
from torch.utils.data import Dataset
import prepare_data



class RoboticsDataset(Dataset):
    def __init__(self, file_names, to_augment=False, transform=None, mode='train', problem_type=None):
        self.file_names = file_names
        self.to_augment = to_augment
        self.transform = transform
        self.mode = mode
        self.problem_type = problem_type

    def __len__(self):
        return len(self.file_names)

    def __getitem__(self, idx):
        img_file_name = self.file_names[idx]
        image = load_image(img_file_name)
        mask = load_mask(img_file_name, self.problem_type)

        data = {"image": image, "mask": mask}
        augmented = self.transform(**data)
        image, mask = augmented["image"], augmented["mask"]

        image = image.transpose(2, 0, 1)

        if self.mode == 'train':
            if self.problem_type == 'binary':
                return image, torch.tensor(mask, dtype=torch.float32).unsqueeze(0)
            else:
                return image, torch.tensor(mask, dtype=torch.long)
        else:
            return image, str(img_file_name)


def load_image(path):
    img = cv2.imread(str(path))
    return cv2.cvtColor(img, cv2.COLOR_BGR2RGB)


def load_mask(path, problem_type):
    if problem_type == 'binary':
        mask_folder = 'binary_masks'
        factor = prepare_data.binary_factor
    elif problem_type == 'parts':
        mask_folder = 'parts_masks'
        factor = prepare_data.parts_factor
    elif problem_type == 'instruments':
        factor = prepare_data.instrument_factor
        mask_folder = 'instruments_masks'

    mask = cv2.imread(str(path).replace('images', mask_folder).replace('jpg', 'png'), 0)

    return (mask / factor).astype(np.uint8)

get paths of training images

In [4]:
from prepare_data import data_path
import random

def get_split(random_seeds=42):
    random.seed(random_seeds)
    train_path = data_path / 'cropped_train'
    
    train_file_names = []
    val_file_names = []

    for instrument_id in range(1, 9):
        all_file_names = list((train_path / ('instrument_dataset_' + str(instrument_id)) / 'images').glob('*'))
        random.shuffle(all_file_names)
        split_idx = int(len(all_file_names) * 0.8)
        train_file_names.extend(all_file_names[:split_idx])
        val_file_names.extend(all_file_names[split_idx:])

    return train_file_names, val_file_names

get dataloader

In [5]:
train_crop_height = 1024
train_crop_width = 1280
val_crop_height = 1024
val_crop_width = 1280
workers = 12
batch_size = 8
problem_type = 'parts'

In [6]:
from torch.utils.data import DataLoader
def make_loader(file_names, shuffle=False, transform=None, problem_type=problem_type, batch_size=1):
        return DataLoader(
            dataset=RoboticsDataset(file_names, transform=transform, problem_type=problem_type),
            shuffle=shuffle,
            num_workers=workers,
            batch_size=batch_size,
            pin_memory=torch.cuda.is_available()
        )

In [7]:
from albumentations import (
    HorizontalFlip,
    VerticalFlip,
    Normalize,
    Compose,
    PadIfNeeded,
    RandomCrop,
    CenterCrop
)
def train_transform(p=1):
    return Compose([
        PadIfNeeded(min_height=train_crop_height, min_width=train_crop_width, p=1),
        RandomCrop(height=train_crop_height, width=train_crop_width, p=1),
        VerticalFlip(p=0.5),
        HorizontalFlip(p=0.5),
        Normalize(p=1)
    ], p=p)

def val_transform(p=1):
    return Compose([
        PadIfNeeded(min_height=val_crop_height, min_width=val_crop_width, p=1),
        CenterCrop(height=val_crop_height, width=val_crop_width, p=1),
        Normalize(p=1)
    ], p=p)

In [8]:
train_file_names, val_file_names = get_split()
train_loader = make_loader(train_file_names, shuffle=True, transform=train_transform(p=1), problem_type=problem_type,batch_size=batch_size)
valid_loader = make_loader(val_file_names, transform=val_transform(p=1), problem_type=problem_type,batch_size=batch_size)

get model

In [9]:
num_classes = 4
device_ids = [0]

In [10]:
from models import UNet11, LinkNet34, UNet, UNet16, AlbuNet
from torch import nn
from unetplusplus import UnetPlusPlus 

model_name = 'LinkNet34'
model = LinkNet34(num_classes=num_classes)
if torch.cuda.is_available():
    model = nn.DataParallel(model, device_ids=device_ids).cuda()



loss function

In [11]:
from loss import LossBinary, LossMulti
jaccard_weight = 0.3
loss = LossMulti(num_classes=num_classes, jaccard_weight=jaccard_weight)

benchmark

In [12]:
import torch.backends.cudnn as cudnn
import torch.backends.cudnn

cudnn.benchmark = True

checkpoints folder

In [13]:
import json

In [14]:
root = 'runs/debug'
# root.joinpath('params.json').write_text(json.dumps(vars(args), indent=True, sort_keys=True))

train function

In [15]:
from datetime import datetime
from pathlib import Path

import random
import numpy as np

import torch
import tqdm

In [16]:
def cuda(x):
    return x.to('cuda', non_blocking=True) if torch.cuda.is_available() else x
def write_event(log, step, **data):
    data['step'] = step
    data['dt'] = datetime.now().isoformat()
    log.write(json.dumps(data, sort_keys=True))
    log.write('\n')
    log.flush()

In [17]:
def train(model, criterion, train_loader, valid_loader, validation, init_optimizer, n_epochs=1, lr=0.0001, batch_size=1, fold=None,
          num_classes=None, root='runs/debug',model_name=model_name,problem_type=problem_type):
    lr = lr
    n_epochs = n_epochs
    optimizer = init_optimizer(lr)

    root = Path(root)
    model_path = root / f'{model_name}_{problem_type}.pt'
    if model_path.exists():
        state = torch.load(str(model_path))
        epoch = state['epoch']
        step = state['step']
        model.load_state_dict(state['model'])
        print('Restored model, epoch {}, step {:,}'.format(epoch, step))
    else:
        epoch = 1
        step = 0

    save = lambda ep: torch.save({
        'model': model.state_dict(),
        'epoch': ep,
        'step': step,
    }, str(model_path))

    report_each = 10
    log = root.joinpath('train.log').open('at', encoding='utf8')
    valid_losses = []
    for epoch in range(epoch, n_epochs + 1):
        model.train()
        random.seed()
        tq = tqdm.tqdm(total=(len(train_loader) * batch_size))
        tq.set_description('Epoch {}, lr {}'.format(epoch, lr))
        losses = []
        tl = train_loader
        mean_loss = 0
        for i, (inputs, targets) in enumerate(tl):
            inputs = cuda(inputs)

            with torch.no_grad():
                targets = cuda(targets)

            outputs = model(inputs)
            loss = criterion(outputs, targets)
            optimizer.zero_grad()
            batch_size = inputs.size(0)
            loss.backward()
            optimizer.step()
            step += 1
            tq.update(batch_size)
            losses.append(loss.item())
            mean_loss = np.mean(losses[-report_each:])
            tq.set_postfix(loss='{:.5f}'.format(mean_loss))
            if i and i % report_each == 0:
                write_event(log, step, loss=mean_loss)
        write_event(log, step, loss=mean_loss)
        tq.close()
        save(epoch + 1)
        valid_metrics = validation(model, criterion, valid_loader, num_classes)
        write_event(log, step, **valid_metrics)
        valid_loss = valid_metrics['valid_loss']
        valid_losses.append(valid_loss)


validation function

In [18]:
from validation import validation_binary, validation_multi
valid = validation_multi

optimizer

In [19]:
from torch.optim import Adam
optimizer = lambda lr: Adam(model.parameters(), lr=lr)

train

In [20]:
train(
    model=model,
    criterion=loss,
    train_loader=train_loader,
    valid_loader=valid_loader,
    validation=valid,
    init_optimizer=optimizer,
    n_epochs=20,
    lr=0.00001,
    batch_size=batch_size,
    num_classes=num_classes,
    model_name = model_name
)

Epoch 1, lr 1e-05: 100%|██████████| 1440/1440 [00:58<00:00, 24.68it/s, loss=4.23843]


Valid loss: 4.3767, average IoU: 0.4507, average Dice: 0.5201


Epoch 2, lr 1e-05: 100%|██████████| 1440/1440 [00:51<00:00, 27.96it/s, loss=3.96089]


Valid loss: 3.9898, average IoU: 0.3315, average Dice: 0.3324


Epoch 3, lr 1e-05: 100%|██████████| 1440/1440 [00:51<00:00, 27.78it/s, loss=2.97181]


Valid loss: 3.0765, average IoU: 0.3328, average Dice: 0.3331


Epoch 4, lr 1e-05: 100%|██████████| 1440/1440 [00:51<00:00, 28.11it/s, loss=2.05266]


Valid loss: 2.0738, average IoU: 0.3326, average Dice: 0.3330


Epoch 5, lr 1e-05: 100%|██████████| 1440/1440 [00:52<00:00, 27.69it/s, loss=1.54738]


Valid loss: 1.5929, average IoU: 0.3313, average Dice: 0.3323


Epoch 6, lr 1e-05: 100%|██████████| 1440/1440 [00:51<00:00, 27.76it/s, loss=1.35804]


Valid loss: 1.4054, average IoU: 0.3315, average Dice: 0.3324


Epoch 7, lr 1e-05: 100%|██████████| 1440/1440 [00:51<00:00, 27.76it/s, loss=1.27449]


Valid loss: 1.3233, average IoU: 0.3299, average Dice: 0.3316


Epoch 8, lr 1e-05: 100%|██████████| 1440/1440 [00:51<00:00, 27.88it/s, loss=1.23126]


Valid loss: 1.2736, average IoU: 0.4842, average Dice: 0.5446


Epoch 9, lr 1e-05: 100%|██████████| 1440/1440 [00:51<00:00, 27.95it/s, loss=1.18833]


Valid loss: 1.2367, average IoU: 0.4857, average Dice: 0.5458


Epoch 10, lr 1e-05: 100%|██████████| 1440/1440 [00:51<00:00, 27.90it/s, loss=1.17054]


Valid loss: 1.2110, average IoU: 0.4850, average Dice: 0.5453


Epoch 11, lr 1e-05: 100%|██████████| 1440/1440 [00:51<00:00, 28.05it/s, loss=1.15171]


Valid loss: 1.1909, average IoU: 0.4839, average Dice: 0.5453


Epoch 12, lr 1e-05: 100%|██████████| 1440/1440 [00:51<00:00, 28.04it/s, loss=1.10253]


Valid loss: 1.1518, average IoU: 0.4845, average Dice: 0.5476


Epoch 13, lr 1e-05: 100%|██████████| 1440/1440 [00:51<00:00, 27.79it/s, loss=1.01027]


Valid loss: 1.0492, average IoU: 0.5288, average Dice: 0.6346


Epoch 14, lr 1e-05: 100%|██████████| 1440/1440 [00:51<00:00, 27.98it/s, loss=0.88397]


Valid loss: 0.9337, average IoU: 0.5290, average Dice: 0.6353


Epoch 15, lr 1e-05: 100%|██████████| 1440/1440 [00:50<00:00, 28.29it/s, loss=0.82111]


Valid loss: 0.8803, average IoU: 0.5312, average Dice: 0.6378


Epoch 17, lr 1e-05: 100%|██████████| 1440/1440 [00:51<00:00, 28.05it/s, loss=0.82438]


Valid loss: 0.8464, average IoU: 0.5982, average Dice: 0.7166


Epoch 18, lr 1e-05: 100%|██████████| 1440/1440 [00:51<00:00, 28.11it/s, loss=0.47175]


Valid loss: 0.5092, average IoU: 0.8370, average Dice: 0.9082


Epoch 19, lr 1e-05: 100%|██████████| 1440/1440 [00:51<00:00, 27.96it/s, loss=0.28343]


Valid loss: 0.3451, average IoU: 0.8650, average Dice: 0.9258


Epoch 20, lr 1e-05: 100%|██████████| 1440/1440 [00:51<00:00, 28.14it/s, loss=0.25703]


Valid loss: 0.3134, average IoU: 0.8737, average Dice: 0.9312


generate mask using the model

In [29]:
output_path = f'predictions/{model_name}/parts'
workers = 12
from prepare_data import (original_height,
                          original_width,
                          h_start, w_start
                          )

In [30]:
def img_transform(p=1):
    return Compose([
        Normalize(p=1)
    ], p=p)

In [31]:
from torch.nn import functional as F
def predict(model, from_file_names, batch_size, to_path, problem_type, img_transform,workers=workers,):
    loader = DataLoader(
        dataset=RoboticsDataset(from_file_names, transform=img_transform, mode='predict', problem_type=problem_type),
        shuffle=False,
        batch_size=batch_size,
        num_workers=workers,
        pin_memory=torch.cuda.is_available()
    )

    with torch.no_grad():
        for batch_num, (inputs, paths) in enumerate(tqdm.tqdm(loader, desc='Predict')):
            inputs = cuda(inputs)

            outputs = model(inputs)

            for i, image_name in enumerate(paths):
                if problem_type == 'binary':
                    factor = prepare_data.binary_factor
                    t_mask = (F.sigmoid(outputs[i, 0]).data.cpu().numpy() * factor).astype(np.uint8)
                elif problem_type == 'parts':
                    factor = prepare_data.parts_factor
                    t_mask = (outputs[i].data.cpu().numpy().argmax(axis=0) * factor).astype(np.uint8)
                elif problem_type == 'instruments':
                    factor = prepare_data.instrument_factor
                    t_mask = (outputs[i].data.cpu().numpy().argmax(axis=0) * factor).astype(np.uint8)

                h, w = t_mask.shape

                full_mask = np.zeros((original_height, original_width))
                full_mask[h_start:h_start + h, w_start:w_start + w] = t_mask

                instrument_folder = Path(paths[i]).parent.parent.name

                (to_path / instrument_folder).mkdir(exist_ok=True, parents=True)

                cv2.imwrite(str(to_path / instrument_folder / (Path(paths[i]).stem + '.png')), full_mask)

In [32]:
val_file_names
print('num file_names = {}'.format(len(val_file_names)))
output_path = Path(output_path)
output_path.mkdir(exist_ok=True, parents=True)

predict(model, val_file_names, batch_size, output_path, problem_type=problem_type,img_transform=img_transform(p=1))

num file_names = 360


Predict: 100%|██████████| 45/45 [00:17<00:00,  2.53it/s]


evaluate

In [33]:
def jaccard(y_true, y_pred):
    intersection = (y_true * y_pred).sum()
    union = y_true.sum() + y_pred.sum() - intersection
    return (intersection + 1e-15) / (union + 1e-15)


def dice(y_true, y_pred):
    return (2 * (y_true * y_pred).sum() + 1e-15) / (y_true.sum() + y_pred.sum() + 1e-15)

def general_dice(y_true, y_pred):
    result = []

    if y_true.sum() == 0:
        if y_pred.sum() == 0:
            return 1
        else:
            return 0

    for instrument_id in set(y_true.flatten()):
        if instrument_id == 0:
            continue
        result += [dice(y_true == instrument_id, y_pred == instrument_id)]

    return np.mean(result)

def general_jaccard(y_true, y_pred):
    result = []

    if y_true.sum() == 0:
        if y_pred.sum() == 0:
            return 1
        else:
            return 0

    for instrument_id in set(y_true.flatten()):
        if instrument_id == 0:
            continue
        result += [jaccard(y_true == instrument_id, y_pred == instrument_id)]

    return np.mean(result)

In [34]:
import os
from prepare_data import height, width, h_start, w_start

target_path = f'predictions/{model_name}'
train_path = 'data/cropped_train'

In [35]:
result_dice = []
result_jaccard = []
if problem_type == 'binary':
    for instrument_id in tqdm.tqdm(range(1, 9)):
        instrument_dataset_name = 'instrument_dataset_' + str(instrument_id)

        pred_folder_name = (Path(target_path) / 'binary' / instrument_dataset_name)
        if not os.path.exists(pred_folder_name):
            continue

        for file_name in (Path(train_path) / instrument_dataset_name / 'binary_masks').glob('*'):
            pred_file_name = (Path(target_path) / 'binary' / instrument_dataset_name / file_name.name)
            if not os.path.exists(pred_file_name):
                continue
            
            y_true = (cv2.imread(str(file_name), 0) > 0).astype(np.uint8)

            pred_image = (cv2.imread(str(pred_file_name), 0) > 255 * 0.5).astype(np.uint8)
            y_pred = pred_image[h_start:h_start + height, w_start:w_start + width]

            result_dice += [dice(y_true, y_pred)]
            result_jaccard += [jaccard(y_true, y_pred)]
            
elif problem_type == 'parts':
    for instrument_id in tqdm.tqdm(range(1, 9)):
        instrument_dataset_name = 'instrument_dataset_' + str(instrument_id)
        for file_name in (
                Path(train_path) / instrument_dataset_name / 'parts_masks').glob('*'):
            y_true = cv2.imread(str(file_name), 0)

            pred_file_name = Path(target_path) / 'parts' / instrument_dataset_name / file_name.name
            if not os.path.exists(pred_file_name):
                continue

            y_pred = cv2.imread(str(pred_file_name), 0)[h_start:h_start + height, w_start:w_start + width]

            result_dice += [general_dice(y_true, y_pred)]
            result_jaccard += [general_jaccard(y_true, y_pred)]

100%|██████████| 8/8 [00:54<00:00,  6.84s/it]


In [36]:
print('Dice = ', np.mean(result_dice), np.std(result_dice))
print('Jaccard = ', np.mean(result_jaccard), np.std(result_jaccard))

Dice =  0.8510731883170084 0.11170368774976065
Jaccard =  0.7659906502225586 0.12557047251563083
