In [1]:
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = 'all'

In [2]:
from fastai.vision import *

In [3]:
path = untar_data(URLs.IMAGENETTE_160)
data_nette = (ImageList.from_folder(path).split_by_folder(valid='val')
              .label_from_folder()
              .transform(get_transforms(), size=128)
              .databunch(bs=64)
              .normalize(imagenet_stats))

In [6]:
learn = cnn_learner(data_nette, models.mobilenet_v2, concat_pool=False, pretrained=True, train_bn=True)
learn.fit_one_cycle(3)

epoch,train_loss,valid_loss,time
0,0.441065,0.201856,00:10
1,0.24497,0.138996,00:10
2,0.179425,0.131967,00:10


In [7]:
learn.unfreeze()
learn.fit_one_cycle(5, slice(1e-6,1e-3))

epoch,train_loss,valid_loss,time
0,0.140665,0.131163,00:12
1,0.180994,0.129539,00:12
2,0.131018,0.118751,00:12
3,0.10056,0.112933,00:11
4,0.081557,0.112444,00:11


In [9]:
import copy
feature_extractor = copy.deepcopy(learn.model[0])
feature_extractor.eval();

In [11]:
x = torch.rand(2,3,224,224).cuda()
feat_out = feature_extractor(x)
feat_out.shape

torch.Size([2, 1280, 7, 7])

In [13]:
path = untar_data(URLs.PETS)
path_img = path/'images'
fnames = get_image_files(path_img)
pat = r'/([^/]+)_\d+.jpg$'


data_pets = ImageDataBunch.from_name_re(path_img, fnames, pat, ds_tfms=get_transforms(), size=128, bs=64
                                        ).normalize(imagenet_stats)

In [20]:
pets_learner = cnn_learner(data_pets, models.mobilenet_v2, train_bn=False)
pets_learner.model[0] = feature_extractor # this changes the `feature extractor`, as shown below
pets_learner.fit(1)
pets_learner.model[0].eval();

epoch,train_loss,valid_loss,time
0,1.379401,0.716572,00:12


### `pets_learner.model[0]` is now sharing memory with `feature_extractor`, and changes it as it trains

In [21]:
torch.equal(feat_out, pets_learner.model[0](x))
torch.equal(feat_out, feature_extractor(x))

False

False

In [22]:
def models_equal(model_1, model_2):
    models_differ = 0
    for key_item_1, key_item_2 in zip(model_1.state_dict().items(), model_2.state_dict().items()):
        if torch.equal(key_item_1[1], key_item_2[1]):
            pass
        else:
            models_differ += 1
            if (key_item_1[0] == key_item_2[0]):
                print('Mismtach found at', key_item_1[0])
            else:
                raise Exception
                print('Models being compared have different architectures')
    if models_differ == 0:
        print('Models match perfectly!')
        return True
    return False

In [23]:
models_equal(feature_extractor, learn.model[0])

Mismtach found at 0.0.1.running_mean
Mismtach found at 0.0.1.running_var
Mismtach found at 0.0.1.num_batches_tracked
Mismtach found at 0.1.conv.0.1.running_mean
Mismtach found at 0.1.conv.0.1.running_var
Mismtach found at 0.1.conv.0.1.num_batches_tracked
Mismtach found at 0.1.conv.2.running_mean
Mismtach found at 0.1.conv.2.running_var
Mismtach found at 0.1.conv.2.num_batches_tracked
Mismtach found at 0.2.conv.0.1.running_mean
Mismtach found at 0.2.conv.0.1.running_var
Mismtach found at 0.2.conv.0.1.num_batches_tracked
Mismtach found at 0.2.conv.1.1.running_mean
Mismtach found at 0.2.conv.1.1.running_var
Mismtach found at 0.2.conv.1.1.num_batches_tracked
Mismtach found at 0.2.conv.3.running_mean
Mismtach found at 0.2.conv.3.running_var
Mismtach found at 0.2.conv.3.num_batches_tracked
Mismtach found at 0.3.conv.0.1.running_mean
Mismtach found at 0.3.conv.0.1.running_var
Mismtach found at 0.3.conv.0.1.num_batches_tracked
Mismtach found at 0.3.conv.1.1.running_mean
Mismtach found at 0.3.c

False

In [24]:
models_equal(learn.model[0], pets_learner.model[0])

Mismtach found at 0.0.1.running_mean
Mismtach found at 0.0.1.running_var
Mismtach found at 0.0.1.num_batches_tracked
Mismtach found at 0.1.conv.0.1.running_mean
Mismtach found at 0.1.conv.0.1.running_var
Mismtach found at 0.1.conv.0.1.num_batches_tracked
Mismtach found at 0.1.conv.2.running_mean
Mismtach found at 0.1.conv.2.running_var
Mismtach found at 0.1.conv.2.num_batches_tracked
Mismtach found at 0.2.conv.0.1.running_mean
Mismtach found at 0.2.conv.0.1.running_var
Mismtach found at 0.2.conv.0.1.num_batches_tracked
Mismtach found at 0.2.conv.1.1.running_mean
Mismtach found at 0.2.conv.1.1.running_var
Mismtach found at 0.2.conv.1.1.num_batches_tracked
Mismtach found at 0.2.conv.3.running_mean
Mismtach found at 0.2.conv.3.running_var
Mismtach found at 0.2.conv.3.num_batches_tracked
Mismtach found at 0.3.conv.0.1.running_mean
Mismtach found at 0.3.conv.0.1.running_var
Mismtach found at 0.3.conv.0.1.num_batches_tracked
Mismtach found at 0.3.conv.1.1.running_mean
Mismtach found at 0.3.c

False

### `pets_learner.model[0]` is now a copy of the `feature_extractor`, and despite being frozen and `train_bn=False`, `pets_learner.model[0]` changes

In [27]:
feature_extractor = copy.deepcopy(learn.model[0])
feature_extractor.eval();

pets_learner = cnn_learner(data_pets, models.mobilenet_v2, train_bn=False)
pets_learner.model[0] = copy.deepcopy(feature_extractor)
pets_learner.fit(1)
pets_learner.model[0].eval();

epoch,train_loss,valid_loss,time
0,1.363092,0.770066,00:11


In [28]:
torch.equal(feat_out, pets_learner.model[0](x))
torch.equal(feat_out, feature_extractor(x))

False

True

In [29]:
models_equal(feature_extractor, learn.model[0])

Models match perfectly!


True

In [30]:
models_equal(learn.model[0], pets_learner.model[0])

Mismtach found at 0.0.1.running_mean
Mismtach found at 0.0.1.running_var
Mismtach found at 0.0.1.num_batches_tracked
Mismtach found at 0.1.conv.0.1.running_mean
Mismtach found at 0.1.conv.0.1.running_var
Mismtach found at 0.1.conv.0.1.num_batches_tracked
Mismtach found at 0.1.conv.2.running_mean
Mismtach found at 0.1.conv.2.running_var
Mismtach found at 0.1.conv.2.num_batches_tracked
Mismtach found at 0.2.conv.0.1.running_mean
Mismtach found at 0.2.conv.0.1.running_var
Mismtach found at 0.2.conv.0.1.num_batches_tracked
Mismtach found at 0.2.conv.1.1.running_mean
Mismtach found at 0.2.conv.1.1.running_var
Mismtach found at 0.2.conv.1.1.num_batches_tracked
Mismtach found at 0.2.conv.3.running_mean
Mismtach found at 0.2.conv.3.running_var
Mismtach found at 0.2.conv.3.num_batches_tracked
Mismtach found at 0.3.conv.0.1.running_mean
Mismtach found at 0.3.conv.0.1.running_var
Mismtach found at 0.3.conv.0.1.num_batches_tracked
Mismtach found at 0.3.conv.1.1.running_mean
Mismtach found at 0.3.c

False

## How To Make Models Share The Same Feature Extractor Reliably?