In [1]:
%load_ext autoreload
%autoreload 2

%matplotlib inline

In [2]:
#export
from exp.nb_11 import *

## Serializing the model

In [3]:
path = datasets.untar_data(datasets.URLs.IMAGEWOOF_160)

In [4]:
size = 128
bs = 64

tfms = [make_rgb, RandomResizedCrop(size, scale=(0.35,1)), np_to_float, PilRandomFlip()]
val_tfms = [make_rgb, CenterCrop(size), np_to_float]
il = ImageList.from_files(path, tfms=tfms)
sd = SplitData.split_by_func(il, partial(grandparent_splitter, valid_name='val'))
ll = label_by_func(sd, parent_labeler, proc_y=CategoryProcessor())
ll.valid.x.tfms = val_tfms
data = ll.to_databunch(bs, c_in=3, c_out=10, num_workers=8)

In [5]:
len(il)

12954

In [6]:
loss_func = LabelSmoothingCrossEntropy()
opt_func = adam_opt(mom=0.9, mom_sqr=0.99, eps=1e-6, wd=1e-2)

In [7]:
learn = cnn_learner(xresnet18, data, loss_func, opt_func, norm=norm_imagenette)

In [8]:
def sched_1cycle(lr, pct_start=0.3, mom_start=0.95, mom_mid=0.85, mom_end=0.95):
    phases = create_phases(pct_start)
    sched_lr  = combine_scheds(phases, cos_1cycle_anneal(lr/10., lr, lr/1e5))
    sched_mom = combine_scheds(phases, cos_1cycle_anneal(mom_start, mom_mid, mom_end))
    return [ParamScheduler('lr', sched_lr),
            ParamScheduler('mom', sched_mom)]

In [9]:
lr = 3e-3
pct_start = 0.5
cbsched = sched_1cycle(lr, pct_start)

In [10]:
learn.fit(40, cbsched)

epoch,train_loss,train_accuracy,valid_loss,valid_accuracy,time
0,2.094667,0.266501,2.027031,0.286,00:11
1,1.893691,0.367593,1.862369,0.378,00:11
2,1.76716,0.436406,2.010226,0.352,00:11
3,1.686197,0.468685,1.731216,0.416,00:11
4,1.619955,0.501365,1.749357,0.442,00:11
5,1.570306,0.526337,1.663812,0.472,00:11
6,1.521743,0.552594,1.51677,0.54,00:11
7,1.48983,0.566726,1.486735,0.574,00:11
8,1.425645,0.595632,1.884718,0.442,00:11
9,1.38598,0.616107,1.736998,0.51,00:11


In [11]:
st = learn.model.state_dict()

In [12]:
type(st)

collections.OrderedDict

In [13]:
', '.join(st.keys())

'0.0.weight, 0.1.weight, 0.1.bias, 0.1.running_mean, 0.1.running_var, 0.1.num_batches_tracked, 1.0.weight, 1.1.weight, 1.1.bias, 1.1.running_mean, 1.1.running_var, 1.1.num_batches_tracked, 2.0.weight, 2.1.weight, 2.1.bias, 2.1.running_mean, 2.1.running_var, 2.1.num_batches_tracked, 4.0.convs.0.0.weight, 4.0.convs.0.1.weight, 4.0.convs.0.1.bias, 4.0.convs.0.1.running_mean, 4.0.convs.0.1.running_var, 4.0.convs.0.1.num_batches_tracked, 4.0.convs.1.0.weight, 4.0.convs.1.1.weight, 4.0.convs.1.1.bias, 4.0.convs.1.1.running_mean, 4.0.convs.1.1.running_var, 4.0.convs.1.1.num_batches_tracked, 4.1.convs.0.0.weight, 4.1.convs.0.1.weight, 4.1.convs.0.1.bias, 4.1.convs.0.1.running_mean, 4.1.convs.0.1.running_var, 4.1.convs.0.1.num_batches_tracked, 4.1.convs.1.0.weight, 4.1.convs.1.1.weight, 4.1.convs.1.1.bias, 4.1.convs.1.1.running_mean, 4.1.convs.1.1.running_var, 4.1.convs.1.1.num_batches_tracked, 5.0.convs.0.0.weight, 5.0.convs.0.1.weight, 5.0.convs.0.1.bias, 5.0.convs.0.1.running_mean, 5.0.convs

In [14]:
st['10.bias']

tensor([ 0.0115,  0.0590,  0.0033, -0.0029, -0.0485,  0.0239,  0.0022, -0.0411,
        -0.0158,  0.0144], device='cuda:0')

In [15]:
mdl_path = path/'models'
mdl_path.mkdir(exist_ok=True)

It's also possible to save the whole model, including the architecture, but it gets quite fiddly and we don't recommend it. Instead, just save the parameters, and recreate the model directly.

In [16]:
torch.save(st, mdl_path/'iw5')

## Pets

In [17]:
pets = datasets.untar_data(datasets.URLs.PETS)

In [18]:
pets.ls()

[PosixPath('/home/vishaladu/.fastai/data/oxford-iiit-pet/annotations'),
 PosixPath('/home/vishaladu/.fastai/data/oxford-iiit-pet/small-256'),
 PosixPath('/home/vishaladu/.fastai/data/oxford-iiit-pet/small-96'),
 PosixPath('/home/vishaladu/.fastai/data/oxford-iiit-pet/images')]

In [19]:
pets_path = pets/'images'

In [20]:
il = ImageList.from_files(pets_path, tfms=tfms)

In [21]:
il

ImageList (7390 items)
[PosixPath('/home/vishaladu/.fastai/data/oxford-iiit-pet/images/Siamese_40.jpg'), PosixPath('/home/vishaladu/.fastai/data/oxford-iiit-pet/images/Siamese_58.jpg'), PosixPath('/home/vishaladu/.fastai/data/oxford-iiit-pet/images/pug_62.jpg'), PosixPath('/home/vishaladu/.fastai/data/oxford-iiit-pet/images/keeshond_14.jpg'), PosixPath('/home/vishaladu/.fastai/data/oxford-iiit-pet/images/Ragdoll_166.jpg'), PosixPath('/home/vishaladu/.fastai/data/oxford-iiit-pet/images/British_Shorthair_141.jpg'), PosixPath('/home/vishaladu/.fastai/data/oxford-iiit-pet/images/Persian_59.jpg'), PosixPath('/home/vishaladu/.fastai/data/oxford-iiit-pet/images/Bombay_163.jpg'), PosixPath('/home/vishaladu/.fastai/data/oxford-iiit-pet/images/basset_hound_94.jpg'), PosixPath('/home/vishaladu/.fastai/data/oxford-iiit-pet/images/yorkshire_terrier_16.jpg')...]
Path: /home/vishaladu/.fastai/data/oxford-iiit-pet/images

In [22]:
#export
def random_splitter(fn, p_valid): return random.random() < p_valid

In [23]:
random.seed(42)

In [24]:
sd = SplitData.split_by_func(il, partial(random_splitter, p_valid=0.1))

In [25]:
sd

SplitData
Train: ImageList (6667 items)
[PosixPath('/home/vishaladu/.fastai/data/oxford-iiit-pet/images/Siamese_40.jpg'), PosixPath('/home/vishaladu/.fastai/data/oxford-iiit-pet/images/pug_62.jpg'), PosixPath('/home/vishaladu/.fastai/data/oxford-iiit-pet/images/keeshond_14.jpg'), PosixPath('/home/vishaladu/.fastai/data/oxford-iiit-pet/images/Ragdoll_166.jpg'), PosixPath('/home/vishaladu/.fastai/data/oxford-iiit-pet/images/British_Shorthair_141.jpg'), PosixPath('/home/vishaladu/.fastai/data/oxford-iiit-pet/images/Persian_59.jpg'), PosixPath('/home/vishaladu/.fastai/data/oxford-iiit-pet/images/basset_hound_94.jpg'), PosixPath('/home/vishaladu/.fastai/data/oxford-iiit-pet/images/american_pit_bull_terrier_32.jpg'), PosixPath('/home/vishaladu/.fastai/data/oxford-iiit-pet/images/basset_hound_152.jpg'), PosixPath('/home/vishaladu/.fastai/data/oxford-iiit-pet/images/basset_hound_170.jpg')...]
Path: /home/vishaladu/.fastai/data/oxford-iiit-pet/images
Valid: ImageList (723 items)
[PosixPath('/ho

In [26]:
n = il.items[0].name; n

'Siamese_40.jpg'

In [27]:
re.findall(r'^(.*)_\d+.jpg$', n)[0]

'Siamese'

In [28]:
def pet_labeler(fn): return re.findall(r'^(.*)_\d+.jpg$', fn.name)[0]

In [29]:
proc = CategoryProcessor()

In [30]:
ll = label_by_func(sd, pet_labeler, proc_y=proc)

In [31]:
', '.join(proc.vocab)

'Siamese, pug, keeshond, Ragdoll, British_Shorthair, Persian, basset_hound, american_pit_bull_terrier, shiba_inu, american_bulldog, Bengal, miniature_pinscher, saint_bernard, scottish_terrier, german_shorthaired, havanese, newfoundland, boxer, Bombay, yorkshire_terrier, Abyssinian, great_pyrenees, chihuahua, Maine_Coon, Birman, Russian_Blue, Egyptian_Mau, Sphynx, english_setter, beagle, japanese_chin, english_cocker_spaniel, wheaten_terrier, samoyed, pomeranian, leonberger, staffordshire_bull_terrier'

In [32]:
ll.valid.x.tfms = val_tfms

In [33]:
c_out = len(proc.vocab)

In [34]:
data = ll.to_databunch(bs, c_in=3, c_out=c_out, num_workers=8)

In [35]:
learn = cnn_learner(xresnet18, data, loss_func, opt_func, norm=norm_imagenette)

In [36]:
learn.fit(5, cbsched)

epoch,train_loss,train_accuracy,valid_loss,valid_accuracy,time
0,3.462268,0.089696,3.470064,0.076072,00:07
1,3.250136,0.142343,4.021382,0.105118,00:06
2,3.040703,0.20339,3.714852,0.102351,00:06
3,2.719277,0.297735,2.722988,0.298755,00:06
4,2.410459,0.39853,2.43293,0.383126,00:06


## Custom head

In [37]:
learn = cnn_learner(xresnet18, data, loss_func, opt_func, c_out=10, norm=norm_imagenette)

In [38]:
st = torch.load(mdl_path/'iw5')

In [39]:
m = learn.model

In [40]:
m.load_state_dict(st)

<All keys matched successfully>

In [41]:
cut = next(i for i,o in enumerate(m.children()) if isinstance(o,nn.AdaptiveAvgPool2d))
m_cut = m[:cut]

In [42]:
xb,yb = get_batch(data.valid_dl, learn)

epoch,train_loss,train_accuracy,valid_loss,valid_accuracy,time


In [43]:
pred = m_cut(xb)

In [44]:
pred.shape

torch.Size([128, 512, 4, 4])

In [45]:
ni = pred.shape[1]

In [46]:
#export
class AdaptiveConcatPool2d(nn.Module):
    def __init__(self, sz=1):
        super().__init__()
        self.output_size = sz
        self.ap = nn.AdaptiveAvgPool2d(sz)
        self.mp = nn.AdaptiveMaxPool2d(sz)
    def forward(self, x): return torch.cat([self.mp(x), self.ap(x)], 1)

In [47]:
nh = 40

m_new = nn.Sequential(
    m_cut, AdaptiveConcatPool2d(), Flatten(),
    nn.Linear(ni*2, data.c_out))

In [48]:
learn.model = m_new

In [49]:
learn.fit(5, cbsched)

epoch,train_loss,train_accuracy,valid_loss,valid_accuracy,time
0,2.922665,0.274936,2.109121,0.495159,00:06
1,1.973275,0.538323,1.953218,0.561549,00:06
2,1.699401,0.640018,1.744647,0.62379,00:06
3,1.491193,0.715914,1.488735,0.719225,00:06
4,1.287838,0.80486,1.414617,0.726141,00:06


## adapt_model and gradual unfreezing

In [50]:
def adapt_model(learn, data):
    cut = next(i for i,o in enumerate(learn.model.children())
               if isinstance(o,nn.AdaptiveAvgPool2d))
    m_cut = learn.model[:cut]
    xb,yb = get_batch(data.valid_dl, learn)
    pred = m_cut(xb)
    ni = pred.shape[1]
    m_new = nn.Sequential(
        m_cut, AdaptiveConcatPool2d(), Flatten(),
        nn.Linear(ni*2, data.c_out))
    learn.model = m_new

In [51]:
learn = cnn_learner(xresnet18, data, loss_func, opt_func, c_out=10, norm=norm_imagenette)
learn.model.load_state_dict(torch.load(mdl_path/'iw5'))

<All keys matched successfully>

In [52]:
adapt_model(learn, data)

epoch,train_loss,train_accuracy,valid_loss,valid_accuracy,time


In [53]:
for p in learn.model[0].parameters(): p.requires_grad_(False)

In [54]:
learn.fit(3, sched_1cycle(1e-2, 0.5))

epoch,train_loss,train_accuracy,valid_loss,valid_accuracy,time
0,2.773465,0.295785,2.454371,0.394191,00:04
1,2.40249,0.419979,2.28928,0.459198,00:04
2,2.050727,0.528724,2.082191,0.514523,00:04


In [55]:
for p in learn.model[0].parameters(): p.requires_grad_(True)

In [56]:
learn.fit(5, cbsched, reset_opt=True)

epoch,train_loss,train_accuracy,valid_loss,valid_accuracy,time
0,1.846987,0.60027,1.879508,0.585062,00:06
1,1.732402,0.636268,1.987868,0.522822,00:06
2,1.642893,0.671666,1.904087,0.564315,00:07
3,1.437848,0.753712,1.597742,0.688797,00:07
4,1.272595,0.815809,1.503559,0.71231,00:06


## Batch norm transfer

In [57]:
learn = cnn_learner(xresnet18, data, loss_func, opt_func, c_out=10, norm=norm_imagenette)
learn.model.load_state_dict(torch.load(mdl_path/'iw5'))
adapt_model(learn, data)

epoch,train_loss,train_accuracy,valid_loss,valid_accuracy,time


In [58]:
def apply_mod(m, f):
    f(m)
    for l in m.children(): apply_mod(l, f)

def set_grad(m, b):
    if isinstance(m, (nn.Linear,nn.BatchNorm2d)): return
    if hasattr(m, 'weight'):
        for p in m.parameters(): p.requires_grad_(b)

In [59]:
apply_mod(learn.model, partial(set_grad, b=False))

In [60]:
learn.fit(3, sched_1cycle(1e-2, 0.5))

epoch,train_loss,train_accuracy,valid_loss,valid_accuracy,time
0,2.661953,0.342283,2.178851,0.466113,00:05
1,2.082033,0.504875,2.050919,0.506224,00:05
2,1.827331,0.59787,1.833369,0.60166,00:05


In [61]:
apply_mod(learn.model, partial(set_grad, b=True))

In [62]:
learn.fit(5, cbsched, reset_opt=True)

epoch,train_loss,train_accuracy,valid_loss,valid_accuracy,time
0,1.718376,0.646618,1.812805,0.605809,00:07
1,1.664148,0.662367,1.823408,0.594744,00:06
2,1.599361,0.683966,1.78428,0.591978,00:06
3,1.419813,0.760462,1.555135,0.701245,00:06
4,1.277487,0.814759,1.460849,0.721992,00:06


Pytorch already has an `apply` method we can use:

In [63]:
learn.model.apply(partial(set_grad, b=False));

## Discriminative LR and param groups

In [64]:
learn = cnn_learner(xresnet18, data, loss_func, opt_func, c_out=10, norm=norm_imagenette)

In [65]:
learn.model.load_state_dict(torch.load(mdl_path/'iw5'))
adapt_model(learn, data)

epoch,train_loss,train_accuracy,valid_loss,valid_accuracy,time


In [66]:
def bn_splitter(m):
    def _bn_splitter(l, g1, g2):
        if isinstance(l, nn.BatchNorm2d): g2 += l.parameters()
        elif hasattr(l, 'weight'): g1 += l.parameters()
        for ll in l.children(): _bn_splitter(ll, g1, g2)
        
    g1,g2 = [],[]
    _bn_splitter(m[0], g1, g2)
    
    g2 += m[1:].parameters()
    return g1,g2

In [67]:
a,b = bn_splitter(learn.model)

In [68]:
test_eq(len(a)+len(b), len(list(m.parameters())))

In [69]:
Learner.ALL_CBS

{'after_backward',
 'after_batch',
 'after_cancel_batch',
 'after_cancel_epoch',
 'after_cancel_train',
 'after_epoch',
 'after_fit',
 'after_loss',
 'after_pred',
 'after_step',
 'begin_batch',
 'begin_epoch',
 'begin_fit',
 'begin_validate'}

In [70]:
#export
from types import SimpleNamespace
cb_types = SimpleNamespace(**{o:o for o in Learner.ALL_CBS})

In [71]:
cb_types.after_backward

'after_backward'

In [72]:
#export
class DebugCallback(Callback):
    _order = 999
    def __init__(self, cb_name, f=None): self.cb_name,self.f = cb_name,f
    def __call__(self, cb_name):
        if cb_name==self.cb_name:
            if self.f: self.f(self.run)
            else:      set_trace()

In [73]:
#export
def sched_1cycle(lrs, pct_start=0.3, mom_start=0.95, mom_mid=0.85, mom_end=0.95):
    phases = create_phases(pct_start)
    sched_lr  = [combine_scheds(phases, cos_1cycle_anneal(lr/10., lr, lr/1e5))
                 for lr in lrs]
    sched_mom = combine_scheds(phases, cos_1cycle_anneal(mom_start, mom_mid, mom_end))
    return [ParamScheduler('lr', sched_lr),
            ParamScheduler('mom', sched_mom)]

In [74]:
disc_lr_sched = sched_1cycle([0,3e-2], 0.5)

In [75]:
learn = cnn_learner(xresnet18, data, loss_func, opt_func,
                    c_out=10, norm=norm_imagenette, splitter=bn_splitter)

learn.model.load_state_dict(torch.load(mdl_path/'iw5'))
adapt_model(learn, data)

epoch,train_loss,train_accuracy,valid_loss,valid_accuracy,time


In [76]:
def _print_det(o): 
    print (len(o.opt.param_groups), o.opt.hypers)
    raise CancelTrainException()

learn.fit(1, disc_lr_sched + [DebugCallback(cb_types.after_batch, _print_det)])

epoch,train_loss,train_accuracy,valid_loss,valid_accuracy,time


2 [{'mom': 0.9499999999999997, 'mom_sqr': 0.99, 'eps': 1e-06, 'wd': 0.01, 'lr': 0.0, 'sqr_mom': 0.99}, {'mom': 0.9499999999999997, 'mom_sqr': 0.99, 'eps': 1e-06, 'wd': 0.01, 'lr': 0.0030000000000000512, 'sqr_mom': 0.99}]


In [77]:
learn.fit(3, disc_lr_sched)

epoch,train_loss,train_accuracy,valid_loss,valid_accuracy,time
0,2.48185,0.386681,2.371274,0.409405,00:06
1,2.228766,0.458077,2.173414,0.460581,00:06
2,1.95677,0.547173,1.828445,0.589212,00:06


In [78]:
disc_lr_sched = sched_1cycle([1e-3,1e-2], 0.3)

In [79]:
learn.fit(5, disc_lr_sched)

epoch,train_loss,train_accuracy,valid_loss,valid_accuracy,time
0,1.778767,0.614969,1.850521,0.593361,00:07
1,1.791522,0.60642,2.102978,0.500692,00:06
2,1.674013,0.659367,1.692361,0.659751,00:06
3,1.514179,0.721464,1.600683,0.669433,00:06
4,1.434761,0.748913,1.599681,0.683264,00:07


## Export

In [80]:
!./notebook2script.py 11a_transfer_learning.ipynb

Converted 11a_transfer_learning.ipynb to exp/nb_11a.py
