In [1]:
import pandas as pd
import torch
from tqdm import tqdm_notebook as tqdm
from matplotlib import pyplot as plt
import numpy as np

In [2]:
from dataloaders.binary_dataloader import BinaryLoader
from utils.train_validation_split import random_train_val_split
from utils.metrics_evaluator import PerformanceMetricsEvaluator
from models.unet import UNet

In [1]:
from utils.mask_functions import better_mask2rle, rle2mask

In [3]:
# Choose free GPU
device = "cpu"

ROOT_DIR = 'data/processed/'
DIR_TO_CSV = 'data/raw/train-rle.csv'

In [4]:
# Read CSV file
csv_file = pd.read_csv(DIR_TO_CSV)
train_csv, val_csv = random_train_val_split(csv_file, 0.2, 44)

In [5]:
val_data = BinaryLoader(val_csv, ROOT_DIR)
val_loader = torch.utils.data.DataLoader(val_data,
                                        batch_size=1,
                                        shuffle=False)

In [6]:
# Create model
model = UNet((3,512,512))
model.load_state_dict(torch.load("weights/unet_baseline_weighted_crossentropy0.233407.pth", map_location=device))

IncompatibleKeys(missing_keys=[], unexpected_keys=[])

In [7]:
def preview(img, gt, *args):
    length = len(args) + 2
    plt.figure(figsize=(20, 20))
    plt.subplot(1, length, 1)
    plt.imshow(img)
    plt.subplot(1, length, 2)
    plt.imshow(gt)
    for i, el in enumerate(args):
        plt.subplot(1, length, i+3)
        plt.imshow(el)
    plt.show()


In [8]:
def viz_val_set():
    for imgs, masks in tqdm(val_loader):
        imgs, masks = imgs.to(device), masks.to(device)
        masks = masks[0]
        with torch.no_grad():
            logits = model(imgs)
        imgs = imgs.numpy()[0].transpose((1, 2, 0))
        logits = logits.softmax(dim=1).argmax(dim=1)
        logits = logits.numpy().reshape((1, 128, 128))
        preview(imgs, masks, logits[0])

In [9]:
def create_pre_submit(in_dir, out_dir):
    for img_fn in tqdm(os.listdir(in_dir)):
        img_path = os.path.join(in_dir, img_fn)
        img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
#         img = cv2.resize(img, (256, 256))
        img = cv2.resize(img, (128, 128))
        img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
        img = torch.tensor([img.transpose(2, 0, 1)], dtype=torch.float32)
        imgs = img.to(device)
        with torch.no_grad():
            logits = model(imgs)
        imgs = imgs.numpy()[0].transpose((1, 2, 0))
        logits = logits.softmax(dim=1).argmax(dim=1)
        logits = (logits.numpy().reshape((1, 128, 128))[0])#.astype(np.float32)
        out = cv2.resize(logits, (1024, 1024), interpolation=cv2.INTER_NEAREST)#.astype(np.uint8)
#         plt.imshow(out), plt.show()
        out = measure.label(out, background=0)
        cv2.imwrite(os.path.join(out_dir, img_fn), out)



In [None]:
for imgs, masks in tqdm(val_loader):
        imgs, masks = imgs.to(device), masks.to(device)
        masks = masks[0]
        with torch.no_grad():
            logits = model(imgs)
        imgs = imgs.numpy()[0].transpose((1, 2, 0))
        logits = logits.softmax(dim=1).argmax(dim=1)
        logits = logits.numpy().reshape((1, 128, 128))
        preview(imgs, masks, logits[0])

In [None]:
viz_val_set()

In [10]:
from utils.data_mapping import dcm2png
import os
import cv2
from skimage import measure
import shutil

In [None]:
shutil.rmtree("data/pre_out")
os.mkdir("data/pre_out")
dcm2png("data/raw/dicom-images-test/", "data/pre_out/", v=1)

In [None]:
shutil.rmtree("data/out")
os.mkdir("data/out")
create_pre_submit("data/pre_out", "data/out")

In [11]:
from utils.make_submission import create_submission

In [12]:
create_submission('data/out', 'baseline', v=1)

  0%|          | 0/1377 [00:00<?, ?it/s]

Number of masks: 1377



100%|██████████| 1377/1377 [10:59<00:00,  1.92it/s]

Submission is saved to submission_baseline.csv successfully!!!





In [None]:
test_img = logits[0]

In [None]:
plt.imshow(test_img)

In [None]:
from itertools import groupby

In [None]:
flat = list(test_img.T.reshape(-1))
start = 0
data = []
prev = 0
for val, lst in groupby(flat):
    length = len(list(lst))
    if val > 0:
        data.append(prev)
        data.append(length)
    prev = length
    start += length
if len(data) == 0:
    print("-1")
res = " ".join(map(str, data))

In [None]:
from utils.mask_functions import mask2rle

In [None]:
res = mask2rle(test_img*255, width, height)

In [None]:
width = 128
height = 128
rle = res
mask= np.zeros(width* height)
array = np.asarray([int(x) for x in rle.split()])
starts = array[0::2]
lengths = array[1::2]

current_position = 0
for index, start in enumerate(starts):
    current_position += start
    mask[current_position:current_position+lengths[index]] = 255
    current_position += lengths[index]
mask = mask.reshape(width, height)
plt.imshow(mask), plt.show()