In [None]:
# Import stuff
import random
import numpy as np
import torch
import torch.backends.cudnn as cudnn
import torch.optim
import pandas as pd
from torch.utils.data import Dataset
from glob import glob
import random
import torch.distributed as dist
import torch.nn as nn
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from torchvision import transforms
from PIL import Image
from torchvision.transforms import ToTensor
import torchvision
import cv2
from sklearn.metrics import roc_curve, auc
import os
import pytorch_lightning as pl
import torch.nn.functional as F
import re
import ast
import random
from torchvision import models
from pytorch_lightning.loggers import WandbLogger
from scipy import signal

In [None]:
class Dataset():

    def __init__(self, file_list, mode, ett_annotations, carina_annotations):
        self.file_list = file_list
        self.mode = mode
        self.ett_annotations = ett_annotations
        self.carina_annotations = carina_annotations

    def __getitem__(self,idx):
        image_filepath = self.file_list[idx]
        image, code, original_imsize = self.loadimage(image_filepath)
        
        gt_map, ett_coords = self.get_GT(code, original_imsize)
#         for i in range(0,len(image)):
#             image[i] = self.crop_center(image[i], ett_coords)
        
        sample = {'Image': image,
                  'GT': gt_map,
                  'Code': code,
                  'ETT_coords': ett_coords,
              }

        return sample

    def __len__(self):
        return len(self.file_list)

    def loadimage(self, image_filepath):
        img = cv2.imread(image_filepath, 0)
        original_imsize = img.shape
        img = self.crop_center(img, 1280, 1280)
        image = ToTensor()(img) 
        image = 2*(image/torch.max(image))-1
        code = image_filepath.split('/')[4][:-4] 
        return image, code, original_imsize
    
    def get_GT(self, code, original_imsize):
        ett_annotations = self.ett_annotations
        ids = ett_annotations['StudyInstanceUID'].tolist()
        index_ett = ids.index(code) 
        
        xs = (ett_annotations['data'][index_ett])
        points = self.str2array(xs)
        y_coords = points[:,1]
        lowest_y = np.argmax(y_coords)
#         gt_map_et = np.zeros([original_imsize[0], original_imsize[1]], dtype='float')
        x_ett, y_ett = points[lowest_y]
        ett_coords = np.array([x_ett,y_ett])
        gt_map_et = self.get_gaussian_kernel(ett_coords, original_imsize)
#         X, Y = np.mgrid[x_ett-20:x_ett+20, y_ett-20:y_ett+20]
#         coords_et = np.vstack((X.ravel(), Y.ravel()))
#         for i in range(0,len(coords_et[0])):
#             gt_map_et[coords_et[1][i]][coords_et[0][i]] = 1
        
        gt_map_et = self.crop_center(gt_map_et, 1280, 1280)

        return gt_map_et, ett_coords

    def crop_center(self,img,cropx,cropy):
        y,x = img.shape
        startx = x//2-(cropx//2)
        # crop closer to top  
        return img[0:cropy,startx:startx+cropx]
    
    def str2array(self, s):
        # Remove space after [
        s=re.sub('\[ +', '[', s.strip())
        # Replace commas and spaces
        s=re.sub('[,\s]+', ', ', s)
        return np.array(ast.literal_eval(s))
    
    def get_gaussian_kernel(self, coords, imshape):
        N = 201 # kernel size
        k1d = signal.gaussian(N, std=20).reshape(N, 1)
        kernel = np.outer(k1d, k1d)
        x, y = imshape
        A = np.zeros([x,y])
        A[coords[1],coords[0]] = 1    # random
        row, col = np.where(A == 1)
        if row[0] >= N//2 and col[0] >= N//2:
            A[row[0]-(N//2):row[0]+(N//2)+1, col[0]-(N//2):col[0]+(N//2)+1] = kernel
        return A
 

In [None]:
carina_annotations = pd.read_json('../input/nihcsvdata/carina.json')
ett_annotations = pd.read_csv('../input/nihcsvdata/ETT_Annotations.csv')
combined_df = pd.read_csv('../input/nihcsvdata/combined_NIH_CLIP.csv')
paths = ett_annotations['StudyInstanceUID'].tolist()
        
combined_df = combined_df[combined_df['Original Image Pixel Spacing x'].notna()]
combined_df = combined_df.reset_index()
combined_df = combined_df.drop(['index'], axis=1)
combined_df = combined_df.drop(['Unnamed: 0'], axis=1)


In [None]:
real_paths = []
paths = ett_annotations['StudyInstanceUID'].tolist()
for path in paths:
    if path in combined_df['StudyInstanceUID'].tolist():
        real_paths.append(path)

train_dir = '../input/ranzcr-clip-catheter-line-classification/train/'
total_files =[train_dir + i + '.jpg' for i in real_paths]
val_length = int(len(total_files) * 0.2)
random.seed(10)
random.shuffle(total_files)
train_files = total_files[val_length:]
val_files = total_files[:val_length]
# test_files = []
# test_files.extend(sorted(glob(test_dir + '/*.jpg')))

# # 10/2/3/ initial split
# train_set = paths[0:10]
# val_set = paths[10:12]
# test_set = paths[12:15]

# CHANGE SHUFFLE BACK
Train_Dataset = Dataset(train_files[:100], "Train", ett_annotations, carina_annotations)
Train_dataloader = DataLoader(Train_Dataset, shuffle=True, num_workers=2, batch_size=1, pin_memory=True)

Val_Dataset = Dataset(val_files[:25], "Val", ett_annotations, carina_annotations)
Val_dataloader = DataLoader(Val_Dataset, shuffle=False, num_workers=2, batch_size=1, pin_memory=True)

# Test_Dataset = Dataset(test_set, "Test", ett_annotations, carina_annotations)
# Test_dataloader = DataLoader(Test_Dataset, shuffle=False, num_workers=2, batch_size=1, pin_memory=True)

In [None]:
# Example cropped image from dataloader
for i, data in enumerate(Train_dataloader):
    image = data['Image']
    gt = data['GT']
    gt = gt.squeeze(0)
    print(gt.shape)
    print(data['Code'])
    gt = gt.cpu().detach().numpy()

    image = image.cpu().detach().numpy()[0][0]
    
#     for g in range(0,1280):
#         for j in range(0,1280):
#             if gt[g][j] != 0:
#                 image[g][j] = gt[g][j]
    plt.imshow(image, interpolation='nearest', cmap='Greys')
    plt.imshow(gt, cmap='Greys', alpha=1)
    break 
#     print(i)

In [None]:
class resconv_block(nn.Module):
    def __init__(self,ch_in,ch_out):
        super(resconv_block,self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(ch_in, ch_out, kernel_size=3,stride=1,padding=1,bias=True),
            nn.BatchNorm2d(ch_out),
            nn.ReLU(inplace=True),
            nn.Conv2d(ch_out, ch_out, kernel_size=3,stride=1,padding=1,bias=True),
            nn.BatchNorm2d(ch_out),
            nn.ReLU(inplace=True)
        )
        self.Conv_1x1 = nn.Conv2d(ch_in,ch_out,kernel_size=1,stride=1,padding=0)

    def forward(self,x):

        residual =  self.Conv_1x1(x)
        x = self.conv(x)

        return residual+x

class up_conv(nn.Module):
    def __init__(self,ch_in,ch_out):
        super(up_conv,self).__init__()
        self.up = nn.Sequential(
            nn.Upsample(scale_factor=2),
            nn.Conv2d(ch_in,ch_out,kernel_size=3,stride=1,padding=1,bias=True),
            nn.BatchNorm2d(ch_out),
            nn.ReLU(inplace=True)
        )

    def forward(self,x):
        x = self.up(x)
        return x

class ResU_Net(nn.Module):
    def __init__(self,img_ch=1,output_ch=1):
        super(ResU_Net,self).__init__()

        self.Maxpool = nn.MaxPool2d(kernel_size=2,stride=2)
        self.Softmax = nn.Softmax(dim=1)
        self.Sigmoid = nn.Sigmoid()

        self.Conv1 = resconv_block(ch_in=img_ch,ch_out=64)
        self.Conv2 = resconv_block(ch_in=64,ch_out=128)
        self.Conv3 = resconv_block(ch_in=128,ch_out=256)
        self.Conv4 = resconv_block(ch_in=256,ch_out=512)
        self.Conv5 = resconv_block(ch_in=512,ch_out=1024)

        self.Up5 = up_conv(ch_in=1024,ch_out=512)
        self.Up_conv5 = resconv_block(ch_in=1024, ch_out=512)

        self.Up4 = up_conv(ch_in=512,ch_out=256)
        self.Up_conv4 = resconv_block(ch_in=512, ch_out=256)

        self.Up3 = up_conv(ch_in=256,ch_out=128)
        self.Up_conv3 = resconv_block(ch_in=256, ch_out=128)

        self.Up2 = up_conv(ch_in=128,ch_out=64)
        self.Up_conv2 = resconv_block(ch_in=128, ch_out=64)

        self.Conv_1x1 = nn.Conv2d(64,output_ch,kernel_size=1,stride=1,padding=0)


    def forward(self,x):
        # encoding path
        x1 = self.Conv1(x)

        x2 = self.Maxpool(x1)
        x2 = self.Conv2(x2)

        x3 = self.Maxpool(x2)
        x3 = self.Conv3(x3)

        x4 = self.Maxpool(x3)
        x4 = self.Conv4(x4)

        x5 = self.Maxpool(x4)
        x5 = self.Conv5(x5)

        # decoding + concat path
        d5 = self.Up5(x5)
        d5 = torch.cat((x4,d5),dim=1)

        d5 = self.Up_conv5(d5)

        d4 = self.Up4(d5)
        d4 = torch.cat((x3,d4),dim=1)
        d4 = self.Up_conv4(d4)

        d3 = self.Up3(d4)
        d3 = torch.cat((x2,d3),dim=1)
        d3 = self.Up_conv3(d3)

        d2 = self.Up2(d3)
        d2 = torch.cat((x1,d2),dim=1)
        d2 = self.Up_conv2(d2)

        d1 = self.Conv_1x1(d2)
        d1 = self.Sigmoid(d1)

        return d1

In [None]:
# model = UNetWithResnet50Encoder(checkpoint_pth = '../input/nihcsvdata/medaug_chexpert_resnet50.pth.tar')
model = ResU_Net()
model = model.cuda()
# criterion=nn.BCELoss()
criterion=nn.MSELoss()
criterion = criterion.cuda()

optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
# with torch.no_grad():

# g = torch.rand(1,1,1280,1280)
# g = g.cuda()
# out = model(g)
# print(out.shape)

In [None]:

def calculate_metrics_no_carina(probability_map, code, combined_df, ett_gt):
        et_tube = probability_map.cpu().detach().numpy()
        ett_gt = ett_gt.cpu().detach().numpy()
        et_tube_coords = np.array(np.unravel_index(np.argmax(et_tube, axis=None), et_tube.shape))
        combined_list = combined_df['StudyInstanceUID'].tolist()
        ind = combined_list.index(code)

        x_conversion = combined_df['Original Image Pixel Spacing x'][int(ind)]
        y_conversion = combined_df['Original Image Pixel Spacing y'][int(ind)]
        et_tube_coords[0] = et_tube_coords[0] * x_conversion
        et_tube_coords[1] = et_tube_coords[1] * y_conversion
        ett_gt[0] = ett_gt[0] * x_conversion
        ett_gt[1] = ett_gt[1] * y_conversion
        
        ett_abs_error = np.linalg.norm(et_tube_coords - ett_gt)/10
        
        if ett_abs_error <= 1:
            ett_correct = True
        else:
            ett_correct = False

        return ett_abs_error, ett_correct

# Train/val

train_global_losses = []
val_global_losses = []
train_global_dists = []
val_global_dists = []

for epoch in range(0, 50):
    
    # Declare lists to keep track of losses and metrics within the epoch
    train_epoch_losses = []
    val_epoch_losses = []
    train_epoch_dists = []
    val_epoch_dists = []
    
    model.train()

    count = 0

    for i, data in enumerate(Train_dataloader):

        input_img = data['Image']
        gt = data['GT']
        code = data['Code']
        
        input_img = input_img.cuda()
        gt = gt.cuda()
        gt = gt.to(torch.float)

#         output = torch.rand(1,1024,1024).cuda()
        output = model(input_img)[0]
        loss = criterion(output, gt)
        train_epoch_losses.append(loss.item())
        
        ett_coords_gt = data['ETT_coords']
        ett_abs_error, ett_correct = calculate_metrics_no_carina(output[0],code[0],combined_df, ett_coords_gt[0])
        train_epoch_dists.append(ett_abs_error)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    # Transition to val mode
    model.eval()

    # Avoid computing gradients during validation to save memory
    with torch.no_grad():
        count = 0

        for i, data in enumerate(Val_dataloader):

            input_img = data['Image']
            gt = data['GT']
            code = data['Code']

            input_img = input_img.cuda()
            gt = gt.cuda()
            gt = gt.to(torch.float)

            output = model(input_img)[0]
#             output = torch.rand(1,1024,1024).cuda()
            loss = criterion(output, gt)
            val_epoch_losses.append(loss.item())

            ett_coords_gt = data['ETT_coords']
            ett_abs_error, ett_correct = calculate_metrics_no_carina(output[0],code[0],combined_df, ett_coords_gt[0])
            val_epoch_dists.append(ett_abs_error)

    train_net_loss = sum(train_epoch_losses) / len(train_epoch_losses)
    val_net_loss = sum(val_epoch_losses) / len(val_epoch_losses)
    train_global_losses.append(train_net_loss)
    val_global_losses.append(val_net_loss)
    
    train_net_dists = sum(train_epoch_dists) / len(train_epoch_dists)
    val_net_dists = sum(val_epoch_dists) / len(val_epoch_dists)
    train_global_dists.append(train_net_dists)
    val_global_dists.append(val_net_dists)

    print('Epoch: {} | Train Loss: {} | Val Loss: {} | Train Dist: {} | Val Dist: {} |'.format(epoch, train_net_loss, val_net_loss,train_net_dists, val_net_dists))

    checkpoint_dir = './'
    # Save the model if it reaches a new min validation loss
    if val_global_losses[-1] == min(val_global_losses):
        print('saving model at the end of epoch ' + str(epoch))
        if epoch > 5:
            best_epoch = epoch
            file_name = os.path.join(checkpoint_dir, 'model_epoch_{}.pth'.format(epoch))
            torch.save({
                'epoch': epoch,
                'state_dict': model.state_dict(),
                'optim_dict': optimizer.state_dict(),
                },
                file_name)

In [None]:
# Test

best_epoch = np.argmin(np.array(val_global_losses))

model = ResU_Net()

load_dir = './model_epoch_' + str(best_epoch) + '.pth'
checkpoint = torch.load(load_dir)
model.load_state_dict(checkpoint['state_dict'])
model.cuda()
output_map = []
gts = []
codes = []

with torch.no_grad():
    count = 0
    for i, data in enumerate(Val_dataloader):
        input_img = data['Image']
        gt = data['GT']
        codes.append(data['Code'])

        input_img = input_img.cuda()
        gt = gt.cuda()
        gt = gt.to(torch.float)
        gts.append(gt)
        
        output = model(input_img)
        output_map.append(output[0].cpu().detach().numpy())
        
        if count > 5:
            break
        
        count += 1
        


In [None]:
# Visualization (original test image with GT)

def crop_center(img,cropx,cropy):
    y,x = img.shape
    startx = x//2-(cropx//2)
    starty = y//3-(cropy//2)    
    return img[starty:starty+cropy,startx:startx+cropx]

ind = 0
test_path_1 = val_files[ind]
test_img = cv2.imread(test_path_1, 0)
test_img = 2*(test_img/np.max(test_img))-1
g = crop_center(test_img, 1280, 1280)
plt.imshow(g, interpolation='nearest', cmap='Greys')
gt = gts[ind].cpu().detach().numpy()
print(gt.shape)
plt.imshow(gt[0], cmap='Greys', alpha=0.2)

In [None]:
# Visualization (test image output)

# map_1 = output_map[ind]
map_1 = output_map[0]

# plt.imshow(g, interpolation='nearest', cmap='Greys')
# plt.imshow(map_1[0], cmap='Greys', alpha=0.3)
plt.imshow(map_1[0], cmap='Greys')

In [None]:
# old dataloader

# class Dataset():

#     def __init__(self, file_list, mode, ett_annotations, carina_annotations):
#         self.file_list = file_list
#         self.mode = mode
#         self.ett_annotations = ett_annotations
#         self.carina_annotations = carina_annotations

#     def __getitem__(self,idx):
#         image_filepath = self.file_list[idx]
#         image, code, original_imsize = self.loadimage(image_filepath)
        
#         gt_map, ett_coords = self.get_GT(code, original_imsize)
# #         for i in range(0,len(image)):
# #             image[i] = self.crop_center(image[i], ett_coords)
        
#         sample = {'Image': image,
#                   'GT': gt_map,
#                   'Code': code,
#                   'ETT_coords': ett_coords,
#               }

#         return sample

#     def __len__(self):
#         return len(self.file_list)

#     def loadimage(self, image_filepath):
#         img = cv2.imread(image_filepath, 0)
#         original_imsize = img.shape
# #         img = self.crop_center(img, 1024, 1024)
#         image = ToTensor()(img) 
#         image = 2*(image/torch.max(image))-1
#         code = image_filepath.split('/')[4][:-4] 
#         return image, code, original_imsize
    
#     def get_GT(self, code, original_imsize):
#         ett_annotations = self.ett_annotations
#         ids = ett_annotations['StudyInstanceUID'].tolist()
#         index_ett = ids.index(code) 
        
#         xs = (ett_annotations['data'][index_ett])
#         points = self.str2array(xs)
#         y_coords = points[:,1]
#         lowest_y = np.argmax(y_coords)
#         gt_map_et = np.zeros([original_imsize[0], original_imsize[1]], dtype='float')
#         x_ett, y_ett = points[lowest_y]
#         ett_coords = np.array([x_ett,y_ett])
#         X, Y = np.mgrid[x_ett-20:x_ett+20, y_ett-20:y_ett+20]
#         coords_et = np.vstack((X.ravel(), Y.ravel()))
#         for i in range(0,len(coords_et[0])):
#             gt_map_et[coords_et[1][i]][coords_et[0][i]] = 1
#         gt_map_et = self.crop_center(gt_map_et, ett_coords)

#         return gt_map_et, ett_coords

#     def crop_center(self,img,ett_coords):
#         x,y = ett_coords[0], ett_coords[1]
#         print(img.shape)
#         x_im,y_im = img.shape
#         if x - 512 >= 0:
#             startx = x-512
#         else:
#             startx = 0
#         if x + 512 < x_im:
#             endx = x+512
#         else:
#             endx = x_im
            
#         if y - 100 >= 0:
#             starty = y-100
#         else:
#             starty = 0
#         if y + 924 < y_im:
#             endy = y+924
#         else:
#             endy = y_im
            
#         return img[starty:endy,startx:endx]
    
#     def str2array(self, s):
#         # Remove space after [
#         s=re.sub('\[ +', '[', s.strip())
#         # Replace commas and spaces
#         s=re.sub('[,\s]+', ', ', s)
#         return np.array(ast.literal_eval(s))
 