# Result

When unfreezing the layers of the basemodel one has to be careful, to keep the batch normalization layers frozen. Else the model will forget everything already at the very first step. This behavior is demonstrated here, as part of a sanity check.

#Setup

In [1]:
!git init
!git remote add origin https://github.com/sgerloff/sustainable_deepfashion.git
!git pull origin main
!pip install -r requirements.txt

!make setup-preprocessed-gc CATEGORY_ID=1 MIN_PAIR_COUNT=20

Initialized empty Git repository in /content/.git/
remote: Enumerating objects: 338, done.[K
remote: Counting objects: 100% (338/338), done.[K
remote: Compressing objects: 100% (235/235), done.[K
remote: Total 338 (delta 170), reused 231 (delta 84), pack-reused 0[K
Receiving objects: 100% (338/338), 801.80 KiB | 1.32 MiB/s, done.
Resolving deltas: 100% (170/170), done.
From https://github.com/sgerloff/sustainable_deepfashion
 * branch            main       -> FETCH_HEAD
 * [new branch]      main       -> origin/main
Collecting argparse
  Downloading https://files.pythonhosted.org/packages/f2/94/3af39d34be01a24a6e65433d19e107099374224905f1e0cc6bbe1fd22a2f/argparse-1.4.0-py2.py3-none-any.whl
Collecting tensorflow_addons
[?25l  Downloading https://files.pythonhosted.org/packages/74/e3/56d2fe76f0bb7c88ed9b2a6a557e25e83e252aec08f13de34369cd850a0b/tensorflow_addons-0.12.1-cp37-cp37m-manylinux2010_x86_64.whl (703kB)
[K     |████████████████████████████████| 706kB 7.8MB/s 
Collecting zip

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
  inflating: data/processed/train/cat1/107930_8543.jpg  
  inflating: data/processed/train/cat1/024331_1934.jpg  
  inflating: data/processed/train/cat1/141052_11086.jpg  
  inflating: data/processed/train/cat1/182857_14346.jpg  
  inflating: data/processed/train/cat1/010400_837.jpg  
  inflating: data/processed/train/cat1/094388_7451.jpg  
  inflating: data/processed/train/cat1/061448_4860.jpg  
  inflating: data/processed/train/cat1/004991_398.jpg  
  inflating: data/processed/train/cat1/159042_12457.jpg  
  inflating: data/processed/train/cat1/154596_12114.jpg  
  inflating: data/processed/train/cat1/168740_13223.jpg  
  inflating: data/processed/train/cat1/074348_5870.jpg  
  inflating: data/processed/train/cat1/097553_7718.jpg  
  inflating: data/processed/train/cat1/069704_5520.jpg  
  inflating: data/processed/train/cat1/171464_13426.jpg  
  inflating: data/processed/train/cat1/180180_14119.jpg  
  inflating: data/

Load an already trained model to experiment:

In [2]:
!wget http://d2fcl18pl6lkip.cloudfront.net/models/effnet_freeze_basemodel.h5 -O models/effnet_freeze_basemodel.h5 

--2021-03-08 10:26:01--  http://d2fcl18pl6lkip.cloudfront.net/models/effnet_freeze_basemodel.h5
Resolving d2fcl18pl6lkip.cloudfront.net (d2fcl18pl6lkip.cloudfront.net)... 13.225.100.209, 13.225.100.77, 13.225.100.148, ...
Connecting to d2fcl18pl6lkip.cloudfront.net (d2fcl18pl6lkip.cloudfront.net)|13.225.100.209|:80... connected.
HTTP request sent, awaiting response... 301 Moved Permanently
Location: https://d2fcl18pl6lkip.cloudfront.net/models/effnet_freeze_basemodel.h5 [following]
--2021-03-08 10:26:01--  https://d2fcl18pl6lkip.cloudfront.net/models/effnet_freeze_basemodel.h5
Connecting to d2fcl18pl6lkip.cloudfront.net (d2fcl18pl6lkip.cloudfront.net)|13.225.100.209|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 20522000 (20M) [binary/octet-stream]
Saving to: ‘models/effnet_freeze_basemodel.h5’


2021-03-08 10:26:04 (8.45 MB/s) - ‘models/effnet_freeze_basemodel.h5’ saved [20522000/20522000]



In [7]:
from src.models.efficient_net_triplet import EfficientNetTriplet
from src.utility import get_project_dir
import joblib, os

Setup dataset:

In [8]:
effnet_triplet = EfficientNetTriplet()

train_df = joblib.load(os.path.join(get_project_dir(),
                                    "data",
                                    "processed",
                                    "category_id_1_deepfashion_train.joblib"))

dataset, train_size = effnet_triplet.get_dataset(train_df, batch_size=32, training_ratio=1.)

# Santiy Checks:

## Untrained Model

Expected loss is around 0.95, but may fluctuate due to random sample selection in each batch.

In [9]:
effnet_triplet = EfficientNetTriplet()
effnet_triplet.model.fit(dataset, epochs=1, steps_per_epoch=1)



<tensorflow.python.keras.callbacks.History at 0x7fe0c4d6eb50>

## Trained Model

For the trained model, we expect a loss around 0.7-0.8:

In [10]:
effnet_triplet = EfficientNetTriplet()
effnet_triplet.load("effnet_freeze_basemodel.h5")
effnet_triplet.model.fit(dataset, epochs=1, steps_per_epoch=1)

Load model from /content/models/effnet_freeze_basemodel.h5


<tensorflow.python.keras.callbacks.History at 0x7fe055b4b810>

## Trained Model Unfreeze Basemodel (Broken!)

Now if we simply unfreeze all the layers in the basemodel, we are in trouble, the model forgets everything and will produce a loss close to 0.95, i.e. equal to the untrained model.

In [11]:
effnet_triplet = EfficientNetTriplet()
effnet_triplet.load("effnet_freeze_basemodel.h5")

#Unfreeze
effnet_triplet.basemodel.trainable = True
for l in effnet_triplet.basemodel.layers:
  l.trainable = True
effnet_triplet.compile()

effnet_triplet.model.fit(dataset, epochs=1, steps_per_epoch=1)

Load model from /content/models/effnet_freeze_basemodel.h5


<tensorflow.python.keras.callbacks.History at 0x7fe053e753d0>

## Safe unfreeze (Solution)

Turns out the culprit are the batch normalization layers, as also mentioned in the following tensorflow documentation: 
https://www.tensorflow.org/tutorials/images/transfer_learning

In [14]:
effnet_triplet = EfficientNetTriplet()
effnet_triplet.load("effnet_freeze_basemodel.h5")

#Safe unfreeze:
effnet_triplet.basemodel.trainable=True
for l in effnet_triplet.basemodel.layers:
  if l.__class__.__name__ == "BatchNormalization":
    l.trainable = False
  else:
    l.trainable = True

effnet_triplet.compile()

effnet_triplet.model.fit(dataset, epochs=1, steps_per_epoch=1)

Load model from /content/models/effnet_freeze_basemodel.h5


<tensorflow.python.keras.callbacks.History at 0x7fe047238090>