In [1]:
from fastai.vision.all import *
import pandas as pd
import cam
import util

In [2]:
dls, labels = util.chexpert_data_loader(bs=32)

In [92]:
from libauc.losses import AUCMLoss
# from libauc.optimizers import PESG
# from fastai.optimizer import OptimWrapper

gamma = 500
weight_decay = 0
margin = 1.0

criterion = AUCMLoss()

# Create a fastai wrapper for the PESG optimizer. Reference: https://docs.fast.ai/optimizer.html#OptimWrapper
optimizer = partial(OptimWrapper, 
                    convert_groups=False,
#                     hp_map={'imratio':'imratio', 'm':'m', 'lr':'lr', 'gamma':'gamma', 
#                         'clip_value':'clip_value', 'weight_decay':'weight_decay'},
                   opt=PESG, 
                   a=criterion.a, 
                   b=criterion.b, 
                   alpha=criterion.alpha, 
                   gamma=gamma, 
                   margin=margin)

In [1]:
from libauc.losses import AUCMLoss
AUCMLoss??

In [88]:
class _BaseOptimizer():
    "Common functionality between `Optimizer` and `OptimWrapper`"
    def all_params(self, n=slice(None), with_grad=False):
        res = L((p,pg,self.state[p],hyper) for pg,hyper in zip(self.param_lists[n],self.hypers[n]) for p in pg)
        return L(o for o in res if hasattr(o[0], 'grad') and o[0].grad is not None) if with_grad else res

    def _set_require_grad(self, rg, p,pg,state,h): p.requires_grad_(rg or state.get('force_train', False))
    def freeze_to(self, n):
        self.frozen_idx = n if n >= 0 else len(self.param_lists) + n
        if self.frozen_idx >= len(self.param_lists):
            warn(f"Freezing {self.frozen_idx} groups; model has {len(self.param_lists)}; whole model is frozen.")
        for o in self.all_params(slice(n, None)): self._set_require_grad(True,  *o)
        for o in self.all_params(slice(None, n)): self._set_require_grad(False, *o)

    def freeze(self):
        assert(len(self.param_lists)>1)
        self.freeze_to(-1)

    def set_freeze(self, n, rg, ignore_force_train=False):
        for p in self.param_lists[n]: p.requires_grad_(rg or (state.get('force_train', False) and not ignore_force_train))

    def unfreeze(self): self.freeze_to(0)
    def set_hypers(self, **kwargs): L(kwargs.items()).starmap(self.set_hyper)
    def _set_hyper(self, k, v):
        for v_,h in zip(v, self.hypers): h[k] = v_

    def set_hyper(self, k, v):
        if isinstance(v, slice):
            if v.start: v = even_mults(v.start, v.stop, len(self.param_lists))
            else: v = [v.stop/10]*(len(self.param_lists)-1) + [v.stop]
        v = L(v, use_list=None)
        if len(v)==1: v = v*len(self.param_lists)
        assert len(v) == len(self.hypers), f"Trying to set {len(v)} values for {k} but there are {len(self.param_lists)} parameter groups."
        self._set_hyper(k, v)

    @property
    def param_groups(self): return [{**{'params': pg}, **hp} for pg,hp in zip(self.param_lists, self.hypers)]
    @param_groups.setter
    def param_groups(self, v):
        for pg,v_ in zip(self.param_lists,v): pg = v_['params']
        for hyper,v_ in zip(self.hypers,v):
            for k,t in v_.items():
                if k != 'params': hyper[k] = t

class OptimWrapper(_BaseOptimizer, GetAttr):
    "A wrapper class for existing PyTorch optimizers"
    _xtra=['zero_grad', 'step', 'state_dict', 'load_state_dict']
    _default='opt'
    def __init__(self, params, opt, hp_map=None, convert_groups=True, **kwargs):
        self.opt = opt(_convert_params(params), **kwargs) if convert_groups else opt(params, **kwargs)
        if hp_map is None: hp_map = pytorch_hp_map
        self.fwd_map = {k: hp_map[k] if k in hp_map else k for k in self.opt.param_groups[0].keys()}
        self.bwd_map = {v:k for k,v in self.fwd_map.items()}
        print(self.opt.param_groups[0].keys())
        self.state = defaultdict(dict, {})
        self.frozen_idx = 0

    @property
    def hypers(self):
        return [{self.fwd_map[k]:v for k,v in detuplify_pg(pg).items() if k != 'params'} for pg in self.opt.param_groups]

    def _set_hyper(self, k, v):
        for pg,v_ in zip(self.opt.param_groups,v): pg = set_item_pg(pg, self.bwd_map[k], v_)

    def clear_state(self): self.opt.state = defaultdict(dict, {})

    @property
    def param_lists(self): return [pg['params'] for pg in self.opt.param_groups]
    @param_lists.setter
    def param_lists(self, v):
        for pg,v_ in zip(self.opt.param_groups,v): pg['params'] = v_

In [33]:
PESG??

In [64]:
OptimWrapper??

In [93]:
chexpert_learner_dam = ChexpertLearner(dls, densenet121, n_out=len(labels),
                                        loss_func=criterion,
                                        opt_func=optimizer,
                                        model_dir=Path('auc_maximization_stage_2'),
                                        metrics=[RocAucMulti(average=None),
                                                 RocAucMulti(average='weighted')])

dict_keys(['params'])
dict_keys(['params'])


In [73]:
chexpert_learner_dam.find_lr()

KeyError: 'lr'

In [91]:
import copy

class PESG(Optimizer):
    def __init__(self, params=None, a=None, b=None, alpha=None, imratio=0.1, m=1.0, lr=0.1, gamma=500, clip_value=1.0, weight_decay=1e-5, **kwargs):
        self.p = imratio
        self.m = m
        self.params = params
        self.lr = lr
        self.gamma = gamma
        self.clip_value = clip_value
        self.weight_decay = weight_decay
        assert a is not None, 'You are missing variable a!'
        assert b is not None, 'You are missing variable b!'
        assert alpha is not None, 'You are missing variable alpha!'
        self.a = a 
        self.b = b 
        self.alpha = alpha 
        self.a_ref = torch.empty(self.a.shape).normal_(mean=0,std=0.01).cuda() 
        self.b_ref = torch.empty(self.b.shape).normal_(mean=0,std=0.01).cuda() 
        self.a_acc = self.a.clone().detach().requires_grad_(False)
        self.b_acc = self.b.clone().detach().requires_grad_(False)
        self.T = 0 
        
        defaults = dict(imratio=imratio, m=m, lr=lr, gamma=gamma, 
                        clip_value=clip_value, weight_decay=weight_decay)
        super(PESG, self).__init__(params, defaults)
        self.model_acc = copy.deepcopy(self.param_groups)
        for group in self.param_groups:
            model_ref = []
            for p in group['params']:
                if p.grad is not None:
                    model_ref.append(torch.empty(p.shape).normal_(mean=0, std=0.01).cuda())
            group['model_ref'] = model_ref
        print(self.param_groups[0].keys())

    def step(self):
        for i, group in enumerate(self.param_groups):
            for j, param in enumerate(group['params']):
                param.data = param.data - self.lr*( torch.clamp(param.grad.data , -self.clip_value, self.clip_value) + 1/self.gamma*(param.data - group['model_ref'][j].data)) - self.lr*self.weight_decay*param.data
                self.model_acc[i]['params'][j].data += param.data
        self.a.data = self.a.data - self.lr*(torch.clamp(self.a.grad.data, -self.clip_value, self.clip_value) + 1/self.gamma*(self.a.data - self.a_ref.data))- self.lr*self.weight_decay*self.a.data 
        self.b.data = self.b.data - self.lr*(torch.clamp(self.b.grad.data , -self.clip_value, self.clip_value) + 1/self.gamma*(self.b.data - self.b_ref.data))- self.lr*self.weight_decay*self.b.data 
        self.alpha.data = self.alpha.data + self.lr*(2*(self.m + self.b.data - self.a.data)-2*self.alpha.data)
        self.alpha.data  = torch.clamp(self.alpha.data,  0, 999)
        self.a_acc.data = self.a_acc.data + self.a.data
        self.b_acc.data = self.b_acc.data + self.b.data
        self.T = self.T + 1
        self.T = self.T + 1
    def zero_grad(self):
        super().zero_grad()
        self.a.grad = None
        self.b.grad = None
        self.alpha.grad = None

In [None]:
def train(model, trainloader, testloader, loss_func, opt_func):
    for epoch in range(100):
    
        if epoch == 50 or epoch==75:
            optimizer.lr = optimizer.lr/10
            optimizer.update_regularizer()

        train_pred = []
        train_true = []
        model.train()    
        for data, targets in trainloader:
            data, targets  = data.cuda(), targets.cuda()
            y_pred = model(data)
            loss = Loss(y_pred, targets)
            optimizer.zero_grad()
            loss.backward(retain_graph=True)
            optimizer.step()

            train_pred.append(y_pred.cpu().detach().numpy())
            train_true.append(targets.cpu().detach().numpy())

        train_true = np.concatenate(train_true)
        train_pred = np.concatenate(train_pred)
        train_auc = roc_auc_score(train_true, train_pred) 

        model.eval()
        test_pred = []
        test_true = [] 
        for j, data in enumerate(testloader):
            test_data, test_targets = data
            test_data = test_data.cuda()
            y_pred = model(test_data)
            test_pred.append(y_pred.cpu().detach().numpy())
            test_true.append(test_targets.numpy())
        test_true = np.concatenate(test_true)
        test_pred = np.concatenate(test_pred)
        val_auc =  roc_auc_score(test_true, test_pred) 
        model.train()

        # print results
        print("epoch: {}, train_loss: {:4f}, train_auc:{:4f}, test_auc:{:4f}, lr:{:4f}".format(epoch, loss.item(), train_auc, val_auc, optimizer.lr ))

In [11]:
class ChexpertLearner(Learner):
    """ Learner wrapper specifically
        created for CheXpert
    """


    def __init__(self, dls, arch, **kwargs):
        # For guide on all the input parameters,
        # check doc for cnn_learner
        self.path = Path('../saves/')
        self.learn = cnn_learner(dls, arch, path=self.path, **kwargs)
        self.loss_func = kwargs.get('loss_func')
        self.base_lr = 0.002
        self.saved_model_name = f'{self.learn.arch.__name__}-chexpert'


    def find_lr(self):
        # Quick way to find the optimal LRs
        # take the max of the lr_min (min loss/10)
        # and lr_steep (steepest loss/lr curve) as
        # fine_tune will use a cycle rangining from
        # base_lr/100 to base_lr

        # Refer: https://iconof.com/1cycle-learning-rate-policy/
        # Citation:
        #     Smith LN. Cyclical learning rates for training neural networks.
        #     In 2017 IEEE winter conference on applications of computer vision
        #     (WACV) 2017 Mar 24 (pp. 464-472). IEEE.
        
        torch.cuda.empty_cache()

        lr_min, lr_steep = self.learn.lr_find()
        self.base_lr = max(lr_min, lr_steep)
        print(f'lr_min/10: {lr_min}, lr_steep: {lr_steep}, base_lr: {self.base_lr}')


    def learn_model(self, use_saved=False, train_saved=False, old_learner=None, saved_model_name=None,
                    # other args for Learner.fine_tune
                    **kwargs):
        """ Load a previously saved model or train a new model """
        if saved_model_name:
            self.saved_model_name = saved_model_name

        if use_saved:
            try:
                if not train_saved:
                    self.learn.load(self.saved_model_name)
                    return
                if old_learner:
                    old_learner.load(self.saved_model_name)
                    self.learn.model[0].load_state_dict(old_learner.model[0].state_dict())
                    self.learn.loss_func = self.loss_func
                    
            except FileNotFoundError as e:
                print(f'Could not find saved model {self.saved_model_name}.')
        
        torch.cuda.empty_cache()

        # `fine_tune` first freezes the body and then updates only head weights for freeze_epochs
        # then it unfreezes the body and updates all weights for epochs.
        # Internally it uses `fit_one_cycle` for cyclical learning rates:
        # Refer: https://iconof.com/1cycle-learning-rate-policy/
        # Citation:
        #     Smith LN. Cyclical learning rates for training neural networks.
        #     In 2017 IEEE winter conference on applications of computer vision
        #     (WACV) 2017 Mar 24 (pp. 464-472). IEEE.
        
        # Using callbacks for a few things:
        callbacks = [
            ShowGraphCallback(), # Show the graph
            SaveModelCallback(fname=self.saved_model_name, with_opt=True), # Save the model if it improves
            ReduceLROnPlateau(patience=2), # If the validation loss stops improving then reduce it by a factor of 10
            CSVLogger(f'{self.saved_model_name}.csv'), # CSV file for training results
            EarlyStoppingCallback(patience=5),
        ]
        
        self.learn.fine_tune(cbs=callbacks, base_lr=self.base_lr, **kwargs)