Implementation of https://openreview.net/pdf?id=ZPa2SyGcbwh using Fastai

In [None]:
from fastai.vision.all import *

In [None]:
set_seed(999, reproducible=True)

In [None]:
datapath = Path("/kaggle/input/cassava-leaf-disease-classification/")

In [None]:
files = get_image_files(datapath/'train_images')

In [None]:
train_df = pd.read_csv(datapath/'train.csv')

In [None]:
vocab2id = {'Cassava Bacterial Blight (CBB)':0,
            'Cassava Brown Streak Disease (CBSD)':1,
            'Cassava Green Mottle (CGM)':2,
            'Cassava Mosaic Disease (CMD)':3,
            'Healthy':4, 
              }
id2vocab = {v:k for k,v in vocab2id.items()}

vocab = list(vocab2id.keys()); vocab

In [None]:
class GetLabel(DisplayedTransform):
    def __init__(self, fname2labels): store_attr()
    def encodes(self, o):             return self.fname2labels[str(o)]

In [None]:
fnames = str(datapath/'train_images') + "/" + train_df['image_id'].values
labels = train_df['label'].map(id2vocab)
fname2labels = dict(zip(fnames, labels))

In [None]:
def get_dls(files, fname2labels, size = (512,512), bs=32):     

    x_tfms = [PILImage.create]
    y_tfms = [GetLabel(fname2labels), Categorize(vocab)]
    tfms = [x_tfms, y_tfms]

    dsets = Datasets(files, tfms=tfms, splits=RandomSplitter(0.2)(files))

    batch_tfms = []
    batch_tfms.append(Dihedral(p=0.5))
    batch_tfms.append(Rotate(p=0.5, max_deg=45))
    batch_tfms.append(RandomErasing(p=0.5, sl=0.05, sh=0.05, min_aspect=1., max_count=15))
    batch_tfms.append(Brightness(p=0.5, max_lighting=0.3, batch=False))
    batch_tfms.append(Hue(p=0.5, max_hue=0.1, batch=False))
    batch_tfms.append(Saturation(p=0.5, max_lighting=0.1, batch=False))   
    batch_tfms.append(RandomResizedCropGPU(size, min_scale=0.4))
     
    item_tfms = [ToTensor()]

    batch_tfms = [IntToFloatTensor] + batch_tfms + [Normalize.from_stats(*imagenet_stats)]

    train_dl = TfmdDL(dsets.train, shuffle=True, bs=bs, after_item=item_tfms, after_batch=batch_tfms, drop_last=False)
    valid_dl = TfmdDL(dsets.valid, shuffle=False, bs=bs*2, after_item=item_tfms, after_batch=batch_tfms)
    
    dls = DataLoaders(train_dl, valid_dl, device=default_device())
    
    return dls

In [None]:
dls = get_dls(fnames[:32], fname2labels, bs=8)

In [None]:
Counter([dls.tfms[1][0].fname2labels[str(o)] for o in dls.items])

In [None]:
dls.show_batch()

### Timm Utils

In [None]:
!pip install -q timm

In [None]:
# Source: https://github.com/walkwithfastai/walkwithfastai.github.io/blob/master/wwf/vision/timm.py#L13
# Cell
from timm import create_model
from fastai.vision.learner import _update_first_layer

# Cell
def create_timm_body(arch:str, pretrained=True, cut=None, n_in=3):
    "Creates a body from any model in the `timm` library."
    model = create_model(arch, pretrained=pretrained, num_classes=0, global_pool='')
    _update_first_layer(model, n_in, pretrained)
    if cut is None:
        ll = list(enumerate(model.children()))
        cut = next(i for i,o in reversed(ll) if has_pool_type(o))
    if isinstance(cut, int): return nn.Sequential(*list(model.children())[:cut])
    elif callable(cut): return cut(model)
    else: raise NamedError("cut must be either integer or function")

# Cell
def create_timm_model(arch:str, n_out, cut=None, pretrained=True, n_in=3, init=nn.init.kaiming_normal_, custom_head=None,
                     concat_pool=True, **kwargs):
    "Create custom architecture using `arch`, `n_in` and `n_out` from the `timm` library"
    body = create_timm_body(arch, pretrained, None, n_in)
    if custom_head is None:
        nf = num_features_model(nn.Sequential(*body.children())) * (2 if concat_pool else 1)
        head = create_head(nf, n_out, concat_pool=concat_pool, **kwargs)
    else: head = custom_head
    model = nn.Sequential(body, head)
    if init is not None: apply_init(model[1], init)
    return model

# Cell
from fastai.vision.learner import _add_norm

# Cell
def timm_learner(dls, arch:str, loss_func=None, pretrained=True, cut=None, splitter=None,
                y_range=None, config=None, n_out=None, normalize=True, **kwargs):
    "Build a convnet style learner from `dls` and `arch` using the `timm` library"
    if config is None: config = {}
    if n_out is None: n_out = get_c(dls)
    assert n_out, "`n_out` is not defined, and could not be inferred from data, set `dls.c` or pass `n_out`"
    if y_range is None and 'y_range' in config: y_range = config.pop('y_range')
    model = create_timm_model(arch, n_out, default_split, pretrained, y_range=y_range, **config)
    learn = Learner(dls, model, loss_func=loss_func, splitter=default_split, **kwargs)
    if pretrained: learn.freeze()
    return learn

### ProgressiveLabelCorrection Callback

In [None]:
from fastai.vision.all import *

class ProgressiveLabelCorrection(Callback):
    'https://openreview.net/pdf?id=ZPa2SyGcbwh'
    run_valid=False
    def __init__(self, num_classes=5, theta_start=0.3, theta_end=0.05, warm_up=0.2, sched_func=sched_lin):
        store_attr()
        self.theta = theta_start

    def before_fit(self):
        self.eye = torch.eye(self.num_classes).to(self.dls.device)
    
    def after_step(self):        
        if self.pct_train > self.warm_up:
#             set_trace()
            # get batch items
            self.idxs = self.dl._DataLoader__idxs
            self.b_items = L(list(self.dl.items))[self.idxs[self.iter*self.dl.bs:self.iter*self.dl.bs+self.dl.bs]]

            # get incorrectly classified item boolean mask
            preds_max = self.pred.argmax(-1)
            mislabeled_idxs = preds_max != self.y
            
            tot_mislabeled = sum(mislabeled_idxs)
            if tot_mislabeled > 0:
                # get probas and targs for incorrect batch items
                mislabeled_probas = self.pred[mislabeled_idxs].softmax(-1)
                mislabeled_targs = self.y[mislabeled_idxs]

                # get predicted and actual probas and targ
                predicted_probas = mislabeled_probas.max(-1).values
                predicted_targs = mislabeled_probas.max(-1).indices
                actual_probas = mislabeled_probas[self.eye[mislabeled_targs].bool()]

                # update labels or theta 
                msk = torch.abs(predicted_probas - actual_probas) > self.theta
                self.epoch_tot_updated += sum(msk)
                
                new_targs = self.dl.tfms[1][1].vocab[predicted_targs[msk]]
                items_to_change = self.b_items[mislabeled_idxs][msk]

                update_dict = dict(zip(items_to_change, new_targs))
                self.dl.tfms[1][0].fname2labels.update(update_dict)
                print(f"Total changed items in this epoch: {self.epoch_tot_updated}")
                print("New Label Distribution")
                print(Counter([dls.tfms[1][0].fname2labels[str(o)] for o in dls.items]))
                
    def before_epoch(self): 
        self.epoch_tot_updated = 0
        
    def after_epoch(self):   
        if (self.pct_train > self.warm_up) and (self.epoch_tot_updated == 0):
            self.theta = self.sched_func(self.theta_start, self.theta_end, self.pct_train)
            print(f"Reduced theta to: {self.theta}")

### Learner

In [None]:
arch_name = 'efficientnet_b3a'
model = create_timm_model(arch_name, 5, ps=0.5)

In [None]:
def _timm_split(m): return L(m[0], m[1]).map(params)
loss_func = LabelSmoothingCrossEntropy(0.01)
metric = accuracy
cbs = [SaveModelCallback(metric.__name__, fname=arch_name), 
       TerminateOnNaNCallback()]

dls = get_dls(fnames, fname2labels, bs=32)
learner = Learner(dls, model, metrics=metric, loss_func=loss_func, cbs=cbs, splitter=_timm_split)
learner.to_native_fp16();

In [None]:
learner.freeze_to(-1)
learner.fit_flat_cos(3, 1e-2)

In [None]:
learner.unfreeze()
learner.fit_flat_cos(3, slice(5e-4, 1e-3), pct_start=0.0, wd=1e-2,
                     cbs=[ProgressiveLabelCorrection(theta_start=0.4, theta_end=0.1, warm_up=0.2)])

In [None]:
learner.export(f"models/{arch_name}_export.pkl")

### Upvote if you find it useful and let me know if you have any feedback!