In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

import os
from datetime import datetime
import sys
import gc
sys.path.append('../../')

from sklearn.model_selection import KFold

from dependencies import *
from settings import *
from reproducibility import *
from models.TGS_salt.Unet34_scSE_hyper import Unet_scSE_hyper as Net

Using paths on kail-main

Importing numerical libraries...
Importing standard libraries...
Importing miscellaneous functions...
Importing constants...
Importing Neural Network dependencies...
	PyTorch
	Keras
	TensorFlow
	Metrics, Losses and LR Schedulers
	Kaggle Metrics
	Image augmentations
	Datasets
Importing external libraries...
	Lovasz Losses (elu+1)

Fixing random seed for reproducibility...
	Setting random seed to 35202.

Setting CUDA environment...
	torch.__version__              = 1.1.0
	torch.version.cuda             = 9.0.176
	torch.backends.cudnn.version() = 7501
	os['CUDA_VISIBLE_DEVICES']     = 0,1
	torch.cuda.device_count()      = 2



In [2]:
SIZE = 256
FACTOR = 256
ne = "ne"
initial_checkpoint = None
MODEL = "ResNet34"
# OHEM = "OHEM"
OHEM = ''

batch_size = 16
n_acc = 64 / batch_size
nfolds = 4

noise_th = 75.0*(SIZE/128.0)**2 #threshold for the number of predicted pixels
best_thr0 = 0.2 #preliminary value of the threshold for metric calculation

data_root = '../../data/siim-pneumothorax'
torch.cuda.set_device('cuda:1')

In [3]:
def time_to_str(time, str):
    #if str == 'min':
    #	    return str(round(float(time)/60,5))+" min(s)"
    return round(time,4)

In [4]:
#TODO: Instead of directly printing to stdout, copy it into a txt file
class Logger():
    def __init__(self,name=MODEL+OHEM+ne):
        super().__init__()
        self.model=name
        #if OHEM != "OHEM":
        #    self.model=MODEL+ne[ne.find("_")+1:]
        self.file = open(self.model+"_log.txt","w+")
        self.file.close()
    def write(self, str):
        print(str)
        self.file = open(self.model+"_log.txt","a+")
        self.file.write(str)
        self.file.close()
    def write2(self, str):
        print(str, end='',flush=True)
        self.file = open(self.model+"_log.txt","a+")
        self.file.write(str)
        self.file.close()
    def stop():
        self.file.close()
        
log = Logger()

In [5]:
def valid_augment(image,mask,index):
    cache = Struct(image = image.copy(), mask = mask.copy())
    image, mask = do_resize2(image, mask, SIZE, SIZE)
    image, mask = do_center_pad_to_factor2(image, mask, factor = FACTOR)
    return image,mask,index,cache

def train_augment(image,mask,index):
    cache = Struct(image = image.copy(), mask = mask.copy())

    if np.random.rand() < 0.5:
         image, mask = do_horizontal_flip2(image, mask)
         pass

    if np.random.rand() < 0.5:
        c = np.random.choice(4)
        if c==0:
            image, mask = do_random_shift_scale_crop_pad2(image, mask, 0.2) #0.125

        if c==1:
            image, mask = do_horizontal_shear2( image, mask, dx=np.random.uniform(-0.07,0.07) )
            pass

        if c==2:
            image, mask = do_shift_scale_rotate2( image, mask, dx=0, dy=0, scale=1, angle=np.random.uniform(0,15))  #10

        if c==3:
            image, mask = do_elastic_transform2(image, mask, grid=10, distort=np.random.uniform(0,0.15))#0.10
            pass
    if np.random.rand() < 0.5:
        c = np.random.choice(3)
        if c==0:
            image = do_brightness_shift(image,np.random.uniform(-0.1,+0.1))
        if c==1:
            image = do_brightness_multiply(image,np.random.uniform(1-0.08,1+0.08))
        if c==2:
            image = do_gamma(image,np.random.uniform(1-0.08,1+0.08))
        # if c==1:
        #     image = do_invert_intensity(image)

    image, mask = do_resize2(image, mask, SIZE, SIZE)
    image, mask = do_center_pad_to_factor2(image, mask, factor = FACTOR)
    return image,mask,index,cache

In [6]:
def null_augment(image, mask, index):
    cache = Struct(image = image.copy(), mask = mask.copy())
    return image, mask, index, cache

def null_collate(batch):

    batch_size = len(batch)
    cache = []
    input = []
    truth = []
    index = []
    for b in range(batch_size):
        input.append(batch[b][0])
        truth.append(batch[b][1])
        index.append(batch[b][2])
        cache.append(batch[b][3])
    input = torch.from_numpy(np.array(input)).float().unsqueeze(1)

    if truth[0]!=[]:
        truth = torch.from_numpy(np.array(truth)).float().unsqueeze(1)

    return input, truth, index, cache

class SIIMDataset(Dataset):
    def __init__(self, data_root, fold, width=1024, height=1024, phase='train', augment=null_augment, random_state=2019, nfolds=4):
        self.data_root = data_root
        self.fold = fold
        self.height = width
        self.width = height
        self.phase = phase
        self.augment = augment
        
        kf = KFold(n_splits=nfolds, shuffle=True, random_state=random_state)
        train_list = os.listdir(os.path.join(data_root, 'train_png'))
        
        if phase == 'train':
            index_list = list(kf.split(list(range(len(train_list)))))[fold][0]
            self.filenames = [train_list[i] for i in index_list]
        elif phase == 'val':
            index_list = list(kf.split(list(range(len(train_list)))))[fold][1]
            self.filenames = [train_list[i] for i in index_list]
        else: # test
            self.filenames = os.listdir(os.path.join(data_root, 'test_png'))
            
    def __getitem__(self, index):
        img_path = os.path.join(self.data_root, 'train_png/{}'.format(self.filenames[index]))
        img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE).astype(np.float32) / 255
        img = cv2.resize(img, (self.width, self.height), interpolation = cv2.INTER_AREA)
        
        if self.phase != 'test':
            mask_path = os.path.join(self.data_root, 'mask_png/{}'.format(self.filenames[index]))
            mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE).astype(np.float32) / 255
            mask = cv2.resize(mask, (self.width, self.height), interpolation = cv2.INTER_AREA)
        else: # test
            mask = []
        
        return self.augment(img, mask, index)
    
    def __len__(self):
        return len(self.filenames)


In [7]:
def validation( net, valid_loader ):

    valid_num  = 0
    valid_loss = np.zeros(3, np.float32)

    for input, truth, index, cache in valid_loader:
        input = input.cuda()
        truth = truth.cuda()
        
        with torch.no_grad():
            logit = net(input) #data_parallel(net,input)
            prob  = F.sigmoid(logit)
            
            loss  = net.focal_loss(logit, truth, 1.0, 0.5, 0.25) + net.criterion(logit, truth)
            dice  = net.metric(logit, truth, noise_th=noise_th, threshold=best_thr0)

        batch_size = len(index)
        valid_loss += batch_size * np.array(( loss.item(), dice.item(), 0))
        valid_num += batch_size
        
    valid_loss /= valid_num
    
    return valid_loss

In [8]:
def freeze(net):
    for p in net.feature_net.parameters():
        p.requires_grad = False
        
def unfreeze(net):
    for p in net.feature_net.parameters():
        p.requires_grad = True

def cosine_annealing_scheduler(num_iter, lr_init, lr_min):
    scheduler = lambda x: ((lr_init-lr_min)/2)*(np.cos(PI*(np.mod(x-1,num_iter)/(num_iter)))+1)+lr_min
        
def set_BN_momentum(model, momentum=0.1*batch_size/64):
    for i, (name, layer) in enumerate(model.named_modules()):
        if isinstance(layer, nn.BatchNorm2d) or isinstance(layer, nn.BatchNorm1d):
            layer.momentum = momentum
            
def fit_one_cycle(epochs, net, train_loader, val_loader, lr_init=0.001, lr_min=0.00001):
    # init learner
    num_batch = len(train_loader)
    num_iter = num_batch * epochs
    iter_smooth = 20
    iter_log    = 100
    iter_valid  = num_batch
    # iter_valid = 20
    
    print(num_batch, num_iter)
    
    scheduler = cosine_annealing_scheduler(num_iter, lr_init, lr_min)
    optimizer = optim.SGD(filter(lambda p: p.requires_grad, net.parameters()),
          lr=0.01, momentum=0.9, weight_decay=0.0001
    )
    set_BN_momentum(net)
    
    start_iter = 0
    start_epoch= 0
    train_loss  = np.zeros(6,np.float32)
    valid_loss  = np.zeros(6,np.float32)
    batch_loss  = np.zeros(6,np.float32)
    rate = 0
    iter = 0
    epoch = 0
    
    start = timer()
    while iter < num_iter:  # loop over the dataset multiple times
        sum_train_loss = np.zeros(6,np.float32)
        sum = 0
        
        log.write(' rate    iter   epoch   | valid_loss               | train_loss               | batch_loss               |  time          \n')
        log.write('-------------------------------------------------------------------------------------------------------------------------------\n')
            
        for input, truth, index, cache in train_loader:
            # validation
            if (iter + 1) % iter_valid == 0:
                net.set_mode('valid')
                valid_loss = validation(net, val_loader)
                net.set_mode('train')

                log.write2('\r')
                log.write('%0.4f  %5.1f  %6.1f  |  %0.3f  %0.3f  (%0.3f) |  %0.3f  %0.3f  |  %0.3f  %0.3f  | %s \n' % (\
                         rate, iter/num_batch, epoch,
                         valid_loss[0], valid_loss[1], valid_loss[2],
                         train_loss[0], train_loss[1],
                         batch_loss[0], batch_loss[1],
                         time_to_str((timer() - start),'min')))
                time.sleep(0.01)
                
            if scheduler is not None:
                lr = scheduler(iter)
                if lr<0 : break
                adjust_learning_rate(optimizer, lr)
                rate = get_learning_rate(optimizer)
            
            # ok, train
            net.set_mode('train')

            input = input.cuda()
            truth = truth.cuda()

            logit = net(input) #data_parallel(net,input)

            if OHEM == "OHEM":
                loss = net.focal_loss(logit, truth, 1.0, 0.5, 0.25) + net.criterion(logit, truth)
            else:
                loss = net.criterion(logit, truth)

            dice = net.metric(logit, truth, noise_th=noise_th, threshold=best_thr0)
            
            # learn with grad acc
            loss /= n_acc
            loss.backward()
            
            if ((iter + 1) % n_acc) == 0:
                optimizer.step()
                optimizer.zero_grad()
                torch.nn.utils.clip_grad_norm_(net.parameters(), 1)
            
            # print statistics  ------------
            batch_loss = np.array((
                           loss.item(),
                           dice.item(),
                           0, 0, 0, 0,
                         ))
            sum_train_loss += batch_loss
            sum += 1
            if iter%iter_smooth == 0:
                train_loss = sum_train_loss/sum
                sum_train_loss = np.zeros(6,np.float32)
                sum = 0

            log.write2('\r%0.4f  %5.1f  %6.1f  |  %0.3f  %0.3f  (%0.3f) |  %0.3f  %0.3f  |  %0.3f  %0.3f  | %s ' % (\
                         rate, iter/num_batch, epoch,
                         valid_loss[0], valid_loss[1], valid_loss[2],
                         train_loss[0], train_loss[1],
                         batch_loss[0], batch_loss[1],
                         time_to_str((timer() - start), 'min')))
            
            iter += 1
            epoch = iter // num_batch


In [9]:
def get_dataloaders(data_root, batch_size, fold, nfolds=4, width=1024, height=1024, train_augment=null_augment, val_augment=null_augment, random_state=SEED):
    train_dataset = SIIMDataset(
        data_root,
        fold,
        width=width, height=height,
        phase='train',
        augment=train_augment,
        random_state=random_state,
        nfolds=nfolds
    )

    train_loader  = DataLoader(
        train_dataset,
        sampler     = RandomSampler(train_dataset),
        batch_size  = batch_size,
        drop_last   = True,
        num_workers = 8,
        pin_memory  = True,
        collate_fn  = null_collate
    )

    val_dataset = SIIMDataset(
        data_root,
        fold,
        width=width, height=height,
        phase='val',
        augment=val_augment,
        random_state=random_state,
        nfolds=nfolds
    )

    val_loader  = DataLoader(
        val_dataset,
        sampler     = RandomSampler(val_dataset),
        batch_size  = batch_size,
        drop_last   = False,
        num_workers = 8,
        pin_memory  = True,
        collate_fn  = null_collate
    )
    
    return train_loader, val_loader

In [10]:
print("Training U-Net with hypercolumn concatenation and spatial/channel-wise excitation...")

now = datetime.now()

for fold in range(nfolds):
    train_loader, val_loader = get_dataloaders(
        data_root,
        batch_size,
        fold, nfolds=nfolds,
        width=SIZE, height=SIZE,
        train_augment=null_augment, val_augment=null_augment,
        random_state=SEED
    )
    
    net = Net().cuda()
    
    lr = 0.001
    
    # warm up
    # freeze(net)
    fit_one_cycle(
        6, net,
        train_loader, val_loader,
        lr_init=lr, lr_min=lr/100
    )
    
    # ok, train
    # unfreeze(net)
    fit_one_cycle(
        12, net,
        train_loader, val_loader,
        lr_init=lr/2, lr_min=lr/80
    )
    
    # save
    torch.save(net.state_dict(), './model/unet_{}_{}_fold{}.pth'.format(now.strftime('%Y%m%d-%H%M%S'), SIZE, fold))
    
    gc.collect()
    torch.cuda.empty_cache()
    

Training U-Net with hypercolumn concatenation and spatial/channel-wise excitation...
502 3012
 rate    iter   epoch   | valid_loss               | train_loss               | batch_loss               |  time          

-------------------------------------------------------------------------------------------------------------------------------





0.0000    1.0     0.0  |  0.000  0.000  (0.000) |  0.279  0.769  |  0.363  0.688  | 311.0042 





0.0000    1.0     0.0  |  1.162  0.772  (0.000) |  0.279  0.769  |  0.363  0.688  | 398.3047 

0.0000    1.0     0.0  |  1.162  0.772  (0.000) |  0.279  0.769  |  0.247  0.812  | 398.9417  rate    iter   epoch   | valid_loss               | train_loss               | batch_loss               |  time          

-------------------------------------------------------------------------------------------------------------------------------





0.0000    2.0     1.0  |  1.162  0.772  (0.000) |  0.218  0.766  |  0.364  0.562  | 706.1625 





0.0000    2.0     1.0  |  1.158  0.729  (0.000) |  0.218  0.766  |  0.364  0.562  | 742.7521 

0.0000    2.0     1.0  |  1.158  0.729  (0.000) |  0.218  0.766  |  0.074  0.812  | 743.3767  rate    iter   epoch   | valid_loss               | train_loss               | batch_loss               |  time          

-------------------------------------------------------------------------------------------------------------------------------





0.0000    3.0     2.0  |  1.158  0.729  (0.000) |  0.261  0.730  |  0.303  0.699  | 1050.4923 





0.0000    3.0     2.0  |  1.134  0.748  (0.000) |  0.261  0.730  |  0.303  0.699  | 1087.0128 

0.0000    3.0     2.0  |  1.134  0.748  (0.000) |  0.261  0.730  |  0.219  0.750  | 1087.6246  rate    iter   epoch   | valid_loss               | train_loss               | batch_loss               |  time          

-------------------------------------------------------------------------------------------------------------------------------





0.0000    4.0     3.0  |  1.134  0.748  (0.000) |  0.212  0.724  |  0.374  0.628  | 1394.7197 





0.0000    4.0     3.0  |  1.144  0.762  (0.000) |  0.212  0.724  |  0.374  0.628  | 1431.0965 

0.0000    4.0     3.0  |  1.144  0.762  (0.000) |  0.212  0.724  |  0.274  0.696  | 1431.7213  rate    iter   epoch   | valid_loss               | train_loss               | batch_loss               |  time          

-------------------------------------------------------------------------------------------------------------------------------





0.0000    5.0     4.0  |  1.144  0.762  (0.000) |  0.203  0.728  |  0.176  0.814  | 1738.1873 





0.0000    5.0     4.0  |  1.091  0.736  (0.000) |  0.203  0.728  |  0.176  0.814  | 1776.1956 

0.0000    5.0     4.0  |  1.091  0.736  (0.000) |  0.203  0.728  |  0.272  0.694  | 1776.8711  rate    iter   epoch   | valid_loss               | train_loss               | batch_loss               |  time          

-------------------------------------------------------------------------------------------------------------------------------





0.0000    6.0     5.0  |  1.091  0.736  (0.000) |  0.213  0.708  |  0.220  0.691  | 2086.5696 





0.0000    6.0     5.0  |  1.091  0.689  (0.000) |  0.213  0.708  |  0.220  0.691  | 2123.0368 

0.0000    6.0     5.0  |  1.091  0.689  (0.000) |  0.213  0.708  |  0.138  0.822  | 2123.6569 502 6024
 rate    iter   epoch   | valid_loss               | train_loss               | batch_loss               |  time          

-------------------------------------------------------------------------------------------------------------------------------





0.0000    1.0     0.0  |  0.000  0.000  (0.000) |  0.206  0.722  |  0.247  0.694  | 306.5102 





0.0000    1.0     0.0  |  1.042  0.760  (0.000) |  0.206  0.722  |  0.247  0.694  | 343.0811 

0.0000    1.0     0.0  |  1.042  0.760  (0.000) |  0.206  0.722  |  0.107  0.895  | 343.7088  rate    iter   epoch   | valid_loss               | train_loss               | batch_loss               |  time          

-------------------------------------------------------------------------------------------------------------------------------





0.0000    2.0     1.0  |  1.042  0.760  (0.000) |  0.213  0.715  |  0.236  0.697  | 650.1143 





0.0000    2.0     1.0  |  1.134  0.756  (0.000) |  0.213  0.715  |  0.236  0.697  | 686.6661 

0.0000    2.0     1.0  |  1.134  0.756  (0.000) |  0.213  0.715  |  0.139  0.757  | 687.2886  rate    iter   epoch   | valid_loss               | train_loss               | batch_loss               |  time          

-------------------------------------------------------------------------------------------------------------------------------





0.0000    3.0     2.0  |  1.134  0.756  (0.000) |  0.228  0.651  |  0.146  0.758  | 993.5001 





0.0000    3.0     2.0  |  1.084  0.709  (0.000) |  0.228  0.651  |  0.146  0.758  | 1030.0156 

0.0000    3.0     2.0  |  1.084  0.709  (0.000) |  0.228  0.651  |  0.153  0.688  | 1030.6348  rate    iter   epoch   | valid_loss               | train_loss               | batch_loss               |  time          

-------------------------------------------------------------------------------------------------------------------------------





0.0000    4.0     3.0  |  1.084  0.709  (0.000) |  0.205  0.699  |  0.126  0.752  | 1337.1567 





0.0000    4.0     3.0  |  1.088  0.696  (0.000) |  0.205  0.699  |  0.126  0.752  | 1373.7337 

0.0000    4.0     3.0  |  1.088  0.696  (0.000) |  0.205  0.699  |  0.160  0.704  | 1374.3563  rate    iter   epoch   | valid_loss               | train_loss               | batch_loss               |  time          

-------------------------------------------------------------------------------------------------------------------------------





0.0000    5.0     4.0  |  1.088  0.696  (0.000) |  0.172  0.732  |  0.102  0.817  | 1680.6832 





0.0000    5.0     4.0  |  1.142  0.691  (0.000) |  0.172  0.732  |  0.102  0.817  | 1716.8958 

0.0000    5.0     4.0  |  1.142  0.691  (0.000) |  0.172  0.732  |  0.253  0.637  | 1717.5086  rate    iter   epoch   | valid_loss               | train_loss               | batch_loss               |  time          

-------------------------------------------------------------------------------------------------------------------------------





0.0000    6.0     5.0  |  1.142  0.691  (0.000) |  0.173  0.730  |  0.318  0.578  | 2023.4173 





0.0000    6.0     5.0  |  1.192  0.667  (0.000) |  0.173  0.730  |  0.318  0.578  | 2059.5832 

0.0000    6.0     5.0  |  1.192  0.667  (0.000) |  0.173  0.730  |  0.144  0.762  | 2060.2014  rate    iter   epoch   | valid_loss               | train_loss               | batch_loss               |  time          

-------------------------------------------------------------------------------------------------------------------------------





0.0000    7.0     6.0  |  1.192  0.667  (0.000) |  0.167  0.768  |  0.115  0.827  | 2366.6766 





0.0000    7.0     6.0  |  1.143  0.696  (0.000) |  0.167  0.768  |  0.115  0.827  | 2402.7629 

0.0000    7.0     6.0  |  1.143  0.696  (0.000) |  0.167  0.768  |  0.116  0.754  | 2403.3747  rate    iter   epoch   | valid_loss               | train_loss               | batch_loss               |  time          

-------------------------------------------------------------------------------------------------------------------------------





0.0000    8.0     7.0  |  1.143  0.696  (0.000) |  0.181  0.739  |  0.160  0.758  | 2709.7318 





0.0000    8.0     7.0  |  1.206  0.759  (0.000) |  0.181  0.739  |  0.160  0.758  | 2746.2491 

0.0000    8.0     7.0  |  1.206  0.759  (0.000) |  0.181  0.739  |  0.344  0.578  | 2746.8751  rate    iter   epoch   | valid_loss               | train_loss               | batch_loss               |  time          

-------------------------------------------------------------------------------------------------------------------------------





0.0000    9.0     8.0  |  1.206  0.759  (0.000) |  0.163  0.751  |  0.108  0.815  | 3054.6576 





0.0000    9.0     8.0  |  1.205  0.675  (0.000) |  0.163  0.751  |  0.108  0.815  | 3091.5465 

0.0000    9.0     8.0  |  1.205  0.675  (0.000) |  0.163  0.751  |  0.144  0.814  | 3092.1583  rate    iter   epoch   | valid_loss               | train_loss               | batch_loss               |  time          

-------------------------------------------------------------------------------------------------------------------------------





0.0000   10.0     9.0  |  1.205  0.675  (0.000) |  0.178  0.723  |  0.155  0.818  | 3400.4329 





0.0000   10.0     9.0  |  1.215  0.679  (0.000) |  0.178  0.723  |  0.155  0.818  | 3436.8603 

0.0000   10.0     9.0  |  1.215  0.679  (0.000) |  0.178  0.723  |  0.126  0.756  | 3437.4865  rate    iter   epoch   | valid_loss               | train_loss               | batch_loss               |  time          

-------------------------------------------------------------------------------------------------------------------------------





0.0000   11.0    10.0  |  1.215  0.679  (0.000) |  0.183  0.725  |  0.191  0.694  | 3744.6791 





0.0000   11.0    10.0  |  1.237  0.743  (0.000) |  0.183  0.725  |  0.191  0.694  | 3780.8997 

0.0000   11.0    10.0  |  1.237  0.743  (0.000) |  0.183  0.725  |  0.244  0.692  | 3781.5076  rate    iter   epoch   | valid_loss               | train_loss               | batch_loss               |  time          

-------------------------------------------------------------------------------------------------------------------------------





0.0000   12.0    11.0  |  1.237  0.743  (0.000) |  0.187  0.721  |  0.095  0.820  | 4089.364 





0.0000   12.0    11.0  |  1.379  0.624  (0.000) |  0.187  0.721  |  0.095  0.820  | 4125.8063 

0.0000   12.0    11.0  |  1.379  0.624  (0.000) |  0.187  0.721  |  0.143  0.760  | 4126.4326 

IndexError: tuple index out of range