In [1]:
import torch
from pytorch_scripts.ghostnetv2 import ghostnetv2

In [3]:
def load_fi_weights(model, filename, verbose=False):
    count = 0
    new_dict = {}
    weights = torch.load(filename, weights_only=True)
    state_dict = model.state_dict()
    for name, param in state_dict.items():
        if verbose:
            print(name, param.data.shape)
        if "dummy" in name:
            if verbose:
                print("-\n")
            continue
        new_name = name.replace("conv.weight", "weight").replace("conv.bias", "bias")
        new_name = new_name.replace("linear.weight", "weight").replace(
            "linear.bias", "bias"
        )
        new_name = new_name.replace(".layers", "")
        new_weights = weights[new_name]
        if verbose:
            print(new_name, new_weights.shape, "\n")
        count += 1
        if param.data.shape != new_weights.shape:
            raise ValueError(
                f"Shape mismatch: {param.data.shape} != {new_weights.shape}"
            )
        new_dict[name] = new_weights

    print(f"Loaded {count} weights")
    model.load_state_dict(new_dict, strict=False)

In [4]:
model = ghostnetv2()
x = torch.randn(1, 5, 224, 224)
y, intermediates = model(x)
print(y.shape)
print(len(intermediates))

torch.Size([1, 1000])
10


In [5]:
load_fi_weights(model, "weights/GN_SSL_280.pt")

Loaded 764 weights
