In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

import os
from datetime import datetime
import sys
import gc
sys.path.append('../../')

from sklearn.model_selection import KFold

from dependencies import *
from settings import *
from reproducibility import *
from models.TGS_salt.Unet34_scSE_hyper import Unet_scSE_hyper as Net

Using paths on kail-main

Importing numerical libraries...
Importing standard libraries...
Importing miscellaneous functions...
Importing constants...
Importing Neural Network dependencies...
	PyTorch
	Keras
	TensorFlow
	Metrics, Losses and LR Schedulers
	Kaggle Metrics
	Image augmentations
	Datasets
Importing external libraries...
	Lovasz Losses (elu+1)

Fixing random seed for reproducibility...
	Setting random seed to 35202.

Setting CUDA environment...
	torch.__version__              = 1.1.0
	torch.version.cuda             = 9.0.176
	torch.backends.cudnn.version() = 7501
	os['CUDA_VISIBLE_DEVICES']     = 0,1
	torch.cuda.device_count()      = 1



In [2]:
SIZE = 256
FACTOR = 256
ne = "ne"
initial_checkpoint = None
MODEL = "ResNet34"
#OHEM = "OHEM"
OHEM = ''

batch_size = 16
n_acc = 64 / batch_size
nfolds = 4

noise_th = 75.0*(SIZE/128.0)**2 #threshold for the number of predicted pixels
best_thr0 = 0.2 #preliminary value of the threshold for metric calculation

data_root = '../../data/siim-pneumothorax'
torch.cuda.set_device('cuda:0')

In [3]:
def time_to_str(time, str):
    #if str == 'min':
    #	    return str(round(float(time)/60,5))+" min(s)"
    return round(time,4)

In [4]:
#TODO: Instead of directly printing to stdout, copy it into a txt file
class Logger():
    def __init__(self,name=MODEL+OHEM+ne):
        super().__init__()
        self.model=name
        #if OHEM != "OHEM":
        #    self.model=MODEL+ne[ne.find("_")+1:]
        self.file = open(self.model+"_log.txt","w+")
        self.file.close()
    def write(self, str):
        print(str)
        self.file = open(self.model+"_log.txt","a+")
        self.file.write(str)
        self.file.close()
    def write2(self, str):
        print(str, end='',flush=True)
        self.file = open(self.model+"_log.txt","a+")
        self.file.write(str)
        self.file.close()
    def stop():
        self.file.close()
        
log = Logger()

In [5]:
def valid_augment(image,mask,index):
    cache = Struct(image = image.copy(), mask = mask.copy())
    image, mask = do_resize2(image, mask, SIZE, SIZE)
    image, mask = do_center_pad_to_factor2(image, mask, factor = FACTOR)
    return image,mask,index,cache

def train_augment(image,mask,index):
    cache = Struct(image = image.copy(), mask = mask.copy())

    if np.random.rand() < 0.5:
         image, mask = do_horizontal_flip2(image, mask)
         pass

    if np.random.rand() < 0.5:
        c = np.random.choice(4)
        if c==0:
            image, mask = do_random_shift_scale_crop_pad2(image, mask, 0.2) #0.125

        if c==1:
            image, mask = do_horizontal_shear2( image, mask, dx=np.random.uniform(-0.07,0.07) )
            pass

        if c==2:
            image, mask = do_shift_scale_rotate2( image, mask, dx=0, dy=0, scale=1, angle=np.random.uniform(0,15))  #10

        if c==3:
            image, mask = do_elastic_transform2(image, mask, grid=10, distort=np.random.uniform(0,0.15))#0.10
            pass
    if np.random.rand() < 0.5:
        c = np.random.choice(3)
        if c==0:
            image = do_brightness_shift(image,np.random.uniform(-0.1,+0.1))
        if c==1:
            image = do_brightness_multiply(image,np.random.uniform(1-0.08,1+0.08))
        if c==2:
            image = do_gamma(image,np.random.uniform(1-0.08,1+0.08))
        # if c==1:
        #     image = do_invert_intensity(image)

    image, mask = do_resize2(image, mask, SIZE, SIZE)
    image, mask = do_center_pad_to_factor2(image, mask, factor = FACTOR)
    return image,mask,index,cache

In [6]:
def null_augment(image, mask, index):
    cache = Struct(image = image.copy(), mask = mask.copy())
    return image, mask, index, cache

def null_collate(batch):

    batch_size = len(batch)
    cache = []
    input = []
    truth = []
    index = []
    for b in range(batch_size):
        input.append(batch[b][0])
        truth.append(batch[b][1])
        index.append(batch[b][2])
        cache.append(batch[b][3])
    input = torch.from_numpy(np.array(input)).float().unsqueeze(1)

    if truth[0]!=[]:
        truth = torch.from_numpy(np.array(truth)).float().unsqueeze(1)

    return input, truth, index, cache

class SIIMDataset(Dataset):
    def __init__(self, data_root, fold, width=1024, height=1024, phase='train', augment=null_augment, random_state=2019, nfolds=4):
        self.data_root = data_root
        self.fold = fold
        self.height = width
        self.width = height
        self.phase = phase
        self.augment = augment
        
        kf = KFold(n_splits=nfolds, shuffle=True, random_state=random_state)
        train_list = os.listdir(os.path.join(data_root, 'train_png'))
        
        if phase == 'train':
            index_list = list(kf.split(list(range(len(train_list)))))[fold][0]
            self.filenames = [train_list[i] for i in index_list]
        elif phase == 'val':
            index_list = list(kf.split(list(range(len(train_list)))))[fold][1]
            self.filenames = [train_list[i] for i in index_list]
        else: # test
            self.filenames = os.listdir(os.path.join(data_root, 'test_png'))

    def __getitem__(self, index):
        img_path = os.path.join(self.data_root, 'train_png/{}'.format(self.filenames[index]))
        img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE).astype(np.float32) / 255.
        img = cv2.resize(img, (self.width, self.height), interpolation = cv2.INTER_AREA)
        
        if self.phase == 'test':
            mask = []
        else: # train and val
            mask_path = os.path.join(self.data_root, 'mask_png/{}'.format(self.filenames[index]))
            mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE).astype(np.float32) / 255.
            mask = cv2.resize(mask, (self.width, self.height), interpolation = cv2.INTER_AREA)
        
        return self.augment(img, mask, index)
    
    def __len__(self):
        return len(self.filenames)


In [7]:
def validation( net, valid_loader ):

    valid_num  = 0
    valid_loss = np.zeros(3, np.float32)
    
    #debug
    if 0: #debug  ##-------------------------------
        debug_num = 2
        debug_count = 0
        
        for input, truth, index, cache in valid_loader:
            images = input.cpu().data.numpy().squeeze()
            masks  = truth.cpu().data.numpy().squeeze()
            batch_size = len(index)
            for b in range(batch_size):
                image = images[b]*255
                image = np.dstack([image,image,image])

                mask = masks[b]
                
                # Plot some samples
                fig, (ax0, ax1, ax2) = plt.subplots(ncols=3, figsize=(12, 4))
                ax0.imshow(image.astype(np.uint8))
                ax1.imshow(mask, vmin=0, vmax=1)
                ax2.imshow(mask, vmin=0, vmax=1)
                ax1.set_title('Targets')
                ax2.set_title('Predictions')
                
                plt.show()
                
            debug_count += 1
            if debug_count > debug_num:
                break
    #--------------------------------------

    for input, truth, index, cache in valid_loader:
        input = input.cuda()
        truth = truth.cuda()
        
        with torch.no_grad():
            logit = net(input) #data_parallel(net,input)

            loss  = net.focal_loss(logit, truth, 1.0, 0.5, 0.25) + net.criterion(logit, truth)
            dice  = net.metric(logit, truth, noise_th=noise_th, threshold=best_thr0)

        batch_size = len(index)
        valid_loss += batch_size * np.array(( loss.item(), dice.item(), 0))
        valid_num += batch_size
        
    valid_loss /= valid_num
    
    return valid_loss

In [8]:
def freeze(net):
    for p in net.conv1.parameters():
        p.requires_grad = False
        
    for p in net.encoder2.parameters():
        p.requires_grad = False
        
    for p in net.encoder3.parameters():
        p.requires_grad = False
        
    for p in net.encoder4.parameters():
        p.requires_grad = False
        
    for p in net.encoder5.parameters():
        p.requires_grad = False
        
    for p in net.center.parameters():
        p.requires_grad = False
        
    for p in net.decoder5.parameters():
        p.requires_grad = False
        
    for p in net.decoder4.parameters():
        p.requires_grad = False
        
    for p in net.decoder3.parameters():
        p.requires_grad = False
        
    for p in net.decoder2.parameters():
        p.requires_grad = False
        
    for p in net.decoder1.parameters():
        p.requires_grad = False
        
def unfreeze(net):
    for p in net.conv1.parameters():
        p.requires_grad = True
        
    for p in net.encoder2.parameters():
        p.requires_grad = True
        
    for p in net.encoder3.parameters():
        p.requires_grad = True
        
    for p in net.encoder4.parameters():
        p.requires_grad = True
        
    for p in net.encoder5.parameters():
        p.requires_grad = True
        
    for p in net.center.parameters():
        p.requires_grad = True
        
    for p in net.decoder5.parameters():
        p.requires_grad = True
        
    for p in net.decoder4.parameters():
        p.requires_grad = True
        
    for p in net.decoder3.parameters():
        p.requires_grad = True
        
    for p in net.decoder2.parameters():
        p.requires_grad = True
        
    for p in net.decoder1.parameters():
        p.requires_grad = True

def cosine_annealing_scheduler(num_iter, lr_init, lr_min):
    scheduler = lambda x: ((lr_init-lr_min)/2)*(np.cos(PI*(np.mod(x,num_iter)/(num_iter)))+1)+lr_min
    return scheduler
        
def set_BN_momentum(model, momentum=0.1*batch_size/64):
    for i, (name, layer) in enumerate(model.named_modules()):
        if isinstance(layer, nn.BatchNorm2d) or isinstance(layer, nn.BatchNorm1d):
            layer.momentum = momentum
            
def fit_one_cycle(epochs, net, train_loader, val_loader, lr_init=0.001, lr_min=0.000001, OHEM=''):
    # init learner
    iter_per_epoch = len(train_loader)
    num_iter = iter_per_epoch * epochs
    iter_smooth = 20
    iter_log    = 100
    iter_valid  = iter_per_epoch
    #iter_valid = 100
    
    scheduler = cosine_annealing_scheduler(num_iter, lr_init, lr_min)
    optimizer = optim.SGD(filter(lambda p: p.requires_grad, net.parameters()),
          lr=lr_init, momentum=0.9, weight_decay=0.0001
    )
    set_BN_momentum(net)
    
    start_iter = 0
    start_epoch= 0
    train_loss  = np.zeros(6,np.float32)
    valid_loss  = np.zeros(6,np.float32)
    batch_loss  = np.zeros(6,np.float32)
    rate = 0
    iter = 0
    epoch = 0
    
    #debug
    if 0: #debug  ##-------------------------------
        debug_num = 2
        debug_count = 0
        
        for input, truth, index, cache in train_loader:
            images = input.cpu().data.numpy().squeeze()
            masks  = truth.cpu().data.numpy().squeeze()
            batch_size = len(index)
            for b in range(batch_size):
                image = images[b]*255
                image = np.dstack([image,image,image])

                mask = masks[b]
                
                # Plot some samples
                fig, (ax0, ax1, ax2) = plt.subplots(ncols=3, figsize=(12, 4))
                ax0.imshow(image.astype(np.uint8))
                ax1.imshow(mask, vmin=0, vmax=1)
                ax2.imshow(mask, vmin=0, vmax=1)
                ax1.set_title('Targets')
                ax2.set_title('Predictions')
                
                plt.show()
                
            debug_count += 1
            if debug_count > debug_num:
                break
    #--------------------------------------
    
    start = timer()
    while iter < num_iter:  # loop over the dataset multiple times
        sum_train_loss = np.zeros(6,np.float32)
        sum = 0
        

        log.write('\n rate    iter   epoch   | valid_loss               | train_loss               | batch_loss               |  time          \n')
        log.write('-------------------------------------------------------------------------------------------------------------------------------\n')
            
        for input, truth, index, cache in train_loader:
            # validation
            if (iter + 1) % iter_valid == 0:
                net.set_mode('valid')
                valid_loss = validation(net, val_loader)
                
                net.set_mode('train')
                
                '''
                log.write2('\r')
                log.write('%0.4f  %5.1f  %6.1f  |  %0.3f  %0.3f  (%0.3f) |  %0.3f  %0.3f  |  %0.3f  %0.3f  | %s \n' % (\
                         rate, iter/iter_per_epoch, epoch,
                         valid_loss[0], valid_loss[1], valid_loss[2],
                         train_loss[0], train_loss[1],
                         batch_loss[0], batch_loss[1],
                         time_to_str((timer() - start),'min')))
                '''
                time.sleep(0.01)
                
            if scheduler is not None:
                lr = scheduler(iter)
                if lr<0 : break
                adjust_learning_rate(optimizer, lr)
                rate = get_learning_rate(optimizer)
            
            # ok, train
            net.set_mode('train')

            input = input.cuda()
            truth = truth.cuda()

            logit = net(input) #data_parallel(net,input)

            if OHEM == "OHEM":
                loss = net.focal_loss(logit, truth, 1.0, 0.5, 0.25) + net.criterion(logit, truth)
            else:
                loss = net.criterion2(logit, truth)
                
            dice = net.metric(logit, truth, noise_th=noise_th, threshold=best_thr0)
            
            # learn with grad acc
            loss /= n_acc
            loss.backward()
            
            if ((iter + 1) % n_acc) == 0:
                optimizer.step()
                optimizer.zero_grad()
                torch.nn.utils.clip_grad_norm_(net.parameters(), 1)
            
            # print statistics  ------------
            batch_loss = np.array((
                           loss.item(),
                           dice.item(),
                           0, 0, 0, 0,
                         ))
            sum_train_loss += batch_loss
            sum += 1
            if iter%iter_smooth == 0:
                train_loss = sum_train_loss/sum
                sum_train_loss = np.zeros(6,np.float32)
                sum = 0

            log.write2('\r%0.4f  %5.1f  %6.1f  |  %0.3f  %0.3f  (%0.3f) |  %0.3f  %0.3f  |  %0.3f  %0.3f  | %s ' % (\
                         rate, iter/iter_per_epoch, epoch+1,
                         valid_loss[0], valid_loss[1], valid_loss[2],
                         train_loss[0], train_loss[1],
                         batch_loss[0], batch_loss[1],
                         time_to_str((timer() - start), 'min')))
            
            iter += 1
            epoch = iter // iter_per_epoch


In [9]:
def get_dataloaders(data_root, batch_size, fold, nfolds=4, width=1024, height=1024, train_augment=null_augment, val_augment=null_augment, random_state=SEED):
    train_dataset = SIIMDataset(
        data_root,
        fold,
        width=width, height=height,
        phase='train',
        augment=train_augment,
        random_state=random_state,
        nfolds=nfolds
    )

    train_loader  = DataLoader(
        train_dataset,
        sampler     = RandomSampler(train_dataset),
        batch_size  = batch_size,
        drop_last   = True,
        num_workers = 8,
        pin_memory  = True,
        collate_fn  = null_collate
    )

    val_dataset = SIIMDataset(
        data_root,
        fold,
        width=width, height=height,
        phase='val',
        augment=val_augment,
        random_state=random_state,
        nfolds=nfolds
    )

    val_loader  = DataLoader(
        val_dataset,
        sampler     = RandomSampler(val_dataset),
        batch_size  = batch_size,
        drop_last   = False,
        num_workers = 8,
        pin_memory  = True,
        collate_fn  = null_collate
    )
    
    return train_loader, val_loader

In [None]:
print("Training U-Net with hypercolumn concatenation and spatial/channel-wise excitation...")

now = datetime.now()

for fold in range(nfolds):
    log.write('\nfold {}:\n'.format(fold))
    
    train_loader, val_loader = get_dataloaders(
        data_root,
        batch_size,
        fold, nfolds=nfolds,
        width=SIZE, height=SIZE,
        train_augment=train_augment, val_augment=valid_augment,
        random_state=SEED
    )
    
    net = Net().cuda()
    
    lr = 0.001
    
    # warm up
    freeze(net)
    fit_one_cycle(
        6, net,
        train_loader, val_loader,
        lr_init=lr, lr_min=lr/100
    )
    
    # ok, train
    unfreeze(net)
    fit_one_cycle(
        12, net,
        train_loader, val_loader,
        lr_init=lr/2, lr_min=lr/80,
        OHEM='OHEM'
    )
    
    # save
    torch.save(net.state_dict(), './unet_{}_{}_fold{}.pth'.format(now.strftime('%Y%m%d-%H%M%S'), SIZE, fold))
    
    gc.collect()
    torch.cuda.empty_cache()
    

Training U-Net with hypercolumn concatenation and spatial/channel-wise excitation...

fold 0:


 rate    iter   epoch   | valid_loss               | train_loss               | batch_loss               |  time          

-------------------------------------------------------------------------------------------------------------------------------





0.0009    1.0     1.0  |  0.000  0.000  (0.000) |  0.006  0.769  |  0.005  0.688  | 129.6477 



0.0009    1.0     1.0  |  1.897  0.772  (0.000) |  0.006  0.769  |  0.010  0.812  | 166.5917 
 rate    iter   epoch   | valid_loss               | train_loss               | batch_loss               |  time          

-------------------------------------------------------------------------------------------------------------------------------





0.0008    2.0     2.0  |  1.897  0.772  (0.000) |  0.005  0.825  |  0.008  0.688  | 293.4759 



0.0008    2.0     2.0  |  1.970  0.772  (0.000) |  0.005  0.825  |  0.001  1.000  | 328.8501 
 rate    iter   epoch   | valid_loss               | train_loss               | batch_loss               |  time          

-------------------------------------------------------------------------------------------------------------------------------





0.0005    3.0     3.0  |  1.970  0.772  (0.000) |  0.006  0.769  |  0.007  0.750  | 455.8459 



0.0005    3.0     3.0  |  1.996  0.772  (0.000) |  0.006  0.769  |  0.005  0.812  | 491.0799 
 rate    iter   epoch   | valid_loss               | train_loss               | batch_loss               |  time          

-------------------------------------------------------------------------------------------------------------------------------





0.0003    4.0     4.0  |  1.996  0.772  (0.000) |  0.006  0.794  |  0.023  0.625  | 617.977 



0.0003    4.0     4.0  |  2.006  0.772  (0.000) |  0.006  0.794  |  0.010  0.688  | 653.2663 
 rate    iter   epoch   | valid_loss               | train_loss               | batch_loss               |  time          

-------------------------------------------------------------------------------------------------------------------------------





0.0001    5.0     5.0  |  2.006  0.772  (0.000) |  0.005  0.797  |  0.003  0.875  | 780.2319 



0.0001    5.0     5.0  |  2.010  0.772  (0.000) |  0.005  0.797  |  0.007  0.750  | 815.4704 
 rate    iter   epoch   | valid_loss               | train_loss               | batch_loss               |  time          

-------------------------------------------------------------------------------------------------------------------------------





0.0000    6.0     6.0  |  2.010  0.772  (0.000) |  0.007  0.775  |  0.004  0.750  | 942.5422 



0.0000    6.0     6.0  |  2.010  0.772  (0.000) |  0.007  0.775  |  0.005  0.812  | 977.9283 
 rate    iter   epoch   | valid_loss               | train_loss               | batch_loss               |  time          

-------------------------------------------------------------------------------------------------------------------------------





0.0005    1.0     1.0  |  0.000  0.000  (0.000) |  0.339  0.759  |  0.407  0.688  | 305.6177 



0.0005    1.0     1.0  |  1.288  0.772  (0.000) |  0.339  0.759  |  0.234  0.875  | 341.2361 
 rate    iter   epoch   | valid_loss               | train_loss               | batch_loss               |  time          

-------------------------------------------------------------------------------------------------------------------------------





0.0005    2.0     2.0  |  1.288  0.772  (0.000) |  0.319  0.772  |  0.407  0.688  | 644.3641 



0.0005    2.0     2.0  |  1.270  0.772  (0.000) |  0.319  0.772  |  0.271  0.812  | 680.007 
 rate    iter   epoch   | valid_loss               | train_loss               | batch_loss               |  time          

-------------------------------------------------------------------------------------------------------------------------------





0.0004    3.0     3.0  |  1.270  0.772  (0.000) |  0.345  0.744  |  0.271  0.812  | 982.7071 



0.0004    3.0     3.0  |  1.261  0.772  (0.000) |  0.345  0.744  |  0.074  1.000  | 1018.3261 
 rate    iter   epoch   | valid_loss               | train_loss               | batch_loss               |  time          

-------------------------------------------------------------------------------------------------------------------------------





0.0004    4.0     4.0  |  1.261  0.772  (0.000) |  0.361  0.728  |  0.260  0.812  | 1320.8983 



0.0004    4.0     4.0  |  1.252  0.772  (0.000) |  0.361  0.728  |  0.287  0.812  | 1356.7136 
 rate    iter   epoch   | valid_loss               | train_loss               | batch_loss               |  time          

-------------------------------------------------------------------------------------------------------------------------------





0.0003    5.0     5.0  |  1.252  0.772  (0.000) |  0.273  0.809  |  0.270  0.812  | 1659.4969 



0.0003    5.0     5.0  |  1.246  0.772  (0.000) |  0.273  0.809  |  0.410  0.688  | 1695.2256 
 rate    iter   epoch   | valid_loss               | train_loss               | batch_loss               |  time          

-------------------------------------------------------------------------------------------------------------------------------





0.0003    6.0     6.0  |  1.246  0.772  (0.000) |  0.276  0.806  |  0.520  0.562  | 1997.8392 



0.0003    6.0     6.0  |  1.240  0.772  (0.000) |  0.276  0.806  |  0.265  0.812  | 2033.5706 
 rate    iter   epoch   | valid_loss               | train_loss               | batch_loss               |  time          

-------------------------------------------------------------------------------------------------------------------------------





0.0002    7.0     7.0  |  1.240  0.772  (0.000) |  0.269  0.809  |  0.273  0.812  | 2336.1954 



0.0002    7.0     7.0  |  1.232  0.772  (0.000) |  0.269  0.809  |  0.257  0.812  | 2371.8905 
 rate    iter   epoch   | valid_loss               | train_loss               | batch_loss               |  time          

-------------------------------------------------------------------------------------------------------------------------------





0.0001    8.0     8.0  |  1.232  0.772  (0.000) |  0.268  0.812  |  0.315  0.750  | 2674.4479 



0.0001    8.0     8.0  |  1.227  0.772  (0.000) |  0.268  0.812  |  0.404  0.688  | 2710.2544 
 rate    iter   epoch   | valid_loss               | train_loss               | batch_loss               |  time          

-------------------------------------------------------------------------------------------------------------------------------





0.0001    9.0     9.0  |  1.227  0.772  (0.000) |  0.274  0.806  |  0.195  0.875  | 3013.1037 



0.0001    9.0     9.0  |  1.222  0.772  (0.000) |  0.274  0.806  |  0.274  0.812  | 3048.777 
 rate    iter   epoch   | valid_loss               | train_loss               | batch_loss               |  time          

-------------------------------------------------------------------------------------------------------------------------------





0.0000   10.0    10.0  |  1.222  0.772  (0.000) |  0.303  0.772  |  0.270  0.812  | 3351.5028 



0.0000   10.0    10.0  |  1.218  0.773  (0.000) |  0.303  0.772  |  0.140  0.938  | 3387.133 
 rate    iter   epoch   | valid_loss               | train_loss               | batch_loss               |  time          

-------------------------------------------------------------------------------------------------------------------------------





0.0000   11.0    11.0  |  1.218  0.773  (0.000) |  0.312  0.766  |  0.337  0.750  | 3689.6926 



0.0000   11.0    11.0  |  1.215  0.772  (0.000) |  0.312  0.766  |  0.337  0.750  | 3725.3997 
 rate    iter   epoch   | valid_loss               | train_loss               | batch_loss               |  time          

-------------------------------------------------------------------------------------------------------------------------------





0.0000   12.0    12.0  |  1.215  0.772  (0.000) |  0.310  0.761  |  0.186  0.875  | 4028.1904 



0.0000   12.0    12.0  |  1.215  0.773  (0.000) |  0.310  0.761  |  0.271  0.812  | 4063.8457 
fold 1:


 rate    iter   epoch   | valid_loss               | train_loss               | batch_loss               |  time          

-------------------------------------------------------------------------------------------------------------------------------





0.0009    1.0     1.0  |  0.000  0.000  (0.000) |  0.006  0.772  |  0.008  0.750  | 140.9223 



0.0009    1.0     1.0  |  1.720  0.784  (0.000) |  0.006  0.772  |  0.005  0.875  | 176.1979 
 rate    iter   epoch   | valid_loss               | train_loss               | batch_loss               |  time          

-------------------------------------------------------------------------------------------------------------------------------





0.0008    1.8     2.0  |  1.720  0.784  (0.000) |  0.005  0.769  |  0.002  0.938  | 289.9793 

## find best threshold

In [None]:
# TTA with flip-lr

def tta(net, loader):
    net.set_mode('test')
    
    all_probs = []
    all_truths = []
    
    for input, truth, index, cache in loader:
        input = input.cuda()
        
        logits = net(input)
        probs = F.sigmoid(logits)
        
        all_probs.append(probs)
        all_truths.append(truth)
    
    # tensor.shape = [N, C, H, W], so we flip dim -1
    for input, truth, index, cache in loader:
        input = input.cuda()
        input = torch.flip(input, [-1])
        
        logits = net(input)
        probs = F.sigmoid(logits)
        
        probs = torch.flip(probs, [-1])
        
        all_probs[index] += probs
        all_probs[index] *= 0.5
        
    all_probs = torch.cat(all_probs, dim=0)
    all_truths = torch.cat(all_truths, dim=0)
    
    gc.collect()
    torch.cuda.empty_cache()
    
    return all_probs, all_truths
    

In [None]:
#dice for threshold selection

def dice_overall(preds, targs):
    n = preds.shape[0]
    
    preds = preds.view(n, -1)
    targs = targs.view(n, -1)
    
    targs = (targs > 0.5).long()
    
    intersect = (preds * targs).sum(-1).float()
    union = (preds + targs).sum(-1).float()
    
    # get 1 for both empty pred and targ
    u0 = union==0
    intersect[u0] = 1
    union[u0] = 2
    
    return (2. * intersect / union)


In [None]:
# load saved model

tag = '20190723-014424_256'
nfolds = 4

for fold in range(nfolds):
    print('\nfold {}:\n'.format(fold))
    
    train_loader, val_loader = get_dataloaders(
        data_root,
        batch_size,
        fold, nfolds=nfolds,
        width=SIZE, height=SIZE,
        train_augment=train_augment, val_augment=val_augment,
        random_state=SEED
    )
    
    net = Net().cuda()
    net.load_state_dict(torch.load('unet_{}_{}.pth'.format(tag, fold)))
    
    probs, truths = tta(net, val_loader)
    
    # noise removal
    probs[probs.view(probs.shape[0], -1).sum(-1) < noise_th,...] = 0.0
    
    # search best threshold for this fold
    dices = []
    thrs = np.arange(0.01, 1, 0.01)

    probs = probs.cuda()
    truths = truths.cuda()

    for th in progress_bar(thrs):
        preds = (probs>th).long()
        dices.append(dice_overall(preds, truths).mean())
        
    dices = np.array(dices)    
    
    # save best
    scores.append(dices.max())
    best_thrs.append(thrs[dices.argmax()])

    print(dices, scores, best_thrs)

    # clean up
    if fold != nfolds-1: 
        del probs, truths, preds

    gc.collect()
    torch.cuda.empty_cache()
    

In [None]:
print('scores: ', scores)
print('mean score: ', np.array(scores).mean())

print('thresholds: ', best_thrs)
best_thr = np.array(best_thrs).mean()
print('best threshold: ', best_thr)

In [None]:
# show dice-thr plot for the LAST fold

best_dice = dices.max()

plt.figure(figsize=(8,4))
plt.plot(thrs, dices)

plt.vlines(x=best_thrs[-1], ymin=dices.min(), ymax=dices.max())
plt.text(best_thrs[-1]+0.03, best_dice-0.01, f'DICE = {best_dice:.3f}', fontsize=14);

plt.show()

## submission

In [None]:
'''
def test_augment(image,mask,index):
    cache = Struct(image = image.copy(), mask = mask.copy())
    image = do_resize(image, SIZE, SIZE)
    image = do_center_pad_to_factor(image, factor = FACTOR)
    return image,mask,index,cache
'''

test_dataset = SIIMDataset(
    data_root,
    0,
    width=width, height=height,
    phase='test',
    augment=null_augment
)

test_loader  = DataLoader(
    test_dataset,
    batch_size  = batch_size,
    drop_last   = False,
    num_workers = 8,
    pin_memory  = True,
    collate_fn  = null_collate
)


In [None]:
for fold in range(nfolds):
    print('fold: {}'.format(fold))
    
    net = Net().cuda()
    net.load_state_dict(torch.load('unet_{}_{}.pth'.format(tag, fold)))
    
    probs, _ = tta(net, test_loader)
    
    if fold == 0:
        all_probs = probs
    else:
        all_probs += probs
        
    gc.collect()
    torch.cuda.empty_cache()
    
all_probs /= nfolds


In [None]:
# noise removal

probs_clean = all_probs.clone()

probs_clean[probs_clean.view(probs_clean.shape[0], -1).sum(-1) < noise_th,...] = 0.0


In [None]:
# majority vote - no we don't actually use vote

# best_thr = 0.2

print('best_thr: {}'.format(best_thr))

pt_vote = torch.where(probs_clean > best_thr, torch.ones_like(probs_clean), torch.zeros_like(probs_clean))
pt_vote = pt_vote.numpy()


In [None]:
# Generate rle encodings in parallel (images are first converted to the original size)
mask_size = 1024

def mask_worker(mask):
    im = PIL.Image.fromarray((mask.T).astype(np.uint8)).resize((mask_size, mask_size))
    im = np.asarray(im)

    rle = mask2rle(im, mask_size, mask_size)
    
    return rle

pool = mp.Pool()
rle_list = pool.map(mask_worker, pt_vote)


In [None]:
now = datetime.now()

ids = [os.path.splitext(filename)[0] for filename in test_dataset.filenames]

sub_df = pd.DataFrame({'ImageId': ids, 'EncodedPixels': rle_list})
sub_df.loc[sub_df.EncodedPixels=='', 'EncodedPixels'] = '-1'

sub_df.to_csv('unet_sub_{}_{:.6f}.csv'.format(now.strftime('%Y%m%d-%H%M%S'), np.array(scores).mean()), index=False)
sub_df.head()
