In [None]:
from fastai.vision.all import *
import pandas as pd
import cam
import util

In [None]:
dls, labels = util.chexpert_data_loader()

In [None]:
dls.show_batch(max_n=9, figsize=(20,9))

In [None]:
# First train on conditional probabilities
chexpert_learner_conditional = util.ChexpertLearner(dls, densenet121, n_out=len(labels), y_range=(0, 1),
                                        loss_func=util.BCEFlatHLCP(hierarchy_map=util.hierarchy_map),
                                        metrics=[partial(accuracy_multi, sigmoid=False),
                                                 RocAucMulti(average='weighted')])

In [None]:
# Quick way to find the optimal LRs
# take the max of the lr_min (min loss/10)
# and lr_steep (steepest loss/lr curve) as
# fine_tune will use a cycle rangining from
# base_lr/100 to base_lr
lr_min, lr_steep = chexpert_learner_conditional.learn.lr_find()
base_lr = max(lr_min, lr_steep)
base_lr

In [None]:
chexpert_learner_conditional.learn_model(use_saved=False , epochs=30, freeze_epochs=2, base_lr=base_lr)

In [None]:
# Next train unconditionally for only transfer learning
chexpert_learner_unconditional = util.ChexpertLearner(dls, densenet121, n_out=len(labels), y_range=(0, 1),
                                        loss_func=BCELossFlat(),
                                        metrics=[partial(accuracy_multi, sigmoid=False),
                                                 RocAucMulti(average='weighted')])

In [None]:
# Quick way to find the optimal LRs
# take the max of the lr_min (min loss/10)
# and lr_steep (steepest loss/lr curve) as
# fine_tune will use a cycle rangining from
# base_lr/100 to base_lr
lr_min, lr_steep = chexpert_learner_unconditional.learn.lr_find()
base_lr = max(lr_min, lr_steep)
base_lr

In [None]:
chexpert_learner_unconditional.learn_model(use_saved=True, train_saved=True, epochs=30, freeze_epochs=30, base_lr=base_lr)

In [None]:
chexpert_learner = chexpert_learner_unconditional
chexpert_learner.learn.show_results()

In [None]:
interp = Interpretation.from_learner(chexpert_learner.learn)
interp.plot_top_losses(9)

In [None]:
cam.plot_cam(chexpert_learner.learn)