In [2]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

import os
from datetime import datetime
import sys
import gc
sys.path.append('../../')

from sklearn.model_selection import KFold

from dependencies import *
from settings import *
from reproducibility import *
from models.TGS_salt.Unet34_scSE_hyper import Unet_scSE_hyper as Net

Importing numerical libraries...
Importing standard libraries...
Importing miscellaneous functions...
Importing constants...
Importing Neural Network dependencies...
	PyTorch
	Keras
	TensorFlow
	Metrics, Losses and LR Schedulers
	Kaggle Metrics
	Image augmentations
	Datasets
Importing external libraries...
	Lovasz Losses (elu+1)

Fixing random seed for reproducibility...
	Setting random seed to 35202.

Setting CUDA environment...
	torch.__version__              = 1.1.0
	torch.version.cuda             = 9.0.176
	torch.backends.cudnn.version() = 7501
	os['CUDA_VISIBLE_DEVICES']     = 0,1
	torch.cuda.device_count()      = 2



In [15]:
SIZE = 256
FACTOR = 128
ne = "ne"
initial_checkpoint = None
MODEL = "ResNet34"

batch_size = 16
n_acc = 64 / batch_size
nfolds = 4

data_root = '../../data/siim-pneumothorax'
torch.cuda.set_device('cuda:0')

In [4]:
def time_to_str(time, str):
    #if str == 'min':
    #	    return str(round(float(time)/60,5))+" min(s)"
    return round(time,4)

In [5]:
#TODO: Instead of directly printing to stdout, copy it into a txt file
class Logger():
    def __init__(self,path=None, fold=FOLD):
        super().__init__()
        self.fold=str(fold)
        self.file = None 
    def write(self, str):
        print(str)
        self.file = open(self.fold+"_log.txt","w") 
        self.file.write(str)
        self.file.close()
    def stop():
        self.file.close()

In [6]:
def valid_augment(image,mask,index):
    cache = Struct(image = image.copy(), mask = mask.copy())
    image, mask = do_resize2(image, mask, SIZE, SIZE)
    image, mask = do_center_pad_to_factor2(image, mask, factor = FACTOR)
    return image,mask,index,cache

def train_augment(image,mask,index):
    cache = Struct(image = image.copy(), mask = mask.copy())

    if np.random.rand() < 0.5:
         image, mask = do_horizontal_flip2(image, mask)
         pass

    if np.random.rand() < 0.5:
        c = np.random.choice(4)
        if c==0:
            image, mask = do_random_shift_scale_crop_pad2(image, mask, 0.2) #0.125

        if c==1:
            image, mask = do_horizontal_shear2( image, mask, dx=np.random.uniform(-0.07,0.07) )
            pass

        if c==2:
            image, mask = do_shift_scale_rotate2( image, mask, dx=0, dy=0, scale=1, angle=np.random.uniform(0,15))  #10

        if c==3:
            image, mask = do_elastic_transform2(image, mask, grid=10, distort=np.random.uniform(0,0.15))#0.10
            pass
    if np.random.rand() < 0.5:
        c = np.random.choice(3)
        if c==0:
            image = do_brightness_shift(image,np.random.uniform(-0.1,+0.1))
        if c==1:
            image = do_brightness_multiply(image,np.random.uniform(1-0.08,1+0.08))
        if c==2:
            image = do_gamma(image,np.random.uniform(1-0.08,1+0.08))
        # if c==1:
        #     image = do_invert_intensity(image)

    image, mask = do_resize2(image, mask, SIZE, SIZE)
    image, mask = do_center_pad_to_factor2(image, mask, factor = FACTOR)
    return image,mask,index,cache

In [12]:
def null_augment(image, label, index):
    cache = Struct(image = image.copy(), mask = mask.copy())
    return image, label, index, cache

def null_collate(batch):

    batch_size = len(batch)
    cache = []
    input = []
    truth = []
    index = []
    for b in range(batch_size):
        input.append(batch[b][0])
        truth.append(batch[b][1])
        index.append(batch[b][2])
        cache.append(batch[b][3])
    input = torch.from_numpy(np.array(input)).float().unsqueeze(1)

    if truth[0]!=[]:
        truth = torch.from_numpy(np.array(truth)).float().unsqueeze(1)

    return input, truth, index, cache

class SIIMDataset(Dataset):
    def __init__(self, data_root, fold, width=1024, height=1024, phase='train', augment=null_augment, random_state=2019, nfolds=4):
        self.data_root = data_root
        self.fold = fold
        self.height = width
        self.width = height
        self.phase = phase
        self.image_dir = img_dir
        self.augment = augment
        
        kf = KFold(n_splits=nfolds, shuffle=True, random_state=random_state)
        train_list = os.listdir(os.path.join(data_root, 'train_png'))
        
        if phase == 'train':
            self.filenames = list(kf.split(list(range(len(train_list)))))[fold][0]
        elif phase == 'val':
            self.filenames = list(kf.split(list(range(len(train_list)))))[fold][1]
        else: # test
            self.filenames = os.listdir(os.path.join(data_root, 'test_png'))
            
    def __getitem__(self, index):
        img_path = os.path.join(self.data_root, 'train_png/{}'.format(self.filenames[index]))
        img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE).astype(np.float32) / 255
        
        if self.phase != 'test':
            mask_path = os.path.join(self.data_root, 'mask_png/{}'.format(self.filenames[index]))
            mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE).astype(np.float32) / 255
        else: # test
            mask = []
        
        return self.augment(img, mask, index)
    
    def __len__(self):
        return len(self.filenames)


Fixing random seed for reproducibility...
	Setting random seed to 35202.

Setting CUDA environment...
	torch.__version__              = 1.1.0
	torch.version.cuda             = 9.0.176
	torch.backends.cudnn.version() = 7501
	os['CUDA_VISIBLE_DEVICES']     = 0,1
	torch.cuda.device_count()      = 2



In [16]:
def validation( net, valid_loader ):

    valid_num  = 0
    valid_loss = np.zeros(3,np.float32)

    predicts = []
    truths   = []

    for input, truth, index, cache in valid_loader:
        input = input.cuda()
        truth = truth.cuda()
        with torch.no_grad():
            logit = net(input) #data_parallel(net,input)
            prob  = F.sigmoid(logit)
            loss  = net.focal_loss(logit, truth, 1.0, 0.5, 0.25) + net.criterion(logit, truth)
            dice  = net.metric(logit, truth)

        batch_size = len(index)
        valid_loss += batch_size*np.array(( loss.item(), dice.item(), 0))
        valid_num += batch_size

        prob  = prob [:,:,Y0:Y1, X0:X1]
        truth = truth[:,:,Y0:Y1, X0:X1]
        #prob  = F.avg_pool2d(prob,  kernel_size=2, stride=2)
        #truth = F.avg_pool2d(truth, kernel_size=2, stride=2)
        predicts.append(prob.data.cpu().numpy())
        truths.append(truth.data.cpu().numpy())

    assert(valid_num == len(valid_loader.sampler))
    valid_loss  = valid_loss/valid_num

    #--------------------------------------------------------
    predicts = np.concatenate(predicts).squeeze()
    truths   = np.concatenate(truths).squeeze()
    precision, result, threshold  = do_kaggle_metric(predicts, truths)
    valid_loss[2] = precision.mean()

    return valid_loss

In [17]:
def freeze(net):
    for p in net.feature_net.parameters():
        p.requires_grad = False
        
def unfreeze(net):
    for p in net.feature_net.parameters():
        p.requires_grad = True

def cosine_annealing_scheduler(num_iter, lr_init, lr_min):
    scheduler = lambda x: ((lr_init-lr_min)/2)*(np.cos(PI*(np.mod(x-1,num_iter)/(num_iter)))+1)+lr_min
        
def set_BN_momentum(model, momentum=0.1*batch_size/64):
    for i, (name, layer) in enumerate(model.named_modules()):
        if isinstance(layer, nn.BatchNorm2d) or isinstance(layer, nn.BatchNorm1d):
            layer.momentum = momentum
            
def fit_one_cycle(epochs, net, train_loader, val_loader, lr_init=0.001, lr_min=0.00001):
    # init learner
    num_batch = len(train_loader)
    num_iter = num_batch * epochs
    iter_smooth = 20
    iter_log    = 100
    iter_valid  = num_batch
    
    scheduler = cosine_annealing_scheduler(num_iter, lr_init, lr_min)
    optimizer = optim.SGD(filter(lambda p: p.requires_grad, net.parameters()),
          lr=0.01, momentum=0.9, weight_decay=0.0001
    )
    set_BN_momentum(net)
    
    start_iter = 0
    start_epoch= 0
    train_loss  = np.zeros(6,np.float32)
    valid_loss  = np.zeros(6,np.float32)
    batch_loss  = np.zeros(6,np.float32)
    rate = 0
    iter = 0
    
    start = timer()
    while iter < num_iters:  # loop over the dataset multiple times
        sum_train_loss = np.zeros(6,np.float32)
        sum = 0
            
        for input, truth, index, cache in train_loader:
            # validation
            if (iter + 1) % iter_valid == 0:
                net.set_mode('valid')
                valid_loss = validation(net, valid_loader)
                net.set_mode('train')

                log.write2('\r')
                log.write('%0.4f  %5.1f  %6.1f  |  %0.3f  %0.3f  (%0.3f) |  %0.3f  %0.3f  |  %0.3f  %0.3f  | %s \n' % (\
                         rate, iter/1000, epoch,
                         valid_loss[0], valid_loss[1], valid_loss[2],
                         train_loss[0], train_loss[1],
                         batch_loss[0], batch_loss[1],
                         time_to_str((timer() - start),'min')))
                time.sleep(0.01)
                
            if scheduler is not None:
                lr = scheduler(iter)
                if lr<0 : break
                adjust_learning_rate(optimizer, lr)
                rate = get_learning_rate(optimizer)
            
            # ok, train
            net.set_mode('train')

            input = input.cuda()
            truth = truth.cuda()

            logit = net(input) #data_parallel(net,input)

            if OHEM == "OHEM":
                loss = net.focal_loss(logit, truth, 1.0, 0.5, 0.25) + net.criterion(logit, truth)
            else:
                loss = net.criterion(logit, truth)

            dice = net.metric(logit, truth)
            
            # learn with grad acc
            loss /= n_acc
            loss.backward()
            
            if ((iter + 1) % n_acc) == 0:
                optimizer.step()
                optimizer.zero_grad()
                torch.nn.utils.clip_grad_norm(net.parameters(), 1)
            
            # print statistics  ------------
            batch_loss = np.array((
                           loss.item(),
                           dice.item(),
                           0, 0, 0, 0,
                         ))
            sum_train_loss += batch_loss
            sum += 1
            if iter%iter_smooth == 0:
                train_loss = sum_train_loss/sum
                sum_train_loss = np.zeros(6,np.float32)
                sum = 0

            log.write2('\r%0.4f  %5.1f  %6.1f  |  %0.3f  %0.3f  (%0.3f) |  %0.3f  %0.3f  |  %0.3f  %0.3f  | %s ' % (\
                         rate, iter/1000, epoch,
                         valid_loss[0], valid_loss[1], valid_loss[2],
                         train_loss[0], train_loss[1],
                         batch_loss[0], batch_loss[1],
                         time_to_str((timer() - start), 'min')))
            
            iter += 1


In [None]:
def get_dataloaders(data_root, batch_size, fold, nfolds=4, width=1024, height=1024, train_augment=null_augment, val_augment=null_augment, random_state=SEED):
    train_dataset = SIIMDataset(
        data_root,
        fold,
        width=width, height=height,
        phase='train',
        augment=train_augment,
        random_state=random_state,
        nfolds=nfolds
    )

    train_loader  = DataLoader(
        train_dataset,
        sampler     = RandomSampler(train_dataset),
        batch_size  = batch_size,
        drop_last   = True,
        num_workers = 8,
        pin_memory  = True,
        collate_fn  = null_collate
    )

    valid_dataset = SIIMDataset(
        data_root,
        fold,
        width=width, height=height,
        phase='val',
        augment=val_augment,
        random_state=random_state,
        nfolds=nfolds
    )

    valid_loader  = DataLoader(
        valid_dataset,
        sampler     = RandomSampler(valid_dataset),
        batch_size  = batch_size,
        drop_last   = False,
        num_workers = 8,
        pin_memory  = True,
        collate_fn  = null_collate
    )
    
    return train_loader, val_loader

In [10]:
print("Training U-Net with hypercolumn concatenation and spatial/channel-wise excitation...")

now = datetime.now()

for fold in range(nfolds):
    train_loader, val_loader = get_dataloaders(
        data_root,
        batch_size,
        fold, nfolds=nfolds,
        width=SIZE, height=SIZE,
        train_augment=null_augment, val_augment=null_augment,
        random_state=SEED
    )
    
    net = Net().cuda()
    
    lr = 0.001
    
    # warm up
    freeze(net)
    fit_one_cycle(
        6, net,
        train_loader, val_loader,
        lr_init=lr, lr_min=lr/100
    )
    
    # ok, train
    unfreeze(net)
    fit_one_cycle(
        12, net,
        train_loader, val_loader,
        lr_init=lr/2, lr_min=lr/80
    )
    
    # save
    torch.save(net.state_dict(), './model/unet_{}_fold{}.pth'.format())
    
    gc.collect()
    torch.cuda.empty_cache()
    