In [None]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
## Load in the needed packages ##

from fastai.imports import *
from fastai.transforms import *
from fastai.conv_learner import *
from fastai.model import *
from fastai.dataset import *
from fastai.sgdr import *
from fastai.plots import *
from sklearn.metrics import confusion_matrix
from torch.utils.data.sampler import SequentialSampler, RandomSampler, BatchSampler

In [None]:
## Set up parameters ##

PATH = "../../../training_data/classification_patches/" # Path to training data
sz=224 # Patch size

batch_size=64 # Batch size
drop_out=0.8 # Drop out percentage
pretrained=True # Pretrain model using ImageNet?

arch=dn201

In [None]:
## Set up dataset and network ##

data = ImageClassifierData.from_paths(PATH, tfms=tfms_from_model(arch, sz), bs=batch_size)
learn = ConvLearner.pretrained(arch, data, ps=drop_out, precompute=True, pretrained=pretrained)

In [None]:
## Set learning rate ##

short_lr = 0.01

In [None]:
## Begin training on final layers only ##

learn.freeze()
learn.precompute = True # Use this setting before full model training

learn.fit(short_lr, 5, cycle_len=1, cycle_mult=2, best_save_name='best_final_layers') # Train with adaptive learning rate

In [None]:
## Set learning rate ##

long_lr=np.array([7e-4,7e-3,7e-2])

In [None]:
## Begin training on full model ##

learn.unfreeze()
learn.precompute = False # Use this setting after full model training
learn.fit(long_lr, 3, cycle_len=1, cycle_mult=2, best_save_name='best_full_model')

In [None]:
## Plot current classification accuracy on validation dataset ##

log_preds,y = learn.TTA()
probs = np.mean(np.exp(log_preds),0)
preds = np.argmax(probs, axis=1)

cm = confusion_matrix(y, preds)
plot_confusion_matrix(cm, data.classes)

In [None]:
## Save current model ##
## Best model is already saved in training, but you can specifically save the current one here ##

learn.unfreeze()
torch.save(learn.model, PATH+'/models/finalModel.h5')

In [None]:
## Load a model ##
## Run this after cell #5 ##

model_name = 'finalModel.h5'
model = torch.load(PATH+'/models/'+model_name).state_dict()
learn.model.load_state_dict(model)