# BESNet with classification

In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

import os
from datetime import datetime
import sys
import gc
sys.path.append('../../')

from sklearn.model_selection import KFold
from tqdm import tqdm

from dependencies import *
from settings import *
from reproducibility import *
from models.TGS_salt.Cls_BesNet import ClsBesNet as Net

Using paths on kail-main

Importing numerical libraries...
Importing standard libraries...
Importing miscellaneous functions...
Importing constants...
Importing Neural Network dependencies...
	PyTorch
	Keras
	TensorFlow
	Metrics, Losses and LR Schedulers
	Kaggle Metrics
	Image augmentations
	Datasets
Importing external libraries...
	Lovasz Losses (elu+1)

Fixing random seed for reproducibility...
	Setting random seed to 35202.

Setting CUDA environment...
	torch.__version__              = 1.1.0
	torch.version.cuda             = 9.0.176
	torch.backends.cudnn.version() = 7501
	os['CUDA_VISIBLE_DEVICES']     = 0,1
	torch.cuda.device_count()      = 2



In [2]:
SIZE = 256
FACTOR = SIZE
ne = "ne"
initial_checkpoint = None
MODEL = "ResNet34"

batch_size = 8
n_acc = 256 / batch_size
nfolds = 4

noise_th = 75.0*(SIZE/128.0)**2 #threshold for the number of predicted pixels
best_thr0 = 0.2 #preliminary value of the threshold for metric calculation

data_root = '../../data/siim-pneumothorax'
torch.cuda.set_device('cuda:1')

In [3]:
def time_to_str(time, str):
    #if str == 'min':
    #	    return str(round(float(time)/60,5))+" min(s)"
    return round(time,4)

In [4]:
#TODO: Instead of directly printing to stdout, copy it into a txt file
class Logger():
    def __init__(self,name=MODEL+ne):
        super().__init__()
        self.model=name
        #if OHEM != "OHEM":
        #    self.model=MODEL+ne[ne.find("_")+1:]
        self.file = open(self.model+"_clsbes_log.txt","w+")
        self.file.close()
        
        self.debug_file = open(self.model + '_clsbes_debug.txt', 'w+')
        self.debug_file.close()
    def write(self, str):
        print(str)
        self.file = open(self.model+"_clsbes_log.txt","a+")
        self.file.write(str)
        self.file.close()
    def write2(self, str):
        print(str, end='',flush=True)
        self.file = open(self.model+"_clsbes_log.txt","a+")
        self.file.write(str)
        self.file.close()
    def debug(self, str):
        self.debug_file = open(self.model + '_clsbes_debug.txt', 'a+')
        self.debug_file.write(str)
        self.debug_file.close()
    def stop():
        self.file.close()
        self.debug_file.close()
        
log = Logger()

In [5]:
def valid_augment(image,mask,index):
    cache = Struct(image = image.copy(), mask = mask.copy())
    # image, mask = do_resize2(image, mask, SIZE, SIZE)
    # image, mask = do_center_pad_to_factor2(image, mask, factor = FACTOR)
    return image,mask,index,cache

def train_augment(image,mask,index):
    cache = Struct(image = image.copy(), mask = mask.copy())

    if np.random.rand() < 0.5:
         image, mask = do_horizontal_flip2(image, mask)
         pass

    if np.random.rand() < 0.2:
        c = np.random.choice(4)
        if c==0:
            image, mask = do_random_shift_scale_crop_pad2(image, mask, 0.1) #0.125

        if c==1:
            image, mask = do_horizontal_shear2( image, mask, dx=np.random.uniform(-0.02,0.02) )
            pass

        if c==2:
            image, mask = do_shift_scale_rotate2( image, mask, dx=0, dy=0, scale=1, angle=np.random.uniform(0,15))  #10

        if c==3:
            image, mask = do_elastic_transform2(image, mask, grid=10, distort=np.random.uniform(0,0.05))#0.10
            pass
    if np.random.rand() < 0.1:
        c = np.random.choice(3)
        if c==0:
            image = do_brightness_shift(image,np.random.uniform(-0.1,+0.1))
        if c==1:
            image = do_brightness_multiply(image,np.random.uniform(1-0.08,1+0.08))
        if c==2:
            image = do_gamma(image,np.random.uniform(1-0.08,1+0.08))
        # if c==1:
        #     image = do_invert_intensity(image)

    # image, mask = do_resize2(image, mask, SIZE, SIZE)
    # image, mask = do_center_pad_to_factor2(image, mask, factor = FACTOR)
    return image,mask,index,cache

In [6]:
def null_augment(image, mask, index):
    cache = Struct(image = image.copy(), mask = mask.copy())
    return image, mask, index, cache

def null_collate(batch):

    batch_size = len(batch)
    cache = []
    input = []
    truth = []
    index = []
    for b in range(batch_size):
        input.append(batch[b][0])
        truth.append(batch[b][1])
        index.append(batch[b][2])
        cache.append(batch[b][3])
    input = torch.from_numpy(np.array(input)).float().unsqueeze(1)

    if truth[0]!=[]:
        truth = torch.from_numpy(np.array(truth)).float().unsqueeze(1)

    return input, truth, index, cache

def get_weights_for_balanced_classes(cls_list, num_classes):
    # get count per class
    count = [0] * num_classes
    
    for cls in cls_list:
        count[cls] += 1

    # get weight per class
    weight_per_class = [0.] * num_classes
    N = float(len(cls_list))
    
    for i in range(num_classes):
        weight_per_class[i] = N / float(count[i])
        
    #　get weight per sample
    weights = [0] * len(cls_list)
    
    for i, cls in enumerate(cls_list):
        weights[i] = weight_per_class[cls]
        
    return weights

def get_boundary(masks):
    mask_arr = (masks.cpu().numpy() * 255).astype(np.uint8).squeeze()
    b_arr = []
    
    for mask in mask_arr:
        b_img = np.zeros(mask.shape)
        
        contours, hier = cv2.findContours(mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE)
        cv2.drawContours(b_img, contours, -1, 255, 1)
        
        b_arr.append(b_img)
        
    b_arr = np.stack(b_arr)
    
    return torch.from_numpy(b_arr)

def get_class(masks):
    n = masks.shape[0]
    classes = ((masks.view(n, -1).sum(-1, keepdim=True)).float() > 0.).float()
    
    return classes

class SIIMDataset(Dataset):
    def __init__(self, data_root, fold, pos_neg_ratio=0.5, width=1024, height=1024, phase='train', augment=null_augment, random_state=2019, nfolds=4):
        self.data_root = data_root
        self.fold = fold
        self.height = width
        self.width = height
        self.phase = phase
        self.augment = augment
        
        kf = KFold(n_splits=nfolds, shuffle=True, random_state=random_state)
        train_list = os.listdir(os.path.join(data_root, 'train_png'))
        
        if phase == 'train':
            index_list = list(kf.split(list(range(len(train_list)))))[fold][0]
            self.filenames = [train_list[i] for i in index_list]
            
            # read masks for pos/neg ratio sampler
            train_df = pd.read_csv(os.path.join(self.data_root, 'train-rle.csv'))
            pos_ids = list(train_df[train_df[' EncodedPixels']!=' -1']['ImageId'])

            self.cls_list = [1 if filename.split('.png')[0] in pos_ids else 0 for filename in self.filenames]

        elif phase == 'val':
            index_list = list(kf.split(list(range(len(train_list)))))[fold][1]
            self.filenames = [train_list[i] for i in index_list]
        else: # test
            self.filenames = os.listdir(os.path.join(data_root, 'test_png'))

    def __getitem__(self, index):
        img_path = os.path.join(self.data_root, 'train_png/{}'.format(self.filenames[index]))
        img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE).astype(np.float32) / 255.
        img = cv2.resize(img, (self.width, self.height), interpolation = cv2.INTER_AREA)
        
        if self.phase == 'test':
            mask = []
        else: # train and val
            mask_path = os.path.join(self.data_root, 'mask_png/{}'.format(self.filenames[index]))
            mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE).astype(np.float32) / 255.
            mask = cv2.resize(mask, (self.width, self.height), interpolation = cv2.INTER_AREA)
        
        return self.augment(img, mask, index)
    
    def __len__(self):
        return len(self.filenames)


In [7]:
def validation( net, valid_loader, weights=None ):

    valid_num  = 0
    valid_loss = np.zeros(3, np.float32)
    
    logits = []
    truths = []
    for input, truth, index, cache in valid_loader:
        input = input.cuda()
        truth = truth.cuda()
        
        with torch.no_grad():
            b_masks = get_boundary(truth).float().cuda() / 255.
            classes = get_class(truth).cuda()
            
            m_logit, b_logit, c_logit = net(input) #data_parallel(net,input)
            
            b_loss = net.boundary_criterion(b_logit, b_masks, weights=weights)
            m_loss = net.mask_criterion(m_logit, b_logit, truth, b_masks, alpha=5., beta=0.2, weights=weights)
            c_loss = net.class_criterion(c_logit, classes).mean()
            
            # losses are already reduced
            loss = b_loss + m_loss + c_loss
            
            dice  = net.metric(m_logit, truth, noise_th=0, threshold=0, logger=log)
            
            logits.append(m_logit.cpu())
            truths.append(truth.cpu())

        batch_size = len(index)
        valid_loss += batch_size * np.array(( loss.item(), dice.item(), 0))
        valid_num += batch_size
        
    valid_loss /= valid_num
    
    # find out optimal thr and dice
    log.debug('\nscan\n')
    logits = torch.cat(logits, dim=0)
    truths = torch.cat(truths, dim=0)
    
    gc.collect()
    torch.cuda.empty_cache()
    
    thrs = np.arange(0.05, 1, 0.05)
    
    th_dices = []
    for th in thrs:
        th_dice = net.metric(logits, truths, noise_th=0, threshold=th, logger=log)
        th_dices.append(th_dice)
        
    th_dices = np.array(th_dices)
    best_dice = th_dices.max()
    best_thr = thrs[th_dices.argmax()]
    
    valid_loss[1] = best_dice
    valid_loss[2] = best_thr
    
    gc.collect()
    torch.cuda.empty_cache()
        
    return valid_loss

In [8]:
def freeze(net):
    for p in net.conv1.parameters():
        p.requires_grad = False
        
    for p in net.encoder2.parameters():
        p.requires_grad = False
        
    for p in net.encoder3.parameters():
        p.requires_grad = False
        
    for p in net.encoder4.parameters():
        p.requires_grad = False
        
    for p in net.encoder5.parameters():
        p.requires_grad = False
        
    for p in net.center.parameters():
        p.requires_grad = False
        
    for p in net.decoder5.parameters():
        p.requires_grad = False
        
    for p in net.decoder4.parameters():
        p.requires_grad = False
        
    for p in net.decoder3.parameters():
        p.requires_grad = False
        
    for p in net.decoder2.parameters():
        p.requires_grad = False
        
    for p in net.decoder1.parameters():
        p.requires_grad = False
        
def unfreeze(net):
    for p in net.conv1.parameters():
        p.requires_grad = True
        
    for p in net.encoder2.parameters():
        p.requires_grad = True
        
    for p in net.encoder3.parameters():
        p.requires_grad = True
        
    for p in net.encoder4.parameters():
        p.requires_grad = True
        
    for p in net.encoder5.parameters():
        p.requires_grad = True
        
    for p in net.center.parameters():
        p.requires_grad = True
        
    for p in net.decoder5.parameters():
        p.requires_grad = True
        
    for p in net.decoder4.parameters():
        p.requires_grad = True
        
    for p in net.decoder3.parameters():
        p.requires_grad = True
        
    for p in net.decoder2.parameters():
        p.requires_grad = True
        
    for p in net.decoder1.parameters():
        p.requires_grad = True

def cosine_annealing_scheduler(num_iter, lr_init, lr_min):
    scheduler = lambda x: ((lr_init-lr_min)/2)*(np.cos(PI*(np.mod(x,num_iter)/(num_iter)))+1)+lr_min
    return scheduler
        
def set_BN_momentum(model, momentum=0.1*batch_size/64):
    for i, (name, layer) in enumerate(model.named_modules()):
        if isinstance(layer, nn.BatchNorm2d) or isinstance(layer, nn.BatchNorm1d):
            layer.momentum = momentum
            
def fit_one_cycle(epochs, net, train_loader, val_loader, lr_init=0.001, lr_min=0.000001, weights=None):
    # init learner
    iter_per_epoch = len(train_loader)
    num_iter = iter_per_epoch * epochs
    iter_smooth = 20
    iter_log    = 100
    iter_valid  = iter_per_epoch
    #iter_valid = 100
    
    #scheduler = None
    scheduler = cosine_annealing_scheduler(num_iter, lr_init, lr_min)
    optimizer = optim.SGD(filter(lambda p: p.requires_grad, net.parameters()),
          lr=lr_init, momentum=0.9, weight_decay=0.0001
    )
    set_BN_momentum(net)
    
    start_iter = 0
    start_epoch= 0
    train_loss  = np.zeros(6,np.float32)
    valid_loss  = np.zeros(6,np.float32)
    batch_loss  = np.zeros(6,np.float32)
    rate = 0
    iter = 0
    epoch = 0
    
    #debug
    if 0: #debug  ##-------------------------------
        debug_num = 2
        debug_count = 0
        
        for input, truth, index, cache in train_loader:
            images = input.cpu().data.numpy().squeeze()
            masks  = truth.cpu().data.numpy().squeeze()
            
            batch_size = len(index)
            for b in range(batch_size):
                image = images[b]*255
                image = np.dstack([image,image,image])

                mask = masks[b]
                print(np.max(mask))
                
                # Plot some samples
                fig, (ax0, ax1, ax2) = plt.subplots(ncols=3, figsize=(12, 4))
                ax0.imshow(image.astype(np.uint8))
                ax1.imshow(mask, vmin=0, vmax=1)
                ax1.set_title('Targets')
                
                plt.show()
                
            debug_count += 1
            if debug_count > debug_num:
                break
    #--------------------------------------
    
    start = timer()
    while iter < num_iter:  # loop over the dataset multiple times
        sum_train_loss = np.zeros(6,np.float32)
        sum = 0

        log.write('\n rate    iter   epoch   | valid_loss               | train_loss               | batch_loss               |  time          \n')
        log.write('-------------------------------------------------------------------------------------------------------------------------------\n')
            
        for input, truth, index, cache in train_loader:
            # validation
            if (iter + 1) % iter_valid == 0:
                log.debug('\nval\n')
                net.set_mode('valid')
                valid_loss = validation(net, val_loader)

                net.set_mode('train')
                log.debug('\ntrain\n')
                time.sleep(0.01)
            
            if scheduler is not None:
                lr = scheduler(iter)
                if lr<0 : break
                adjust_learning_rate(optimizer, lr)
                rate = get_learning_rate(optimizer)
            
            # ok, train
            net.set_mode('train')

            input = input.cuda().float()
            truth = truth.cuda().float()
            
            b_masks = (get_boundary(truth).float().cuda() / 255.).float()
            classes = get_class(truth).cuda()

            m_logit, b_logit, c_logit = net(input) #data_parallel(net,input)

            b_loss = net.boundary_criterion(b_logit, b_masks, weights=weights)
            m_loss = net.mask_criterion(m_logit, b_logit.detach(), truth, b_masks.detach(), alpha=5., beta=0.2, weights=weights)
            c_loss = net.class_criterion(c_logit, classes).mean()
            
            # losses are already reduced
            loss = b_loss + m_loss + c_loss
            loss_rec = loss.item()
                
            dice = net.metric(m_logit, truth, noise_th=0, threshold=0, logger=log)
            
            # learn with grad acc
            loss /= n_acc
            loss.backward()
            
            if ((iter + 1) % n_acc) == 0:
                optimizer.step()
                optimizer.zero_grad()
                # torch.nn.utils.clip_grad_norm_(net.parameters(), 1)
            
            # print statistics  ------------
            batch_loss = np.array((
                           loss.item(),
                           dice.item(),
                           0, 0, 0, 0,
                         ))
            sum_train_loss += batch_loss
            sum += 1
            if iter%iter_smooth == 0:
                train_loss = sum_train_loss/sum
                sum_train_loss = np.zeros(6,np.float32)
                sum = 0

            log.write2('\r%0.4f  %5.1f  %6.1f  |  %0.3f  %0.3f  (%0.3f) |  %0.3f  %0.3f  |  %0.3f  %0.3f  |  %0.4f  %0.4f  %0.4f | %s ' % (\
                         rate, iter/iter_per_epoch, epoch+1,
                         valid_loss[0], valid_loss[1], valid_loss[2],
                         train_loss[0], train_loss[1],
                         loss_rec, batch_loss[1],
                         b_loss.item(), m_loss.item(), c_loss.item(),
                         time_to_str((timer() - start), 'min')))
            
            iter += 1
            epoch = iter // iter_per_epoch

In [9]:
def get_dataloaders(data_root, batch_size, fold, nfolds=4, width=1024, height=1024, train_augment=null_augment, val_augment=null_augment, random_state=SEED):
    train_dataset = SIIMDataset(
        data_root,
        fold,
        width=width, height=height,
        phase='train',
        augment=train_augment,
        random_state=random_state,
        nfolds=nfolds
    )
    
    weights = get_weights_for_balanced_classes(train_dataset.cls_list, 2)
    weights = torch.DoubleTensor(weights)
    balance_sampler = torch.utils.data.sampler.WeightedRandomSampler(weights, len(weights))

    train_loader  = DataLoader(
        train_dataset,
        # sampler     = RandomSampler(train_dataset),
        sampler = balance_sampler,
        batch_size  = batch_size,
        drop_last   = True,
        num_workers = 8,
        pin_memory  = True,
        collate_fn  = null_collate
    )

    val_dataset = SIIMDataset(
        data_root,
        fold,
        width=width, height=height,
        phase='val',
        augment=val_augment,
        random_state=random_state,
        nfolds=nfolds
    )

    val_loader  = DataLoader(
        val_dataset,
        sampler     = RandomSampler(val_dataset),
        batch_size  = batch_size,
        drop_last   = False,
        num_workers = 8,
        pin_memory  = True,
        collate_fn  = null_collate
    )
    
    return train_loader, val_loader

## Train

In [10]:
# one fold test!

train_loader, val_loader = get_dataloaders(
    data_root,
    batch_size,
    0, nfolds=10,
    width=SIZE, height=SIZE,
    train_augment=train_augment, val_augment=valid_augment,
    random_state=SEED
)

net = Net().cuda()

lr = 1e-2

In [11]:
# warm up
freeze(net)
fit_one_cycle(
    2, net,
    train_loader, val_loader,
    lr_init=lr, lr_min=lr/100,
    weights=[0.05, 0.95]
)


 rate    iter   epoch   | valid_loss               | train_loss               | batch_loss               |  time          

-------------------------------------------------------------------------------------------------------------------------------





0.0051    1.0     1.0  |  0.000  0.000  (0.000) |  0.398  0.016  |  14.313  0.016  |  0.1882  0.1926  13.9326 | 229.4141 



0.0051    1.0     1.0  |  11.135  0.783  (0.800) |  0.398  0.016  |  17.305  0.023  |  0.2188  0.1975  16.8887 | 254.4256 
 rate    iter   epoch   | valid_loss               | train_loss               | batch_loss               |  time          

-------------------------------------------------------------------------------------------------------------------------------





0.0001    2.0     2.0  |  11.135  0.783  (0.800) |  0.351  0.017  |  17.421  0.012  |  0.1794  0.1930  17.0489 | 482.6882 



0.0001    2.0     2.0  |  12.266  0.783  (0.700) |  0.351  0.017  |  20.349  0.016  |  0.2054  0.1834  19.9604 | 506.0486 

In [12]:
torch.save(net.state_dict(), 'clsbes-cp-1.pth')

In [13]:
net.load_state_dict(torch.load('clsbes-cp-1.pth'))

IncompatibleKeys(missing_keys=[], unexpected_keys=[])

In [13]:
# ok, train positive
unfreeze(net)
fit_one_cycle(
    50, net,
    train_loader, val_loader,
    lr_init=lr, lr_min=lr/120,
    weights=[0.05, 0.95]
)


 rate    iter   epoch   | valid_loss               | train_loss               | batch_loss               |  time          

-------------------------------------------------------------------------------------------------------------------------------





0.0100    1.0     1.0  |  0.000  0.000  (0.000) |  0.083  0.024  |  1.260  0.006  |  0.1411  0.1614  0.9575 | 457.1003 



0.0100    1.0     1.0  |  4.652  0.783  (0.700) |  0.083  0.024  |  1.752  0.043  |  0.1695  0.1666  1.4157 | 481.1523 
 rate    iter   epoch   | valid_loss               | train_loss               | batch_loss               |  time          

-------------------------------------------------------------------------------------------------------------------------------





0.0100    2.0     2.0  |  4.652  0.783  (0.700) |  0.068  0.045  |  2.244  0.052  |  0.1455  0.1399  1.9584 | 935.8012 



0.0100    2.0     2.0  |  1.241  0.783  (0.950) |  0.068  0.045  |  2.077  0.083  |  0.1482  0.1502  1.7786 | 959.1011 
 rate    iter   epoch   | valid_loss               | train_loss               | batch_loss               |  time          

-------------------------------------------------------------------------------------------------------------------------------





0.0099    3.0     3.0  |  1.241  0.783  (0.950) |  0.034  0.053  |  1.633  0.022  |  0.1240  0.1448  1.3637 | 1413.9275 



0.0099    3.0     3.0  |  2.061  0.783  (0.900) |  0.034  0.053  |  1.222  0.024  |  0.1656  0.1853  0.8716 | 1437.2992 
 rate    iter   epoch   | valid_loss               | train_loss               | batch_loss               |  time          

-------------------------------------------------------------------------------------------------------------------------------





0.0098    4.0     4.0  |  2.061  0.783  (0.900) |  0.033  0.048  |  1.465  0.040  |  0.1676  0.1987  1.0991 | 1891.6811 



0.0098    4.0     4.0  |  1.727  0.783  (0.950) |  0.033  0.048  |  20.521  0.023  |  0.1655  0.1779  20.1772 | 1915.0641 
 rate    iter   epoch   | valid_loss               | train_loss               | batch_loss               |  time          

-------------------------------------------------------------------------------------------------------------------------------





0.0098    5.0     5.0  |  1.727  0.783  (0.950) |  0.032  0.060  |  0.925  0.029  |  0.1413  0.1490  0.6351 | 2369.4632 



0.0098    5.0     5.0  |  1.569  0.783  (0.950) |  0.032  0.060  |  1.143  0.076  |  0.1576  0.1947  0.7909 | 2392.85 
 rate    iter   epoch   | valid_loss               | train_loss               | batch_loss               |  time          

-------------------------------------------------------------------------------------------------------------------------------





0.0097    6.0     6.0  |  1.569  0.783  (0.950) |  0.036  0.071  |  1.173  0.084  |  0.1522  0.1369  0.8843 | 2847.3714 



0.0097    6.0     6.0  |  1.117  0.783  (0.950) |  0.036  0.071  |  0.799  0.018  |  0.1158  0.1494  0.5332 | 2871.4289 
 rate    iter   epoch   | valid_loss               | train_loss               | batch_loss               |  time          

-------------------------------------------------------------------------------------------------------------------------------





0.0095    7.0     7.0  |  1.117  0.783  (0.950) |  0.032  0.073  |  1.129  0.023  |  0.1168  0.1321  0.8805 | 3326.5474 



0.0095    7.0     7.0  |  1.865  0.783  (0.950) |  0.032  0.073  |  1.376  0.053  |  0.1008  0.1022  1.1728 | 3349.9452 
 rate    iter   epoch   | valid_loss               | train_loss               | batch_loss               |  time          

-------------------------------------------------------------------------------------------------------------------------------





0.0094    8.0     8.0  |  1.865  0.783  (0.950) |  0.041  0.055  |  1.202  0.089  |  0.1495  0.1296  0.9225 | 3805.1793 



0.0094    8.0     8.0  |  0.786  0.782  (0.950) |  0.041  0.055  |  1.355  0.056  |  0.1594  0.1635  1.0321 | 3829.6765 
 rate    iter   epoch   | valid_loss               | train_loss               | batch_loss               |  time          

-------------------------------------------------------------------------------------------------------------------------------





0.0092    9.0     9.0  |  0.786  0.782  (0.950) |  0.063  0.085  |  0.886  0.071  |  0.1087  0.0965  0.6811 | 4287.3058 



0.0092    9.0     9.0  |  0.966  0.783  (0.950) |  0.063  0.085  |  0.747  0.052  |  0.1657  0.1499  0.4310 | 4311.6915 
 rate    iter   epoch   | valid_loss               | train_loss               | batch_loss               |  time          

-------------------------------------------------------------------------------------------------------------------------------





0.0091   10.0    10.0  |  0.966  0.783  (0.950) |  0.030  0.078  |  0.957  0.100  |  0.1318  0.1064  0.7189 | 4769.055 



0.0091   10.0    10.0  |  1.151  0.783  (0.950) |  0.030  0.078  |  0.867  0.076  |  0.1620  0.1317  0.5737 | 4794.1396 
 rate    iter   epoch   | valid_loss               | train_loss               | batch_loss               |  time          

-------------------------------------------------------------------------------------------------------------------------------





0.0089   11.0    11.0  |  1.151  0.783  (0.950) |  0.032  0.073  |  0.635  0.109  |  0.0981  0.0844  0.4529 | 5251.3032 



0.0089   11.0    11.0  |  1.003  0.783  (0.950) |  0.032  0.073  |  2.793  0.132  |  0.1478  0.1533  2.4917 | 5276.4974 
 rate    iter   epoch   | valid_loss               | train_loss               | batch_loss               |  time          

-------------------------------------------------------------------------------------------------------------------------------





0.0087   12.0    12.0  |  1.003  0.783  (0.950) |  0.031  0.081  |  2.835  0.009  |  0.1690  0.1687  2.4969 | 5733.7992 



0.0087   12.0    12.0  |  2.768  0.779  (0.950) |  0.031  0.081  |  1.625  0.031  |  0.0736  0.0683  1.4835 | 5758.884 
 rate    iter   epoch   | valid_loss               | train_loss               | batch_loss               |  time          

-------------------------------------------------------------------------------------------------------------------------------





0.0084   13.0    13.0  |  2.768  0.779  (0.950) |  0.041  0.064  |  2.111  0.086  |  0.1608  0.1658  1.7849 | 6216.0252 



0.0084   13.0    13.0  |  2.050  0.782  (0.950) |  0.041  0.064  |  1.550  0.040  |  0.0831  0.0770  1.3903 | 6241.115 
 rate    iter   epoch   | valid_loss               | train_loss               | batch_loss               |  time          

-------------------------------------------------------------------------------------------------------------------------------





0.0082   14.0    14.0  |  2.050  0.782  (0.950) |  0.033  0.081  |  0.710  0.033  |  0.1224  0.1236  0.4641 | 6698.0576 



0.0082   14.0    14.0  |  0.977  0.781  (0.950) |  0.033  0.081  |  1.062  0.044  |  0.1462  0.1522  0.7636 | 6723.318 
 rate    iter   epoch   | valid_loss               | train_loss               | batch_loss               |  time          

-------------------------------------------------------------------------------------------------------------------------------





0.0080   15.0    15.0  |  0.977  0.781  (0.950) |  0.036  0.084  |  0.673  0.086  |  0.1072  0.1267  0.4390 | 7180.2777 



0.0080   15.0    15.0  |  0.973  0.783  (0.950) |  0.036  0.084  |  0.905  0.129  |  0.0955  0.0981  0.7112 | 7205.3796 
 rate    iter   epoch   | valid_loss               | train_loss               | batch_loss               |  time          

-------------------------------------------------------------------------------------------------------------------------------





0.0077   16.0    16.0  |  0.973  0.783  (0.950) |  0.033  0.086  |  0.771  0.105  |  0.1081  0.1046  0.5584 | 7662.6973 



0.0077   16.0    16.0  |  1.193  0.781  (0.950) |  0.033  0.086  |  0.823  0.071  |  0.1080  0.1054  0.6101 | 7687.7664 
 rate    iter   epoch   | valid_loss               | train_loss               | batch_loss               |  time          

-------------------------------------------------------------------------------------------------------------------------------





0.0074   17.0    17.0  |  1.193  0.781  (0.950) |  0.034  0.075  |  0.909  0.054  |  0.0867  0.0765  0.7454 | 8145.0443 



0.0074   17.0    17.0  |  0.791  0.783  (0.950) |  0.034  0.075  |  1.640  0.089  |  0.1366  0.1448  1.3590 | 8170.1206 
 rate    iter   epoch   | valid_loss               | train_loss               | batch_loss               |  time          

-------------------------------------------------------------------------------------------------------------------------------





0.0072   18.0    18.0  |  0.791  0.783  (0.950) |  0.032  0.075  |  0.789  0.028  |  0.0718  0.0582  0.6590 | 8627.266 



0.0072   18.0    18.0  |  0.780  0.781  (0.950) |  0.032  0.075  |  1.151  0.047  |  0.1388  0.1479  0.8640 | 8651.7243 
 rate    iter   epoch   | valid_loss               | train_loss               | batch_loss               |  time          

-------------------------------------------------------------------------------------------------------------------------------





0.0069   19.0    19.0  |  0.780  0.781  (0.950) |  0.030  0.082  |  0.851  0.109  |  0.0752  0.0691  0.7064 | 9108.9756 



0.0069   19.0    19.0  |  0.833  0.783  (0.950) |  0.030  0.082  |  0.718  0.050  |  0.1104  0.1244  0.4837 | 9134.0539 
 rate    iter   epoch   | valid_loss               | train_loss               | batch_loss               |  time          

-------------------------------------------------------------------------------------------------------------------------------





0.0066   20.0    20.0  |  0.833  0.783  (0.950) |  0.030  0.075  |  0.963  0.140  |  0.0962  0.0927  0.7740 | 9591.2032 



0.0066   20.0    20.0  |  0.777  0.783  (0.950) |  0.030  0.075  |  0.808  0.073  |  0.0993  0.0939  0.6149 | 9615.7929 
 rate    iter   epoch   | valid_loss               | train_loss               | batch_loss               |  time          

-------------------------------------------------------------------------------------------------------------------------------





0.0063   21.0    21.0  |  0.777  0.783  (0.950) |  0.030  0.072  |  0.894  0.098  |  0.0859  0.0777  0.7307 | 10073.0537 



0.0063   21.0    21.0  |  1.175  0.780  (0.950) |  0.030  0.072  |  0.914  0.054  |  0.1096  0.0896  0.7151 | 10098.1022 
 rate    iter   epoch   | valid_loss               | train_loss               | batch_loss               |  time          

-------------------------------------------------------------------------------------------------------------------------------





0.0060   22.0    22.0  |  1.175  0.780  (0.950) |  0.033  0.081  |  0.983  0.091  |  0.1281  0.1168  0.7380 | 10555.4487 



0.0060   22.0    22.0  |  0.781  0.776  (0.950) |  0.033  0.081  |  1.422  0.191  |  0.1308  0.1272  1.1638 | 10580.4695 
 rate    iter   epoch   | valid_loss               | train_loss               | batch_loss               |  time          

-------------------------------------------------------------------------------------------------------------------------------





0.0057   23.0    23.0  |  0.781  0.776  (0.950) |  0.029  0.071  |  1.173  0.044  |  0.1003  0.0968  0.9762 | 11037.441 



0.0057   23.0    23.0  |  0.944  0.782  (0.950) |  0.029  0.071  |  0.720  0.095  |  0.0877  0.0890  0.5428 | 11062.0034 
 rate    iter   epoch   | valid_loss               | train_loss               | batch_loss               |  time          

-------------------------------------------------------------------------------------------------------------------------------





0.0054   24.0    24.0  |  0.944  0.782  (0.950) |  0.026  0.087  |  0.857  0.054  |  0.0971  0.1004  0.6593 | 11519.2155 



0.0054   24.0    24.0  |  0.853  0.778  (0.950) |  0.026  0.087  |  1.001  0.044  |  0.1304  0.1369  0.7335 | 11544.3219 
 rate    iter   epoch   | valid_loss               | train_loss               | batch_loss               |  time          

-------------------------------------------------------------------------------------------------------------------------------





0.0050   25.0    25.0  |  0.853  0.778  (0.950) |  0.026  0.086  |  0.678  0.062  |  0.0751  0.0611  0.5420 | 12001.508 



0.0050   25.0    25.0  |  0.977  0.782  (0.950) |  0.026  0.086  |  0.797  0.108  |  0.1111  0.0860  0.5995 | 12025.9851 
 rate    iter   epoch   | valid_loss               | train_loss               | batch_loss               |  time          

-------------------------------------------------------------------------------------------------------------------------------





0.0047   26.0    26.0  |  0.977  0.782  (0.950) |  0.028  0.085  |  1.332  0.062  |  0.1816  0.2283  0.9217 | 12483.3051 



0.0047   26.0    26.0  |  2.377  0.781  (0.950) |  0.028  0.085  |  0.759  0.150  |  0.1141  0.0927  0.5519 | 12507.746 
 rate    iter   epoch   | valid_loss               | train_loss               | batch_loss               |  time          

-------------------------------------------------------------------------------------------------------------------------------





0.0044   27.0    27.0  |  2.377  0.781  (0.950) |  0.028  0.090  |  0.637  0.085  |  0.0763  0.0668  0.4935 | 12964.9639 



0.0044   27.0    27.0  |  1.175  0.783  (0.950) |  0.028  0.090  |  0.927  0.111  |  0.1287  0.1147  0.6837 | 12989.5504 
 rate    iter   epoch   | valid_loss               | train_loss               | batch_loss               |  time          

-------------------------------------------------------------------------------------------------------------------------------





0.0041   28.0    28.0  |  1.175  0.783  (0.950) |  0.026  0.080  |  0.775  0.135  |  0.1109  0.1212  0.5431 | 13446.5879 



0.0041   28.0    28.0  |  0.734  0.781  (0.950) |  0.026  0.080  |  0.710  0.086  |  0.0783  0.0704  0.5611 | 13471.1075 
 rate    iter   epoch   | valid_loss               | train_loss               | batch_loss               |  time          

-------------------------------------------------------------------------------------------------------------------------------





0.0038   29.0    29.0  |  0.734  0.781  (0.950) |  0.028  0.071  |  0.872  0.122  |  0.1069  0.0957  0.6691 | 13928.2904 



0.0038   29.0    29.0  |  0.883  0.780  (0.950) |  0.028  0.071  |  0.840  0.062  |  0.1255  0.1269  0.5872 | 13952.7922 
 rate    iter   epoch   | valid_loss               | train_loss               | batch_loss               |  time          

-------------------------------------------------------------------------------------------------------------------------------





0.0035   30.0    30.0  |  0.883  0.780  (0.950) |  0.026  0.069  |  0.832  0.128  |  0.1147  0.1244  0.5925 | 14409.7599 



0.0035   30.0    30.0  |  0.736  0.780  (0.950) |  0.026  0.069  |  0.729  0.065  |  0.0937  0.0842  0.5515 | 14435.1581 
 rate    iter   epoch   | valid_loss               | train_loss               | batch_loss               |  time          

-------------------------------------------------------------------------------------------------------------------------------





0.0032   31.0    31.0  |  0.736  0.780  (0.950) |  0.026  0.082  |  0.708  0.041  |  0.1134  0.1167  0.4779 | 14892.1401 



0.0032   31.0    31.0  |  0.723  0.778  (0.950) |  0.026  0.082  |  0.772  0.056  |  0.0696  0.0631  0.6389 | 14917.1902 
 rate    iter   epoch   | valid_loss               | train_loss               | batch_loss               |  time          

-------------------------------------------------------------------------------------------------------------------------------





0.0029   32.0    32.0  |  0.723  0.778  (0.950) |  0.027  0.080  |  1.175  0.030  |  0.2013  0.2918  0.6816 | 15374.3237 



0.0029   32.0    32.0  |  0.890  0.780  (0.950) |  0.027  0.080  |  1.078  0.125  |  0.1327  0.1273  0.8181 | 15399.4954 
 rate    iter   epoch   | valid_loss               | train_loss               | batch_loss               |  time          

-------------------------------------------------------------------------------------------------------------------------------





0.0027   33.0    33.0  |  0.890  0.780  (0.950) |  0.027  0.092  |  0.684  0.030  |  0.0717  0.0665  0.5457 | 15856.5381 



0.0027   33.0    33.0  |  0.906  0.779  (0.950) |  0.027  0.092  |  0.897  0.136  |  0.1218  0.1198  0.6557 | 15881.074 
 rate    iter   epoch   | valid_loss               | train_loss               | batch_loss               |  time          

-------------------------------------------------------------------------------------------------------------------------------





0.0024   34.0    34.0  |  0.906  0.779  (0.950) |  0.026  0.078  |  0.817  0.058  |  0.1496  0.1129  0.5550 | 16338.3856 



0.0024   34.0    34.0  |  0.821  0.778  (0.950) |  0.026  0.078  |  0.648  0.117  |  0.1012  0.0794  0.4671 | 16363.5113 
 rate    iter   epoch   | valid_loss               | train_loss               | batch_loss               |  time          

-------------------------------------------------------------------------------------------------------------------------------





0.0021   35.0    35.0  |  0.821  0.778  (0.950) |  0.028  0.082  |  0.768  0.110  |  0.1165  0.1152  0.5361 | 16820.7985 



0.0021   35.0    35.0  |  0.753  0.780  (0.950) |  0.028  0.082  |  0.839  0.081  |  0.1171  0.1220  0.5997 | 16845.2785 
 rate    iter   epoch   | valid_loss               | train_loss               | batch_loss               |  time          

-------------------------------------------------------------------------------------------------------------------------------





0.0019   36.0    36.0  |  0.753  0.780  (0.950) |  0.027  0.076  |  0.873  0.196  |  0.0987  0.0882  0.6857 | 17302.4702 



0.0019   36.0    36.0  |  0.750  0.779  (0.950) |  0.027  0.076  |  0.779  0.110  |  0.0895  0.0802  0.6091 | 17327.654 
 rate    iter   epoch   | valid_loss               | train_loss               | batch_loss               |  time          

-------------------------------------------------------------------------------------------------------------------------------





0.0016   37.0    37.0  |  0.750  0.779  (0.950) |  0.025  0.096  |  0.681  0.125  |  0.1513  0.1728  0.3565 | 17784.7805 



0.0016   37.0    37.0  |  0.765  0.779  (0.950) |  0.025  0.096  |  0.958  0.077  |  0.1109  0.1036  0.7435 | 17809.8008 
 rate    iter   epoch   | valid_loss               | train_loss               | batch_loss               |  time          

-------------------------------------------------------------------------------------------------------------------------------





0.0014   38.0    38.0  |  0.765  0.779  (0.950) |  0.028  0.094  |  0.770  0.073  |  0.0899  0.0735  0.6065 | 18267.1257 



0.0014   38.0    38.0  |  0.736  0.779  (0.950) |  0.028  0.094  |  0.535  0.027  |  0.0645  0.0678  0.4024 | 18291.7306 
 rate    iter   epoch   | valid_loss               | train_loss               | batch_loss               |  time          

-------------------------------------------------------------------------------------------------------------------------------





0.0012   39.0    39.0  |  0.736  0.779  (0.950) |  0.027  0.107  |  0.766  0.016  |  0.1221  0.1265  0.5174 | 18748.9571 



0.0012   39.0    39.0  |  0.727  0.779  (0.950) |  0.027  0.107  |  0.805  0.028  |  0.0764  0.0737  0.6546 | 18773.6531 
 rate    iter   epoch   | valid_loss               | train_loss               | batch_loss               |  time          

-------------------------------------------------------------------------------------------------------------------------------





0.0010   40.0    40.0  |  0.727  0.779  (0.950) |  0.028  0.094  |  1.123  0.080  |  0.0945  0.0818  0.9468 | 19230.78 



0.0010   40.0    40.0  |  0.844  0.779  (0.950) |  0.028  0.094  |  0.651  0.078  |  0.0980  0.0737  0.4789 | 19255.8057 
 rate    iter   epoch   | valid_loss               | train_loss               | batch_loss               |  time          

-------------------------------------------------------------------------------------------------------------------------------





0.0009   41.0    41.0  |  0.844  0.779  (0.950) |  0.025  0.096  |  0.991  0.104  |  0.1531  0.1800  0.6578 | 19713.0729 



0.0009   41.0    41.0  |  0.776  0.779  (0.950) |  0.025  0.096  |  0.887  0.062  |  0.1170  0.1052  0.6650 | 19738.1352 
 rate    iter   epoch   | valid_loss               | train_loss               | batch_loss               |  time          

-------------------------------------------------------------------------------------------------------------------------------





0.0007   42.0    42.0  |  0.776  0.779  (0.950) |  0.027  0.079  |  0.841  0.042  |  0.0783  0.0670  0.6955 | 20195.3825 



0.0007   42.0    42.0  |  0.750  0.779  (0.950) |  0.027  0.079  |  0.618  0.030  |  0.1048  0.1107  0.4029 | 20220.7077 
 rate    iter   epoch   | valid_loss               | train_loss               | batch_loss               |  time          

-------------------------------------------------------------------------------------------------------------------------------





0.0006   43.0    43.0  |  0.750  0.779  (0.950) |  0.025  0.081  |  0.611  0.081  |  0.0709  0.0595  0.4806 | 20677.9765 



0.0006   43.0    43.0  |  0.766  0.779  (0.950) |  0.025  0.081  |  0.847  0.073  |  0.0724  0.0594  0.7150 | 20702.4379 
 rate    iter   epoch   | valid_loss               | train_loss               | batch_loss               |  time          

-------------------------------------------------------------------------------------------------------------------------------





0.0004   44.0    44.0  |  0.766  0.779  (0.950) |  0.028  0.085  |  0.930  0.082  |  0.0855  0.0796  0.7647 | 21159.4342 



0.0004   44.0    44.0  |  0.861  0.778  (0.950) |  0.028  0.085  |  0.821  0.138  |  0.0684  0.0580  0.6941 | 21184.5886 
 rate    iter   epoch   | valid_loss               | train_loss               | batch_loss               |  time          

-------------------------------------------------------------------------------------------------------------------------------





0.0003   45.0    45.0  |  0.861  0.778  (0.950) |  0.026  0.090  |  0.680  0.084  |  0.0754  0.0633  0.5413 | 21641.8348 



0.0003   45.0    45.0  |  0.871  0.778  (0.950) |  0.026  0.090  |  0.900  0.142  |  0.1135  0.1079  0.6789 | 21666.4877 
 rate    iter   epoch   | valid_loss               | train_loss               | batch_loss               |  time          

-------------------------------------------------------------------------------------------------------------------------------





0.0002   46.0    46.0  |  0.871  0.778  (0.950) |  0.025  0.088  |  1.315  0.059  |  0.1215  0.1289  1.0647 | 22123.7795 



0.0002   46.0    46.0  |  0.865  0.779  (0.950) |  0.025  0.088  |  0.893  0.061  |  0.1511  0.1279  0.6135 | 22148.929 
 rate    iter   epoch   | valid_loss               | train_loss               | batch_loss               |  time          

-------------------------------------------------------------------------------------------------------------------------------





0.0002   47.0    47.0  |  0.865  0.779  (0.950) |  0.025  0.090  |  0.747  0.138  |  0.1140  0.1239  0.5090 | 22606.2795 



0.0002   47.0    47.0  |  0.773  0.779  (0.950) |  0.025  0.090  |  0.800  0.082  |  0.1527  0.1429  0.5049 | 22630.8849 
 rate    iter   epoch   | valid_loss               | train_loss               | batch_loss               |  time          

-------------------------------------------------------------------------------------------------------------------------------





0.0001   48.0    48.0  |  0.773  0.779  (0.950) |  0.025  0.102  |  0.992  0.129  |  0.1346  0.1299  0.7271 | 23088.2832 



0.0001   48.0    48.0  |  0.867  0.778  (0.950) |  0.025  0.102  |  1.207  0.176  |  0.1476  0.1649  0.8943 | 23113.2784 
 rate    iter   epoch   | valid_loss               | train_loss               | batch_loss               |  time          

-------------------------------------------------------------------------------------------------------------------------------





0.0001   49.0    49.0  |  0.867  0.778  (0.950) |  0.026  0.094  |  0.714  0.144  |  0.1224  0.1066  0.4850 | 23570.6906 



0.0001   49.0    49.0  |  0.811  0.779  (0.950) |  0.026  0.094  |  0.729  0.118  |  0.1143  0.1112  0.5032 | 23595.7471 
 rate    iter   epoch   | valid_loss               | train_loss               | batch_loss               |  time          

-------------------------------------------------------------------------------------------------------------------------------





0.0001   50.0    50.0  |  0.811  0.779  (0.950) |  0.026  0.092  |  1.049  0.089  |  0.0894  0.0832  0.8760 | 24053.3734 



0.0001   50.0    50.0  |  0.899  0.780  (0.950) |  0.026  0.092  |  0.806  0.036  |  0.0813  0.0821  0.6429 | 24078.6106 

In [14]:
torch.save(net.state_dict(), 'clsbes-cp-2.pth')

In [None]:
fit_one_cycle(
    50, net,
    train_loader, val_loader,
    lr_init=lr/20, lr_min=lr/250,
    weights=None
)

In [None]:
torch.save(net.state_dict(), 'clsbes-cp-3.pth')

In [None]:
net.load_state_dict(torch.load('clsbes-cp-3.pth'))

## find best threshold

In [10]:
train_loader, val_loader = get_dataloaders(
    data_root,
    batch_size,
    0, nfolds=10,
    width=SIZE, height=SIZE,
    train_augment=train_augment, val_augment=valid_augment,
    random_state=SEED
)

net = Net().cuda()

In [14]:
net.load_state_dict(torch.load('bes-cp-3.pth'))

IncompatibleKeys(missing_keys=[], unexpected_keys=[])

In [15]:
# TTA with flip-lr

def tta(net, loader):
    net.set_mode('test')
    
    all_probs = []
    all_truths = []
    
    with torch.no_grad():
        with tqdm(total=len(loader), file=sys.stdout) as pbar:
            for input, truth, index, cache in loader:
                input = input.cuda()

                logits = net(input)
                probs = F.sigmoid(logits)

                all_probs.append(probs)
                all_truths.append(truth)

                pbar.update(1)
            
        i = 0
        with tqdm(total=len(loader), file=sys.stdout) as pbar:
            # tensor.shape = [N, C, H, W], so we flip dim -1
            for input, truth, index, cache in loader:
                input = input.cuda()
                input = torch.flip(input, [-1])

                logits = net(input)
                probs = F.sigmoid(logits)

                probs = torch.flip(probs, [-1])

                all_probs[i] += probs
                all_probs[i] *= 0.5

                i += 1
                pbar.update(1)

        all_probs = torch.cat(all_probs, dim=0)
        all_truths = torch.cat(all_truths, dim=0)
    
    gc.collect()
    torch.cuda.empty_cache()
    
    return all_probs, all_truths

In [21]:
# a fake tta!

def fake_tta(net, loader):
    net.set_mode('test')
    
    all_m_probs = []
    all_b_probs = []
    all_truths = []
    
    with torch.no_grad():
        with tqdm(total=len(loader), file=sys.stdout) as pbar:
            for input, truth, index, cache in loader:
                input = input.cuda()

                m_logits, b_logits = net(input)
                m_probs = F.sigmoid(m_logits)
                b_probs = F.sigmoid(b_logits)

                all_m_probs.append(m_probs)
                all_b_probs.append(b_probs)
                all_truths.append(truth)

                pbar.update(1)
                
                if 0: #debug  ##-------------------------------
                    images = input.cpu().detach().numpy().squeeze()
                    masks  = truth.cpu().detach().numpy().squeeze()
                    
                    m_results = m_probs.cpu().detach().numpy().squeeze()
                    b_results = b_probs.cpu().detach().numpy().squeeze()
                    
                    batch_size = len(index)
                    
                    for b in range(batch_size):
                        image = images[b]*255
                        image = np.dstack([image,image,image])

                        mask = masks[b]
                        m_result = m_results[b]
                        b_result = b_results[b]
                        
                        # Plot some samples
                        fig, (ax0, ax1, ax2, ax3) = plt.subplots(ncols=4, figsize=(16, 4))
                        
                        ax0.imshow(image.astype(np.uint8))
                        
                        ax1.imshow(mask, vmin=0, vmax=1)
                        ax1.set_title('truth')
                        
                        ax2.imshow(m_result, vmin=0, vmax=1)
                        ax2.set_title('m_prob')
                        
                        ax3.imshow(b_result, vmin=0, vmax=1)
                        ax3.set_title('b_prob')

                        plt.show()
                #--------------------------------------

        all_m_probs = torch.cat(all_m_probs, dim=0)
        all_b_probs = torch.cat(all_b_probs, dim=0)
        all_truths = torch.cat(all_truths, dim=0)
    
    gc.collect()
    torch.cuda.empty_cache()
    
    return all_m_probs, all_b_probs, all_truths

In [17]:
#dice for threshold selection

def dice_overall(preds, targs):
    n = preds.shape[0]
    
    preds = preds.view(n, -1)
    targs = targs.view(n, -1)
    
    targs = (targs > 0.5).long()
    
    intersect = (preds * targs).sum(-1).float()
    union = (preds + targs).sum(-1).float()
    
    # get 1 for both empty pred and targ
    u0 = union==0
    intersect[u0] = 1
    union[u0] = 2
    
    return (2. * intersect / union)

### eval

In [23]:
m_probs, b_probs, truths = fake_tta(net, val_loader)

# noise removal
m_probs[m_probs.view(m_probs.shape[0], -1).sum(-1) < 75*(SIZE/128.0)**2,...] = 0.0

# search best threshold for this fold
scores, best_thrs = [],[]
dices = []
thrs = np.arange(0.01, 1, 0.01)

m_probs = m_probs.cuda()
truths = truths.cuda()

with tqdm(total=len(thrs), file=sys.stdout) as pbar:
    for th in thrs:
        preds = (m_probs>th).long()
        dices.append(dice_overall(preds, truths).mean())

        pbar.update(1)

dices = np.array(dices)    

# save best
scores.append(dices.max())
best_thrs.append(thrs[dices.argmax()])

print(dices, scores, best_thrs)

  0%|          | 0/134 [00:00<?, ?it/s]



  1%|          | 1/134 [00:00<01:05,  2.02it/s]



100%|██████████| 134/134 [00:15<00:00,  8.68it/s]
tensor([ 68.6168, 188.8050,  76.6630,  ...,  37.3683, 405.8776,  53.5727],
       device='cuda:0')
100%|██████████| 99/99 [00:01<00:00, 58.47it/s]
[tensor(0.6575, device='cuda:0') tensor(0.6616, device='cuda:0')
 tensor(0.6641, device='cuda:0') tensor(0.6658, device='cuda:0')
 tensor(0.6672, device='cuda:0') tensor(0.6682, device='cuda:0')
 tensor(0.6692, device='cuda:0') tensor(0.6700, device='cuda:0')
 tensor(0.6707, device='cuda:0') tensor(0.6714, device='cuda:0')
 tensor(0.6720, device='cuda:0') tensor(0.6725, device='cuda:0')
 tensor(0.6730, device='cuda:0') tensor(0.6735, device='cuda:0')
 tensor(0.6740, device='cuda:0') tensor(0.6744, device='cuda:0')
 tensor(0.6748, device='cuda:0') tensor(0.6752, device='cuda:0')
 tensor(0.6755, device='cuda:0') tensor(0.6759, device='cuda:0')
 tensor(0.6762, device='cuda:0') tensor(0.6765, device='cuda:0')
 tensor(0.6768, device='cuda:0') tensor(0.6771, device='cuda:0')
 tensor(0.6774, device=