In [None]:
import os, glob
import csv
import random
import pickle
import pandas as pd

import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.transforms as transforms

from PIL import Image
from torch.utils.data import Dataset, DataLoader
from scipy.ndimage.measurements import label
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import roc_auc_score, roc_curve

In [None]:
class Dataset(Dataset):

    def __init__(self, root_dir, split='train', output_size=(429,460)):
        # Define attributes
        self.output_size = output_size
        self.root_dir = root_dir
        self.split = split
        
        self.img_id = []
        for filename in glob.glob(os.path.join(self.root_dir, self.split, '*.jpg')):
            self.img_id.append(int(filename[6+len(split):6+len(split)+7]))
    
        self.images = []
        for k in range(len(self.img_id)):
            img_name = os.path.join(self.root_dir, self.split, str(self.img_id[k]) + ".jpg")
            img = np.array(Image.open(img_name).convert('RGB'))
            img = transforms.functional.to_tensor(img)
            img = transforms.functional.resize(img, self.output_size, interpolation=Image.BILINEAR)
            self.images.append(img)
            
        # Load ground truth for 'train' and 'val' sets
        if split != 'test':
            self.segs = []
            for k in range(len(self.img_id)):
                seg_name = os.path.join(self.root_dir, self.split, str(self.img_id[k]) + ".npy")
                od = np.array(np.load(seg_name)).copy()
                od = (od[:,:]==2).astype(np.float32)
                od = torch.from_numpy(od[None,:,:])
                od = transforms.functional.resize(od, self.output_size, interpolation=Image.NEAREST)
                #seg = torch.cat([od, oc], dim=0)
                self.segs.append(od) #seg
                
        print('Succesfully loaded {} dataset.'.format(split) + ' '*50)
            
            
    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        # Image
        img = self.images[idx]
    
        # Return only images for 'test' set
        if self.split == 'test':
            return img
        
        # Else, images and ground truth
        else:
            # Label
            df = pd.read_csv("data/dataset_2images.csv")
            lab = torch.tensor(df[df.eid_broad == self.img_id[idx]].has_disease.values[0], dtype=torch.float32)
            #lab = torch.tensor(self.index[str(idx)]['Label'], dtype=torch.float32)

            # Segmentation masks
            seg = self.segs[idx]

            #Width
            rows = (seg != 0).sum(1).float()
            w = torch.mean(rows[50:400])
            std = rows[50:400].std()
            att = torch.FloatTensor([w, std])
        
            return img, lab, seg, att, str(self.img_id[idx]) + ".jpg"   #self.index[str(idx)]['ImgName'] #lab, fov

In [None]:
def compute_dice_coef(input, target):
    '''
    Compute dice score metric.
    '''
    batch_size = input.shape[0]
    return sum([dice_coef_sample(input[k,:,:], target[k,:,:]) for k in range(batch_size)])/batch_size

def dice_coef_sample(input, target):
    iflat = input.contiguous().view(-1)
    tflat = target.contiguous().view(-1)
    intersection = (iflat * tflat).sum()
    return (2. * intersection) / (iflat.sum() + tflat.sum())


def width(binary_segmentation):
    '''
    Get the width along long axis from a binary segmentation.
    '''

    wd = np.mean(np.sum(binary_segmentation, axis=2)[:,50:400], axis=1) 
    return wd


def compute_width_error(pred_od, gt_od):
    '''
    Compute width prediction error, along with predicted width and ground truth width.
    '''
    pred_width = width(pred_od)
    gt_width = width(gt_od)
    width_err = np.mean(np.abs(gt_width - pred_width))
    return width_err, pred_width, gt_width


def classif_eval(classif_preds, classif_gts):
    '''
    Compute AUC classification score.
    '''
    auc = roc_auc_score(classif_gts, classif_preds)
    return auc

In [None]:
def refine_seg(pred):
    '''
    Only retain the biggest connected component of a segmentation map.
    '''
    np_pred = pred.numpy()
        
    largest_ccs = []
    for i in range(np_pred.shape[0]):
        labeled, ncomponents = label(np_pred[i,:,:])
        bincounts = np.bincount(labeled.flat)[1:]
        if len(bincounts) == 0:
            largest_cc = labeled == 0
        else:
            largest_cc = labeled == np.argmax(bincounts)+1
        largest_cc = torch.tensor(largest_cc, dtype=torch.float32)
        largest_ccs.append(largest_cc)
    largest_ccs = torch.stack(largest_ccs)
    
    return largest_ccs

In [None]:
class UNet(nn.Module):
    def __init__(self, n_channels=3, n_classes=1):
        super(UNet, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.epoch = 0

        self.inc = DoubleConv(n_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        factor = 2 
        self.down4 = Down(512, 1024 // factor)
        self.up1 = Up(1024, 512 // factor)
        self.up2 = Up(512, 256 // factor)
        self.up3 = Up(256, 128 // factor)
        self.up4 = Up(128, 64)
        self.output_layer = OutConv(64, n_classes)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        out = self.up1(x5, x4)
        out = self.up2(out, x3)
        out = self.up3(out, x2)
        out = self.up4(out, x1)
        out = self.output_layer(out)
        out = torch.sigmoid(out)
        return out

    
class DoubleConv(nn.Module):
    """(convolution => [BN] => ReLU) * 2"""

    def __init__(self, in_channels, out_channels, mid_channels=None):
        super().__init__()
        if not mid_channels:
            mid_channels = out_channels
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)


class Down(nn.Module):
    """Downscaling with maxpool then double conv"""

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.maxpool_conv(x)


class Up(nn.Module):
    """Upscaling then double conv"""

    def __init__(self, in_channels, out_channels):
        super().__init__()

        # Use the normal convolutions to reduce the number of channels
        self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)


    def forward(self, x1, x2):
        x1 = self.up(x1)
        # input is CHW
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]

        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)


class OutConv(nn.Module):
    '''
    Simple convolution.
    '''
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        return self.conv(x)

In [None]:
root_dir = 'data'
lr = 1e-4
batch_size = 5
num_workers = 0
total_epoch = 4

In [None]:
# Datasets
train_set = Dataset(root_dir, 
                          split='train')
val_set = Dataset(root_dir, 
                        split='val')
test_set = Dataset(root_dir, 
                         split='test')

In [None]:
# Dataloaders
train_loader = DataLoader(train_set, 
                          batch_size=batch_size, 
                          shuffle=True, 
                          num_workers=num_workers,
                          pin_memory=True,
                         )
val_loader = DataLoader(val_set, 
                        batch_size=batch_size, 
                        shuffle=False, 
                        num_workers=num_workers,
                        pin_memory=True,
                        )
test_loader = DataLoader(test_set, 
                        batch_size=batch_size, 
                        shuffle=False, 
                        num_workers=num_workers,
                        pin_memory=True)

In [None]:
# Device
device = torch.device("cuda:0")
#device = torch.device("cpu")

# Network
model = UNet(n_channels=3, n_classes=1).to(device)

# Loss
seg_loss = torch.nn.BCELoss(reduction='mean')

# Optimizer
optimizer = optim.Adam(model.parameters(), lr=lr)

In [None]:
# Define parameters
nb_train_batches = 3 #len(train_loader)
nb_val_batches = 3 #len(val_loader)
nb_iter = 0
best_val_auc = 0.

while model.epoch < total_epoch:
    # Accumulators
    train_widths, val_widths = [], []
    train_classif_gts, val_classif_gts = [], []
    train_loss, val_loss = 0., 0.
    train_dsc_od, val_dsc_od = 0., 0.
    train_width_error, val_width_error = 0., 0.
    
    ############
    # TRAINING #
    ############
    model.train()
    train_data = iter(train_loader)
    for k in range(nb_train_batches):
        # Loads data
        imgs, classif_gts, seg_gts, fov_coords, names = train_data.next()
        imgs, classif_gts, seg_gts = imgs.to(device), classif_gts.to(device), seg_gts.to(device)
        # Forward pass
        logits = model(imgs)
        loss = seg_loss(logits, seg_gts)
 
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_loss += loss.item() / nb_train_batches
        
        with torch.no_grad():
            # Compute segmentation metric
            pred_od = refine_seg((logits[:,0,:,:]>=0.5).type(torch.int8).cpu()).to(device)
            gt_od = seg_gts[:,0,:,:].type(torch.int8)
            dsc_od = compute_dice_coef(pred_od, gt_od)
            train_dsc_od += dsc_od.item()/nb_train_batches


            # Compute and store widths
            width_error, pred_width, gt_width = compute_width_error(pred_od.cpu().numpy(), gt_od.cpu().numpy())
            train_widths += pred_width.tolist()
            train_width_error += width_error / nb_train_batches
            train_classif_gts += classif_gts.cpu().numpy().tolist()
            
        # Increase iterations
        nb_iter += 1
        
        # Std out
        print('Epoch {}, iter {}/{}, loss {:.6f}'.format(model.epoch+1, k+1, nb_train_batches, loss.item()) + ' '*20)
        
    # Train a logistic regression on widths
    train_widths = np.array(train_widths).reshape(-1,1)
    train_classif_gts = np.array(train_classif_gts)
    clf = LogisticRegression(random_state=0, solver='lbfgs').fit(train_widths, train_classif_gts)
    train_classif_preds = clf.predict_proba(train_widths)[:,1]
    train_auc = classif_eval(train_classif_preds, train_classif_gts)
    
    ##############
    # VALIDATION #
    ##############
    model.eval()
    with torch.no_grad():
        val_data = iter(val_loader)
        for k in range(nb_val_batches):
            # Loads data
            imgs, classif_gts, seg_gts, fov_coords, names = val_data.next()
            imgs, classif_gts, seg_gts = imgs.to(device), classif_gts.to(device), seg_gts.to(device)

            # Forward pass
            logits = model(imgs)
            val_loss += seg_loss(logits, seg_gts).item() / nb_val_batches

            # Std out
            print('Validation iter {}/{}'.format(k+1, nb_val_batches) + ' '*50, 
                  end='\r')
            
            # Compute segmentation metric
            pred_od = refine_seg((logits[:,0,:,:]>=0.5).type(torch.int8).cpu()).to(device)
            gt_od = seg_gts[:,0,:,:].type(torch.int8)
            dsc_od = compute_dice_coef(pred_od, gt_od)
            val_dsc_od += dsc_od.item()/nb_val_batches
            
            # Compute and store widths
            width_error, pred_width, gt_width = compute_width_error(pred_od.cpu().numpy(), gt_od.cpu().numpy())
            val_widths += pred_width.tolist()
            val_width_error += width_error / nb_val_batches
            val_classif_gts += classif_gts.cpu().numpy().tolist()
            

    # CVD predictions from widths
    val_widths = np.array(val_widths).reshape(-1,1)
    val_classif_gts = np.array(val_classif_gts)
    val_classif_preds = clf.predict_proba(val_widths)[:,1]
    val_auc = classif_eval(val_classif_preds, val_classif_gts)
        
    # Validation results
    print('VALIDATION epoch {}'.format(model.epoch+1)+' '*50)
    print('LOSSES: {:.4f} (train), {:.4f} (val)'.format(train_loss, val_loss))
    print('Segmentation (Dice Score): {:.4f} (train), {:.4f} (val)'.format(train_dsc_od, val_dsc_od))
    print('width error: {:.4f} (train), {:.4f} (val)'.format(train_width_error, val_width_error))
    print('Classification (AUC): {:.4f} (train), {:.4f} (val)'.format(train_auc, val_auc))
        
    # End of epoch
    model.epoch += 1

In [None]:
model.eval()
val_widths = []
val_classif_gts = []
val_loss = 0.
val_dsc_od = 0.
val_dsc_oc = 0.
val_width_error = 0.
with torch.no_grad():
    val_data = iter(val_loader)
    for k in range(nb_val_batches):
        # Loads data
        imgs, classif_gts, seg_gts, fov_coords, names = val_data.next()
        imgs, classif_gts, seg_gts = imgs.to(device), classif_gts.to(device), seg_gts.to(device)

        # Forward pass
        logits = model(imgs)
        val_loss += seg_loss(logits, seg_gts).item() / nb_val_batches

        # Std out
        print('Validation iter {}/{}'.format(k+1, nb_val_batches) + ' '*50, 
              end='\r')

        # Compute segmentation metric
        pred_od = refine_seg((logits[:,0,:,:]>=0.5).type(torch.int8).cpu()).to(device)
        gt_od = seg_gts[:,0,:,:].type(torch.int8)
        dsc_od = compute_dice_coef(pred_od, gt_od)
        val_dsc_od += dsc_od.item()/nb_val_batches

        # Compute and store widths
        width_error, pred_width, gt_width = compute_width_error(pred_od.cpu().numpy(), gt_od.cpu().numpy())
        val_widths += pred_width.tolist()
        val_width_error += width_error / nb_val_batches
        val_classif_gts += classif_gts.cpu().numpy().tolist()


# CVD predictions from widths
val_widths = np.array(val_widths).reshape(-1,1)
val_classif_gts = np.array(val_classif_gts)
val_classif_preds = clf.predict_proba(val_widths)[:,1]
val_auc = classif_eval(val_classif_preds, val_classif_gts)

# Validation results
print('VALIDATION '+' '*50)
print('LOSSES: {:.4f} (val)'.format(val_loss))
print('OD segmentation (Dice Score): {:.4f} (val)'.format(val_dsc_od))
print('width error: {:.4f} (val)'.format(val_width_error))
print('Classification (AUC): {:.4f} (val)'.format(val_auc))

In [None]:
nb_test_batches = len(test_loader)
model.eval()
test_widths = []
with torch.no_grad():
    test_data = iter(test_loader)
    for k in range(nb_test_batches):
        # Loads data
        imgs = test_data.next()
        imgs = imgs.to(device)

        # Forward pass
        logits = model(imgs)

        # Std out
        print('Test iter {}/{}'.format(k+1, nb_test_batches) + ' '*50, 
              end='\r')
            
        # Compute segmentation
        pred_od = refine_seg((logits[:,0,:,:]>=0.5).type(torch.int8).cpu()).to(device)
            
        # Compute and store widths
        pred_width = vertical_cup_to_disc_ratio(pred_od.cpu().numpy(), pred_oc.cpu().numpy())
        test_widths += pred_width.tolist()
            

    # CVD predictions from widths
    test_widths = np.array(test_widths).reshape(-1,1)
    test_classif_preds = clf.predict_proba(test_widths)[:,1]