In [None]:
from matplotlib import pyplot as plt
import pandas as pd
import os
import torchvision.transforms as tr
import numpy as np
import torch
from torch import nn
from torch import optim
import torch.nn.functional as F
import matplotlib.pyplot as plt
import pydicom
import glob
import collections
from datetime import datetime
from skimage import measure
from skimage.measure import block_reduce
from matplotlib import pyplot as plt
from mpl_toolkits.mplot3d.art3d import Poly3DCollection
from skimage.transform import resize
import torchvision.models as models
from torch.utils.data import Dataset, DataLoader
from torch.autograd import Variable
import seaborn as sns
from collections import defaultdict

In [None]:
import warnings
warnings.filterwarnings('ignore')

In [None]:
df = pd.read_csv('../input/vinbigdata-chest-xray-abnormalities-detection/train.csv')
df.head()

In [None]:
input_file = os.listdir('../input/vinbigdata-chest-xray-abnormalities-detection/train')
input_files = []
for ip in input_file:
    input_files.append(ip.split('.')[0])

In [None]:
input_files[0]

In [None]:
df[df['image_id']==input_files[3]].sort_values(by=['class_id'])

In [None]:
class traindataset(torch.utils.data.Dataset):

    def __init__(self, df, file_list, transform):
        super().__init__()
        
        self.file_list = file_list
        self.df = df
    
    def __len__(self) :
        return len(self.file_list)

    def __getitem__(self, idx):
        
        img_id = self.file_list[idx]
        df = self.df
        s = 1024
        N = 15
        d_f = df[df['image_id']==img_id]
        dff = d_f.sort_values(by=['class_id'])
        class_id = dff['class_id'].values.tolist()
        img_pxl = pydicom.read_file('../input/vinbigdata-chest-xray-abnormalities-detection/train/'+img_id+'.dicom').pixel_array
        img_res = resize(img_pxl,(s,s),anti_aliasing=True)
        img_np = img_res.astype(np.float32())
        img_tr = torch.from_numpy(img_np)
        x_ = s/img_pxl.shape[1]
        y_ = s/img_pxl.shape[0]
        xmin = [x*x_ for x in dff['x_min'].values.tolist()]
        ymin = [y*y_ for y in dff['y_min'].values.tolist()]
        xmax = [x1*x_ for x1 in dff['x_max'].values.tolist()]
        ymax = [y1*y_ for y1 in dff['y_max'].values.tolist()]
        #bbox = []
        #for z in range(len(xmin)):
        #    bbox.append([xmin[z],ymin[z],xmax[z],ymax[z]])
        mask = np.zeros((N,s,s))
        for k,m in enumerate(class_id):
            if m != 14:
                x1,x2,y1,y2 = int(xmin[k]),int(xmax[k]),int(ymin[k]),int(ymax[k])
                mask[m,y1:y2,x1:x2] = 1
        mask_numpy = mask.astype(np.float32())
        mask_tensor = torch.from_numpy(mask_numpy)
        #mask_tensor = np.transpose(mask_tensor, (2,0,1))
        return img_tr,mask_tensor

In [None]:
traindata = traindataset(file_list = input_files,df =df,transform = None)
data_loader = torch.utils.data.DataLoader(traindata, batch_size=1, shuffle=True, num_workers=1)

In [None]:
for i in data_loader:
    print(i)
    break

from matplotlib.patches import Rectangle
for k,kk in enumerate(traindata):
    plt.imshow(kk[0])
    for org in kk[3]:
        plt.gca().add_patch(Rectangle((org[0], org[1]), (org[2]-org[0]), (org[3]-org[1]),linewidth=1,edgecolor='b',facecolor='none'))
    plt.show()
    for j in kk[2]:
        plt.imshow(j)
        plt.show()
    if k ==10:
        break

In [None]:
from collections import OrderedDict
import torch
import torch.nn as nn

class UNet(nn.Module):

    def __init__(self, in_channels=1, out_channels=15, init_features=32):
        super(UNet, self).__init__()

        features = init_features
        self.encoder1 = UNet._block(in_channels, features, name="enc1")
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.encoder2 = UNet._block(features, features * 2, name="enc2")
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.encoder3 = UNet._block(features * 2, features * 4, name="enc3")
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.encoder4 = UNet._block(features * 4, features * 8, name="enc4")
        self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.bottleneck = UNet._block(features * 8, features * 16, name="bottleneck")

        self.upconv4 = nn.ConvTranspose2d(
            features * 16, features * 8, kernel_size=2, stride=2
        )
        self.decoder4 = UNet._block((features * 8) * 2, features * 8, name="dec4")
        self.upconv3 = nn.ConvTranspose2d(
            features * 8, features * 4, kernel_size=2, stride=2
        )
        self.decoder3 = UNet._block((features * 4) * 2, features * 4, name="dec3")
        self.upconv2 = nn.ConvTranspose2d(
            features * 4, features * 2, kernel_size=2, stride=2
        )
        self.decoder2 = UNet._block((features * 2) * 2, features * 2, name="dec2")
        self.upconv1 = nn.ConvTranspose2d(
            features * 2, features, kernel_size=2, stride=2
        )
        self.decoder1 = UNet._block(features * 2, features, name="dec1")

        self.conv = nn.Conv2d(
            in_channels=features, out_channels=out_channels, kernel_size=1
        )

    def forward(self, x):
        enc1 = self.encoder1(x)
        enc2 = self.encoder2(self.pool1(enc1))
        enc3 = self.encoder3(self.pool2(enc2))
        enc4 = self.encoder4(self.pool3(enc3))

        bottleneck = self.bottleneck(self.pool4(enc4))

        dec4 = self.upconv4(bottleneck)
        dec4 = torch.cat((dec4, enc4), dim=1)
        dec4 = self.decoder4(dec4)
        dec3 = self.upconv3(dec4)
        dec3 = torch.cat((dec3, enc3), dim=1)
        dec3 = self.decoder3(dec3)
        dec2 = self.upconv2(dec3)
        dec2 = torch.cat((dec2, enc2), dim=1)
        dec2 = self.decoder2(dec2)
        dec1 = self.upconv1(dec2)
        dec1 = torch.cat((dec1, enc1), dim=1)
        dec1 = self.decoder1(dec1)
        return torch.sigmoid(self.conv(dec1))

    @staticmethod
    def _block(in_channels, features, name):
        return nn.Sequential(
            OrderedDict(
                [
                    (
                        name + "conv1",
                        nn.Conv2d(
                            in_channels=in_channels,
                            out_channels=features,
                            kernel_size=3,
                            padding=1,
                            bias=False,
                        ),
                    ),
                    (name + "norm1", nn.BatchNorm2d(num_features=features)),
                    (name + "relu1", nn.ReLU(inplace=True)),
                    (
                        name + "conv2",
                        nn.Conv2d(
                            in_channels=features,
                            out_channels=features,
                            kernel_size=3,
                            padding=1,
                            bias=False,
                        ),
                    ),
                    (name + "norm2", nn.BatchNorm2d(num_features=features)),
                    (name + "relu2", nn.ReLU(inplace=True)),
                ]
            )
        )

In [None]:
model_unet = UNet(in_channels=1, out_channels=15, init_features=32)
model_unet = model_unet.cuda()
calc = nn.MSELoss()
optimizer = optim.Adamax(model_unet.parameters(), lr=0.0003)

In [None]:
def dice_loss(pred, target, smooth = 1.):
    pred = pred.contiguous()
    target = target.contiguous()    

    intersection = (pred * target).sum(dim=2).sum(dim=2)
    
    loss = (1 - ((2. * intersection + smooth) / (pred.sum(dim=2).sum(dim=2) + target.sum(dim=2).sum(dim=2) + smooth)))
    
    return loss.mean()

In [None]:
def calc_loss(pred, target, metrics, bce_weight=0.5):
    bce = F.binary_cross_entropy_with_logits(pred, target)
        
    pred = F.sigmoid(pred)
    dice = dice_loss(pred, target)
    
    loss = bce * bce_weight + dice * (1 - bce_weight)
    
    metrics['bce'] += bce.data.cpu().numpy() * target.size(0)
    metrics['dice'] += dice.data.cpu().numpy() * target.size(0)
    metrics['loss'] += loss.data.cpu().numpy() * target.size(0)
    
    return loss

In [None]:
metrics = defaultdict(float)

In [None]:
epochs = 1
steps = 0
print_every = 10
train_losses, train_accuracy = [], []
#model.load_state_dict(torch.load('./final_model.pth'))
for epoch in range(epochs):
    model_unet.train()
    size = 0
    running_loss = 0
    acc = 0
    for a,(image_train, y_train) in enumerate(data_loader):
        steps += 1
        image_train, y_train = image_train.unsqueeze(0).cuda(), y_train.cuda()
        image_train = Variable(image_train,requires_grad=True)
        optimizer.zero_grad()
        y_predtrain = model_unet.forward(image_train)
        #y_train=y_train.type(torch.LongTensor)
        #loss = calc_loss(y_predtrain, y_train,metrics)
        loss = calc(y_predtrain, y_train)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        ps = torch.exp(y_predtrain)
        top_p = torch.max(y_predtrain, 1)
        top_class = torch.argmax(y_predtrain,dim = 1)
        equals = top_class == y_train
        print(torch.mean(equals.type(torch.FloatTensor)).item())
        acc += torch.mean(equals.type(torch.FloatTensor)).item()
        size += image_train.shape[0]
        model_unet.eval()
        print(f"Epoch {epoch+1}/{epochs}.. "
              f"Train loss: {running_loss/print_every:.3f}.. "
              f"Train accuracy: {acc/len(data_loader):.3f}")
    #torch.save(model.state_dict(),'./'+str(epoch)+'model.pth')
    #print('model saved')
    torch.save(model_unet.state_dict(),'./'+str(epoch)+'unet_model.pth')
    train_losses.append(float(running_loss)/float(size))
    train_accuracy.append(float(acc)/float(size))
    print('train_losses',epoch,train_losses)
    print('train_accuracy',epoch,train_accuracy)
torch.save(model_unet.state_dict(),'./final_unet_model.pth')
print('model saved')
print('train_losses',epoch,train_losses)
print('train_accuracy',epoch,train_accuracy)

In [None]:
model_unet.load_state_dict(torch.load('../input/unet-op/60unet_model.pth'))

In [None]:
img_pxl = pydicom.read_file('../input/vinbigdata-chest-xray-abnormalities-detection/train/000d68e42b71d3eac10ccc077aba07c1.dicom').pixel_array
img_res = resize(img_pxl,(1024,1024),anti_aliasing=True)
img_np = img_res.astype(np.float32())
img_tr = torch.from_numpy(img_np).unsqueeze(0).unsqueeze(0).cuda()
output = model_unet(img_tr)

In [None]:
predict = output.cpu().detach().squeeze()
print(predict.shape)

In [None]:
for i in predict:
    plt.imshow(i)
    plt.show()