In [None]:
#from __future__ import print_function
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
import cv2
import sys
import numpy as np
from skimage import segmentation
import torch.nn.init

import pydicom
import matplotlib.pyplot as plt
from skimage.transform import resize
from skimage import measure
from scipy.fft import fft
from scipy.optimize import curve_fit

In [None]:
use_cuda = torch.cuda.is_available()
use_cuda

In [None]:


parser = argparse.ArgumentParser(description='PyTorch Unsupervised Segmentation')
parser.add_argument('--nChannel', metavar='N', default=100, type=int, 
                    help='number of channels')
parser.add_argument('--maxIter', metavar='T', default=1000, type=int, 
                    help='number of maximum iterations')
parser.add_argument('--minLabels', metavar='minL', default=3, type=int, 
                    help='minimum number of labels')
parser.add_argument('--lr', metavar='LR', default=0.1, type=float, 
                    help='learning rate')
parser.add_argument('--nConv', metavar='M', default=2, type=int, 
                    help='number of convolutional layers')
parser.add_argument('--num_superpixels', metavar='K', default=10000, type=int, 
                    help='number of superpixels')
parser.add_argument('--compactness', metavar='C', default=100, type=float, 
                    help='compactness of superpixels')
parser.add_argument('--visualize', metavar='1 or 0', default=1, type=int, 
                    help='visualization flag')
parser.add_argument('--input', metavar='FILENAME',
                    help='input image file name', required=True)
args = parser.parse_args()



In [None]:
# CNN model
nChannel = 100
nConv = 2

class MyNet(nn.Module):
    def __init__(self, input_dim):
        super(MyNet, self).__init__()
        self.conv1 = nn.Conv2d(input_dim, nChannel, kernel_size=3, stride=1, padding=1 )
        self.bn1 = nn.BatchNorm2d(nChannel)
        self.conv2 = nn.ModuleList()
        self.bn2 = nn.ModuleList()
        for i in range(nConv-1):
            self.conv2.append( nn.Conv2d(nChannel, nChannel, kernel_size=3, stride=1, padding=1 ) )
            self.bn2.append( nn.BatchNorm2d(nChannel) )
        self.conv3 = nn.Conv2d(nChannel, nChannel, kernel_size=1, stride=1, padding=0 )
        self.bn3 = nn.BatchNorm2d(nChannel)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu( x )
        x = self.bn1(x)
        for i in range(nConv-1):
            x = self.conv2[i](x)
            x = F.relu( x )
            x = self.bn2[i](x)
        x = self.conv3(x)
        x = self.bn3(x)
        return x


In [None]:
# load image
filepath = '/kaggle/input/siim-covid19-detection/train/000c9c05fd14/e555410bd2cd/51759b5579bc.dcm'
im = resize(pydicom.dcmread(filepath).pixel_array, (512, 512))
im = im - im.min()
im = im.astype(float) / im.max()
data = torch.from_numpy(im)
if use_cuda:
    data = data.cuda()
data = Variable(data)

plt.figure(figsize=(10, 10))
plt.imshow(im, cmap=plt.cm.bone)

In [None]:
# segmentation

min_sizes = [50, 100, 200]

fig, axs = plt.subplots(2, 2, figsize=(20, 20))

for i in range(len(min_sizes)):
    labels = segmentation.felzenszwalb(im, scale=1, sigma=0.8, min_size=min_sizes[i])

    axs[i // 2, i % 2].imshow(segmentation.mark_boundaries(im, labels))
    
    
axs[1, 1].imshow(im, cmap=plt.cm.bone)

In [None]:
labels = segmentation.felzenszwalb(im, scale=1, sigma=0.8, min_size=100)
labels = labels.reshape(im.shape[0]*im.shape[1])
u_labels = np.unique(labels)
l_inds = []
for i in range(len(u_labels)):
    l_inds.append( np.where( labels == u_labels[ i ] )[ 0 ] )
    
plt.figure(figsize=(10, 10))
plt.imshow(segmentation.mark_boundaries(im, labels.reshape(im.shape)))

In [None]:
# train
model = MyNet(1)
if use_cuda:
    model.cuda()
model.train()
loss_fn = torch.nn.CrossEntropyLoss()

lr = 0.1
maxIter = 1000
minLabels = 5

optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)
for batch_idx in range(maxIter):
    # forwarding
    optimizer.zero_grad()
    output = model( data.reshape((1, 1, im.shape[0], im.shape[1])).float() )[ 0 ]
    output = output.permute( 1, 2, 0 ).contiguous().view( -1, nChannel )
    ignore, target = torch.max( output, 1 )
    im_target = target.data.cpu().numpy()
    nLabels = len(np.unique(im_target))

    # superpixel refinement
    # TODO: use Torch Variable instead of numpy for faster calculation
    for i in range(len(l_inds)):
        labels_per_sp = im_target[ l_inds[ i ] ]
        u_labels_per_sp = np.unique( labels_per_sp )
        hist = np.zeros( len(u_labels_per_sp) )
        for j in range(len(hist)):
            hist[ j ] = len( np.where( labels_per_sp == u_labels_per_sp[ j ] )[ 0 ] )
        im_target[ l_inds[ i ] ] = u_labels_per_sp[ np.argmax( hist ) ]
    target = torch.from_numpy( im_target )
    if use_cuda:
        target = target.cuda()
    target = Variable( target )
    loss = loss_fn(output, target)
    loss.backward()
    optimizer.step()

    #print (batch_idx, '/', args.maxIter, ':', nLabels, loss.data[0])
    print (batch_idx, '/', maxIter, ':', nLabels, loss.item())

    if nLabels <= minLabels:
        print ("nLabels", nLabels, "reached minLabels", minLabels, ".")
        break


In [None]:
output = model( data.reshape((1, 1, im.shape[0], im.shape[1])).float() )[0]
output = output.permute( 1, 2, 0 ).contiguous().view( -1, nChannel )
ignore, target = torch.max( output, 1 )
im_target = target.data.cpu().numpy().reshape(im.shape[0], im.shape[1])
print(im_target.shape)

fig, axs = plt.subplots(1, 2, figsize=(20, 10))
axs[0].imshow(im, cmap=plt.cm.bone)
axs[1].imshow(im_target)

In [None]:
im_target

In [None]:
fig, axs = plt.subplots(1, 5, figsize=(20, 10))
for i in range(5):
    axs[i].plot(range(512), im[50 + 100 * i])

In [None]:
xs = np.array(range(512))
plt.plot(xs, 0.5 + np.sin(xs * 0.03 - a) * 0.3)
plt.plot(xs, im[150])

In [None]:
fig, axs = plt.subplots(1, 5, figsize=(20, 10))

def f(x, a, b):
    return 0.5 + np.sin(x * 0.03 - a) * b

for i in range(5):
    popt, pcov = curve_fit(f, xs, im[50 + 100 * i], p0=[0, 0.3])
    ys = f(xs, *popt)
    print(pcov, np.max(np.abs(ys - im[50 + 100 * i])))
    axs[i].plot(range(512), im[50 + 100 * i])
    axs[i].plot(xs, ys)

In [None]:
i = 2
popt, pcov = curve_fit(f, xs, im[50 + 100 * i], p0=[0, 0.3])
ys = f(xs, *popt)
#print(pcov, np.max(np.abs(ys - im[50 + 100 * i])))
axs[i].plot(range(512), im[50 + 100 * i])
axs[i].plot(xs, ys)

print(np.argsort(ys)[:2])

indexes = components[250, np.argsort(ys)[:2]]
plt.figure(figsize=(10, 10))
mask = (components == indexes[0]) | (components == indexes[1])
plt.imshow(mask)

In [None]:
# CNN model
nChannel = 100
nConv = 2

class MyNet(nn.Module):
    def __init__(self, input_dim):
        super(MyNet, self).__init__()
        self.conv1 = nn.Conv2d(input_dim, nChannel, kernel_size=3, stride=1, padding=1 )
        self.bn1 = nn.BatchNorm2d(nChannel)
        self.conv2 = nn.ModuleList()
        self.bn2 = nn.ModuleList()
        for i in range(nConv-1):
            self.conv2.append( nn.Conv2d(nChannel, nChannel, kernel_size=3, stride=1, padding=1 ) )
            self.bn2.append( nn.BatchNorm2d(nChannel) )
        self.conv3 = nn.Conv2d(nChannel, nChannel, kernel_size=1, stride=1, padding=0 )
        self.bn3 = nn.BatchNorm2d(nChannel)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu( x )
        x = self.bn1(x)
        for i in range(nConv-1):
            x = self.conv2[i](x)
            x = F.relu( x )
            x = self.bn2[i](x)
        x = self.conv3(x)
        x = self.bn3(x)
        return x


def detect_lungs(filepath):
    # load image
    #print(filepath)
    im = resize(pydicom.dcmread(filepath).pixel_array, (256, 256))
    im = im - im.min()
    im = im.astype(float) / im.max()
    data = torch.from_numpy(im)
    if use_cuda:
        data = data.cuda()
    data = Variable(data)
    
    # Segmentation
    labels = segmentation.felzenszwalb(im, scale=1, sigma=0.8, min_size=50)
    labels = labels.reshape(im.shape[0]*im.shape[1])
    u_labels = np.unique(labels)
    l_inds = []
    for i in range(len(u_labels)):
        l_inds.append( np.where( labels == u_labels[ i ] )[ 0 ] )

    # train
    model = MyNet(1)
    if use_cuda:
        model.cuda()
    model.train()
    loss_fn = torch.nn.CrossEntropyLoss()

    lr = 0.1
    maxIter = 100
    minLabels = 10

    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)
    for batch_idx in range(maxIter):
        # forwarding
        optimizer.zero_grad()
        output = model( data.reshape((1, 1, im.shape[0], im.shape[1])).float() )[ 0 ]
        output = output.permute( 1, 2, 0 ).contiguous().view( -1, nChannel )
        ignore, target = torch.max( output, 1 )
        im_target = target.data.cpu().numpy()
        nLabels = len(np.unique(im_target))

        # superpixel refinement
        # TODO: use Torch Variable instead of numpy for faster calculation
        for i in range(len(l_inds)):
            labels_per_sp = im_target[ l_inds[ i ] ]
            u_labels_per_sp = np.unique( labels_per_sp )
            hist = np.zeros( len(u_labels_per_sp) )
            for j in range(len(hist)):
                hist[ j ] = len( np.where( labels_per_sp == u_labels_per_sp[ j ] )[ 0 ] )
            im_target[ l_inds[ i ] ] = u_labels_per_sp[ np.argmax( hist ) ]
        target = torch.from_numpy( im_target )
        if use_cuda:
            target = target.cuda()
        target = Variable( target )
        loss = loss_fn(output, target)
        loss.backward()
        optimizer.step()

        #print (batch_idx, '/', args.maxIter, ':', nLabels, loss.data[0])
        #print (batch_idx, '/', maxIter, ':', nLabels, loss.item())

        if nLabels <= minLabels:
            #print ("nLabels", nLabels, "reached minLabels", minLabels, ".")
            break

    # predict
    output = model( data.reshape((1, 1, im.shape[0], im.shape[1])).float() )[0]
    output = output.permute( 1, 2, 0 ).contiguous().view( -1, nChannel )
    ignore, target = torch.max( output, 1 )
    components = target.data.cpu().numpy().reshape(im.shape[0], im.shape[1])

    # Find a horizontal line with lungs
    def f(x, a, b):
        return 0.5 + np.sin(x * 0.03 - a) * b

    xs = np.array(range(im.shape[1]))
    best_line = 0
    best_score = np.inf
    for i in range(7):
        line = im.shape[1] // 8 * (i + 1)
        popt, pcov = curve_fit(f, xs, im[line], p0=[0, 0.3])
        ys = f(xs, *popt)
        score = np.max(np.abs(ys - im[line]))
        if score < best_score:
            best_line = line
            best_score = score
    
    # Find the lungs on the best line
    popt, pcov = curve_fit(f, xs, im[best_line], p0=[0, 0.3])
    ys = f(xs, *popt)

    indexes = components[best_line, np.argsort(ys)[:2]]
    mask = (components == indexes[0]) | (components == indexes[1])
    return mask

In [None]:
#filepath = '/kaggle/input/siim-covid19-detection/train/000c9c05fd14/e555410bd2cd/51759b5579bc.dcm'
filepath = '/kaggle/input/siim-covid19-detection/train/ff0879eb20ed/d8a644cc4f93/000c3a3f293f.dcm'
mask = detect_lungs(filepath)

im = resize(pydicom.dcmread(filepath).pixel_array, (512, 512))
im = im - im.min()
im = im.astype(float) / im.max()

fig, axs = plt.subplots(1, 2, figsize=(10, 10))
axs[0].imshow(im, cmap=plt.cm.bone)
axs[1].imshow(mask)

In [None]:
f = '/kaggle/input/siim-covid19-detection/train/9d514ce429a7/22897cd1daa0/0012ff7358bc.dcm'
mask = detect_lungs(f)

im = resize(pydicom.dcmread(f).pixel_array, (256, 256))
im = im - im.min()
im = im.astype(float) / im.max()

fig, axs = plt.subplots(1, 2, figsize=(10, 10))
axs[0].imshow(im, cmap=plt.cm.bone)
axs[1].imshow(mask)

In [None]:
labels = segmentation.felzenszwalb(im, scale=1, sigma=0.8, min_size=1000)
labels = labels.reshape(im.shape[0]*im.shape[1])
u_labels = np.unique(labels)
l_inds = []
for i in range(len(u_labels)):
    l_inds.append( np.where( labels == u_labels[ i ] )[ 0 ] )
    
plt.figure(figsize=(10, 10))
plt.imshow(segmentation.mark_boundaries(im, labels.reshape(im.shape)))