In [1]:
# ------------Library--------------#
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.utils.data.dataset import Dataset
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from torch.utils.data.sampler import *

import torch.optim as optim
from torch.optim.optimizer import Optimizer, required
from torch.optim.lr_scheduler import MultiStepLR, CosineAnnealingWarmRestarts, CosineAnnealingLR, ReduceLROnPlateau
from torch.nn.parallel.data_parallel import data_parallel
from torch.nn.utils.rnn import *
from torch.cuda.amp import autocast, GradScaler
from torch.autograd import Variable
import segmentation_models_pytorch as smp
import albumentations as A
from albumentations.pytorch import ToTensorV2, ToTensor
from PIL import Image
Image.MAX_IMAGE_PIXELS = None

import tifffile as tiff
import json



#
import pandas as pd
import cv2
import os
import random
import numpy as np
import math
import sys
from collections import defaultdict
import itertools as it
from timeit import default_timer as timer
import matplotlib.pyplot as plt
#
from sklearn.model_selection import KFold
# loss
#from lovasz import lovasz_hinge
#from losses_pytorch.lovasz_loss import LovaszSoftmax
PI  = np.pi
INF = np.inf
EPS = 1e-12



In [2]:
class args:
    # ---- factor ---- #
    amp = True
    gpu = '4,5'
    encoder='b4'#'resnet34'
    decoder='unet'
    diff_arch = True
    encoders = ["efficientnet-b4", "efficientnet-b4", "efficientnet-b4", "efficientnet-b4", "efficientnet-b4"]
    decoders = ["unet", "unet", "unet", "unet", "unet"]
    
    batch_size=8
    weight_decay=1e-6
    epochs=50
    n_fold=5
    fold=0 # [0, 1, 2, 3, 4]
    all_fold_train = True # all fold training
    
    # ---- Dataset ---- #
    image_size=1024 # crop size
    crop_size=image_size
    
    tile_size = 1280
    tile_step = 640
    tile_scale = 0.5
    dataset = f'{tile_scale}_{tile_size}_{tile_step}_train_fold'#'0.25_320_160_train_fold'
    val_dataset = f'{tile_scale}_{tile_size}_{tile_size}_val_fold'
    if diff_arch:
        dir = f'{epochs}_{encoders}_{decoders}_{image_size}_{tile_size}_{tile_step}_{tile_scale}'
    else:
        dir = f'{epochs}_{encoder}_{decoder}_{image_size}_{tile_size}_{tile_step}_{tile_scale}' 
    # ---- optimizer, scheduler .. ---- #
    T_max=10 # CosineAnnealingLR
    opt =  'radam_look' # [adamw, radam_look]
    scheduler='CosineAnnealingLR' #'MultiStepLR' # ['ReduceLROnPlateau', 'CosineAnnealingLR', 'CosineAnnealingWarmRestarts']
    loss = 'bce' # [lovasz, bce, bce_dice, dice]
    factor=0.4 # ReduceLROnPlateau, MultiStepLR
    patience=3 # ReduceLROnPlateau
    eps=1e-6 # ReduceLROnPlateau
    
    decay_epoch = [4, 8, 12]
    T_0=4 # CosineAnnealingWarmRestarts
    #encoder_lr=4e-4
    #decoder_lr=4e-4
    start_lr = 1e-3
    min_lr=1e-6
    #----------------------------------#
    
    
    # ----- 여러 시도 ------#
    clf_head=False # encoder에 classfication head 붙일지 여부
    label_smoothing = False # label smoothing 여부
    multi_gpu=True if len(gpu)>1 else False # multi gpu 사용
    clf_alpha = 0.3 # classification head 의 loss 비율
    smoothing = 0.1 # label smoothing factor
    dice_smoothing = 1 # dice loss 사용시 하이퍼 파라미터
    
    # ---- Else ---- #
    num_workers=8
    seed=42
    
data_dir = '/home/jeonghokim/competition/HubMap/data/'
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
device = torch.device(f"cuda" if torch.cuda.is_available() else "cpu")
##----------------
def set_seeds(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False # for faster training, but not deterministic

# useful function

In [3]:
#-------evaluation metric, loss---------#
###################################
def np_binary_cross_entropy_loss(probability, mask):
    p = probability.reshape(-1)
    t = mask.reshape(-1)

    #---
    logp = -np.log(np.clip(p,1e-6,1))
    logn = -np.log(np.clip(1-p,1e-6,1))
    loss = t*logp +(1-t)*logn
    loss = loss.mean()
    return loss

def np_dice_score(probability, mask):
    p = probability.reshape(-1)
    t = mask.reshape(-1)

    p = p>0.5
    t = t>0.5
    uion = p.sum() + t.sum()
    overlap = (p*t).sum()
    dice = 2*overlap/(uion+0.001)
    return dice

def dice_score(
    outputs: torch.Tensor,
    targets: torch.Tensor,
    eps: float = 1e-7,
    threshold: float = None,):
    """
    Reference:
    https://catalyst-team.github.io/catalyst/_modules/catalyst/dl/utils/criterion/dice.html
    """
    if threshold is not None:
        outputs = (outputs > threshold).float()
        targets = (targets > threshold).float()

    intersection = torch.sum(targets * outputs)
    union = torch.sum(targets) + torch.sum(outputs)
    dice = 2 * intersection / (union + eps)

    return dice
def torch_accuracy(
    outputs: torch.Tensor,
    targets: torch.Tensor,
    eps: float = 1e-7,
    threshold: float = None,):

    if threshold is not None:
        outputs = (outputs > threshold).float()
        
    tp = torch.sum(targets*outputs)/torch.sum(targets)
    tn = torch.sum((1-outputs)*(1-targets))/torch.sum(1-targets)

    return tp, tn

def np_accuracy(probability, mask):
    p = probability.reshape(-1)
    t = mask.reshape(-1)
    p = p>0.5
    t = t>0.5
    tp = (p*t).sum()/((t).sum()+1e-7)
    tn = ((1-p)*(1-t)).sum()/(1-t).sum()
    return tp, tn

def criterion_binary_cross_entropy(logit, mask):
    logit = logit.reshape(-1)
    mask = mask.reshape(-1)

    loss = F.binary_cross_entropy_with_logits(logit, mask)
    return loss

# threshold dice score
def np_dice_score2(probability, mask, threshold):
    p = probability.reshape(-1)
    t = mask.reshape(-1)

    p = p>threshold
    t = t>0.5
    uion = p.sum() + t.sum()
    overlap = (p*t).sum()
    dice = 2*overlap/(uion+0.001)
    return dice

# --------------------
# Loss
# --------------------
class DiceBCELoss(nn.Module):
    # Formula Given above.
    def __init__(self, weight=None, size_average=True):
        super(DiceBCELoss, self).__init__()

    def forward(self, inputs, targets, smooth=args.smoothing):
        
        #comment out if your model contains a sigmoid or equivalent activation layer
        BCE = F.binary_cross_entropy_with_logits(inputs, targets, reduction='mean')
        inputs = F.sigmoid(inputs)       
        
        #flatten label and prediction tensors
        inputs = inputs.view(-1)
        targets = targets.view(-1)
        
        intersection = (inputs * targets).mean()                            
        dice_loss = 1 - (2.*intersection + smooth)/(inputs.mean() + targets.mean() + smooth)  
        
        Dice_BCE = BCE*0.6 + dice_loss*0.4
        
        return Dice_BCE.mean()
class DiceLoss(nn.Module):
    # Formula Given above.
    def __init__(self, weight=None, size_average=True):
        super(DiceLoss, self).__init__()

    def forward(self, inputs, targets, smooth=args.dice_smoothing):
        
        inputs = inputs.view(-1)
        inputs = F.sigmoid(inputs)   
        targets = targets.view(-1)
        
        intersection = (inputs * targets).mean()                            
        dice_loss = 1 - (2.*intersection + smooth)/(inputs.mean() + targets.mean() + smooth)  
                
        return dice_loss.mean()
    
#PyTorch lovasz
def symmetric_lovasz(outputs, targets):
    return 0.5*(lovasz_hinge(outputs, targets) + lovasz_hinge(-outputs, 1.0 - targets))
import torch
import torch.nn as nn


#from torch.autograd import Function
# copy from: https://github.com/Hsuxu/Loss_ToolBox-PyTorch/blob/master/LovaszSoftmax/lovasz_loss.py
def lovasz_grad(gt_sorted):
    """
    Computes gradient of the Lovasz extension w.r.t sorted errors
    See Alg. 1 in paper
    """
    p = len(gt_sorted)
    gts = gt_sorted.sum()
    intersection = gts - gt_sorted.float().cumsum(0)
    union = gts + (1 - gt_sorted).float().cumsum(0)
    jaccard = 1. - intersection / union
    if p > 1:  # cover 1-pixel case
        jaccard[1:p] = jaccard[1:p] - jaccard[0:-1]
    return jaccard
class LovaszSoftmax(nn.Module):
    def __init__(self, reduction='mean'):
        super(LovaszSoftmax, self).__init__()
        self.reduction = reduction

    def prob_flatten(self, input, target):
        assert input.dim() in [4, 5]
        num_class = input.size(1)
        if input.dim() == 4:
            input = input.permute(0, 2, 3, 1).contiguous()
            input_flatten = input.view(-1, num_class)
        elif input.dim() == 5:
            input = input.permute(0, 2, 3, 4, 1).contiguous()
            input_flatten = input.view(-1, num_class)
        target_flatten = target.view(-1)
        return input_flatten, target_flatten

    def lovasz_softmax_flat(self, inputs, targets):
        num_classes = inputs.size(1)
        losses = []
        for c in range(num_classes):
            target_c = (targets == c).float()
            if num_classes == 1:
                input_c = inputs[:, 0]
            else:
                input_c = inputs[:, c]
            loss_c = (torch.autograd.Variable(target_c) - input_c).abs()
            loss_c_sorted, loss_index = torch.sort(loss_c, 0, descending=True)
            target_c_sorted = target_c[loss_index]
            losses.append(torch.dot(loss_c_sorted, torch.autograd.Variable(lovasz_grad(target_c_sorted))))
        losses = torch.stack(losses)

        if self.reduction == 'none':
            loss = losses
        elif self.reduction == 'sum':
            loss = losses.sum()
        else:
            loss = losses.mean()
        return loss

    def forward(self, inputs, targets):
        # print(inputs.shape, targets.shape) # (batch size, class_num, x,y,z), (batch size, 1, x,y,z)
        inputs, targets = self.prob_flatten(inputs, targets)
        # print(inputs.shape, targets.shape)
        losses = self.lovasz_softmax_flat(inputs, targets)
        return losses
class Lovasz_loss(nn.Module):
    def __init__(self):
        super(Lovasz_loss, self).__init__()
        
    def forward(self, inputs, targets):
        return LovaszSoftmax()(inputs, targets)
###################################
#-------ELSE function---------#
###################################

class Logger(object):
    def __init__(self):
        self.terminal = sys.stdout  #stdout
        self.file = None

    def open(self, file, mode=None):
        if mode is None: mode ='w'
        self.file = open(file, mode)

    def write(self, message, is_terminal=1, is_file=1 ):
        if '\r' in message: is_file=0

        if is_terminal == 1:
            self.terminal.write(message)
            self.terminal.flush()
            #time.sleep(1)

        if is_file == 1:
            self.file.write(message)
            self.file.flush()

    def flush(self):
        # this flush method is needed for python 3 compatibility.
        # this handles the flush command by doing nothing.
        # you might want to specify some extra behavior here.
        pass
def print_args(args, logger=None):
    for k, v in vars(args).items():
        if logger is not None:
            logger.write('{:<16} : {}\n'.format(k, v))
        else:
            print('{:<16} : {}'.format(k, v))
def time_to_str(t, mode='min'):
    if mode=='min':
        t  = int(t)/60
        hr = t//60
        min = t%60
        return '%2d hr %02d min'%(hr,min)

    elif mode=='sec':
        t   = int(t)
        min = t//60
        sec = t%60
        return '%2d min %02d sec'%(min,sec)

    else:
        raise NotImplementedError
def get_learning_rate(optimizer):
    lr=[]
    for param_group in optimizer.param_groups:
        lr +=[ param_group['lr'] ]

    assert(len(lr)==1) #we support only one param_group
    lr = lr[0]

    return lr


###########################
#---- label smoothing -----
###########################
class LabelSmoothing(nn.Module):
    def __init__(self, smoothing = args.smoothing):
        super(LabelSmoothing, self).__init__()
        self.smoothing = smoothing

    def forward(self, x, target):
        x = x.float().flatten()
        target = target.float() * (1-self.smoothing) + 0.5 * self.smoothing
        target = target.flatten()


        loss  = F.binary_cross_entropy_with_logits(x, target, reduction='mean')

        return loss.mean()


In [4]:
#-------masking & tile & decode---------#
def read_tiff(image_file):
    """
    *data size*
    e.g.) (3, w, h) or (1,1,3,w,h) or (w, h, 3)  --> transform --> (w, h, 3)
    """
    image = tiff.imread(image_file)
    if image.shape[0] == 1:
        image = image[0][0]
        image = image.transpose(1, 2, 0)
        image = np.ascontiguousarray(image)
    elif image.shape[0] == 3:
        image = image.transpose(1, 2, 0)
        image = np.ascontiguousarray(image)
    return image

def read_mask(mask_file):
    mask = np.array(Image.open(mask_file))
    return mask

def read_json_as_df(json_file):
    with open(json_file) as f:
        j = json.load(f)
    df = pd.json_normalize(j)
    return df


def draw_strcuture(df, height, width, fill=255, structure=[]):
    mask = np.zeros((height, width), np.uint8)
    for row in df.values:
        type  = row[2]  #geometry.type
        coord = row[3]  # geometry.coordinates
        name  = row[4]   # properties.classification.name

        if structure !=[]:
            if not any(s in name for s in structure): continue


        if type=='Polygon':
            pt = np.array(coord).astype(np.int32)
            #cv2.polylines(mask, [coord.reshape((-1, 1, 2))], True, 255, 1)
            cv2.fillPoly(mask, [pt.reshape((-1, 1, 2))], fill)

        if type=='MultiPolygon':
            for pt in coord:
                pt = np.array(pt).astype(np.int32)
                cv2.fillPoly(mask, [pt.reshape((-1, 1, 2))], fill)

    return mask

# resize, cvtcolor, generate mask
# 원하는 object 영역만 따오는 mask
def draw_strcuture_from_hue(image, fill=255, scale=1/32): # 0.25/32 default
    height, width, _ = image.shape
    vv = cv2.resize(image, dsize=None, fx=scale, fy=scale, interpolation=cv2.INTER_LINEAR)
    vv = cv2.cvtColor(vv, cv2.COLOR_RGB2HSV)
    # image_show('v[0]', v[:,:,0])
    # image_show('v[1]', v[:,:,1])
    # image_show('v[2]', v[:,:,2])
    # cv2.waitKey(0)
    mask = (vv[:, :, 1] > 32).astype(np.uint8) # rgb2hsv를 하고나서 1채널에 대해 시행하면 원하는 object만 잘따온다.
    mask = mask*fill
    mask = cv2.resize(mask, dsize=(width, height), interpolation=cv2.INTER_LINEAR) # 다시 원래사이즈로 복구

    return mask

# --- rle ---------------------------------
def rle_decode(rle, height, width , fill=255):
    s = rle.split()
    start, length = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])]
    start -= 1
    mask = np.zeros(height*width, dtype=np.uint8)
    for i, l in zip(start, length):
        mask[i:i+l] = fill
    mask = mask.reshape(width,height).T
    mask = np.ascontiguousarray(mask)
    return mask


def rle_encode(mask):
    m = mask.T.flatten()
    m = np.concatenate([[0], m, [0]])
    run = np.where(m[1:] != m[:-1])[0] + 1
    run[1::2] -= run[::2]
    rle =  ' '.join(str(r) for r in run)
    return rle


# --- tile ---------------------------------
"""
-결국, tile_image, tile_mask만 가져다가 쓴다.
1. scale로 resize를 하고 image size와 step만큼 건너뛰며 이미지를 만든다.
2. 이때 일정 영역이 빈마스크면 데이터에서 제외한다.
3. 쌓은 image와 mask를 return
"""
def to_tile(image, mask, structure, scale, size, step, min_score): 
    half = size//2
    image_small = cv2.resize(image, dsize=None, fx=scale, fy=scale, interpolation=cv2.INTER_LINEAR) # defualt는 1/4만큼 w,h를 줄인다.
    height, width, _ = image_small.shape

    #make score
    structure_small = cv2.resize(structure, dsize=None, fx=scale, fy=scale, interpolation=cv2.INTER_LINEAR)
    vv = structure_small.astype(np.float32)/255

    #make coord
    xx = np.linspace(half, width  - half, int(np.ceil((width  - size) / step)))
    yy = np.linspace(half, height - half, int(np.ceil((height - size) / step)))
    xx = [int(x) for x in xx]
    yy = [int(y) for y in yy]

    coord  = []
    reject = []
    for cy in yy:
        for cx in xx:
            cv = vv[cy - half:cy + half, cx - half:cx + half].mean() # h, w // tiling한 마스크(structure)가 평균 0.25를 안넘으면 버린다.
            if cv>min_score: # min_score ,default:0.25, 0.25의 의미?, 타일링 이미지의 1/4는 object여야 한다는 의미?
                coord.append([cx,cy,cv])
            else:
                reject.append([cx,cy,cv])
    #-----
    if 1: # resize한 image를 tiling 하여 리스트만든다
        tile_image = []
        for cx,cy,cv in coord:
            t = image_small[cy - half:cy + half, cx - half:cx + half] # resize한 image에서 indexing만 하는과정
            assert (t.shape == (size, size, 3))
            tile_image.append(t)

    if mask is not None: # mask를 resize하고 tiling하여 리스트 만든다
        mask_small = cv2.resize(mask, dsize=None, fx=scale, fy=scale, interpolation=cv2.INTER_LINEAR)
        tile_mask = []
        for cx,cy,cv in coord:
            t = mask_small[cy - half:cy + half, cx - half:cx + half]
            assert (t.shape == (size, size))
            tile_mask.append(t)
    else:
        mask_small = None
        tile_mask  = None

    return {
        'image_small': image_small,
        'mask_small' : mask_small,
        'structure_small' : structure_small,
        'tile_image' : tile_image,
        'tile_mask'  : tile_mask,
        'coord'  : coord,
        'reject' : reject,
    }



"""
submission할때 쓰임
"""
def to_mask(tile, coord, height, width, scale, size, step, min_score, aggregate='mean'):

    half = size//2
    mask  = np.zeros((height, width), np.float32)

    if 'mean' in aggregate:
        w = np.ones((size,size), np.float32)

        #if 'sq' in aggregate:
        if 1:
            #https://stackoverflow.com/questions/17190649/how-to-obtain-a-gaussian-filter-in-python
            y,x = np.mgrid[-half:half,-half:half]
            y = half-abs(y)
            x = half-abs(x)
            w = np.minimum(x,y)
            w = w/w.max()#*2.5
            w = np.minimum(w,1)

        #--------------
        count = np.zeros((height, width), np.float32)
        for t, (cx, cy, cv) in enumerate(coord):
            mask [cy - half:cy + half, cx - half:cx + half] += tile[t]*w
            count[cy - half:cy + half, cx - half:cx + half] += w
               # see unet paper for "Overlap-tile strategy for seamless segmentation of arbitrary large images"
        m = (count != 0)
        mask[m] /= count[m]

    if aggregate=='max':
        for t, (cx, cy, cv) in enumerate(coord):
            mask[cy - half:cy + half, cx - half:cx + half] = np.maximum(
                mask[cy - half:cy + half, cx - half:cx + half], tile[t] )

    return mask

# --------------이 아래로 안씀 ------------------------------#



# --draw ------------------------------------------
"""
경계선을 그리게 만든다, 컨투어
하지만 안씀
"""
def mask_to_inner_contour(mask):
    mask = mask>0.5
    pad = np.lib.pad(mask, ((1, 1), (1, 1)), 'reflect')
    contour = mask & (
            (pad[1:-1,1:-1] != pad[:-2,1:-1]) \
          | (pad[1:-1,1:-1] != pad[2:,1:-1])  \
          | (pad[1:-1,1:-1] != pad[1:-1,:-2]) \
          | (pad[1:-1,1:-1] != pad[1:-1,2:])
    )
    return contour


def draw_contour_overlay(image, mask, color=(0,0,255), thickness=1):
    contour =  mask_to_inner_contour(mask)
    if thickness==1:
        image[contour] = color
    else:
        r = max(1,thickness//2)
        for y,x in np.stack(np.where(contour)).T:
            cv2.circle(image, (x,y), r, color, lineType=cv2.LINE_4 )
    return image


# make dataset

In [5]:
# ------ make dataset  new version image fold--------- #
#################################
"""
- robust validation을 위해 overlap 없는 데이터도 만든다
"""
# <todo> make difference scale tile

tile_scale = 0.5
tile_min_score = 0.25
tile_size = 1280#320  # 480 #
tile_average_step = 640#160 #240  # 160 #192
tile_average_step2 = tile_size

#make tile train image
# train,tiling (image,mask) png 저장용도
def run_make_train_tile():

    train_tile_dir = data_dir + f'/tile/{tile_scale}_{tile_size}_{tile_average_step}_train_fold' #nipa2

    df_train = pd.read_csv(data_dir + '/train.csv')
    print(df_train)
    print(df_train.shape)
    
    df_all = []
    
    os.makedirs(train_tile_dir, exist_ok=True)
    for i in range(0,len(df_train)):
        id, encoding = df_train.iloc[i]
        # 1. image 불러오고
        image_file = data_dir + '/train/%s.tiff' % id
        image = read_tiff(image_file)

        height, width = image.shape[:2]
        #mask = rle_decode(encoding, height, width, 255)
        # 2. mask, target 불러오고
        mask_file = data_dir + '/train/%s.mask.png' % id
        mask = read_mask(mask_file)
        
        # 3. 일정영역,object만 표시한 mask불러오기.
        structure = draw_strcuture_from_hue(image, fill=255, scale=tile_scale/32)
        print(id, mask_file)
        
        # make tile
        # 4. 학습할 tile image, mask를 생성한다.
        tile = to_tile(image, mask, structure, tile_scale, tile_size, tile_average_step, tile_min_score)

        coord = np.array(tile['coord'])
        df_image = pd.DataFrame()
        df_image['cx']=coord[:,0].astype(np.int32)
        df_image['cy']=coord[:,1].astype(np.int32)
        df_image['cv']=coord[:,2]

        # --- save ---
        os.makedirs(train_tile_dir+'/%s'%id, exist_ok=True)

        tile_id =[]
        num = len(tile['tile_image'])
        for t in range(num):
            cx,cy,cv   = tile['coord'][t]
            #s = '%s_y%08d_x%08d' % (id, cy, cx)
            s = 'y%08d_x%08d' %(cy, cx)
            tile_id.append(s)

            tile_image = tile['tile_image'][t]
            tile_mask  = tile['tile_mask'][t]
            cv2.imwrite(train_tile_dir + '/%s/%s.png' %(id, s), tile_image)
            cv2.imwrite(train_tile_dir + '/%s/%s.mask.png' %(id, s), tile_mask)


        df_image['tile_id']= [f'{train_tile_dir}/{id}/'+ x for x in tile_id]
        df_all.append(df_image)
    df_all = pd.concat(df_all, 0).reset_index(drop=True)
    df_all[['tile_id','cx','cy','cv']].to_csv(train_tile_dir+'/image_id.csv', index=False)
#------
# maek tile val image
def run_make_val_tile():

    train_tile_dir = data_dir + f'/tile/{tile_scale}_{tile_size}_{tile_average_step2}_val_fold' #nipa2

    df_train = pd.read_csv(data_dir + '/train.csv')
    print(df_train)
    print(df_train.shape)
    
    df_all = []
    
    os.makedirs(train_tile_dir, exist_ok=True)
    for i in range(0,len(df_train)):
        id, encoding = df_train.iloc[i]
        # 1. image 불러오고
        image_file = data_dir + '/train/%s.tiff' % id
        image = read_tiff(image_file)

        height, width = image.shape[:2]
        #mask = rle_decode(encoding, height, width, 255)
        # 2. mask, target 불러오고
        mask_file = data_dir + '/train/%s.mask.png' % id
        mask = read_mask(mask_file)
        
        # 3. 일정영역,object만 표시한 mask불러오기.
        structure = draw_strcuture_from_hue(image, fill=255, scale=tile_scale/32)
        print(id, mask_file)
        
        # make tile
        # 4. 학습할 tile image, mask를 생성한다.
        tile = to_tile(image, mask, structure, tile_scale, tile_size, tile_average_step2, tile_min_score)

        coord = np.array(tile['coord'])
        df_image = pd.DataFrame()
        df_image['cx']=coord[:,0].astype(np.int32)
        df_image['cy']=coord[:,1].astype(np.int32)
        df_image['cv']=coord[:,2]

        # --- save ---
        os.makedirs(train_tile_dir+'/%s'%id, exist_ok=True)

        tile_id =[]
        num = len(tile['tile_image'])
        for t in range(num):
            cx,cy,cv   = tile['coord'][t]
            #s = '%s_y%08d_x%08d' % (id, cy, cx)
            s = 'y%08d_x%08d' %(cy, cx)
            tile_id.append(s)

            tile_image = tile['tile_image'][t]
            tile_mask  = tile['tile_mask'][t]
            cv2.imwrite(train_tile_dir + '/%s/%s.png' %(id, s), tile_image)
            cv2.imwrite(train_tile_dir + '/%s/%s.mask.png' %(id, s), tile_mask)


        df_image['tile_id']= [f'{train_tile_dir}/{id}/'+ x for x in tile_id]
        df_all.append(df_image)
    df_all = pd.concat(df_all, 0).reset_index(drop=True)
    df_all[['tile_id','cx','cy','cv']].to_csv(train_tile_dir+'/image_id.csv', index=False)

    
#make tile train image
# test tiling image png 저장용도
def run_make_test_tile():
    #tile_scale = 0.25
    #tile_min_score = 0.25
    #tile_size = 480#320  # 480 #
    #tile_average_step = 240#160 #240  # 160 #192

    #test_tile_dir = '/home/ubuntu/gwang/hubmap/etc/tile/0.25_640_320_test'
    test_tile_dir = data_dir + f'/tile/{tile_scale}_{tile_size}_{tile_average_step}_test'
    #---


    os.makedirs(test_tile_dir, exist_ok=True)
    assert False, 'todo modify test file'
    for id in ['c68fe75ea','afa5e8098',]:
        print(id)

        # 1. test image load
        image_file = data_dir + '/test/%s.tiff' % id
        json_file  = data_dir + '/test/%s-anatomical-structure.json' % id

        image = read_tiff(image_file)
        height, width = image.shape[:2]

        mask = None
        # 2. test structure load
        structure = draw_strcuture(read_json_as_df(json_file), height, width, structure=['Cortex'])
        # structure = draw_strcuture_from_hue(image, fill=255, scale=tile_scale/32)

        # 3. test를 위한 tile image 생성
        #make tile
        tile = to_tile(image, mask, structure, tile_scale, tile_size, tile_average_step, tile_min_score)

        coord = np.array(tile['coord'])
        df_image = pd.DataFrame()
        df_image['cx']=coord[:,0].astype(np.int32)
        df_image['cy']=coord[:,1].astype(np.int32)
        df_image['cv']=coord[:,2]

        # --- save ---
        os.makedirs(test_tile_dir+'/%s'%id, exist_ok=True)

        tile_id =[]
        num = len(tile['tile_image'])
        for t in range(num):
            cx,cy,cv   = tile['coord'][t]
            s = 'y%08d_x%08d' % (cy, cx)
            tile_id.append(s)

            tile_image = tile['tile_image'][t]
            cv2.imwrite(test_tile_dir + '/%s/%s.png' % (id, s), tile_image)
            #image_show('tile_image', tile_image)
            #cv2.waitKey(1)


        df_image['tile_id']=tile_id
        df_image[['tile_id','cx','cy','cv']].to_csv(test_tile_dir+'/%s.csv'%id, index=False)
        #------


#make tile train image
# tile이 아닌 train image의 mask생성
def run_make_train_mask():

    df_train = pd.read_csv(data_dir + '/train.csv')
    print(df_train)
    print(df_train.shape)

    for i in range(0,len(df_train)):
        id, encoding = df_train.iloc[i]

        image_file = data_dir + '/train/%s.tiff' % id
        image = read_tiff(image_file)

        if image.shape[0]==1:
            image = image[0][0]
            image = image.transpose(1, 2, 0)
            image = np.ascontiguousarray(image)
            height, width = image.shape[:2]
        elif image.shape[0] == 3:
            image = image.transpose(1, 2, 0)
            image = np.ascontiguousarray(image)
            height, width = image.shape[:2]
        else:
            height, width = image.shape[:2]
        mask = rle_decode(encoding, height, width, 255)

        cv2.imwrite(data_dir + '/train/%s.mask.png' % id, mask)


#make tile train image
def run_make_pseudo_tile():

    
    tile_scale = 0.25
    tile_min_score = 0.25
    tile_size = 480#320  #480 #
    tile_average_step = 240 #160 #240  # 192
    #---
    pseudo_tile_dir = data_dir + f'/tile/{tile_scale}_{tile_size}_{tile_average_step}_pseudo_0.95'
    #df_train = pd.read_csv(data_dir + '/train.csv')
    #df_pseudo = pd.read_csv('/root/share1/kaggle/2020/hubmap/result/resnet34/fold2/submit-fold-2-resnet34-00010000_model_lb0.837.csv')
    df_pseudo = pd.read_csv('../../submission/0.891_submission-fold6-00004000_model_thres-0.9.csv')
    
    print(df_pseudo)
    print(df_pseudo.shape)

    os.makedirs(pseudo_tile_dir, exist_ok=True)
    for i in range(0,len(df_pseudo)):
        id, encoding = df_pseudo.iloc[i]

        image_file = data_dir + '/test/%s.tiff' % id
        image = read_tiff(image_file)

        height, width = image.shape[:2]
        mask = rle_decode(encoding, height, width, 255)

        #make tile
        structure = draw_strcuture_from_hue(image, fill=255, scale=tile_scale/32)

        tile = to_tile(image, mask, structure, tile_scale, tile_size, tile_average_step, tile_min_score)
        #to_tile(image, mask, structure, tile_scale, tile_size, tile_average_step, tile_min_score)

        #mask, scale, size, step, min_score

        coord = np.array(tile['coord'])
        df_image = pd.DataFrame()
        df_image['cx']=coord[:,0].astype(np.int32)
        df_image['cy']=coord[:,1].astype(np.int32)
        df_image['cv']=coord[:,2]

        # --- save ---
        os.makedirs(pseudo_tile_dir + '/%s'%id, exist_ok=True)

        tile_id =[]
        num = len(tile['tile_image'])
        for t in range(num):
            cx,cy,cv   = tile['coord'][t]
            s = 'y%08d_x%08d' % (cy, cx)
            tile_id.append(s)

            tile_image = tile['tile_image'][t]
            tile_mask  = tile['tile_mask'][t]
            cv2.imwrite(pseudo_tile_dir + '/%s/%s.png' % (id, s), tile_image)
            cv2.imwrite(pseudo_tile_dir + '/%s/%s.mask.png' % (id, s), tile_mask)

            #image_show('tile_image', tile_image)
            #image_show('tile_mask', tile_mask)
            #cv2.waitKey(1)


        df_image['tile_id']=tile_id
        df_image[['tile_id','cx','cy','cv']].to_csv(pseudo_tile_dir+'/%s.csv'%id, index=False)
        #------

def split_fold():
    
    df = pd.read_csv(f'{data_dir}/tile/{tile_scale}_{tile_size}_{tile_average_step}_train_fold/image_id.csv')
    df2 = pd.read_csv(f'{data_dir}/tile/{tile_scale}_{tile_size}_{tile_average_step2}_val_fold/image_id.csv')

    a = {0 : '0486052bb', 1 : '095bf7a1f', 2 : '1e2425f28', 3 : '26dc41664',
        4 : '2f6ecfcdf', 5 : '4ef6695ce', 6 : '54f2eec69', 7 : '8242609fa',
        8 : 'aaa6a05cc', 9 : 'afa5e8098', 10 :'b2dc8411c', 11: 'b9a3865fc',
        12 :'c68fe75ea', 13: 'cb2d976f4', 14 :'e79de561c'}
    #
    kf = KFold(n_splits=args.n_fold, random_state=args.seed, shuffle=True)
    fold_dict={}
    for n, (t,v) in enumerate(kf.split(a)):
        for f in v:
            fold_dict[a[f]] = n

    df['fold'] = df['tile_id'].apply(lambda x : x.split('/')[-2])
    df['fold'] = df['fold'].apply(lambda x :fold_dict[x])
    
    df2['fold'] = df2['tile_id'].apply(lambda x : x.split('/')[-2])
    df2['fold'] = df2['fold'].apply(lambda x :fold_dict[x])
    
    df.to_csv(f'{data_dir}/tile/{tile_scale}_{tile_size}_{tile_average_step}_train_fold/image_id_split.csv', index=False)
    df2.to_csv(f'{data_dir}/tile/{tile_scale}_{tile_size}_{tile_average_step2}_val_fold/image_id_split.csv', index=False)
    print('saved split fold')
    
# main #################################################################
if 0:
    if __name__ == '__main__':
        #print('started run make train mask')
        # 1.
        print('started 1')
        run_make_train_mask()
        # 2.
        print('started 2')
        run_make_train_tile()
        # 3.
        print('started 3')
        run_make_val_tile()
        
        #print('started 3')
        #run_make_test_tile()
        # 4. if use pseudo datasets
        #run_make_pseudo_tile()
        
        print('split kfold csv')
        split_fold()
    
    

# Dataset & augmentation

In [6]:
#--------------- Dataset ----------------#
##########################################

#--------------- 
# Old version
#--------------- 
def make_image_id_v1(mode):
    train_image_id = {
        0 : '0486052bb', 1 : '095bf7a1f',
        2 : '1e2425f28', 3 : '26dc41664',
        4 : '2f6ecfcdf', 5 : '4ef6695ce',
        6 : '54f2eec69', 7 : '8242609fa',
        8 : 'aaa6a05cc', 9 : 'afa5e8098', 
        10 :'b2dc8411c', 11: 'b9a3865fc',
        12 :'c68fe75ea', 13: 'cb2d976f4',
        14 :'e79de561c'
    }

    test_image_id = {
        0 : '2ec3f1bb9', 1 : '3589adb90',
        2 : '57512b7f1', 3 : 'aa05346ff',
        4 : 'd488c759a',
    }
    if 'pseudo-all'==mode:
        test_id = [ test_image_id[i] for i in [0,1,2,3,4] ]
        return test_id

    if 'test-all'==mode:
        test_id = [ test_image_id[i] for i in [0,1,2,3,4] ] # list(test_image_id.values()) #
        return test_id

    if 'train-all'==mode:
        train_id = [ train_image_id[i] for i in [x for x in train_image_id] ] # list(test_image_id.values()) #
        return train_id

    if 'valid' in mode or 'train' in mode:
        fold = {int(x) for x in mode.split('-')[1].split(',')}
        #valid = [fold,]
        train = list({x for x in train_image_id}-fold)
        valid_id = [ train_image_id[i] for i in fold ]
        train_id = [ train_image_id[i] for i in train ]

        if 'valid' in mode: return valid_id
        if 'train' in mode: return train_id
class HuDataset_v1(Dataset):
    def __init__(self, image_id, image_dir, augment=None):
        self.augment = augment
        self.image_id = image_id
        self.image_dir = image_dir

        tile_id = []
        for i in range(len(image_dir)):
            for id in image_id[i]: 
                df = pd.read_csv(data_dir + '/tile/%s/%s.csv'% (self.image_dir[i],id) )
                tile_id += ('%s/%s/'%(self.image_dir[i],id) + df.tile_id).tolist()

        self.tile_id = tile_id
        self.len =len(self.tile_id)


    def __len__(self):
        return self.len

    def __str__(self):
        string  = ''
        string += '\tlen  = %d\n'%len(self)
        string += '\timage_dir = %s\n'%self.image_dir
        string += '\timage_id  = %s\n'%str(self.image_id)
        string += '\t          = %d\n'%sum(len(i) for i in self.image_id)
        return string


    def __getitem__(self, index):
        id = self.tile_id[index]
        image = cv2.imread(data_dir + '/tile/%s.png'%(id), cv2.IMREAD_COLOR)
        mask  = cv2.imread(data_dir + '/tile/%s.mask.png'%(id), cv2.IMREAD_GRAYSCALE)
        #print(data_dir + '/tile/%s/%s.png'%(self.image_dir,id))

        image = image.astype(np.float32) / 255
        mask  = mask.astype(np.float32) / 255
        r = {
            'index' : index,
            'tile_id' : id,
            'mask' : mask,
            'image' : image,
        }
        if self.augment is not None: r = self.augment(r)
        return r

#--------------- 
# Old version (simple fold)
#--------------- 
def make_image_id_(mode):
    train_image_id = {
        0 : '0486052bb', 1 : '095bf7a1f',
        2 : '1e2425f28', 3 : '26dc41664',
        4 : '2f6ecfcdf', 5 : '4ef6695ce',
        6 : '54f2eec69', 7 : '8242609fa',
        8 : 'aaa6a05cc', 9 : 'afa5e8098', 
        10 :'b2dc8411c', 11: 'b9a3865fc',
        12 :'c68fe75ea', 13: 'cb2d976f4',
        14 :'e79de561c'
    }

    test_image_id = {
        0 : '2ec3f1bb9', 1 : '3589adb90',
        2 : '57512b7f1', 3 : 'aa05346ff',
        4 : 'd488c759a',
    }
    if 'pseudo-all'==mode:
        test_id = [ test_image_id[i] for i in [0,1,2,3,4] ]
        return test_id

    if 'test-all'==mode:
        test_id = [ test_image_id[i] for i in [0,1,2,3,4] ] # list(test_image_id.values()) #
        return test_id

    if 'train-all'==mode:
        train_id = [ train_image_id[i] for i in [x for x in train_image_id] ] # list(test_image_id.values()) #
        return train_id

    if 'valid' in mode or 'train' in mode:
        fold = {int(x) for x in mode.split('-')[1].split(',')}
        #valid = [fold,]
        train = list({x for x in train_image_id}-fold)
        valid_id = [ train_image_id[i] for i in fold ]
        train_id = [ train_image_id[i] for i in train ]

        if 'valid' in mode: return valid_id
        if 'train' in mode: return train_id
class HuDataset_(Dataset):
    def __init__(self, tile_id, augment=None):
        self.augment = augment

        self.tile_id = tile_id
        self.len =len(self.tile_id)


    def __len__(self):
        return self.len

    def __str__(self):
        string  = ''
        string += '\tlen  = %d\n'%len(self)
        return string


    def __getitem__(self, index):
        id = self.tile_id[index]
        image = cv2.imread(f'{data_dir}/tile/{args.dataset}/{id}.png', cv2.IMREAD_COLOR)
        mask  = cv2.imread(f'{data_dir}/tile/{args.dataset}/{id}.mask.png', cv2.IMREAD_GRAYSCALE)
        #print(data_dir + '/tile/%s/%s.png'%(self.image_dir,id))

        image = image.astype(np.float32) / 255
        mask  = mask.astype(np.float32) / 255
        r = {
            'index' : index,
            'tile_id' : id,
            'mask' : mask,
            'image' : image,
        }
        if self.augment is not None: r = self.augment()
        
        return r

#--------------- 
# New version(image fold)
#--------------- 
def make_image_id(mode):
    train_image_id = {
        0 : '0486052bb', 1 : '095bf7a1f',
        2 : '1e2425f28', 3 : '26dc41664',
        4 : '2f6ecfcdf', 5 : '4ef6695ce',
        6 : '54f2eec69', 7 : '8242609fa',
        8 : 'aaa6a05cc', 9 : 'afa5e8098', 
        10 :'b2dc8411c', 11: 'b9a3865fc',
        12 :'c68fe75ea', 13: 'cb2d976f4',
        14 :'e79de561c'
    }

    test_image_id = {
        0 : '2ec3f1bb9', 1 : '3589adb90',
        2 : '57512b7f1', 3 : 'aa05346ff',
        4 : 'd488c759a',
    }
    if 'pseudo-all'==mode:
        test_id = [ test_image_id[i] for i in [0,1,2,3,4] ]
        return test_id

    if 'test-all'==mode:
        test_id = [ test_image_id[i] for i in [0,1,2,3,4] ] # list(test_image_id.values()) #
        return test_id

    if 'train-all'==mode:
        train_id = [ train_image_id[i] for i in [x for x in train_image_id] ] # list(test_image_id.values()) #
        return train_id

    if 'valid' in mode or 'train' in mode:
        fold = {int(x) for x in mode.split('-')[1].split(',')}
        #valid = [fold,]
        train = list({x for x in train_image_id}-fold)
        valid_id = [ train_image_id[i] for i in fold ]
        train_id = [ train_image_id[i] for i in train ]

        if 'valid' in mode: return valid_id
        if 'train' in mode: return train_id
class HuDataset(Dataset):
    def __init__(self, df, augment=None):
        self.augment = augment

        #self.tile_id = tile_id
        #self.len =len(self.tile_id)
        self.df = df


    def __len__(self):
        return len(self.df)

    def __str__(self):
        string  = ''
        string += '\tlen  = %d\n'%len(self)
        return string


    def __getitem__(self, index):
        id = self.df['tile_id'].loc[index]
        image = cv2.imread(f'{id}.png', cv2.IMREAD_COLOR)
        mask  = cv2.imread(f'{id}.mask.png', cv2.IMREAD_GRAYSCALE)
        #print(data_dir + '/tile/%s/%s.png'%(self.image_dir,id))

        image = image.astype(np.float32) / 255
        mask  = mask.astype(np.float32) / 255
        r = {
            'index' : index,
            'tile_id' : id,
            'mask' : mask,
            'image' : image,
        }
        if self.augment is not None: r = self.augment(r)
        #if self.augment is not None: r = self.augment(image=r['image'], mask=r['mask'])
        return r

def null_collate(batch):
    batch_size = len(batch)
    index = []
    mask = []
    image = []
    for r in batch:
        index.append(r['index'])
        mask.append(r['mask'])
        image.append(r['image'])

    image = np.stack(image)
    image = image[...,::-1]
    image = image.transpose(0,3,1,2)
    image = np.ascontiguousarray(image)

    mask  = np.stack(mask)
    mask  = np.ascontiguousarray(mask)

    #---
    image = torch.from_numpy(image).contiguous().float()
    mask  = torch.from_numpy(mask).contiguous().unsqueeze(1)
    mask  = (mask>0.5).float()

    return {
        'index' : index,
        'mask' : mask,
        'image' : image,
    }

In [7]:
#---------- augmentation ---------------------#
###############################################
#flip
def do_random_flip_transpose(image, mask):
    if np.random.rand()>0.5:
        image = cv2.flip(image,0)
        mask = cv2.flip(mask,0)
    if np.random.rand()>0.5:
        image = cv2.flip(image,1)
        mask = cv2.flip(mask,1)
    if np.random.rand()>0.5:
        image = image.transpose(1,0,2)
        mask = mask.transpose(1,0)

    image = np.ascontiguousarray(image)
    mask = np.ascontiguousarray(mask)
    return image, mask

#geometric
def do_random_crop(image, mask, size):
    height, width = image.shape[:2]
    x = np.random.choice(width -size)
    y = np.random.choice(height-size)
    image = image[y:y+size,x:x+size]
    mask  = mask[y:y+size,x:x+size]
    return image, mask

def do_random_scale_crop(image, mask, size, mag):
    height, width = image.shape[:2]

    s = 1 + np.random.uniform(-1, 1)*mag
    s =  int(s*size)

    x = np.random.choice(width -s)
    y = np.random.choice(height-s)
    image = image[y:y+s,x:x+s]
    mask  = mask[y:y+s,x:x+s]
    if s!=size:
        image = cv2.resize(image, dsize=(size,size), interpolation=cv2.INTER_LINEAR)
        mask  = cv2.resize(mask, dsize=(size,size), interpolation=cv2.INTER_LINEAR)
    return image, mask

def do_random_rotate_crop(image, mask, size, mag=30 ):
    angle = 1+np.random.uniform(-1, 1)*mag

    height, width = image.shape[:2]
    dst = np.array([
        [0,0],[size,size], [size,0], [0,size],
    ])

    c = np.cos(angle/180*2*PI)
    s = np.sin(angle/180*2*PI)
    src = (dst-size//2)@np.array([[c, -s],[s, c]]).T
    src[:,0] -= src[:,0].min()
    src[:,1] -= src[:,1].min()

    src[:,0] = src[:,0] + np.random.uniform(0,width -src[:,0].max())
    src[:,1] = src[:,1] + np.random.uniform(0,height-src[:,1].max())

    if 0: #debug
        def to_int(f):
            return (int(f[0]),int(f[1]))

        cv2.line(image, to_int(src[0]), to_int(src[1]), (0,0,1), 16)
        cv2.line(image, to_int(src[1]), to_int(src[2]), (0,0,1), 16)
        cv2.line(image, to_int(src[2]), to_int(src[3]), (0,0,1), 16)
        cv2.line(image, to_int(src[3]), to_int(src[0]), (0,0,1), 16)
        image_show_norm('image', image, min=0, max=1)
        cv2.waitKey(1)


    transform = cv2.getAffineTransform(src[:3].astype(np.float32), dst[:3].astype(np.float32))
    image = cv2.warpAffine( image, transform, (size, size), flags=cv2.INTER_LINEAR,
                                 borderMode=cv2.BORDER_CONSTANT, borderValue=(0,0,0))
    mask  = cv2.warpAffine( mask, transform, (size, size), flags=cv2.INTER_LINEAR,
                                 borderMode=cv2.BORDER_CONSTANT, borderValue=0)
    return image, mask

#warp/elastic deform ...
#<todo>

#noise
def do_random_noise(image, mask, mag=0.1):
    height, width = image.shape[:2]
    noise = np.random.uniform(-1,1, (height, width,1))*mag
    image = image + noise
    image = np.clip(image,0,1)
    return image, mask


#intensity
def do_random_contast(image, mask, mag=0.3):
    alpha = 1 + random.uniform(-1,1)*mag
    image = image * alpha
    image = np.clip(image,0,1)
    return image, mask

def do_random_gain(image, mask, mag=0.3):
    alpha = 1 + random.uniform(-1,1)*mag
    image = image ** alpha
    image = np.clip(image,0,1)
    return image, mask

def do_random_hsv(image, mask, mag=[0.15,0.25,0.25]):
    image = (image*255).astype(np.uint8)
    hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)

    h = hsv[:, :, 0].astype(np.float32)  # hue
    s = hsv[:, :, 1].astype(np.float32)  # saturation
    v = hsv[:, :, 2].astype(np.float32)  # value
    h = (h*(1 + random.uniform(-1,1)*mag[0]))%180
    s =  s*(1 + random.uniform(-1,1)*mag[1])
    v =  v*(1 + random.uniform(-1,1)*mag[2])

    hsv[:, :, 0] = np.clip(h,0,180).astype(np.uint8)
    hsv[:, :, 1] = np.clip(s,0,255).astype(np.uint8)
    hsv[:, :, 2] = np.clip(v,0,255).astype(np.uint8)
    image = cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR)
    image = image.astype(np.float32)/255
    return image, mask


def filter_small(mask, min_size):

    m = (mask*255).astype(np.uint8)

    num_comp, comp, stat, centroid = cv2.connectedComponentsWithStats(m, connectivity=8)
    if num_comp==1: return mask

    filtered = np.zeros(comp.shape,dtype=np.uint8)
    area = stat[:, -1]
    for i in range(1, num_comp):
        if area[i] >= min_size:
            filtered[comp == i] = 255
    return filtered

In [8]:
#---------- optimizer, scheduler ---------------------#
############################################
class Lookahead(Optimizer):
    def __init__(self, optimizer, alpha=0.5, k=6):

        if not 0.0 <= alpha <= 1.0:
            raise ValueError(f'Invalid slow update rate: {alpha}')
        if not 1 <= k:
            raise ValueError(f'Invalid lookahead steps: {k}')

        self.optimizer = optimizer
        self.param_groups = self.optimizer.param_groups
        self.alpha = alpha
        self.k = k
        for group in self.param_groups:
            group["step_counter"] = 0

        self.slow_weights = [
                [p.clone().detach() for p in group['params']]
            for group in self.param_groups]

        for w in it.chain(*self.slow_weights):
            w.requires_grad = False
        self.state = optimizer.state

    def step(self, closure=None):
        loss = None
        if closure is not None:
            loss = closure()
        loss = self.optimizer.step()

        for group,slow_weights in zip(self.param_groups,self.slow_weights):
            group['step_counter'] += 1
            if group['step_counter'] % self.k != 0:
                continue
            for p,q in zip(group['params'],slow_weights):
                if p.grad is None:
                    continue
                q.data.add_(p.data - q.data, alpha=self.alpha )
                p.data.copy_(q.data)
        return loss
class RAdam(Optimizer):

    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0):
        defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
        self.buffer = [[None, None, None] for ind in range(10)]
        super(RAdam, self).__init__(params, defaults)

    def __setstate__(self, state):
        super(RAdam, self).__setstate__(state)

    def step(self, closure=None):

        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:

            for p in group['params']:
                if p.grad is None:
                    continue
                grad = p.grad.data.float()
                if grad.is_sparse:
                    raise RuntimeError('RAdam does not support sparse gradients')

                p_data_fp32 = p.data.float()

                state = self.state[p]

                if len(state) == 0:
                    state['step'] = 0
                    state['exp_avg'] = torch.zeros_like(p_data_fp32)
                    state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
                else:
                    state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
                    state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)

                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
                beta1, beta2 = group['betas']

                exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value = 1 - beta2)
                exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)

                state['step'] += 1
                buffered = self.buffer[int(state['step'] % 10)]
                if state['step'] == buffered[0]:
                    N_sma, step_size = buffered[1], buffered[2]
                else:
                    buffered[0] = state['step']
                    beta2_t = beta2 ** state['step']
                    N_sma_max = 2 / (1 - beta2) - 1
                    N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
                    buffered[1] = N_sma

                    # more conservative since it's an approximated value
                    if N_sma >= 5:
                        step_size = math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step'])
                    else:
                        step_size = 1.0 / (1 - beta1 ** state['step'])
                    buffered[2] = step_size

                if group['weight_decay'] != 0:
                    p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)

                # more conservative since it's an approximated value
                if N_sma >= 5:
                    denom = exp_avg_sq.sqrt().add_(group['eps'])
                    p_data_fp32.addcdiv_(exp_avg, denom, value=-step_size * group['lr'])
                else:
                    p_data_fp32.add_(exp_avg, alpha=-step_size * group['lr'])

                p.data.copy_(p_data_fp32)

        return loss

#---------- scheduler ---------------------#
def get_scheduler(optimizer):
    if args.scheduler =='CosineAnnealingWarmRestarts':
        scheduler = CosineAnnealingWarmRestarts(optimizer, T_0 = args.epochs//args.T_0, T_mult=1, eta_min=0, last_epoch=-1)
    elif args.scheduler == 'CosineAnnealingLR':
        scheduler = CosineAnnealingLR(optimizer, T_max=args.T_max, eta_min=args.min_lr, last_epoch=-1)
    elif args.scheduler == 'ReduceLROnPlateau':
        scheduler = ReduceLROnPlateau(optimizer, mode='max', factor=args.factor, patience=args.patience, verbose=True, 
                                      min_lr = args.min_lr, eps=args.eps)
    else:
        scheduler=None
        assert False, 'not implement'

    return scheduler

# Model

In [9]:
class DOWNBLOCK(nn.Module):
    def __init__(self):
        super(DOWNBLOCK, self).__init__()
        self.down_conv1 = nn.Conv2d(3, 32, kernel_size=7, stride=2, padding=3, bias=False)
        self.down_bn1 = nn.BatchNorm2d(32)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.relu(self.down_bn1(self.down_conv1(x)))
        return x
    
def SegModel():
    models = []
    # 다른 모덷들일때
    if args.diff_arch:
        for i in range(args.n_fold):
            en_name = args.encoders[i]
            de_name = args.decoders[i]
            # decoder별로 로드
            if de_name.lower() == "unet":
                if args.clf_head:
                    print('classification head')
                    aux_params=dict(
                        pooling='avg',             # one of 'avg', 'max'
                        dropout=0.5,               # dropout ratio, default is None
                        activation='sigmoid',      # activation function, default is None
                        classes=1,                 # define number of output labels
                    )
                    model = smp.Unet(
                        encoder_name=en_name,        # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
                        encoder_weights='imagenet',     # use `imagenet` pretrained weights for encoder initialization
                        in_channels=3,                  # model input channels (1 for grayscale images, 3 for RGB, etc.)
                        classes=1,                      # model output channels (number of classes in your dataset)
                        aux_params=aux_params
                        )
                else:
                    model = smp.Unet(
                        encoder_name=en_name,        # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
                        encoder_weights='imagenet',     # use `imagenet` pretrained weights for encoder initialization
                        in_channels=3,                  # model input channels (1 for grayscale images, 3 for RGB, etc.)
                        classes=1,                      # model output channels (number of classes in your dataset)
                        )
            elif de_name.lower() == "fpn":
                model = smp.FPN(
                    encoder_name=en_name,
                    encoder_weights="imagenet",
                    in_channels=3,
                    classes=1
                )
            elif de_name.lower() == "upp":
                model = smp.UnetPlusPlus(
                    encoder_name=en_name,
                    encoder_weights="imagenet",
                    in_channels=3,
                    classes=1
                )
            elif de_name.lower() == "linknet":
                model = smp.Linknet(
                    encoder_name=en_name,
                    encoder_weights="imagenet",
                    in_channels=3,
                    classes=1
                )
            else:
                raise NotImplementedError
            models.append(model)
                
        
    # 같은 모델 일 때 5개 복사
    else:
        if args.encoder in ['b0','b1','b2','b3','b4','b5','b6','b7']:
            encoder_name_ = f'efficientnet-{args.encoder}' #'timm-efficientnet-b4'
            print('encoder : ', encoder_name_)
        else:
            encoder_name_ = args.encoder
        if args.decoder =='fpn':
            print('fpn loaded')
            model = smp.FPN(
                encoder_name=encoder_name_,        # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
                encoder_weights='imagenet',     # use `imagenet` pretrained weights for encoder initialization
                in_channels=3,                  # model input channels (1 for grayscale images, 3 for RGB, etc.)
                classes=1,                      # model output channels (number of classes in your dataset)
                )
        elif args.decoder =='unet':
            print('unet loaded')
            if args.clf_head:
                print('classification head')
                aux_params=dict(
                    pooling='avg',             # one of 'avg', 'max'
                    dropout=0.5,               # dropout ratio, default is None
                    activation='sigmoid',      # activation function, default is None
                    classes=1,                 # define number of output labels
                )
                model = smp.Unet(
                    encoder_name=encoder_name_,        # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
                    encoder_weights='imagenet',     # use `imagenet` pretrained weights for encoder initialization
                    in_channels=3,                  # model input channels (1 for grayscale images, 3 for RGB, etc.)
                    classes=1,                      # model output channels (number of classes in your dataset)
                    aux_params=aux_params
                    )
            else:
                model = smp.Unet(
                    encoder_name=encoder_name_,        # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
                    encoder_weights='imagenet',     # use `imagenet` pretrained weights for encoder initialization
                    in_channels=3,                  # model input channels (1 for grayscale images, 3 for RGB, etc.)
                    classes=1,                      # model output channels (number of classes in your dataset)
                    )
                #list_ = [DOWNBLOCK(), model, nn.Upsample(size=640, mode='bilinear', align_corners=True)]
                #model = nn.Sequential(*list_)
        if args.encoder=='R50':
            if args.decoder=='ViT':
                vit_name='R50-ViT-B_16'
                config_vit = CONFIGS[vit_name]
                config_vit.n_classes = 1
                config_vit.n_skip = 3
                if vit_name.find('R50') != -1:
                    config_vit.patches.grid = (int(args.image_size / 16), int(args.image_size / 16))
                model = VisionTransformer(config_vit, img_size=args.image_size, num_classes=config_vit.n_classes) 
        for i in range(args.n_fold):
            models.append(model)
    return models

# train

In [21]:
# ------- image fold train version ------- #

# augmentation
#"""현재 crop 없는상태"""
def train_augment(record):
    image = record['image']
    mask  = record['mask']
    
    for fn in np.random.choice([
        lambda image, mask : do_random_rotate_crop(image, mask, size=args.crop_size, mag=45),
        lambda image, mask : do_random_scale_crop(image, mask, size=args.crop_size, mag=0.075),
        lambda image, mask : do_random_crop(image, mask, size=args.crop_size),
    ],1): image, mask = fn(image, mask)

    #if (np.random.choice(10,1)<7)[0]:
    for fn in np.random.choice([
        lambda image, mask : (image, mask),
        lambda image, mask : do_random_contast(image, mask, mag=0.8),
        lambda image, mask : do_random_gain(image, mask, mag=0.9),
        #lambda image, mask : do_random_hsv(image, mask, mag=[0.1, 0.2, 0]),
        lambda image, mask : do_random_noise(image, mask, mag=0.1),
    ],2): image, mask =  fn(image, mask)
    #if (np.random.choice(10,1)<7)[0]:
    image, mask = do_random_hsv(image, mask, mag=[0.1, 0.2, 0])
    image, mask = do_random_flip_transpose(image, mask)

    record['mask'] = mask
    record['image'] = image
    return record
#그냥 데이터 로더 3개 만들어서 이미지별로 각각 계산해서 평균하자..
def do_valid(net, valid_loader):

    valid_num = 0
    total = 0 ; dice=0 ; loss=0 ; tp = 0 ; tn = 0
    dice2=0 ; loss2=0
    valid_probability, valid_probability2, valid_probability3 = [],[],[]
    valid_mask, valid_mask2, valid_mask3 = [],[],[]

    net = net.eval()

    #start_timer = timer()
    with torch.no_grad():
        for t, batch in enumerate(valid_loader):
            batch_size = len(batch['index'])
            mask  = batch['mask']
            image = batch['image'].to(device)
            
            if args.clf_head:
                logit, _ = net(image) # seg, clf
            else:
                logit = net(image)#data_parallel(net, image) #net(input)#
            probability = torch.sigmoid(logit)
                
            valid_probability.append(probability.data.cpu().numpy())
            valid_mask.append(mask.data.cpu().numpy())

    #assert(valid_num == len(valid_loader.dataset)) # drop last True이면 assert되는거임
    probability = np.concatenate(valid_probability)
    mask = np.concatenate(valid_mask)
    if args.loss =='bce':
        loss = np_binary_cross_entropy_loss(probability, mask)
    elif args.loss =='lovasz':
        loss = 0
    
    # mean loss, dice ..
    dice = np_dice_score(probability, mask)
    tp, tn = np_accuracy(probability, mask)

    return [dice, loss,  tp, tn]

def run_train(args):
    out_dir = data_dir + f'/result/{args.dir}_{args.encoder}_{args.image_size}'

    ## setup  ----------------------------------------
    for f in ['checkpoint','train','valid'] : os.makedirs(out_dir +'/'+f, exist_ok=True)
    #backup_project_as_zip(PROJECT_PATH, out_dir +'/backup/code.train.%s.zip'%IDENTIFIER)
    log = Logger()
    log.open(out_dir+'/log.train.txt',mode='a')

    # my log argument
    print_args(args, log)

    log.write('\tout_dir  = %s\n' % out_dir)
    log.write('\n')


    log.write('** dataset setting **\n')
    #-----------dataset split --------------------#
    tile_id = []
    image_dir_ = f'{args.dataset}'#'0.25_320_160_train'
    image_dir=[image_dir_, ] # pseudo할때 뒤에 추가
    
    image_dir_val_ = f'{args.val_dataset}'#'0.25_320_320_val'
    image_dir_val=[image_dir_val_, ]
    
    for i in range(len(image_dir)):
        df = pd.read_csv(data_dir + '/tile/%s/image_id_split.csv'% (image_dir[i]) )

    for i in range(len(image_dir_val)):
        df2 = pd.read_csv(data_dir + '/tile/%s/image_id_split.csv'% (image_dir_val[i]) )
    df2['img_id'] = df2['tile_id'].apply(lambda x: x.split('/')[-2])
        
    kf = KFold(n_splits=args.n_fold, random_state=args.seed, shuffle=True)
    all_dice = []
    models = SegModel()
    for n_fold, (trn_idx, val_idx) in enumerate(kf.split(df)):
        train_df = df[df['fold']!= n_fold].reset_index(drop=True)
        val_df = df2[df2['fold']== n_fold].reset_index(drop=True).copy()
        
        # validation loader 3개 만들기 위함
        unique_value = val_df['tile_id'].apply(lambda x: x.split('/')[-2]).unique() #[valid_id1, valid_id2, valid_id3 ]
        val_img_id1 = unique_value[0] ; val_img_id2 = unique_value[1] ; val_img_id3= unique_value[2]
        val_df1= val_df[val_df['img_id']==val_img_id1].reset_index(drop=True)
        val_df2= val_df[val_df['img_id']==val_img_id2].reset_index(drop=True)
        val_df3= val_df[val_df['img_id']==val_img_id3].reset_index(drop=True)
        #####################################################
        train_dataset = HuDataset(
            df = train_df,
            augment = train_augment
        )
        train_loader  = DataLoader(
            train_dataset,
            sampler = RandomSampler(train_dataset),
            batch_size  = args.batch_size,
            drop_last   = False,
            num_workers = 8,
            pin_memory  = True,
            collate_fn  = null_collate
        )
        # val loader1
        valid_dataset1 = HuDataset(
            df = val_df1
            ,
        )
        valid_loader1 = DataLoader(
            valid_dataset1,
            sampler = SequentialSampler(valid_dataset1),
            batch_size  = args.batch_size,
            drop_last   = False,
            num_workers = 4,
            pin_memory  = True,
            collate_fn  = null_collate
        )
        # val loader2
        valid_dataset2 = HuDataset(
            df = val_df2
            ,
        )
        
        valid_loader2 = DataLoader(
            valid_dataset2,
            sampler = SequentialSampler(valid_dataset2),
            batch_size  = args.batch_size,
            drop_last   = False,
            num_workers = 4,
            pin_memory  = True,
            collate_fn  = null_collate
        )
        # val loader3
        valid_dataset3 = HuDataset(
            df = val_df3
            ,
        )
        valid_loader3 = DataLoader(
            valid_dataset3,
            sampler = SequentialSampler(valid_dataset3),
            batch_size  = args.batch_size,
            drop_last   = False,
            num_workers = 4,
            pin_memory  = True,
            collate_fn  = null_collate
        )
        log.write('fold = %s\n'%str(n_fold))
        log.write('train_dataset : \n%s\n'%(train_dataset))
        log.write('valid_dataset1 : \n%s\n'%(valid_dataset1))
        log.write('valid_dataset2 : \n%s\n'%(valid_dataset2))
        log.write('valid_dataset3 : \n%s\n'%(valid_dataset3))
        log.write('\n')

        # ------------------------
        #  Model
        # ------------------------
        log.write('** net setting **\n')

        scaler = GradScaler()
        models = SegModel()
        
        net = models[n_fold]
        net = net.to(device)
        if args.multi_gpu:
            log.write('multi gpu')
            net = nn.DataParallel(net)
        
        
        # ------------------------
        #  Optimizer
        # ------------------------
        if args.opt =='adamw':
            optimizer = torch.optim.AdamW(net.parameters(), lr = args.start_lr)

        elif args.opt =='radam_look':
            optimizer = Lookahead(RAdam(filter(lambda p: p.requires_grad, net.parameters()),lr=args.start_lr), alpha=0.5, k=5)
        if optimizer == None:
            assert False, 'no have optimizer'
        
        # ------------------------
        #  scheduler
        # ------------------------
        scheduler = get_scheduler(optimizer)


        log.write('optimizer\n  %s\n'%(optimizer))
        #log.write('schduler\n  %s\n'%(schduler))
        log.write('\n')

        ## start training here! ##############################################
        #array([0.57142857, 0.42857143])
        log.write('** start training here! **\n')
        log.write('   is_mixed_precision = %s \n'%str(args.amp))
        log.write('   batch_size = %d \n'%(args.batch_size))
        log.write('             |-------------- VALID---------|---- TRAIN/BATCH ----------------\n')
        log.write('rate  epoch  | dice   loss   tp     tn     | loss           | time           \n')
        log.write('-------------------------------------------------------------------------------------\n')
                  #0.00100   0.50  0.80 | 0.891  0.020  0.000  0.000  | 0.000  0.000   |  0 hr 02 min

        def message(mode='print'):
            if mode==('print'):
                asterisk = ' '
                loss = batch_loss
            if mode==('log'):
                asterisk = '*'
                loss = train_loss

            text = \
                '%0.5f  %s%s    | '%(rate, epoch, asterisk,) +\
                '%4.3f  %4.3f  %4.3f  %4.3f  | '%(*valid_loss,) +\
                '%4.3f  %4.3f   | '%(*loss,) +\
                '%s' % (time_to_str(timer() - start_timer,'min'))

            return text

        #----
        valid_loss = np.zeros(4,np.float32)
        train_loss = np.zeros(2,np.float32)
        batch_loss = np.zeros_like(train_loss)
        sum_train_loss = np.zeros_like(train_loss)
        sum_train = 0
        loss = torch.FloatTensor([0]).sum()


        start_timer = timer()
        rate = 0
        best_dice = 0
        for epoch in range(1, args.epochs+1):
            #print('\r',end='',flush=True)
            #log.write(message(mode='log')+'\n')
            # training
            for t, batch in enumerate(train_loader):

                # learning rate schduler -------------
                #adjust_learning_rate(optimizer, schduler(iteration))
                rate = get_learning_rate(optimizer)

                # one iteration update  -------------
                batch_size = len(batch['index'])
                net.train()

                if args.amp:
                    #image = image.half()
                    with autocast():
                        mask  = batch['mask'].to(device)
                        image = batch['image'].to(device)

                        optimizer.zero_grad()
                        #logit = data_parallel(net, image)
                        if args.clf_head:
                            logit, logit2 = net(image) # seg logit, clf logit
                        else:
                            logit = net(image)
                        if args.loss == 'bce':
                            if args.label_smoothing:
                                loss = LabelSmoothing()(logit, mask)
                            else:
                                loss = criterion_binary_cross_entropy(logit, mask)
                            if args.clf_head:
                                loss += args.clf_alpha *nn.BCEWithLogitsLoss()(logit2, (mask.sum(dim=(2,3))>0).float() )
                        elif args.loss =='lovasz':
                            #loss = LovaszHingeLoss()(logit, mask)
                            loss = symmetric_lovasz(logit, mask)
                            
                        elif args.loss == 'bce_dice':
                            loss = DiceBCELoss()(logit, mask)

                    scaler.scale(loss).backward()
                    scaler.step(optimizer)
                    scaler.update()

                else :
                    mask  = batch['mask'].to(device)
                    image = batch['image'].to(device)

                    optimizer.zero_grad()
                    #logit = data_parallel(net, image)
                    logit = net(image)
                    if args.loss == 'bce':
                        loss = criterion_binary_cross_entropy(logit, mask)
                    elif args.loss =='lovasz':
                        loss = symmetric_lovasz(logit, mask)

                    loss.backward()
                    optimizer.step()


                # print statistics  --------

                batch_loss = np.array([ loss.item(), 0 ])
                sum_train_loss += batch_loss
                sum_train += 1

                #print('\r',end='',flush=True)
                #print(message(mode='print'), end='',flush=True)
            

            # train loss
            train_loss = sum_train_loss/(sum_train+1e-12)
            sum_train_loss[...] = 0
            sum_train = 0
            print("do valid...")
            # scheudler
            valid_loss1 = do_valid(net, valid_loader1) #
            valid_loss2 = do_valid(net, valid_loader2)
            valid_loss3 = do_valid(net, valid_loader3)
            valid_loss = (np.array(valid_loss1) + np.array(valid_loss2) + np.array(valid_loss3))/3
            
            log.write(message(mode='log')+'\n')
            log.write(f'{val_img_id1} dice : {valid_loss1[0]:.5f}, {val_img_id2} dice : {valid_loss2[0]:.5f}, {val_img_id3} dice : {valid_loss3[0]:.5f}\n')
            
            if isinstance(scheduler, ReduceLROnPlateau):
                scheduler.step(valid_loss[0])
            else:
                scheduler.step()
            
            # saved models
            #if valid_loss[0] > best_dice:
            if valid_loss[0] > best_dice:
                best_dice = valid_loss[0]
                log.write(f'\n saved best models, dice:{best_dice:.5f}\n')
                torch.save({
                    'state_dict': net.state_dict(),
                    'epoch': epoch,
                }, out_dir + f'/checkpoint/{n_fold}fold_{epoch}epoch_{best_dice:.4f}_{args.encoders[n_fold]}_{args.decoders[n_fold]}model.pth')
            
            log.write('='*80+'\n')

        log.write('\n')
        
        all_dice.append(best_dice)
    
    print(f'all dice score : {sum(all_dice)/len(all_dice) : .4f}')


In [22]:
if __name__ == '__main__':
    # set seed
    print('no set seed') if args.seed ==-1 else set_seeds(seed=args.seed)
    run_train(args)

__module__       : __main__
amp              : True
gpu              : 4,5
encoder          : b4
decoder          : unet
diff_arch        : True
encoders         : ['efficientnet-b4', 'efficientnet-b4', 'efficientnet-b4', 'efficientnet-b4', 'efficientnet-b4']
decoders         : ['unet', 'unet', 'unet', 'unet', 'unet']
batch_size       : 8
weight_decay     : 1e-06
epochs           : 50
n_fold           : 5
fold             : 0
all_fold_train   : True
image_size       : 1024
crop_size        : 1024
tile_size        : 1280
tile_step        : 640
tile_scale       : 0.5
dataset          : 0.5_1280_640_train_fold
val_dataset      : 0.5_1280_1280_val_fold
dir              : 50_['efficientnet-b4', 'efficientnet-b4', 'efficientnet-b4', 'efficientnet-b4', 'efficientnet-b4']_['unet', 'unet', 'unet', 'unet', 'unet']_1024_1280_640_0.5
T_max            : 10
opt              : radam_look
scheduler        : CosineAnnealingLR
loss             : bce
factor           : 0.4
patience         : 3
eps       

do valid...
0.00079  24*    | 0.934  0.009  0.930  0.998  | 0.010  0.000   |  2 hr 18 min
b9a3865fc dice : 0.94494, 0486052bb dice : 0.95140, afa5e8098 dice : 0.90527
do valid...
0.00065  25*    | 0.938  0.009  0.957  0.998  | 0.010  0.000   |  2 hr 24 min
b9a3865fc dice : 0.94124, 0486052bb dice : 0.95206, afa5e8098 dice : 0.92105
do valid...
0.00050  26*    | 0.938  0.009  0.950  0.998  | 0.009  0.000   |  2 hr 30 min
b9a3865fc dice : 0.94405, 0486052bb dice : 0.94860, afa5e8098 dice : 0.92215
do valid...
0.00035  27*    | 0.940  0.009  0.954  0.998  | 0.009  0.000   |  2 hr 35 min
b9a3865fc dice : 0.94250, 0486052bb dice : 0.95330, afa5e8098 dice : 0.92383
do valid...
0.00021  28*    | 0.938  0.009  0.948  0.998  | 0.009  0.000   |  2 hr 41 min
b9a3865fc dice : 0.94466, 0486052bb dice : 0.95250, afa5e8098 dice : 0.91564
do valid...
0.00010  29*    | 0.937  0.009  0.950  0.998  | 0.008  0.000   |  2 hr 47 min
b9a3865fc dice : 0.94477, 0486052bb dice : 0.95248, afa5e8098 dice : 0.9150

do valid...
0.00079  4*    | 0.927  0.013  0.936  0.997  | 0.013  0.000   |  0 hr 23 min
aaa6a05cc dice : 0.90628, cb2d976f4 dice : 0.94061, 4ef6695ce dice : 0.93385

 saved best models, dice:0.92692
do valid...
0.00065  5*    | 0.926  0.013  0.908  0.998  | 0.012  0.000   |  0 hr 28 min
aaa6a05cc dice : 0.89676, cb2d976f4 dice : 0.94311, 4ef6695ce dice : 0.93719
do valid...
0.00050  6*    | 0.933  0.012  0.944  0.997  | 0.011  0.000   |  0 hr 34 min
aaa6a05cc dice : 0.91819, cb2d976f4 dice : 0.94453, 4ef6695ce dice : 0.93612

 saved best models, dice:0.93294
do valid...
0.00035  7*    | 0.930  0.012  0.918  0.998  | 0.011  0.000   |  0 hr 40 min
aaa6a05cc dice : 0.91179, cb2d976f4 dice : 0.93920, 4ef6695ce dice : 0.93872
do valid...
0.00021  8*    | 0.934  0.011  0.929  0.998  | 0.010  0.000   |  0 hr 46 min
aaa6a05cc dice : 0.91889, cb2d976f4 dice : 0.94401, 4ef6695ce dice : 0.93988

 saved best models, dice:0.93426
do valid...
0.00010  9*    | 0.934  0.011  0.935  0.998  | 0.009  0.

0.00050  36*    | 0.919  0.016  0.888  0.999  | 0.008  0.000   |  3 hr 27 min
aaa6a05cc dice : 0.88036, cb2d976f4 dice : 0.93748, 4ef6695ce dice : 0.94021
do valid...
0.00065  37*    | 0.933  0.012  0.922  0.998  | 0.008  0.000   |  3 hr 33 min
aaa6a05cc dice : 0.91427, cb2d976f4 dice : 0.94055, 4ef6695ce dice : 0.94286
do valid...
0.00079  38*    | 0.927  0.013  0.912  0.998  | 0.009  0.000   |  3 hr 38 min
aaa6a05cc dice : 0.91728, cb2d976f4 dice : 0.93308, 4ef6695ce dice : 0.93134
do valid...
0.00090  39*    | 0.932  0.012  0.919  0.998  | 0.009  0.000   |  3 hr 44 min
aaa6a05cc dice : 0.92102, cb2d976f4 dice : 0.94047, 4ef6695ce dice : 0.93326
do valid...
0.00098  40*    | 0.925  0.014  0.901  0.998  | 0.010  0.000   |  3 hr 50 min
aaa6a05cc dice : 0.90784, cb2d976f4 dice : 0.94127, 4ef6695ce dice : 0.92635
do valid...
0.00100  41*    | 0.930  0.012  0.913  0.998  | 0.009  0.000   |  3 hr 56 min
aaa6a05cc dice : 0.91682, cb2d976f4 dice : 0.93649, 4ef6695ce dice : 0.93602
do valid..

do valid...
0.00050  16*    | 0.937  0.016  0.935  0.997  | 0.010  0.000   |  1 hr 37 min
e79de561c dice : 0.94037, 095bf7a1f dice : 0.93311, 1e2425f28 dice : 0.93884
do valid...
0.00065  17*    | 0.932  0.016  0.933  0.997  | 0.011  0.000   |  1 hr 43 min
e79de561c dice : 0.92435, 095bf7a1f dice : 0.93321, 1e2425f28 dice : 0.93923
do valid...
0.00079  18*    | 0.940  0.014  0.940  0.997  | 0.011  0.000   |  1 hr 49 min
e79de561c dice : 0.94266, 095bf7a1f dice : 0.93825, 1e2425f28 dice : 0.93943

 saved best models, dice:0.94011
do valid...
0.00090  19*    | 0.939  0.014  0.945  0.997  | 0.010  0.000   |  1 hr 55 min
e79de561c dice : 0.94006, 095bf7a1f dice : 0.93811, 1e2425f28 dice : 0.93972
do valid...
0.00098  20*    | 0.924  0.018  0.887  0.998  | 0.011  0.000   |  2 hr 01 min
e79de561c dice : 0.91015, 095bf7a1f dice : 0.92581, 1e2425f28 dice : 0.93527
do valid...
0.00100  21*    | 0.937  0.015  0.925  0.998  | 0.010  0.000   |  2 hr 07 min
e79de561c dice : 0.93799, 095bf7a1f dice 

do valid...
0.00010  49*    | 0.938  0.016  0.937  0.997  | 0.006  0.000   |  4 hr 58 min
e79de561c dice : 0.93496, 095bf7a1f dice : 0.93577, 1e2425f28 dice : 0.94256
do valid...
0.00003  50*    | 0.937  0.017  0.940  0.997  | 0.006  0.000   |  5 hr 04 min
e79de561c dice : 0.93432, 095bf7a1f dice : 0.93606, 1e2425f28 dice : 0.94146

fold = 3
train_dataset : 
	len  = 4733

valid_dataset1 : 
	len  = 48

valid_dataset2 : 
	len  = 118

valid_dataset3 : 
	len  = 26


** net setting **
multi gpuoptimizer
  Lookahead (
Parameter Group 0
    betas: (0.9, 0.999)
    eps: 1e-08
    initial_lr: 0.001
    lr: 0.001
    step_counter: 0
    weight_decay: 0
)

** start training here! **
   is_mixed_precision = True 
   batch_size = 8 
             |-------------- VALID---------|---- TRAIN/BATCH ----------------
rate  epoch  | dice   loss   tp     tn     | loss           | time           
-------------------------------------------------------------------------------------
do valid...
0.00100  1*    |

2f6ecfcdf dice : 0.94991, 8242609fa dice : 0.95152, b2dc8411c dice : 0.94396
do valid...
0.00010  29*    | 0.952  0.007  0.929  0.999  | 0.008  0.000   |  3 hr 00 min
2f6ecfcdf dice : 0.95255, 8242609fa dice : 0.95330, b2dc8411c dice : 0.94950
do valid...
0.00003  30*    | 0.953  0.007  0.933  0.999  | 0.008  0.000   |  3 hr 06 min
2f6ecfcdf dice : 0.95334, 8242609fa dice : 0.95456, b2dc8411c dice : 0.95146
do valid...
0.00000  31*    | 0.952  0.007  0.931  0.999  | 0.008  0.000   |  3 hr 12 min
2f6ecfcdf dice : 0.95260, 8242609fa dice : 0.95387, b2dc8411c dice : 0.95070
do valid...
0.00003  32*    | 0.952  0.007  0.932  0.999  | 0.008  0.000   |  3 hr 18 min
2f6ecfcdf dice : 0.95292, 8242609fa dice : 0.95381, b2dc8411c dice : 0.95046
do valid...
0.00010  33*    | 0.952  0.007  0.933  0.999  | 0.008  0.000   |  3 hr 24 min
2f6ecfcdf dice : 0.95274, 8242609fa dice : 0.95358, b2dc8411c dice : 0.94859
do valid...
0.00021  34*    | 0.953  0.007  0.936  0.999  | 0.008  0.000   |  3 hr 31 mi

54f2eec69 dice : 0.93234, 26dc41664 dice : 0.94972, c68fe75ea dice : 0.88684
do valid...
0.00010  9*    | 0.924  0.011  0.915  0.998  | 0.010  0.000   |  0 hr 52 min
54f2eec69 dice : 0.93080, 26dc41664 dice : 0.94989, c68fe75ea dice : 0.89269
do valid...
0.00003  10*    | 0.925  0.011  0.918  0.998  | 0.010  0.000   |  0 hr 58 min
54f2eec69 dice : 0.93272, 26dc41664 dice : 0.95101, c68fe75ea dice : 0.89027
do valid...
0.00000  11*    | 0.924  0.011  0.918  0.998  | 0.010  0.000   |  1 hr 04 min
54f2eec69 dice : 0.93240, 26dc41664 dice : 0.95058, c68fe75ea dice : 0.89029
do valid...
0.00003  12*    | 0.925  0.011  0.921  0.998  | 0.009  0.000   |  1 hr 10 min
54f2eec69 dice : 0.93196, 26dc41664 dice : 0.95156, c68fe75ea dice : 0.89147
do valid...
0.00010  13*    | 0.921  0.012  0.912  0.998  | 0.009  0.000   |  1 hr 16 min
54f2eec69 dice : 0.92843, 26dc41664 dice : 0.95110, c68fe75ea dice : 0.88422
do valid...
0.00021  14*    | 0.918  0.013  0.904  0.998  | 0.010  0.000   |  1 hr 22 min

do valid...
0.00098  42*    | 0.919  0.014  0.899  0.998  | 0.009  0.000   |  4 hr 23 min
54f2eec69 dice : 0.93116, 26dc41664 dice : 0.94257, c68fe75ea dice : 0.88236
do valid...
0.00090  43*    | 0.920  0.012  0.899  0.998  | 0.009  0.000   |  4 hr 29 min
54f2eec69 dice : 0.92749, 26dc41664 dice : 0.94794, c68fe75ea dice : 0.88520
do valid...
0.00079  44*    | 0.903  0.015  0.877  0.998  | 0.008  0.000   |  4 hr 35 min
54f2eec69 dice : 0.89196, 26dc41664 dice : 0.93660, c68fe75ea dice : 0.88027
do valid...
0.00065  45*    | 0.926  0.012  0.919  0.998  | 0.008  0.000   |  4 hr 41 min
54f2eec69 dice : 0.93327, 26dc41664 dice : 0.94801, c68fe75ea dice : 0.89577

 saved best models, dice:0.92568
do valid...
0.00050  46*    | 0.927  0.013  0.920  0.998  | 0.008  0.000   |  4 hr 47 min
54f2eec69 dice : 0.93551, 26dc41664 dice : 0.95059, c68fe75ea dice : 0.89567

 saved best models, dice:0.92726
do valid...
0.00035  47*    | 0.928  0.013  0.926  0.998  | 0.008  0.000   |  4 hr 53 min
54f2eec

In [None]:
# #0.9333, deterministic하게 결정시킨것임.

# b9a3865fc dice : 0.94369, 0486052bb dice : 0.95022, afa5e8098 dice : 0.90132
# saved best models, dice:0.93174

# aaa6a05cc dice : 0.90918, cb2d976f4 dice : 0.94353, 4ef6695ce dice : 0.93610
# saved best models, dice:0.92960

# e79de561c dice : 0.93229, 095bf7a1f dice : 0.93469, 1e2425f28 dice : 0.93323
# saved best models, dice:0.93340

# 2f6ecfcdf dice : 0.95216, 8242609fa dice : 0.95276, b2dc8411c dice : 0.94840
# saved best models, dice:0.95111
 
# 54f2eec69 dice : 0.92800, 26dc41664 dice : 0.94386, c68fe75ea dice : 0.88966
# saved best models, dice:0.92050


# Train with albumentation augment version

albumentation으로 augmentation 실험할 경우 사용

In [45]:
# ------- image fold train version, my augment ------- #

# augmentation
def train_augment():

    return A.Compose([
            A.OneOf([
                A.RandomCrop(args.image_size,args.crop_size),
                A.RandomResizedCrop(args.image_size,args.crop_size)
             ], p=1),
            A.OneOf([
                #A.RandomContrast(),
                #A.RandomBrightness(),
                A.RandomGamma(),
                A.RandomBrightnessContrast()
                ], p=0.5),
            A.OneOf([
                A.CLAHE(clip_limit=2),
                A.HueSaturationValue(10,15,10),
                A.ChannelShuffle(),
                A.InvertImg()
                ], p=1),
            A.OneOf([
                A.HorizontalFlip(),
                A.VerticalFlip(),
                A.RandomRotate90(),
                A.ShiftScaleRotate()
            ], p = 0.5 ),
        
        #A.Resize(512, 512),
        ToTensor()
    ],p=1.)
def val_augment():

    return A.Compose([
        ToTensor()
    ],p=1.)

def do_valid(net, valid_loader):

    valid_num = 0
    total = 0 ; dice=0 ; loss=0 ; tp = 0 ; tn = 0
    dice2=0 ; loss2=0
    valid_probability, valid_probability2, valid_probability3 = [],[],[]
    valid_mask, valid_mask2, valid_mask3 = [],[],[]

    net = net.eval()

    #start_timer = timer()
    with torch.no_grad():
        for t, (image,mask) in enumerate(valid_loader):
            #mask  = batch['mask']
            image = image.to(device)
            
            if args.clf_head:
                logit, _ = net(image) # seg, clf
            else:
                logit = net(image)#data_parallel(net, image) #net(input)#
            probability = torch.sigmoid(logit)
                
            valid_probability.append(probability.data.cpu().numpy())
            valid_mask.append(mask.data.cpu().numpy())

    #assert(valid_num == len(valid_loader.dataset)) # drop last True이면 assert되는거임
    probability = np.concatenate(valid_probability)
    mask = np.concatenate(valid_mask)
    if args.loss =='bce':
        loss = np_binary_cross_entropy_loss(probability, mask)
    elif args.loss =='lovasz':
        loss = 0
    
    # mean loss, dice ..
    dice = np_dice_score(probability, mask)
    tp, tn = np_accuracy(probability, mask)

    return [dice, loss,  tp, tn]

def run_train(args):
    out_dir = data_dir + f'/result/{args.dir}_{args.encoder}_{args.image_size}'

    ## setup  ----------------------------------------
    for f in ['checkpoint','train','valid'] : os.makedirs(out_dir +'/'+f, exist_ok=True)
    #backup_project_as_zip(PROJECT_PATH, out_dir +'/backup/code.train.%s.zip'%IDENTIFIER)
    log = Logger()
    log.open(out_dir+'/log.train.txt',mode='a')

    # my log argument
    print_args(args, log)

    log.write('\tout_dir  = %s\n' % out_dir)
    log.write('\n')


    log.write('** dataset setting **\n')
    #-----------dataset split --------------------#
    tile_id = []
    image_dir_ = f'{args.dataset}'#'0.25_320_160_train'
    image_dir=[image_dir_, ] # pseudo할때 뒤에 추가
    
    image_dir_val_ = f'{args.val_dataset}'#'0.25_320_320_val'
    image_dir_val=[image_dir_val_, ]
    
    for i in range(len(image_dir)):
        df = pd.read_csv(data_dir + '/tile/%s/image_id_split.csv'% (image_dir[i]) )

    for i in range(len(image_dir_val)):
        df2 = pd.read_csv(data_dir + '/tile/%s/image_id_split.csv'% (image_dir_val[i]) )
    df2['img_id'] = df2['tile_id'].apply(lambda x: x.split('/')[-2])
        
    kf = KFold(n_splits=args.n_fold, random_state=args.seed, shuffle=True)
    all_dice = []
    for n_fold, (trn_idx, val_idx) in enumerate(kf.split(df)):
        if not args.all_fold_train:
            if n_fold != args.fold:
                print(f'{n_fold} fold pass')
                continue
        if n_fold in [3, 4]:
            print(n_fold,'fold pass')
            continue
        train_df = df[df['fold']!= n_fold].reset_index(drop=True)
        val_df = df2[df2['fold']== n_fold].reset_index(drop=True).copy()
        
        # validation loader 3개 만들기 위함
        unique_value = val_df['tile_id'].apply(lambda x: x.split('/')[-2]).unique() #[valid_id1, valid_id2, valid_id3 ]
        val_img_id1 = unique_value[0] ; val_img_id2 = unique_value[1] ; val_img_id3= unique_value[2]
        val_df1= val_df[val_df['img_id']==val_img_id1].reset_index(drop=True)
        val_df2= val_df[val_df['img_id']==val_img_id2].reset_index(drop=True)
        val_df3= val_df[val_df['img_id']==val_img_id3].reset_index(drop=True)
        #####################################################
        train_dataset = HuDataset(
            df = train_df,
            augment = train_augment()
        )
        train_loader  = DataLoader(
            train_dataset,
            sampler = RandomSampler(train_dataset),
            batch_size  = args.batch_size,
            drop_last   = False,
            num_workers = 8,
            pin_memory  = True,
        )
        # val loader1
        valid_dataset1 = HuDataset(
            df = val_df1,
            augment = val_augment()
        )
        valid_loader1 = DataLoader(
            valid_dataset1,
            sampler = SequentialSampler(valid_dataset1),
            batch_size  = args.batch_size,
            drop_last   = False,
            num_workers = 4,
            pin_memory  = True,
        )
        # val loader2
        valid_dataset2 = HuDataset(
            df = val_df2,
            augment = val_augment()
        )
        
        valid_loader2 = DataLoader(
            valid_dataset2,
            sampler = SequentialSampler(valid_dataset2),
            batch_size  = args.batch_size,
            drop_last   = False,
            num_workers = 4,
            pin_memory  = True,
        )
        # val loader3
        valid_dataset3 = HuDataset(
            df = val_df3,
            augment = val_augment()
        )
        valid_loader3 = DataLoader(
            valid_dataset3,
            sampler = SequentialSampler(valid_dataset3),
            batch_size  = args.batch_size,
            drop_last   = False,
            num_workers = 4,
            pin_memory  = True,
        )
        log.write('fold = %s\n'%str(n_fold))
        log.write('train_dataset : \n%s\n'%(train_dataset))
        log.write('valid_dataset1 : \n%s\n'%(valid_dataset1))
        log.write('valid_dataset2 : \n%s\n'%(valid_dataset2))
        log.write('valid_dataset3 : \n%s\n'%(valid_dataset3))
        log.write('\n')

        # ------------------------
        #  Model
        # ------------------------
        log.write('** net setting **\n')

        scaler = GradScaler()
        net = SegModel() 
        net = net.to(device)
        
        # ------------------------
        #  Optimizer
        # ------------------------
        if args.opt =='adamw':
            optimizer = torch.optim.AdamW(net.parameters(), lr = args.start_lr)

        elif args.opt =='radam_look':
            optimizer = Lookahead(RAdam(filter(lambda p: p.requires_grad, net.parameters()),lr=args.start_lr), alpha=0.5, k=5)
        if optimizer == None:
            assert False, 'no have optimizer'
        
        # ------------------------
        #  scheduler
        # ------------------------
        scheduler = get_scheduler(optimizer)


        log.write('optimizer\n  %s\n'%(optimizer))
        #log.write('schduler\n  %s\n'%(schduler))
        log.write('\n')

        ## start training here! ##############################################
        #array([0.57142857, 0.42857143])
        log.write('** start training here! **\n')
        log.write('   is_mixed_precision = %s \n'%str(args.amp))
        log.write('   batch_size = %d \n'%(args.batch_size))
        log.write('             |-------------- VALID---------|---- TRAIN/BATCH ----------------\n')
        log.write('rate  epoch  | dice   loss   tp     tn     | loss           | time           \n')
        log.write('-------------------------------------------------------------------------------------\n')
                  #0.00100   0.50  0.80 | 0.891  0.020  0.000  0.000  | 0.000  0.000   |  0 hr 02 min

        def message(mode='print'):
            if mode==('print'):
                asterisk = ' '
                loss = batch_loss
            if mode==('log'):
                asterisk = '*'
                loss = train_loss

            text = \
                '%0.5f  %s%s    | '%(rate, epoch, asterisk,) +\
                '%4.3f  %4.3f  %4.3f  %4.3f  | '%(*valid_loss,) +\
                '%4.3f  %4.3f   | '%(*loss,) +\
                '%s' % (time_to_str(timer() - start_timer,'min'))

            return text

        #----
        valid_loss = np.zeros(4,np.float32)
        train_loss = np.zeros(2,np.float32)
        batch_loss = np.zeros_like(train_loss)
        sum_train_loss = np.zeros_like(train_loss)
        sum_train = 0
        loss = torch.FloatTensor([0]).sum()


        start_timer = timer()
        rate = 0
        best_dice = 0
        for epoch in range(1, args.epochs+1):
            #print('\r',end='',flush=True)
            #log.write(message(mode='log')+'\n')
            # training
            for t, (image, mask) in enumerate(train_loader):
                # learning rate schduler -------------
                #adjust_learning_rate(optimizer, schduler(iteration))
                rate = get_learning_rate(optimizer)

                # one iteration update  -------------
                #batch_size = len(batch['index'])
                net.train()

                if args.amp:
                    #image = image.half()
                    with autocast():
                        mask  = mask.to(device)
                        image = image.to(device)

                        optimizer.zero_grad()
                        #logit = data_parallel(net, image)
                        if args.clf_head:
                            logit, logit2 = net(image) # seg logit, clf logit
                        else:
                            logit = net(image)
                        if args.loss == 'bce':
                            if args.label_smoothing:
                                loss = LabelSmoothing()(logit, mask)
                            else:
                                loss = criterion_binary_cross_entropy(logit, mask)
                            if args.clf_head:
                                loss += args.clf_alpha *nn.BCEWithLogitsLoss()(logit2, (mask.sum(dim=(2,3))>0).float() )
                        elif args.loss =='lovasz':
                            #loss = LovaszHingeLoss()(logit, mask)
                            loss = symmetric_lovasz(logit, mask)
                            
                        elif args.loss == 'bce_dice':
                            loss = DiceBCELoss()(logit, mask)
                        elif args.loss == 'dice':
                            loss = DiceLoss()(logit, mask)
                            

                    scaler.scale(loss).backward()
                    scaler.step(optimizer)
                    scaler.update()

                else :
                    mask  = batch['mask'].to(device)
                    image = batch['image'].to(device)

                    optimizer.zero_grad()
                    #logit = data_parallel(net, image)
                    logit = net(image)
                    if args.loss == 'bce':
                        loss = criterion_binary_cross_entropy(logit, mask)
                    elif args.loss =='lovasz':
                        loss = symmetric_lovasz(logit, mask)

                    loss.backward()
                    optimizer.step()


                # print statistics  --------

                batch_loss = np.array([ loss.item(), 0 ])
                sum_train_loss += batch_loss
                sum_train += 1

                #print('\r',end='',flush=True)
                #print(message(mode='print'), end='',flush=True)
            

            # train loss
            train_loss = sum_train_loss/(sum_train+1e-12)
            sum_train_loss[...] = 0
            sum_train = 0

            # scheudler
            valid_loss1 = do_valid(net, valid_loader1) #
            valid_loss2 = do_valid(net, valid_loader2)
            valid_loss3 = do_valid(net, valid_loader3)
            valid_loss = (np.array(valid_loss1) + np.array(valid_loss2) + np.array(valid_loss3))/3
            
            log.write(message(mode='log')+'\n')
            log.write(f'{val_img_id1} dice : {valid_loss1[0]:.5f}, {val_img_id2} dice : {valid_loss2[0]:.5f}, {val_img_id3} dice : {valid_loss3[0]:.5f}\n')
            
            if isinstance(scheduler, ReduceLROnPlateau):
                scheduler.step(valid_loss[0])
            else:
                scheduler.step()

            # saved models
            #if valid_loss[0] > best_dice:
            if valid_loss[0] > best_dice:
                best_dice = valid_loss[0]
                log.write(f'\n saved best models, dice:{best_dice:.5f}\n')
                torch.save({
                    'state_dict': net.state_dict(),
                    'epoch': epoch,
                }, out_dir + f'/checkpoint/{n_fold}fold_{epoch}epoch_{best_dice:.4f}_model.pth')
            
            log.write('='*80+'\n')

        log.write('\n')
        
        all_dice.append(best_dice)
    
    print(f'all dice score : {sum(all_dice)/len(all_dice) : .4f}')
            



In [14]:
if __name__ == '__main__':
    # set seed
    print('no set seed') if args.seed ==-1 else set_seeds(seed=args.seed)
    run_train(args)

__module__       : __main__
amp              : True
gpu              : 0, 1, 2, 3
encoder          : b4
decoder          : unet
diff_arch        : True
encoders         : ['efficientnet-b4', 'efficientnet-b4', 'resnet34', 'xception', 'dpn68']
decoders         : ['unet', 'fpn', 'upp', 'unet', 'upp']
batch_size       : 64
weight_decay     : 1e-06
epochs           : 1
n_fold           : 5
fold             : 0
all_fold_train   : True
image_size       : 512
crop_size        : 512
tile_size        : 640
tile_step        : 320
tile_scale       : 0.5
dataset          : 0.5_640_320_train_fold
val_dataset      : 0.5_640_640_val_fold
dir              : 1_['efficientnet-b4', 'efficientnet-b4', 'resnet34', 'xception', 'dpn68']_['unet', 'fpn', 'upp', 'unet', 'upp']_512_640_320_0.5
T_max            : 10
opt              : radam_look
scheduler        : CosineAnnealingLR
loss             : bce
factor           : 0.4
patience         : 3
eps              : 1e-06
decay_epoch      : [4, 8, 12]
T_0        

AttributeError: 'list' object has no attribute 'to'

In [14]:
# ------- tile shuffle fold train version(use X) ----------#
"""
# augmentation
def train_augment(record):
    image = record['image']
    mask  = record['mask']
    
    for fn in np.random.choice([
        lambda image, mask : do_random_rotate_crop(image, mask, size=args.image_size, mag=45),
        lambda image, mask : do_random_scale_crop(image, mask, size=args.image_size, mag=0.075),
        lambda image, mask : do_random_crop(image, mask, size=args.image_size),
    ],1): image, mask = fn(image, mask)

    #if (np.random.choice(10,1)<7)[0]:
    for fn in np.random.choice([
        lambda image, mask : (image, mask),
        lambda image, mask : do_random_contast(image, mask, mag=0.8),
        lambda image, mask : do_random_gain(image, mask, mag=0.9),
        #lambda image, mask : do_random_hsv(image, mask, mag=[0.1, 0.2, 0]),
        lambda image, mask : do_random_noise(image, mask, mag=0.1),
    ],2): image, mask =  fn(image, mask)
    #if (np.random.choice(10,1)<7)[0]:
    image, mask = do_random_hsv(image, mask, mag=[0.1, 0.2, 0])
    image, mask = do_random_flip_transpose(image, mask)

    record['mask'] = mask
    record['image'] = image
    return record

# validation
def do_valid2(net, valid_loader):

    valid_num = 0
    total = 0 ; dice=0 ; loss=0 ; tp = 0 ; tn = 0
    dice2=0 ; loss2=0
    valid_probability, valid_probability2 = [],[]
    valid_mask = []

    net = net.eval()

    #start_timer = timer()
    with torch.no_grad():
        for t, batch in enumerate(valid_loader):
            batch_size = len(batch['index'])
            mask  = batch['mask']
            image = batch['image'].to(device)

            logit = net(image)#data_parallel(net, image) #net(input)#
            probability = torch.sigmoid(logit)
            
            # loss
            if args.loss =='bce':
                #loss += criterion_binary_cross_entropy(probability.data.cpu(), mask.data.cpu()).item()
                loss += np_binary_cross_entropy_loss(probability.cpu().numpy(), mask.cpu().numpy())
            elif args.loss =='lovasz':
                loss += symmetric_lovasz(probability.data.cpu(), mask.data.cpu()).item()
                
            # dice
            dice += dice_score(probability.data.cpu(), mask.data.cpu(), threshold = 0.5).item()
            tp_, tn_ = torch_accuracy(probability.data.cpu(), mask.data.cpu(), threshold = 0.5)
            tp+=tp_.item() ; tn += tn_.item()
            # numpy
            #dice2 += np_dice_score(probability.data.cpu().numpy(), mask.data.cpu().numpy())
            #loss2 += np_binary_cross_entropy_loss(probability.cpu().numpy(), mask.cpu().numpy())

            #valid_num += batch_size

    #assert(valid_num == len(valid_loader.dataset)) # drop last True이면 assert되는거임
    
    # mean loss, dice ..
    loss = loss/len(valid_loader)
    dice = dice/len(valid_loader)
    tp = tp/len(valid_loader) ; tn = tn/len(valid_loader)

    return [dice, loss,  tp, tn]
# append 버전
def do_valid(net, valid_loader):

    valid_num = 0
    total = 0 ; dice=0 ; loss=0 ; tp = 0 ; tn = 0
    dice2=0 ; loss2=0
    valid_probability, valid_probability2 = [],[]
    valid_mask = []

    net = net.eval()

    #start_timer = timer()
    with torch.no_grad():
        for t, batch in enumerate(valid_loader):
            batch_size = len(batch['index'])
            mask  = batch['mask']
            image = batch['image'].to(device)

            logit = net(image)#data_parallel(net, image) #net(input)#
            probability = torch.sigmoid(logit)
            
            valid_probability.append(probability.data.cpu().numpy())
            valid_mask.append(mask.data.cpu().numpy())


    
    probability = np.concatenate(valid_probability)
    mask = np.concatenate(valid_mask)
    if args.loss =='bce':
        loss = np_binary_cross_entropy_loss(probability, mask)
    elif args.loss =='lovasz':
        loss = symmetric_lovasz(probability, mask)
    
    # mean loss, dice ..
    dice = np_dice_score(probability, mask)
    tp, tn = np_accuracy(probability, mask)
    return [dice, loss,  tp, tn]

def run_train(args):
    out_dir = data_dir + f'/result/{args.dir}_fold{args.fold}_{args.encoder}_{args.image_size}'

    ## setup  ----------------------------------------
    for f in ['checkpoint','train','valid','backup'] : os.makedirs(out_dir +'/'+f, exist_ok=True)
    #backup_project_as_zip(PROJECT_PATH, out_dir +'/backup/code.train.%s.zip'%IDENTIFIER)
    log = Logger()
    log.open(out_dir+'/log.train.txt',mode='a')

    # my log argument
    print_args(args, log)

    log.write('\tout_dir  = %s\n' % out_dir)
    log.write('\n')


    log.write('** dataset setting **\n')
    #-----------dataset split --------------------#
    tile_id = []
    image_dir_ = f'{args.dataset}'#'0.25_480_240_train'
    image_dir=[image_dir_, ] # pseudo할때 뒤에 추가

    for i in range(len(image_dir)):
        df = pd.read_csv(data_dir + '/tile/%s/image_id.csv'% (image_dir[i]) )
        tile_id += ('%s/'%(image_dir[i]) + df.tile_id).tolist()

    kf = KFold(n_splits=args.n_fold, random_state=args.seed, shuffle=True)
    all_dice = []
    for n_fold, (trn_idx, val_idx) in enumerate(kf.split(tile_id)):
        if not args.all_fold_train:
            if n_fold != args.fold:
                print(f'{n_fold} fold pass')
                continue

        #####################################################
        train_dataset = HuDataset(
            tile_id = df.loc[trn_idx]['tile_id'].tolist(),
            augment = train_augment
        )
        train_loader  = DataLoader(
            train_dataset,
            sampler = RandomSampler(train_dataset),
            batch_size  = args.batch_size,
            drop_last   = False,
            num_workers = 8,
            pin_memory  = True,
            collate_fn  = null_collate
        )

        valid_dataset = HuDataset(
            tile_id = df.loc[val_idx]['tile_id'].tolist()
            ,
        )
        valid_loader = DataLoader(
            valid_dataset,
            sampler = SequentialSampler(valid_dataset),
            batch_size  = args.batch_size,
            drop_last   = False,
            num_workers = 4,
            pin_memory  = True,
            collate_fn  = null_collate
        )
        log.write('fold = %s\n'%str(n_fold))
        log.write('train_dataset : \n%s\n'%(train_dataset))
        log.write('valid_dataset : \n%s\n'%(valid_dataset))
        log.write('\n')

        # ------------------------
        #  Model
        # ------------------------
        log.write('** net setting **\n')

        scaler = GradScaler()
        net = SegModel() 
        net = net.to(device)
        
        # ------------------------
        #  Optimizer
        # ------------------------
        if args.opt =='adamw':
            optimizer = torch.optim.AdamW(net.parameters(), lr = args.start_lr)

        elif args.opt =='radam_look':
            optimizer = Lookahead(RAdam(filter(lambda p: p.requires_grad, net.parameters()),lr=args.start_lr), alpha=0.5, k=5)
        if optimizer == None:
            assert False, 'no have optimizer'
        
        # ------------------------
        #  scheduler
        # ------------------------
        scheduler = get_scheduler(optimizer)


        log.write('optimizer\n  %s\n'%(optimizer))
        #log.write('schduler\n  %s\n'%(schduler))
        log.write('\n')

        ## start training here! ##############################################
        #array([0.57142857, 0.42857143])
        log.write('** start training here! **\n')
        log.write('   is_mixed_precision = %s \n'%str(args.amp))
        log.write('   batch_size = %d \n'%(args.batch_size))
        log.write('             |-------------- VALID---------|---- TRAIN/BATCH ----------------\n')
        log.write('rate  epoch  | dice   loss   tp     tn     | loss           | time           \n')
        log.write('-------------------------------------------------------------------------------------\n')
                  #0.00100   0.50  0.80 | 0.891  0.020  0.000  0.000  | 0.000  0.000   |  0 hr 02 min

        def message(mode='print'):
            if mode==('print'):
                asterisk = ' '
                loss = batch_loss
            if mode==('log'):
                asterisk = '*'
                loss = train_loss

            text = \
                '%0.5f  %s%s    | '%(rate, epoch, asterisk,) +\
                '%4.3f  %4.3f  %4.3f  %4.3f  | '%(*valid_loss,) +\
                '%4.3f  %4.3f   | '%(*loss,) +\
                '%s' % (time_to_str(timer() - start_timer,'min'))

            return text

        #----
        valid_loss = np.zeros(4,np.float32)
        train_loss = np.zeros(2,np.float32)
        batch_loss = np.zeros_like(train_loss)
        sum_train_loss = np.zeros_like(train_loss)
        sum_train = 0
        loss = torch.FloatTensor([0]).sum()


        start_timer = timer()
        rate = 0
        best_dice = 0
        for epoch in range(1, args.epochs+1):
            #print('\r',end='',flush=True)
            #log.write(message(mode='log')+'\n')
            # training
            for t, batch in enumerate(train_loader):

                # learning rate schduler -------------
                #adjust_learning_rate(optimizer, schduler(iteration))
                rate = get_learning_rate(optimizer)

                # one iteration update  -------------
                batch_size = len(batch['index'])
                net.train()

                if args.amp:
                    #image = image.half()
                    with autocast():
                        mask  = batch['mask'].to(device)
                        image = batch['image'].to(device)

                        optimizer.zero_grad()
                        #logit = data_parallel(net, image)
                        logit = net(image)
                        if args.loss == 'bce':
                            loss = criterion_binary_cross_entropy(logit, mask)
                        elif args.loss =='lovasz':
                            loss = symmetric_lovasz(logit, mask)

                    scaler.scale(loss).backward()
                    scaler.step(optimizer)
                    scaler.update()

                else :
                    mask  = batch['mask'].to(device)
                    image = batch['image'].to(device)

                    optimizer.zero_grad()
                    #logit = data_parallel(net, image)
                    logit = net(image)
                    if args.loss == 'bce':
                        loss = criterion_binary_cross_entropy(logit, mask)
                    elif args.loss =='lovasz':
                        loss = symmetric_lovasz(logit, mask)

                    loss.backward()
                    optimizer.step()


                # print statistics  --------

                batch_loss = np.array([ loss.item(), 0 ])
                sum_train_loss += batch_loss
                sum_train += 1

                #print('\r',end='',flush=True)
                #print(message(mode='print'), end='',flush=True)

            # train loss
            train_loss = sum_train_loss/(sum_train+1e-12)
            sum_train_loss[...] = 0
            sum_train = 0

            # scheudler
            valid_loss = do_valid(net, valid_loader) #
            log.write(message(mode='log')+'\n')
            if isinstance(scheduler, ReduceLROnPlateau):
                scheduler.step(valid_loss[0])
            else:
                scheduler.step()

            # saved models
            #if valid_loss[0] > best_dice:
            if valid_loss[0] > best_dice:
                best_dice = valid_loss[0]
                print(f'\n saved best models, dice:{best_dice:.5f}')
                torch.save({
                    'state_dict': net.state_dict(),
                    'epoch': epoch,
                }, out_dir + f'/checkpoint/{n_fold}fold_{epoch}epoch_{best_dice:.4f}_model.pth')

        log.write('\n')
        all_dice.append(best_dice)
    
    print(f'all dice score : {sum(all_dice)/len(all_dice) : .4f}')
  
     """   
;

''

# validation 예측

eval mode : 모델들 불러와서 validation에 해당하는 이미지 예측후 cv측정, threshold별 dice 계산

gen_image : validation에 해당하는 이미지 예측후 visualize(저장된 이미지로 확인가능)

In [20]:
class args:
    # ---- factor ---- #
    server ='local' # ['kaggle', 'local'] local은 cv측정용도
    amp = False
    gpu = 4
    
    encoder='b4'#'resnet34'
    decoder='unet'
    n_fold = 5
    diff_arch = True
    encoders = ["efficientnet-b4", "efficientnet-b4", "efficientnet-b4", "efficientnet-b4", "efficientnet-b4"]
    decoders = ["unet", "unet", "unet", "unet", "unet"]
    batch_size=16
    #fold=0
    mode = 'eval' # ['eval', 'gen_image']
    loss = 'bce'
    clf_head=False
    dataset = '0.5_1280_640_train_fold'#'[0.25_256_128_train', '0.25_480_240_train' ]# dataset size
    val_dataset = '0.5_1280_1280_val_fold'
    
    model_path = ["./data/result/50_['efficientnet-b4', 'efficientnet-b4', 'efficientnet-b4', 'efficientnet-b4', 'efficientnet-b4']_['unet', 'unet', 'unet', 'unet', 'unet']_1024_1280_640_0.5_b4_1024" +
                  "/checkpoint/" + x for x in \
                 ['0fold_49epoch_0.9421_efficientnet-b4_unetmodel.pth','1fold_47epoch_0.9374_efficientnet-b4_unetmodel.pth',
                 '2fold_45epoch_0.9418_efficientnet-b4_unetmodel.pth','3fold_10epoch_0.9536_efficientnet-b4_unetmodel.pth',
                 '4fold_50epoch_0.9283_efficientnet-b4_unetmodel.pth']]
    
    sub = '[visualize][04.05]_0.9337_models'# submission name
    
    # ---- Dataset ---- #
    
    tile_size = 1280
    tile_average_step = 640
    tile_scale = 0.5
    tile_min_score = 0.25  

#assert args.server!='local', 'not implement'
device = torch.device(f"cuda" if torch.cuda.is_available() else "cpu")

In [21]:
all_dice_dict={}

def do_valid(net, valid_loader):

    valid_num = 0
    total = 0 ; dice=0 ; loss=0 ; tp = 0 ; tn = 0
    dice2=0 ; loss2=0
    valid_probability, valid_probability2, valid_probability3 = [],[],[]
    valid_mask, valid_mask2, valid_mask3 = [],[],[]

    net = net.eval()

    #start_timer = timer()
    with torch.no_grad():
        for t, batch in enumerate(valid_loader):
            mask  = batch['mask']
            image = batch['image'].to(device)
            
            if args.clf_head:
                logit, _ = net(image) # seg, clf
            else:
                logit = net(image)#data_parallel(net, image) #net(input)#
            probability = torch.sigmoid(logit)
                
            valid_probability.append(probability.data.cpu().numpy())
            valid_mask.append(mask.data.cpu().numpy())

    #assert(valid_num == len(valid_loader.dataset)) # drop last True이면 assert되는거임
    probability = np.concatenate(valid_probability)
    mask = np.concatenate(valid_mask)
    if args.loss =='bce':
        loss = np_binary_cross_entropy_loss(probability, mask)
    elif args.loss =='lovasz':
        loss = 0
    
    dice = [np_dice_score2(probability, mask, round(th, 2)) for th in np.arange(0.1, 0.7, 0.05)]
    #tp, tn = np_accuracy(probability, mask)

    return np.array(dice)#[dice_dict, loss,  tp, tn]


def gen_val_image(args):
    out_dir = args.model_path[0].split('checkpoint')[0]

    ## setup  ----------------------------------------
    for f in ['checkpoint','train','valid','backup'] : os.makedirs(out_dir +'/'+f, exist_ok=True)

    log = Logger()
    log.open(out_dir+'/log.val.txt',mode='a')

    # my log argument
    print_args(args, log)

    submit_dir = out_dir + '/valid/%s-mean'%(args.server)
    os.makedirs(submit_dir,exist_ok=True)

    #
    for fold in range(5):
        scaler = GradScaler()
        net = SegModel() 
        net = net.to(device)
        state_dict = torch.load(args.model_path[fold], map_location=lambda storage, loc: storage)['state_dict']
        net.load_state_dict(state_dict,strict=True)  #True
        net = net.eval()

        #log.write('schduler\n  %s\n'%(schduler))
        log.write('\n')

        #----      

        # make validation predict images
        tile_size = args.tile_size #320
        tile_average_step = args.tile_average_step#320 #192
        tile_scale = args.tile_scale
        tile_min_score = args.tile_min_score
        #
        a = pd.read_csv('../hubmap/tile/0.25_320_160_train_fold/image_id_split.csv')
        b = a[a['fold']==fold]
        valid_image_id = b['tile_id'].apply(lambda x : x.split('/')[-2]).unique()

        #
        start_timer = timer()
        for id in valid_image_id:
            image_file = data_dir + '/train/%s.tiff' % id
            image = read_tiff(image_file)
            height, width = image.shape[:2]

            json_file  = data_dir + '/train/%s-anatomical-structure.json' % id
            structure = draw_strcuture_from_hue(image, fill=255, scale=tile_scale/32)   
            mask_file = data_dir + '/train/%s.mask.png' % id
            mask  = read_mask(mask_file)

            #--- predict here!  ---
            tile = to_tile(image, mask, structure, tile_scale, tile_size, tile_average_step, tile_min_score)

            tile_image = tile['tile_image']
            tile_image = np.stack(tile_image)[..., ::-1]
            tile_image = np.ascontiguousarray(tile_image.transpose(0,3,1,2))
            tile_image = tile_image.astype(np.float32)/255
            print(tile_image.shape)
            tile_probability = []

            batch = np.array_split(tile_image, len(tile_image)//4)
            for t,m in enumerate(batch):
                print('\r %s  %d / %d   %s'%(id, t, len(batch), time_to_str(timer() - start_timer, 'sec')), end='',flush=True)
                m = torch.from_numpy(m).to(device)

                p = []
                with torch.no_grad():
                    logit = net(m)
                    p.append(torch.sigmoid(logit))
                    if args.server == 'local':
                        if 0: #tta here
                            #logit = data_parallel(net, m.flip(dims=(2,)))
                            logit = net(m.flip(dims=(2,)))
                            p.append(torch.sigmoid(logit.flip(dims=(2,))))

                            #logit = data_parallel(net, m.flip(dims=(3,)))
                            logit = net(m.flip(dims=(3,)))
                            p.append(torch.sigmoid(logit.flip(dims=(3,))))
                        p = torch.cat(p)

                tile_probability.append(p.data.cpu().numpy())
            print('\r' , end='',flush=True)
            log.write('%s  %d / %d   %s\n'%(id, t, len(batch), time_to_str(timer() - start_timer, 'sec')))

            tile_probability = np.concatenate(tile_probability).squeeze(1)
            height, width = tile['image_small'].shape[:2]
            probability = to_mask(tile_probability, tile['coord'], height, width,
                                  tile_scale, tile_size, tile_average_step, tile_min_score,
                                  aggregate='mean')
            #
            truth = tile['mask_small'].astype(np.float32)/255
            overlay = np.dstack([
                np.zeros_like(truth),
                probability, #green
                truth, #red
            ])

            image_small = tile['image_small'].astype(np.float32)/255
            #predict = (probability>thres).astype(np.float32)
            overlay1 = 1-(1-image_small)*(1-overlay)
            overlay2 = image_small.copy()
            overlay2 = draw_contour_overlay(overlay2, tile['structure_small'], color=(1, 1, 1), thickness=3)
            overlay2 = draw_contour_overlay(overlay2, truth, color=(0, 0, 1), thickness=8)
            overlay2 = draw_contour_overlay(overlay2, probability, color=(0, 1, 0), thickness=3)

            if 1:
                #cv2.imwrite(submit_dir+'/%s.image_small.png'%id, (image_small*255).astype(np.uint8))
                #cv2.imwrite(submit_dir+'/%s.probability.png'%id, (probability*255).astype(np.uint8))
                #cv2.imwrite(submit_dir+'/%s.predict.png'%id, (predict*255).astype(np.uint8))
                #cv2.imwrite(submit_dir+'/%s.overlay.png'%id, (overlay*255).astype(np.uint8))
                #cv2.imwrite(submit_dir+'/%s.overlay1.png'%id, (overlay1*255).astype(np.uint8))
                cv2.imwrite(submit_dir+'/%s.overlay2.png'%id, (overlay2*255).astype(np.uint8))
def eval_image(args):

    #-----------dataset split --------------------#
    tile_id = []
    image_dir_ = f'{args.dataset}'#'0.25_320_160_train'
    image_dir=[image_dir_, ] # pseudo할때 뒤에 추가
    
    image_dir_val_ = f'{args.val_dataset}'#'0.25_320_320_val'
    image_dir_val=[image_dir_val_, ]
    
    for i in range(len(image_dir)):
        df = pd.read_csv(data_dir + '/tile/%s/image_id_split.csv'% (image_dir[i]) )

    for i in range(len(image_dir_val)):
        df2 = pd.read_csv(data_dir + '/tile/%s/image_id_split.csv'% (image_dir_val[i]) )
    df2['img_id'] = df2['tile_id'].apply(lambda x: x.split('/')[-2])
        
    all_dice = []
    for n_fold in range(5):

        train_df = df[df['fold']!= n_fold].reset_index(drop=True)
        val_df = df2[df2['fold']== n_fold].reset_index(drop=True).copy()
        
        # validation loader 3개 만들기 위함
        unique_value = val_df['tile_id'].apply(lambda x: x.split('/')[-2]).unique() #[valid_id1, valid_id2, valid_id3 ]
        val_img_id1 = unique_value[0] ; val_img_id2 = unique_value[1] ; val_img_id3= unique_value[2]
        val_df1= val_df[val_df['img_id']==val_img_id1].reset_index(drop=True)
        val_df2= val_df[val_df['img_id']==val_img_id2].reset_index(drop=True)
        val_df3= val_df[val_df['img_id']==val_img_id3].reset_index(drop=True)
        #####################################################
        # val loader1
        valid_dataset1 = HuDataset(
            df = val_df1
            ,
        )
        valid_loader1 = DataLoader(
            valid_dataset1,
            sampler = SequentialSampler(valid_dataset1),
            batch_size  = args.batch_size,
            drop_last   = False,
            num_workers = 4,
            pin_memory  = True,
            collate_fn  = null_collate
        )
        # val loader2
        valid_dataset2 = HuDataset(
            df = val_df2
            ,
        )
        
        valid_loader2 = DataLoader(
            valid_dataset2,
            sampler = SequentialSampler(valid_dataset2),
            batch_size  = args.batch_size,
            drop_last   = False,
            num_workers = 4,
            pin_memory  = True,
            collate_fn  = null_collate
        )
        # val loader3
        valid_dataset3 = HuDataset(
            df = val_df3
            ,
        )
        valid_loader3 = DataLoader(
            valid_dataset3,
            sampler = SequentialSampler(valid_dataset3),
            batch_size  = args.batch_size,
            drop_last   = False,
            num_workers = 4,
            pin_memory  = True,
            collate_fn  = null_collate
        )
        # ------------------------
        #  Model
        # ------------------------

        scaler = GradScaler()
        models = SegModel() 
        net = models[n_fold].to(device)
        state_dict = torch.load(args.model_path[n_fold], map_location=lambda storage, loc: storage)['state_dict']
        # 병렬처리를 했으면 앞에 module이 붙으므로 키를 바꿔줘야 한다. 
        for key in list(state_dict.keys()):
            if "module." in key:
                state_dict[key.replace("module.", "")] = state_dict[key]
                del state_dict[key]
        net.load_state_dict(state_dict,strict=True)  #True
        net = net.eval()
        
        print("model load success!!!")
        # scheudler
        valid_loss1 = do_valid(net, valid_loader1) #
        valid_loss2 = do_valid(net, valid_loader2)
        valid_loss3 = do_valid(net, valid_loader3)
        valid_loss = (valid_loss1 + valid_loss2 + valid_loss3)/3
        
        all_dice.append(valid_loss)
        
    dice = sum(all_dice)/len(all_dice)
    for n, th in enumerate(np.arange(0.1, 0.7, 0.05)):
        th = round(th, 2)
        print(f'th:{th}, dice score : {dice[n] : .4f}')

In [22]:
"""red is real"""
if 1: #normal
    if __name__ == '__main__':
        if args.mode == 'eval':
            eval_image(args)
        elif args.mode =='gen_image':
            gen_val_image(args)

model load success!!!
model load success!!!
model load success!!!
model load success!!!
model load success!!!
th:0.1, dice score :  0.9144
th:0.15, dice score :  0.9242
th:0.2, dice score :  0.9307
th:0.25, dice score :  0.9351
th:0.3, dice score :  0.9381
th:0.35, dice score :  0.9399
th:0.4, dice score :  0.9409
th:0.45, dice score :  0.9411
th:0.5, dice score :  0.9406
th:0.55, dice score :  0.9396
th:0.6, dice score :  0.9380
th:0.65, dice score :  0.9356


# submission

In [26]:
class args:
    # ---- factor ---- #
    server ='kaggle' # ['kaggle', 'local'] local은 cv측정용도
    amp = False
    gpu = 4
    
    encoder='b4'#'resnet34'
    decoder='unet'
    
    diff_arch = True
    encoders = ["efficientnet-b4", "efficientnet-b4", "efficientnet-b4", "efficientnet-b4", "efficientnet-b4"]
    decoders = ["unet", "unet", "unet", "unet", "unet"]
    n_fold = 5
    batch_size=8
    clf_head=False
    
    threshold = 0.45
    
    model_path = '../hubmap/result/'

    en_model_path = ["./data/result/50_['efficientnet-b4', 'efficientnet-b4', 'efficientnet-b4', 'efficientnet-b4', 'efficientnet-b4']_['unet', 'unet', 'unet', 'unet', 'unet']_1024_1280_640_0.5_b4_1024" +
                  "/checkpoint/" + x for x in \
                 ['0fold_49epoch_0.9421_efficientnet-b4_unetmodel.pth','1fold_47epoch_0.9374_efficientnet-b4_unetmodel.pth',
                 '2fold_45epoch_0.9418_efficientnet-b4_unetmodel.pth','3fold_10epoch_0.9536_efficientnet-b4_unetmodel.pth',
                 '4fold_50epoch_0.9283_efficientnet-b4_unetmodel.pth']]
    sub = '30epoch_imagefold_0.9338_320_160'# submission name
    
    # ---- Dataset ---- #
    
    tile_size = 1280
    tile_average_step = 640
    tile_scale = 0.5
    tile_min_score = 0.25  

assert args.server!='local', 'not implement'
device = torch.device(f"cuda" if torch.cuda.is_available() else "cpu")

In [27]:
thres = args.threshold

prob = []

def mask_to_csv(image_id, submit_dir):

    predicted = []
    for id in image_id:
        image_file = data_dir + '/test/%s.tiff' % id
        image = read_tiff(image_file)

        height, width = image.shape[:2]
        predict_file = submit_dir + '/%s.predict.png' % id
        # predict = cv2.imread(predict_file, cv2.IMREAD_GRAYSCALE)
        predict = np.array(Image.open(predict_file))
        predict = cv2.resize(predict, dsize=(width, height), interpolation=cv2.INTER_LINEAR)
        predict = (predict > 128).astype(np.uint8) * 255

        p = rle_encode(predict)
        predicted.append(p)

    df = pd.DataFrame()
    df['id'] = image_id
    df['predicted'] = predicted
    return df

def run_submit(args):

    #fold = 6
    out_dir = args.model_path.split('checkpoint')[0]
    initial_checkpoint = out_dir + '/checkpoint' + args.model_path.split('checkpoint')[1]
    
    # local은 cv측정 용도

    server = args.server#'kaggle' , 'local'

    #---
    submit_dir = out_dir + '/test/%s-%s-mean-thres(%s)'%(server, initial_checkpoint[-18:-4],thres)
    os.makedirs(submit_dir,exist_ok=True)

    log = Logger()
    log.open(out_dir+'/log.submit.txt',mode='a')

    #---
    if server == 'local':
        valid_image_id = make_image_id('valid-%d' % fold)
    if server == 'kaggle':
        valid_image_id = make_image_id('test-all')

    if server == 'local':
        tile_size = args.tile_size #320
        tile_average_step = args.tile_average_step#320 #192
        tile_scale = args.tile_scale
        tile_min_score = args.tile_min_score
    if server == 'kaggle' :
        tile_size = args.tile_size#640#640 #320
        tile_average_step = args.tile_average_step#320#320 #192
        tile_scale = args.tile_scale#0.25
        tile_min_score = args.tile_min_score#0.25   

    log.write('tile_size = %d \n'%tile_size)
    log.write('tile_average_step = %d \n'%tile_average_step)
    log.write('tile_scale = %f \n'%tile_scale)
    log.write('tile_min_score = %f \n'%tile_min_score)
    log.write('\n')

    
    # ----- model -------
    net = SegModel() 
    net.to(device)
    state_dict = torch.load(initial_checkpoint, map_location=lambda storage, loc: storage)['state_dict']
    net.load_state_dict(state_dict,strict=True)  #True
    net = net.eval()
    
    start_timer = timer()
    for id in valid_image_id:
        if server == 'local':
            image_file = data_dir + '/train/%s.tiff' % id
            image = read_tiff(image_file)
            height, width = image.shape[:2]

            json_file  = data_dir + '/train/%s-anatomical-structure.json' % id
            structure = draw_strcuture_from_hue(image, fill=255, scale=tile_scale/32)   
            mask_file = data_dir + '/train/%s.mask.png' % id
            mask  = read_mask(mask_file)

        if server == 'kaggle':
            image_file = data_dir + '/test/%s.tiff' % id
            json_file  = data_dir + '/test/%s-anatomical-structure.json' % id

            image = read_tiff(image_file)
            height, width = image.shape[:2]
            structure = draw_strcuture(read_json_as_df(json_file), height, width, structure=['Cortex'])

            mask = None


        #--- predict here!  ---
        tile = to_tile(image, mask, structure, tile_scale, tile_size, tile_average_step, tile_min_score)

        tile_image = tile['tile_image']
        tile_image = np.stack(tile_image)[..., ::-1]
        tile_image = np.ascontiguousarray(tile_image.transpose(0,3,1,2))
        tile_image = tile_image.astype(np.float32)/255
        print(tile_image.shape)
        tile_probability = []
        
        batch = np.array_split(tile_image, len(tile_image)//4)
        for t,m in enumerate(batch):
            print('\r %s  %d / %d   %s'%(id, t, len(batch), time_to_str(timer() - start_timer, 'sec')), end='',flush=True)
            m = torch.from_numpy(m).to(device)

            p = []
            with torch.no_grad():
                logit = net(m)
                p.append(torch.sigmoid(logit))

                #---
                if server == 'kaggle':
                    if 1: #tta here
                        logit = net(m.flip(dims=(2,)))
                        p.append(torch.sigmoid(logit.flip(dims=(2,))))

                        logit = net(m.flip(dims=(3,)))
                        p.append(torch.sigmoid(logit.flip(dims=(3,))))
                    p = torch.stack(p).mean(0)
                if server == 'local':
                    if 0: #tta here
                        #logit = data_parallel(net, m.flip(dims=(2,)))
                        logit = net(m.flip(dims=(2,)))
                        p.append(torch.sigmoid(logit.flip(dims=(2,))))

                        #logit = data_parallel(net, m.flip(dims=(3,)))
                        logit = net(m.flip(dims=(3,)))
                        p.append(torch.sigmoid(logit.flip(dims=(3,))))
                    p = torch.cat(p)
                    #p = torch.stack(p)

            tile_probability.append(p.data.cpu().numpy())

        print('\r' , end='',flush=True)
        log.write('%s  %d / %d   %s\n'%(id, t, len(batch), time_to_str(timer() - start_timer, 'sec')))

        tile_probability = np.concatenate(tile_probability).squeeze(1)
        height, width = tile['image_small'].shape[:2]
        probability = to_mask(tile_probability, tile['coord'], height, width,
                              tile_scale, tile_size, tile_average_step, tile_min_score,
                              aggregate='mean')
        

        #--- show results ---
        if server == 'local':
            truth = tile['mask_small'].astype(np.float32)/255
            truth2 = np.concatenate(tile['tile_mask']).astype(np.float32)/255
        if server == 'kaggle':
            truth = np.zeros((height, width), np.float32)

        overlay = np.dstack([
            np.zeros_like(truth),
            probability, #green
            truth, #red
        ])
        image_small = tile['image_small'].astype(np.float32)/255
        predict = (probability>thres).astype(np.float32)
        overlay1 = 1-(1-image_small)*(1-overlay)
        overlay2 = image_small.copy()
        overlay2 = draw_contour_overlay(overlay2, tile['structure_small'], color=(1, 1, 1), thickness=3)
        overlay2 = draw_contour_overlay(overlay2, truth, color=(0, 0, 1), thickness=8)
        overlay2 = draw_contour_overlay(overlay2, probability, color=(0, 1, 0), thickness=3)

        if 1:
            cv2.imwrite(submit_dir+'/%s.image_small.png'%id, (image_small*255).astype(np.uint8))
            cv2.imwrite(submit_dir+'/%s.probability.png'%id, (probability*255).astype(np.uint8))
            cv2.imwrite(submit_dir+'/%s.predict.png'%id, (predict*255).astype(np.uint8))
            cv2.imwrite(submit_dir+'/%s.overlay.png'%id, (overlay*255).astype(np.uint8))
            cv2.imwrite(submit_dir+'/%s.overlay1.png'%id, (overlay1*255).astype(np.uint8))
            cv2.imwrite(submit_dir+'/%s.overlay2.png'%id, (overlay2*255).astype(np.uint8))

        #---

        if server == 'local':

            loss = np_binary_cross_entropy_loss(probability, truth)
            dice = np_dice_score(probability, truth) # 여기는 큰이미지로 바꾼상태에서 dice
            dice2 = np_dice_score(tile_probability, truth2) # 작은이미지상태, 즉 training과 같은 cv구할려고 dice
            tp, tn = np_accuracy(probability, truth)
            log.write('submit_dir = %s \n'%submit_dir)
            log.write('initial_checkpoint = %s \n'%initial_checkpoint)
            log.write('loss   = %0.8f \n'%loss)
            log.write('dice   = %0.8f \n'%dice)
            log.write('dice2   = %0.8f \n'%dice2)
            log.write('tp, tn = %0.8f, %0.8f \n'%(tp, tn))
            log.write('\n')
            #cv2.waitKey(0)

    #-----
    if server == 'kaggle':
        csv_file = submit_dir + args.sub+'.csv'
        df = mask_to_csv(valid_image_id, submit_dir)
        df.to_csv(csv_file, index=False)
        print(df)

    zz=0
    
def run_submit_ensemble(args):

    #fold = 6
    out_dir = args.en_model_path[0].split('checkpoint')[0]
    
    
    # local은 cv측정 용도

    server = args.server#'kaggle' , 'local'

    #---
    submit_dir = out_dir + '/test/%s-%s-thres(%s)'%(server, args.sub,thres)
    os.makedirs(submit_dir,exist_ok=True)

    log = Logger()
    log.open(out_dir+'/log.submit.txt',mode='a')

    #---
    if server == 'local':
        valid_image_id = make_image_id('valid-%d' % fold)
    if server == 'kaggle':
        valid_image_id = make_image_id('test-all')

    if server == 'local':
        tile_size = args.tile_size #320
        tile_average_step = args.tile_average_step#320 #192
        tile_scale = args.tile_scale
        tile_min_score = args.tile_min_score
    if server == 'kaggle' :
        tile_size = args.tile_size#640#640 #320
        tile_average_step = args.tile_average_step#320#320 #192
        tile_scale = args.tile_scale#0.25
        tile_min_score = args.tile_min_score#0.25   

    log.write('tile_size = %d \n'%tile_size)
    log.write('tile_average_step = %d \n'%tile_average_step)
    log.write('tile_scale = %f \n'%tile_scale)
    log.write('tile_min_score = %f \n'%tile_min_score)
    log.write('\n')

    
    
    start_timer = timer()
    for id in valid_image_id:
        fold_prob = []
        models = SegModel()
        for i, m_p in enumerate(args.en_model_path):
            initial_checkpoint = m_p
            # ----- model -------
            net = models[i]
            net.to(device)
            state_dict = torch.load(initial_checkpoint, map_location=lambda storage, loc: storage)['state_dict']
            for key in list(state_dict.keys()):
                if "module." in key:
                    state_dict[key.replace("module.", "")] = state_dict[key]
                    del state_dict[key]
            net.load_state_dict(state_dict,strict=True)  #True
            net = net.eval()
            print("model load success!!!")
            if server == 'local':
                image_file = data_dir + '/train/%s.tiff' % id
                image = read_tiff(image_file)
                height, width = image.shape[:2]

                json_file  = data_dir + '/train/%s-anatomical-structure.json' % id
                structure = draw_strcuture_from_hue(image, fill=255, scale=tile_scale/32)   
                mask_file = data_dir + '/train/%s.mask.png' % id
                mask  = read_mask(mask_file)

            if server == 'kaggle':
                image_file = data_dir + '/test/%s.tiff' % id
                json_file  = data_dir + '/test/%s-anatomical-structure.json' % id

                image = read_tiff(image_file)
                height, width = image.shape[:2]
                structure = draw_strcuture(read_json_as_df(json_file), height, width, structure=['Cortex'])

                mask = None


            #--- predict here!  ---
            tile = to_tile(image, mask, structure, tile_scale, tile_size, tile_average_step, tile_min_score)

            tile_image = tile['tile_image']
            tile_image = np.stack(tile_image)[..., ::-1]
            tile_image = np.ascontiguousarray(tile_image.transpose(0,3,1,2))
            tile_image = tile_image.astype(np.float32)/255
            print(tile_image.shape)
            tile_probability = []

            batch = np.array_split(tile_image, len(tile_image)//4)
            for t,m in enumerate(batch):
                print('\r %s  %d / %d   %s'%(id, t, len(batch), time_to_str(timer() - start_timer, 'sec')), end='',flush=True)
                m = torch.from_numpy(m).to(device)

                p = []
                with torch.no_grad():
                    logit = net(m)
                    p.append(torch.sigmoid(logit))

                    #---
                    if server == 'kaggle':
                        if 1: #tta here
                            logit = net(m.flip(dims=(2,)))
                            p.append(torch.sigmoid(logit.flip(dims=(2,))))

                            logit = net(m.flip(dims=(3,)))
                            p.append(torch.sigmoid(logit.flip(dims=(3,))))
                        p = torch.stack(p).mean(0)
                    if server == 'local':
                        if 0: #tta here
                            #logit = data_parallel(net, m.flip(dims=(2,)))
                            logit = net(m.flip(dims=(2,)))
                            p.append(torch.sigmoid(logit.flip(dims=(2,))))

                            #logit = data_parallel(net, m.flip(dims=(3,)))
                            logit = net(m.flip(dims=(3,)))
                            p.append(torch.sigmoid(logit.flip(dims=(3,))))
                        p = torch.cat(p)
                        #p = torch.stack(p)

                tile_probability.append(p.data.cpu().numpy())

            print('\r' , end='',flush=True)
            log.write('%s  %d / %d   %s\n'%(id, t, len(batch), time_to_str(timer() - start_timer, 'sec')))

            tile_probability = np.concatenate(tile_probability).squeeze(1)
            height, width = tile['image_small'].shape[:2]
            probability = to_mask(tile_probability, tile['coord'], height, width,
                                  tile_scale, tile_size, tile_average_step, tile_min_score,
                                  aggregate='mean')

            fold_prob.append(probability)
        
        probability = sum(fold_prob)/len(args.en_model_path)
        #--- show results ---
        if server == 'local':
            truth = tile['mask_small'].astype(np.float32)/255
            truth2 = np.concatenate(tile['tile_mask']).astype(np.float32)/255
        if server == 'kaggle':
            truth = np.zeros((height, width), np.float32)

        overlay = np.dstack([
            np.zeros_like(truth),
            probability, #green
            truth, #red
        ])
        image_small = tile['image_small'].astype(np.float32)/255
        predict = (probability>thres).astype(np.float32)
        overlay1 = 1-(1-image_small)*(1-overlay)
        overlay2 = image_small.copy()
        overlay2 = draw_contour_overlay(overlay2, tile['structure_small'], color=(1, 1, 1), thickness=3)
        overlay2 = draw_contour_overlay(overlay2, truth, color=(0, 0, 1), thickness=8)
        overlay2 = draw_contour_overlay(overlay2, probability, color=(0, 1, 0), thickness=3)

        if 1:
            cv2.imwrite(submit_dir+'/%s.image_small.png'%id, (image_small*255).astype(np.uint8))
            cv2.imwrite(submit_dir+'/%s.probability.png'%id, (probability*255).astype(np.uint8))
            cv2.imwrite(submit_dir+'/%s.predict.png'%id, (predict*255).astype(np.uint8))
            cv2.imwrite(submit_dir+'/%s.overlay.png'%id, (overlay*255).astype(np.uint8))
            cv2.imwrite(submit_dir+'/%s.overlay1.png'%id, (overlay1*255).astype(np.uint8))
            cv2.imwrite(submit_dir+'/%s.overlay2.png'%id, (overlay2*255).astype(np.uint8))

        #---

        if server == 'local':

            loss = np_binary_cross_entropy_loss(probability, truth)
            dice = np_dice_score(probability, truth) # 여기는 큰이미지로 바꾼상태에서 dice
            dice2 = np_dice_score(tile_probability, truth2) # 작은이미지상태, 즉 training과 같은 cv구할려고 dice
            tp, tn = np_accuracy(probability, truth)
            log.write('submit_dir = %s \n'%submit_dir)
            log.write('initial_checkpoint = %s \n'%initial_checkpoint)
            log.write('loss   = %0.8f \n'%loss)
            log.write('dice   = %0.8f \n'%dice)
            log.write('dice2   = %0.8f \n'%dice2)
            log.write('tp, tn = %0.8f, %0.8f \n'%(tp, tn))
            log.write('\n')
            #cv2.waitKey(0)

    #-----
    if server == 'kaggle':
        csv_file = submit_dir +'.csv'
        df = mask_to_csv(valid_image_id, submit_dir)
        df.to_csv(csv_file, index=False)
        print(df)

    zz=0

In [28]:

if 0: #normal
    if __name__ == '__main__':
        run_submit(args)
elif 1:# ensemble
    if __name__ == '__main__':
        run_submit_ensemble(args)

tile_size = 1280 
tile_average_step = 640 
tile_scale = 0.500000 
tile_min_score = 0.250000 

model load success!!!
(364, 3, 1280, 1280)
2ec3f1bb9  90 / 91    1 min 38 secc
model load success!!!
(364, 3, 1280, 1280)
2ec3f1bb9  90 / 91    3 min 21 secc
model load success!!!
(364, 3, 1280, 1280)
2ec3f1bb9  90 / 91    5 min 05 secc
model load success!!!
(364, 3, 1280, 1280)
2ec3f1bb9  90 / 91    6 min 48 secc
model load success!!!
(364, 3, 1280, 1280)
2ec3f1bb9  90 / 91    8 min 32 secc
model load success!!!
(139, 3, 1280, 1280)
3589adb90  33 / 34   10 min 18 secc
model load success!!!
(139, 3, 1280, 1280)
3589adb90  33 / 34   10 min 59 secc
model load success!!!
(139, 3, 1280, 1280)
3589adb90  33 / 34   11 min 39 secc
model load success!!!
(139, 3, 1280, 1280)
3589adb90  33 / 34   12 min 20 secc
model load success!!!
(139, 3, 1280, 1280)
3589adb90  33 / 34   13 min 01 secc
model load success!!!
(240, 3, 1280, 1280)
57512b7f1  59 / 60   15 min 42 secc
model load success!!!
(240, 3, 1280, 