In [None]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline
%pylab inline

from fastai.vision import *
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.image as mpimg

In [None]:
experiment_name = 'retrain_base_pretrain_labch_brit_rus'
experiment_shortname = 'labch_brit_rus'
path_base = untar_data(URLs.PETS)
path = path_base/'images'
path_experiments = Path('./experiments')
path_experiment = path_experiments/experiment_name
path_experiment.mkdir(parents = True, exist_ok=True)
csv_name = 'switch_British_Shorthair-Russian_Blue.csv'
# csv_name = 'labels.csv'
save_newmodels = False

In [None]:
lrs_totest = [3e-5, 5e-5, 1e-4, 3e-4, 5e-4]
# lrs_totest = [1e-3, 3e-3]
fit_types = ['onecycle']
epochs = [10, 20]
# epoch_long = 30
# epoch_short = 10
# epoch_long = 3
# epoch_short = 1

### Utils

In [None]:
def fig_save_htmllink(modelsavename, plotname, fig):
    fname = "{}_{}.png".format(modelsavename,plotname)
    fig.savefig(path_experiment/fname)
    imgstr = '<img src="{}" /> '.format(path_experiment/fname)
    return imgstr

In [None]:
from io import BytesIO
import base64
def fig2inlinehtml(fig):
    figfile = BytesIO()
    fig.savefig(figfile, format='png')
    figfile.seek(0) 
    # for python 2.7:
    #figdata_png = base64.b64encode(figfile.getvalue())
    # for python 3.x:
    figdata_png = base64.b64encode(figfile.getvalue()).decode()
    imgstr = '<img src="data:image/png;base64,{}" />'.format(figdata_png)
    return imgstr

### Load data

In [None]:
src = ImageList.from_csv(path_base, csv_name=csv_name).split_by_rand_pct(0.4, seed=2).label_from_df(cols='label')

In [None]:
bs = 64
tfms = get_transforms()
np.random.seed(2)

In [None]:
def get_data(size, bs, padding_mode='reflection'):
    return (src.transform(tfms, size=size, padding_mode=padding_mode)
           .databunch(bs=bs).normalize(imagenet_stats))

In [None]:
data = get_data(224, bs)

In [None]:
gc.collect()

### Train script

In [None]:
pd.set_option('display.max_colwidth', None)
output_df = pd.DataFrame(columns=['basemodeltype', 'basemodelfile', 'newmodelfile', 'lr', 'fit_type', 'epoch', 'orig_metric', 'retrain_metric',
                                'orig_most_confused', 'fig_base_toplosses', 'fig_base_confmat', 'fig_train_confmat', 'fig_orig_lrfind',
                                'last_train_loss', 'last_val_loss', 'last_val_metric', 
                                'fig_train_lr', 'fig_train_losses', 'fig_train_metrics'])

In [None]:
%matplotlib agg
%matplotlib agg
def trainscript(output_df):
    
    _ = learn.load(basemodelfile)

    orig_val_metric = float(learn.validate()[1])

    interp = ClassificationInterpretation.from_learner(learn)
    losses,idxs = interp.top_losses()
    fig_base_toplosses = interp.plot_top_losses(9, figsize=(15,11), return_fig=True)
    fig_base_confmat = interp.plot_confusion_matrix(figsize=(12,12), dpi=60, return_fig=True)
    most_conf_base = interp.most_confused(min_val=10)

    learn.lr_find()
    fig_orig_lrfind = learn.recorder.plot(return_fig = True)

    for epoch in epochs:
        for lr in lrs_totest:
#             epoch = epoch_short if lr >= 1e-3 else epoch_long
            for fit_type in fit_types:
                _ = learn.load(basemodelfile)
                modelsavename = basemodelfile + '_' + experiment_shortname + '-' + 'ep' + str(epoch) + '-' + fit_type + '-' + 'lr' + str(lr)
                print("Training: {}".format(modelsavename))
                if fit_type == 'const':
                    learn.fit(epoch, lr= lr)
                else:
                    learn.fit_one_cycle(epoch, max_lr= lr)
                if save_newmodels: learn.save(modelsavename)

                last_train_loss = float(learn.recorder.losses[-1])
                last_val_loss = float(learn.recorder.val_losses[-1])
                last_val_metric = float(learn.recorder.metrics[-1][0])

                fig_train_lr = learn.recorder.plot_lr(return_fig = True)
                fig_train_losses = learn.recorder.plot_losses(return_fig = True)
                fig_train_metrics = learn.recorder.plot_metrics(return_fig = True)
                interp = ClassificationInterpretation.from_learner(learn)
                fig_train_confmat = interp.plot_confusion_matrix(figsize=(12,12), dpi=60, return_fig = True)

                # html emgedded figures
                html_fig_base_toplosses = fig2inlinehtml(fig_base_toplosses)
                html_fig_base_confmat = fig2inlinehtml(fig_base_confmat)
                html_fig_orig_lrfind = fig2inlinehtml(fig_orig_lrfind)
                html_fig_train_lr = fig2inlinehtml(fig_train_lr)
                html_fig_train_losses = fig2inlinehtml(fig_train_losses)
                html_fig_train_metrics = fig2inlinehtml(fig_train_metrics)
                html_fig_train_confmat = fig2inlinehtml(fig_train_confmat)

                output_df = output_df.append({'basemodeltype': basemodelname,
                                      'basemodelfile': basemodelfile, 
                                      'newmodelfile' : modelsavename, 
                                      'lr' : lr, 
                                      'fit_type' : fit_type, 
                                      'epoch' : epoch, 
                                      'orig_metric' : orig_val_metric, 
                                      'retrain_metric' : last_val_metric,
                                      'orig_most_confused' : most_conf_base, 
                                      'fig_base_toplosses' : html_fig_base_toplosses, 
                                      'fig_base_confmat' : html_fig_base_confmat, 
                                      'fig_orig_lrfind' : html_fig_orig_lrfind,
                                      'last_train_loss' : last_train_loss, 
                                      'last_val_loss' : last_val_loss, 
                                      'last_val_metric' : last_val_metric, 
                                      'fig_train_lr' : html_fig_train_lr, 
                                      'fig_train_losses' : html_fig_train_losses, 
                                      'fig_train_metrics' : html_fig_train_metrics, 
                                      'fig_train_confmat' : html_fig_train_confmat
                                     }, ignore_index=True)
                output_df.to_pickle(exp_modelpath/'output.pkl')
                output_df.to_html(exp_modelpath/'output.html', escape=False)
                output_df.drop([fig_base_toplosses, fig_base_confmat], axis=1).to_html(exp_modelpath/'output_compact.html', escape=False)

                plt.close(fig_train_lr)
                plt.close(fig_train_losses)
                plt.close(fig_train_metrics)
                plt.close(fig_train_confmat)
                fig_train_lr = None
                fig_train_losses = None
                fig_train_metrics = None
                fig_train_confmat = None        
                gc.collect()


    plt.close(fig_base_toplosses)
    plt.close(fig_base_confmat)
    plt.close(fig_orig_lrfind)
    fig_base_toplosses = None
    fig_base_confmat = None
    fig_orig_lrfind = None
    plt.close("all")
    gc.collect()
    return output_df

### ResNet18

In [None]:
basemodelname='resnet18'
basemodelfile = 'basetrain-res18-lr3e3_1e61e4-ep8_3-pretrain'
exp_modelpath = path_experiment/basemodelname
(exp_modelpath).mkdir(parents = True, exist_ok=True)

In [None]:
pd.set_option('display.max_colwidth', None)
output_df = pd.DataFrame(columns=['basemodeltype', 'basemodelfile', 'newmodelfile', 'lr', 'fit_type', 'epoch', 'orig_metric', 'retrain_metric',
                                'orig_most_confused', 'fig_base_toplosses', 'fig_base_confmat', 'fig_train_confmat', 'fig_orig_lrfind',
                                'last_train_loss', 'last_val_loss', 'last_val_metric', 
                                'fig_train_lr', 'fig_train_losses', 'fig_train_metrics'])

In [None]:
learn = cnn_learner(data, models.resnet18, loss_func = nn.CrossEntropyLoss(), metrics=error_rate, pretrained=False)

In [None]:
output_df = trainscript(output_df)

In [None]:
learn.destroy()
gc.collect()

### ResNet34

In [None]:
basemodelname='resnet34'
basemodelfile = 'basetrain-res34-lr3e3_1e61e4-ep8_3-pretrain'
exp_modelpath = path_experiment/basemodelname
(exp_modelpath).mkdir(parents = True, exist_ok=True)

In [None]:
pd.set_option('display.max_colwidth', None)
output_df = pd.DataFrame(columns=['basemodeltype', 'basemodelfile', 'newmodelfile', 'lr', 'fit_type', 'epoch', 'orig_metric', 'retrain_metric',
                                'orig_most_confused', 'fig_base_toplosses', 'fig_base_confmat', 'fig_train_confmat', 'fig_orig_lrfind',
                                'last_train_loss', 'last_val_loss', 'last_val_metric', 
                                'fig_train_lr', 'fig_train_losses', 'fig_train_metrics'])

In [None]:
learn = cnn_learner(data, models.resnet34, loss_func = nn.CrossEntropyLoss(), metrics=error_rate, pretrained=False)

In [None]:
output_df = trainscript(output_df)

In [None]:
learn.destroy()
gc.collect()

### ResNet50

In [None]:
basemodelname='resnet50'
basemodelfile = 'basetrain-res50-lr3e3_1e61e4-ep8_3-pretrain'
exp_modelpath = path_experiment/basemodelname
(exp_modelpath).mkdir(parents = True, exist_ok=True)

In [None]:
pd.set_option('display.max_colwidth', None)
output_df = pd.DataFrame(columns=['basemodeltype', 'basemodelfile', 'newmodelfile', 'lr', 'fit_type', 'epoch', 'orig_metric', 'retrain_metric',
                                'orig_most_confused', 'fig_base_toplosses', 'fig_base_confmat', 'fig_train_confmat', 'fig_orig_lrfind',
                                'last_train_loss', 'last_val_loss', 'last_val_metric', 
                                'fig_train_lr', 'fig_train_losses', 'fig_train_metrics'])

In [None]:
learn = cnn_learner(data, models.resnet50, loss_func = nn.CrossEntropyLoss(), metrics=error_rate, pretrained=False)

In [None]:
output_df = trainscript(output_df)

In [None]:
learn.destroy()
gc.collect()

### ResNet101

In [None]:
basemodelname='resnet101'
basemodelfile = 'basetrain-res101-lr3e3_1e61e4-ep8_3-pretrain'
exp_modelpath = path_experiment/basemodelname
(exp_modelpath).mkdir(parents = True, exist_ok=True)

In [None]:
pd.set_option('display.max_colwidth', None)
output_df = pd.DataFrame(columns=['basemodeltype', 'basemodelfile', 'newmodelfile', 'lr', 'fit_type', 'epoch', 'orig_metric', 'retrain_metric',
                                'orig_most_confused', 'fig_base_toplosses', 'fig_base_confmat', 'fig_train_confmat', 'fig_orig_lrfind',
                                'last_train_loss', 'last_val_loss', 'last_val_metric', 
                                'fig_train_lr', 'fig_train_losses', 'fig_train_metrics'])

In [None]:
learn = cnn_learner(data, models.resnet101, loss_func = nn.CrossEntropyLoss(), metrics=error_rate, pretrained=False)

In [None]:
output_df = trainscript(output_df)

In [None]:
learn.destroy()
gc.collect()