In [14]:
# ------------Library--------------#
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.utils.data.dataset import Dataset
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from torch.utils.data.sampler import *

import torch.optim as optim
from torch.optim.optimizer import Optimizer, required
from torch.optim.lr_scheduler import MultiStepLR, CosineAnnealingWarmRestarts, CosineAnnealingLR
from torch.nn.parallel.data_parallel import data_parallel
from torch.nn.utils.rnn import *
from torch.cuda.amp import autocast, GradScaler

import segmentation_models_pytorch as smp

from PIL import Image
Image.MAX_IMAGE_PIXELS = None

import tifffile as tiff
import json



#
import pandas as pd
import cv2
import os
import random
import numpy as np
import math
import sys
from collections import defaultdict
import itertools as it
from timeit import default_timer as timer
from glob import glob
from tqdm import tqdm
#
from sklearn.model_selection import KFold
#from lib.net.lookahead import *
#from lib.net.radam import *
# constant #
PI  = np.pi
INF = np.inf
EPS = 1e-12



In [15]:
class args:
    # ---- factor ---- #
    amp = True  # mixed precision 나중에는 false
    gpu = 3
    dir = "100epoch_nooverlap_640_25_50"  # 로그 저장 폴더
    encoder='b4'#'resnet34'
    decoder='unet'
    batch_size= 16  #잘 안건드림 128이상은 별로 
    weight_decay=1e-6  
    n_fold=5
    fold=0 # [0, 1, 2, 3, 4] 
    all_fold_train = True # all fold training 총5번 돔 
    
    # ---- Dataset ---- #
    image_size=512
    dataset = '0.25_640_320_train' # dataset size
    overlap = False

    # ---- optimizer, scheduler .. ---- #
    epochs=100   # 바꿔보기
    opt =  'radam_look' # [adamw, radam_look]
    scheduler='CosineAnnealingLR' #'MultiStepLR' # ['ReduceLROnPlateau', 'CosineAnnealingLR', 'CosineAnnealingWarmRestarts']
    loss = 'bce' # lovasz
    factor=0.2 # ReduceLROnPlateau, MultiStepLR
    patience=2 # ReduceLROnPlateau
    eps=1e-6 # ReduceLROnPlateau
    T_max=10 # CosineAnnealingLR
    decay_epoch = [4, 8, 12]
    T_0=4 # CosineAnnealingWarmRestarts
    #encoder_lr=4e-4
    #decoder_lr=4e-4
    start_lr = 1e-3
    min_lr=1e-6
    
    # ---- Else ---- #
    num_workers=8
    seed=42
    
data_dir = "/home/jeonghokim/competition/HubMap/data/"
##----------------


# useful function

In [3]:
#-------evaluation metric---------#
###################################
def np_binary_cross_entropy_loss(probability, mask):
    p = probability.reshape(-1)
    t = mask.reshape(-1)

    #---
    logp = -np.log(np.clip(p,1e-6,1))
    logn = -np.log(np.clip(1-p,1e-6,1))
    loss = t*logp +(1-t)*logn
    loss = loss.mean()
    return loss

def np_dice_score(probability, mask):
    p = probability.reshape(-1)
    t = mask.reshape(-1)

    p = p>0.5
    t = t>0.5
    uion = p.sum() + t.sum()
    overlap = (p*t).sum()
    dice = 2*overlap/(uion+0.001)
    return dice

def np_accuracy(probability, mask):
    p = probability.reshape(-1)
    t = mask.reshape(-1)
    p = p>0.5
    t = t>0.5
    tp = (p*t).sum()/(t).sum()
    tn = ((1-p)*(1-t)).sum()/(1-t).sum()
    return tp, tn

def criterion_binary_cross_entropy(logit, mask):
    logit = logit.reshape(-1)
    mask = mask.reshape(-1)

    loss = F.binary_cross_entropy_with_logits(logit, mask)
    return loss

# threshold dice score
def np_dice_score2(probability, mask, threshold):
    p = probability.reshape(-1)
    t = mask.reshape(-1)

    p = p>threshold
    t = t>0.5
    uion = p.sum() + t.sum()
    overlap = (p*t).sum()
    dice = 2*overlap/(uion+0.001)
    return dice

###################################
#-------ELSE function---------#
###################################
def set_seeds(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True # for faster training, but not deterministic
class Logger(object):
    def __init__(self):
        self.terminal = sys.stdout  #stdout
        self.file = None

    def open(self, file, mode=None):
        if mode is None: mode ='w'
        self.file = open(file, mode)

    def write(self, message, is_terminal=1, is_file=1 ):
        if '\r' in message: is_file=0

        if is_terminal == 1:
            self.terminal.write(message)
            self.terminal.flush()
            #time.sleep(1)

        if is_file == 1:
            self.file.write(message)
            self.file.flush()

    def flush(self):
        # this flush method is needed for python 3 compatibility.
        # this handles the flush command by doing nothing.
        # you might want to specify some extra behavior here.
        pass
def print_args(args, logger=None):
    for k, v in vars(args).items():
        if logger is not None:
            logger.write('{:<16} : {}'.format(k, v))
        else:
            print('{:<16} : {}'.format(k, v))
def time_to_str(t, mode='min'):
    if mode=='min':
        t  = int(t)/60
        hr = t//60
        min = t%60
        return '%2d hr %02d min'%(hr,min)

    elif mode=='sec':
        t   = int(t)
        min = t//60
        sec = t%60
        return '%2d min %02d sec'%(min,sec)

    else:
        raise NotImplementedError
def get_learning_rate(optimizer):
    lr=[]
    for param_group in optimizer.param_groups:
        lr +=[ param_group['lr'] ]

    assert(len(lr)==1) #we support only one param_group
    lr = lr[0]

    return lr

In [4]:
#-------masking & tile & decode---------#
def read_tiff(image_file):
    """
    *data size*
    e.g.) (3, w, h) or (1,1,3,w,h) or (w, h, 3)  --> transform --> (w, h, 3)
    """
    image = tiff.imread(image_file)
    if image.shape[0] == 1:
        image = image[0][0]
        image = image.transpose(1, 2, 0)
        image = np.ascontiguousarray(image)
    elif image.shape[0] == 3:
        image = image.transpose(1, 2, 0)
        image = np.ascontiguousarray(image)
    return image

def read_mask(mask_file):
    mask = np.array(Image.open(mask_file))
    return mask

def read_json_as_df(json_file):
    with open(json_file) as f:
        j = json.load(f)
    df = pd.json_normalize(j)
    return df


def draw_strcuture(df, height, width, fill=255, structure=[]):
    mask = np.zeros((height, width), np.uint8)
    for row in df.values:
        type  = row[2]  #geometry.type
        coord = row[3]  # geometry.coordinates
        name  = row[4]   # properties.classification.name

        if structure !=[]:
            if not any(s in name for s in structure): continue


        if type=='Polygon':
            pt = np.array(coord).astype(np.int32)
            #cv2.polylines(mask, [coord.reshape((-1, 1, 2))], True, 255, 1)
            cv2.fillPoly(mask, [pt.reshape((-1, 1, 2))], fill)

        if type=='MultiPolygon':
            for pt in coord:
                pt = np.array(pt).astype(np.int32)
                cv2.fillPoly(mask, [pt.reshape((-1, 1, 2))], fill)

    return mask

# resize, cvtcolor, generate mask
# 원하는 object 영역만 따오는 mask
def draw_strcuture_from_hue(image, fill=255, scale=1/32): # 0.25/32 default
    height, width, _ = image.shape
    vv = cv2.resize(image, dsize=None, fx=scale, fy=scale, interpolation=cv2.INTER_LINEAR)
    vv = cv2.cvtColor(vv, cv2.COLOR_RGB2HSV)
    # image_show('v[0]', v[:,:,0])
    # image_show('v[1]', v[:,:,1])
    # image_show('v[2]', v[:,:,2])
    # cv2.waitKey(0)
    mask = (vv[:, :, 1] > 32).astype(np.uint8) # rgb2hsv를 하고나서 1채널에 대해 시행하면 원하는 object만 잘따온다.
    mask = mask*fill
    mask = cv2.resize(mask, dsize=(width, height), interpolation=cv2.INTER_LINEAR) # 다시 원래사이즈로 복구

    return mask

# --- rle ---------------------------------
def rle_decode(rle, height, width , fill=255):
    s = rle.split()
    start, length = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])]
    start -= 1
    mask = np.zeros(height*width, dtype=np.uint8)
    for i, l in zip(start, length):
        mask[i:i+l] = fill
    mask = mask.reshape(width,height).T
    mask = np.ascontiguousarray(mask)
    return mask


def rle_encode(mask):
    m = mask.T.flatten()
    m = np.concatenate([[0], m, [0]])
    run = np.where(m[1:] != m[:-1])[0] + 1
    run[1::2] -= run[::2]
    rle =  ' '.join(str(r) for r in run)
    return rle


# --- tile ---------------------------------
"""
-결국, tile_image, tile_mask만 가져다가 쓴다.
1. scale로 resize를 하고 image size와 step만큼 건너뛰며 이미지를 만든다.
2. 이때 일정 영역이 빈마스크면 데이터에서 제외한다.
3. 쌓은 image와 mask를 return
"""
def to_tile(image, mask, structure, scale, size, step, min_score): 
    half = size//2
    image_small = cv2.resize(image, dsize=None, fx=scale, fy=scale, interpolation=cv2.INTER_LINEAR) # defualt는 1/4만큼 w,h를 줄인다.
    height, width, _ = image_small.shape

    #make score
    structure_small = cv2.resize(structure, dsize=None, fx=scale, fy=scale, interpolation=cv2.INTER_LINEAR)
    vv = structure_small.astype(np.float32)/255

    #make coord
    xx = np.linspace(half, width  - half, int(np.ceil((width  - size) / step)))
    yy = np.linspace(half, height - half, int(np.ceil((height - size) / step)))
    xx = [int(x) for x in xx]
    yy = [int(y) for y in yy]

    coord  = []
    reject = []
    for cy in yy:
        for cx in xx:
            cv = vv[cy - half:cy + half, cx - half:cx + half].mean() # h, w // tiling한 마스크(structure)가 평균 0.25를 안넘으면 버린다.
            if cv>min_score: # min_score ,default:0.25, 0.25의 의미?, 타일링 이미지의 1/4는 object여야 한다는 의미?
                coord.append([cx,cy,cv])
            else:
                reject.append([cx,cy,cv])
    #-----
    if 1: # resize한 image를 tiling 하여 리스트만든다
        tile_image = []
        for cx,cy,cv in coord:
            t = image_small[cy - half:cy + half, cx - half:cx + half] # resize한 image에서 indexing만 하는과정
            assert (t.shape == (size, size, 3))
            tile_image.append(t)

    if mask is not None: # mask를 resize하고 tiling하여 리스트 만든다
        mask_small = cv2.resize(mask, dsize=None, fx=scale, fy=scale, interpolation=cv2.INTER_LINEAR)
        tile_mask = []
        for cx,cy,cv in coord:
            t = mask_small[cy - half:cy + half, cx - half:cx + half]
            assert (t.shape == (size, size))
            tile_mask.append(t)
    else:
        mask_small = None
        tile_mask  = None

    return {
        'image_small': image_small,
        'mask_small' : mask_small,
        'structure_small' : structure_small,
        'tile_image' : tile_image,
        'tile_mask'  : tile_mask,
        'coord'  : coord,
        'reject' : reject,
    }

"""
submission할때 쓰임
"""
def to_mask(tile, coord, height, width, scale, size, step, min_score, aggregate='mean'):

    half = size//2
    mask  = np.zeros((height, width), np.float32)

    if 'mean' in aggregate:
        w = np.ones((size,size), np.float32)

        #if 'sq' in aggregate:
        if 1:
            #https://stackoverflow.com/questions/17190649/how-to-obtain-a-gaussian-filter-in-python
            y,x = np.mgrid[-half:half,-half:half]
            y = half-abs(y)
            x = half-abs(x)
            w = np.minimum(x,y)
            w = w/w.max()#*2.5
            w = np.minimum(w,1)

        #--------------
        count = np.zeros((height, width), np.float32)
        for t, (cx, cy, cv) in enumerate(coord):
            mask [cy - half:cy + half, cx - half:cx + half] += tile[t]*w
            count[cy - half:cy + half, cx - half:cx + half] += w
               # see unet paper for "Overlap-tile strategy for seamless segmentation of arbitrary large images"
        m = (count != 0)
        mask[m] /= count[m]

    if aggregate=='max':
        for t, (cx, cy, cv) in enumerate(coord):
            mask[cy - half:cy + half, cx - half:cx + half] = np.maximum(
                mask[cy - half:cy + half, cx - half:cx + half], tile[t] )

    return mask

# --------------이 아래로 안씀 ------------------------------#



# --draw ------------------------------------------
"""
경계선을 그리게 만든다, 컨투어
하지만 안씀
"""
def mask_to_inner_contour(mask):
    mask = mask>0.5
    pad = np.lib.pad(mask, ((1, 1), (1, 1)), 'reflect')
    contour = mask & (
            (pad[1:-1,1:-1] != pad[:-2,1:-1]) \
          | (pad[1:-1,1:-1] != pad[2:,1:-1])  \
          | (pad[1:-1,1:-1] != pad[1:-1,:-2]) \
          | (pad[1:-1,1:-1] != pad[1:-1,2:])
    )
    return contour


def draw_contour_overlay(image, mask, color=(0,0,255), thickness=1):
    contour =  mask_to_inner_contour(mask)
    if thickness==1:
        image[contour] = color
    else:
        r = max(1,thickness//2)
        for y,x in np.stack(np.where(contour)).T:
            cv2.circle(image, (x,y), r, color, lineType=cv2.LINE_4 )
    return image


# make dataset

In [5]:
# ------ make dataset --------- #
#################################

# <todo> make difference scale tile

tile_scale = 0.25
tile_min_score = 0.25  # 
tile_size = 480#320  # 480 #
tile_average_step = 240#160 #240  # 160 #192


#make tile train image
# train,tiling (image,mask) png 저장용도
def run_make_train_tile():
    #tile_scale = 0.25
    #tile_min_score = 0.25
    #tile_size = 480#480#320  #480 #
    #tile_average_step = 240#240 #160 #240  # 192


    train_tile_dir = data_dir + f'/tile/{tile_scale}_{tile_size}_{tile_average_step}_train' #nipa2

    df_train = pd.read_csv(data_dir + '/train.csv')
    print(df_train)
    print(df_train.shape)
    
    df_all = []
    
    os.makedirs(train_tile_dir, exist_ok=True)
    for i in range(0,len(df_train)):
        id, encoding = df_train.iloc[i]
        # 1. image 불러오고
        image_file = data_dir + '/train/%s.tiff' % id
        image = read_tiff(image_file)

        height, width = image.shape[:2]
        #mask = rle_decode(encoding, height, width, 255)
        # 2. mask, target 불러오고
        mask_file = data_dir + '/train/%s.mask.png' % id
        mask = read_mask(mask_file)
        
        # 3. 일정영역,object만 표시한 mask불러오기.
        structure = draw_strcuture_from_hue(image, fill=255, scale=tile_scale/32)
        print(id, mask_file)
        
        # make tile
        # 4. 학습할 tile image, mask를 생성한다.
        tile = to_tile(image, mask, structure, tile_scale, tile_size, tile_average_step, tile_min_score)

        coord = np.array(tile['coord'])
        df_image = pd.DataFrame()
        df_image['cx']=coord[:,0].astype(np.int32)
        df_image['cy']=coord[:,1].astype(np.int32)
        df_image['cv']=coord[:,2]

        # --- save ---
        os.makedirs(train_tile_dir+'/%s'%id, exist_ok=True)

        tile_id =[]
        num = len(tile['tile_image'])
        for t in range(num):
            cx,cy,cv   = tile['coord'][t]
            s = '%s_y%08d_x%08d' % (id, cy, cx)
            tile_id.append(s)

            tile_image = tile['tile_image'][t]
            tile_mask  = tile['tile_mask'][t]
            cv2.imwrite(train_tile_dir + '/%s.png' % (s), tile_image)
            cv2.imwrite(train_tile_dir + '/%s.mask.png' % (s), tile_mask)

            #image_show('tile_image', tile_image)
            #image_show('tile_mask', tile_mask)
            #cv2.waitKey(1)


        df_image['tile_id']=tile_id
        df_all.append(df_image)
    df_all = pd.concat(df_all, 0).reset_index(drop=True)
    df_all[['tile_id','cx','cy','cv']].to_csv(train_tile_dir+'/image_id.csv', index=False)
        #------



#make tile train image
# test tiling image png 저장용도
def run_make_test_tile():
    #tile_scale = 0.25
    #tile_min_score = 0.25
    #tile_size = 480#320  # 480 #
    #tile_average_step = 240#160 #240  # 160 #192

    #test_tile_dir = '/home/ubuntu/gwang/hubmap/etc/tile/0.25_640_320_test'
    test_tile_dir = data_dir + f'/tile/{tile_scale}_{tile_size}_{tile_average_step}_test'
    #---


    os.makedirs(test_tile_dir, exist_ok=True)
    assert False, 'todo modify test file'
    for id in ['c68fe75ea','afa5e8098',]:
        print(id)

        # 1. test image load
        image_file = data_dir + '/test/%s.tiff' % id
        json_file  = data_dir + '/test/%s-anatomical-structure.json' % id

        image = read_tiff(image_file)
        height, width = image.shape[:2]

        mask = None
        # 2. test structure load
        structure = draw_strcuture(read_json_as_df(json_file), height, width, structure=['Cortex'])
        # structure = draw_strcuture_from_hue(image, fill=255, scale=tile_scale/32)

        # 3. test를 위한 tile image 생성
        #make tile
        tile = to_tile(image, mask, structure, tile_scale, tile_size, tile_average_step, tile_min_score)

        coord = np.array(tile['coord'])
        df_image = pd.DataFrame()
        df_image['cx']=coord[:,0].astype(np.int32)
        df_image['cy']=coord[:,1].astype(np.int32)
        df_image['cv']=coord[:,2]

        # --- save ---
        os.makedirs(test_tile_dir+'/%s'%id, exist_ok=True)

        tile_id =[]
        num = len(tile['tile_image'])
        for t in range(num):
            cx,cy,cv   = tile['coord'][t]
            s = 'y%08d_x%08d' % (cy, cx)
            tile_id.append(s)

            tile_image = tile['tile_image'][t]
            cv2.imwrite(test_tile_dir + '/%s/%s.png' % (id, s), tile_image)
            #image_show('tile_image', tile_image)
            #cv2.waitKey(1)


        df_image['tile_id']=tile_id
        df_image[['tile_id','cx','cy','cv']].to_csv(test_tile_dir+'/%s.csv'%id, index=False)
        #------


#make tile train image
# tile이 아닌 train image의 mask생성
def run_make_train_mask():

    df_train = pd.read_csv(data_dir + '/train.csv')
    print(df_train)
    print(df_train.shape)

    for i in range(0,len(df_train)):
        id, encoding = df_train.iloc[i]

        image_file = data_dir + '/train/%s.tiff' % id
        image = read_tiff(image_file)

        if image.shape[0]==1:
            image = image[0][0]
            image = image.transpose(1, 2, 0)
            image = np.ascontiguousarray(image)
            height, width = image.shape[:2]
        elif image.shape[0] == 3:
            image = image.transpose(1, 2, 0)
            image = np.ascontiguousarray(image)
            height, width = image.shape[:2]
        else:
            height, width = image.shape[:2]
        mask = rle_decode(encoding, height, width, 255)

        cv2.imwrite(data_dir + '/train/%s.mask.png' % id, mask)


#make tile train image
def run_make_pseudo_tile():

    
    tile_scale = 0.25
    tile_min_score = 0.25
    tile_size = 480#320  #480 #
    tile_average_step = 240 #160 #240  # 192
    #---
    pseudo_tile_dir = data_dir + f'/tile/{tile_scale}_{tile_size}_{tile_average_step}_pseudo_0.95'
    #df_train = pd.read_csv(data_dir + '/train.csv')
    #df_pseudo = pd.read_csv('/root/share1/kaggle/2020/hubmap/result/resnet34/fold2/submit-fold-2-resnet34-00010000_model_lb0.837.csv')
    df_pseudo = pd.read_csv('../../submission/0.891_submission-fold6-00004000_model_thres-0.9.csv')
    
    print(df_pseudo)
    print(df_pseudo.shape)

    os.makedirs(pseudo_tile_dir, exist_ok=True)
    for i in range(0,len(df_pseudo)):
        id, encoding = df_pseudo.iloc[i]

        image_file = data_dir + '/test/%s.tiff' % id
        image = read_tiff(image_file)

        height, width = image.shape[:2]
        mask = rle_decode(encoding, height, width, 255)

        #make tile
        structure = draw_strcuture_from_hue(image, fill=255, scale=tile_scale/32)

        tile = to_tile(image, mask, structure, tile_scale, tile_size, tile_average_step, tile_min_score)
        #to_tile(image, mask, structure, tile_scale, tile_size, tile_average_step, tile_min_score)

        #mask, scale, size, step, min_score

        coord = np.array(tile['coord'])
        df_image = pd.DataFrame()
        df_image['cx']=coord[:,0].astype(np.int32)
        df_image['cy']=coord[:,1].astype(np.int32)
        df_image['cv']=coord[:,2]

        # --- save ---
        os.makedirs(pseudo_tile_dir + '/%s'%id, exist_ok=True)

        tile_id =[]
        num = len(tile['tile_image'])
        for t in range(num):
            cx,cy,cv   = tile['coord'][t]
            s = 'y%08d_x%08d' % (cy, cx)
            tile_id.append(s)

            tile_image = tile['tile_image'][t]
            tile_mask  = tile['tile_mask'][t]
            cv2.imwrite(pseudo_tile_dir + '/%s/%s.png' % (id, s), tile_image)
            cv2.imwrite(pseudo_tile_dir + '/%s/%s.mask.png' % (id, s), tile_mask)

            #image_show('tile_image', tile_image)
            #image_show('tile_mask', tile_mask)
            #cv2.waitKey(1)


        df_image['tile_id']=tile_id
        df_image[['tile_id','cx','cy','cv']].to_csv(pseudo_tile_dir+'/%s.csv'%id, index=False)
        #------


# main #################################################################
if 0:
    if __name__ == '__main__':
        #print('started run make train mask')
        # 1.
        print('started 1')
        run_make_train_mask()
        # 2.
        print('started 2')
        run_make_train_tile()
        # 3.
        #print('started 3')
        #run_make_test_tile()
        # 4. if use pseudo datasets
        #run_make_pseudo_tile()
    
    

# Dataset & augmentation

In [6]:
#--------------- Dataset ----------------#
##########################################

#--------------- 
# Old version
#--------------- 
def make_image_id_v1(mode):
    train_image_id = {
        0 : '0486052bb', 1 : '095bf7a1f',
        2 : '1e2425f28', 3 : '26dc41664',
        4 : '2f6ecfcdf', 5 : '4ef6695ce',
        6 : '54f2eec69', 7 : '8242609fa',
        8 : 'aaa6a05cc', 9 : 'afa5e8098', 
        10 :'b2dc8411c', 11: 'b9a3865fc',
        12 :'c68fe75ea', 13: 'cb2d976f4',
        14 :'e79de561c'
    }

    test_image_id = {
        0 : '2ec3f1bb9', 1 : '3589adb90',
        2 : '57512b7f1', 3 : 'aa05346ff',
        4 : 'd488c759a',
    }
    if 'pseudo-all'==mode:
        test_id = [ test_image_id[i] for i in [0,1,2,3,4] ]
        return test_id

    if 'test-all'==mode:
        test_id = [ test_image_id[i] for i in [0,1,2,3,4] ] # list(test_image_id.values()) #
        return test_id

    if 'train-all'==mode:
        train_id = [ train_image_id[i] for i in [x for x in train_image_id] ] # list(test_image_id.values()) #
        return train_id

    if 'valid' in mode or 'train' in mode:
        fold = {int(x) for x in mode.split('-')[1].split(',')}
        #valid = [fold,]
        train = list({x for x in train_image_id}-fold)
        valid_id = [ train_image_id[i] for i in fold ]
        train_id = [ train_image_id[i] for i in train ]

        if 'valid' in mode: return valid_id
        if 'train' in mode: return train_id
class HuDataset_v1(Dataset):
    def __init__(self, image_id, image_dir, augment=None):
        self.augment = augment
        self.image_id = image_id
        self.image_dir = image_dir

        tile_id = []
        for i in range(len(image_dir)):
            for id in image_id[i]: 
                df = pd.read_csv(data_dir + '/tile/%s/%s.csv'% (self.image_dir[i],id) )
                tile_id += ('%s/%s/'%(self.image_dir[i],id) + df.tile_id).tolist()

        self.tile_id = tile_id
        self.len =len(self.tile_id)


    def __len__(self):
        return self.len

    def __str__(self):
        string  = ''
        string += '\tlen  = %d\n'%len(self)
        string += '\timage_dir = %s\n'%self.image_dir
        string += '\timage_id  = %s\n'%str(self.image_id)
        string += '\t          = %d\n'%sum(len(i) for i in self.image_id)
        return string


    def __getitem__(self, index):
        id = self.tile_id[index]
        image = cv2.imread(data_dir + '/tile/%s.png'%(id), cv2.IMREAD_COLOR)
        mask  = cv2.imread(data_dir + '/tile/%s.mask.png'%(id), cv2.IMREAD_GRAYSCALE)
        #print(data_dir + '/tile/%s/%s.png'%(self.image_dir,id))

        image = image.astype(np.float32) / 255
        mask  = mask.astype(np.float32) / 255
        r = {
            'index' : index,
            'tile_id' : id,
            'mask' : mask,
            'image' : image,
        }
        if self.augment is not None: r = self.augment(r)
        return r

#--------------- 
# New version
#--------------- 
def make_image_id(mode):
    train_image_id = {
        0 : '0486052bb', 1 : '095bf7a1f',
        2 : '1e2425f28', 3 : '26dc41664',
        4 : '2f6ecfcdf', 5 : '4ef6695ce',
        6 : '54f2eec69', 7 : '8242609fa',
        8 : 'aaa6a05cc', 9 : 'afa5e8098', 
        10 :'b2dc8411c', 11: 'b9a3865fc',
        12 :'c68fe75ea', 13: 'cb2d976f4',
        14 :'e79de561c'
    }

    test_image_id = {
        0 : '2ec3f1bb9', 1 : '3589adb90',
        2 : '57512b7f1', 3 : 'aa05346ff',
        4 : 'd488c759a',
    }
    if 'pseudo-all'==mode:
        test_id = [ test_image_id[i] for i in [0,1,2,3,4] ]
        return test_id

    if 'test-all'==mode:
        test_id = [ test_image_id[i] for i in [0,1,2,3,4] ] # list(test_image_id.values()) #
        return test_id

    if 'train-all'==mode:
        train_id = [ train_image_id[i] for i in [x for x in train_image_id] ] # list(test_image_id.values()) #
        return train_id

    if 'valid' in mode or 'train' in mode:
        fold = {int(x) for x in mode.split('-')[1].split(',')}
        #valid = [fold,]
        train = list({x for x in train_image_id}-fold)
        valid_id = [ train_image_id[i] for i in fold ]
        train_id = [ train_image_id[i] for i in train ]

        if 'valid' in mode: return valid_id
        if 'train' in mode: return train_id
class HuDataset(Dataset):
    def __init__(self, tile_id, augment=None):
        self.augment = augment

        self.tile_id = tile_id
        self.len =len(self.tile_id)


    def __len__(self):
        return self.len

    def __str__(self):
        string  = ''
        string += '\tlen  = %d\n'%len(self)
        return string


    def __getitem__(self, index):
        id = self.tile_id[index]
        image = cv2.imread(f'{data_dir}/tile/{args.dataset}/{id}.png', cv2.IMREAD_COLOR)
        mask  = cv2.imread(f'{data_dir}/tile/{args.dataset}/{id}.mask.png', cv2.IMREAD_GRAYSCALE)
        #print(data_dir + '/tile/%s/%s.png'%(self.image_dir,id))
        
        image = image.astype(np.float32) / 255
        mask  = mask.astype(np.float32) / 255
        r = {
            'index' : index,
            'tile_id' : id,
            'mask' : mask,
            'image' : image,
        }
        if self.augment is not None: r = self.augment(r)
        return r
    
    
# no over lapped dataset
class HuDataset_nol(Dataset):
    def __init__(self, data_dir : str, tile_info : str, fold : int, train=True, augment=None):
        self.augment = augment
        train_names = sorted(["1e2425f28", "2f6ecfcdf", "4ef6695ce", "26dc41664", "54f2eec69",
                      "095bf7a1f", "0486052bb", "8242609fa", "aaa6a05cc", "afa5e8098",
                      "b2dc8411c", "b9a3865fc", "c68fe75ea", "cb2d976f4", "e79de561c"])
        valid_names = []
        for i in range(3):
            valid_names.append(train_names.pop((fold-1)*3))
        self.img_paths = []
        self.mask_paths = []
        if train:
            for name in train_names:
                self.img_paths += glob(os.path.join(data_dir, "tile", tile_info, "{}_*.png".format(name)))
                self.mask_paths += glob(os.path.join(data_dir, "tile", tile_info, "{}_*.mask.png".format(name)))
            self.mask_paths.sort()
            self.img_paths = sorted(list(set(self.img_paths) - set(self.mask_paths)))
        else:
            for name in valid_names:
                self.img_paths += glob(os.path.join(data_dir, "tile", tile_info, "{}_*.png".format(name)))
                self.mask_paths += glob(os.path.join(data_dir, "tile", tile_info, "{}_*.mask.png".format(name)))
            self.mask_paths.sort()
            self.img_paths = sorted(list(set(self.img_paths) - set(self.mask_paths)))

        assert len(self.img_paths) == len(self.mask_paths), "different number of images!!!!"
        print("{} => # imgs : {}, # masks : {}".format("train" if train else "valid", len(self.img_paths), len(self.mask_paths)))
    def __len__(self):
        return len(self.img_paths)
    
    def __getitem__(self, idx):
        img = cv2.imread(self.img_paths[idx], cv2.IMREAD_COLOR)
        mask = cv2.imread(self.mask_paths[idx], cv2.IMREAD_GRAYSCALE)
        img = img.astype(np.float32) / 255.0
        mask = mask.astype(np.float32) / 255.0
        r = {
        'index' : idx,
        #'tile_id' : id,
        'mask' : mask,
        'image' : img,
        }
        if self.augment is not None: r = self.augment(r)
        return r
        
    

def null_collate(batch):
    batch_size = len(batch)
    index = []
    mask = []
    image = []
    for r in batch:
        index.append(r['index'])
        mask.append(r['mask'])
        image.append(r['image'])

    image = np.stack(image)
    image = image[...,::-1]
    image = image.transpose(0,3,1,2)
    image = np.ascontiguousarray(image)

    mask  = np.stack(mask)
    mask  = np.ascontiguousarray(mask)

    #---
    image = torch.from_numpy(image).contiguous().float()
    mask  = torch.from_numpy(mask).contiguous().unsqueeze(1)
    mask  = (mask>0.5).float()

    return {
        'index' : index,
        'mask' : mask,
        'image' : image,
    }


In [7]:
#---------- augmentation ---------------------#
###############################################
#flip
def do_random_flip_transpose(image, mask):
    if np.random.rand()>0.5:
        image = cv2.flip(image,0)
        mask = cv2.flip(mask,0)
    if np.random.rand()>0.5:
        image = cv2.flip(image,1)
        mask = cv2.flip(mask,1)
    if np.random.rand()>0.5:
        image = image.transpose(1,0,2)
        mask = mask.transpose(1,0)

    image = np.ascontiguousarray(image)
    mask = np.ascontiguousarray(mask)
    return image, mask

#geometric
def do_random_crop(image, mask, size):
    height, width = image.shape[:2]
    x = np.random.choice(width -size)
    y = np.random.choice(height-size)
    image = image[y:y+size,x:x+size]
    mask  = mask[y:y+size,x:x+size]
    return image, mask

def do_random_scale_crop(image, mask, size, mag):
    height, width = image.shape[:2]

    s = 1 + np.random.uniform(-1, 1)*mag
    s =  int(s*size)

    x = np.random.choice(width -s)
    y = np.random.choice(height-s)
    image = image[y:y+s,x:x+s]
    mask  = mask[y:y+s,x:x+s]
    if s!=size:
        image = cv2.resize(image, dsize=(size,size), interpolation=cv2.INTER_LINEAR)
        mask  = cv2.resize(mask, dsize=(size,size), interpolation=cv2.INTER_LINEAR)
    return image, mask

def do_random_rotate_crop(image, mask, size, mag=30 ):
    angle = 1+np.random.uniform(-1, 1)*mag

    height, width = image.shape[:2]
    dst = np.array([
        [0,0],[size,size], [size,0], [0,size],
    ])

    c = np.cos(angle/180*2*PI)
    s = np.sin(angle/180*2*PI)
    src = (dst-size//2)@np.array([[c, -s],[s, c]]).T
    src[:,0] -= src[:,0].min()
    src[:,1] -= src[:,1].min()

    src[:,0] = src[:,0] + np.random.uniform(0,width -src[:,0].max())
    src[:,1] = src[:,1] + np.random.uniform(0,height-src[:,1].max())

    if 0: #debug
        def to_int(f):
            return (int(f[0]),int(f[1]))

        cv2.line(image, to_int(src[0]), to_int(src[1]), (0,0,1), 16)
        cv2.line(image, to_int(src[1]), to_int(src[2]), (0,0,1), 16)
        cv2.line(image, to_int(src[2]), to_int(src[3]), (0,0,1), 16)
        cv2.line(image, to_int(src[3]), to_int(src[0]), (0,0,1), 16)
        image_show_norm('image', image, min=0, max=1)
        cv2.waitKey(1)


    transform = cv2.getAffineTransform(src[:3].astype(np.float32), dst[:3].astype(np.float32))
    image = cv2.warpAffine( image, transform, (size, size), flags=cv2.INTER_LINEAR,
                                 borderMode=cv2.BORDER_CONSTANT, borderValue=(0,0,0))
    mask  = cv2.warpAffine( mask, transform, (size, size), flags=cv2.INTER_LINEAR,
                                 borderMode=cv2.BORDER_CONSTANT, borderValue=0)
    return image, mask

#warp/elastic deform ...
#<todo>

#noise
def do_random_noise(image, mask, mag=0.1):
    height, width = image.shape[:2]
    noise = np.random.uniform(-1,1, (height, width,1))*mag
    image = image + noise
    image = np.clip(image,0,1)
    return image, mask


#intensity
def do_random_contast(image, mask, mag=0.3):
    alpha = 1 + random.uniform(-1,1)*mag
    image = image * alpha
    image = np.clip(image,0,1)
    return image, mask

def do_random_gain(image, mask, mag=0.3):
    alpha = 1 + random.uniform(-1,1)*mag
    image = image ** alpha
    image = np.clip(image,0,1)
    return image, mask

def do_random_hsv(image, mask, mag=[0.15,0.25,0.25]):
    image = (image*255).astype(np.uint8)
    hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)

    h = hsv[:, :, 0].astype(np.float32)  # hue
    s = hsv[:, :, 1].astype(np.float32)  # saturation
    v = hsv[:, :, 2].astype(np.float32)  # value
    h = (h*(1 + random.uniform(-1,1)*mag[0]))%180
    s =  s*(1 + random.uniform(-1,1)*mag[1])
    v =  v*(1 + random.uniform(-1,1)*mag[2])

    hsv[:, :, 0] = np.clip(h,0,180).astype(np.uint8)
    hsv[:, :, 1] = np.clip(s,0,255).astype(np.uint8)
    hsv[:, :, 2] = np.clip(v,0,255).astype(np.uint8)
    image = cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR)
    image = image.astype(np.float32)/255
    return image, mask


def filter_small(mask, min_size):

    m = (mask*255).astype(np.uint8)

    num_comp, comp, stat, centroid = cv2.connectedComponentsWithStats(m, connectivity=8)
    if num_comp==1: return mask

    filtered = np.zeros(comp.shape,dtype=np.uint8)
    area = stat[:, -1]
    for i in range(1, num_comp):
        if area[i] >= min_size:
            filtered[comp == i] = 255
    return filtered

In [8]:
#---------- optimizer ---------------------#
############################################
class Lookahead(Optimizer):
    def __init__(self, optimizer, alpha=0.5, k=6):

        if not 0.0 <= alpha <= 1.0:
            raise ValueError(f'Invalid slow update rate: {alpha}')
        if not 1 <= k:
            raise ValueError(f'Invalid lookahead steps: {k}')

        self.optimizer = optimizer
        self.param_groups = self.optimizer.param_groups
        self.alpha = alpha
        self.k = k
        for group in self.param_groups:
            group["step_counter"] = 0

        self.slow_weights = [
                [p.clone().detach() for p in group['params']]
            for group in self.param_groups]

        for w in it.chain(*self.slow_weights):
            w.requires_grad = False
        self.state = optimizer.state

    def step(self, closure=None):
        loss = None
        if closure is not None:
            loss = closure()
        loss = self.optimizer.step()

        for group,slow_weights in zip(self.param_groups,self.slow_weights):
            group['step_counter'] += 1
            if group['step_counter'] % self.k != 0:
                continue
            for p,q in zip(group['params'],slow_weights):
                if p.grad is None:
                    continue
                q.data.add_(p.data - q.data, alpha=self.alpha )
                p.data.copy_(q.data)
        return loss
class RAdam(Optimizer):

    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0):
        defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
        self.buffer = [[None, None, None] for ind in range(10)]
        super(RAdam, self).__init__(params, defaults)

    def __setstate__(self, state):
        super(RAdam, self).__setstate__(state)

    def step(self, closure=None):

        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:

            for p in group['params']:
                if p.grad is None:
                    continue
                grad = p.grad.data.float()
                if grad.is_sparse:
                    raise RuntimeError('RAdam does not support sparse gradients')

                p_data_fp32 = p.data.float()

                state = self.state[p]

                if len(state) == 0:
                    state['step'] = 0
                    state['exp_avg'] = torch.zeros_like(p_data_fp32)
                    state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
                else:
                    state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
                    state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)

                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
                beta1, beta2 = group['betas']

                exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value = 1 - beta2)
                exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)

                state['step'] += 1
                buffered = self.buffer[int(state['step'] % 10)]
                if state['step'] == buffered[0]:
                    N_sma, step_size = buffered[1], buffered[2]
                else:
                    buffered[0] = state['step']
                    beta2_t = beta2 ** state['step']
                    N_sma_max = 2 / (1 - beta2) - 1
                    N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
                    buffered[1] = N_sma

                    # more conservative since it's an approximated value
                    if N_sma >= 5:
                        step_size = math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step'])
                    else:
                        step_size = 1.0 / (1 - beta1 ** state['step'])
                    buffered[2] = step_size

                if group['weight_decay'] != 0:
                    p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)

                # more conservative since it's an approximated value
                if N_sma >= 5:
                    denom = exp_avg_sq.sqrt().add_(group['eps'])
                    p_data_fp32.addcdiv_(exp_avg, denom, value=-step_size * group['lr'])
                else:
                    p_data_fp32.add_(exp_avg, alpha=-step_size * group['lr'])

                p.data.copy_(p_data_fp32)

        return loss


# Model

In [9]:
def SegModel():
    if args.encoder in ['b0','b1','b2','b3','b4','b5','b6','b7']:
        encoder_name_ = f'efficientnet-{args.encoder}' #'timm-efficientnet-b4'
        print('encoder : ', encoder_name_)
    else:
        encoder_name_ = args.encoder
    if args.decoder =='fpn':
        print('fpn loaded')
        model = smp.FPN(
            encoder_name=encoder_name_,        # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
            encoder_weights='imagenet',     # use `imagenet` pretrained weights for encoder initialization
            in_channels=3,                  # model input channels (1 for grayscale images, 3 for RGB, etc.)
            classes=1,                      # model output channels (number of classes in your dataset)
            )
    elif args.decoder =='unet':
        print('unet loaded')
        model = smp.Unet(
            encoder_name=encoder_name_,        # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
            encoder_weights='imagenet',     # use `imagenet` pretrained weights for encoder initialization
            in_channels=3,                  # model input channels (1 for grayscale images, 3 for RGB, etc.)
            classes=1,                      # model output channels (number of classes in your dataset)
            )
    return model

# train

In [10]:
# augmentation
def train_augment(record):
    image = record['image']
    mask  = record['mask']
    
    for fn in np.random.choice([
        lambda image, mask : do_random_rotate_crop(image, mask, size=args.image_size, mag=45),
        lambda image, mask : do_random_scale_crop(image, mask, size=args.image_size, mag=0.075),
        lambda image, mask : do_random_crop(image, mask, size=args.image_size),
    ],1): image, mask = fn(image, mask)

    #if (np.random.choice(10,1)<7)[0]:
    for fn in np.random.choice([
        lambda image, mask : (image, mask),
        lambda image, mask : do_random_contast(image, mask, mag=0.8),
        lambda image, mask : do_random_gain(image, mask, mag=0.9),
        #lambda image, mask : do_random_hsv(image, mask, mag=[0.1, 0.2, 0]),
        lambda image, mask : do_random_noise(image, mask, mag=0.1),
    ],2): image, mask =  fn(image, mask)
    #if (np.random.choice(10,1)<7)[0]:
    image, mask = do_random_hsv(image, mask, mag=[0.1, 0.2, 0])
    image, mask = do_random_flip_transpose(image, mask)

    record['mask'] = mask
    record['image'] = image
    return record

# validation
# validation
def do_valid(net, valid_loader):

    valid_num = 0
    total = 0 ; dice2=0 ; loss2=0
    valid_probability, valid_probability2 = [],[]
    valid_mask = []

    net = net.eval()
    lovasz_loss=0

    #start_timer = timer()
    with torch.no_grad():
        for t, batch in enumerate(valid_loader):
            batch_size = len(batch['index'])
            mask  = batch['mask']
            image = batch['image'].to(device)

            logit = net(image)#data_parallel(net, image) #net(input)#
            probability = torch.sigmoid(logit)

            valid_probability.append(probability.data.cpu().numpy())
            valid_mask.append(mask.data.cpu().numpy())

            valid_num += batch_size


    assert(valid_num == len(valid_loader.dataset))
    #print('')
    #------
    probability = np.concatenate(valid_probability)
    mask = np.concatenate(valid_mask)
    loss = np_binary_cross_entropy_loss(probability, mask)

    dice = np_dice_score(probability, mask)
    tp, tn = np_accuracy(probability, mask)

    return [dice, loss,  tp, tn]

def run_train(args):
    if args.overlap:
        out_dir = data_dir + f'/result/{args.dir}_fold{args.fold}_{args.encoder}_{args.image_size}'

        ## setup  ----------------------------------------
        for f in ['checkpoint','train','valid','backup'] : os.makedirs(out_dir +'/'+f, exist_ok=True)
        #backup_project_as_zip(PROJECT_PATH, out_dir +'/backup/code.train.%s.zip'%IDENTIFIER)
        log = Logger()
        log.open(out_dir+'/log.train.txt',mode='a')

        # my log argument
        print_args(args, log)

        log.write('\tout_dir  = %s\n' % out_dir)
        log.write('\n')


        log.write('** dataset setting **\n')
        #-----------dataset split --------------------#
        tile_id = []
        image_dir_ = f'{args.dataset}'#'0.25_480_240_train'
        image_dir=[image_dir_, ] # pseudo할때 뒤에 추가

        for i in range(len(image_dir)):
            df = pd.read_csv(data_dir + '/tile/%s/image_id.csv'% (image_dir[i]) )
            tile_id += ('%s/'%(image_dir[i]) + df.tile_id).tolist()

        kf = KFold(n_splits=args.n_fold, random_state=args.seed, shuffle=True)
        all_dice = []
        for n_fold, (trn_idx, val_idx) in enumerate(kf.split(tile_id)):
            if not args.all_fold_train:
                if n_fold != args.fold:
                    print(f'{n_fold} fold pass')
                    continue

            #####################################################
            train_dataset = HuDataset(
                tile_id = df.loc[trn_idx]['tile_id'].tolist(),
                augment = train_augment
            )
            train_loader  = DataLoader(
                train_dataset,
                sampler = RandomSampler(train_dataset),
                batch_size  = args.batch_size,
                drop_last   = False,
                num_workers = 8,
                pin_memory  = True,
                collate_fn  = null_collate
            )

            valid_dataset = HuDataset(
                tile_id = df.loc[val_idx]['tile_id'].tolist()
                ,
            )
            valid_loader = DataLoader(
                valid_dataset,
                sampler = SequentialSampler(valid_dataset),
                batch_size  = args.batch_size,
                drop_last   = False,
                num_workers = 4,
                pin_memory  = True,
                collate_fn  = null_collate
            )

            log.write('fold = %s\n'%str(n_fold))
            log.write('train_dataset : \n%s\n'%(train_dataset))
            log.write('valid_dataset : \n%s\n'%(valid_dataset))
            log.write('\n')

            ## net ----------------------------------------
            log.write('** net setting **\n')

            scaler = GradScaler()
            net = SegModel() 
            net = net.to(device)

            if args.opt =='adamw':
                optimizer = torch.optim.AdamW(net.parameters(), lr = args.start_lr)
                #optimizer = torch.optim.AdamW([
                #    {'params': model.decoder.parameters(), 'lr': start_lr}, 
                #    {'params': model.encoder.parameters(), 'lr': start_lr},  
                #])
            elif args.opt =='radam_look':
                optimizer = Lookahead(RAdam(filter(lambda p: p.requires_grad, net.parameters()),lr=args.start_lr), alpha=0.5, k=5)
            if optimizer == None:
                assert False, 'no have optimizer'

            #if args.scheduler == 'multistep':
            #    m_e = args.multistep.split(',')
            #    scheduler = MultiStepLR(optimizer, milestones=[int(m_e[0]), int(m_e[1])], gamma=args.multistep_gamma)
            if args.scheduler =='CosineAnnealingWarmRestarts':
                scheduler = CosineAnnealingWarmRestarts(optimizer, 
                                                 T_0 = args.epochs//args.T_0, 
                                                 T_mult=1, 
                                                 eta_min=0, 
                                                 last_epoch=-1)
            elif args.scheduler == 'CosineAnnealingLR':
                scheduler = CosineAnnealingLR(optimizer, T_max=args.T_max, eta_min=args.min_lr, last_epoch=-1)
            else:
                scheduler=None


            log.write('optimizer\n  %s\n'%(optimizer))
            #log.write('schduler\n  %s\n'%(schduler))
            log.write('\n')

            ## start training here! ##############################################
            #array([0.57142857, 0.42857143])
            log.write('** start training here! **\n')
            log.write('   is_mixed_precision = %s \n'%str(args.amp))
            log.write('   batch_size = %d \n'%(args.batch_size))
            log.write('             |-------------- VALID---------|---- TRAIN/BATCH ----------------\n')
            log.write('rate  epoch  | dice   loss   tp     tn     | loss           | time           \n')
            log.write('-------------------------------------------------------------------------------------\n')
                      #0.00100   0.50  0.80 | 0.891  0.020  0.000  0.000  | 0.000  0.000   |  0 hr 02 min

            def message(mode='print'):
                if mode==('print'):
                    asterisk = ' '
                    loss = batch_loss
                if mode==('log'):
                    asterisk = '*'
                    loss = train_loss

                text = \
                    '%0.5f  %s%s    | '%(rate, epoch, asterisk,) +\
                    '%4.3f  %4.3f  %4.3f  %4.3f  | '%(*valid_loss,) +\
                    '%4.3f  %4.3f   | '%(*loss,) +\
                    '%s' % (time_to_str(timer() - start_timer,'min'))

                return text

            #----
            valid_loss = np.zeros(4,np.float32)
            train_loss = np.zeros(2,np.float32)
            batch_loss = np.zeros_like(train_loss)
            sum_train_loss = np.zeros_like(train_loss)
            sum_train = 0
            loss = torch.FloatTensor([0]).sum()


            start_timer = timer()
            rate = 0
            best_dice = 0
            #while  iteration < num_iteration:
            for epoch in range(args.epochs):
                print('\r',end='',flush=True)
                log.write(message(mode='log')+'\n')
                # training
                for t, batch in enumerate(train_loader):

                    # learning rate schduler -------------
                    #adjust_learning_rate(optimizer, schduler(iteration))
                    rate = get_learning_rate(optimizer)

                    # one iteration update  -------------
                    batch_size = len(batch['index'])
                    net.train()

                    if args.amp:
                        #image = image.half()
                        with autocast():
                            mask  = batch['mask'].to(device)
                            image = batch['image'].to(device)

                            optimizer.zero_grad()
                            #logit = data_parallel(net, image)
                            logit = net(image)
                            if args.loss == 'bce':
                                loss = criterion_binary_cross_entropy(logit, mask)
                            elif args.loss =='lovasz':
                                loss = symmetric_lovasz(logit, mask)

                        scaler.scale(loss).backward()
                        scaler.step(optimizer)
                        scaler.update()

                    else :
                        mask  = batch['mask'].to(device)
                        image = batch['image'].to(device)

                        optimizer.zero_grad()
                        #logit = data_parallel(net, image)
                        logit = net(image)
                        if args.loss == 'bce':
                            loss = criterion_binary_cross_entropy(logit, mask)
                        elif args.loss =='lovasz':
                            loss = symmetric_lovasz(logit, mask)

                        loss.backward()
                        optimizer.step()


                    # print statistics  --------

                    batch_loss = np.array([ loss.item(), 0 ])
                    sum_train_loss += batch_loss
                    sum_train += 1

                    print('\r',end='',flush=True)
                    print(message(mode='print'), end='',flush=True)

                # train loss
                train_loss = sum_train_loss/(sum_train+1e-12)
                sum_train_loss[...] = 0
                sum_train = 0

                # scheudler
                if scheduler is not None:
                    scheduler.step()
                # validation
                valid_loss = do_valid(net, valid_loader) #

                # saved models
                if valid_loss[0] > best_dice:
                    best_dice = valid_loss[0]
                    print(f'\n saved best models, dice:{best_dice:.5f}')
                    torch.save({
                        'state_dict': net.state_dict(),
                        'epoch': epoch,
                    }, out_dir + f'/checkpoint/{n_fold}fold_{epoch}epoch_{best_dice:.4f}_model.pth')

            log.write('\n')
            all_dice.append(best_dice)

            print(f'all dice score : {sum(all_dice)/len(all_dice) : .4f}')
    else:  # no overlap
        out_dir = data_dir + f'/result/{args.dir}_fold{args.fold}_{args.encoder}_{args.image_size}'

        ## setup  ----------------------------------------
        for f in ['checkpoint','train','valid','backup'] : 
            os.makedirs(out_dir +'/'+f, exist_ok=True)
        #backup_project_as_zip(PROJECT_PATH, out_dir +'/backup/code.train.%s.zip'%IDENTIFIER)
        log = Logger()
        log.open(out_dir+'/log.train.txt',mode='a')

        # my log argument
        print_args(args, log)

        log.write('\tout_dir  = %s\n' % out_dir)
        log.write('\n')


        log.write('** dataset setting **\n')
        #-----------dataset split --------------------#
        all_dice = []
        for n_fold in range(1,6):
            train_dataset = HuDataset_nol(data_dir, args.dataset, n_fold, train=True, augment=train_augment)
            train_loader  = DataLoader(
                train_dataset,
                sampler = RandomSampler(train_dataset),
                batch_size  = args.batch_size,
                drop_last   = False,
                num_workers = 8,
                pin_memory  = True,
                collate_fn  = null_collate
            )
            valid_dataset = HuDataset_nol(data_dir, args.dataset, n_fold, train=False, augment=None)
            valid_loader = DataLoader(
                valid_dataset,
                sampler = SequentialSampler(valid_dataset),
                batch_size  = args.batch_size,
                drop_last   = False,
                num_workers = 4,
                pin_memory  = True,
                collate_fn  = null_collate
            )

            log.write('fold = %s\n'%str(n_fold))
            log.write('train_dataset : \n%s\n'%(train_dataset))
            log.write('valid_dataset : \n%s\n'%(valid_dataset))
            log.write('\n')

            ## net ----------------------------------------
            log.write('** net setting **\n')

            scaler = GradScaler()
            net = SegModel() 
            net = net.to(device)

            if args.opt =='adamw':
                optimizer = torch.optim.AdamW(net.parameters(), lr = args.start_lr)
                #optimizer = torch.optim.AdamW([
                #    {'params': model.decoder.parameters(), 'lr': start_lr}, 
                #    {'params': model.encoder.parameters(), 'lr': start_lr},  
                #])
            elif args.opt =='radam_look':
                optimizer = Lookahead(RAdam(filter(lambda p: p.requires_grad, net.parameters()),lr=args.start_lr), alpha=0.5, k=5)
            if optimizer == None:
                assert False, 'no have optimizer'

            #if args.scheduler == 'multistep':
            #    m_e = args.multistep.split(',')
            #    scheduler = MultiStepLR(optimizer, milestones=[int(m_e[0]), int(m_e[1])], gamma=args.multistep_gamma)
            if args.scheduler =='CosineAnnealingWarmRestarts':
                scheduler = CosineAnnealingWarmRestarts(optimizer, 
                                                 T_0 = args.epochs//args.T_0, 
                                                 T_mult=1, 
                                                 eta_min=0, 
                                                 last_epoch=-1)
            elif args.scheduler == 'CosineAnnealingLR':
                scheduler = CosineAnnealingLR(optimizer, T_max=args.T_max, eta_min=args.min_lr, last_epoch=-1)
            else:
                scheduler=None


            log.write('optimizer\n  %s\n'%(optimizer))
            #log.write('schduler\n  %s\n'%(schduler))
            log.write('\n')

            ## start training here! ##############################################
            #array([0.57142857, 0.42857143])
            log.write('** start training here! **\n')
            log.write('   is_mixed_precision = %s \n'%str(args.amp))
            log.write('   batch_size = %d \n'%(args.batch_size))
            log.write('             |-------------- VALID---------|---- TRAIN/BATCH ----------------\n')
            log.write('rate  epoch  | dice   loss   tp     tn     | loss           | time           \n')
            log.write('-------------------------------------------------------------------------------------\n')
                      #0.00100   0.50  0.80 | 0.891  0.020  0.000  0.000  | 0.000  0.000   |  0 hr 02 min

            def message(mode='print'):
                if mode==('print'):
                    asterisk = ' '
                    loss = batch_loss
                if mode==('log'):
                    asterisk = '*'
                    loss = train_loss

                text = \
                    '%0.5f  %s%s    | '%(rate, epoch, asterisk,) +\
                    '%4.3f  %4.3f  %4.3f  %4.3f  | '%(*valid_loss,) +\
                    '%4.3f  %4.3f   | '%(*loss,) +\
                    '%s' % (time_to_str(timer() - start_timer,'min'))

                return text

            #----
            valid_loss = np.zeros(4,np.float32)
            train_loss = np.zeros(2,np.float32)
            batch_loss = np.zeros_like(train_loss)
            sum_train_loss = np.zeros_like(train_loss)
            sum_train = 0
            loss = torch.FloatTensor([0]).sum()


            start_timer = timer()
            rate = 0
            best_dice = 0
            #while  iteration < num_iteration:
            for epoch in range(args.epochs):
                print('\r',end='',flush=True)
                log.write(message(mode='log')+'\n')
                # training
                for t, batch in enumerate(train_loader):

                    # learning rate schduler -------------
                    #adjust_learning_rate(optimizer, schduler(iteration))
                    rate = get_learning_rate(optimizer)

                    # one iteration update  -------------
                    batch_size = len(batch['index'])
                    net.train()

                    if args.amp:
                        #image = image.half()
                        with autocast():
                            mask  = batch['mask'].to(device)
                            image = batch['image'].to(device)

                            optimizer.zero_grad()
                            #logit = data_parallel(net, image)
                            logit = net(image)
                            if args.loss == 'bce':
                                loss = criterion_binary_cross_entropy(logit, mask)
                            elif args.loss =='lovasz':
                                loss = symmetric_lovasz(logit, mask)

                        scaler.scale(loss).backward()
                        scaler.step(optimizer)
                        scaler.update()

                    else :
                        mask  = batch['mask'].to(device)
                        image = batch['image'].to(device)

                        optimizer.zero_grad()
                        #logit = data_parallel(net, image)
                        logit = net(image)
                        if args.loss == 'bce':
                            loss = criterion_binary_cross_entropy(logit, mask)
                        elif args.loss =='lovasz':
                            loss = symmetric_lovasz(logit, mask)

                        loss.backward()
                        optimizer.step()


                    # print statistics  --------

                    batch_loss = np.array([ loss.item(), 0 ])
                    sum_train_loss += batch_loss
                    sum_train += 1

                    print('\r',end='',flush=True)
                    print(message(mode='print'), end='',flush=True)

                # train loss
                train_loss = sum_train_loss/(sum_train+1e-12)
                sum_train_loss[...] = 0
                sum_train = 0

                # scheudler
                if scheduler is not None:
                    scheduler.step()
                # validation
                valid_loss = do_valid(net, valid_loader) #

                # saved models
                if valid_loss[0] > best_dice:
                    best_dice = valid_loss[0]
                    print(f'\n saved best models, dice:{best_dice:.5f}')
                    torch.save({
                        'state_dict': net.state_dict(),
                        'epoch': epoch,
                    }, out_dir + f'/checkpoint/{n_fold}fold_{epoch}epoch_{best_dice:.4f}_model.pth')

            log.write('\n')
            all_dice.append(best_dice)

            print(f'all dice score : {sum(all_dice)/len(all_dice) : .4f}')



In [None]:
if __name__ == '__main__':
    # set seed
    print('no set seed') if args.seed ==-1 else set_seeds(seed=args.seed)
    device = torch.device(f"cuda:{args.gpu}" if torch.cuda.is_available() else "cpu")
    run_train(args)

__module__       : __main__amp              : Truegpu              : 3dir              : 100epoch_nooverlap_640_25_50encoder          : b4decoder          : unetbatch_size       : 16weight_decay     : 1e-06n_fold           : 5fold             : 0all_fold_train   : Trueimage_size       : 512dataset          : 0.25_640_320_trainoverlap          : Falseepochs           : 100opt              : radam_lookscheduler        : CosineAnnealingLRloss             : bcefactor           : 0.2patience         : 2eps              : 1e-06T_max            : 10decay_epoch      : [4, 8, 12]T_0              : 4start_lr         : 0.001min_lr           : 1e-06num_workers      : 8seed             : 42__dict__         : <attribute '__dict__' of 'args' objects>__weakref__      : <attribute '__weakref__' of 'args' objects>__doc__          : None	out_dir  = /home/jeonghokim/competition/HubMap/data//result/100epoch_nooverlap_640_25_50_fold0_b4_512

** dataset setting **
train => # imgs : 4556, # masks : 4556
valid

0.00090  43*    | 0.941  0.010  0.925  0.999  | 0.008  0.000   |  2 hr 11 min
0.00079  44*    | 0.943  0.010  0.935  0.998  | 0.008  0.000   |  2 hr 14 min
0.00065  44     | 0.943  0.010  0.935  0.998  | 0.007  0.000   |  2 hr 16 min
 saved best models, dice:0.94337
0.00065  45*    | 0.943  0.010  0.934  0.998  | 0.008  0.000   |  2 hr 17 min
0.00050  46*    | 0.940  0.011  0.920  0.999  | 0.008  0.000   |  2 hr 20 min
0.00035  47*    | 0.939  0.011  0.920  0.999  | 0.008  0.000   |  2 hr 23 min
0.00021  48*    | 0.940  0.011  0.922  0.999  | 0.007  0.000   |  2 hr 26 min
0.00010  49*    | 0.942  0.011  0.929  0.998  | 0.007  0.000   |  2 hr 29 min
0.00003  50*    | 0.941  0.011  0.927  0.999  | 0.007  0.000   |  2 hr 32 min
0.00000  51*    | 0.942  0.011  0.929  0.999  | 0.007  0.000   |  2 hr 35 min
0.00003  52*    | 0.942  0.011  0.929  0.999  | 0.007  0.000   |  2 hr 38 min
0.00010  53*    | 0.942  0.011  0.928  0.999  | 0.007  0.000   |  2 hr 41 min
0.00021  54*    | 0.941  0.011 

0.00050  5     | 0.919  0.017  0.928  0.997  | 0.010  0.000   |  0 hr 18 min
 saved best models, dice:0.92071
0.00050  6*    | 0.921  0.016  0.934  0.997  | 0.015  0.000   |  0 hr 19 min
0.00035  7*    | 0.920  0.015  0.938  0.997  | 0.013  0.000   |  0 hr 22 min
0.00021  7     | 0.920  0.015  0.938  0.997  | 0.008  0.000   |  0 hr 24 min
 saved best models, dice:0.92185
0.00021  8*    | 0.922  0.015  0.934  0.997  | 0.013  0.000   |  0 hr 25 min
0.00010  8     | 0.922  0.015  0.934  0.997  | 0.014  0.000   |  0 hr 28 min
 saved best models, dice:0.92222
0.00010  9*    | 0.922  0.015  0.934  0.997  | 0.012  0.000   |  0 hr 28 min
0.00003  10*    | 0.922  0.015  0.932  0.997  | 0.012  0.000   |  0 hr 32 min
0.00000  11*    | 0.922  0.015  0.933  0.997  | 0.012  0.000   |  0 hr 35 min
0.00003  12*    | 0.921  0.015  0.933  0.997  | 0.012  0.000   |  0 hr 38 min
0.00010  12     | 0.921  0.015  0.933  0.997  | 0.008  0.000   |  0 hr 40 min
 saved best models, dice:0.92292
0.00010  13*    |

# submission

In [21]:
class args:
    # ---- factor ---- #
    server ='kaggle' # ['kaggle', 'local'] local은 cv측정용도
    amp = False
    gpu = 3
    
    encoder='b4'#'resnet34'
    decoder='unet'
    batch_size=256

    threshold = 0.4
    min_size = 0 # 총 덩어리가 픽셀 1000개가 안되는 것은 삭제한다. 2021-04-21 post_processing
    
    model_path = './data/result/new_5fold_alltraining2_fold0_b4_256/checkpoint/0fold_28epoch_0.9458_model.pth'# 모델한개
    
    en_model_path = ['./data/result/(0.924)100epoch_nooverlap_640_25_50_fold0_b4_512/checkpoint/' + x for x in \
                     
                     ['1fold_51epoch_0.9405_model.pth','2fold_61epoch_0.9437_model.pth','3fold_56epoch_0.9469_model.pth',
                     '4fold_36epoch_0.9297_model.pth','5fold_34epoch_0.9342_model.pth' ]]# if ensemble 직접 입력해야한다. 
    sub = '100epoch_5fold'# submission name
    
    # ---- Dataset ---- #
    
    tile_size = 640  # 이것도 480으로 해보기 
    tile_average_step = 320
    tile_scale = 0.25
    tile_min_score = 0.25  

device = torch.device(f"cuda:{args.gpu}" if torch.cuda.is_available() else "cpu")

In [22]:
thres = args.threshold

prob = []

def mask_to_csv(image_id, submit_dir):

    predicted = []
    for id in image_id:
        image_file = data_dir + '/test/%s.tiff' % id
        image = read_tiff(image_file)

        height, width = image.shape[:2]
        predict_file = submit_dir + '/%s.predict.png' % id
        # predict = cv2.imread(predict_file, cv2.IMREAD_GRAYSCALE)
        predict = np.array(Image.open(predict_file))
        predict = cv2.resize(predict, dsize=(width, height), interpolation=cv2.INTER_LINEAR)
        predict = (predict > 128).astype(np.uint8) * 255

        p = rle_encode(predict)
        predicted.append(p)

    df = pd.DataFrame()
    df['id'] = image_id
    df['predicted'] = predicted
    return df

def run_submit(args):

    #fold = 6
    out_dir = args.model_path.split('checkpoint')[0]
    initial_checkpoint = out_dir + '/checkpoint' + args.model_path.split('checkpoint')[1]
    
    # local은 cv측정 용도

    server = args.server#'kaggle' , 'local'

    #---
    submit_dir = out_dir + '/test/%s-%s-mean-thres(%s)'%(server, initial_checkpoint[-18:-4],thres)
    os.makedirs(submit_dir,exist_ok=True)

    log = Logger()
    log.open(out_dir+'/log.submit.txt',mode='a')

    #---
    if server == 'local':
        valid_image_id = make_image_id('valid-%d' % fold)
    if server == 'kaggle':
        valid_image_id = make_image_id('test-all')

    if server == 'local':
        tile_size = args.tile_size #320
        tile_average_step = args.tile_average_step#320 #192
        tile_scale = args.tile_scale
        tile_min_score = args.tile_min_score
    if server == 'kaggle' :
        tile_size = args.tile_size#640#640 #320
        tile_average_step = args.tile_average_step#320#320 #192
        tile_scale = args.tile_scale#0.25
        tile_min_score = args.tile_min_score#0.25   

    log.write('tile_size = %d \n'%tile_size)
    log.write('tile_average_step = %d \n'%tile_average_step)
    log.write('tile_scale = %f \n'%tile_scale)
    log.write('tile_min_score = %f \n'%tile_min_score)
    log.write('\n')

    
    # ----- model -------
    net = SegModel() 
    net.to(device)
    state_dict = torch.load(initial_checkpoint, map_location=lambda storage, loc: storage)['state_dict']
    net.load_state_dict(state_dict,strict=True)  #True
    net = net.eval()
    
    start_timer = timer()
    for id in valid_image_id:
        if server == 'local':
            image_file = data_dir + '/train/%s.tiff' % id
            image = read_tiff(image_file)
            height, width = image.shape[:2]

            json_file  = data_dir + '/train/%s-anatomical-structure.json' % id
            structure = draw_strcuture_from_hue(image, fill=255, scale=tile_scale/32)   
            mask_file = data_dir + '/train/%s.mask.png' % id
            mask  = read_mask(mask_file)

        if server == 'kaggle':
            image_file = data_dir + '/test/%s.tiff' % id
            json_file  = data_dir + '/test/%s-anatomical-structure.json' % id

            image = read_tiff(image_file)
            height, width = image.shape[:2]
            structure = draw_strcuture(read_json_as_df(json_file), height, width, structure=['Cortex'])

            mask = None


        #--- predict here!  ---
        tile = to_tile(image, mask, structure, tile_scale, tile_size, tile_average_step, tile_min_score)

        tile_image = tile['tile_image']
        tile_image = np.stack(tile_image)[..., ::-1]
        tile_image = np.ascontiguousarray(tile_image.transpose(0,3,1,2))
        tile_image = tile_image.astype(np.float32)/255
        print(tile_image.shape)
        tile_probability = []
        
        batch = np.array_split(tile_image, len(tile_image)//4)
        for t,m in enumerate(batch):
            print('\r %s  %d / %d   %s'%(id, t, len(batch), time_to_str(timer() - start_timer, 'sec')), end='',flush=True)
            m = torch.from_numpy(m).cuda()

            p = []
            with torch.no_grad():
                logit = net(m)
                p.append(torch.sigmoid(logit))

                #---
                if server == 'kaggle':
                    if 1: #tta here
                        logit = net(m.flip(dims=(2,)))
                        p.append(torch.sigmoid(logit.flip(dims=(2,))))

                        logit = net(m.flip(dims=(3,)))
                        p.append(torch.sigmoid(logit.flip(dims=(3,))))
                    p = torch.stack(p).mean(0)
                if server == 'local':
                    if 0: #tta here
                        #logit = data_parallel(net, m.flip(dims=(2,)))
                        logit = net(m.flip(dims=(2,)))
                        p.append(torch.sigmoid(logit.flip(dims=(2,))))

                        #logit = data_parallel(net, m.flip(dims=(3,)))
                        logit = net(m.flip(dims=(3,)))
                        p.append(torch.sigmoid(logit.flip(dims=(3,))))
                    p = torch.cat(p)
                    #p = torch.stack(p)

            tile_probability.append(p.data.cpu().numpy())

        print('\r' , end='',flush=True)
        log.write('%s  %d / %d   %s\n'%(id, t, len(batch), time_to_str(timer() - start_timer, 'sec')))

        tile_probability = np.concatenate(tile_probability).squeeze(1)
        height, width = tile['image_small'].shape[:2]
        probability = to_mask(tile_probability, tile['coord'], height, width,
                              tile_scale, tile_size, tile_average_step, tile_min_score,
                              aggregate='mean')
        

        #--- show results ---
        if server == 'local':
            truth = tile['mask_small'].astype(np.float32)/255
            truth2 = np.concatenate(tile['tile_mask']).astype(np.float32)/255
        if server == 'kaggle':
            truth = np.zeros((height, width), np.float32)

        overlay = np.dstack([
            np.zeros_like(truth),
            probability, #green
            truth, #red
        ])
        image_small = tile['image_small'].astype(np.float32)/255
        predict = (probability>thres).astype(np.float32)
        overlay1 = 1-(1-image_small)*(1-overlay)
        overlay2 = image_small.copy()
        overlay2 = draw_contour_overlay(overlay2, tile['structure_small'], color=(1, 1, 1), thickness=3)
        overlay2 = draw_contour_overlay(overlay2, truth, color=(0, 0, 1), thickness=8)
        overlay2 = draw_contour_overlay(overlay2, probability, color=(0, 1, 0), thickness=3)

        if 1:
            cv2.imwrite(submit_dir+'/%s.image_small.png'%id, (image_small*255).astype(np.uint8))
            cv2.imwrite(submit_dir+'/%s.probability.png'%id, (probability*255).astype(np.uint8))
            cv2.imwrite(submit_dir+'/%s.predict.png'%id, (predict*255).astype(np.uint8))
            cv2.imwrite(submit_dir+'/%s.overlay.png'%id, (overlay*255).astype(np.uint8))
            cv2.imwrite(submit_dir+'/%s.overlay1.png'%id, (overlay1*255).astype(np.uint8))
            cv2.imwrite(submit_dir+'/%s.overlay2.png'%id, (overlay2*255).astype(np.uint8))

        #---

        if server == 'local':

            loss = np_binary_cross_entropy_loss(probability, truth)
            dice = np_dice_score(probability, truth) # 여기는 큰이미지로 바꾼상태에서 dice
            dice2 = np_dice_score(tile_probability, truth2) # 작은이미지상태, 즉 training과 같은 cv구할려고 dice
            tp, tn = np_accuracy(probability, truth)
            log.write('submit_dir = %s \n'%submit_dir)
            log.write('initial_checkpoint = %s \n'%initial_checkpoint)
            log.write('loss   = %0.8f \n'%loss)
            log.write('dice   = %0.8f \n'%dice)
            log.write('dice2   = %0.8f \n'%dice2)
            log.write('tp, tn = %0.8f, %0.8f \n'%(tp, tn))
            log.write('\n')
            #cv2.waitKey(0)

    #-----
    if server == 'kaggle':
        csv_file = submit_dir + args.sub+'.csv'
        df = mask_to_csv(valid_image_id, submit_dir)
        df.to_csv(csv_file, index=False)
        print(df)

    zz=0
    
def remove_small(probability, threshold, min_size):
    mask = cv2.threshold(probability, threshold, 1, cv2.THRESH_BINARY)[1]  # 먼저 들어온 예측 값에 대하여 mask 생성
    num_component, component = cv2.connectedComponents(mask.astype(np.uint8))  # 이 함수는 0인 부분은 배경, 1인 부분은 하나의 덩어리로 생각해서
    # 총 분리된 덩어리 갯수(N)와 각 덩어리에 숫자를 1,2, ..., N을 매겨서 반환한다. 
    predictions = np.zeros_like(mask)  # 최종 예측 마스크를 선언
    for c in tqdm(range(1, num_component)):
        p = (component == c)  # 각 덩어리를 체크한뒤에 이 덩어리의 픽셀 갯수가 min_size보다 크면 1로 체크한다. 
        if p.sum() > min_size:
            predictions[p] = 1
    return predictions
    
def run_submit_ensemble(args):

    #fold = 6
    out_dir = args.en_model_path[0].split('checkpoint')[0]
    
    
    # local은 cv측정 용도

    server = args.server#'kaggle' , 'local'

    #---
    submit_dir = out_dir + '/test/%s-%s-thres(%s)_remove_small'%(server, args.sub,thres)
    os.makedirs(submit_dir,exist_ok=True)

    log = Logger()
    log.open(out_dir+'/log.submit.txt',mode='a')

    #---
    if server == 'local':
        valid_image_id = make_image_id('valid-%d' % fold)
    if server == 'kaggle':
        valid_image_id = make_image_id('test-all')

    if server == 'local':
        tile_size = args.tile_size #320
        tile_average_step = args.tile_average_step#320 #192
        tile_scale = args.tile_scale
        tile_min_score = args.tile_min_score
    if server == 'kaggle' :
        tile_size = args.tile_size#640#640 #320
        tile_average_step = args.tile_average_step#320#320 #192
        tile_scale = args.tile_scale#0.25
        tile_min_score = args.tile_min_score#0.25   

    log.write('tile_size = %d \n'%tile_size)
    log.write('tile_average_step = %d \n'%tile_average_step)
    log.write('tile_scale = %f \n'%tile_scale)
    log.write('tile_min_score = %f \n'%tile_min_score)
    log.write('\n')

    
    
    start_timer = timer()
    for id in valid_image_id:
        fold_prob = []
        for m_p in args.en_model_path:
            initial_checkpoint = m_p
            # ----- model -------
            net = SegModel() 
            net.to(device)
            state_dict = torch.load(initial_checkpoint, map_location=lambda storage, loc: storage)['state_dict']
            net.load_state_dict(state_dict,strict=True)  #True
            net = net.eval()
            if server == 'local':
                image_file = data_dir + '/train/%s.tiff' % id
                image = read_tiff(image_file)
                height, width = image.shape[:2]

                json_file  = data_dir + '/train/%s-anatomical-structure.json' % id
                structure = draw_strcuture_from_hue(image, fill=255, scale=tile_scale/32)   
                mask_file = data_dir + '/train/%s.mask.png' % id
                mask  = read_mask(mask_file)

            if server == 'kaggle':
                image_file = data_dir + '/test/%s.tiff' % id
                json_file  = data_dir + '/test/%s-anatomical-structure.json' % id

                image = read_tiff(image_file)
                height, width = image.shape[:2]
                structure = draw_strcuture(read_json_as_df(json_file), height, width, structure=['Cortex'])

                mask = None


            #--- predict here!  ---
            tile = to_tile(image, mask, structure, tile_scale, tile_size, tile_average_step, tile_min_score)

            tile_image = tile['tile_image']
            tile_image = np.stack(tile_image)[..., ::-1]
            tile_image = np.ascontiguousarray(tile_image.transpose(0,3,1,2))
            tile_image = tile_image.astype(np.float32)/255
            print(tile_image.shape)
            tile_probability = []

            batch = np.array_split(tile_image, len(tile_image)//4)
            for t,m in enumerate(batch):
                print('\r %s  %d / %d   %s'%(id, t, len(batch), time_to_str(timer() - start_timer, 'sec')), end='',flush=True)
                m = torch.from_numpy(m).to(device)

                p = []
                with torch.no_grad():
                    logit = net(m)
                    p.append(torch.sigmoid(logit))

                    #---
                    if server == 'kaggle':
                        if 1: #tta here
                            logit = net(m.flip(dims=(2,)))
                            p.append(torch.sigmoid(logit.flip(dims=(2,))))

                            logit = net(m.flip(dims=(3,)))
                            p.append(torch.sigmoid(logit.flip(dims=(3,))))
                        p = torch.stack(p).mean(0)
                    if server == 'local':
                        if 0: #tta here
                            #logit = data_parallel(net, m.flip(dims=(2,)))
                            logit = net(m.flip(dims=(2,)))
                            p.append(torch.sigmoid(logit.flip(dims=(2,))))

                            #logit = data_parallel(net, m.flip(dims=(3,)))
                            logit = net(m.flip(dims=(3,)))
                            p.append(torch.sigmoid(logit.flip(dims=(3,))))
                        p = torch.cat(p)
                        #p = torch.stack(p)

                tile_probability.append(p.data.cpu().numpy())

            print('\r' , end='',flush=True)
            log.write('%s  %d / %d   %s\n'%(id, t, len(batch), time_to_str(timer() - start_timer, 'sec')))

            tile_probability = np.concatenate(tile_probability).squeeze(1)
            height, width = tile['image_small'].shape[:2]
            probability = to_mask(tile_probability, tile['coord'], height, width,
                                  tile_scale, tile_size, tile_average_step, tile_min_score,
                                  aggregate='mean')


            fold_prob.append(probability)
        
        probability = sum(fold_prob)/len(args.en_model_path)
        #--- show results ---
        if server == 'local':
            truth = tile['mask_small'].astype(np.float32)/255
            truth2 = np.concatenate(tile['tile_mask']).astype(np.float32)/255
        if server == 'kaggle':
            truth = np.zeros((height, width), np.float32)

        overlay = np.dstack([
            np.zeros_like(truth),
            probability, #green
            truth, #red
        ])
        image_small = tile['image_small'].astype(np.float32)/255
        predict = remove_small(probability, args.threshold, args.min_size)
        overlay1 = 1-(1-image_small)*(1-overlay)
        overlay2 = image_small.copy()
        overlay2 = draw_contour_overlay(overlay2, tile['structure_small'], color=(1, 1, 1), thickness=3)
        overlay2 = draw_contour_overlay(overlay2, truth, color=(0, 0, 1), thickness=8)
        overlay2 = draw_contour_overlay(overlay2, probability, color=(0, 1, 0), thickness=3)

        if 1:
            cv2.imwrite(submit_dir+'/%s.image_small.png'%id, (image_small*255).astype(np.uint8))
            cv2.imwrite(submit_dir+'/%s.probability.png'%id, (probability*255).astype(np.uint8))
            cv2.imwrite(submit_dir+'/%s.predict.png'%id, (predict*255).astype(np.uint8))
            cv2.imwrite(submit_dir+'/%s.overlay.png'%id, (overlay*255).astype(np.uint8))
            cv2.imwrite(submit_dir+'/%s.overlay1.png'%id, (overlay1*255).astype(np.uint8))
            cv2.imwrite(submit_dir+'/%s.overlay2.png'%id, (overlay2*255).astype(np.uint8))

        #---

        if server == 'local':

            loss = np_binary_cross_entropy_loss(probability, truth)
            dice = np_dice_score(probability, truth) # 여기는 큰이미지로 바꾼상태에서 dice
            dice2 = np_dice_score(tile_probability, truth2) # 작은이미지상태, 즉 training과 같은 cv구할려고 dice
            tp, tn = np_accuracy(probability, truth)
            log.write('submit_dir = %s \n'%submit_dir)
            log.write('initial_checkpoint = %s \n'%initial_checkpoint)
            log.write('loss   = %0.8f \n'%loss)
            log.write('dice   = %0.8f \n'%dice)
            log.write('dice2   = %0.8f \n'%dice2)
            log.write('tp, tn = %0.8f, %0.8f \n'%(tp, tn))
            log.write('\n')
            #cv2.waitKey(0)

    #-----
    if server == 'kaggle':
        csv_file = submit_dir +'.csv'
        df = mask_to_csv(valid_image_id, submit_dir)
        df.to_csv(csv_file, index=False)
        print(df)

    zz=0

In [23]:

if 0: #normal
    if __name__ == '__main__':
        run_submit(args)
elif 1:# ensemble
    if __name__ == '__main__':
        run_submit_ensemble(args)

tile_size = 640 
tile_average_step = 320 
tile_scale = 0.250000 
tile_min_score = 0.250000 

encoder :  efficientnet-b4
unet loaded
(364, 3, 640, 640)
2ec3f1bb9  90 / 91    0 min 38 secc
encoder :  efficientnet-b4
unet loaded
(364, 3, 640, 640)
2ec3f1bb9  90 / 91    1 min 19 secc
encoder :  efficientnet-b4
unet loaded
(364, 3, 640, 640)
2ec3f1bb9  90 / 91    1 min 54 secc
encoder :  efficientnet-b4
unet loaded
(364, 3, 640, 640)
2ec3f1bb9  90 / 91    2 min 31 secc
encoder :  efficientnet-b4
unet loaded
(364, 3, 640, 640)
2ec3f1bb9  90 / 91    3 min 12 secc


100%|██████████| 400/400 [00:41<00:00,  9.60it/s]


encoder :  efficientnet-b4
unet loaded
(139, 3, 640, 640)
3589adb90  33 / 34    4 min 26 secc
encoder :  efficientnet-b4
unet loaded
(139, 3, 640, 640)
3589adb90  33 / 34    4 min 42 secc
encoder :  efficientnet-b4
unet loaded
(139, 3, 640, 640)
3589adb90  33 / 34    4 min 59 secc
encoder :  efficientnet-b4
unet loaded
(139, 3, 640, 640)
3589adb90  33 / 34    5 min 16 secc
encoder :  efficientnet-b4
unet loaded
(139, 3, 640, 640)
3589adb90  33 / 34    5 min 32 secc


100%|██████████| 222/222 [00:12<00:00, 17.18it/s]


encoder :  efficientnet-b4
unet loaded
(240, 3, 640, 640)
57512b7f1  59 / 60    6 min 31 secc
encoder :  efficientnet-b4
unet loaded
(240, 3, 640, 640)
57512b7f1  59 / 60    7 min 08 secc
encoder :  efficientnet-b4
unet loaded
(240, 3, 640, 640)
57512b7f1  59 / 60    7 min 46 secc
encoder :  efficientnet-b4
unet loaded
(240, 3, 640, 640)
57512b7f1  59 / 60    8 min 23 secc
encoder :  efficientnet-b4
unet loaded
(240, 3, 640, 640)
57512b7f1  59 / 60    9 min 00 secc


100%|██████████| 141/141 [00:19<00:00,  7.24it/s]


encoder :  efficientnet-b4
unet loaded
(390, 3, 640, 640)
aa05346ff  96 / 97   10 min 31 secc
encoder :  efficientnet-b4
unet loaded
(390, 3, 640, 640)
aa05346ff  96 / 97   11 min 26 secc
encoder :  efficientnet-b4
unet loaded
(390, 3, 640, 640)
aa05346ff  96 / 97   12 min 20 secc
encoder :  efficientnet-b4
unet loaded
(390, 3, 640, 640)
aa05346ff  96 / 97   13 min 15 secc
encoder :  efficientnet-b4
unet loaded
(390, 3, 640, 640)
aa05346ff  96 / 97   14 min 03 secc


100%|██████████| 307/307 [00:41<00:00,  7.32it/s]


encoder :  efficientnet-b4
unet loaded
(230, 3, 640, 640)
d488c759a  56 / 57   15 min 41 secc
encoder :  efficientnet-b4
unet loaded
(230, 3, 640, 640)
d488c759a  56 / 57   16 min 17 secc
encoder :  efficientnet-b4
unet loaded
(230, 3, 640, 640)
d488c759a  56 / 57   16 min 53 secc
encoder :  efficientnet-b4
unet loaded
(230, 3, 640, 640)
d488c759a  56 / 57   17 min 29 secc
encoder :  efficientnet-b4
unet loaded
(230, 3, 640, 640)
d488c759a  56 / 57   18 min 05 secc


100%|██████████| 115/115 [00:15<00:00,  7.54it/s]


          id                                          predicted
0  2ec3f1bb9  60762281 49 60786270 51 60810259 53 60834248 5...
1  3589adb90  68541227 61 68570659 63 68600092 64 68629524 6...
2  57512b7f1  328952562 22 328985801 24 329019041 24 3290522...
3  aa05346ff  52856686 46 52887405 48 52918125 48 52948844 5...
4  d488c759a  494044870 14 494091529 16 494138189 16 4941848...
