# BESNet with classification

In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

import os
from datetime import datetime
import sys
import gc
sys.path.append('../../')

from sklearn.model_selection import KFold
from tqdm import tqdm

from dependencies import *
from settings import *
from reproducibility import *
from models.TGS_salt.Cls_BesNet import ClsBesNet as Net

Using paths on kail-main

Importing numerical libraries...
Importing standard libraries...
Importing miscellaneous functions...
Importing constants...
Importing Neural Network dependencies...
	PyTorch
	Keras
	TensorFlow
	Metrics, Losses and LR Schedulers
	Kaggle Metrics
	Image augmentations
	Datasets
Importing external libraries...
	Lovasz Losses (elu+1)

Fixing random seed for reproducibility...
	Setting random seed to 35202.

Setting CUDA environment...
	torch.__version__              = 1.1.0
	torch.version.cuda             = 9.0.176
	torch.backends.cudnn.version() = 7501
	os['CUDA_VISIBLE_DEVICES']     = 0,1
	torch.cuda.device_count()      = 2



In [2]:
SIZE = 256
FACTOR = SIZE
ne = "ne"
initial_checkpoint = None
MODEL = "ResNet34"

batch_size = 8
n_acc = 256 / batch_size
nfolds = 4

noise_th = 75.0*(SIZE/128.0)**2 #threshold for the number of predicted pixels
best_thr0 = 0.2 #preliminary value of the threshold for metric calculation

data_root = '../../data/siim-pneumothorax'
torch.cuda.set_device('cuda:1')

In [3]:
def time_to_str(time, str):
    #if str == 'min':
    #	    return str(round(float(time)/60,5))+" min(s)"
    return round(time,4)

In [4]:
#TODO: Instead of directly printing to stdout, copy it into a txt file
class Logger():
    def __init__(self,name=MODEL+ne):
        super().__init__()
        self.model=name
        #if OHEM != "OHEM":
        #    self.model=MODEL+ne[ne.find("_")+1:]
        self.file = open(self.model+"_clsbes_log.txt","w+")
        self.file.close()
        
        self.debug_file = open(self.model + '_clsbes_debug.txt', 'w+')
        self.debug_file.close()
    def write(self, str):
        print(str)
        self.file = open(self.model+"_clsbes_log.txt","a+")
        self.file.write(str)
        self.file.close()
    def write2(self, str):
        print(str, end='',flush=True)
        self.file = open(self.model+"_clsbes_log.txt","a+")
        self.file.write(str)
        self.file.close()
    def debug(self, str):
        self.debug_file = open(self.model + '_clsbes_debug.txt', 'a+')
        self.debug_file.write(str)
        self.debug_file.close()
    def stop():
        self.file.close()
        self.debug_file.close()
        
log = Logger()

In [5]:
def valid_augment(image,mask,index):
    cache = Struct(image = image.copy(), mask = mask.copy())
    # image, mask = do_resize2(image, mask, SIZE, SIZE)
    # image, mask = do_center_pad_to_factor2(image, mask, factor = FACTOR)
    return image,mask,index,cache

def train_augment(image,mask,index):
    cache = Struct(image = image.copy(), mask = mask.copy())

    if np.random.rand() < 0.5:
         image, mask = do_horizontal_flip2(image, mask)
         pass

    if np.random.rand() < 0.2:
        c = np.random.choice(4)
        if c==0:
            image, mask = do_random_shift_scale_crop_pad2(image, mask, 0.1) #0.125

        if c==1:
            image, mask = do_horizontal_shear2( image, mask, dx=np.random.uniform(-0.02,0.02) )
            pass

        if c==2:
            image, mask = do_shift_scale_rotate2( image, mask, dx=0, dy=0, scale=1, angle=np.random.uniform(0,15))  #10

        if c==3:
            image, mask = do_elastic_transform2(image, mask, grid=10, distort=np.random.uniform(0,0.05))#0.10
            pass
    if np.random.rand() < 0.1:
        c = np.random.choice(3)
        if c==0:
            image = do_brightness_shift(image,np.random.uniform(-0.1,+0.1))
        if c==1:
            image = do_brightness_multiply(image,np.random.uniform(1-0.08,1+0.08))
        if c==2:
            image = do_gamma(image,np.random.uniform(1-0.08,1+0.08))
        # if c==1:
        #     image = do_invert_intensity(image)

    # image, mask = do_resize2(image, mask, SIZE, SIZE)
    # image, mask = do_center_pad_to_factor2(image, mask, factor = FACTOR)
    return image,mask,index,cache

In [6]:
def null_augment(image, mask, index):
    cache = Struct(image = image.copy(), mask = mask.copy())
    return image, mask, index, cache

def null_collate(batch):

    batch_size = len(batch)
    cache = []
    input = []
    truth = []
    index = []
    for b in range(batch_size):
        input.append(batch[b][0])
        truth.append(batch[b][1])
        index.append(batch[b][2])
        cache.append(batch[b][3])
    input = torch.from_numpy(np.array(input)).float().unsqueeze(1)

    if truth[0]!=[]:
        truth = torch.from_numpy(np.array(truth)).float().unsqueeze(1)

    return input, truth, index, cache

def get_weights_for_balanced_classes(cls_list, num_classes):
    # get count per class
    count = [0] * num_classes
    
    for cls in cls_list:
        count[cls] += 1

    # get weight per class
    weight_per_class = [0.] * num_classes
    N = float(len(cls_list))
    
    for i in range(num_classes):
        weight_per_class[i] = N / float(count[i])
        
    #　get weight per sample
    weights = [0] * len(cls_list)
    
    for i, cls in enumerate(cls_list):
        weights[i] = weight_per_class[cls]
        
    return weights

def get_boundary(masks):
    mask_arr = (masks.cpu().numpy() * 255).astype(np.uint8).squeeze()
    b_arr = []
    
    for mask in mask_arr:
        b_img = np.zeros(mask.shape)
        
        contours, hier = cv2.findContours(mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE)
        cv2.drawContours(b_img, contours, -1, 255, 1)
        
        b_arr.append(b_img)
        
    b_arr = np.stack(b_arr)
    
    return torch.from_numpy(b_arr)

def get_class(masks):
    n = masks.shape[0]
    classes = ((masks.view(n, -1).sum(-1, keepdim=True)).float() > 0.).float()
    
    return classes

class SIIMDataset(Dataset):
    def __init__(self, data_root, fold, pos_neg_ratio=0.5, width=1024, height=1024, phase='train', augment=null_augment, random_state=2019, nfolds=4):
        self.data_root = data_root
        self.fold = fold
        self.height = width
        self.width = height
        self.phase = phase
        self.augment = augment
        
        kf = KFold(n_splits=nfolds, shuffle=True, random_state=random_state)
        train_list = os.listdir(os.path.join(data_root, 'train_png'))
        
        if phase == 'train':
            index_list = list(kf.split(list(range(len(train_list)))))[fold][0]
            self.filenames = [train_list[i] for i in index_list]
            
            # read masks for pos/neg ratio sampler
            train_df = pd.read_csv(os.path.join(self.data_root, 'train-rle.csv'))
            pos_ids = list(train_df[train_df[' EncodedPixels']!=' -1']['ImageId'])

            self.cls_list = [1 if filename.split('.png')[0] in pos_ids else 0 for filename in self.filenames]

        elif phase == 'val':
            index_list = list(kf.split(list(range(len(train_list)))))[fold][1]
            self.filenames = [train_list[i] for i in index_list]
        else: # test
            self.filenames = os.listdir(os.path.join(data_root, 'test_png'))

    def __getitem__(self, index):
        img_path = os.path.join(self.data_root, 'train_png/{}'.format(self.filenames[index]))
        img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE).astype(np.float32) / 255.
        img = cv2.resize(img, (self.width, self.height), interpolation = cv2.INTER_AREA)
        
        if self.phase == 'test':
            mask = []
        else: # train and val
            mask_path = os.path.join(self.data_root, 'mask_png/{}'.format(self.filenames[index]))
            mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE).astype(np.float32) / 255.
            mask = cv2.resize(mask, (self.width, self.height), interpolation = cv2.INTER_AREA)
        
        return self.augment(img, mask, index)
    
    def __len__(self):
        return len(self.filenames)


In [7]:
def validation( net, valid_loader, weights=None ):

    valid_num  = 0
    valid_loss = np.zeros(3, np.float32)
    
    logits = []
    truths = []
    for input, truth, index, cache in valid_loader:
        input = input.cuda()
        truth = truth.cuda()
        
        with torch.no_grad():
            b_masks = get_boundary(truth).float().cuda() / 255.
            classes = get_class(truth).cuda()
            
            m_logit, b_logit, c_logit = net(input) #data_parallel(net,input)
            
            b_loss = net.boundary_criterion(b_logit, b_masks, weights=weights)
            m_loss = net.mask_criterion(m_logit, b_logit, truth, b_masks, alpha=5., beta=0.2, weights=weights)
            c_loss = net.class_criterion(c_logit, classes).mean()
            
            # losses are already reduced
            loss = b_loss + m_loss + c_loss
            
            dice  = net.metric(m_logit, truth, noise_th=0, threshold=0, logger=log)
            
            logits.append(m_logit.cpu())
            truths.append(truth.cpu())

        batch_size = len(index)
        valid_loss += batch_size * np.array(( loss.item(), dice.item(), 0))
        valid_num += batch_size
        
    valid_loss /= valid_num
    
    # find out optimal thr and dice
    log.debug('\nscan\n')
    logits = torch.cat(logits, dim=0)
    truths = torch.cat(truths, dim=0)
    
    gc.collect()
    torch.cuda.empty_cache()
    
    thrs = np.arange(0.05, 1, 0.05)
    
    th_dices = []
    for th in thrs:
        th_dice = net.metric(logits, truths, noise_th=0, threshold=th, logger=log)
        th_dices.append(th_dice)
        
    th_dices = np.array(th_dices)
    best_dice = th_dices.max()
    best_thr = thrs[th_dices.argmax()]
    
    valid_loss[1] = best_dice
    valid_loss[2] = best_thr
    
    gc.collect()
    torch.cuda.empty_cache()
        
    return valid_loss

In [8]:
def freeze(net):
    for p in net.conv1.parameters():
        p.requires_grad = False
        
    for p in net.encoder2.parameters():
        p.requires_grad = False
        
    for p in net.encoder3.parameters():
        p.requires_grad = False
        
    for p in net.encoder4.parameters():
        p.requires_grad = False
        
    for p in net.encoder5.parameters():
        p.requires_grad = False
        
    for p in net.center.parameters():
        p.requires_grad = False
        
    for p in net.decoder5.parameters():
        p.requires_grad = False
        
    for p in net.decoder4.parameters():
        p.requires_grad = False
        
    for p in net.decoder3.parameters():
        p.requires_grad = False
        
    for p in net.decoder2.parameters():
        p.requires_grad = False
        
    for p in net.decoder1.parameters():
        p.requires_grad = False
        
def unfreeze(net):
    for p in net.conv1.parameters():
        p.requires_grad = True
        
    for p in net.encoder2.parameters():
        p.requires_grad = True
        
    for p in net.encoder3.parameters():
        p.requires_grad = True
        
    for p in net.encoder4.parameters():
        p.requires_grad = True
        
    for p in net.encoder5.parameters():
        p.requires_grad = True
        
    for p in net.center.parameters():
        p.requires_grad = True
        
    for p in net.decoder5.parameters():
        p.requires_grad = True
        
    for p in net.decoder4.parameters():
        p.requires_grad = True
        
    for p in net.decoder3.parameters():
        p.requires_grad = True
        
    for p in net.decoder2.parameters():
        p.requires_grad = True
        
    for p in net.decoder1.parameters():
        p.requires_grad = True

def cosine_annealing_scheduler(num_iter, lr_init, lr_min):
    scheduler = lambda x: ((lr_init-lr_min)/2)*(np.cos(PI*(np.mod(x,num_iter)/(num_iter)))+1)+lr_min
    return scheduler
        
def set_BN_momentum(model, momentum=0.1*batch_size/64):
    for i, (name, layer) in enumerate(model.named_modules()):
        if isinstance(layer, nn.BatchNorm2d) or isinstance(layer, nn.BatchNorm1d):
            layer.momentum = momentum
            
def fit_one_cycle(epochs, net, train_loader, val_loader, lr_init=0.001, lr_min=0.000001, weights=None):
    # init learner
    iter_per_epoch = len(train_loader)
    num_iter = iter_per_epoch * epochs
    iter_smooth = 20
    iter_log    = 100
    iter_valid  = iter_per_epoch
    #iter_valid = 100
    
    #scheduler = None
    scheduler = cosine_annealing_scheduler(num_iter, lr_init, lr_min)
    optimizer = optim.SGD(filter(lambda p: p.requires_grad, net.parameters()),
          lr=lr_init, momentum=0.9, weight_decay=0.0001
    )
    set_BN_momentum(net)
    
    start_iter = 0
    start_epoch= 0
    train_loss  = np.zeros(6,np.float32)
    valid_loss  = np.zeros(6,np.float32)
    batch_loss  = np.zeros(6,np.float32)
    rate = 0
    iter = 0
    epoch = 0
    
    #debug
    if 0: #debug  ##-------------------------------
        debug_num = 2
        debug_count = 0
        
        for input, truth, index, cache in train_loader:
            images = input.cpu().data.numpy().squeeze()
            masks  = truth.cpu().data.numpy().squeeze()
            
            batch_size = len(index)
            for b in range(batch_size):
                image = images[b]*255
                image = np.dstack([image,image,image])

                mask = masks[b]
                print(np.max(mask))
                
                # Plot some samples
                fig, (ax0, ax1, ax2) = plt.subplots(ncols=3, figsize=(12, 4))
                ax0.imshow(image.astype(np.uint8))
                ax1.imshow(mask, vmin=0, vmax=1)
                ax1.set_title('Targets')
                
                plt.show()
                
            debug_count += 1
            if debug_count > debug_num:
                break
    #--------------------------------------
    
    start = timer()
    while iter < num_iter:  # loop over the dataset multiple times
        sum_train_loss = np.zeros(6,np.float32)
        sum = 0

        log.write('\n rate    iter   epoch   | valid_loss               | train_loss               | batch_loss               |  time          \n')
        log.write('-------------------------------------------------------------------------------------------------------------------------------\n')
            
        for input, truth, index, cache in train_loader:
            # validation
            if (iter + 1) % iter_valid == 0:
                log.debug('\nval\n')
                net.set_mode('valid')
                valid_loss = validation(net, val_loader)

                net.set_mode('train')
                log.debug('\ntrain\n')
                time.sleep(0.01)
            
            if scheduler is not None:
                lr = scheduler(iter)
                if lr<0 : break
                adjust_learning_rate(optimizer, lr)
                rate = get_learning_rate(optimizer)
            
            # ok, train
            net.set_mode('train')

            input = input.cuda().float()
            truth = truth.cuda().float()
            
            b_masks = (get_boundary(truth).float().cuda() / 255.).float()
            classes = get_class(truth).cuda()

            m_logit, b_logit, c_logit = net(input) #data_parallel(net,input)

            b_loss = net.boundary_criterion(b_logit, b_masks, weights=weights)
            m_loss = net.mask_criterion(m_logit, b_logit.detach(), truth, b_masks.detach(), alpha=5., beta=0.2, weights=weights)
            c_loss = net.class_criterion(c_logit, classes).mean()
            
            # losses are already reduced
            loss = b_loss + m_loss + c_loss
            loss_rec = loss.item()
                
            dice = net.metric(m_logit, truth, noise_th=0, threshold=0, logger=log)
            
            # learn with grad acc
            loss /= n_acc
            loss.backward()
            
            if ((iter + 1) % n_acc) == 0:
                optimizer.step()
                optimizer.zero_grad()
                # torch.nn.utils.clip_grad_norm_(net.parameters(), 1)
            
            # print statistics  ------------
            batch_loss = np.array((
                           loss.item(),
                           dice.item(),
                           0, 0, 0, 0,
                         ))
            sum_train_loss += batch_loss
            sum += 1
            if iter%iter_smooth == 0:
                train_loss = sum_train_loss/sum
                sum_train_loss = np.zeros(6,np.float32)
                sum = 0

            log.write2('\r%0.4f  %5.1f  %6.1f  |  %0.3f  %0.3f  (%0.3f) |  %0.3f  %0.3f  |  %0.3f  %0.3f  |  %0.4f  %0.4f  %0.4f | %s ' % (\
                         rate, iter/iter_per_epoch, epoch+1,
                         valid_loss[0], valid_loss[1], valid_loss[2],
                         train_loss[0], train_loss[1],
                         loss_rec, batch_loss[1],
                         b_loss.item(), m_loss.item(), c_loss.item(),
                         time_to_str((timer() - start), 'min')))
            
            iter += 1
            epoch = iter // iter_per_epoch

In [9]:
def get_dataloaders(data_root, batch_size, fold, nfolds=4, width=1024, height=1024, train_augment=null_augment, val_augment=null_augment, random_state=SEED):
    train_dataset = SIIMDataset(
        data_root,
        fold,
        width=width, height=height,
        phase='train',
        augment=train_augment,
        random_state=random_state,
        nfolds=nfolds
    )
    
    weights = get_weights_for_balanced_classes(train_dataset.cls_list, 2)
    weights = torch.DoubleTensor(weights)
    balance_sampler = torch.utils.data.sampler.WeightedRandomSampler(weights, len(weights))

    train_loader  = DataLoader(
        train_dataset,
        # sampler     = RandomSampler(train_dataset),
        sampler = balance_sampler,
        batch_size  = batch_size,
        drop_last   = True,
        num_workers = 8,
        pin_memory  = True,
        collate_fn  = null_collate
    )

    val_dataset = SIIMDataset(
        data_root,
        fold,
        width=width, height=height,
        phase='val',
        augment=val_augment,
        random_state=random_state,
        nfolds=nfolds
    )

    val_loader  = DataLoader(
        val_dataset,
        sampler     = RandomSampler(val_dataset),
        batch_size  = batch_size,
        drop_last   = False,
        num_workers = 8,
        pin_memory  = True,
        collate_fn  = null_collate
    )
    
    return train_loader, val_loader

## Train

In [10]:
# one fold test!

train_loader, val_loader = get_dataloaders(
    data_root,
    batch_size,
    0, nfolds=10,
    width=SIZE, height=SIZE,
    train_augment=train_augment, val_augment=valid_augment,
    random_state=SEED
)

net = Net().cuda()

lr = 1e-2

In [None]:
# warm up
freeze(net)
fit_one_cycle(
    2, net,
    train_loader, val_loader,
    lr_init=lr, lr_min=lr/100,
    weights=[0.05, 0.95]
)

In [12]:
torch.save(net.state_dict(), 'clsbes-cp-1.pth')

In [11]:
net.load_state_dict(torch.load('clsbes-cp-1.pth'))

IncompatibleKeys(missing_keys=[], unexpected_keys=[])

In [None]:
# ok, train positive
unfreeze(net)
fit_one_cycle(
    50, net,
    train_loader, val_loader,
    lr_init=lr, lr_min=lr/120,
    weights=[0.05, 0.95]
)

In [14]:
torch.save(net.state_dict(), 'clsbes-cp-2.pth')

In [12]:
net.load_state_dict(torch.load('clsbes-cp-2.pth'))

IncompatibleKeys(missing_keys=[], unexpected_keys=[])

In [13]:
lr = 0.01

fit_one_cycle(
    50, net,
    train_loader, val_loader,
    lr_init=lr, lr_min=lr/120,
    weights=[0.01, 0.99]
)


 rate    iter   epoch   | valid_loss               | train_loss               | batch_loss               |  time          

-------------------------------------------------------------------------------------------------------------------------------





0.0100    1.0     1.0  |  0.000  0.000  (0.000) |  0.028  0.061  |  0.786  0.060  |  0.0397  0.0390  0.7070 | 456.3055 



0.0100    1.0     1.0  |  1.069  0.783  (0.850) |  0.028  0.061  |  0.737  0.033  |  0.0471  0.0490  0.6406 | 479.263 
 rate    iter   epoch   | valid_loss               | train_loss               | batch_loss               |  time          

-------------------------------------------------------------------------------------------------------------------------------





0.0100    2.0     2.0  |  1.069  0.783  (0.850) |  0.025  0.086  |  0.486  0.109  |  0.0302  0.0268  0.4286 | 933.6061 



0.0100    2.0     2.0  |  1.464  0.783  (0.900) |  0.025  0.086  |  0.708  0.063  |  0.0409  0.0355  0.6312 | 957.7763 
 rate    iter   epoch   | valid_loss               | train_loss               | batch_loss               |  time          

-------------------------------------------------------------------------------------------------------------------------------





0.0099    3.0     3.0  |  1.464  0.783  (0.900) |  0.026  0.092  |  0.642  0.028  |  0.0296  0.0327  0.5797 | 1412.1341 



0.0099    3.0     3.0  |  1.476  0.783  (0.950) |  0.026  0.092  |  0.744  0.156  |  0.0347  0.0285  0.6805 | 1436.2791 
 rate    iter   epoch   | valid_loss               | train_loss               | batch_loss               |  time          

-------------------------------------------------------------------------------------------------------------------------------





0.0098    4.0     4.0  |  1.476  0.783  (0.950) |  0.023  0.075  |  0.969  0.077  |  0.0380  0.0400  0.8909 | 1890.6757 



0.0098    4.0     4.0  |  2.492  0.783  (0.950) |  0.023  0.075  |  1.346  0.137  |  0.0438  0.0407  1.2614 | 1914.8744 
 rate    iter   epoch   | valid_loss               | train_loss               | batch_loss               |  time          

-------------------------------------------------------------------------------------------------------------------------------





0.0098    5.0     5.0  |  2.492  0.783  (0.950) |  0.032  0.055  |  0.688  0.033  |  0.0314  0.0404  0.6167 | 2369.4901 



0.0098    5.0     5.0  |  1.003  0.783  (0.900) |  0.032  0.055  |  1.352  0.062  |  0.0341  0.0390  1.2790 | 2393.4872 
 rate    iter   epoch   | valid_loss               | train_loss               | batch_loss               |  time          

-------------------------------------------------------------------------------------------------------------------------------





0.0097    6.0     6.0  |  1.003  0.783  (0.900) |  0.026  0.071  |  0.932  0.051  |  0.0467  0.0602  0.8247 | 2847.8056 



0.0097    6.0     6.0  |  0.845  0.783  (0.900) |  0.026  0.071  |  0.738  0.013  |  0.0437  0.0473  0.6473 | 2871.9063 
 rate    iter   epoch   | valid_loss               | train_loss               | batch_loss               |  time          

-------------------------------------------------------------------------------------------------------------------------------





0.0095    7.0     7.0  |  0.845  0.783  (0.900) |  0.023  0.052  |  0.830  0.033  |  0.0431  0.0440  0.7427 | 3326.08 



0.0095    7.0     7.0  |  0.965  0.783  (0.800) |  0.023  0.052  |  0.831  0.121  |  0.0404  0.0447  0.7457 | 3350.2718 
 rate    iter   epoch   | valid_loss               | train_loss               | batch_loss               |  time          

-------------------------------------------------------------------------------------------------------------------------------





0.0094    8.0     8.0  |  0.965  0.783  (0.800) |  0.039  0.087  |  1.158  0.115  |  0.0344  0.0295  1.0943 | 3804.5907 



0.0094    8.0     8.0  |  0.713  0.783  (0.850) |  0.039  0.087  |  0.500  0.029  |  0.0316  0.0361  0.4318 | 3828.7586 
 rate    iter   epoch   | valid_loss               | train_loss               | batch_loss               |  time          

-------------------------------------------------------------------------------------------------------------------------------





0.0092    9.0     9.0  |  0.713  0.783  (0.850) |  0.021  0.075  |  0.818  0.028  |  0.0343  0.0344  0.7489 | 4283.2854 



0.0092    9.0     9.0  |  1.099  0.782  (0.850) |  0.021  0.075  |  0.885  0.074  |  0.0297  0.0287  0.8269 | 4307.4317 
 rate    iter   epoch   | valid_loss               | train_loss               | batch_loss               |  time          

-------------------------------------------------------------------------------------------------------------------------------





0.0091   10.0    10.0  |  1.099  0.782  (0.850) |  0.025  0.066  |  0.714  0.155  |  0.0373  0.0311  0.6458 | 4761.3263 



0.0091   10.0    10.0  |  0.726  0.783  (0.900) |  0.025  0.066  |  0.872  0.064  |  0.0434  0.0424  0.7863 | 4785.4668 
 rate    iter   epoch   | valid_loss               | train_loss               | batch_loss               |  time          

-------------------------------------------------------------------------------------------------------------------------------





0.0089   11.0    11.0  |  0.726  0.783  (0.900) |  0.026  0.074  |  0.557  0.080  |  0.0350  0.0318  0.4902 | 5238.9037 



0.0089   11.0    11.0  |  1.167  0.783  (0.950) |  0.026  0.074  |  0.730  0.068  |  0.0428  0.0385  0.6483 | 5262.6402 
 rate    iter   epoch   | valid_loss               | train_loss               | batch_loss               |  time          

-------------------------------------------------------------------------------------------------------------------------------





0.0087   12.0    12.0  |  1.167  0.783  (0.950) |  0.020  0.075  |  0.750  0.089  |  0.0408  0.0372  0.6722 | 5716.1706 



0.0087   12.0    12.0  |  1.124  0.783  (0.900) |  0.020  0.075  |  0.697  0.087  |  0.0432  0.0402  0.6136 | 5739.9879 
 rate    iter   epoch   | valid_loss               | train_loss               | batch_loss               |  time          

-------------------------------------------------------------------------------------------------------------------------------





0.0084   13.0    13.0  |  1.124  0.783  (0.900) |  0.029  0.079  |  0.521  0.110  |  0.0333  0.0308  0.4569 | 6193.8927 



0.0084   13.0    13.0  |  0.780  0.783  (0.900) |  0.029  0.079  |  1.466  0.126  |  0.0430  0.0467  1.3759 | 6217.5584 
 rate    iter   epoch   | valid_loss               | train_loss               | batch_loss               |  time          

-------------------------------------------------------------------------------------------------------------------------------





0.0082   14.0    14.0  |  0.780  0.783  (0.900) |  0.025  0.089  |  1.254  0.005  |  0.0491  0.0510  1.1534 | 6671.289 



0.0082   14.0    14.0  |  1.372  0.782  (0.850) |  0.025  0.089  |  1.119  0.046  |  0.0256  0.0240  1.0695 | 6694.9936 
 rate    iter   epoch   | valid_loss               | train_loss               | batch_loss               |  time          

-------------------------------------------------------------------------------------------------------------------------------





0.0080   15.0    15.0  |  1.372  0.782  (0.850) |  0.023  0.069  |  1.011  0.081  |  0.0438  0.0465  0.9205 | 7148.8706 



0.0080   15.0    15.0  |  1.278  0.783  (0.900) |  0.023  0.069  |  0.859  0.044  |  0.0286  0.0263  0.8037 | 7172.5123 
 rate    iter   epoch   | valid_loss               | train_loss               | batch_loss               |  time          

-------------------------------------------------------------------------------------------------------------------------------





0.0077   16.0    16.0  |  1.278  0.783  (0.900) |  0.025  0.090  |  0.551  0.036  |  0.0382  0.0376  0.4752 | 7626.6106 



0.0077   16.0    16.0  |  0.714  0.782  (0.950) |  0.025  0.090  |  0.750  0.056  |  0.0398  0.0344  0.6759 | 7650.3787 
 rate    iter   epoch   | valid_loss               | train_loss               | batch_loss               |  time          

-------------------------------------------------------------------------------------------------------------------------------





0.0074   17.0    17.0  |  0.714  0.782  (0.950) |  0.022  0.086  |  0.610  0.099  |  0.0335  0.0355  0.5405 | 8104.9352 



0.0074   17.0    17.0  |  0.939  0.783  (0.950) |  0.022  0.086  |  0.541  0.141  |  0.0297  0.0300  0.4808 | 8128.6206 
 rate    iter   epoch   | valid_loss               | train_loss               | batch_loss               |  time          

-------------------------------------------------------------------------------------------------------------------------------





0.0072   18.0    18.0  |  0.939  0.783  (0.950) |  0.022  0.081  |  1.042  0.069  |  0.0368  0.0395  0.9654 | 8582.4452 



0.0072   18.0    18.0  |  0.728  0.783  (0.900) |  0.022  0.081  |  0.501  0.101  |  0.0276  0.0222  0.4511 | 8606.1684 
 rate    iter   epoch   | valid_loss               | train_loss               | batch_loss               |  time          

-------------------------------------------------------------------------------------------------------------------------------





0.0069   19.0    19.0  |  0.728  0.783  (0.900) |  0.022  0.071  |  0.831  0.061  |  0.0304  0.0274  0.7728 | 9060.4632 



0.0069   19.0    19.0  |  0.787  0.782  (0.950) |  0.022  0.071  |  0.760  0.086  |  0.0417  0.0397  0.6785 | 9084.1434 
 rate    iter   epoch   | valid_loss               | train_loss               | batch_loss               |  time          

-------------------------------------------------------------------------------------------------------------------------------





0.0066   20.0    20.0  |  0.787  0.782  (0.950) |  0.022  0.074  |  0.648  0.032  |  0.0230  0.0191  0.6062 | 9538.2991 



0.0066   20.0    20.0  |  0.706  0.782  (0.850) |  0.022  0.074  |  0.499  0.040  |  0.0419  0.0441  0.4129 | 9561.9641 
 rate    iter   epoch   | valid_loss               | train_loss               | batch_loss               |  time          

-------------------------------------------------------------------------------------------------------------------------------





0.0063   21.0    21.0  |  0.706  0.782  (0.850) |  0.022  0.081  |  0.734  0.130  |  0.0251  0.0236  0.6852 | 10015.8587 



0.0063   21.0    21.0  |  0.712  0.783  (0.900) |  0.022  0.081  |  0.392  0.051  |  0.0356  0.0423  0.3137 | 10039.614 
 rate    iter   epoch   | valid_loss               | train_loss               | batch_loss               |  time          

-------------------------------------------------------------------------------------------------------------------------------





0.0060   22.0    22.0  |  0.712  0.783  (0.900) |  0.021  0.082  |  0.506  0.148  |  0.0316  0.0307  0.4437 | 10493.3159 



0.0060   22.0    22.0  |  0.815  0.783  (0.950) |  0.021  0.082  |  0.652  0.066  |  0.0316  0.0323  0.5882 | 10517.0153 
 rate    iter   epoch   | valid_loss               | train_loss               | batch_loss               |  time          

-------------------------------------------------------------------------------------------------------------------------------





0.0057   23.0    23.0  |  0.815  0.783  (0.950) |  0.023  0.063  |  0.336  0.122  |  0.0295  0.0273  0.2794 | 10971.0995 



0.0057   23.0    23.0  |  0.569  0.782  (0.950) |  0.023  0.063  |  0.488  0.062  |  0.0342  0.0268  0.4269 | 10994.7285 
 rate    iter   epoch   | valid_loss               | train_loss               | batch_loss               |  time          

-------------------------------------------------------------------------------------------------------------------------------





0.0054   24.0    24.0  |  0.569  0.782  (0.950) |  0.022  0.080  |  0.635  0.073  |  0.0410  0.0388  0.5551 | 11447.9639 



0.0054   24.0    24.0  |  0.645  0.783  (0.900) |  0.022  0.080  |  0.784  0.152  |  0.0425  0.0483  0.6931 | 11471.6377 
 rate    iter   epoch   | valid_loss               | train_loss               | batch_loss               |  time          

-------------------------------------------------------------------------------------------------------------------------------





0.0050   25.0    25.0  |  0.645  0.783  (0.900) |  0.021  0.074  |  0.660  0.051  |  0.0340  0.0351  0.5908 | 11924.7748 



0.0050   25.0    25.0  |  0.674  0.783  (0.950) |  0.021  0.074  |  0.448  0.109  |  0.0278  0.0253  0.3947 | 11948.5707 
 rate    iter   epoch   | valid_loss               | train_loss               | batch_loss               |  time          

-------------------------------------------------------------------------------------------------------------------------------





0.0047   26.0    26.0  |  0.674  0.783  (0.950) |  0.018  0.081  |  0.894  0.059  |  0.0326  0.0288  0.8329 | 12401.4312 



0.0047   26.0    26.0  |  0.746  0.783  (0.900) |  0.018  0.081  |  0.867  0.032  |  0.0429  0.0526  0.7719 | 12425.1512 
 rate    iter   epoch   | valid_loss               | train_loss               | batch_loss               |  time          

-------------------------------------------------------------------------------------------------------------------------------





0.0044   27.0    27.0  |  0.746  0.783  (0.900) |  0.020  0.078  |  0.431  0.088  |  0.0244  0.0205  0.3859 | 12878.146 



0.0044   27.0    27.0  |  0.753  0.783  (0.950) |  0.020  0.078  |  0.648  0.090  |  0.0367  0.0316  0.5794 | 12901.8974 
 rate    iter   epoch   | valid_loss               | train_loss               | batch_loss               |  time          

-------------------------------------------------------------------------------------------------------------------------------





0.0041   28.0    28.0  |  0.753  0.783  (0.950) |  0.026  0.070  |  0.934  0.047  |  0.0519  0.0613  0.8204 | 13355.3961 



0.0041   28.0    28.0  |  0.640  0.783  (0.950) |  0.026  0.070  |  0.684  0.135  |  0.0355  0.0307  0.6179 | 13379.1666 
 rate    iter   epoch   | valid_loss               | train_loss               | batch_loss               |  time          

-------------------------------------------------------------------------------------------------------------------------------





0.0038   29.0    29.0  |  0.640  0.783  (0.950) |  0.020  0.089  |  0.417  0.102  |  0.0271  0.0261  0.3634 | 13832.8721 



0.0038   29.0    29.0  |  1.268  0.783  (0.900) |  0.020  0.089  |  0.688  0.081  |  0.0405  0.0396  0.6080 | 13856.6659 
 rate    iter   epoch   | valid_loss               | train_loss               | batch_loss               |  time          

-------------------------------------------------------------------------------------------------------------------------------





0.0035   30.0    30.0  |  1.268  0.783  (0.900) |  0.021  0.085  |  0.478  0.118  |  0.0344  0.0357  0.4081 | 14310.5715 



0.0035   30.0    30.0  |  0.730  0.783  (0.950) |  0.021  0.085  |  0.626  0.111  |  0.0276  0.0239  0.5748 | 14334.3856 
 rate    iter   epoch   | valid_loss               | train_loss               | batch_loss               |  time          

-------------------------------------------------------------------------------------------------------------------------------





0.0032   31.0    31.0  |  0.730  0.783  (0.950) |  0.019  0.072  |  0.560  0.082  |  0.0357  0.0349  0.4894 | 14788.8592 



0.0032   31.0    31.0  |  0.753  0.783  (0.900) |  0.019  0.072  |  0.590  0.054  |  0.0379  0.0396  0.5128 | 14812.6728 
 rate    iter   epoch   | valid_loss               | train_loss               | batch_loss               |  time          

-------------------------------------------------------------------------------------------------------------------------------





0.0029   32.0    32.0  |  0.753  0.783  (0.900) |  0.020  0.073  |  0.738  0.133  |  0.0343  0.0355  0.6678 | 15267.7427 



0.0029   32.0    32.0  |  0.605  0.782  (0.900) |  0.020  0.073  |  0.579  0.066  |  0.0313  0.0304  0.5169 | 15291.667 
 rate    iter   epoch   | valid_loss               | train_loss               | batch_loss               |  time          

-------------------------------------------------------------------------------------------------------------------------------





0.0027   33.0    33.0  |  0.605  0.782  (0.900) |  0.023  0.069  |  0.395  0.037  |  0.0351  0.0398  0.3201 | 15747.3291 



0.0027   33.0    33.0  |  0.865  0.783  (0.850) |  0.023  0.069  |  0.681  0.074  |  0.0229  0.0199  0.6385 | 15771.3238 
 rate    iter   epoch   | valid_loss               | train_loss               | batch_loss               |  time          

-------------------------------------------------------------------------------------------------------------------------------





0.0024   34.0    34.0  |  0.865  0.783  (0.850) |  0.019  0.069  |  0.573  0.030  |  0.0511  0.0668  0.4554 | 16226.4187 



0.0024   34.0    34.0  |  0.743  0.783  (0.850) |  0.019  0.069  |  0.691  0.113  |  0.0405  0.0467  0.6037 | 16250.1739 
 rate    iter   epoch   | valid_loss               | train_loss               | batch_loss               |  time          

-------------------------------------------------------------------------------------------------------------------------------





0.0021   35.0    35.0  |  0.743  0.783  (0.850) |  0.020  0.068  |  0.510  0.037  |  0.0265  0.0253  0.4582 | 16705.3277 



0.0021   35.0    35.0  |  0.663  0.783  (0.900) |  0.020  0.068  |  0.844  0.131  |  0.0391  0.0385  0.7666 | 16729.3093 
 rate    iter   epoch   | valid_loss               | train_loss               | batch_loss               |  time          

-------------------------------------------------------------------------------------------------------------------------------





0.0019   36.0    36.0  |  0.663  0.783  (0.900) |  0.018  0.078  |  0.449  0.059  |  0.0375  0.0292  0.3827 | 17184.0413 



0.0019   36.0    36.0  |  0.762  0.783  (0.900) |  0.018  0.078  |  0.416  0.088  |  0.0358  0.0295  0.3506 | 17207.8065 
 rate    iter   epoch   | valid_loss               | train_loss               | batch_loss               |  time          

-------------------------------------------------------------------------------------------------------------------------------





0.0016   37.0    37.0  |  0.762  0.783  (0.900) |  0.021  0.082  |  0.642  0.095  |  0.0373  0.0392  0.5655 | 17662.2556 



0.0016   37.0    37.0  |  0.621  0.783  (0.900) |  0.021  0.082  |  0.660  0.074  |  0.0403  0.0457  0.5739 | 17685.9534 
 rate    iter   epoch   | valid_loss               | train_loss               | batch_loss               |  time          

-------------------------------------------------------------------------------------------------------------------------------





0.0014   38.0    38.0  |  0.621  0.783  (0.900) |  0.020  0.068  |  0.611  0.174  |  0.0337  0.0317  0.5452 | 18140.2518 



0.0014   38.0    38.0  |  0.593  0.783  (0.950) |  0.020  0.068  |  0.491  0.113  |  0.0305  0.0282  0.4321 | 18163.9906 
 rate    iter   epoch   | valid_loss               | train_loss               | batch_loss               |  time          

-------------------------------------------------------------------------------------------------------------------------------





0.0012   39.0    39.0  |  0.593  0.783  (0.950) |  0.018  0.085  |  0.411  0.079  |  0.0461  0.0528  0.3123 | 18618.1524 



0.0012   39.0    39.0  |  0.621  0.783  (0.850) |  0.018  0.085  |  0.685  0.052  |  0.0397  0.0405  0.6048 | 18642.0267 
 rate    iter   epoch   | valid_loss               | train_loss               | batch_loss               |  time          

-------------------------------------------------------------------------------------------------------------------------------





0.0010   40.0    40.0  |  0.621  0.783  (0.850) |  0.018  0.078  |  0.613  0.081  |  0.0296  0.0265  0.5565 | 19096.1311 



0.0010   40.0    40.0  |  0.695  0.783  (0.900) |  0.018  0.078  |  0.338  0.035  |  0.0245  0.0239  0.2901 | 19120.0902 
 rate    iter   epoch   | valid_loss               | train_loss               | batch_loss               |  time          

-------------------------------------------------------------------------------------------------------------------------------





0.0009   41.0    41.0  |  0.695  0.783  (0.900) |  0.020  0.075  |  0.698  0.016  |  0.0390  0.0394  0.6193 | 19574.0107 



0.0009   41.0    41.0  |  0.662  0.783  (0.900) |  0.020  0.075  |  0.344  0.024  |  0.0285  0.0294  0.2863 | 19597.6919 
 rate    iter   epoch   | valid_loss               | train_loss               | batch_loss               |  time          

-------------------------------------------------------------------------------------------------------------------------------





0.0007   42.0    42.0  |  0.662  0.783  (0.900) |  0.022  0.078  |  0.848  0.073  |  0.0317  0.0278  0.7884 | 20052.1208 



0.0007   42.0    42.0  |  0.773  0.783  (0.900) |  0.022  0.078  |  0.345  0.057  |  0.0338  0.0284  0.2831 | 20075.8574 
 rate    iter   epoch   | valid_loss               | train_loss               | batch_loss               |  time          

-------------------------------------------------------------------------------------------------------------------------------





0.0006   43.0    43.0  |  0.773  0.783  (0.900) |  0.018  0.081  |  0.805  0.100  |  0.0429  0.0476  0.7141 | 20529.3719 



0.0006   43.0    43.0  |  0.664  0.783  (0.900) |  0.018  0.081  |  0.675  0.055  |  0.0378  0.0375  0.6001 | 20553.3151 
 rate    iter   epoch   | valid_loss               | train_loss               | batch_loss               |  time          

-------------------------------------------------------------------------------------------------------------------------------





0.0004   44.0    44.0  |  0.664  0.783  (0.900) |  0.019  0.082  |  0.551  0.041  |  0.0277  0.0250  0.4980 | 21007.1656 



0.0004   44.0    44.0  |  0.650  0.783  (0.900) |  0.019  0.082  |  0.442  0.026  |  0.0370  0.0396  0.3657 | 21031.5368 
 rate    iter   epoch   | valid_loss               | train_loss               | batch_loss               |  time          

-------------------------------------------------------------------------------------------------------------------------------





0.0003   45.0    45.0  |  0.650  0.783  (0.900) |  0.017  0.073  |  0.344  0.104  |  0.0252  0.0210  0.2977 | 21486.0185 



0.0003   45.0    45.0  |  0.653  0.783  (0.900) |  0.017  0.073  |  0.784  0.075  |  0.0292  0.0247  0.7297 | 21509.7466 
 rate    iter   epoch   | valid_loss               | train_loss               | batch_loss               |  time          

-------------------------------------------------------------------------------------------------------------------------------





0.0002   46.0    46.0  |  0.653  0.783  (0.900) |  0.019  0.083  |  0.599  0.063  |  0.0337  0.0333  0.5317 | 21964.3355 



0.0002   46.0    46.0  |  0.658  0.783  (0.900) |  0.019  0.083  |  0.458  0.147  |  0.0243  0.0213  0.4124 | 21988.1555 
 rate    iter   epoch   | valid_loss               | train_loss               | batch_loss               |  time          

-------------------------------------------------------------------------------------------------------------------------------





0.0002   47.0    47.0  |  0.658  0.783  (0.900) |  0.021  0.069  |  0.469  0.081  |  0.0284  0.0264  0.4147 | 22443.395 



0.0002   47.0    47.0  |  0.668  0.783  (0.900) |  0.021  0.069  |  0.674  0.146  |  0.0348  0.0321  0.6072 | 22467.1899 
 rate    iter   epoch   | valid_loss               | train_loss               | batch_loss               |  time          

-------------------------------------------------------------------------------------------------------------------------------





0.0001   48.0    48.0  |  0.668  0.783  (0.900) |  0.018  0.077  |  1.249  0.055  |  0.0425  0.0441  1.1624 | 22921.7901 



0.0001   48.0    48.0  |  0.651  0.783  (0.900) |  0.018  0.077  |  0.696  0.039  |  0.0435  0.0428  0.6101 | 22945.5827 
 rate    iter   epoch   | valid_loss               | train_loss               | batch_loss               |  time          

-------------------------------------------------------------------------------------------------------------------------------





0.0001   49.0    49.0  |  0.651  0.783  (0.900) |  0.019  0.071  |  0.712  0.135  |  0.0368  0.0393  0.6358 | 23400.4856 



0.0001   49.0    49.0  |  0.653  0.783  (0.900) |  0.019  0.071  |  0.469  0.043  |  0.0443  0.0438  0.3809 | 23424.4456 
 rate    iter   epoch   | valid_loss               | train_loss               | batch_loss               |  time          

-------------------------------------------------------------------------------------------------------------------------------





0.0001   50.0    50.0  |  0.653  0.783  (0.900) |  0.019  0.095  |  0.666  0.096  |  0.0419  0.0428  0.5814 | 23879.7736 



0.0001   50.0    50.0  |  0.668  0.783  (0.900) |  0.019  0.095  |  0.881  0.108  |  0.0454  0.0531  0.7828 | 23903.4468 

In [None]:
torch.save(net.state_dict(), 'clsbes-cp-3.pth')

In [None]:
net.load_state_dict(torch.load('clsbes-cp-3.pth'))

## find best threshold

In [10]:
train_loader, val_loader = get_dataloaders(
    data_root,
    batch_size,
    0, nfolds=10,
    width=SIZE, height=SIZE,
    train_augment=train_augment, val_augment=valid_augment,
    random_state=SEED
)

net = Net().cuda()

In [14]:
net.load_state_dict(torch.load('bes-cp-3.pth'))

IncompatibleKeys(missing_keys=[], unexpected_keys=[])

In [15]:
# TTA with flip-lr

def tta(net, loader):
    net.set_mode('test')
    
    all_probs = []
    all_truths = []
    
    with torch.no_grad():
        with tqdm(total=len(loader), file=sys.stdout) as pbar:
            for input, truth, index, cache in loader:
                input = input.cuda()

                logits = net(input)
                probs = F.sigmoid(logits)

                all_probs.append(probs)
                all_truths.append(truth)

                pbar.update(1)
            
        i = 0
        with tqdm(total=len(loader), file=sys.stdout) as pbar:
            # tensor.shape = [N, C, H, W], so we flip dim -1
            for input, truth, index, cache in loader:
                input = input.cuda()
                input = torch.flip(input, [-1])

                logits = net(input)
                probs = F.sigmoid(logits)

                probs = torch.flip(probs, [-1])

                all_probs[i] += probs
                all_probs[i] *= 0.5

                i += 1
                pbar.update(1)

        all_probs = torch.cat(all_probs, dim=0)
        all_truths = torch.cat(all_truths, dim=0)
    
    gc.collect()
    torch.cuda.empty_cache()
    
    return all_probs, all_truths

In [21]:
# a fake tta!

def fake_tta(net, loader):
    net.set_mode('test')
    
    all_m_probs = []
    all_b_probs = []
    all_truths = []
    
    with torch.no_grad():
        with tqdm(total=len(loader), file=sys.stdout) as pbar:
            for input, truth, index, cache in loader:
                input = input.cuda()

                m_logits, b_logits = net(input)
                m_probs = F.sigmoid(m_logits)
                b_probs = F.sigmoid(b_logits)

                all_m_probs.append(m_probs)
                all_b_probs.append(b_probs)
                all_truths.append(truth)

                pbar.update(1)
                
                if 0: #debug  ##-------------------------------
                    images = input.cpu().detach().numpy().squeeze()
                    masks  = truth.cpu().detach().numpy().squeeze()
                    
                    m_results = m_probs.cpu().detach().numpy().squeeze()
                    b_results = b_probs.cpu().detach().numpy().squeeze()
                    
                    batch_size = len(index)
                    
                    for b in range(batch_size):
                        image = images[b]*255
                        image = np.dstack([image,image,image])

                        mask = masks[b]
                        m_result = m_results[b]
                        b_result = b_results[b]
                        
                        # Plot some samples
                        fig, (ax0, ax1, ax2, ax3) = plt.subplots(ncols=4, figsize=(16, 4))
                        
                        ax0.imshow(image.astype(np.uint8))
                        
                        ax1.imshow(mask, vmin=0, vmax=1)
                        ax1.set_title('truth')
                        
                        ax2.imshow(m_result, vmin=0, vmax=1)
                        ax2.set_title('m_prob')
                        
                        ax3.imshow(b_result, vmin=0, vmax=1)
                        ax3.set_title('b_prob')

                        plt.show()
                #--------------------------------------

        all_m_probs = torch.cat(all_m_probs, dim=0)
        all_b_probs = torch.cat(all_b_probs, dim=0)
        all_truths = torch.cat(all_truths, dim=0)
    
    gc.collect()
    torch.cuda.empty_cache()
    
    return all_m_probs, all_b_probs, all_truths

In [17]:
#dice for threshold selection

def dice_overall(preds, targs):
    n = preds.shape[0]
    
    preds = preds.view(n, -1)
    targs = targs.view(n, -1)
    
    targs = (targs > 0.5).long()
    
    intersect = (preds * targs).sum(-1).float()
    union = (preds + targs).sum(-1).float()
    
    # get 1 for both empty pred and targ
    u0 = union==0
    intersect[u0] = 1
    union[u0] = 2
    
    return (2. * intersect / union)

### eval

In [23]:
m_probs, b_probs, truths = fake_tta(net, val_loader)

# noise removal
m_probs[m_probs.view(m_probs.shape[0], -1).sum(-1) < 75*(SIZE/128.0)**2,...] = 0.0

# search best threshold for this fold
scores, best_thrs = [],[]
dices = []
thrs = np.arange(0.01, 1, 0.01)

m_probs = m_probs.cuda()
truths = truths.cuda()

with tqdm(total=len(thrs), file=sys.stdout) as pbar:
    for th in thrs:
        preds = (m_probs>th).long()
        dices.append(dice_overall(preds, truths).mean())

        pbar.update(1)

dices = np.array(dices)    

# save best
scores.append(dices.max())
best_thrs.append(thrs[dices.argmax()])

print(dices, scores, best_thrs)

  0%|          | 0/134 [00:00<?, ?it/s]



  1%|          | 1/134 [00:00<01:05,  2.02it/s]



100%|██████████| 134/134 [00:15<00:00,  8.68it/s]
tensor([ 68.6168, 188.8050,  76.6630,  ...,  37.3683, 405.8776,  53.5727],
       device='cuda:0')
100%|██████████| 99/99 [00:01<00:00, 58.47it/s]
[tensor(0.6575, device='cuda:0') tensor(0.6616, device='cuda:0')
 tensor(0.6641, device='cuda:0') tensor(0.6658, device='cuda:0')
 tensor(0.6672, device='cuda:0') tensor(0.6682, device='cuda:0')
 tensor(0.6692, device='cuda:0') tensor(0.6700, device='cuda:0')
 tensor(0.6707, device='cuda:0') tensor(0.6714, device='cuda:0')
 tensor(0.6720, device='cuda:0') tensor(0.6725, device='cuda:0')
 tensor(0.6730, device='cuda:0') tensor(0.6735, device='cuda:0')
 tensor(0.6740, device='cuda:0') tensor(0.6744, device='cuda:0')
 tensor(0.6748, device='cuda:0') tensor(0.6752, device='cuda:0')
 tensor(0.6755, device='cuda:0') tensor(0.6759, device='cuda:0')
 tensor(0.6762, device='cuda:0') tensor(0.6765, device='cuda:0')
 tensor(0.6768, device='cuda:0') tensor(0.6771, device='cuda:0')
 tensor(0.6774, device=