In [None]:
import os
import sys
sys.path.append(os.path.join(os.getcwd(), '..'))

import torch
from models.efficientnet import EfficientNetV2

MODEL_SIZE = ['s', 'm', 'l'][0]

CKP_PATH = f'loaded_models/efficientnet_v2_{MODEL_SIZE}.pth'
SAVE_PATH = f'../weights/imagenet_efficientnet_v2_{MODEL_SIZE}.pth'

model = EfficientNetV2(MODEL_SIZE)
old_model_ckp = torch.load(CKP_PATH, weights_only=False)

In [None]:
cur_dict = model.state_dict()

# All norm layers has one-dim value, except fc layers in SE Module - their baises are one-dim too.
# Loaded model was trained with BatchNorm, I'm going to use LayerNorm, so I do not want to transfer norm layers
normless_old_keys = []
normless_cur_keys = []

for key, value in old_model_ckp.items():
    is_one_dim = value.dim() > 1
    is_fc = 'fc' in key
    is_feature = key.startswith('features') # Remove all non-features layers

    if is_one_dim and is_feature and not is_fc:
        normless_old_keys.append(key)

for key, value in cur_dict.items():
    is_one_dim = value.dim() > 1
    is_fc = 'fc' in key

    if is_one_dim and not is_fc:
        normless_cur_keys.append(key)

for cur_key, old_key in zip(normless_cur_keys, normless_old_keys):
    cur_dict[cur_key] = old_model_ckp[old_key]

In [34]:
model.load_state_dict(cur_dict)
torch.save(model.state_dict(), SAVE_PATH)