<a href="https://colab.research.google.com/github/taguka/atlas/blob/master/protein.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Data Preparation

In [0]:
!apt-get -qq install -y libsm6 libxext6 && pip install -q -U opencv-python
!pip3 install https://download.pytorch.org/whl/cu80/torch-1.0.0-cp36-cp36m-linux_x86_64.whl
!pip install torchvision
!pip install pretrainedmodels
!pip install attrdict

In [0]:
# Generate auth tokens for Colab
from google.colab import auth
auth.authenticate_user()
from google.colab import drive
drive.mount('/content/gdrive/', force_remount=True)

In [0]:
!pip3 install kaggle
from googleapiclient.discovery import build
import io, os
from googleapiclient.http import MediaIoBaseDownload

drive_service = build('drive', 'v3')
results = drive_service.files().list(
        q="name = 'kaggle.json'", fields="files(id)").execute()
kaggle_api_key = results.get('files', [])
filename = "/content/.kaggle/kaggle.json"
os.makedirs(os.path.dirname(filename), exist_ok=True)
request = drive_service.files().get_media(fileId=kaggle_api_key[0]['id'])
fh = io.FileIO(filename, 'wb')
downloader = MediaIoBaseDownload(fh, request)
done = False
while done is False:
    status, done = downloader.next_chunk()
    print("Download %d%%." % int(status.progress() * 100))
os.chmod(filename, 600)
!mkdir ~/.kaggle
!cp /content/.kaggle/kaggle.json ~/.kaggle/kaggle.json



!kaggle competitions download -c human-protein-atlas-image-classification
!unzip -qq train.zip -d train | awk 'BEGIN {ORS=" "} {if(NR%500==0) print "."}'
!unzip -qq test.zip -d test | awk 'BEGIN {ORS=" "} {if(NR%500==0) print "."}'
!rm test.zip
!rm train.zip

# Config

In [0]:
from attrdict import AttrDict

config = AttrDict()
config.output_path = 'gdrive/My Drive/protein/output/'
config.data_path = '/content'
config.submission_path = 'gdrive/My Drive/protein/submissions/'
config.target_file = 'gdrive/My Drive/protein/labels.csv'
config.checkpoint = None
config.num_classes=28
config.num_channels=3
config.model = 'resnet50'
config.opt = 'sgd'
config.loss = 'bce'
config.fold=0
config.batch_size = 8
config.img_size = 512
config.epochs=50
config.decay_epochs=15
config.ft_epochs=0
config.ft_opt='sgd'
config.ft_lr='sgd'
config.dropout=0.1
config.lr = 0.0010
config.momentum=0.9
config.weight_decay=0.0005
config.seed=10
config.log_interval=1000
config.print_freq=10
config.no_cuda=False
config.external_data = True
config.use_sampler = True
config.exp='stage_1'

config.save_batches=False
config.class_weights=True
config.num_epochs = [1, 8, 8]
config.cycles_len = [0, 2, 4]
config.lr_divs = [0, 4, 12]


In [0]:
from torch.utils.data.sampler import Sampler
class WeightedRandomOverSampler(Sampler):
    #Over-samples elements from [0,..,len(weights)-1] factor number of times.
    #Each element is sample at least once, the remaining over-sampling is determined
    #by the weights.
    #Arguments:
    #    weights (list) : a list of weights, not necessary summing up to one
    #    factor (float) : the oversampling factor (>= 1.0)
   

    def __init__(self, weights, factor=2.):
        self.weights = torch.DoubleTensor(weights)
        assert factor >= 1.
        self.num_samples = int(len(self.weights) * factor)

    def __iter__(self):
        base_samples = torch.arange(0, len(self.weights)).long()
        remaining = self.num_samples - len(self.weights)
        over_samples = torch.multinomial(self.weights, remaining, True)
        samples = torch.cat((base_samples, over_samples), dim=0)
        print('num samples', len(samples))
        return (samples[i] for i in torch.randperm(len(samples)))

    def __len__(self):
        return self.num_samples


In [0]:
from typing import Tuple
import albumentations as album
def get_transforms(image_size: int) -> Tuple[album.Compose, album.Compose, album.Compose]:
    transforms_train = album.Compose([
        album.Resize(image_size, image_size, interpolation=Image.BICUBIC),
        album.Rotate(interpolation=Image.BICUBIC),
        album.RandomRotate90(),
        album.HorizontalFlip(),
        album.RandomBrightnessContrast(),
        album.Normalize([0.08069, 0.05258, 0.05487], [0.13704, 0.10145, 0.15313])
    ])

    transforms_test = album.Compose([
        album.Resize(image_size, image_size, interpolation=Image.BICUBIC),
        album.Normalize([0.08069, 0.05258, 0.05487], [0.13704, 0.10145, 0.15313])
    ])

    transforms_test_aug = album.Compose([
        album.Resize(image_size, image_size, interpolation=Image.BICUBIC),
        album.Rotate(interpolation=Image.BICUBIC),
        album.RandomRotate90(),
        album.HorizontalFlip(),
        album.Normalize([0.08069, 0.05258, 0.05487], [0.13704, 0.10145, 0.15313])
    ])

    return transforms_train, transforms_test, transforms_test_aug

# Opt

In [0]:
"""Reduce on Plateau Learning Rate Scheduler
Taken from pytorch master to use with 0.12 release.
"""
from torch.optim import Optimizer


class _LRScheduler(object):
    def __init__(self, optimizer, last_epoch=-1):
        if not isinstance(optimizer, Optimizer):
            raise TypeError('{} is not an Optimizer'.format(
                type(optimizer).__name__))
        self.optimizer = optimizer
        if last_epoch == -1:
            for group in optimizer.param_groups:
                group.setdefault('initial_lr', group['lr'])
        else:
            for i, group in enumerate(optimizer.param_groups):
                if 'initial_lr' not in group:
                    raise KeyError("param 'initial_lr' is not specified "
                                   "in param_groups[{}] when resuming an optimizer".format(i))
        self.base_lrs = list(map(lambda group: group['initial_lr'], optimizer.param_groups))
        self.step(last_epoch + 1)
        self.last_epoch = last_epoch

    def get_lr(self):
        raise NotImplementedError

    def step(self, epoch=None):
        if epoch is None:
            epoch = self.last_epoch + 1
        self.last_epoch = epoch
        for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
            param_group['lr'] = lr


class ReduceLROnPlateau(object):
    """Reduce learning rate when a metric has stopped improving.
    Models often benefit from reducing the learning rate by a factor
    of 2-10 once learning stagnates. This scheduler reads a metrics
    quantity and if no improvement is seen for a 'patience' number
    of epochs, the learning rate is reduced.

    Args:
        optimizer (Optimizer): Wrapped optimizer.
        mode (str): One of `min`, `max`. In `min` mode, lr will
            be reduced when the quantity monitored has stopped
            decreasing; in `max` mode it will be reduced when the
            quantity monitored has stopped increasing. Default: 'min'.
        factor (float): Factor by which the learning rate will be
            reduced. new_lr = lr * factor. Default: 0.1.
        patience (int): Number of epochs with no improvement after
            which learning rate will be reduced. Default: 10.
        verbose (bool): If True, prints a message to stdout for
            each update. Default: False.
        threshold (float): Threshold for measuring the new optimum,
            to only focus on significant changes. Default: 1e-4.
        threshold_mode (str): One of `rel`, `abs`. In `rel` mode,
            dynamic_threshold = best * ( 1 + threshold ) in 'max'
            mode or best * ( 1 - threshold ) in `min` mode.
            In `abs` mode, dynamic_threshold = best + threshold in
            `max` mode or best - threshold in `min` mode. Default: 'rel'.
        cooldown (int): Number of epochs to wait before resuming
            normal operation after lr has been reduced. Default: 0.
        min_lr (float or list): A scalar or a list of scalars. A
            lower bound on the learning rate of all param groups
            or each group respectively. Default: 0.
        eps (float): Minimal decay applied to lr. If the difference
            between new and old lr is smaller than eps, the update is
            ignored. Default: 1e-8.
    """

    def __init__(self, optimizer, mode='min', factor=0.1, patience=10,
                 verbose=False, threshold=1e-4, threshold_mode='rel',
                 cooldown=0, min_lr=0, eps=1e-8):

        if factor >= 1.0:
            raise ValueError('Factor should be < 1.0.')
        self.factor = factor

        if not isinstance(optimizer, Optimizer):
            raise TypeError('{} is not an Optimizer'.format(
                type(optimizer).__name__))
        self.optimizer = optimizer

        if isinstance(min_lr, list) or isinstance(min_lr, tuple):
            if len(min_lr) != len(optimizer.param_groups):
                raise ValueError("expected {} min_lrs, got {}".format(
                    len(optimizer.param_groups), len(min_lr)))
            self.min_lrs = list(min_lr)
        else:
            self.min_lrs = [min_lr] * len(optimizer.param_groups)

        self.patience = patience
        self.verbose = verbose
        self.cooldown = cooldown
        self.cooldown_counter = 0
        self.mode = mode
        self.threshold = threshold
        self.threshold_mode = threshold_mode
        self.best = None
        self.num_bad_epochs = None
        self.mode_worse = None  # the worse value for the chosen mode
        self.is_better = None
        self.eps = eps
        self.last_epoch = -1
        self._init_is_better(mode=mode, threshold=threshold,
                             threshold_mode=threshold_mode)
        self._reset()

    def _reset(self):
        """Resets num_bad_epochs counter and cooldown counter."""
        self.best = self.mode_worse
        self.cooldown_counter = 0
        self.num_bad_epochs = 0

    def step(self, metrics, epoch=None):
        current = metrics
        if epoch is None:
            epoch = self.last_epoch = self.last_epoch + 1
        self.last_epoch = epoch

        if self.is_better(current, self.best):
            self.best = current
            self.num_bad_epochs = 0
        else:
            self.num_bad_epochs += 1

        if self.in_cooldown:
            self.cooldown_counter -= 1
            self.num_bad_epochs = 0  # ignore any bad epochs in cooldown

        if self.num_bad_epochs > self.patience:
            self._reduce_lr(epoch)
            self.cooldown_counter = self.cooldown
            self.num_bad_epochs = 0

    def _reduce_lr(self, epoch):
        for i, param_group in enumerate(self.optimizer.param_groups):
            old_lr = float(param_group['lr'])
            new_lr = max(old_lr * self.factor, self.min_lrs[i])
            if old_lr - new_lr > self.eps:
                param_group['lr'] = new_lr
                if self.verbose:
                    print('Epoch {:5d}: reducing learning rate'
                          ' of group {} to {:.4e}.'.format(epoch, i, new_lr))

    @property
    def in_cooldown(self):
        return self.cooldown_counter > 0

    def _init_is_better(self, mode, threshold, threshold_mode):
        if mode not in {'min', 'max'}:
            raise ValueError('mode ' + mode + ' is unknown!')
        if threshold_mode not in {'rel', 'abs'}:
            raise ValueError('threshold mode ' + mode + ' is unknown!')
        if mode == 'min' and threshold_mode == 'rel':
            rel_epsilon = 1. - threshold
            self.is_better = lambda a, best: a < best * rel_epsilon
            self.mode_worse = float('Inf')
        elif mode == 'min' and threshold_mode == 'abs':
            self.is_better = lambda a, best: a < best - threshold
            self.mode_worse = float('Inf')
        elif mode == 'max' and threshold_mode == 'rel':
            rel_epsilon = threshold + 1.
            self.is_better = lambda a, best: a > best * rel_epsilon
            self.mode_worse = -float('Inf')
        else:  # mode == 'max' and epsilon_mode == 'abs':
            self.is_better = lambda a, best: a > best + threshold
            self.mode_worse = -float('Inf')


In [0]:
"""Yellowfin Optimizer
Sourced from: https://github.com/JianGoForIt/YellowFin_Pytorch (MIT License)
"""

import math
# for torch optim sgd
import numpy as np
import torch

class YFOptimizer(object):
  def __init__(self, var_list, lr=0.1, mu=0.0, clip_thresh=None, weight_decay=0.0,
    beta=0.999, curv_win_width=20, zero_debias=True, delta_mu=0.0):
    '''
    clip thresh is the threshold value on ||lr * gradient||
    delta_mu can be place holder/variable/python scalar. They are used for additional
    momentum in situations such as asynchronous-parallel training. The default is 0.0
    for basic usage of the optimizer.
    Args:
      lr: python scalar. The initial value of learning rate, we use 1.0 in our paper.
      mu: python scalar. The initial value of momentum, we use 0.0 in our paper.
      clip_thresh: python scalar. The cliping threshold for tf.clip_by_global_norm.
        if None, no clipping will be carried out. 
      beta: python scalar. The smoothing parameter for estimations.
      delta_mu: for extensions. Not necessary in the basic use. (TODO)
    Other features:
      If you want to manually control the learning rates, self.lr_factor is
      an interface to the outside, it is an multiplier for the internal learning rate
      in YellowFin. It is helpful when you want to do additional hand tuning
      or some decaying scheme to the tuned learning rate in YellowFin. 
      Example on using lr_factor can be found here:
      (TODO)
    '''
    self._lr = lr
    self._mu = mu
    # we convert var_list from generator to list so that
    # it can be used for multiple times
    self._var_list = list(var_list)
    self._clip_thresh = clip_thresh
    self._beta = beta
    self._curv_win_width = curv_win_width
    self._zero_debias = zero_debias
    self._optimizer = torch.optim.SGD(self._var_list, lr=self._lr, 
      momentum=self._mu, weight_decay=weight_decay)
    self._iter = 0
    # global states are the statistics
    self._global_state = {}

    # for decaying learning rate and etc.
    self._lr_factor = 1.0
    pass


  def state_dict(self):
    sgd_state_dict = self._optimizer.state_dict()
    global_state = self._global_state
    lr_factor = self._lr_factor
    iter = self._iter
    lr = self._lr
    mu = self._mu

    return {
      "sgd_state_dict": sgd_state_dict,
      "global_state": global_state,
      "lr_factor": lr_factor,
      "iter": iter,
      "lr": lr,
      "mu": mu,
    }


  def load_state_dict(self, state_dict):
    self._optimizer.load_state_dict(state_dict['sgd_state_dict'])
    self._global_state = state_dict['global_state']
    self._lr_factor = state_dict['lr_factor']
    self._iter = state_dict['iter']
    self._lr = state_dict['lr']
    self._mu = state_dict['mu']


  def set_lr_factor(self, factor):
    self._lr_factor = factor
    return


  def get_lr_factor(self):
    return self._lr_factor


  def zero_grad(self):
    self._optimizer.zero_grad()


  def zero_debias_factor(self):
    return 1.0 - self._beta ** (self._iter + 1)


  def curvature_range(self):
    global_state = self._global_state
    if self._iter == 0:
      global_state["curv_win"] = torch.FloatTensor(self._curv_win_width, 1).zero_()
    curv_win = global_state["curv_win"]
    grad_norm_squared = self._global_state["grad_norm_squared"]
    curv_win[self._iter % self._curv_win_width] = grad_norm_squared
    valid_end = min(self._curv_win_width, self._iter + 1)
    beta = self._beta
    if self._iter == 0:
      global_state["h_min_avg"] = 0.0
      global_state["h_max_avg"] = 0.0
      self._h_min = 0.0
      self._h_max = 0.0
    global_state["h_min_avg"] = \
      global_state["h_min_avg"] * beta + (1 - beta) * torch.min(curv_win[:valid_end] )
    global_state["h_max_avg"] = \
      global_state["h_max_avg"] * beta + (1 - beta) * torch.max(curv_win[:valid_end] )
    if self._zero_debias:
      debias_factor = self.zero_debias_factor()
      self._h_min = global_state["h_min_avg"] / debias_factor
      self._h_max = global_state["h_max_avg"] / debias_factor
    else:
      self._h_min = global_state["h_min_avg"]
      self._h_max = global_state["h_max_avg"]
    return


  def grad_variance(self):
    global_state = self._global_state
    beta = self._beta
    self._grad_var = np.array(0.0, dtype=np.float32)
    for group in self._optimizer.param_groups:
      for p in group['params']:
        if p.grad is None:
          continue
        grad = p.grad.data
        state = self._optimizer.state[p]

        if self._iter == 0:
          state["grad_avg"] = grad.new().resize_as_(grad).zero_()
          state["grad_avg_squared"] = 0.0
        state["grad_avg"].mul_(beta).add_(1 - beta, grad)
        self._grad_var += torch.sum(state["grad_avg"] * state["grad_avg"] )
        
    if self._zero_debias:
      debias_factor = self.zero_debias_factor()
    else:
      debias_factor = 1.0

    self._grad_var /= -(debias_factor**2)
    self._grad_var += global_state['grad_norm_squared_avg'] / debias_factor
    return


  def dist_to_opt(self):
    global_state = self._global_state
    beta = self._beta
    if self._iter == 0:
      global_state["grad_norm_avg"] = 0.0
      global_state["dist_to_opt_avg"] = 0.0
    global_state["grad_norm_avg"] = \
      global_state["grad_norm_avg"] * beta + (1 - beta) * math.sqrt(global_state["grad_norm_squared"] )
    global_state["dist_to_opt_avg"] = \
      global_state["dist_to_opt_avg"] * beta \
      + (1 - beta) * global_state["grad_norm_avg"] / global_state['grad_norm_squared_avg']
    if self._zero_debias:
      debias_factor = self.zero_debias_factor()
      self._dist_to_opt = global_state["dist_to_opt_avg"] / debias_factor
    else:
      self._dist_to_opt = global_state["dist_to_opt_avg"]
    return


  def after_apply(self):
    # compute running average of gradient and norm of gradient
    beta = self._beta
    global_state = self._global_state
    if self._iter == 0:
      global_state["grad_norm_squared_avg"] = 0.0

    global_state["grad_norm_squared"] = 0.0
    for group in self._optimizer.param_groups:
      for p in group['params']:
        if p.grad is None:
          continue
        grad = p.grad.data
        # global_state['grad_norm_squared'] += torch.dot(grad, grad)
        global_state['grad_norm_squared'] += torch.sum(grad * grad)
        
    global_state['grad_norm_squared_avg'] = \
      global_state['grad_norm_squared_avg'] * beta + (1 - beta) * global_state['grad_norm_squared']
    # global_state['grad_norm_squared_avg'].mul_(beta).add_(1 - beta, global_state['grad_norm_squared'] )
        
    self.curvature_range()
    self.grad_variance()
    self.dist_to_opt()
    if self._iter > 0:
      self.get_mu()    
      self.get_lr()
      self._lr = beta * self._lr + (1 - beta) * self._lr_t
      self._mu = beta * self._mu + (1 - beta) * self._mu_t
    return


  def get_lr(self):
    self._lr_t = (1.0 - math.sqrt(self._mu_t) )**2 / self._h_min
    return


  def get_mu(self):
    coef = [-1.0, 3.0, 0.0, 1.0]
    coef[2] = -(3 + self._dist_to_opt**2 * self._h_min**2 / 2 / self._grad_var)
    roots = np.roots(coef)
    root = roots[np.logical_and(np.logical_and(np.real(roots) > 0.0, 
      np.real(roots) < 1.0), np.imag(roots) < 1e-5) ]
    assert root.size == 1
    dr = self._h_max / self._h_min
    self._mu_t = max(np.real(root)[0]**2, ( (np.sqrt(dr) - 1) / (np.sqrt(dr) + 1) )**2 )
    return 


  def update_hyper_param(self):
    for group in self._optimizer.param_groups:
      group['momentum'] = self._mu
      group['lr'] = self._lr * self._lr_factor
    return


  def step(self):
    # add weight decay
    for group in self._optimizer.param_groups:
      for p in group['params']:
        if p.grad is None:
            continue
        grad = p.grad.data

        if group['weight_decay'] != 0:
            grad = grad.add(group['weight_decay'], p.data)
    
    #if self._clip_thresh != None:
    #  torch.nn.utils.clip_grad_norm(self._var_list, self._clip_thresh)
    if self._clip_thresh is not None:
      if isinstance(self._var_list[0], dict):
        params = []
        for p in self._var_list:
          params.extend(p['params'])
        torch.nn.utils.clip_grad_norm(params, self._clip_thresh)
      else:
        torch.nn.utils.clip_grad_norm(self._var_list, self._clip_thresh)
    
    # apply update
    self._optimizer.step()

    # after appply
    self.after_apply()

    # update learning rate and momentum
    self.update_hyper_param()

    self._iter += 1
    return 



In [0]:
import numpy as np
import os
import torch
import torch.nn
from sklearn.metrics import fbeta_score
import shutil
class AverageMeter:
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def calc_crop_size(target_w, target_h, angle=0.0, scale=1.0):
    crop_w = target_w
    crop_h = target_h
    if angle:
        corners = np.array(
            [[target_w/2, -target_w/2, -target_w/2, target_w/2],
            [target_h/2, target_h/2, -target_h/2, -target_h/2]])
        s = np.sin(angle * np.pi/180)
        c = np.cos(angle * np.pi/180)
        M = np.array([[c, -s], [s, c]])
        rotated_corners = np.dot(M, corners)
        crop_w = 2 * np.max(np.abs(rotated_corners[0, :]))
        crop_h = 2 * np.max(np.abs(rotated_corners[1, :]))
    crop_w = int(np.ceil(crop_w / scale))
    crop_h = int(np.ceil(crop_h / scale))
    return crop_w, crop_h


def crop_center(img, cx, cy, crop_w, crop_h):
    img_h, img_w = img.shape[:2]
    trunc_top = trunc_bottom = trunc_left = trunc_right = 0
    left = cx - crop_w//2
    if left < 0:
        trunc_left = 0 - left
        left = 0
    right = left - trunc_left + crop_w
    if right > img_w:
        trunc_right = right - img_w
        right = img_w
    top = cy - crop_h//2
    if top < 0:
        trunc_top = 0 - top
        top = 0
    bottom = top - trunc_top + crop_h
    if bottom > img_h:
        trunc_bottom = bottom - img_h
        bottom = img_h
    if trunc_left or trunc_right or trunc_top or trunc_bottom:
        img_new = np.zeros((crop_h, crop_w, img.shape[2]), dtype=img.dtype)
        trunc_bottom = crop_h - trunc_bottom
        trunc_right = crop_w - trunc_right
        img_new[trunc_top:trunc_bottom, trunc_left:trunc_right] = img[top:bottom, left:right]
        return img_new
    else:
        return img[top:bottom, left:right]


def crop_points_center(points, cx, cy, crop_w, crop_h):
    xl = cx - crop_w // 2
    xu = xl + crop_w
    yl = cy - crop_h // 2
    yu = yl + crop_h
    mask = (points[:, 0] >= xl) & (points[:, 0] < xu) & (points[:, 1] >= yl) & (points[:, 1] < yu)
    return points[mask]


def crop_points(points, x, y, crop_w, crop_h):
    xu = x + crop_w
    yu = y + crop_h
    mask = (points[:, 0] >= x) & (points[:, 0] < xu) & (points[:, 1] >= y) & (points[:, 1] < yu)
    return points[mask]

def adjust_learning_rate(optimizer, epoch, initial_lr, decay_epochs=30):
    """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
    if isinstance(optimizer, YFOptimizer):
        return
    lr = initial_lr * (0.1 ** (epoch // decay_epochs))
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr


def save_checkpoint(state, is_best, filename='checkpoint.pth.tar', output_dir=''):
    save_path = os.path.join(output_dir, filename)
    torch.save(state, save_path)
    if is_best:
        shutil.copyfile(save_path, os.path.join(output_dir, 'model_best.pth.tar'))


def accuracy(output, target, topk=(1,)):
    """Computes the precision@k for the specified values of k"""
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].view(-1).float().sum(0)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res


def scores(output, target, threshold=0.2):
    # Count true positives, true negatives, false positives and false negatives.
    outputr = (output > threshold).long()
    target = target.long()
    a_sum = 0.0
    p_sum = 0.0
    r_sum = 0.0
    f2_sum = 0.0

    def _safe_size(t, n=0):
        if n < len(t.size()):
            return t.size(n)
        else:
            return 0

    count = 0
    for o, t in zip(outputr, target):
        tp = _safe_size(torch.nonzero(o * t))
        tn = _safe_size(torch.nonzero((o - 1) * (t - 1)))
        fp = _safe_size(torch.nonzero(o * (t - 1)))
        fn = _safe_size(torch.nonzero((o - 1) * t))
        a = (tp + tn) / (tp + fp + fn + tn)
        if tp == 0 and fp == 0 and fn == 0:
            p = 1.0
            r = 1.0
            f2 = 1.0
        elif tp == 0 and (fp > 0 or fn > 0):
            p = 0.0
            r = 0.0
            f2 = 0.0
        else:
            p = tp / (tp + fp)
            r = tp / (tp + fn)
            f2 = (5 * p * r) / (4 * p + r)
        a_sum += a
        p_sum += p
        r_sum += r
        f2_sum += f2
        count += 1
    accuracy = a_sum / count
    precision = p_sum / count
    recall = r_sum / count
    fmeasure = f2_sum / count
    return accuracy, precision, recall, fmeasure


def f2_score(output, target, threshold):
    output = (output > threshold)
    return fbeta_score(target, output, beta=2, average='samples')


def optimise_f2_thresholds(y, p, verbose=True, resolution=100):
    """ Find optimal threshold values for f2 score. Thanks Anokas
    https://www.kaggle.com/c/planet-understanding-the-amazon-from-space/discussion/32475
    """
    size = y.shape[1]

    def mf(x):
        p2 = np.zeros_like(p)
        for i in range(size):
            p2[:, i] = (p[:, i] > x[i]).astype(np.int)
        score = fbeta_score(y, p2, beta=2, average='samples')
        return score

    x = [0.2] * size
    for i in range(size):
        best_i2 = 0
        best_score = 0
        for i2 in range(resolution):
            i2 /= resolution
            x[i] = i2
            score = mf(x)
            if score > best_score:
                best_i2 = i2
                best_score = score
        x[i] = best_i2
        if verbose:
            print(i, best_i2, best_score)

    return x, best_score

# Loss

In [0]:
import torch.nn as nn
import torch.nn.functional as F


class FocalLoss(nn.Module):
    def __init__(self, gamma: float = 2):
        super().__init__()
        self.gamma = gamma

    def forward(self, input, target):
        if target.size() != input.size():
            raise ValueError(f'Target size ({target.size()}) must be the same as input size ({input.size()})')

        input = input.float()
        target = target.float()

        max_val = (-input).clamp(min=0)
        loss = input - input * target + max_val + ((-max_val).exp() + (-input - max_val).exp()).log()
        invprobs = F.logsigmoid(-input * (target * 2.0 - 1.0))
        loss = (invprobs * self.gamma).exp() * loss

        return loss.sum(dim=1).mean()


class BCEWithLogitsLoss(nn.Module):
    def __init__(self, weight=None):
        super().__init__()
        self.loss = nn.BCEWithLogitsLoss(weight)

    def forward(self, y_pred, y):
        return self.loss(y_pred.float(), y.float())


# Utils

In [0]:
import numpy as np
import os
import torch
import torch.nn
from sklearn.metrics import fbeta_score
import shutil

class AverageMeter:
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def calc_crop_size(target_w, target_h, angle=0.0, scale=1.0):
    crop_w = target_w
    crop_h = target_h
    if angle:
        corners = np.array(
            [[target_w/2, -target_w/2, -target_w/2, target_w/2],
            [target_h/2, target_h/2, -target_h/2, -target_h/2]])
        s = np.sin(angle * np.pi/180)
        c = np.cos(angle * np.pi/180)
        M = np.array([[c, -s], [s, c]])
        rotated_corners = np.dot(M, corners)
        crop_w = 2 * np.max(np.abs(rotated_corners[0, :]))
        crop_h = 2 * np.max(np.abs(rotated_corners[1, :]))
    crop_w = int(np.ceil(crop_w / scale))
    crop_h = int(np.ceil(crop_h / scale))
    return crop_w, crop_h


def crop_center(img, cx, cy, crop_w, crop_h):
    img_h, img_w = img.shape[:2]
    trunc_top = trunc_bottom = trunc_left = trunc_right = 0
    left = cx - crop_w//2
    if left < 0:
        trunc_left = 0 - left
        left = 0
    right = left - trunc_left + crop_w
    if right > img_w:
        trunc_right = right - img_w
        right = img_w
    top = cy - crop_h//2
    if top < 0:
        trunc_top = 0 - top
        top = 0
    bottom = top - trunc_top + crop_h
    if bottom > img_h:
        trunc_bottom = bottom - img_h
        bottom = img_h
    if trunc_left or trunc_right or trunc_top or trunc_bottom:
        img_new = np.zeros((crop_h, crop_w, img.shape[2]), dtype=img.dtype)
        trunc_bottom = crop_h - trunc_bottom
        trunc_right = crop_w - trunc_right
        img_new[trunc_top:trunc_bottom, trunc_left:trunc_right] = img[top:bottom, left:right]
        return img_new
    else:
        return img[top:bottom, left:right]


def crop_points_center(points, cx, cy, crop_w, crop_h):
    xl = cx - crop_w // 2
    xu = xl + crop_w
    yl = cy - crop_h // 2
    yu = yl + crop_h
    mask = (points[:, 0] >= xl) & (points[:, 0] < xu) & (points[:, 1] >= yl) & (points[:, 1] < yu)
    return points[mask]


def crop_points(points, x, y, crop_w, crop_h):
    xu = x + crop_w
    yu = y + crop_h
    mask = (points[:, 0] >= x) & (points[:, 0] < xu) & (points[:, 1] >= y) & (points[:, 1] < yu)
    return points[mask]

def adjust_learning_rate(optimizer, epoch, initial_lr, decay_epochs=30):
    """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
    if isinstance(optimizer, YFOptimizer):
        return
    lr = initial_lr * (0.1 ** (epoch // decay_epochs))
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr


def save_checkpoint(state, is_best, filename='checkpoint.pth.tar', output_dir=''):
    save_path = os.path.join(output_dir, filename)
    torch.save(state, save_path)
    if is_best:
        shutil.copyfile(save_path, os.path.join(output_dir, 'model_best.pth.tar'))


def accuracy(output, target, topk=(1,)):
    """Computes the precision@k for the specified values of k"""
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].view(-1).float().sum(0)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res


def scores(output, target, threshold=0.2):
    # Count true positives, true negatives, false positives and false negatives.
    outputr = (output > threshold).long()
    target = target.long()
    a_sum = 0.0
    p_sum = 0.0
    r_sum = 0.0
    f2_sum = 0.0

    def _safe_size(t, n=0):
        if n < len(t.size()):
            return t.size(n)
        else:
            return 0

    count = 0
    for o, t in zip(outputr, target):
        tp = _safe_size(torch.nonzero(o * t))
        tn = _safe_size(torch.nonzero((o - 1) * (t - 1)))
        fp = _safe_size(torch.nonzero(o * (t - 1)))
        fn = _safe_size(torch.nonzero((o - 1) * t))
        a = (tp + tn) / (tp + fp + fn + tn)
        if tp == 0 and fp == 0 and fn == 0:
            p = 1.0
            r = 1.0
            f2 = 1.0
        elif tp == 0 and (fp > 0 or fn > 0):
            p = 0.0
            r = 0.0
            f2 = 0.0
        else:
            p = tp / (tp + fp)
            r = tp / (tp + fn)
            f2 = (5 * p * r) / (4 * p + r)
        a_sum += a
        p_sum += p
        r_sum += r
        f2_sum += f2
        count += 1
    accuracy = a_sum / count
    precision = p_sum / count
    recall = r_sum / count
    fmeasure = f2_sum / count
    return accuracy, precision, recall, fmeasure


def f2_score(output, target, threshold):
    output = (output > threshold)
    return fbeta_score(target, output, beta=2, average='samples')


def optimise_f2_thresholds(y, p, verbose=True, resolution=100):
    """ Find optimal threshold values for f2 score. Thanks Anokas
    https://www.kaggle.com/c/planet-understanding-the-amazon-from-space/discussion/32475
    """
    size = y.shape[1]

    def mf(x):
        p2 = np.zeros_like(p)
        for i in range(size):
            p2[:, i] = (p[:, i] > x[i]).astype(np.int)
        score = fbeta_score(y, p2, beta=2, average='samples')
        return score

    x = [0.2] * size
    for i in range(size):
        best_i2 = 0
        best_score = 0
        for i2 in range(resolution):
            i2 /= resolution
            x[i] = i2
            score = mf(x)
            if score > best_score:
                best_i2 = i2
                best_score = score
        x[i] = best_i2
        if verbose:
            print(i, best_i2, best_score)

    return x, best_score

# Dataset

In [0]:
import cv2
import torch
import torch.utils.data as data
import pandas as pd
import numpy as np
import os
from torch.utils.data.sampler import Sampler
from typing import Tuple, List, Any
import albumentations as album
from PIL import Image
from torchvision import datasets, transforms

#BASE_PATH = 'C:\\Kaggle\\atlas\\rwightman\\data'
#TRAIN_CSV = 'train.csv'

IMG_EXTENSIONS = [ '.png']
LABELS = list(map(str,range(28)))
"""
def create_class_weight(labels_dict, mu=0.8):
  total = sum(labels_dict.values())
  keys = labels_dict.keys()
  class_weight = dict()
  class_weight_log = dict()

  for key in keys:
      score = total / float(labels_dict[key])
      score_log = math.log(mu * total / float(labels_dict[key]))
      class_weight[key] = round(score, 2) if score > 1.0 else round(1.0, 2)
      class_weight_log[key] = round(score_log, 2) if score_log > 1.0 else round(1.0, 2)
  return class_weight, class_weight_log

train_df = pd.read_csv(os.path.join(BASE_PATH,TRAIN_CSV))
train_df.Target = train_df.Target.map(lambda x: set(x.split()))
count = Counter()
train_df.Target.apply(lambda x: count.update(x))
labels_dict=dict(count)
true_class_weights=create_class_weight(labels_dict)[0]
log_class_weights=create_class_weight(labels_dict)[1]

ALL_WEIGHTS=[true_class_weights[key] for key in sorted(true_class_weights.keys(), 
             key=lambda x:int(x))]
    
ALL_WEIGHTS_L=[log_class_weights[key] for key in sorted(log_class_weights.keys(), 
             key=lambda x:int(x))]
"""   
ALL_WEIGHTS = [3.94, 40.5, 14.02, 32.53, 27.33, 20.21, 50.38, 18.0, 958.15, 
               1128.49, 1813.64, 46.46, 73.81, 94.57, 47.64, 2418.19, 95.82, 
               241.82, 56.3, 34.27, 295.24, 13.45, 63.32, 17.13, 157.71, 6.17, 
               154.82, 4616.55]

ALL_WEIGHTS_L = [1.15, 3.48, 2.42, 3.26, 3.08, 2.78, 3.7, 2.67, 6.64, 6.81, 
                 7.28, 3.62, 4.08, 4.33, 3.64, 7.57, 4.34, 5.27, 3.81, 3.31, 
                 5.46, 2.38, 3.93, 2.62, 4.84, 1.6, 4.82, 8.21]


def find_inputs(folder, types=IMG_EXTENSIONS):
    inputs = []
    for root, _, files in os.walk(folder, topdown=False):
        for rel_filename in files:
            base, ext = os.path.splitext(rel_filename)
            if ext.lower() in types:
                abs_filename = os.path.join(root, rel_filename)
                inputs.append((base, abs_filename))
    return inputs

def get_test_aug(factor):
    if not factor or factor == 1:
        return [
            [False, False, False]]
    elif factor == 4:
        # transpose, v-flip, h-flip
        return [
            [False, False, False],
            [False, False, True],
            [False, True, False],
            [True, True, True]]
    elif factor == 8:
        # return list of all combinations of flips and transpose
        return ((1 & np.arange(0, 8)[:, np.newaxis] // 2**np.arange(2, -1, -1)) > 0).tolist()
    else:
        print('Invalid augmentation factor')
        return [
            [False, False, False]]


In [0]:
class HumanDataset(data.Dataset):
    def __init__(
            self,
            input_root,
            target_file='',
            train=True,
            img_size=512,
            fold=0,
            test_aug=0,
            num_channels=3,
            transform=None):

        inputs = find_inputs(input_root)
        if len(inputs) == 0:
            raise (RuntimeError("Found 0 images in : " + input_root))
        target_df = pd.read_csv(target_file)
        if train:
            target_df = target_df[target_df['fold'] != fold]
        else:
            target_df = target_df[target_df['fold'] == fold]
        target_df.drop(['fold'], 1, inplace=True)

        self.inputs = target_df['Id'].apply(lambda x:os.path.join(input_root,x)).tolist()
        self.target_array = target_df.as_matrix(columns=LABELS).astype(np.float32)
        self.target_array = torch.from_numpy(self.target_array)
        self.train = train
        self.dataset_mean = [0.0804419, 0.05262986, 0.05474701] 
        self.dataset_std = [0.13000701, 0.08796628, 0.1386317] 
        self.img_size = img_size
        self.my_transform = transform
        if not train:
            self.test_aug = get_test_aug(test_aug)
        else:
            self.test_aug = []


    def _load_input(self, index):
        path = self.inputs[index]
        colors = ['red','green','blue']
        flags = cv2.IMREAD_GRAYSCALE
        img = [cv2.imread((path+'_'+color+'.png'), flags) for color in colors]
        return np.stack(img, axis=-1) 

    def __getitem__(self, index):
        input_img = self._load_input(index)
        if self.target_array is not None:
            target_tensor = self.target_array[index]
        else:
            target_tensor = torch.zeros(1)
        
        augmented = self.my_transform(image=input_img)
        input_tensor =  transforms.ToTensor()(augmented['image'])
        index_tensor = torch.LongTensor([index])
        return input_tensor, target_tensor, index_tensor
       
            
    def __len__(self):
        return len(self.inputs) * len(self.test_aug) if self.test_aug else len(self.inputs)

    def get_aug_factor(self):
        return len(self.test_aug)

    def get_class_weights(self):
        return np.array(ALL_WEIGHTS_L)

    def get_sample_weights(self):
        class_weights = torch.FloatTensor(self.get_class_weights())
        weighted_samples = []
        for index in range(len(self.inputs)):
            masked_weights = self.target_array[index] * class_weights
            weighted_samples.append(masked_weights.max())
        weighted_samples = torch.DoubleTensor(weighted_samples)
        weighted_samples = weighted_samples / weighted_samples.min()
        return weighted_samples

In [0]:
class WeightedRandomOverSampler(Sampler):
    #Over-samples elements from [0,..,len(weights)-1] factor number of times.
    #Each element is sample at least once, the remaining over-sampling is determined
    #by the weights.
    #Arguments:
    #    weights (list) : a list of weights, not necessary summing up to one
    #    factor (float) : the oversampling factor (>= 1.0)
   

    def __init__(self, weights, factor=2.):
        self.weights = torch.DoubleTensor(weights)
        assert factor >= 1.
        self.num_samples = int(len(self.weights) * factor)

    def __iter__(self):
        base_samples = torch.arange(0, len(self.weights)).long()
        remaining = self.num_samples - len(self.weights)
        over_samples = torch.multinomial(self.weights, remaining, True)
        samples = torch.cat((base_samples, over_samples), dim=0)
        print('num samples', len(samples))
        return (samples[i] for i in torch.randperm(len(samples)))

    def __len__(self):
        return self.num_samples


In [0]:
def get_transforms(image_size: int) -> Tuple[album.Compose, album.Compose, album.Compose]:
    transforms_train = album.Compose([
        album.Resize(image_size, image_size, interpolation=Image.BICUBIC),
        album.Rotate(interpolation=Image.BICUBIC),
        album.RandomRotate90(),
        album.HorizontalFlip(),
        album.RandomBrightnessContrast(),
        album.Normalize([0.08069, 0.05258, 0.05487], [0.13704, 0.10145, 0.15313])
    ])

    transforms_test = album.Compose([
        album.Resize(image_size, image_size, interpolation=Image.BICUBIC),
        album.Normalize([0.08069, 0.05258, 0.05487], [0.13704, 0.10145, 0.15313])
    ])

    transforms_test_aug = album.Compose([
        album.Resize(image_size, image_size, interpolation=Image.BICUBIC),
        album.Rotate(interpolation=Image.BICUBIC),
        album.RandomRotate90(),
        album.HorizontalFlip(),
        album.Normalize([0.08069, 0.05258, 0.05487], [0.13704, 0.10145, 0.15313])
    ])

    return transforms_train, transforms_test, transforms_test_aug

# Train epoch

In [0]:
import os
import time
import numpy as np
from collections import OrderedDict
import torch
import torch.autograd as autograd
import torch.nn
import torch.nn.functional as F
import torchvision.utils

def train_epoch(
        epoch, model, loader, optimizer, loss_fn,
        class_weights=None, output_dir='', exp=None, batch_limit=0):

    epoch_step = (epoch - 1) * len(loader)
    losses_m = AverageMeter()

    model.train()

    for batch_idx, (input, target, index) in enumerate(loader):
        if not config.no_cuda:
            input, target = input.cuda(), target.cuda()
        input_var = autograd.Variable(input)
        target_var = autograd.Variable(target)
        output = model(input_var)

        loss = loss_fn(output, target_var)
        losses_m.update(loss.data.item(), input_var.size(0))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
                             
        rowd = OrderedDict(batch_idx=batch_idx)
        rowd.update(OrderedDict([('lr', optimizer.param_groups[0]['lr'])]))

        with open(os.path.join(output_dir, 'summary_lr_%d.csv'%epoch), mode='a') as cf:
          dw = csv.DictWriter(cf, fieldnames=rowd.keys())
          dw.writerow(rowd)
        with open(os.path.join('summary_lr_%d.csv'%epoch), mode='a') as local:
          dw = csv.DictWriter(local, fieldnames=rowd.keys())
          dw.writerow(rowd)
          
        if batch_idx % config.log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]  '
                  'Loss: {loss.val:.6f} ({loss.avg:.4f})  '.format(
                epoch,
                batch_idx * len(input), len(loader.sampler),
                100. * batch_idx / len(loader),
                loss=losses_m))
        
            
            if exp is not None:
                exp.add_scalar_value('loss_train', losses_m.val, step=step)
                exp.add_scalar_value('learning_rate', optimizer.param_groups[0]['lr'], step=step)

            if config.save_batches:
                torchvision.utils.save_image(
                    input,
                    os.path.join(output_dir, 'train-batch-%d.jpg' % batch_idx),
                    padding=0,
                    normalize=True)

        if batch_limit and batch_idx >= batch_limit:
            break
    return OrderedDict([('train_loss', losses_m.avg)])

In [0]:
def validate(step, model, loader, loss_fn,  threshold, output_dir='', exp=None):
    batch_time_m = AverageMeter()
    losses_m = AverageMeter()
    prec1_m = AverageMeter()
    acc_m = AverageMeter()
    f2_m = AverageMeter()

    model.eval()

    end = time.time()
    output_list = []
    target_list = []
    for i, (input, target, _) in enumerate(loader):
        if not config.no_cuda:
            input, target = input.cuda(), target.cuda()
        target_var = autograd.Variable(target.max(dim=1)[1].squeeze(), volatile=True)
        input_var = autograd.Variable(input, volatile=True)

        # compute output
        output = model(input_var)

        # augmentation reduction
        reduce_factor = loader.dataset.get_aug_factor()
        if reduce_factor > 1:
            output.data = output.data.unfold(0, reduce_factor, reduce_factor).mean(dim=2)
            target_var.data = target_var.data[0:target_var.size(0):reduce_factor]

        # calc loss
        loss = loss_fn(output, target_var)
        losses_m.update(loss.data.item(), input.size(0))

        # output non-linearities and metrics

        output = F.softmax(output)
        a, p, _, f2 = scores(output.data, target_var.data, threshold)
        acc_m.update(a, output.size(0))
        prec1_m.update(p, output.size(0))
        f2_m.update(f2, output.size(0))

        # copy to CPU and collect
        target_list.append(target.cpu().numpy())
        output_list.append(output.data.cpu().numpy())

        batch_time_m.update(time.time() - end)
        end = time.time()
        if i % config.print_freq == 0:
            print('Test: [{0}/{1}]\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})  '
                  'Acc {acc.val:.4f} ({acc.avg:.4f})  '
                  'Prec {prec.val:.4f} ({prec.avg:.4f})  '
                  'F2 {f2.val:.4f} ({f2.avg:.4f})  '.format(
                    i, len(loader),
                    loss=losses_m,
                    acc=acc_m, prec=prec1_m, f2=f2_m))
 
            if config.save_batches:
                torchvision.utils.save_image(
                    input,
                    os.path.join(output_dir, 'validate-batch-%d.jpg' % i),
                    padding=0,
                    normalize=True)

    output_total = np.concatenate(output_list, axis=0)
    target_total = np.concatenate(target_list, axis=0)

    new_threshold, f2 = optimise_f2_thresholds(target_total, output_total)
    metrics = [('eval_loss', losses_m.avg), ('eval_f2', f2)]

    print(f2, new_threshold)

 #   if exp is not None:
#        exp.add_scalar_value('loss_eval', losses_m.avg, step=step)
#        exp.add_scalar_value('prec@1_eval', prec1_m.avg, step=step)
#        exp.add_scalar_value('f2_eval', f2, step=step)

    return OrderedDict(metrics), new_threshold

# Model

In [0]:
from typing import Tuple
import torchvision.models as models
import pretrainedmodels
import torch.nn as nn
import torch

class AdaptiveConcatPool2d(torch.nn.Module):
    def __init__(self, sz: Tuple[int, int] = (1, 1)):
        super().__init__()
        self.average_pool = torch.nn.AdaptiveAvgPool2d(sz)
        self.max_pool = torch.nn.AdaptiveMaxPool2d(sz)

    def forward(self, x):
        return torch.cat([self.max_pool(x), self.average_pool(x)], 1)
def freeze(model: torch.nn.Module):
    """Freeze all model parameters."""
    for param in model.parameters():
        param.requires_grad = False


def unfreeze(model: torch.nn.Module):
    """Unfreeze all model parameters."""
    for param in model.parameters():
        param.requires_grad = True
        


In [0]:
def load_model(model_name: str, num_classes: int, pretrained: str):
    return pretrainedmodels.__dict__[model_name](num_classes=num_classes, pretrained=pretrained)

def get_model(model_name: str,
              num_classes: int,
              num_channels: int = 3,
              dropout: float = 0.5,
              frozen: bool = True):
    model = load_model(model_name, num_classes=1000, pretrained='imagenet')

    if frozen:
        freeze(model)

    if num_channels == 4:
        w = model.conv1.weight
        model.conv1 = nn.Conv2d(4, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
        model.conv1.weight = nn.Parameter(torch.cat((w, w[:, 1:2, :, :]), dim=1))

    model.avgpool = nn.Sequential(AdaptiveConcatPool2d())

    model.last_linear = nn.Sequential(
        nn.BatchNorm1d(4096),
        nn.Dropout(dropout),
        nn.Linear(4096, 512),
        nn.ReLU(),
        nn.BatchNorm1d(512),
        nn.Dropout(dropout),
        nn.Linear(512, num_classes)
    )

    return model

# Train

In [0]:
import csv
import os
import numpy as np
from collections import OrderedDict
import torch
import torch.nn
import torch.optim as optim 
import torch.utils.data as data

def main():
    train_input_root = os.path.join(config.data_path,'train')
    train_labels_file = config.target_file

    if config.output_path:
        output_base = config.output_path
    else:
        output_base = './output'

    exp_name = '-'.join([config.exp,
        config.model,
        str(config.img_size),
        'f'+str(config.fold)])
    output_dir = os.path.join(output_base, 'train', exp_name)
    if not os.path.exists(output_dir):
      os.makedirs(output_dir)

    batch_size = config.batch_size
    num_epochs = config.epochs
    img_size = config.img_size
    transforms_train, transforms_test, transforms_test_aug = get_transforms(img_size)

    torch.manual_seed(config.seed)

    dataset_train = HumanDataset(
        train_input_root,
        train_labels_file,
        train=True,
        img_size=img_size,
        fold=config.fold,
        transform=transforms_train
    )
    sampler = WeightedRandomOverSampler(dataset_train.get_sample_weights())
    loader_train = data.DataLoader(
        dataset_train,
        batch_size=batch_size,
        shuffle=False,
        sampler=sampler,
        num_workers=1,
        pin_memory=False
    )
    dataset_eval = HumanDataset(
        train_input_root,
        train_labels_file,
        train=False,
        img_size=img_size,
        fold=config.fold,
        transform=transforms_test
    )

    loader_eval = data.DataLoader(
        dataset_eval,
        batch_size=batch_size,
        shuffle=False,
        num_workers=1,
        pin_memory=False
    )
    model=get_model(config.model,config.num_classes,config.num_channels,config.dropout)
    
    if not config.no_cuda:
        model.cuda()

    if config.opt.lower() == 'sgd':
        optimizer = optim.SGD(
            model.parameters(), lr=config.lr, momentum=config.momentum, weight_decay=config.weight_decay)
    elif config.opt.lower() == 'adam':
        optimizer = optim.Adam(
            model.parameters(), lr=config.lr, weight_decay=config.weight_decay)
    elif config.opt.lower() == 'adadelta':
        optimizer = optim.Adadelta(
            model.parameters(), lr=config.lr, weight_decay=config.weight_decay)
    elif config.opt.lower() == 'rmsprop':
        optimizer = optim.RMSprop(
            model.parameters(), lr=config.lr, alpha=0.9, momentum=config.momentum, 
            weight_decay=config.weight_decay)
    elif config.opt.lower() == 'yellowfin':
        optimizer = YFOptimizer(
            model.parameters(), lr=config.lr, weight_decay=config.weight_decay, clip_thresh=2)
    else:
        assert False and "Invalid optimizer"

    if not config.decay_epochs:
        lr_scheduler = ReduceLROnPlateau(optimizer, patience=8)
    else:
        lr_scheduler = None

    if config.class_weights:
        class_weights = torch.from_numpy(dataset_train.get_class_weights()).float()
        class_weights_norm = class_weights / class_weights.sum()
        if not config.no_cuda:
            class_weights = class_weights.cuda()
            class_weights_norm = class_weights_norm.cuda()
    else:
        class_weights = None
        class_weights_norm = None
    
    if config.loss.lower() == 'bce':
        #assert not args.multi_label and 'Cannot use crossentropy with multi-label target.'
        loss_fn = BCEWithLogitsLoss(weight=class_weights)
    elif config.loss.lower() == 'focal':
        loss_fn =FocalLoss()
    else:
        assert False and "Invalid loss function"
    if not config.no_cuda:
        loss_fn = loss_fn.cuda()

    start_epoch = 1

        
    if config.checkpoint:
        if os.path.isfile(config.checkpoint):
            print("=> loading checkpoint '{}'".format(config.checkpoint))
            checkpoint = torch.load(config.checkpoint)   
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})".format(config.checkpoint, checkpoint['epoch']))
            start_epoch = checkpoint['epoch']
        else:
            print("=> no checkpoint found at '{}'".format(config.checkpoint))
            exit(-1)
    # Optional fine-tune of only the final classifier weights for specified number of epochs (or part of)
    if not config.checkpoint and config.ft_epochs > 0.:
        if config.opt.lower() == 'adam':
            finetune_optimizer = optim.Adam(
                model.get_fc().parameters(), lr=config.ft_lr, weight_decay=config.weight_decay)
        else:
            finetune_optimizer = optim.SGD(
                model.get_fc().parameters(), lr=config.ft_lr, momentum=config.momentum, weight_decay=config.weight_decay)

        finetune_epochs_int = int(np.ceil(config.ft_epochs))
        finetune_final_batches = int(np.ceil((1 - (finetune_epochs_int - config.ft_epochs)) * len(loader_train)))
        print(finetune_epochs_int, finetune_final_batches)
        for fepoch in range(1, finetune_epochs_int + 1):
            if fepoch == finetune_epochs_int and finetune_final_batches:
                batch_limit = finetune_final_batches
            else:
                batch_limit = 0
            train_epoch(
                fepoch, model, loader_train, finetune_optimizer, loss_fn, 
                class_weights_norm, output_dir, batch_limit=batch_limit)
            step = fepoch * len(loader_train)
            score, _ = validate(step, model, loader_eval, loss_fn,  0.2, output_dir)
            
    score_metric = 'f2'
    best_loss = None
    best_f2 = None
    threshold = 0.2
    try:
        for epoch in range(start_epoch, num_epochs + 1):
            if config.decay_epochs:
                adjust_learning_rate(optimizer, epoch, initial_lr=config.lr, decay_epochs=config.decay_epochs)
                
            train_metrics = train_epoch(
                epoch, model, loader_train, optimizer, loss_fn, class_weights_norm, output_dir, exp=None)
            
            step = epoch * len(loader_train)
            eval_metrics, latest_threshold = validate(
                step, model, loader_eval, loss_fn,  threshold, output_dir, exp=None)

            if lr_scheduler is not None:
                lr_scheduler.step(eval_metrics['eval_loss'])

            rowd = OrderedDict(epoch=epoch)
            rowd.update(train_metrics)
            rowd.update(eval_metrics)
            with open(os.path.join(output_dir, 'summary.csv'), mode='a') as cf:
                dw = csv.DictWriter(cf, fieldnames=rowd.keys())
                if best_loss is None:  # first iteration (epoch == 1 can't be used)
                    dw.writeheader()
                dw.writerow(rowd)

            best = False
            if best_loss is None or eval_metrics['eval_loss'] < best_loss[1]:
                best_loss = (epoch, eval_metrics['eval_loss'])
                if score_metric == 'loss':
                    best = True
            if best_f2 is None or eval_metrics['eval_f2'] > best_f2[1]:
                best_f2 = (epoch, eval_metrics['eval_f2'])
                if score_metric == 'f2':
                    best = True

            save_checkpoint({
                'epoch': epoch + 1,
                'arch': config.model,
                'state_dict':  model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'threshold': latest_threshold,
                },
                is_best=best,
                filename='checkpoint-%d.pth.tar' % epoch,
                output_dir=output_dir)

    except KeyboardInterrupt:
      print('*** Best loss: {0} (epoch {1})'.format(best_loss[1], best_loss[0]))
      print('*** Best f2: {0} (epoch {1})'.format(best_f2[1], best_f2[0]))           
            
            
if __name__ == '__main__':
    main()

num samples 49528
