In [None]:
## Mount Google Drive Data (If using Google Colaboratory)
try:
    from google.colab import drive
    drive.mount('/content/gdrive')
except:
    print("Mounting Failed.")

Mounted at /content/gdrive


In [None]:
## Standard Library
import os
import json

## External Libraries
import numpy as np
import torch
import torch.nn as nn
from torchvision import transforms
from torch.autograd import Variable
import torch.nn.functional as functional
from torch.utils.data import Dataset, DataLoader
from skimage import io
import matplotlib.pyplot as plt
from pathlib import Path
from tqdm import tqdm
import tqdm
import cv2

from albumentations import Compose, RandomCrop, VerticalFlip, HorizontalFlip, Normalize, CenterCrop
import json
from datetime import datetime

import random


In [None]:
train_path = Path('/content/gdrive/MyDrive/instrument_5_8_training')

cropped_train_path = Path('/content/gdrive/MyDrive/cropped_training')

original_height, original_width = 1080, 1920
height, width = 1024, 1280
h_start, w_start = 28, 320

binary_factor = 255
parts_factor = 85
instrument_factor = 32



In [None]:
for instrument_index in range(5, 9):
  instrument_folder = 'instrument_dataset_' + str(instrument_index)

  (cropped_train_path / instrument_folder / 'images').mkdir(exist_ok=True, parents=True)

  binary_mask_folder = (cropped_train_path / instrument_folder / 'binary_masks')
  binary_mask_folder.mkdir(exist_ok=True, parents=True)

  parts_mask_folder = (cropped_train_path / instrument_folder / 'parts_masks')
  parts_mask_folder.mkdir(exist_ok=True, parents=True)

  instrument_mask_folder = (cropped_train_path / instrument_folder / 'instruments_masks')
  instrument_mask_folder.mkdir(exist_ok=True, parents=True)

  mask_folders = list((train_path / instrument_folder / 'ground_truth').glob('*'))
  # mask_folders = [x for x in mask_folders if 'Other' not in str(mask_folders)]

  for file_name in tqdm(list((train_path / instrument_folder / 'left_frames').glob('*'))):
      img = cv2.imread(str(file_name))
      old_h, old_w, _ = img.shape

      img = img[h_start: h_start + height, w_start: w_start + width]
      cv2.imwrite(str(cropped_train_path / instrument_folder / 'images' / (file_name.stem + '.jpg')), img,
                  [cv2.IMWRITE_JPEG_QUALITY, 100])

      mask_binary = np.zeros((old_h, old_w))
      mask_parts = np.zeros((old_h, old_w))
      mask_instruments = np.zeros((old_h, old_w))

      for mask_folder in mask_folders:
          mask = cv2.imread(str(mask_folder / file_name.name), 0)

          if 'Bipolar_Forceps' in str(mask_folder):
              mask_instruments[mask > 0] = 1
          elif 'Prograsp_Forceps' in str(mask_folder):
              mask_instruments[mask > 0] = 2
          elif 'Large_Needle_Driver' in str(mask_folder):
              mask_instruments[mask > 0] = 3
          elif 'Vessel_Sealer' in str(mask_folder):
              mask_instruments[mask > 0] = 4
          elif 'Grasping_Retractor' in str(mask_folder):
              mask_instruments[mask > 0] = 5
          elif 'Monopolar_Curved_Scissors' in str(mask_folder):
              mask_instruments[mask > 0] = 6
          elif 'Other' in str(mask_folder):
              mask_instruments[mask > 0] = 7

          if 'Other' not in str(mask_folder):
              mask_binary += mask

              mask_parts[mask == 10] = 1  # Shaft
              mask_parts[mask == 20] = 2  # Wrist
              mask_parts[mask == 30] = 3  # Claspers

      mask_binary = (mask_binary[h_start: h_start + height, w_start: w_start + width] > 0).astype(
          np.uint8) * binary_factor
      mask_parts = (mask_parts[h_start: h_start + height, w_start: w_start + width]).astype(
          np.uint8) * parts_factor
      mask_instruments = (mask_instruments[h_start: h_start + height, w_start: w_start + width]).astype(
          np.uint8) * instrument_factor

      cv2.imwrite(str(binary_mask_folder / file_name.name), mask_binary)
      cv2.imwrite(str(parts_mask_folder / file_name.name), mask_parts)
      cv2.imwrite(str(instrument_mask_folder / file_name.name), mask_instruments)

In [None]:
def get_split(fold):
    folds = {0: [5],
             1: [6],
             2: [7],
             3: [8],}

    train_path= Path("/content/gdrive/MyDrive/cropped_training")
    train_file_names = []
    val_file_names = []

    for instrument_id in range(5, 9):

        if instrument_id in folds[fold]:
            val_file_names += list((train_path / ('instrument_dataset_' + str(instrument_id)) / 'images').glob('*'))
        else:
            train_file_names += list((train_path / ('instrument_dataset_' + str(instrument_id)) / 'images').glob('*'))
    return train_file_names, val_file_names

In [None]:
get_split(1)

([PosixPath('/content/gdrive/MyDrive/cropped_training/instrument_dataset_5/images/frame014.jpg'),
  PosixPath('/content/gdrive/MyDrive/cropped_training/instrument_dataset_5/images/frame013.jpg'),
  PosixPath('/content/gdrive/MyDrive/cropped_training/instrument_dataset_5/images/frame012.jpg'),
  PosixPath('/content/gdrive/MyDrive/cropped_training/instrument_dataset_5/images/frame011.jpg'),
  PosixPath('/content/gdrive/MyDrive/cropped_training/instrument_dataset_5/images/frame010.jpg'),
  PosixPath('/content/gdrive/MyDrive/cropped_training/instrument_dataset_5/images/frame009.jpg'),
  PosixPath('/content/gdrive/MyDrive/cropped_training/instrument_dataset_5/images/frame008.jpg'),
  PosixPath('/content/gdrive/MyDrive/cropped_training/instrument_dataset_5/images/frame007.jpg'),
  PosixPath('/content/gdrive/MyDrive/cropped_training/instrument_dataset_5/images/frame006.jpg'),
  PosixPath('/content/gdrive/MyDrive/cropped_training/instrument_dataset_5/images/frame005.jpg'),
  PosixPath('/conten

In [None]:
import torch
import numpy as np
import cv2
from torch.utils.data import Dataset
from albumentations.pytorch.functional import img_to_tensor


class RoboticsDataset(Dataset):
    def __init__(self, file_names, to_augment=False, transform=None, mode='train', problem_type=None):
        self.file_names = file_names
        self.to_augment = to_augment
        self.transform = transform
        self.mode = mode
        self.problem_type = problem_type

    def __len__(self):
        return len(self.file_names)

    def __getitem__(self, idx):
        img_file_name = self.file_names[idx]
        image = load_image(img_file_name)
        mask = load_mask(img_file_name, self.problem_type)

        data = {"image": image, "mask": mask}
        if self.transform != None:
          data = self.transform(**data)
        image, mask = data["image"], data["mask"]

        if self.mode == 'train':
            if self.problem_type == 'binary':
                return img_to_tensor(image), torch.from_numpy(np.expand_dims(mask, 0)).float()
            else:
                return img_to_tensor(image), torch.from_numpy(mask).long()
        else:
            return img_to_tensor(image), str(img_file_name)


def make_loader(file_names, shuffle=False, transform=None, problem_type='binary', batch_size=1):
        return DataLoader(
            dataset=RoboticsDataset(file_names, transform=transform, problem_type=problem_type),
            shuffle=shuffle,
            batch_size=batch_size,
        )

def load_image(path):
    img = cv2.imread(str(path))
    return cv2.cvtColor(img, cv2.COLOR_BGR2RGB)


def load_mask(path, problem_type):
    if problem_type == 'binary':
        mask_folder = 'binary_masks'
        factor = binary_factor
    elif problem_type == 'parts':
        mask_folder = 'parts_masks'
        factor = parts_factor
    elif problem_type == 'instruments':
        factor = instrument_factor
        mask_folder = 'instruments_masks'

    mask = cv2.imread(str(path).replace('images', mask_folder).replace('jpg', 'png'), 0)

    return (mask / factor).astype(np.uint8)

In [None]:
fold = 0
train_file_names, val_file_names = get_split(fold)
print('num train = {}, num_val = {}'.format(len(train_file_names), len(val_file_names)))

num train = 675, num_val = 225


In [None]:
def train_transform(p=1):
    return Compose([
        RandomCrop(height=1024,width=1280,p=1),
        VerticalFlip(p=0.5),
        HorizontalFlip(p=0.5),
        Normalize(p=1),
    ], p=p)

def val_transform(p=1):
    return Compose([
        CenterCrop(height=1024,width=1280,p=1),
        Normalize(p=1),], p=p)



In [None]:
train_loader = make_loader(train_file_names, shuffle=True, transform=train_transform(), problem_type="instruments",batch_size=6)
for inputs,targets in train_loader:
  print(inputs)
  print(targets)
  break

tensor([[[[-1.8610, -1.8610, -1.8610,  ..., -1.8610, -1.8610, -1.8610],
          [-1.8610, -1.8610, -1.8610,  ..., -1.8610, -1.8610, -1.8610],
          [-1.8610, -1.8610, -1.8439,  ..., -1.8610, -1.8610, -1.8610],
          ...,
          [-1.8610, -1.8610, -1.8610,  ..., -1.8439, -1.8439, -1.8439],
          [-1.8439, -1.8439, -1.8439,  ..., -1.8439, -1.8439, -1.8439],
          [-1.8439, -1.8439, -1.8439,  ..., -1.8439, -1.8439, -1.8439]],

         [[-1.7731, -1.7731, -1.7731,  ..., -1.7731, -1.7731, -1.7731],
          [-1.7731, -1.7731, -1.7731,  ..., -1.7731, -1.7731, -1.7731],
          [-1.7731, -1.7731, -1.7906,  ..., -1.7731, -1.7731, -1.7731],
          ...,
          [-1.7906, -1.7906, -1.7906,  ..., -1.7906, -1.7906, -1.7906],
          [-1.7906, -1.7906, -1.7906,  ..., -1.7906, -1.7906, -1.7906],
          [-1.7906, -1.7906, -1.7906,  ..., -1.7906, -1.7906, -1.7906]],

         [[-1.5779, -1.5779, -1.5430,  ..., -1.5430, -1.5430, -1.5430],
          [-1.5779, -1.5779, -

In [None]:
def cuda(x):
  if torch.cuda.is_available():
    device = torch.device("cuda")
    use_gpu = True
  else:
    device = torch.device("cpu")
    use_gpu = False
  return x.to(device)

In [None]:
from torch import nn
import torch
from torchvision import models
import torchvision
from torch.nn import functional as F

def conv3x3(in_, out):
    return nn.Conv2d(in_, out, 3, padding=1)


class ConvRelu(nn.Module):
    def __init__(self, in_: int, out: int):
        super(ConvRelu, self).__init__()
        self.conv = conv3x3(in_, out)
        self.activation = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.conv(x)
        x = self.activation(x)
        return x


class DecoderBlock(nn.Module):

    def __init__(self, in_channels, middle_channels, out_channels, is_deconv=True):
        super(DecoderBlock, self).__init__()
        self.in_channels = in_channels

        if is_deconv:
            self.block = nn.Sequential(
                ConvRelu(in_channels, middle_channels),
                nn.ConvTranspose2d(middle_channels, out_channels, kernel_size=4, stride=2,
                                   padding=1),
                nn.ReLU(inplace=True)
            )
        else:
            self.block = nn.Sequential(
                nn.Upsample(scale_factor=2, mode='bilinear'),
                ConvRelu(in_channels, middle_channels),
                ConvRelu(middle_channels, out_channels),
            )

    def forward(self, x):
        return self.block(x)
class UNet16(nn.Module):
    def __init__(self, num_classes=1, num_filters=32, pretrained=False):

        super().__init__()
        self.num_classes = num_classes

        self.pool = nn.MaxPool2d(2, 2)

        self.encoder = torchvision.models.vgg16(pretrained=pretrained).features

        self.relu = nn.ReLU(inplace=True)

        self.conv1 = nn.Sequential(self.encoder[0],
                                   self.relu,
                                   self.encoder[2],
                                   self.relu)

        self.conv2 = nn.Sequential(self.encoder[5],
                                   self.relu,
                                   self.encoder[7],
                                   self.relu)

        self.conv3 = nn.Sequential(self.encoder[10],
                                   self.relu,
                                   self.encoder[12],
                                   self.relu,
                                   self.encoder[14],
                                   self.relu)

        self.conv4 = nn.Sequential(self.encoder[17],
                                   self.relu,
                                   self.encoder[19],
                                   self.relu,
                                   self.encoder[21],
                                   self.relu)

        self.conv5 = nn.Sequential(self.encoder[24],
                                   self.relu,
                                   self.encoder[26],
                                   self.relu,
                                   self.encoder[28],
                                   self.relu)

        self.center = DecoderBlock(512, num_filters * 8 * 2, num_filters * 8)

        self.dec5 = DecoderBlock(512 + num_filters * 8, num_filters * 8 * 2, num_filters * 8)
        self.dec4 = DecoderBlock(512 + num_filters * 8, num_filters * 8 * 2, num_filters * 8)
        self.dec3 = DecoderBlock(256 + num_filters * 8, num_filters * 4 * 2, num_filters * 2)
        self.dec2 = DecoderBlock(128 + num_filters * 2, num_filters * 2 * 2, num_filters)
        self.dec1 = ConvRelu(64 + num_filters, num_filters)
        self.final = nn.Conv2d(num_filters, num_classes, kernel_size=1)

    def forward(self, x):
        conv1 = self.conv1(x)
        conv2 = self.conv2(self.pool(conv1))
        conv3 = self.conv3(self.pool(conv2))
        conv4 = self.conv4(self.pool(conv3))
        conv5 = self.conv5(self.pool(conv4))

        center = self.center(self.pool(conv5))

        dec5 = self.dec5(torch.cat([center, conv5], 1))

        dec4 = self.dec4(torch.cat([dec5, conv4], 1))
        dec3 = self.dec3(torch.cat([dec4, conv3], 1))
        dec2 = self.dec2(torch.cat([dec3, conv2], 1))
        dec1 = self.dec1(torch.cat([dec2, conv1], 1))

        if self.num_classes > 1:
            x_out = F.log_softmax(self.final(dec1), dim=1)
        else:
            x_out = self.final(dec1)

        return x_out


In [None]:
import torch
from torch import nn
from torch.nn import functional as F
import numpy as np

class LossMulti:
    def __init__(self, jaccard_weight=1, class_weights=None, num_classes=1):
        if class_weights is not None:
            nll_weight = cuda(
                torch.from_numpy(class_weights.astype(np.float32)))
        else:
            nll_weight = None
        self.nll_loss = nn.NLLLoss(weight=nll_weight)
        self.jaccard_weight = jaccard_weight
        self.num_classes = num_classes

    def __call__(self, outputs, targets):
        loss = (1 - self.jaccard_weight) * self.nll_loss(outputs, targets)

        if self.jaccard_weight:
            eps = 1e-15
            for cls in range(self.num_classes):
                jaccard_target = (targets == cls).float()
                jaccard_output = outputs[:, cls].exp()
                intersection = (jaccard_output * jaccard_target).sum()

                union = jaccard_output.sum() + jaccard_target.sum()
                loss -= torch.log((intersection + eps) / (union - intersection + eps)) * self.jaccard_weight
        return loss

In [None]:
def get_jaccard(y_true, y_pred):
    epsilon = 1e-15
    intersection = (y_pred * y_true).sum(dim=-2).sum(dim=-1)
    union = y_true.sum(dim=-2).sum(dim=-1) + y_pred.sum(dim=-2).sum(dim=-1)

    return list(((intersection + epsilon) / (union - intersection + epsilon)).data.cpu().numpy())


def validation_multi(model: nn.Module, criterion, valid_loader, num_classes):
    with torch.no_grad():
        model.eval()
        losses = []
        confusion_matrix = np.zeros(
            (num_classes, num_classes), dtype=np.uint32)
        for inputs, targets in valid_loader:
            inputs = cuda(inputs)
            targets = cuda(targets)
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            losses.append(loss.item())
            output_classes = outputs.data.cpu().numpy().argmax(axis=1)
            target_classes = targets.data.cpu().numpy()
            confusion_matrix += calculate_confusion_matrix_from_arrays(
                output_classes, target_classes, num_classes)

        confusion_matrix = confusion_matrix[1:, 1:]  # exclude background
        valid_loss = np.mean(losses)  # type: float
        ious = {'iou_{}'.format(cls + 1): iou
                for cls, iou in enumerate(calculate_iou(confusion_matrix))}

        dices = {'dice_{}'.format(cls + 1): dice
                 for cls, dice in enumerate(calculate_dice(confusion_matrix))}

        average_iou = np.mean(list(ious.values()))
        average_dices = np.mean(list(dices.values()))

        print(
            'Valid loss: {:.4f}, average IoU: {:.4f}, average Dice: {:.4f}'.format(valid_loss,
                                                                                   average_iou,
                                                                                   average_dices))
        metrics = {'valid_loss': valid_loss, 'iou': average_iou}
        metrics.update(ious)
        metrics.update(dices)
        return metrics


def calculate_confusion_matrix_from_arrays(prediction, ground_truth, nr_labels):
    replace_indices = np.vstack((
        ground_truth.flatten(),
        prediction.flatten())
    ).T
    confusion_matrix, _ = np.histogramdd(
        replace_indices,
        bins=(nr_labels, nr_labels),
        range=[(0, nr_labels), (0, nr_labels)]
    )
    confusion_matrix = confusion_matrix.astype(np.uint32)
    return confusion_matrix


def calculate_iou(confusion_matrix):
    ious = []
    for index in range(confusion_matrix.shape[0]):
        true_positives = confusion_matrix[index, index]
        false_positives = confusion_matrix[:, index].sum() - true_positives
        false_negatives = confusion_matrix[index, :].sum() - true_positives
        denom = true_positives + false_positives + false_negatives
        if denom == 0:
            iou = 0
        else:
            iou = float(true_positives) / denom
        ious.append(iou)
    return ious


def calculate_dice(confusion_matrix):
    dices = []
    for index in range(confusion_matrix.shape[0]):
        true_positives = confusion_matrix[index, index]
        false_positives = confusion_matrix[:, index].sum() - true_positives
        false_negatives = confusion_matrix[index, :].sum() - true_positives
        denom = 2 * true_positives + false_positives + false_negatives
        if denom == 0:
            dice = 0
        else:
            dice = 2 * float(true_positives) / denom
        dices.append(dice)
    return dices

In [None]:
def write_event(log, step, **data):
    data['step'] = step
    data['dt'] = datetime.now().isoformat()
    log.write(json.dumps(data, sort_keys=True))
    log.write('\n')
    log.flush()


def train( model, criterion, train_loader, val_loader, validation, optimizer, root="/content/gdrive/MyDrive/model", n_epochs=10, fold=0,
          num_classes=None,batch_size=6):

    n_epochs = n_epochs
    root = Path(root)
    root.mkdir(exist_ok=True, parents=True)
    model_path = root / 'model_{fold}.pt'.format(fold=fold)
    if model_path.exists():
        state = torch.load(str(model_path))
        epoch = state['epoch']
        step = state['step']
        model.load_state_dict(state['model'])
        print('Restored model, epoch {}, step {:,}'.format(epoch, step))
    else:
        epoch = 1
        step = 0

    save = lambda ep: torch.save({
        'model': model.state_dict(),
        'epoch': ep,
        'step': step,
    }, str(model_path))

    report_each = 2
    log = root.joinpath('train_{fold}.log'.format(fold=fold)).open('at', encoding='utf8')
    valid_losses = []
    for epoch in range(epoch, n_epochs + 1):
        model.train()
        random.seed()
        tq = tqdm.tqdm(total=(len(train_loader) * batch_size))
        tq.set_description('Epoch {}, lr {}'.format(epoch, lr))
        losses = []
        tl = train_loader
        try:
            mean_loss = 0
            for i, (inputs, targets) in enumerate(tl):
                inputs = cuda(inputs)

                with torch.no_grad():
                    targets = cuda(targets)

                outputs = model(inputs)
                loss = criterion(outputs, targets)
                optimizer.zero_grad()
                batch_size = inputs.size(0)
                loss.backward()
                optimizer.step()
                step += 1
                tq.update(batch_size)
                losses.append(loss.item())
                mean_loss = np.mean(losses[-report_each:])
                tq.set_postfix(loss='{:.5f}'.format(mean_loss))
                if i and i % report_each == 0:
                    write_event(log, step, loss=mean_loss)
            write_event(log, step, loss=mean_loss)
            tq.close()
            save(epoch + 1)
            valid_metrics = validation(model, criterion, val_loader, num_classes)
            write_event(log, step, **valid_metrics)
            valid_loss = valid_metrics['valid_loss']
            valid_losses.append(valid_loss)
        except KeyboardInterrupt:
            tq.close()
            print('Ctrl+C, saving snapshot')
            save(epoch)
            print('done.')

In [None]:
import torch.backends.cudnn as cudnn
import torch.backends.cudnn
from torch.optim import Adam
problem_type="instruments"
batch_size = 3
train_loader = make_loader(train_file_names, shuffle=True, transform=train_transform(), problem_type=problem_type,batch_size=batch_size)
val_loader = make_loader(val_file_names, shuffle=False, transform=val_transform(), problem_type=problem_type)

num_classes = 8
model = UNet16(num_classes=num_classes)
model = model.cuda()
cudnn.benchmark = True
criterion = LossMulti(num_classes=num_classes,jaccard_weight=0.3)
lr = 0.0001
optimizer = Adam(model.parameters(), lr=lr)

train(model=model,criterion=criterion,train_loader=train_loader,val_loader=val_loader,validation=validation_multi,optimizer=optimizer,num_classes=num_classes)

Epoch 1, lr 0.0001:  50%|█████     | 675/1350 [32:50<32:50,  2.92s/it, loss=50.76133]


Valid loss: 55.2670, average IoU: 0.0000, average Dice: 0.0000


Epoch 2, lr 0.0001: 100%|██████████| 675/675 [11:54<00:00,  1.06s/it, loss=33.79669]


Valid loss: 56.4630, average IoU: 0.0000, average Dice: 0.0000


Epoch 3, lr 0.0001:  28%|██▊       | 186/675 [03:16<08:36,  1.06s/it, loss=21.49203]


Ctrl+C, saving snapshot
done.


Epoch 4, lr 0.0001:   0%|          | 3/675 [00:03<12:22,  1.11s/it, loss=12.53465]


Ctrl+C, saving snapshot
done.


Epoch 5, lr 0.0001:   0%|          | 3/675 [00:03<12:00,  1.07s/it]


Ctrl+C, saving snapshot
done.


Epoch 6, lr 0.0001:   0%|          | 3/675 [00:03<12:00,  1.07s/it]


Ctrl+C, saving snapshot
done.


Epoch 7, lr 0.0001:   0%|          | 0/675 [00:00<?, ?it/s]