In [None]:
from pathlib import Path

import torch
from easydict import EasyDict as edict
import yaml
import openslide

from histolung.models.models import MILModel, PretrainedModelLoader
from histolung.legacy.heatmaps import MIL_model
from histolung.legacy.models import ModelOption
from histolung.utils import yaml_load

In [None]:
modeldir = Path("/home/valentin/workspaces/histolung/models/MIL/f_MIL_res34v2_v2_rumc_best_cosine_v3").resolve()

In [None]:
cfg = yaml_load(modeldir / "config_f_MIL_res34v2_v2_rumc_best_cosine_v3.yml")

In [None]:
checkpoint = torch.load(modeldir / "fold_0" / "checkpoint.pt")

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [None]:
model_old = ModelOption(cfg.model.model_name,
                        cfg.model.num_classes,
                        freeze=cfg.model.freeze_weights,
                        num_freezed_layers=cfg.model.num_frozen_layers,
                        dropout=cfg.model.dropout,
                        embedding_bool=cfg.model.embedding_bool,
                        pool_algorithm=cfg.model.pool_algorithm)

hidden_space_len = cfg.model.hidden_space_len

net_old = MIL_model(model_old, hidden_space_len, cfg)

net_old.load_state_dict(checkpoint["model_state_dict"], strict=False)
net_old.to(device)
net_old.eval()



In [None]:
model = PretrainedModelLoader(cfg.model.model_name,
                              cfg.model.num_classes,
                              freeze=cfg.model.freeze_weights,
                              num_freezed_layers=cfg.model.num_frozen_layers,
                              dropout=cfg.model.dropout,
                              embedding_bool=cfg.model.embedding_bool,
                              pool_algorithm=cfg.model.pool_algorithm)

In [None]:
hidden_space_len = cfg.model.hidden_space_len
net = MILModel(model, hidden_space_len, cfg)

In [None]:
net.load_state_dict(checkpoint["model_state_dict"], strict=False)
net.to(device)
net.eval()

In [None]:
x = torch.rand((1,3,226,226)).to(device)

In [None]:
emb_old = net_old.net(x)

In [None]:
def compare_models(model1, model2):
    for (name1, param1), (name2, param2) in zip(model1.state_dict().items(), model2.state_dict().items()):
        if not torch.equal(param1, param2):
            print(f"Mismatch found in layer: {name1}")
            return False
    print("The models have the same weights.")
    return True

In [None]:
compare_models(net.net, net_old.net)

In [None]:
emb = net.net(x)

In [None]:
emb_old - emb

In [None]:
net(x)

In [None]:
net_old(x,x)

In [None]:
net.load_state_dict(checkpoint["model_state_dict"], strict=False)

In [None]:
checkpoint["model_state_dict"]

In [None]:
net_old.state_dict

In [None]:
net.state_dict